Skip to content

Commit

Permalink
feat: create safe_atanh operator
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 14, 2024
1 parent 58ade27 commit f5a41e7
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 7 deletions.
8 changes: 8 additions & 0 deletions src/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ function safe_acosh(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:Abstract
x < oneunit(x) && return T(NaN)
return acosh(x)
end
function safe_atanh(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}}
-oneunit(x) <= x <= oneunit(x) || return T(NaN)
return atanh(x)
end
function safe_sqrt(x::T)::T where {T<:Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}}
x < zero(x) && return T(NaN)
return sqrt(x)
Expand All @@ -90,6 +94,7 @@ safe_log10(x) = log10(x)
safe_log1p(x) = log1p(x)
safe_asin(x) = asin(x)
safe_acos(x) = acos(x)
safe_atanh(x) = atanh(x)
safe_acosh(x) = acosh(x)
safe_sqrt(x) = sqrt(x)

Expand Down Expand Up @@ -121,6 +126,7 @@ DE.get_op_name(::typeof(safe_log1p)) = "log1p"
DE.get_op_name(::typeof(safe_asin)) = "asin"
DE.get_op_name(::typeof(safe_acos)) = "acos"
DE.get_op_name(::typeof(safe_acosh)) = "acosh"
DE.get_op_name(::typeof(safe_atanh)) = "atanh"
DE.get_op_name(::typeof(safe_sqrt)) = "sqrt"

# Expression algebra
Expand All @@ -132,6 +138,7 @@ DE.declare_operator_alias(::typeof(safe_log1p), ::Val{1}) = log1p
DE.declare_operator_alias(::typeof(safe_asin), ::Val{1}) = asin
DE.declare_operator_alias(::typeof(safe_acos), ::Val{1}) = acos
DE.declare_operator_alias(::typeof(safe_acosh), ::Val{1}) = acosh
DE.declare_operator_alias(::typeof(safe_atanh), ::Val{1}) = atanh
DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt

# Deprecated operations:
Expand All @@ -153,5 +160,6 @@ get_safe_op(::typeof(asin)) = safe_asin
get_safe_op(::typeof(acos)) = safe_acos
get_safe_op(::typeof(sqrt)) = safe_sqrt
get_safe_op(::typeof(acosh)) = safe_acosh
get_safe_op(::typeof(atanh)) = safe_atanh

end
24 changes: 17 additions & 7 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ using ..OperatorsModule:
safe_log2,
safe_log1p,
safe_sqrt,
safe_asin,
safe_acos,
safe_acosh,
atanh_clip
safe_atanh
using ..MutationWeightsModule: AbstractMutationWeights, MutationWeights, mutations
import ..OptionsStructModule: Options
using ..OptionsStructModule: ComplexityMapping, operator_specialization
Expand Down Expand Up @@ -139,7 +141,7 @@ end
]
end

function binopmap(op::F) where {F}
function binopmap(@nospecialize(op))
if op == plus
return +
elseif op == mult
Expand All @@ -155,14 +157,14 @@ function binopmap(op::F) where {F}
end
return op
end
function inverse_binopmap(op::F) where {F}
function inverse_binopmap(@nospecialize(op))
if op == safe_pow
return ^
end
return op
end

function unaopmap(op::F) where {F}
function unaopmap(@nospecialize(op))
if op == log
return safe_log
elseif op == log10
Expand All @@ -173,14 +175,18 @@ function unaopmap(op::F) where {F}
return safe_log1p
elseif op == sqrt
return safe_sqrt
elseif op == asin
return safe_asin
elseif op == acos
return safe_acos
elseif op == acosh
return safe_acosh
elseif op == atanh
return atanh_clip
return safe_atanh
end
return op
end
function inverse_unaopmap(op::F) where {F}
function inverse_unaopmap(@nospecialize(op))
if op == safe_log
return log
elseif op == safe_log10
Expand All @@ -191,9 +197,13 @@ function inverse_unaopmap(op::F) where {F}
return log1p
elseif op == safe_sqrt
return sqrt
elseif op == safe_asin
return asin
elseif op == safe_acos
return acos
elseif op == safe_acosh
return acosh
elseif op == atanh_clip
elseif op == safe_atanh
return atanh
end
return op
Expand Down
6 changes: 6 additions & 0 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ export Population,
safe_log2,
safe_log10,
safe_log1p,
safe_asin,
safe_acos,
safe_acosh,
safe_atanh,
safe_sqrt,
neg,
greater,
Expand Down Expand Up @@ -247,7 +250,10 @@ using .CoreModule:
safe_log10,
safe_log1p,
safe_sqrt,
safe_asin,
safe_acos,
safe_acosh,
safe_atanh,
neg,
greater,
cond,
Expand Down
6 changes: 6 additions & 0 deletions test/test_operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ include("test_params.jl")
@test abs(safe_log1p(val) - log1p(val)) < 1e-6
@test abs(safe_acosh(val2) - acosh(val2)) < 1e-6
@test isnan(safe_acosh(-val2))
@test abs(safe_asin(val) - asin(val)) < 1e-6
@test isnan(safe_asin(val2))
@test abs(safe_acos(val) - acos(val)) < 1e-6
@test isnan(safe_acos(val2))
@test abs(safe_atanh(val) - atanh(val)) < 1e-6
@test isnan(safe_atanh(val2))
@test neg(-val) == val
@test safe_sqrt(val) == sqrt(val)
@test isnan(safe_sqrt(-val))
Expand Down

0 comments on commit f5a41e7

Please sign in to comment.