Skip to content

Commit

Permalink
test: clean up operator tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 14, 2024
1 parent d2af8dc commit e165f27
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 34 deletions.
4 changes: 1 addition & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ end
@eval @run_package_tests filter = ti -> !isdisjoint(ti.tags, $tags_to_run) verbose = true

# TODO: This is a very slow test
@testitem "Test custom operators and additional types" tags = [:part2] begin
include("test_operators.jl")
end
include("test_operators.jl")

@testitem "Test tree construction and scoring" tags = [:part3] begin
include("test_tree_construction.jl")
Expand Down
134 changes: 103 additions & 31 deletions test/test_operators.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
using SymbolicRegression
using SymbolicRegression:
plus,
sub,
mult,
square,
cube,
safe_pow,
safe_log,
safe_log2,
safe_log10,
safe_sqrt,
safe_acosh,
neg,
greater,
cond,
relu,
logical_or,
logical_and,
gamma
using Random: MersenneTwister
using Suppressor: @capture_err
using LoopVectorization
include("test_params.jl")

@testset "Generic operator tests" begin
@testitem "Generic operator tests" tags = [:part2] begin
using SymbolicRegression
using SymbolicRegression:
plus,
sub,
mult,
square,
cube,
safe_pow,
safe_log,
safe_log2,
safe_log10,
safe_sqrt,
safe_acosh,
safe_atanh,
safe_asin,
safe_acos,
neg,
greater,
cond,
relu,
logical_or,
logical_and,
gamma

types_to_test = [Float16, Float32, Float64, BigFloat]
for T in types_to_test
val = T(0.5)
Expand Down Expand Up @@ -76,7 +75,11 @@ include("test_params.jl")
end
end

@testset "Test built-in operators pass validation" begin
@testitem "Built-in operators pass validation" tags = [:part3] begin
using SymbolicRegression
using SymbolicRegression:
plus, sub, mult, square, cube, neg, relu, greater, logical_or, logical_and, cond

types_to_test = [Float16, Float32, Float64, BigFloat]
options = Options(;
binary_operators=[plus, sub, mult, /, ^, greater, logical_or, logical_and, cond],
Expand All @@ -89,7 +92,10 @@ end
end
end

@testset "Test built-in operators pass validation for complex numbers" begin
@testitem "Built-in operators pass validation for complex numbers" tags = [:part2] begin
using SymbolicRegression
using SymbolicRegression: plus, sub, mult, square, cube, neg

types_to_test = [ComplexF16, ComplexF32, ComplexF64]
options = Options(;
binary_operators=[plus, sub, mult, /, ^],
Expand All @@ -100,7 +106,10 @@ end
end
end

@testset "Test incompatibilities are caught" begin
@testitem "Incompatibilities are caught" tags = [:part3] begin
using SymbolicRegression
using SymbolicRegression: greater

options = Options(; binary_operators=[greater])
@test_throws ErrorException SymbolicRegression.assert_operators_well_defined(
ComplexF64, options
Expand All @@ -110,7 +119,9 @@ end
)
end

@testset "Operators which return the wrong type should fail" begin
@testitem "Operators with wrong type fail" tags = [:part2] begin
using SymbolicRegression

my_bad_op(x) = 1.0f0
options = Options(; binary_operators=[], unary_operators=[my_bad_op])
@test_throws ErrorException SymbolicRegression.assert_operators_well_defined(
Expand All @@ -122,7 +133,13 @@ end
@test_nowarn SymbolicRegression.assert_operators_well_defined(Float32, options)
end

@testset "Turbo mode should be the same" begin
@testitem "Turbo mode matches regular mode" tags = [:part3] begin
using SymbolicRegression
using SymbolicRegression:
plus, sub, mult, square, cube, neg, relu, greater, logical_or, logical_and, cond
using Random: MersenneTwister
using Suppressor: @capture_err

binary_operators = [plus, sub, mult, /, ^, greater, logical_or, logical_and, cond]
unary_operators = [square, cube, log, log2, log10, log1p, sqrt, atanh, acosh, neg, relu]
options = Options(; binary_operators, unary_operators)
Expand All @@ -148,3 +165,58 @@ end
end
end
end

@testitem "Safe operators are compatible with ForwardDiff" tags = [:part2] begin
using SymbolicRegression
using SymbolicRegression:
safe_log,
safe_log2,
safe_log10,
safe_log1p,
safe_sqrt,
safe_asin,
safe_acos,
safe_atanh,
safe_acosh,
safe_pow
using ForwardDiff

# Test all safe operators
safe_operators = [
(safe_log, 2.0, -1.0), # (operator, valid_input, invalid_input)
(safe_log2, 2.0, -1.0),
(safe_log10, 2.0, -1.0),
(safe_log1p, 0.5, -2.0),
(safe_sqrt, 2.0, -1.0),
(safe_asin, 0.5, 2.0),
(safe_acos, 0.5, 2.0),
(safe_atanh, 0.5, 2.0),
(safe_acosh, 2.0, 0.5),
]

for (op, valid_x, invalid_x) in safe_operators
# Test derivative exists and is correct for valid input
deriv = ForwardDiff.derivative(op, valid_x)
@test !isnan(deriv)
@test !iszero(deriv) # All these operators should have non-zero derivatives at test points

# Test derivative is 0.0 for invalid input
deriv_invalid = ForwardDiff.derivative(op, invalid_x)
@test iszero(deriv_invalid)
end

# Test safe_pow separately since it's binary
for x in [0.5, 2.0], y in [2.0, 0.5]
# Test valid derivatives
deriv_x = ForwardDiff.derivative(x -> safe_pow(x, y), x)
deriv_y = ForwardDiff.derivative(y -> safe_pow(x, y), y)
@test !isnan(deriv_x)
@test !isnan(deriv_y)
@test !iszero(deriv_x) # Should be non-zero for our test points

# Test invalid cases return 0.0 derivatives
@test iszero(ForwardDiff.derivative(x -> safe_pow(x, -1.0), 0.0)) # 0^(-1)
@test iszero(ForwardDiff.derivative(x -> safe_pow(-x, 0.5), 1.0)) # (-x)^0.5
@test iszero(ForwardDiff.derivative(x -> safe_pow(x, -0.5), 0.0)) # 0^(-0.5)
end
end

0 comments on commit e165f27

Please sign in to comment.