Skip to content

Commit

Permalink
Test operators under units
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 11, 2023
1 parent 93b6d92 commit f555298
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 11 deletions.
24 changes: 13 additions & 11 deletions src/Operators.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module OperatorsModule

using SpecialFunctions: SpecialFunctions
using DynamicQuantities: UnionAbstractQuantity
import SpecialFunctions: erf, erfc
import Base: @deprecate
import ..ProgramConstantsModule: DATA_TYPE
Expand All @@ -13,7 +14,8 @@ function gamma(x::T)::T where {T<:DATA_TYPE}
end
gamma(x) = SpecialFunctions.gamma(x)

atanh_clip(x) = atanh(mod(x + 1, 2) - 1)
atanh_clip(x) = atanh(mod(x + oneunit(x), oneunit(x) + oneunit(x)) - oneunit(x)) * one(x)
# == atanh((x + 1) % 2 - 1)

# Implicitly defined:
#binary: mod
Expand All @@ -23,37 +25,37 @@ atanh_clip(x) = atanh(mod(x + 1, 2) - 1)
# Define allowed operators. Any julia operator can also be used.
# TODO: Add all of these operators to the precompilation.
# TODO: Since simplification is done in DynamicExpressions.jl, are these names correct anymore?
function safe_pow(x::T, y::T)::T where {T<:AbstractFloat}
function safe_pow(x::T, y::T)::T where {T<:Union{AbstractFloat,UnionAbstractQuantity}}
if isinteger(y)
y < T(0) && x == T(0) && return T(NaN)
y < zero(y) && iszero(x) && return T(NaN)
else
y > T(0) && x < T(0) && return T(NaN)
y < T(0) && x <= T(0) && return T(NaN)
y > zero(y) && x < zero(x) && return T(NaN)
y < zero(y) && x <= zero(x) && return T(NaN)
end
return x^y
end
function safe_log(x::T)::T where {T<:AbstractFloat}
x <= T(0) && return T(NaN)
x <= zero(x) && return T(NaN)
return log(x)
end
function safe_log2(x::T)::T where {T<:AbstractFloat}
x <= T(0) && return T(NaN)
x <= zero(x) && return T(NaN)
return log2(x)
end
function safe_log10(x::T)::T where {T<:AbstractFloat}
x <= T(0) && return T(NaN)
x <= zero(x) && return T(NaN)
return log10(x)
end
function safe_log1p(x::T)::T where {T<:AbstractFloat}
x <= T(-1) && return T(NaN)
x <= -oneunit(x) && return T(NaN)
return log1p(x)
end
function safe_acosh(x::T)::T where {T<:AbstractFloat}
x < T(1) && return T(NaN)
x < oneunit(x) && return T(NaN)
return acosh(x)
end
function safe_sqrt(x::T)::T where {T<:AbstractFloat}
x < T(0) && return T(NaN)
x < zero(x) && return T(NaN)
return sqrt(x)
end
# TODO: Should the above be made more generic, for, e.g., compatibility with units?
Expand Down
54 changes: 54 additions & 0 deletions test/test_units.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
using SymbolicRegression
using SymbolicRegression:
square,
cube,
plus,
sub,
mult,
greater,
cond,
relu,
logical_or,
logical_and,
safe_pow,
atanh_clip
using SymbolicRegression.InterfaceDynamicQuantitiesModule: get_units, get_dimensions_type
using SymbolicRegression.MLJInterfaceModule: unwrap_units_single
using SymbolicRegression.DimensionalAnalysisModule:
Expand All @@ -10,6 +23,7 @@ import DynamicQuantities:
QuantityArray,
SymbolicDimensions,
Dimensions,
DimensionError,
@u_str,
@us_str,
uparse,
Expand Down Expand Up @@ -155,6 +169,46 @@ options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos]
@test length(valid_trees) > 0
end

@testset "Operator compatibility" begin
## square cube plus sub mult greater cond relu logical_or logical_and safe_pow atanh_clip
# Want to ensure these operators perform correctly in the context of units
@test square(1.0u"m") == 1.0u"m^2"
@test cube(1.0u"m") == 1.0u"m^3"
@test plus(1.0u"m", 1.0u"m") == 2.0u"m"
@test_throws DimensionError plus(1.0u"m", 1.0u"s")
@test sub(1.0u"m", 1.0u"m") == 0.0u"m"
@test_throws DimensionError sub(1.0u"m", 1.0u"s")
@test mult(1.0u"m", 1.0u"m") == 1.0u"m^2"
@test mult(1.0u"m", 1.0u"s") == 1.0u"m*s"
@test greater(1.1u"m", 1.0u"m") == true
@test greater(0.9u"m", 1.0u"m") == false
@test typeof(greater(1.1u"m", 1.0u"m")) === typeof(1.0u"m")
@test_throws DimensionError greater(1.0u"m", 1.0u"s")
@test cond(0.1u"m", 1.5u"m") == 1.5u"m"
@test cond(-0.1u"m", 1.5u"m") == 0.0u"m"
@test cond(-0.1u"s", 1.5u"m") == 0.0u"m"
@test relu(0.1u"m") == 0.1u"m"
@test relu(-0.1u"m") == 0.0u"m"
@test logical_or(0.1u"m", 0.0u"m") == 1.0
@test logical_or(-0.1u"m", 0.0u"m") == 0.0
@test logical_or(-0.5u"m", 1.0u"m") == 1.0
@test logical_or(-0.2u"m", -0.2u"m") == 0.0
@test logical_and(0.1u"m", 0.0u"m") == 0.0
@test logical_and(0.1u"s", 0.0u"m") == 0.0
@test logical_and(-0.1u"m", 0.0u"m") == 0.0
@test logical_and(-0.5u"m", 1.0u"m") == 0.0
@test logical_and(-0.2u"s", -0.2u"m") == 0.0
@test logical_and(0.2u"s", 0.2u"m") == 1.0
@test safe_pow(4.0u"m", 0.5u"1") == 2.0u"m^0.5"
@test isnan(safe_pow(-4.0u"m", 0.5u"1"))
@test typeof(safe_pow(-4.0u"m", 0.5u"1")) === typeof(1.0u"m")
@inferred safe_pow(4.0u"m", 0.5u"1")
@test_throws DimensionError safe_pow(1.0u"m", 1.0u"m")
@test atanh_clip(0.5u"1") == atanh(0.5)
@test atanh_clip(2.5u"1") == atanh(0.5)
@test_throws DimensionError atanh_clip(1.0u"m")
end

@testset "Search with dimensional constraints on output" begin
X = randn(2, 128)
X[2, :] .= X[1, :]
Expand Down

0 comments on commit f555298

Please sign in to comment.