diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 3a8d0a1d..1cc34fe4 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -32,6 +32,7 @@ using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame using ..UtilsModule: subscriptify, @ignore using ..LoggingModule: AbstractSRLogger +using ..TemplateExpressionModule: TemplateExpression import ..equation_search @@ -50,8 +51,9 @@ end """Generate an `SRRegressor` struct containing all the fields in `Options`.""" function modelexpr(model_name::Symbol) - struct_def = :(Base.@kwdef mutable struct $(model_name){D<:AbstractDimensions,L} <: - AbstractSRRegressor + struct_def = :(Base.@kwdef mutable struct $(model_name){ + D<:AbstractDimensions,L,E<:AbstractExpression + } <: AbstractSRRegressor niterations::Int = 100 parallelism::Symbol = :multithreading numprocs::Union{Int,Nothing} = nothing @@ -62,7 +64,7 @@ function modelexpr(model_name::Symbol) logger::Union{AbstractSRLogger,Nothing} = nothing runtests::Bool = true run_id::Union{String,Nothing} = nothing - loss_type::L = Nothing + loss_type::Type{L} = Nothing selection_method::Function = choose_best dimensions_type::Type{D} = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE} end) @@ -71,7 +73,14 @@ function modelexpr(model_name::Symbol) # Add everything from `Options` constructor directly to struct: for (i, option) in enumerate(DEFAULT_OPTIONS) + if getsymb(first(option.args)) == :expression_type + continue + end insert!(fields, i, Expr(:(=), option.args...)) + if getsymb(first(option.args)) == :node_type + # Manually add `expression_type` above, so it can be depended on by `node_type` + insert!(fields, i - 1, :(expression_type::Type{E} = Expression)) + end end # We also need to create the `get_options` function, based on this: @@ -597,7 +606,7 @@ const input_scitype = Union{ MMI.metadata_model( SRRegressor; input_scitype, - target_scitype=AbstractVector{<:Any}, + target_scitype=AbstractVector{<:MMI.Continuous}, supports_weights=true, reports_feature_importances=false, load_path="SymbolicRegression.MLJInterfaceModule.SRRegressor", @@ -606,13 +615,22 @@ MMI.metadata_model( MMI.metadata_model( MultitargetSRRegressor; input_scitype, - target_scitype=Union{MMI.Table(Any),AbstractMatrix{<:Any}}, + target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}}, supports_weights=true, reports_feature_importances=false, load_path="SymbolicRegression.MLJInterfaceModule.MultitargetSRRegressor", human_name="Multi-Target Symbolic Regression via Evolutionary Search", ) +function MMI.target_scitype(::Type{<:SRRegressor{<:Any,<:Any,<:TemplateExpression}}) + return AbstractVector{<:Any} +end +function MMI.target_scitype( + ::Type{<:MultitargetSRRegressor{<:Any,<:Any,<:TemplateExpression}} +) + return Union{MMI.Table(Any),AbstractMatrix{<:Any}} +end + function tag_with_docstring(model_name::Symbol, description::String, bottom_matter::String) docstring = """$(MMI.doc_header(eval(model_name)))