Skip to content

Commit

Permalink
Merge pull request #374 from MilesCranmer/fix-predict
Browse files Browse the repository at this point in the history
fix: `predict` for TemplateExpressions
  • Loading branch information
MilesCranmer authored Nov 28, 2024
2 parents 6d2a72d + 808bd10 commit 3f4b201
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
12 changes: 11 additions & 1 deletion examples/template_expression.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using SymbolicRegression
using Random: rand
using MLJBase: machine, fit!, report
using MLJBase: machine, fit!, report, predict
using Test: @test

options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos))
Expand Down Expand Up @@ -59,3 +59,13 @@ best_g2 = get_contents(best_expr).g2
@test best_f(x1, x2) @. sin.(x1)
@test best_g1(x3) (@. x3 * x3)
@test best_g2(x3) (@. x3)

# Test prediction
x_test = rand(10, 3)
y_test = [
(sin(x_test[i, 1]) + x_test[i, 3]^2, sin(x_test[i, 1]) + x_test[i, 3]) for
i in eachindex(axes(x_test, 1))
]
predictions = predict(mach, (data=x_test, idx=length(r.equations)))
@test map(first, predictions) map(first, y_test)
@test map(last, predictions) map(last, y_test)
4 changes: 4 additions & 0 deletions src/TemplateExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,10 @@ function contains_features_greater_than(tree::AbstractExpressionNode, max_featur
end
end

function Base.isempty(ex::TemplateExpression)
return all(isempty, values(get_contents(ex)))
end

# TODO: Add custom behavior to adjust what feature nodes can be generated

end

0 comments on commit 3f4b201

Please sign in to comment.