Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add safe versions of asin and acos #388

Merged
merged 8 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ using .OperatorsModule:
safe_log10,
safe_log1p,
safe_sqrt,
safe_asin,
safe_acos,
safe_acosh,
safe_atanh,
neg,
greater,
cond,
Expand Down
65 changes: 46 additions & 19 deletions src/Operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using SpecialFunctions: SpecialFunctions
using DynamicQuantities: UnionAbstractQuantity
using SpecialFunctions: erf, erfc
using Base: @deprecate
using DynamicDiff: ForwardDiff
using ..ProgramConstantsModule: DATA_TYPE
using ...UtilsModule: @ignore
#TODO - actually add these operators to the module!
Expand All @@ -19,15 +20,25 @@ gamma(x) = SpecialFunctions.gamma(x)
atanh_clip(x) = atanh(mod(x + oneunit(x), oneunit(x) + oneunit(x)) - oneunit(x)) * one(x)
# == atanh((x + 1) % 2 - 1)

const Dual = ForwardDiff.Dual

# Implicitly defined:
#binary: mod
#unary: exp, abs, log1p, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, erf, erfc, gamma, relu, round, floor, ceil, round, sign.

const FloatOrDual = Union{AbstractFloat,Dual{<:Any,<:AbstractFloat}}

# Use some fast operators from https://github.com/JuliaLang/julia/blob/81597635c4ad1e8c2e1c5753fda4ec0e7397543f/base/fastmath.jl
# Define allowed operators. Any julia operator can also be used.
# TODO: Add all of these operators to the precompilation.
# TODO: Since simplification is done in DynamicExpressions.jl, are these names correct anymore?
function safe_pow(x::T, y::T)::T where {T<:Union{AbstractFloat,UnionAbstractQuantity}}
function safe_pow(
x::T1, y::T2
) where {
T1<:Union{FloatOrDual,UnionAbstractQuantity},
T2<:Union{FloatOrDual,UnionAbstractQuantity},
}
T = promote_type(T1, T2)
if isinteger(y)
y < zero(y) && iszero(x) && return T(NaN)
else
Expand All @@ -36,29 +47,32 @@ function safe_pow(x::T, y::T)::T where {T<:Union{AbstractFloat,UnionAbstractQuan
end
return x^y
end
function safe_log(x::T)::T where {T<:AbstractFloat}
x <= zero(x) && return T(NaN)
return log(x)
function safe_log(x::T)::T where {T<:FloatOrDual}
return x > zero(x) ? log(x) : T(NaN)
end
function safe_log2(x::T)::T where {T<:FloatOrDual}
return x > zero(x) ? log2(x) : T(NaN)
end
function safe_log10(x::T)::T where {T<:FloatOrDual}
return x > zero(x) ? log10(x) : T(NaN)
end
function safe_log1p(x::T)::T where {T<:FloatOrDual}
return x > -oneunit(x) ? log1p(x) : T(NaN)
end
function safe_log2(x::T)::T where {T<:AbstractFloat}
x <= zero(x) && return T(NaN)
return log2(x)
function safe_asin(x::T)::T where {T<:FloatOrDual}
return -oneunit(x) <= x <= oneunit(x) ? asin(x) : T(NaN)
end
function safe_log10(x::T)::T where {T<:AbstractFloat}
x <= zero(x) && return T(NaN)
return log10(x)
function safe_acos(x::T)::T where {T<:FloatOrDual}
return -oneunit(x) <= x <= oneunit(x) ? acos(x) : T(NaN)
end
function safe_log1p(x::T)::T where {T<:AbstractFloat}
x <= -oneunit(x) && return T(NaN)
return log1p(x)
function safe_acosh(x::T)::T where {T<:FloatOrDual}
return x >= oneunit(x) ? acosh(x) : T(NaN)
end
function safe_acosh(x::T)::T where {T<:AbstractFloat}
x < oneunit(x) && return T(NaN)
return acosh(x)
function safe_atanh(x::T)::T where {T<:FloatOrDual}
return -oneunit(x) <= x <= oneunit(x) ? atanh(x) : T(NaN)
end
function safe_sqrt(x::T)::T where {T<:AbstractFloat}
x < zero(x) && return T(NaN)
return sqrt(x)
function safe_sqrt(x::T)::T where {T<:FloatOrDual}
return x >= zero(x) ? sqrt(x) : T(NaN)
end
# TODO: Should the above be made more generic, for, e.g., compatibility with units?

Expand All @@ -75,6 +89,9 @@ safe_log(x) = log(x)
safe_log2(x) = log2(x)
safe_log10(x) = log10(x)
safe_log1p(x) = log1p(x)
safe_asin(x) = asin(x)
safe_acos(x) = acos(x)
safe_atanh(x) = atanh(x)
safe_acosh(x) = acosh(x)
safe_sqrt(x) = sqrt(x)

Expand Down Expand Up @@ -103,7 +120,10 @@ DE.get_op_name(::typeof(safe_log)) = "log"
DE.get_op_name(::typeof(safe_log2)) = "log2"
DE.get_op_name(::typeof(safe_log10)) = "log10"
DE.get_op_name(::typeof(safe_log1p)) = "log1p"
DE.get_op_name(::typeof(safe_asin)) = "asin"
DE.get_op_name(::typeof(safe_acos)) = "acos"
DE.get_op_name(::typeof(safe_acosh)) = "acosh"
DE.get_op_name(::typeof(safe_atanh)) = "atanh"
DE.get_op_name(::typeof(safe_sqrt)) = "sqrt"

# Expression algebra
Expand All @@ -112,7 +132,10 @@ DE.declare_operator_alias(::typeof(safe_log), ::Val{1}) = log
DE.declare_operator_alias(::typeof(safe_log2), ::Val{1}) = log2
DE.declare_operator_alias(::typeof(safe_log10), ::Val{1}) = log10
DE.declare_operator_alias(::typeof(safe_log1p), ::Val{1}) = log1p
DE.declare_operator_alias(::typeof(safe_asin), ::Val{1}) = asin
DE.declare_operator_alias(::typeof(safe_acos), ::Val{1}) = acos
DE.declare_operator_alias(::typeof(safe_acosh), ::Val{1}) = acosh
DE.declare_operator_alias(::typeof(safe_atanh), ::Val{1}) = atanh
DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt

# Deprecated operations:
Expand All @@ -123,13 +146,17 @@ DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt
@ignore pow(x, y) = safe_pow(x, y)
@ignore pow_abs(x, y) = safe_pow(x, y)

# Actual mappings used for evaluation
get_safe_op(op::F) where {F<:Function} = op
get_safe_op(::typeof(^)) = safe_pow
get_safe_op(::typeof(log)) = safe_log
get_safe_op(::typeof(log2)) = safe_log2
get_safe_op(::typeof(log10)) = safe_log10
get_safe_op(::typeof(log1p)) = safe_log1p
get_safe_op(::typeof(asin)) = safe_asin
get_safe_op(::typeof(acos)) = safe_acos
get_safe_op(::typeof(sqrt)) = safe_sqrt
get_safe_op(::typeof(acosh)) = safe_acosh
get_safe_op(::typeof(atanh)) = safe_atanh

end
24 changes: 17 additions & 7 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ using ..OperatorsModule:
safe_log2,
safe_log1p,
safe_sqrt,
safe_asin,
safe_acos,
safe_acosh,
atanh_clip
safe_atanh
using ..MutationWeightsModule: AbstractMutationWeights, MutationWeights, mutations
import ..OptionsStructModule: Options
using ..OptionsStructModule: ComplexityMapping, operator_specialization
Expand Down Expand Up @@ -139,7 +141,7 @@ end
]
end

function binopmap(op::F) where {F}
function binopmap(@nospecialize(op))
if op == plus
return +
elseif op == mult
Expand All @@ -155,14 +157,14 @@ function binopmap(op::F) where {F}
end
return op
end
function inverse_binopmap(op::F) where {F}
function inverse_binopmap(@nospecialize(op))
if op == safe_pow
return ^
end
return op
end

function unaopmap(op::F) where {F}
function unaopmap(@nospecialize(op))
if op == log
return safe_log
elseif op == log10
Expand All @@ -173,14 +175,18 @@ function unaopmap(op::F) where {F}
return safe_log1p
elseif op == sqrt
return safe_sqrt
elseif op == asin
return safe_asin
elseif op == acos
return safe_acos
elseif op == acosh
return safe_acosh
elseif op == atanh
return atanh_clip
return safe_atanh
end
return op
end
function inverse_unaopmap(op::F) where {F}
function inverse_unaopmap(@nospecialize(op))
if op == safe_log
return log
elseif op == safe_log10
Expand All @@ -191,9 +197,13 @@ function inverse_unaopmap(op::F) where {F}
return log1p
elseif op == safe_sqrt
return sqrt
elseif op == safe_asin
return asin
elseif op == safe_acos
return acos
elseif op == safe_acosh
return acosh
elseif op == atanh_clip
elseif op == safe_atanh
return atanh
end
return op
Expand Down
6 changes: 6 additions & 0 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ export Population,
safe_log2,
safe_log10,
safe_log1p,
safe_asin,
safe_acos,
safe_acosh,
safe_atanh,
safe_sqrt,
neg,
greater,
Expand Down Expand Up @@ -247,7 +250,10 @@ using .CoreModule:
safe_log10,
safe_log1p,
safe_sqrt,
safe_asin,
safe_acos,
safe_acosh,
safe_atanh,
neg,
greater,
cond,
Expand Down
4 changes: 1 addition & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ end
@eval @run_package_tests filter = ti -> !isdisjoint(ti.tags, $tags_to_run) verbose = true

# TODO: This is a very slow test
@testitem "Test custom operators and additional types" tags = [:part2] begin
include("test_operators.jl")
end
include("test_operators.jl")

@testitem "Test tree construction and scoring" tags = [:part3] begin
include("test_tree_construction.jl")
Expand Down
Loading
Loading