From 42dfd8d145cc182adc08e72ab2ed4093f1eb152e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 14 Dec 2024 02:46:16 -0800 Subject: [PATCH 1/8] feat: add safe versions of `asin` and `acos` --- src/Operators.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/Operators.jl b/src/Operators.jl index b38ccf97f..11b00fa69 100644 --- a/src/Operators.jl +++ b/src/Operators.jl @@ -52,6 +52,14 @@ function safe_log1p(x::T)::T where {T<:AbstractFloat} x <= -oneunit(x) && return T(NaN) return log1p(x) end +function safe_asin(x::T)::T where {T<:AbstractFloat} + -oneunit(x) <= x <= oneunit(x) || return T(NaN) + return asin(x) +end +function safe_acos(x::T)::T where {T<:AbstractFloat} + -oneunit(x) <= x <= oneunit(x) || return T(NaN) + return acos(x) +end function safe_acosh(x::T)::T where {T<:AbstractFloat} x < oneunit(x) && return T(NaN) return acosh(x) @@ -75,6 +83,8 @@ safe_log(x) = log(x) safe_log2(x) = log2(x) safe_log10(x) = log10(x) safe_log1p(x) = log1p(x) +safe_asin(x) = asin(x) +safe_acos(x) = acos(x) safe_acosh(x) = acosh(x) safe_sqrt(x) = sqrt(x) @@ -103,6 +113,8 @@ DE.get_op_name(::typeof(safe_log)) = "log" DE.get_op_name(::typeof(safe_log2)) = "log2" DE.get_op_name(::typeof(safe_log10)) = "log10" DE.get_op_name(::typeof(safe_log1p)) = "log1p" +DE.get_op_name(::typeof(safe_asin)) = "asin" +DE.get_op_name(::typeof(safe_acos)) = "acos" DE.get_op_name(::typeof(safe_acosh)) = "acosh" DE.get_op_name(::typeof(safe_sqrt)) = "sqrt" @@ -112,6 +124,8 @@ DE.declare_operator_alias(::typeof(safe_log), ::Val{1}) = log DE.declare_operator_alias(::typeof(safe_log2), ::Val{1}) = log2 DE.declare_operator_alias(::typeof(safe_log10), ::Val{1}) = log10 DE.declare_operator_alias(::typeof(safe_log1p), ::Val{1}) = log1p +DE.declare_operator_alias(::typeof(safe_asin), ::Val{1}) = asin +DE.declare_operator_alias(::typeof(safe_acos), ::Val{1}) = acos DE.declare_operator_alias(::typeof(safe_acosh), ::Val{1}) = acosh DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt @@ -123,12 +137,15 @@ DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt @ignore pow(x, y) = safe_pow(x, y) @ignore pow_abs(x, y) = safe_pow(x, y) +# Actual mappings used for evaluation get_safe_op(op::F) where {F<:Function} = op get_safe_op(::typeof(^)) = safe_pow get_safe_op(::typeof(log)) = safe_log get_safe_op(::typeof(log2)) = safe_log2 get_safe_op(::typeof(log10)) = safe_log10 get_safe_op(::typeof(log1p)) = safe_log1p +get_safe_op(::typeof(asin)) = safe_asin +get_safe_op(::typeof(acos)) = safe_acos get_safe_op(::typeof(sqrt)) = safe_sqrt get_safe_op(::typeof(acosh)) = safe_acosh From 58ade27e87eb9f2cb7b28342a6834974cf4b296d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 14 Dec 2024 10:03:13 -0800 Subject: [PATCH 2/8] feat: make safe operators compatible with ForwardDiff --- src/Core.jl | 3 +++ src/Operators.jl | 23 ++++++++++++++--------- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/Core.jl b/src/Core.jl index c442efc73..bc9210aed 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -33,7 +33,10 @@ using .OperatorsModule: safe_log10, safe_log1p, safe_sqrt, + safe_asin, + safe_acos, safe_acosh, + safe_atanh, neg, greater, cond, diff --git a/src/Operators.jl b/src/Operators.jl index 11b00fa69..c3b73ecb4 100644 --- a/src/Operators.jl +++ b/src/Operators.jl @@ -5,6 +5,7 @@ using SpecialFunctions: SpecialFunctions using DynamicQuantities: UnionAbstractQuantity using SpecialFunctions: erf, erfc using Base: @deprecate +using DynamicDiff: ForwardDiff using ..ProgramConstantsModule: DATA_TYPE using ...UtilsModule: @ignore #TODO - actually add these operators to the module! @@ -19,6 +20,8 @@ gamma(x) = SpecialFunctions.gamma(x) atanh_clip(x) = atanh(mod(x + oneunit(x), oneunit(x) + oneunit(x)) - oneunit(x)) * one(x) # == atanh((x + 1) % 2 - 1) +const Dual = ForwardDiff.Dual + # Implicitly defined: #binary: mod #unary: exp, abs, log1p, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, erf, erfc, gamma, relu, round, floor, ceil, round, sign. @@ -27,7 +30,9 @@ atanh_clip(x) = atanh(mod(x + oneunit(x), oneunit(x) + oneunit(x)) - oneunit(x)) # Define allowed operators. Any julia operator can also be used. # TODO: Add all of these operators to the precompilation. # TODO: Since simplification is done in DynamicExpressions.jl, are these names correct anymore? -function safe_pow(x::T, y::T)::T where {T<:Union{AbstractFloat,UnionAbstractQuantity}} +function safe_pow( + x::T, y::T +)::T where {T<:Union{AbstractFloat,UnionAbstractQuantity,Dual{<:AbstractFloat}}} if isinteger(y) y < zero(y) && iszero(x) && return T(NaN) else @@ -36,35 +41,35 @@ function safe_pow(x::T, y::T)::T where {T<:Union{AbstractFloat,UnionAbstractQuan end return x^y end -function safe_log(x::T)::T where {T<:AbstractFloat} +function safe_log(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} x <= zero(x) && return T(NaN) return log(x) end -function safe_log2(x::T)::T where {T<:AbstractFloat} +function safe_log2(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} x <= zero(x) && return T(NaN) return log2(x) end -function safe_log10(x::T)::T where {T<:AbstractFloat} +function safe_log10(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} x <= zero(x) && return T(NaN) return log10(x) end -function safe_log1p(x::T)::T where {T<:AbstractFloat} +function safe_log1p(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} x <= -oneunit(x) && return T(NaN) return log1p(x) end -function safe_asin(x::T)::T where {T<:AbstractFloat} +function safe_asin(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} -oneunit(x) <= x <= oneunit(x) || return T(NaN) return asin(x) end -function safe_acos(x::T)::T where {T<:AbstractFloat} +function safe_acos(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} -oneunit(x) <= x <= oneunit(x) || return T(NaN) return acos(x) end -function safe_acosh(x::T)::T where {T<:AbstractFloat} +function safe_acosh(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} x < oneunit(x) && return T(NaN) return acosh(x) end -function safe_sqrt(x::T)::T where {T<:AbstractFloat} +function safe_sqrt(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} x < zero(x) && return T(NaN) return sqrt(x) end From f5a41e7963321ce2b2576daf1a8e64d692a1b7f5 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 14 Dec 2024 10:05:16 -0800 Subject: [PATCH 3/8] feat: create `safe_atanh` operator --- src/Operators.jl | 8 ++++++++ src/Options.jl | 24 +++++++++++++++++------- src/SymbolicRegression.jl | 6 ++++++ test/test_operators.jl | 6 ++++++ 4 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/Operators.jl b/src/Operators.jl index c3b73ecb4..3bf0120ad 100644 --- a/src/Operators.jl +++ b/src/Operators.jl @@ -69,6 +69,10 @@ function safe_acosh(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:Abstract x < oneunit(x) && return T(NaN) return acosh(x) end +function safe_atanh(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} + -oneunit(x) <= x <= oneunit(x) || return T(NaN) + return atanh(x) +end function safe_sqrt(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} x < zero(x) && return T(NaN) return sqrt(x) @@ -90,6 +94,7 @@ safe_log10(x) = log10(x) safe_log1p(x) = log1p(x) safe_asin(x) = asin(x) safe_acos(x) = acos(x) +safe_atanh(x) = atanh(x) safe_acosh(x) = acosh(x) safe_sqrt(x) = sqrt(x) @@ -121,6 +126,7 @@ DE.get_op_name(::typeof(safe_log1p)) = "log1p" DE.get_op_name(::typeof(safe_asin)) = "asin" DE.get_op_name(::typeof(safe_acos)) = "acos" DE.get_op_name(::typeof(safe_acosh)) = "acosh" +DE.get_op_name(::typeof(safe_atanh)) = "atanh" DE.get_op_name(::typeof(safe_sqrt)) = "sqrt" # Expression algebra @@ -132,6 +138,7 @@ DE.declare_operator_alias(::typeof(safe_log1p), ::Val{1}) = log1p DE.declare_operator_alias(::typeof(safe_asin), ::Val{1}) = asin DE.declare_operator_alias(::typeof(safe_acos), ::Val{1}) = acos DE.declare_operator_alias(::typeof(safe_acosh), ::Val{1}) = acosh +DE.declare_operator_alias(::typeof(safe_atanh), ::Val{1}) = atanh DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt # Deprecated operations: @@ -153,5 +160,6 @@ get_safe_op(::typeof(asin)) = safe_asin get_safe_op(::typeof(acos)) = safe_acos get_safe_op(::typeof(sqrt)) = safe_sqrt get_safe_op(::typeof(acosh)) = safe_acosh +get_safe_op(::typeof(atanh)) = safe_atanh end diff --git a/src/Options.jl b/src/Options.jl index d7fc61cf8..b6f1dd847 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -23,8 +23,10 @@ using ..OperatorsModule: safe_log2, safe_log1p, safe_sqrt, + safe_asin, + safe_acos, safe_acosh, - atanh_clip + safe_atanh using ..MutationWeightsModule: AbstractMutationWeights, MutationWeights, mutations import ..OptionsStructModule: Options using ..OptionsStructModule: ComplexityMapping, operator_specialization @@ -139,7 +141,7 @@ end ] end -function binopmap(op::F) where {F} +function binopmap(@nospecialize(op)) if op == plus return + elseif op == mult @@ -155,14 +157,14 @@ function binopmap(op::F) where {F} end return op end -function inverse_binopmap(op::F) where {F} +function inverse_binopmap(@nospecialize(op)) if op == safe_pow return ^ end return op end -function unaopmap(op::F) where {F} +function unaopmap(@nospecialize(op)) if op == log return safe_log elseif op == log10 @@ -173,14 +175,18 @@ function unaopmap(op::F) where {F} return safe_log1p elseif op == sqrt return safe_sqrt + elseif op == asin + return safe_asin + elseif op == acos + return safe_acos elseif op == acosh return safe_acosh elseif op == atanh - return atanh_clip + return safe_atanh end return op end -function inverse_unaopmap(op::F) where {F} +function inverse_unaopmap(@nospecialize(op)) if op == safe_log return log elseif op == safe_log10 @@ -191,9 +197,13 @@ function inverse_unaopmap(op::F) where {F} return log1p elseif op == safe_sqrt return sqrt + elseif op == safe_asin + return asin + elseif op == safe_acos + return acos elseif op == safe_acosh return acosh - elseif op == atanh_clip + elseif op == safe_atanh return atanh end return op diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index cc7decf09..fd95b0851 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -67,7 +67,10 @@ export Population, safe_log2, safe_log10, safe_log1p, + safe_asin, + safe_acos, safe_acosh, + safe_atanh, safe_sqrt, neg, greater, @@ -247,7 +250,10 @@ using .CoreModule: safe_log10, safe_log1p, safe_sqrt, + safe_asin, + safe_acos, safe_acosh, + safe_atanh, neg, greater, cond, diff --git a/test/test_operators.jl b/test/test_operators.jl index 1221ba6c7..c446ec088 100644 --- a/test/test_operators.jl +++ b/test/test_operators.jl @@ -37,6 +37,12 @@ include("test_params.jl") @test abs(safe_log1p(val) - log1p(val)) < 1e-6 @test abs(safe_acosh(val2) - acosh(val2)) < 1e-6 @test isnan(safe_acosh(-val2)) + @test abs(safe_asin(val) - asin(val)) < 1e-6 + @test isnan(safe_asin(val2)) + @test abs(safe_acos(val) - acos(val)) < 1e-6 + @test isnan(safe_acos(val2)) + @test abs(safe_atanh(val) - atanh(val)) < 1e-6 + @test isnan(safe_atanh(val2)) @test neg(-val) == val @test safe_sqrt(val) == sqrt(val) @test isnan(safe_sqrt(-val)) From d2af8dc027840371c778a5cabcf2a5e6240bf34e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 14 Dec 2024 11:27:03 -0800 Subject: [PATCH 4/8] refactor: clean up safe operators --- src/Operators.jl | 55 +++++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/src/Operators.jl b/src/Operators.jl index 3bf0120ad..350718bd8 100644 --- a/src/Operators.jl +++ b/src/Operators.jl @@ -26,13 +26,19 @@ const Dual = ForwardDiff.Dual #binary: mod #unary: exp, abs, log1p, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, erf, erfc, gamma, relu, round, floor, ceil, round, sign. +const FloatOrDual = Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}} + # Use some fast operators from https://github.com/JuliaLang/julia/blob/81597635c4ad1e8c2e1c5753fda4ec0e7397543f/base/fastmath.jl # Define allowed operators. Any julia operator can also be used. # TODO: Add all of these operators to the precompilation. # TODO: Since simplification is done in DynamicExpressions.jl, are these names correct anymore? function safe_pow( - x::T, y::T -)::T where {T<:Union{AbstractFloat,UnionAbstractQuantity,Dual{<:AbstractFloat}}} + x::T1, y::T2 +) where { + T1<:Union{FloatOrDual,UnionAbstractQuantity}, + T2<:Union{FloatOrDual,UnionAbstractQuantity}, +} + T = promote_type(T1, T2) if isinteger(y) y < zero(y) && iszero(x) && return T(NaN) else @@ -41,41 +47,32 @@ function safe_pow( end return x^y end -function safe_log(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} - x <= zero(x) && return T(NaN) - return log(x) +function safe_log(x::T)::T where {T<:FloatOrDual} + return x > zero(x) ? log(x) : T(NaN) end -function safe_log2(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} - x <= zero(x) && return T(NaN) - return log2(x) +function safe_log2(x::T)::T where {T<:FloatOrDual} + return x > zero(x) ? log2(x) : T(NaN) end -function safe_log10(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} - x <= zero(x) && return T(NaN) - return log10(x) +function safe_log10(x::T)::T where {T<:FloatOrDual} + return x > zero(x) ? log10(x) : T(NaN) end -function safe_log1p(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} - x <= -oneunit(x) && return T(NaN) - return log1p(x) +function safe_log1p(x::T)::T where {T<:FloatOrDual} + return x > -oneunit(x) ? log1p(x) : T(NaN) end -function safe_asin(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} - -oneunit(x) <= x <= oneunit(x) || return T(NaN) - return asin(x) +function safe_asin(x::T)::T where {T<:FloatOrDual} + return -oneunit(x) <= x <= oneunit(x) ? asin(x) : T(NaN) end -function safe_acos(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} - -oneunit(x) <= x <= oneunit(x) || return T(NaN) - return acos(x) +function safe_acos(x::T)::T where {T<:FloatOrDual} + return -oneunit(x) <= x <= oneunit(x) ? acos(x) : T(NaN) end -function safe_acosh(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} - x < oneunit(x) && return T(NaN) - return acosh(x) +function safe_acosh(x::T)::T where {T<:FloatOrDual} + return x >= oneunit(x) ? acosh(x) : T(NaN) end -function safe_atanh(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} - -oneunit(x) <= x <= oneunit(x) || return T(NaN) - return atanh(x) +function safe_atanh(x::T)::T where {T<:FloatOrDual} + return -oneunit(x) <= x <= oneunit(x) ? atanh(x) : T(NaN) end -function safe_sqrt(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}} - x < zero(x) && return T(NaN) - return sqrt(x) +function safe_sqrt(x::T)::T where {T<:FloatOrDual} + return x >= zero(x) ? sqrt(x) : T(NaN) end # TODO: Should the above be made more generic, for, e.g., compatibility with units? From e165f27f3c67a22539b87a669e8d5891a029bf6e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 14 Dec 2024 11:29:16 -0800 Subject: [PATCH 5/8] test: clean up operator tests --- test/runtests.jl | 4 +- test/test_operators.jl | 134 +++++++++++++++++++++++++++++++---------- 2 files changed, 104 insertions(+), 34 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index d02a1b9be..049ea6866 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,9 +11,7 @@ end @eval @run_package_tests filter = ti -> !isdisjoint(ti.tags, $tags_to_run) verbose = true # TODO: This is a very slow test -@testitem "Test custom operators and additional types" tags = [:part2] begin - include("test_operators.jl") -end +include("test_operators.jl") @testitem "Test tree construction and scoring" tags = [:part3] begin include("test_tree_construction.jl") diff --git a/test/test_operators.jl b/test/test_operators.jl index c446ec088..038270228 100644 --- a/test/test_operators.jl +++ b/test/test_operators.jl @@ -1,29 +1,28 @@ -using SymbolicRegression -using SymbolicRegression: - plus, - sub, - mult, - square, - cube, - safe_pow, - safe_log, - safe_log2, - safe_log10, - safe_sqrt, - safe_acosh, - neg, - greater, - cond, - relu, - logical_or, - logical_and, - gamma -using Random: MersenneTwister -using Suppressor: @capture_err -using LoopVectorization -include("test_params.jl") - -@testset "Generic operator tests" begin +@testitem "Generic operator tests" tags = [:part2] begin + using SymbolicRegression + using SymbolicRegression: + plus, + sub, + mult, + square, + cube, + safe_pow, + safe_log, + safe_log2, + safe_log10, + safe_sqrt, + safe_acosh, + safe_atanh, + safe_asin, + safe_acos, + neg, + greater, + cond, + relu, + logical_or, + logical_and, + gamma + types_to_test = [Float16, Float32, Float64, BigFloat] for T in types_to_test val = T(0.5) @@ -76,7 +75,11 @@ include("test_params.jl") end end -@testset "Test built-in operators pass validation" begin +@testitem "Built-in operators pass validation" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression: + plus, sub, mult, square, cube, neg, relu, greater, logical_or, logical_and, cond + types_to_test = [Float16, Float32, Float64, BigFloat] options = Options(; binary_operators=[plus, sub, mult, /, ^, greater, logical_or, logical_and, cond], @@ -89,7 +92,10 @@ end end end -@testset "Test built-in operators pass validation for complex numbers" begin +@testitem "Built-in operators pass validation for complex numbers" tags = [:part2] begin + using SymbolicRegression + using SymbolicRegression: plus, sub, mult, square, cube, neg + types_to_test = [ComplexF16, ComplexF32, ComplexF64] options = Options(; binary_operators=[plus, sub, mult, /, ^], @@ -100,7 +106,10 @@ end end end -@testset "Test incompatibilities are caught" begin +@testitem "Incompatibilities are caught" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression: greater + options = Options(; binary_operators=[greater]) @test_throws ErrorException SymbolicRegression.assert_operators_well_defined( ComplexF64, options @@ -110,7 +119,9 @@ end ) end -@testset "Operators which return the wrong type should fail" begin +@testitem "Operators with wrong type fail" tags = [:part2] begin + using SymbolicRegression + my_bad_op(x) = 1.0f0 options = Options(; binary_operators=[], unary_operators=[my_bad_op]) @test_throws ErrorException SymbolicRegression.assert_operators_well_defined( @@ -122,7 +133,13 @@ end @test_nowarn SymbolicRegression.assert_operators_well_defined(Float32, options) end -@testset "Turbo mode should be the same" begin +@testitem "Turbo mode matches regular mode" tags = [:part3] begin + using SymbolicRegression + using SymbolicRegression: + plus, sub, mult, square, cube, neg, relu, greater, logical_or, logical_and, cond + using Random: MersenneTwister + using Suppressor: @capture_err + binary_operators = [plus, sub, mult, /, ^, greater, logical_or, logical_and, cond] unary_operators = [square, cube, log, log2, log10, log1p, sqrt, atanh, acosh, neg, relu] options = Options(; binary_operators, unary_operators) @@ -148,3 +165,58 @@ end end end end + +@testitem "Safe operators are compatible with ForwardDiff" tags = [:part2] begin + using SymbolicRegression + using SymbolicRegression: + safe_log, + safe_log2, + safe_log10, + safe_log1p, + safe_sqrt, + safe_asin, + safe_acos, + safe_atanh, + safe_acosh, + safe_pow + using ForwardDiff + + # Test all safe operators + safe_operators = [ + (safe_log, 2.0, -1.0), # (operator, valid_input, invalid_input) + (safe_log2, 2.0, -1.0), + (safe_log10, 2.0, -1.0), + (safe_log1p, 0.5, -2.0), + (safe_sqrt, 2.0, -1.0), + (safe_asin, 0.5, 2.0), + (safe_acos, 0.5, 2.0), + (safe_atanh, 0.5, 2.0), + (safe_acosh, 2.0, 0.5), + ] + + for (op, valid_x, invalid_x) in safe_operators + # Test derivative exists and is correct for valid input + deriv = ForwardDiff.derivative(op, valid_x) + @test !isnan(deriv) + @test !iszero(deriv) # All these operators should have non-zero derivatives at test points + + # Test derivative is 0.0 for invalid input + deriv_invalid = ForwardDiff.derivative(op, invalid_x) + @test iszero(deriv_invalid) + end + + # Test safe_pow separately since it's binary + for x in [0.5, 2.0], y in [2.0, 0.5] + # Test valid derivatives + deriv_x = ForwardDiff.derivative(x -> safe_pow(x, y), x) + deriv_y = ForwardDiff.derivative(y -> safe_pow(x, y), y) + @test !isnan(deriv_x) + @test !isnan(deriv_y) + @test !iszero(deriv_x) # Should be non-zero for our test points + + # Test invalid cases return 0.0 derivatives + @test iszero(ForwardDiff.derivative(x -> safe_pow(x, -1.0), 0.0)) # 0^(-1) + @test iszero(ForwardDiff.derivative(x -> safe_pow(-x, 0.5), 1.0)) # (-x)^0.5 + @test iszero(ForwardDiff.derivative(x -> safe_pow(x, -0.5), 0.0)) # 0^(-0.5) + end +end From d5db24c474bcaa52f639fd9e71506c9e7f2d55fd Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 14 Dec 2024 11:36:18 -0800 Subject: [PATCH 6/8] test: missing LV import --- test/test_operators.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_operators.jl b/test/test_operators.jl index 038270228..2d92a986d 100644 --- a/test/test_operators.jl +++ b/test/test_operators.jl @@ -139,6 +139,7 @@ end plus, sub, mult, square, cube, neg, relu, greater, logical_or, logical_and, cond using Random: MersenneTwister using Suppressor: @capture_err + using LoopVectorization: LoopVectorization as _ binary_operators = [plus, sub, mult, /, ^, greater, logical_or, logical_and, cond] unary_operators = [square, cube, log, log2, log10, log1p, sqrt, atanh, acosh, neg, relu] From db47189f359ef48f6cd2d867515e573bf8c759c4 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 14 Dec 2024 11:39:14 -0800 Subject: [PATCH 7/8] test: missing `test_info` import --- test/test_operators.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_operators.jl b/test/test_operators.jl index 2d92a986d..39c85e33b 100644 --- a/test/test_operators.jl +++ b/test/test_operators.jl @@ -140,6 +140,7 @@ end using Random: MersenneTwister using Suppressor: @capture_err using LoopVectorization: LoopVectorization as _ + include("test_params.jl") binary_operators = [plus, sub, mult, /, ^, greater, logical_or, logical_and, cond] unary_operators = [square, cube, log, log2, log10, log1p, sqrt, atanh, acosh, neg, relu] From 235bde6efd51e6123a12ff798bdf7c2d7c05b4a3 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 14 Dec 2024 11:42:39 -0800 Subject: [PATCH 8/8] test: try to speed up operators test --- test/test_operators.jl | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/test/test_operators.jl b/test/test_operators.jl index 39c85e33b..b7643e707 100644 --- a/test/test_operators.jl +++ b/test/test_operators.jl @@ -145,6 +145,20 @@ end binary_operators = [plus, sub, mult, /, ^, greater, logical_or, logical_and, cond] unary_operators = [square, cube, log, log2, log10, log1p, sqrt, atanh, acosh, neg, relu] options = Options(; binary_operators, unary_operators) + + function test_part(tree, Xpart, options) + y, completed = eval_tree_array(tree, Xpart, options) + completed || return nothing + # We capture any warnings about the LoopVectorization not working + local y_turbo + eval_warnings = @capture_err begin + y_turbo, _ = eval_tree_array(tree, Xpart, options; turbo=true) + end + test_info(@test(y[1] ≈ y_turbo[1] && eval_warnings == "")) do + @info T tree X[:, seed] y y_turbo eval_warnings + end + end + for T in (Float32, Float64), index_bin in 1:length(binary_operators), index_una in 1:length(unary_operators) @@ -154,16 +168,7 @@ end X = rand(MersenneTwister(0), T, 2, 20) for seed in 1:20 Xpart = X[:, [seed]] - y, completed = eval_tree_array(tree, Xpart, options) - completed || continue - local y_turbo - # We capture any warnings about the LoopVectorization not working - eval_warnings = @capture_err begin - y_turbo, _ = eval_tree_array(tree, Xpart, options; turbo=true) - end - test_info(@test y[1] ≈ y_turbo[1] && eval_warnings == "") do - @info T tree X[:, seed] y y_turbo eval_warnings - end + test_part(tree, Xpart, options) end end end