Skip to content

Commit

Permalink
Merge pull request #260 from MilesCranmer/cond-operator
Browse files Browse the repository at this point in the history
Create `cond` operator
  • Loading branch information
MilesCranmer authored Dec 11, 2023
2 parents debe0de + f555298 commit 59abbd8
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 132 deletions.
3 changes: 1 addition & 2 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import .OperatorsModule:
cube,
pow,
safe_pow,
div,
safe_log,
safe_log2,
safe_log10,
Expand All @@ -29,7 +28,7 @@ import .OperatorsModule:
safe_acosh,
neg,
greater,
greater,
cond,
relu,
logical_or,
logical_and,
Expand Down
75 changes: 29 additions & 46 deletions src/Operators.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
module OperatorsModule

using SpecialFunctions: SpecialFunctions
using DynamicQuantities: UnionAbstractQuantity
import SpecialFunctions: erf, erfc
import Base: @deprecate
import ..ProgramConstantsModule: DATA_TYPE
#TODO - actually add these operators to the module!

# TODO: Should this be limited to AbstractFloat instead?
function gamma(x::T)::T where {T<:DATA_TYPE}
out = SpecialFunctions.gamma(x)
return isinf(out) ? T(NaN) : out
end
gamma(x) = SpecialFunctions.gamma(x)

atanh_clip(x) = atanh(mod(x + 1, 2) - 1)
atanh_clip(x) = atanh(mod(x + oneunit(x), oneunit(x) + oneunit(x)) - oneunit(x)) * one(x)
# == atanh((x + 1) % 2 - 1)

# Implicitly defined:
#binary: mod
Expand All @@ -22,94 +25,74 @@ atanh_clip(x) = atanh(mod(x + 1, 2) - 1)
# Define allowed operators. Any julia operator can also be used.
# TODO: Add all of these operators to the precompilation.
# TODO: Since simplification is done in DynamicExpressions.jl, are these names correct anymore?
function plus(x::T, y::T)::T where {T<:DATA_TYPE}
return x + y #Do not change the name of this operator.
end
function sub(x::T, y::T)::T where {T<:DATA_TYPE}
return x - y #Do not change the name of this operator.
end
function mult(x::T, y::T)::T where {T<:DATA_TYPE}
return x * y #Do not change the name of this operator.
end
function square(x::T)::T where {T<:DATA_TYPE}
return x * x
end
function cube(x::T)::T where {T<:DATA_TYPE}
return x^3
end
function safe_pow(x::T, y::T)::T where {T<:AbstractFloat}
function safe_pow(x::T, y::T)::T where {T<:Union{AbstractFloat,UnionAbstractQuantity}}
if isinteger(y)
y < T(0) && x == T(0) && return T(NaN)
y < zero(y) && iszero(x) && return T(NaN)
else
y > T(0) && x < T(0) && return T(NaN)
y < T(0) && x <= T(0) && return T(NaN)
y > zero(y) && x < zero(x) && return T(NaN)
y < zero(y) && x <= zero(x) && return T(NaN)
end
return x^y
end
function div(x::T, y::T)::T where {T<:AbstractFloat}
return x / y
end
function safe_log(x::T)::T where {T<:AbstractFloat}
x <= T(0) && return T(NaN)
x <= zero(x) && return T(NaN)
return log(x)
end
function safe_log2(x::T)::T where {T<:AbstractFloat}
x <= T(0) && return T(NaN)
x <= zero(x) && return T(NaN)
return log2(x)
end
function safe_log10(x::T)::T where {T<:AbstractFloat}
x <= T(0) && return T(NaN)
x <= zero(x) && return T(NaN)
return log10(x)
end
function safe_log1p(x::T)::T where {T<:AbstractFloat}
x <= T(-1) && return T(NaN)
x <= -oneunit(x) && return T(NaN)
return log1p(x)
end
function safe_acosh(x::T)::T where {T<:AbstractFloat}
x < T(1) && return T(NaN)
x < oneunit(x) && return T(NaN)
return acosh(x)
end
function safe_sqrt(x::T)::T where {T<:AbstractFloat}
x < T(0) && return T(NaN)
x < zero(x) && return T(NaN)
return sqrt(x)
end
# TODO: Should the above be made more generic, for, e.g., compatibility with units?

# Generics (and SIMD)
# Do not change the names of these operators, as
# they have special use in simplifications and printing.
square(x) = x * x
cube(x) = x * x * x
plus(x, y) = x + y
sub(x, y) = x - y
mult(x, y) = x * y
# Generics (for SIMD)
safe_pow(x, y) = x^y
div(x, y) = x / y
safe_log(x) = log(x)
safe_log2(x) = log2(x)
safe_log10(x) = log10(x)
safe_log1p(x) = log1p(x)
safe_acosh(x) = acosh(x)
safe_sqrt(x) = sqrt(x)

function neg(x::T)::T where {T}
function neg(x)
return -x
end

function greater(x::T, y::T)::T where {T}
return convert(T, (x > y))
end
function greater(x, y)
return (x > y)
return (x > y) * one(x)
end
function relu(x::T)::T where {T}
return (x + abs(x)) / T(2)
function cond(x, y)
return (x > zero(x)) * y
end

function logical_or(x::T, y::T)::T where {T}
return convert(T, (x > convert(T, 0) || y > convert(T, 0)))
function relu(x)
return (x > zero(x)) * x
end

# (Just use multiplication normally)
function logical_and(x::T, y::T)::T where {T}
return convert(T, (x > convert(T, 0) && y > convert(T, 0)))
function logical_or(x, y)
return ((x > zero(x)) | (y > zero(y))) * one(x)
end
function logical_and(x, y)
return ((x > zero(x)) & (y > zero(y))) * one(x)
end

# Deprecated operations:
Expand Down
9 changes: 4 additions & 5 deletions src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import ..OperatorsModule:
safe_pow,
mult,
sub,
div,
safe_log,
safe_log10,
safe_log2,
Expand Down Expand Up @@ -85,7 +84,7 @@ function build_constraints(
return una_constraints, bin_constraints
end

function binopmap(op)
function binopmap(op::F) where {F}
if op == plus
return +
elseif op == mult
Expand All @@ -101,14 +100,14 @@ function binopmap(op)
end
return op
end
function inverse_binopmap(op)
function inverse_binopmap(op::F) where {F}
if op == safe_pow
return ^
end
return op
end

function unaopmap(op)
function unaopmap(op::F) where {F}
if op == log
return safe_log
elseif op == log10
Expand All @@ -126,7 +125,7 @@ function unaopmap(op)
end
return op
end
function inverse_unaopmap(op)
function inverse_unaopmap(op::F) where {F}
if op == safe_log
return log
elseif op == safe_log10
Expand Down
5 changes: 2 additions & 3 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ export Population,
cube,
pow,
safe_pow,
div,
safe_log,
safe_log2,
safe_log10,
Expand All @@ -53,6 +52,7 @@ export Population,
safe_sqrt,
neg,
greater,
cond,
relu,
logical_or,
logical_and,
Expand Down Expand Up @@ -176,7 +176,6 @@ import .CoreModule:
cube,
pow,
safe_pow,
div,
safe_log,
safe_log2,
safe_log10,
Expand All @@ -185,7 +184,7 @@ import .CoreModule:
safe_acosh,
neg,
greater,
greater,
cond,
relu,
logical_or,
logical_and,
Expand Down
Loading

0 comments on commit 59abbd8

Please sign in to comment.