diff --git a/examples/template_expression.jl b/examples/template_expression.jl index 5b172922..6bf6d135 100644 --- a/examples/template_expression.jl +++ b/examples/template_expression.jl @@ -1,6 +1,6 @@ using SymbolicRegression using Random: rand -using MLJBase: machine, fit!, report +using MLJBase: machine, fit!, report, predict using Test: @test options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) @@ -59,3 +59,13 @@ best_g2 = get_contents(best_expr).g2 @test best_f(x1, x2) ≈ @. sin.(x1) @test best_g1(x3) ≈ (@. x3 * x3) @test best_g2(x3) ≈ (@. x3) + +# Test prediction +x_test = rand(10, 3) +y_test = [ + (sin(x_test[i, 1]) + x_test[i, 3]^2, sin(x_test[i, 1]) + x_test[i, 3]) for + i in eachindex(axes(x_test, 1)) +] +predictions = predict(mach, (data=x_test, idx=length(r.equations))) +@test map(first, predictions) ≈ map(first, y_test) +@test map(last, predictions) ≈ map(last, y_test) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 45c8750f..b4ec83e1 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -493,6 +493,10 @@ function contains_features_greater_than(tree::AbstractExpressionNode, max_featur end end +function Base.isempty(ex::TemplateExpression) + return all(isempty, values(get_contents(ex))) +end + # TODO: Add custom behavior to adjust what feature nodes can be generated end