Skip to content

Commit

Permalink
Fix with_type_parameters import
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 19, 2023
1 parent 06cf456 commit 4f8eaeb
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 16 deletions.
3 changes: 2 additions & 1 deletion src/HallOfFame.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module HallOfFameModule

import DynamicExpressions:
AbstractExpressionNode, Node, constructorof, with_type_parameters, string_tree
AbstractExpressionNode, Node, constructorof, string_tree
import DynamicExpressions.EquationModule: with_type_parameters
import ..UtilsModule: split_string
import ..CoreModule: MAX_DEGREE, Options, Dataset, DATA_TYPE, LOSS_TYPE, relu
import ..ComplexityModule: compute_complexity
Expand Down
31 changes: 17 additions & 14 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module MLJInterfaceModule

using Optim: Optim
import MLJModelInterface as MMI
import DynamicExpressions: eval_tree_array, string_tree, Node
import DynamicExpressions: eval_tree_array, string_tree, AbstractExpressionNode, Node
import DynamicQuantities:
AbstractQuantity,
AbstractDimensions,
Expand Down Expand Up @@ -30,19 +30,20 @@ abstract type AbstractSRRegressor <: MMI.Deterministic end

"""Generate an `SRRegressor` struct containing all the fields in `Options`."""
function modelexpr(model_name::Symbol)
struct_def =
:(Base.@kwdef mutable struct $(model_name){D<:AbstractDimensions,L,use_recorder} <:
AbstractSRRegressor
niterations::Int = 10
parallelism::Symbol = :multithreading
numprocs::Union{Int,Nothing} = nothing
procs::Union{Vector{Int},Nothing} = nothing
addprocs_function::Union{Function,Nothing} = nothing
runtests::Bool = true
loss_type::L = Nothing
selection_method::Function = choose_best
dimensions_type::Type{D} = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE}
end)
struct_def = :(Base.@kwdef mutable struct $(model_name){
D<:AbstractDimensions,L,use_recorder,N<:AbstractExpressionNode
} <: AbstractSRRegressor
niterations::Int = 10
node_type::Type{N} = Node
parallelism::Symbol = :multithreading
numprocs::Union{Int,Nothing} = nothing
procs::Union{Vector{Int},Nothing} = nothing
addprocs_function::Union{Function,Nothing} = nothing
runtests::Bool = true
loss_type::L = Nothing
selection_method::Function = choose_best
dimensions_type::Type{D} = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE}
end)
# TODO: store `procs` from initial run if parallelism is `:multiprocessing`
fields = last(last(struct_def.args).args).args

Expand Down Expand Up @@ -462,6 +463,8 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt
# TODO: These ones are copied (or written) manually:
append_arguments = """- `niterations::Int=10`: The number of iterations to perform the search.
More iterations will improve the results.
- `node_type::Type{N}=Node`: The type of node to use for the search.
For example, `Node` or `GraphNode`.
- `parallelism=:multithreading`: What parallelism mode to use.
The options are `:multithreading`, `:multiprocessing`, and `:serial`.
By default, multithreading will be used. Multithreading uses less memory,
Expand Down
4 changes: 3 additions & 1 deletion src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ import DynamicExpressions:
Node,
GraphNode,
AbstractExpressionNode,
with_type_parameters,
copy_node,
set_node!,
string_tree,
Expand All @@ -96,6 +95,7 @@ import DynamicExpressions:
simplify_tree,
tree_mapreduce,
set_default_variable_names!
import DynamicExpressions.EquationModule: with_type_parameters
@reexport import LossFunctions:
MarginLoss,
DistanceLoss,
Expand Down Expand Up @@ -266,6 +266,8 @@ which is useful for debugging and profiling.
weight the loss for each `y` by this value (same shape as `y`).
- `options::Options=Options()`: The options for the search, such as
which operators to use, evolution hyperparameters, etc.
- `node_type::Type{N}=Node`: The type of node to use for the search.
For example, `Node` or `GraphNode`.
- `variable_names::Union{Vector{String}, Nothing}=nothing`: The names
of each feature in `X`, which will be used during printing of equations.
- `display_variable_names::Union{Vector{String}, Nothing}=variable_names`: Names
Expand Down

0 comments on commit 4f8eaeb

Please sign in to comment.