Skip to content

Commit

Permalink
Merge pull request #289 from MilesCranmer/fix-selection-method
Browse files Browse the repository at this point in the history
Make it easier to select expression from Pareto front for evaluation
  • Loading branch information
MilesCranmer authored Feb 17, 2024
2 parents feab045 + b5d56e0 commit ccea30f
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 66 deletions.
16 changes: 10 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,20 @@ predict(mach, X)
```

This will make predictions using the expression
selected using the function passed to `selection_method`.
By default this selection is made a mix of accuracy and complexity.
For example, we can make predictions using expression 2 with:
selected by `model.selection_method`,
which by default is a mix of accuracy and complexity.

You can override this selection and select an equation from
the Pareto front manually with:

```julia
mach.model.selection_method = Returns(2)
predict(mach, X)
predict(mach, (data=X, idx=2))
```

For fitting multiple outputs, one can use `MultitargetSRRegressor`.
where here we choose to evaluate the second equation.

For fitting multiple outputs, one can use `MultitargetSRRegressor`
(and pass an array of indices to `idx` in `predict` for selecting specific equations).
For a full list of options available to each regressor, see the [API page](https://astroautomata.com/SymbolicRegression.jl/dev/api/).

### Low-Level Interface
Expand Down
130 changes: 70 additions & 60 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,22 +272,12 @@ function prediction_warn()
@warn "Evaluation failed either due to NaNs detected or due to unfinished search. Using 0s for prediction."
end

@inline function wrap_units(v, y_units, i::Integer)
if y_units === nothing
return v
else
return (yi -> Quantity(yi, y_units[i])).(v)
end
end
@inline function wrap_units(v, y_units, ::Nothing)
if y_units === nothing
return v
else
return (yi -> Quantity(yi, y_units)).(v)
end
end
wrap_units(v, ::Nothing, ::Integer) = v
wrap_units(v, ::Nothing, ::Nothing) = v
wrap_units(v, y_units, i::Integer) = (yi -> Quantity(yi, y_units[i])).(v)
wrap_units(v, y_units, ::Nothing) = (yi -> Quantity(yi, y_units)).(v)

function prediction_fallback(::Type{T}, m::SRRegressor, Xnew_t, fitresult) where {T}
function prediction_fallback(::Type{T}, ::SRRegressor, Xnew_t, fitresult, _) where {T}
prediction_warn()
out = fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T))
return wrap_units(out, fitresult.y_units, nothing)
Expand All @@ -301,11 +291,11 @@ function prediction_fallback(
fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T)), fitresult.y_units, i
) for i in 1:(fitresult.num_targets)
]
out_matrix = reduce(hcat, out_cols)
out_matrix = hcat(out_cols...)
if !fitresult.y_is_table
return out_matrix
else
return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype=prototype)
return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype)
end
end

Expand Down Expand Up @@ -342,50 +332,58 @@ function MMI.fitted_params(m::AbstractSRRegressor, fitresult)
)
end

function MMI.predict(m::SRRegressor, fitresult, Xnew)
params = full_report(m, fitresult; v_with_strings=Val(false))
Xnew_t, variable_names, X_units = get_matrix_and_info(Xnew, m.dimensions_type)
T = promote_type(eltype(Xnew_t), fitresult.types.T)
if length(params.equations) == 0
return prediction_fallback(T, m, Xnew_t, fitresult)
end
X_units_clean = clean_units(X_units)
validate_variable_names(variable_names, fitresult)
validate_units(X_units_clean, fitresult.X_units)
eq = params.equations[params.best_idx]
out, completed = eval_tree_array(eq, Xnew_t, fitresult.options)
if !completed
return prediction_fallback(T, m, Xnew_t, fitresult)
function eval_tree_mlj(
tree::Node, X_t, m::AbstractSRRegressor, ::Type{T}, fitresult, i, prototype
) where {T}
out, completed = eval_tree_array(tree, X_t, fitresult.options)
if completed
return wrap_units(out, fitresult.y_units, i)
else
return wrap_units(out, fitresult.y_units, nothing)
return prediction_fallback(T, m, X_t, fitresult, prototype)
end
end
function MMI.predict(m::MultitargetSRRegressor, fitresult, Xnew)

function MMI.predict(m::M, fitresult, Xnew; idx=nothing) where {M<:AbstractSRRegressor}
if Xnew isa NamedTuple && (haskey(Xnew, :idx) || haskey(Xnew, :data))
@assert(
haskey(Xnew, :idx) && haskey(Xnew, :data) && length(keys(Xnew)) == 2,
"If specifying an equation index during prediction, you must use a named tuple with keys `idx` and `data`."
)
return MMI.predict(m, fitresult, Xnew.data; idx=Xnew.idx)
end

params = full_report(m, fitresult; v_with_strings=Val(false))
prototype = MMI.istable(Xnew) ? Xnew : nothing
Xnew_t, variable_names, X_units = get_matrix_and_info(Xnew, m.dimensions_type)
T = promote_type(eltype(Xnew_t), fitresult.types.T)

if isempty(params.equations) || any(isempty, params.equations)
@warn "Equations not found. Returning 0s for prediction."
return prediction_fallback(T, m, Xnew_t, fitresult, prototype)
end

X_units_clean = clean_units(X_units)
validate_variable_names(variable_names, fitresult)
validate_units(X_units_clean, fitresult.X_units)
equations = params.equations
if any(t -> length(t) == 0, equations)
return prediction_fallback(T, m, Xnew_t, fitresult, prototype)
end
best_idx = params.best_idx
outs = []
for (i, (best_i, eq)) in enumerate(zip(best_idx, equations))
out, completed = eval_tree_array(eq[best_i], Xnew_t, fitresult.options)
if !completed
return prediction_fallback(T, m, Xnew_t, fitresult, prototype)

idx = idx === nothing ? params.best_idx : idx

if M <: SRRegressor
return eval_tree_mlj(
params.equations[idx], Xnew_t, m, T, fitresult, nothing, prototype
)
elseif M <: MultitargetSRRegressor
outs = [
eval_tree_mlj(
params.equations[i][idx[i]], Xnew_t, m, T, fitresult, i, prototype
) for i in eachindex(idx, params.equations)
]
out_matrix = reduce(hcat, outs)
if !fitresult.y_is_table
return out_matrix
else
return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype)
end
push!(outs, wrap_units(out, fitresult.y_units, i))
end
out_matrix = reduce(hcat, outs)
if !fitresult.y_is_table
return out_matrix
else
return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype=prototype)
end
end

Expand Down Expand Up @@ -504,11 +502,14 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt
Note that if you pass complex data `::Complex{L}`, then the loss
type will automatically be set to `L`.
- `selection_method::Function`: Function to selection expression from
the Pareto frontier for use in `predict`. See `SymbolicRegression.MLJInterfaceModule.choose_best`
for an example. This function should return a single integer specifying
the index of the expression to use. By default, `choose_best` maximizes
the Pareto frontier for use in `predict`.
See `SymbolicRegression.MLJInterfaceModule.choose_best` for an example.
This function should return a single integer specifying
the index of the expression to use. By default, this maximizes
the score (a pound-for-pound rating) of expressions reaching the threshold
of 1.5x the minimum loss. To fix the index at `5`, you could just write `Returns(5)`.
of 1.5x the minimum loss. To override this at prediction time, you can pass
a named tuple with keys `data` and `idx` to `predict`. See the Operations
section for details.
- `dimensions_type::AbstractDimensions`: The type of dimensions to use when storing
the units of the data. By default this is `DynamicQuantities.SymbolicDimensions`.
"""
Expand All @@ -519,6 +520,9 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt
- `predict(mach, Xnew)`: Return predictions of the target given features `Xnew`, which
should have same scitype as `X` above. The expression used for prediction is defined
by the `selection_method` function, which can be seen by viewing `report(mach).best_idx`.
- `predict(mach, (data=Xnew, idx=i))`: Return predictions of the target given features
`Xnew`, which should have same scitype as `X` above. By passing a named tuple with keys
`data` and `idx`, you are able to specify the equation you wish to evaluate in `idx`.
$(bottom_matter)
"""
Expand Down Expand Up @@ -579,7 +583,8 @@ eval(
Note that unlike other regressors, symbolic regression stores a list of
trained models. The model chosen from this list is defined by the function
`selection_method` keyword argument, which by default balances accuracy
and complexity.
and complexity. You can override this at prediction time by passing a named
tuple with keys `data` and `idx`.
""",
r"^ " => "",
Expand All @@ -591,7 +596,8 @@ eval(
The fields of `fitted_params(mach)` are:
- `best_idx::Int`: The index of the best expression in the Pareto frontier,
as determined by the `selection_method` function.
as determined by the `selection_method` function. Override in `predict` by passing
a named tuple with keys `data` and `idx`.
- `equations::Vector{Node{T}}`: The expressions discovered by the search, represented
in a dominating Pareto frontier (i.e., the best expressions found for
each complexity). `T` is equal to the element type
Expand All @@ -604,7 +610,8 @@ eval(
The fields of `report(mach)` are:
- `best_idx::Int`: The index of the best expression in the Pareto frontier,
as determined by the `selection_method` function.
as determined by the `selection_method` function. Override in `predict` by passing
a named tuple with keys `data` and `idx`.
- `equations::Vector{Node{T}}`: The expressions discovered by the search, represented
in a dominating Pareto frontier (i.e., the best expressions found for
each complexity).
Expand Down Expand Up @@ -701,7 +708,8 @@ eval(
Note that unlike other regressors, symbolic regression stores a list of lists of
trained models. The models chosen from each of these lists is defined by the function
`selection_method` keyword argument, which by default balances accuracy
and complexity.
and complexity. You can override this at prediction time by passing a named
tuple with keys `data` and `idx`.
""",
r"^ " => "",
Expand All @@ -713,7 +721,8 @@ eval(
The fields of `fitted_params(mach)` are:
- `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier,
as determined by the `selection_method` function.
as determined by the `selection_method` function. Override in `predict` by passing
a named tuple with keys `data` and `idx`.
- `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented
in a dominating Pareto frontier (i.e., the best expressions found for
each complexity). The outer vector is indexed by target variable, and the inner
Expand All @@ -727,7 +736,8 @@ eval(
The fields of `report(mach)` are:
- `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier,
as determined by the `selection_method` function.
as determined by the `selection_method` function. Override in `predict` by passing
a named tuple with keys `data` and `idx`.
- `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented
in a dominating Pareto frontier (i.e., the best expressions found for
each complexity). The outer vector is indexed by target variable, and the inner
Expand Down
31 changes: 31 additions & 0 deletions test/test_mlj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,17 @@ end
fit!(mach)
rep = report(mach)
@test occursin("a", rep.equation_strings[rep.best_idx])
ypred_good = predict(mach, X)
@test sum(abs2, predict(mach, X) .- y) / length(y) < 1e-5

@testset "Check that we can choose the equation" begin
ypred_same = predict(mach, (data=X, idx=rep.best_idx))
@test ypred_good == ypred_same

ypred_bad = predict(mach, (data=X, idx=1))
@test ypred_good != ypred_bad
end

@testset "Smoke test SymbolicUtils" begin
eqn = node_to_symbolic(rep.equations[rep.best_idx], model)
n = symbolic_to_node(eqn, model)
Expand All @@ -63,6 +72,28 @@ end
@test all(
eq -> occursin("a", eq), [rep.equation_strings[i][rep.best_idx[i]] for i in 1:3]
)
ypred_good = predict(mach, X)

@testset "Test that we can choose the equation" begin
ypred_same = predict(mach, (data=X, idx=rep.best_idx))
@test ypred_good == ypred_same

ypred_bad = predict(mach, (data=X, idx=[1, 1, 1]))
@test ypred_good != ypred_bad

ypred_mixed = predict(mach, (data=X, idx=[rep.best_idx[1], 1, rep.best_idx[3]]))
@test ypred_mixed == hcat(ypred_good[:, 1], ypred_bad[:, 2], ypred_good[:, 3])

@test_throws AssertionError predict(mach, (data=X,))
VERSION >= v"1.8" &&
@test_throws "If specifying an equation index during" predict(
mach, (data=X,)
)
VERSION >= v"1.8" &&
@test_throws "If specifying an equation index during" predict(
mach, (X=X, idx=1)
)
end
end

@testset "Named outputs" begin
Expand Down

0 comments on commit ccea30f

Please sign in to comment.