Skip to content

Commit

Permalink
feat: allow argument-less TemplateExpression parts
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Nov 27, 2024
1 parent dedb41a commit f4c0d7c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
5 changes: 5 additions & 0 deletions src/ComposableExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ function (ex::AbstractComposableExpression)(
return ValidVector(eval_tree_array(ex, X))
end
end
function (ex::AbstractComposableExpression{T})() where {T}
X = Matrix{T}(undef, 0, 1) # Value is irrelevant as it won't be used
out, _ = eval_tree_array(ex, X) # TODO: The valid is not used; not sure how to incorporate
return only(out)::T
end
function (ex::AbstractComposableExpression)(
x::AbstractComposableExpression, _xs::Vararg{AbstractComposableExpression,N}
) where {N}
Expand Down
2 changes: 1 addition & 1 deletion src/TemplateExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function _record_composable_expression!(variable_constraints, ::Val{k}, args...)
elseif vc != length(args)
throw(ArgumentError("Inconsistent number of arguments passed to $k"))
end
return first(args)
return isempty(args) ? 0.0 : first(args)
end

"""Infers number of features used by each subexpression, by passing in test data."""
Expand Down
24 changes: 24 additions & 0 deletions test/test_composable_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,27 @@ end
((; f, g), (x1, x2)) -> f(x1)
)
end

@testitem "Test argument-less template structure" tags = [:part2] begin
using SymbolicRegression
using DynamicExpressions: OperatorEnum

operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos))
variable_names = ["x1", "x2"]
x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names)
x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names)
c1 = ComposableExpression(Node{Float64}(; val=3.0); operators, variable_names)

# We can evaluate an expression with no arguments:
@test c1() == 3.0
@test typeof(c1()) === Float64

# Create a structure where f takes no arguments and g takes two
structure = TemplateStructure{(:f, :g)}(((; f, g), (x1, x2)) -> f() + g(x1, x2))

@test structure.num_features == (; f=0, g=2)

X = [1.0 2.0]'
expr = TemplateExpression((; f=c1, g=x1 + x2); structure, operators, variable_names)
@test expr(X) [6.0] # 3 + (1 + 2)
end

0 comments on commit f4c0d7c

Please sign in to comment.