Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: when using loss_function, Template Expression cannot have operators outside of binary_operators #380

Open
Moelf opened this issue Dec 8, 2024 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@Moelf
Copy link

Moelf commented Dec 8, 2024

What happened?

consider this:

atlas_template = TemplateStructure{(:b, :e)}(
     ((; b, e), (x,)) -> b(x)^e(x)
)

function test_loss(tree, dataset::Dataset{T,L}, options)::L where {T,L}
  prediction, flag = eval_tree_array(tree, dataset.X, options)
  if !flag
      return L(Inf)
  end
  return sum((prediction .- dataset.y).^2)
end

sr_model2 = SRRegressor(
    niterations=150,
    binary_operators=[+, -, *],
    expression_type=TemplateExpression,
    expression_options=(; structure=atlas_template),
    loss_function = test_loss
)

sr_mach2 = machine(sr_model2, X, ys)

fit!(sr_mach2)

right now it errors:

Version

v1.2.0

Operating System

Linux

Interface

Jupyter Notebook

Relevant log output

[ Info: Training machine(SRRegressor(defaults = nothing, …), …).
┌ Error: Problem fitting the machine machine(SRRegressor(defaults = nothing, …), …). 
└ @ MLJBase ~/.julia/packages/MLJBase/7nGJF/src/machines.jl:694
[ Info: Running type checks... 
[ Info: Type checks okay. 

Operator ^ not found in operators for expression type ComposableExpression{Float64, DynamicExpressions.NodeModule.Node{Float64}, @NamedTuple{operators::DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{}}, variable_names::Nothing}} with binary operators (+, -, *)

Stacktrace:
  [1] apply_operator(op::typeof(^), l::ComposableExpression{Float64, DynamicExpressions.NodeModule.Node{Float64}, @NamedTuple{operators::DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{}}, variable_names::Nothing}}, r::ComposableExpression{Float64, DynamicExpressions.NodeModule.Node{Float64}, @NamedTuple{operators::DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{}}, variable_names::Nothing}})
    @ DynamicExpressions.ExpressionAlgebraModule ~/.julia/packages/DynamicExpressions/fxs7F/src/ExpressionAlgebra.jl:85
  [2] ^(l::ComposableExpression{Float64, DynamicExpressions.NodeModule.Node{Float64}, @NamedTuple{operators::DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{}}, variable_names::Nothing}}, r::ComposableExpression{Float64, DynamicExpressions.NodeModule.Node{Float64}, @NamedTuple{operators::DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{}}, variable_names::Nothing}})
    @ DynamicExpressions.ExpressionAlgebraModule ~/.julia/packages/DynamicExpressions/fxs7F/src/ExpressionAlgebra.jl:119
  [3] (::var"#15#16")(::@NamedTuple{b::ComposableExpression{Float64, DynamicExpressions.NodeModule.Node{Float64}, @NamedTuple{operators::DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{}}, variable_names::Nothing}}, e::ComposableExpression{Float64, DynamicExpressions.NodeModule.Node{Float64}, @NamedTuple{operators::DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{}}, variable_names::Nothing}}}, ::Vector{ComposableExpression{Float32, DynamicExpressions.NodeModule.Node{Float32}, @NamedTuple{operators::DynamicExpressions.OperatorEnumModule.OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*)}, Tuple{}}, variable_names::Nothing}}})
    @ Main ./In[28]:2

Extra Info

removing custom loss_function fixs this

@Moelf Moelf added the bug Something isn't working label Dec 8, 2024
@Moelf Moelf changed the title [BUG]: loss_function cannot have operators outside of binary_operators [BUG]: when using loss_function, Template Expression cannot have operators outside of binary_operators Dec 8, 2024
@MilesCranmer
Copy link
Owner

Sorry, loss_function and TemplateExpression aren't compatible at the moment. This should raise a better error though.

@MilesCranmer
Copy link
Owner

Actually I guess they are compatible. But it requires your template structure to only use operators that are found in the operator set, otherwise get_tree won't be able to assemble the TemplateExpression into a single binary tree.

But a better solution would simply be to allow for loss_function to take the TemplateExpression directly.

@MilesCranmer
Copy link
Owner

In principle it's easy to implement this. I just have no idea what to do for the API. Should it be a different parameter, like loss_function_expression? Or should there be some flag you set that says the full expression should be passed, rather than the tree? Or something else? Feels like there's not really a clean way to add this at the moment.

@Moelf
Copy link
Author

Moelf commented Dec 9, 2024

I guess I didn't understand why loss function needs to be expressed in terms of SR expressions, if user doesn't need to take derivative, can we not view loss function as a blackbox and just evaluate it?

@MilesCranmer
Copy link
Owner

Yeah it can definitely be a black box. The error you saw is just when it was trying to extract a single AbstractExpressionNode from the template expression, to pass to the loss.

The difference is whether the user needs to define it in terms of AbstractExpressionNode input (current API), or AbstractExpression (proposed API). The latter would let you skip this issue. I just don’t know the right API for declaring one way or the other.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants