Skip to content

Commit

Permalink
test: differential operator in template expression
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 12, 2024
1 parent edcb9ed commit bfa0a72
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
4 changes: 1 addition & 3 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,6 @@ using DynamicExpressions:
with_contents,
with_metadata
using DynamicExpressions: with_type_parameters
# TODO: Reexport D once DynamicAutodiff is registered
# @reexport using DynamicAutodiff: D
@reexport using LossFunctions:
MarginLoss,
DistanceLoss,
Expand Down Expand Up @@ -160,7 +158,7 @@ using DynamicExpressions: with_type_parameters
LogitDistLoss,
QuantileLoss,
LogCoshLoss
using DynamicAutodiff: D
using DynamicDiff: D
using Compat: @compat, Fix

#! format: off
Expand Down
4 changes: 2 additions & 2 deletions src/TemplateExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module TemplateExpressionModule

using Random: AbstractRNG
using Compat: Fix
using DynamicAutodiff: DynamicAutodiff
using DynamicDiff: DynamicDiff
using DispatchDoctor: @unstable, @stable
using StyledStrings: @styled_str, annotatedstring
using DynamicExpressions:
Expand Down Expand Up @@ -94,7 +94,7 @@ end

# We pass through the derivative operators, since
# we just want to record the number of arguments.
DynamicAutodiff.D(f::ArgumentRecorder, ::Integer) = f
DynamicDiff.D(f::ArgumentRecorder, ::Integer) = f

"""Infers number of features used by each subexpression, by passing in test data."""
function infer_variable_constraints(::Val{K}, combiner::F) where {K,F}
Expand Down
22 changes: 22 additions & 0 deletions test/test_composable_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,25 @@ end
expr = TemplateExpression((; f=c1, g=x1 + x2); structure, operators, variable_names)
@test expr(X) [6.0] # 3 + (1 + 2)
end

@testitem "Test TemplateExpression with differential operator" tags = [:part3] begin
using SymbolicRegression
using SymbolicRegression: D
using DynamicExpressions: OperatorEnum

operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(sin, cos))
variable_names = ["x1", "x2"]
x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names)
x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names)
x3 = ComposableExpression(Node{Float64}(; feature=3); operators, variable_names)

structure = TemplateStructure{(:f, :g)}(
((; f, g), (x1, x2, x3)) -> f(x1) + D(g, 1)(x2, x3)
)
expr = TemplateExpression(
(; f=x1, g=cos(x1 - x2) + 2.5 * x1); structure, operators, variable_names
)
# Truth: x1 - sin(x2 - x3) + 2.5
X = stack(([1.0, 2.0], [3.0, 4.0], [5.0, 6.0]); dims=1)
@test expr(X) [1.0, 2.0] .- sin.([3.0, 4.0] .- [5.0, 6.0]) .+ 2.5
end

0 comments on commit bfa0a72

Please sign in to comment.