Skip to content

Commit

Permalink
feat: heavier automatic simplification in derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 8, 2024
1 parent 2c2d9fc commit 72e21fe
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 27 deletions.
89 changes: 66 additions & 23 deletions src/ComposableExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ for op in (
end
#! format: on

Base.@enum ConstantDerivative::UInt8 Zero One NegOne Other

"""
D(ex::AbstractComposableExpression, feature::Integer)
Expand All @@ -264,14 +266,19 @@ function D(ex::AbstractComposableExpression, feature::Integer)
nuna = length(operators.unaops)
tree = DE.get_contents(ex)
operators_with_derivatives = _expand_operators(operators)
evaluates_to_zero = ntuple(
i -> operators_with_derivatives.binops[i] == _zero, Val(3 * nbin)
)
evaluates_to_one = ntuple(
i -> operators_with_derivatives.binops[i] == _one, Val(3 * nbin)
evaluates_to_constant = map(
op -> if op == _zero
Zero
elseif op == _one
One
elseif op == _n_one
NegOne
else
Other
end, operators_with_derivatives.binops
)
ctx = SymbolicDerivativeContext(;
feature, plus_idx, mult_idx, nbin, nuna, evaluates_to_zero, evaluates_to_one
feature, plus_idx, mult_idx, nbin, nuna, evaluates_to_constant
)
d_tree = _symbolic_derivative(tree, ctx)
return with_metadata(
Expand All @@ -285,8 +292,7 @@ Base.@kwdef struct SymbolicDerivativeContext{TUP}
mult_idx::Int
nbin::Int
nuna::Int
evaluates_to_zero::TUP
evaluates_to_one::TUP
evaluates_to_constant::TUP
end

function _symbolic_derivative(
Expand All @@ -309,51 +315,83 @@ function _symbolic_derivative(
f_prime_op = tree.op + ctx.nuna

### We do some simplification based on zero/one derivatives ###
if ctx.evaluates_to_zero[f_prime_op]
return constructorof(N)(; val=zero(T))
g_prime = _symbolic_derivative(tree.l, ctx)
if g_prime.degree == 0 && g_prime.constant && iszero(g_prime.val)
return g_prime
else
g_prime = _symbolic_derivative(tree.l, ctx)
if ctx.evaluates_to_one[f_prime_op]
return g_prime
f_prime = constructorof(N)(; op=f_prime_op, l=tree.l)

if g_prime.degree == 0 && g_prime.constant && isone(g_prime.val)
return f_prime
else
f_prime = constructorof(N)(; op=f_prime_op, l=tree.l)
return constructorof(N)(; op=ctx.mult_idx, l=f_prime, r=g_prime)
end
end
else # tree.degree == 2
# f(g(x), h(x)) => f^(1,0)(g(x), h(x)) * g'(x) + f^(0,1)(g(x), h(x)) * h'(x)
f_prime_left_op = tree.op + ctx.nbin
f_prime_right_op = tree.op + 2 * ctx.nbin
f_prime_left_evaluates_to = ctx.evaluates_to_constant[f_prime_left_op]
f_prime_right_evaluates_to = ctx.evaluates_to_constant[f_prime_right_op]

### We do some simplification based on zero/one derivatives ###
first_term = if ctx.evaluates_to_zero[f_prime_left_op]
first_term = if f_prime_left_evaluates_to == Zero

# Simplify and just give zero
constructorof(N)(; val=zero(T))
else
g_prime = _symbolic_derivative(tree.l, ctx)
if ctx.evaluates_to_one[f_prime_left_op]

if f_prime_left_evaluates_to == One ||
(g_prime.degree == 0 && g_prime.constant && iszero(g_prime.val))
# Simplify and just give g_prime
g_prime
else
f_prime_left = constructorof(N)(; op=f_prime_left_op, l=tree.l, r=tree.r)
constructorof(N)(; op=ctx.mult_idx, l=f_prime_left, r=g_prime)
f_prime_left = if f_prime_left_evaluates_to == NegOne
constructorof(N)(; val=-one(T))
else
constructorof(N)(; op=f_prime_left_op, l=tree.l, r=tree.r)
end

if g_prime.degree == 0 && g_prime.constant && isone(g_prime.val)
f_prime_left
else
constructorof(N)(; op=ctx.mult_idx, l=f_prime_left, r=g_prime)
end
end
end

second_term = if ctx.evaluates_to_zero[f_prime_right_op]
second_term = if f_prime_right_evaluates_to == Zero
# Simplify and just give zero
constructorof(N)(; val=zero(T))
else
h_prime = _symbolic_derivative(tree.r, ctx)
if ctx.evaluates_to_one[f_prime_right_op]
if f_prime_right_evaluates_to == One ||
(h_prime.degree == 0 && h_prime.constant && iszero(h_prime.val))
# Simplify and just give h_prime
h_prime
else
f_prime_right = constructorof(N)(; op=f_prime_right_op, l=tree.l, r=tree.r)
constructorof(N)(; op=ctx.mult_idx, l=f_prime_right, r=h_prime)
f_prime_right = if f_prime_right_evaluates_to == NegOne
constructorof(N)(; val=-one(T))
else
constructorof(N)(; op=f_prime_right_op, l=tree.l, r=tree.r)
end
if h_prime.degree == 0 && h_prime.constant && isone(h_prime.val)
f_prime_right
else
constructorof(N)(; op=ctx.mult_idx, l=f_prime_right, r=h_prime)
end
end
end
return constructorof(N)(; op=ctx.plus_idx, l=first_term, r=second_term)

# Simplify if either term is zero
if first_term.degree == 0 && first_term.constant && iszero(first_term.val)
return second_term
elseif second_term.degree == 0 && second_term.constant && iszero(second_term.val)
return first_term
else
return constructorof(N)(; op=ctx.plus_idx, l=first_term, r=second_term)
end
end
end

Expand Down Expand Up @@ -437,6 +475,11 @@ operator_derivative(::DivMonomial{C,XP,YNP}, ::Val{2}, ::Val{2}) where {C,XP,YNP
DivMonomial{-C * YNP,XP,YNP + 1}()
#! format: on

DE.get_op_name(::typeof(first tuple)) = "first"
DE.get_op_name(::typeof(last tuple)) = "last"
DE.get_op_name(::typeof((-) sin)) = "-sin"
DE.get_op_name(::typeof((-) cos)) = "-cos"

function _expand_operators(operators::OperatorEnum)
unaops = operators.unaops
binops = operators.binops
Expand Down
32 changes: 28 additions & 4 deletions test/test_composable_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ end

@testitem "Test symbolic derivatives" tags = [:part2] begin
using SymbolicRegression: ComposableExpression, Node, D
using DynamicExpressions: OperatorEnum
using DynamicExpressions: OperatorEnum, @declare_expression_operator, AbstractExpression
using Zygote: gradient

# Basic setup
Expand Down Expand Up @@ -351,13 +351,37 @@ end
@test D(D(sin(x1) * cos(x2), 1), 2)([1.0], [2.0]) [cos(1.0) * -sin(2.0)]

# Printing should also be nice:
@test repr(D(x1 * x2, 1)) == "(∂₁*(x1, x2) * 1.0) + (∂₂*(x1, x2) * 0.0)"
@test repr(D(x1 * x2, 1)) == "last(x1, x2)"

# We also have special behavior when there is no dependence:
@test repr(D(sin(x2), 1)) == "0.0"
@test repr(D(x2 + sin(x2), 1)) == "0.0"
@test repr(D(x2 + sin(x2) - x1, 1)) ==
"(∂₁-(x2 + sin(x2), x1) * 0.0) + (∂₂-(x2 + sin(x2), x1) * 1.0)"
@test repr(D(x2 + sin(x2) - x1, 1)) == "-1.0"

# But still nice printing for things like -sin:
@test repr(D(D(sin(x1), 1), 1)) == "-sin(x1)"

# Without generating weird additional operators:
@test repr(D(D(D(sin(x1), 1), 1), 1)) == "-cos(x1)"

# Custom functions have nice printing:
my_op(x) = sin(x)
@declare_expression_operator(my_op, 1)
my_bin_op(x, y) = x + y
@declare_expression_operator(my_bin_op, 2)
operators = OperatorEnum(;
binary_operators=(+, -, *, /, my_bin_op), unary_operators=(my_op,)
)

x = ComposableExpression(Node(Float64; feature=1); operators, variable_names)
y = ComposableExpression(Node(Float64; feature=2); operators, variable_names)

@test repr(D(my_op(x), 1)) == "∂my_op(x1)"
@test repr(D(D(my_op(x), 1), 1)) == "∂∂my_op(x1)"

@test repr(D(my_bin_op(x, y), 1)) == "∂₁my_bin_op(x1, x2)"
@test repr(D(my_bin_op(x, y), 2)) == "∂₂my_bin_op(x1, x2)"
@test repr(D(my_bin_op(x, x - y), 2)) == "∂₂my_bin_op(x1, x1 - x2) * -1.0"
end

@testitem "Test template structure with derivatives" tags = [:part2] begin
Expand Down

0 comments on commit 72e21fe

Please sign in to comment.