Skip to content

Commit

Permalink
feat: conditionally widen MLJ scitype
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 21, 2024
1 parent 4336690 commit f4a3bb8
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ using ..ComplexityModule: compute_complexity
using ..HallOfFameModule: HallOfFame, format_hall_of_fame
using ..UtilsModule: subscriptify, @ignore
using ..LoggingModule: AbstractSRLogger
using ..TemplateExpressionModule: TemplateExpression

import ..equation_search

Expand All @@ -50,8 +51,9 @@ end

"""Generate an `SRRegressor` struct containing all the fields in `Options`."""
function modelexpr(model_name::Symbol)
struct_def = :(Base.@kwdef mutable struct $(model_name){D<:AbstractDimensions,L} <:
AbstractSRRegressor
struct_def = :(Base.@kwdef mutable struct $(model_name){
D<:AbstractDimensions,L,E<:AbstractExpression
} <: AbstractSRRegressor
niterations::Int = 100
parallelism::Symbol = :multithreading
numprocs::Union{Int,Nothing} = nothing
Expand All @@ -62,7 +64,7 @@ function modelexpr(model_name::Symbol)
logger::Union{AbstractSRLogger,Nothing} = nothing
runtests::Bool = true
run_id::Union{String,Nothing} = nothing
loss_type::L = Nothing
loss_type::Type{L} = Nothing
selection_method::Function = choose_best
dimensions_type::Type{D} = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE}
end)
Expand All @@ -71,7 +73,14 @@ function modelexpr(model_name::Symbol)

# Add everything from `Options` constructor directly to struct:
for (i, option) in enumerate(DEFAULT_OPTIONS)
if getsymb(first(option.args)) == :expression_type
continue
end
insert!(fields, i, Expr(:(=), option.args...))
if getsymb(first(option.args)) == :node_type
# Manually add `expression_type` above, so it can be depended on by `node_type`
insert!(fields, i - 1, :(expression_type::Type{E} = Expression))
end
end

# We also need to create the `get_options` function, based on this:
Expand Down Expand Up @@ -597,7 +606,7 @@ const input_scitype = Union{
MMI.metadata_model(
SRRegressor;
input_scitype,
target_scitype=AbstractVector{<:Any},
target_scitype=AbstractVector{<:MMI.Continuous},
supports_weights=true,
reports_feature_importances=false,
load_path="SymbolicRegression.MLJInterfaceModule.SRRegressor",
Expand All @@ -606,13 +615,22 @@ MMI.metadata_model(
MMI.metadata_model(
MultitargetSRRegressor;
input_scitype,
target_scitype=Union{MMI.Table(Any),AbstractMatrix{<:Any}},
target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}},
supports_weights=true,
reports_feature_importances=false,
load_path="SymbolicRegression.MLJInterfaceModule.MultitargetSRRegressor",
human_name="Multi-Target Symbolic Regression via Evolutionary Search",
)

function MMI.target_scitype(::Type{<:SRRegressor{<:Any,<:Any,<:TemplateExpression}})
return AbstractVector{<:Any}
end
function MMI.target_scitype(
::Type{<:MultitargetSRRegressor{<:Any,<:Any,<:TemplateExpression}}
)
return Union{MMI.Table(Any),AbstractMatrix{<:Any}}
end

function tag_with_docstring(model_name::Symbol, description::String, bottom_matter::String)
docstring = """$(MMI.doc_header(eval(model_name)))
Expand Down

0 comments on commit f4a3bb8

Please sign in to comment.