Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose recorder function for DynamicAutodiff.jl #377

Merged
merged 15 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
Expand Down Expand Up @@ -48,9 +49,10 @@ Dates = "1"
DifferentiationInterface = "0.5, 0.6"
DispatchDoctor = "^0.4.17"
Distributed = "<0.0.1, 1"
DynamicExpressions = "1.5.0"
DynamicExpressions = "1.6.0"
DynamicQuantities = "1"
Enzyme = "0.12, 0.13"
ForwardDiff = "0.10"
JSON3 = "1"
LineSearches = "7"
Logging = "1"
Expand Down
264 changes: 264 additions & 0 deletions src/ComposableExpression.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
module ComposableExpressionModule

using Compat: Fix
using DispatchDoctor: @unstable
using ForwardDiff: ForwardDiff
using DynamicExpressions:
AbstractExpression,
Expression,
AbstractExpressionNode,
AbstractOperatorEnum,
OperatorEnum,
Metadata,
constructorof,
get_metadata,
eval_tree_array,
set_node!,
get_contents,
with_contents,
with_metadata,
DynamicExpressions as DE
using DynamicExpressions.InterfacesModule:
ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments
Expand Down Expand Up @@ -244,4 +248,264 @@ for op in (
end
#! format: on

Base.@enum ConstantDerivative::UInt8 Zero One NegOne Other

"""
D(ex::AbstractComposableExpression, feature::Integer)

Compute the derivative of `ex` with respect to the `feature`-th variable.
Returns a new `ComposableExpression` with an expanded set of operators.
"""
function D(ex::AbstractComposableExpression, feature::Integer)
metadata = DE.get_metadata(ex)
raw_metadata = getfield(metadata, :_data) # TODO: Upstream this so we can load this
operators = DE.get_operators(ex)
mult_idx = findfirst(==(*), operators.binops)::Integer
plus_idx = findfirst(==(+), operators.binops)::Integer
nbin = length(operators.binops)
nuna = length(operators.unaops)
tree = DE.get_contents(ex)
operators_with_derivatives = _expand_operators(operators)
evaluates_to_constant = map(
op -> if op == _zero
Zero
elseif op == _one
One
elseif op == _n_one
NegOne
else
Other
end, operators_with_derivatives.binops
)
ctx = SymbolicDerivativeContext(;
feature, plus_idx, mult_idx, nbin, nuna, evaluates_to_constant
)
d_tree = _symbolic_derivative(tree, ctx)
return with_metadata(
with_contents(ex, d_tree); raw_metadata..., operators=operators_with_derivatives
)
end

Base.@kwdef struct SymbolicDerivativeContext{TUP}
feature::Int
plus_idx::Int
mult_idx::Int
nbin::Int
nuna::Int
evaluates_to_constant::TUP
end

function _symbolic_derivative(
tree::N, ctx::SymbolicDerivativeContext
) where {T,N<:AbstractExpressionNode{T}}
# NOTE: We cannot mutate the tree here! Since we use it twice.

# Quick test to see if we have any dependence on the feature, so
# we can return 0 for the branch
any_dependence = any(tree) do node
node.degree == 0 && !node.constant && node.feature == ctx.feature
end

if !any_dependence
return constructorof(N)(; val=zero(T))
elseif tree.degree == 0 # && any_dependence
return constructorof(N)(; val=one(T))
elseif tree.degree == 1
# f(g(x)) => f'(g(x)) * g'(x)
f_prime_op = tree.op + ctx.nuna

### We do some simplification based on zero/one derivatives ###
g_prime = _symbolic_derivative(tree.l, ctx)
if g_prime.degree == 0 && g_prime.constant && iszero(g_prime.val)
return g_prime
else
f_prime = constructorof(N)(; op=f_prime_op, l=tree.l)

if g_prime.degree == 0 && g_prime.constant && isone(g_prime.val)
return f_prime
else
return constructorof(N)(; op=ctx.mult_idx, l=f_prime, r=g_prime)
end
end
else # tree.degree == 2
# f(g(x), h(x)) => f^(1,0)(g(x), h(x)) * g'(x) + f^(0,1)(g(x), h(x)) * h'(x)
f_prime_left_op = tree.op + ctx.nbin
f_prime_right_op = tree.op + 2 * ctx.nbin
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
f_prime_left_evaluates_to = ctx.evaluates_to_constant[f_prime_left_op]
f_prime_right_evaluates_to = ctx.evaluates_to_constant[f_prime_right_op]

### We do some simplification based on zero/one derivatives ###
first_term = if f_prime_left_evaluates_to == Zero

# Simplify and just give zero
constructorof(N)(; val=zero(T))
else
g_prime = _symbolic_derivative(tree.l, ctx)

if f_prime_left_evaluates_to == One ||
(g_prime.degree == 0 && g_prime.constant && iszero(g_prime.val))
# Simplify and just give g_prime
g_prime
else
f_prime_left = if f_prime_left_evaluates_to == NegOne
constructorof(N)(; val=-one(T))
else
constructorof(N)(; op=f_prime_left_op, l=tree.l, r=tree.r)
end

if g_prime.degree == 0 && g_prime.constant && isone(g_prime.val)
f_prime_left
else
constructorof(N)(; op=ctx.mult_idx, l=f_prime_left, r=g_prime)
end
end
end

second_term = if f_prime_right_evaluates_to == Zero
# Simplify and just give zero
constructorof(N)(; val=zero(T))
else
h_prime = _symbolic_derivative(tree.r, ctx)
if f_prime_right_evaluates_to == One ||
(h_prime.degree == 0 && h_prime.constant && iszero(h_prime.val))
# Simplify and just give h_prime
h_prime
else
f_prime_right = if f_prime_right_evaluates_to == NegOne
constructorof(N)(; val=-one(T))
else
constructorof(N)(; op=f_prime_right_op, l=tree.l, r=tree.r)
end
if h_prime.degree == 0 && h_prime.constant && isone(h_prime.val)
f_prime_right
else
constructorof(N)(; op=ctx.mult_idx, l=f_prime_right, r=h_prime)
end
end
end

# Simplify if either term is zero
if first_term.degree == 0 && first_term.constant && iszero(first_term.val)
return second_term
elseif second_term.degree == 0 && second_term.constant && iszero(second_term.val)
return first_term
else
return constructorof(N)(; op=ctx.plus_idx, l=first_term, r=second_term)
end
end
end

struct OperatorDerivative{F,degree,arg} <: Function
op::F
end

function Base.show(io::IO, g::OperatorDerivative{F,degree,arg}) where {F,degree,arg}
print(io, "∂")
if degree == 2
if arg == 1
print(io, "₁")
elseif arg == 2
print(io, "₂")
end
end
print(io, g.op)
return nothing
end
Base.show(io::IO, ::MIME"text/plain", g::OperatorDerivative) = show(io, g)

# Generic derivatives:
function (d::OperatorDerivative{F,1,1})(x) where {F}
return ForwardDiff.derivative(d.op, x)
end
function (d::OperatorDerivative{F,2,1})(x, y) where {F}
return ForwardDiff.derivative(Fix{2}(d.op, y), x)
end
function (d::OperatorDerivative{F,2,2})(x, y) where {F}
return ForwardDiff.derivative(Fix{1}(d.op, x), y)
end
function operator_derivative(op::F, ::Val{degree}, ::Val{arg}) where {F,degree,arg}
return OperatorDerivative{F,degree,arg}(op)
end

#! format: off
# Special Cases
## Unary
operator_derivative(::typeof(sin), ::Val{1}, ::Val{1}) = cos
operator_derivative(::typeof(cos), ::Val{1}, ::Val{1}) = (-) ∘ sin
operator_derivative(::typeof((-) ∘ sin), ::Val{1}, ::Val{1}) = (-) ∘ cos
operator_derivative(::typeof((-) ∘ cos), ::Val{1}, ::Val{1}) = sin
operator_derivative(::typeof(exp), ::Val{1}, ::Val{1}) = exp

## Binary
# TODO: We assume that left/right are symmetric here!
_zero(x, _) = zero(x)
_one(x, _) = one(x)
_n_one(x, _) = -one(x)
operator_derivative(::typeof(_zero), ::Val{2}, ::Val{1}) = _zero
operator_derivative(::typeof(_zero), ::Val{2}, ::Val{2}) = _zero
operator_derivative(::typeof(_one), ::Val{2}, ::Val{1}) = _zero
operator_derivative(::typeof(_one), ::Val{2}, ::Val{2}) = _zero
operator_derivative(::typeof(_n_one), ::Val{2}, ::Val{1}) = _zero
operator_derivative(::typeof(_n_one), ::Val{2}, ::Val{2}) = _zero

### Addition
operator_derivative(::typeof(+), ::Val{2}, ::Val{1}) = _one
operator_derivative(::typeof(+), ::Val{2}, ::Val{2}) = _one
operator_derivative(::typeof(-), ::Val{2}, ::Val{1}) = _one
operator_derivative(::typeof(-), ::Val{2}, ::Val{2}) = _n_one

### Multiplication
operator_derivative(::typeof(*), ::Val{2}, ::Val{1}) = last ∘ tuple
operator_derivative(::typeof(*), ::Val{2}, ::Val{2}) = first ∘ tuple
operator_derivative(::typeof(first ∘ tuple), ::Val{2}, ::Val{1}) = _one
operator_derivative(::typeof(first ∘ tuple), ::Val{2}, ::Val{2}) = _zero
operator_derivative(::typeof(last ∘ tuple), ::Val{2}, ::Val{1}) = _zero
operator_derivative(::typeof(last ∘ tuple), ::Val{2}, ::Val{2}) = _one

### Division
struct DivMonomial{C,XP,YNP} <: Function end
function (m::DivMonomial{C,XP,YNP})(x, y) where {C,XP,YNP}
return C * (XP == 0 ? one(x) : x^XP) / (y^YNP)
end
operator_derivative(::typeof(/), ::Val{2}, ::Val{1}) = DivMonomial{1,0,1}()
operator_derivative(::typeof(/), ::Val{2}, ::Val{2}) = DivMonomial{-1,1,2}()
operator_derivative(::DivMonomial{C,XP,YNP}, ::Val{2}, ::Val{1}) where {C,XP,YNP} =
iszero(XP) ? _zero : DivMonomial{C * XP,XP - 1,YNP}()
operator_derivative(::DivMonomial{C,XP,YNP}, ::Val{2}, ::Val{2}) where {C,XP,YNP} =
DivMonomial{-C * YNP,XP,YNP + 1}()
#! format: on

DE.get_op_name(::typeof(first ∘ tuple)) = "first"
DE.get_op_name(::typeof(last ∘ tuple)) = "last"
DE.get_op_name(::typeof((-) ∘ sin)) = "-sin"
DE.get_op_name(::typeof((-) ∘ cos)) = "-cos"

function DE.get_op_name(::DivMonomial{C,XP,YNP}) where {C,XP,YNP}
return join(("((x, y) -> ", string(C), "x^", string(XP), "/y^", string(YNP), ")"))
end

function _expand_operators(operators::OperatorEnum)
unaops = operators.unaops
binops = operators.binops
new_unaops = ntuple(
i -> if i <= length(unaops)
unaops[i]
else
operator_derivative(unaops[i - length(unaops)], Val(1), Val(1))
end,
Val(2 * length(unaops)),
)
new_binops = ntuple(
i -> if i <= length(binops)
binops[i]
elseif i <= 2 * length(binops)
operator_derivative(binops[i - length(binops)], Val(2), Val(1))
else
operator_derivative(binops[i - 2 * length(binops)], Val(2), Val(2))
end,
Val(3 * length(binops)),
)
return OperatorEnum(new_binops, new_unaops)
end

end
2 changes: 1 addition & 1 deletion src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ using .SearchUtilsModule:
using .LoggingModule: AbstractSRLogger, SRLogger, get_logger
using .TemplateExpressionModule: TemplateExpression, TemplateStructure
using .TemplateExpressionModule: TemplateExpression, TemplateStructure, ValidVector
using .ComposableExpressionModule: ComposableExpression
using .ComposableExpressionModule: ComposableExpression, D
using .ExpressionBuilderModule: embed_metadata, strip_metadata

@stable default_mode = "disable" begin
Expand Down
15 changes: 14 additions & 1 deletion src/TemplateExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ using ..MutateModule: MutateModule as MM
using ..PopMemberModule: PopMember
using ..ComposableExpressionModule: ComposableExpression, ValidVector

import ..ComposableExpressionModule: D

"""
TemplateStructure{K,E,NF} <: Function

Expand Down Expand Up @@ -85,14 +87,25 @@ function _record_composable_expression!(variable_constraints, ::Val{k}, args...)
return isempty(args) ? 0.0 : first(args)
end

struct ArgumentRecorder{F} <: Function
f::F
end
(f::ArgumentRecorder)(args...) = f.f(args...)

# We pass through the derivative operators, since
# we just want to record the number of arguments.
D(f::ArgumentRecorder, _::Integer) = f

"""Infers number of features used by each subexpression, by passing in test data."""
function infer_variable_constraints(::Val{K}, combiner::F) where {K,F}
variable_constraints = NamedTuple{K}(map(_ -> Ref(-1), K))
# Now, we need to evaluate the `combine` function to see how many
# features are used for each function call. If unset, we record it.
# If set, we validate.
inner = Fix{1}(_record_composable_expression!, variable_constraints)
_recorders_of_composable_expressions = NamedTuple{K}(map(k -> Fix{1}(inner, Val(k)), K))
_recorders_of_composable_expressions = NamedTuple{K}(
map(k -> ArgumentRecorder(Fix{1}(inner, Val(k))), K)
)
# We use an evaluation to get the variable constraints
combiner(
_recorders_of_composable_expressions,
Expand Down
Loading
Loading