Skip to content

Commit

Permalink
Merge pull request #388 from MilesCranmer/asin
Browse files Browse the repository at this point in the history
- Adds safe versions of `asin` and `acos`
- Also switches `atanh_clip` to `safe_atanh` (small change in behavior)
  • Loading branch information
MilesCranmer authored Dec 14, 2024
2 parents 5755693 + 235bde6 commit c52b36d
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 70 deletions.
3 changes: 3 additions & 0 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ using .OperatorsModule:
safe_log10,
safe_log1p,
safe_sqrt,
safe_asin,
safe_acos,
safe_acosh,
safe_atanh,
neg,
greater,
cond,
Expand Down
65 changes: 46 additions & 19 deletions src/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using SpecialFunctions: SpecialFunctions
using DynamicQuantities: UnionAbstractQuantity
using SpecialFunctions: erf, erfc
using Base: @deprecate
using DynamicDiff: ForwardDiff
using ..ProgramConstantsModule: DATA_TYPE
using ...UtilsModule: @ignore
#TODO - actually add these operators to the module!
Expand All @@ -19,15 +20,25 @@ gamma(x) = SpecialFunctions.gamma(x)
atanh_clip(x) = atanh(mod(x + oneunit(x), oneunit(x) + oneunit(x)) - oneunit(x)) * one(x)
# == atanh((x + 1) % 2 - 1)

const Dual = ForwardDiff.Dual

# Implicitly defined:
#binary: mod
#unary: exp, abs, log1p, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, erf, erfc, gamma, relu, round, floor, ceil, round, sign.

const FloatOrDual = Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}

# Use some fast operators from https://github.com/JuliaLang/julia/blob/81597635c4ad1e8c2e1c5753fda4ec0e7397543f/base/fastmath.jl
# Define allowed operators. Any julia operator can also be used.
# TODO: Add all of these operators to the precompilation.
# TODO: Since simplification is done in DynamicExpressions.jl, are these names correct anymore?
function safe_pow(x::T, y::T)::T where {T<:Union{AbstractFloat,UnionAbstractQuantity}}
function safe_pow(
x::T1, y::T2
) where {
T1<:Union{FloatOrDual,UnionAbstractQuantity},
T2<:Union{FloatOrDual,UnionAbstractQuantity},
}
T = promote_type(T1, T2)
if isinteger(y)
y < zero(y) && iszero(x) && return T(NaN)
else
Expand All @@ -36,29 +47,32 @@ function safe_pow(x::T, y::T)::T where {T<:Union{AbstractFloat,UnionAbstractQuan
end
return x^y
end
function safe_log(x::T)::T where {T<:AbstractFloat}
x <= zero(x) && return T(NaN)
return log(x)
function safe_log(x::T)::T where {T<:FloatOrDual}
return x > zero(x) ? log(x) : T(NaN)
end
function safe_log2(x::T)::T where {T<:FloatOrDual}
return x > zero(x) ? log2(x) : T(NaN)
end
function safe_log10(x::T)::T where {T<:FloatOrDual}
return x > zero(x) ? log10(x) : T(NaN)
end
function safe_log1p(x::T)::T where {T<:FloatOrDual}
return x > -oneunit(x) ? log1p(x) : T(NaN)
end
function safe_log2(x::T)::T where {T<:AbstractFloat}
x <= zero(x) && return T(NaN)
return log2(x)
function safe_asin(x::T)::T where {T<:FloatOrDual}
return -oneunit(x) <= x <= oneunit(x) ? asin(x) : T(NaN)
end
function safe_log10(x::T)::T where {T<:AbstractFloat}
x <= zero(x) && return T(NaN)
return log10(x)
function safe_acos(x::T)::T where {T<:FloatOrDual}
return -oneunit(x) <= x <= oneunit(x) ? acos(x) : T(NaN)
end
function safe_log1p(x::T)::T where {T<:AbstractFloat}
x <= -oneunit(x) && return T(NaN)
return log1p(x)
function safe_acosh(x::T)::T where {T<:FloatOrDual}
return x >= oneunit(x) ? acosh(x) : T(NaN)
end
function safe_acosh(x::T)::T where {T<:AbstractFloat}
x < oneunit(x) && return T(NaN)
return acosh(x)
function safe_atanh(x::T)::T where {T<:FloatOrDual}
return -oneunit(x) <= x <= oneunit(x) ? atanh(x) : T(NaN)
end
function safe_sqrt(x::T)::T where {T<:AbstractFloat}
x < zero(x) && return T(NaN)
return sqrt(x)
function safe_sqrt(x::T)::T where {T<:FloatOrDual}
return x >= zero(x) ? sqrt(x) : T(NaN)
end
# TODO: Should the above be made more generic, for, e.g., compatibility with units?

Expand All @@ -75,6 +89,9 @@ safe_log(x) = log(x)
safe_log2(x) = log2(x)
safe_log10(x) = log10(x)
safe_log1p(x) = log1p(x)
safe_asin(x) = asin(x)
safe_acos(x) = acos(x)
safe_atanh(x) = atanh(x)
safe_acosh(x) = acosh(x)
safe_sqrt(x) = sqrt(x)

Expand Down Expand Up @@ -103,7 +120,10 @@ DE.get_op_name(::typeof(safe_log)) = "log"
DE.get_op_name(::typeof(safe_log2)) = "log2"
DE.get_op_name(::typeof(safe_log10)) = "log10"
DE.get_op_name(::typeof(safe_log1p)) = "log1p"
DE.get_op_name(::typeof(safe_asin)) = "asin"
DE.get_op_name(::typeof(safe_acos)) = "acos"
DE.get_op_name(::typeof(safe_acosh)) = "acosh"
DE.get_op_name(::typeof(safe_atanh)) = "atanh"
DE.get_op_name(::typeof(safe_sqrt)) = "sqrt"

# Expression algebra
Expand All @@ -112,7 +132,10 @@ DE.declare_operator_alias(::typeof(safe_log), ::Val{1}) = log
DE.declare_operator_alias(::typeof(safe_log2), ::Val{1}) = log2
DE.declare_operator_alias(::typeof(safe_log10), ::Val{1}) = log10
DE.declare_operator_alias(::typeof(safe_log1p), ::Val{1}) = log1p
DE.declare_operator_alias(::typeof(safe_asin), ::Val{1}) = asin
DE.declare_operator_alias(::typeof(safe_acos), ::Val{1}) = acos
DE.declare_operator_alias(::typeof(safe_acosh), ::Val{1}) = acosh
DE.declare_operator_alias(::typeof(safe_atanh), ::Val{1}) = atanh
DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt

# Deprecated operations:
Expand All @@ -123,13 +146,17 @@ DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt
@ignore pow(x, y) = safe_pow(x, y)
@ignore pow_abs(x, y) = safe_pow(x, y)

# Actual mappings used for evaluation
get_safe_op(op::F) where {F<:Function} = op
get_safe_op(::typeof(^)) = safe_pow
get_safe_op(::typeof(log)) = safe_log
get_safe_op(::typeof(log2)) = safe_log2
get_safe_op(::typeof(log10)) = safe_log10
get_safe_op(::typeof(log1p)) = safe_log1p
get_safe_op(::typeof(asin)) = safe_asin
get_safe_op(::typeof(acos)) = safe_acos
get_safe_op(::typeof(sqrt)) = safe_sqrt
get_safe_op(::typeof(acosh)) = safe_acosh
get_safe_op(::typeof(atanh)) = safe_atanh

end
24 changes: 17 additions & 7 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ using ..OperatorsModule:
safe_log2,
safe_log1p,
safe_sqrt,
safe_asin,
safe_acos,
safe_acosh,
atanh_clip
safe_atanh
using ..MutationWeightsModule: AbstractMutationWeights, MutationWeights, mutations
import ..OptionsStructModule: Options
using ..OptionsStructModule: ComplexityMapping, operator_specialization
Expand Down Expand Up @@ -139,7 +141,7 @@ end
]
end

function binopmap(op::F) where {F}
function binopmap(@nospecialize(op))
if op == plus
return +
elseif op == mult
Expand All @@ -155,14 +157,14 @@ function binopmap(op::F) where {F}
end
return op
end
function inverse_binopmap(op::F) where {F}
function inverse_binopmap(@nospecialize(op))
if op == safe_pow
return ^
end
return op
end

function unaopmap(op::F) where {F}
function unaopmap(@nospecialize(op))
if op == log
return safe_log
elseif op == log10
Expand All @@ -173,14 +175,18 @@ function unaopmap(op::F) where {F}
return safe_log1p
elseif op == sqrt
return safe_sqrt
elseif op == asin
return safe_asin
elseif op == acos
return safe_acos
elseif op == acosh
return safe_acosh
elseif op == atanh
return atanh_clip
return safe_atanh
end
return op
end
function inverse_unaopmap(op::F) where {F}
function inverse_unaopmap(@nospecialize(op))
if op == safe_log
return log
elseif op == safe_log10
Expand All @@ -191,9 +197,13 @@ function inverse_unaopmap(op::F) where {F}
return log1p
elseif op == safe_sqrt
return sqrt
elseif op == safe_asin
return asin
elseif op == safe_acos
return acos
elseif op == safe_acosh
return acosh
elseif op == atanh_clip
elseif op == safe_atanh
return atanh
end
return op
Expand Down
6 changes: 6 additions & 0 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ export Population,
safe_log2,
safe_log10,
safe_log1p,
safe_asin,
safe_acos,
safe_acosh,
safe_atanh,
safe_sqrt,
neg,
greater,
Expand Down Expand Up @@ -247,7 +250,10 @@ using .CoreModule:
safe_log10,
safe_log1p,
safe_sqrt,
safe_asin,
safe_acos,
safe_acosh,
safe_atanh,
neg,
greater,
cond,
Expand Down
4 changes: 1 addition & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ end
@eval @run_package_tests filter = ti -> !isdisjoint(ti.tags, $tags_to_run) verbose = true

# TODO: This is a very slow test
@testitem "Test custom operators and additional types" tags = [:part2] begin
include("test_operators.jl")
end
include("test_operators.jl")

@testitem "Test tree construction and scoring" tags = [:part3] begin
include("test_tree_construction.jl")
Expand Down
Loading

0 comments on commit c52b36d

Please sign in to comment.