diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index 2916c6922..f09e5272f 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -58,7 +58,12 @@ function eval_tree_array( ) return eval_tree_array( tree, X, options.operators; turbo=options.turbo, bumper=options.bumper, kws... - ) + )::Tuple{expected_array_type(tree, X, options.operators),Bool} +end + +# Improve type inference by telling Julia the expected array returned +function expected_array_type(::AbstractExpressionNode, X::AbstractArray, ::OperatorEnum) + return typeof(similar(X, axes(X, 2))) end """ diff --git a/src/Options.jl b/src/Options.jl index 5351fdee0..b3053ac91 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -26,7 +26,7 @@ using ..OperatorsModule: atanh_clip using ..MutationWeightsModule: MutationWeights, mutations import ..OptionsStructModule: Options -using ..OptionsStructModule: ComplexityMapping +using ..OptionsStructModule: ComplexityMapping, operator_specialization using ..UtilsModule: max_ops, @save_kwargs """ @@ -758,6 +758,7 @@ function Options end options = Options{ eltype(complexity_mapping), + operator_specialization(typeof(operators)), node_type, turbo, bumper, diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index bf03745c0..b61d731ee 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -1,7 +1,8 @@ module OptionsStructModule using Optim: Optim -using DynamicExpressions: AbstractOperatorEnum, AbstractExpressionNode, OperatorEnum +using DynamicExpressions: + AbstractOperatorEnum, AbstractExpressionNode, OperatorEnum, GenericOperatorEnum using LossFunctions: SupervisedLoss import ..MutationWeightsModule: MutationWeights @@ -38,8 +39,13 @@ function ComplexityMapping(; ) end -struct Options{CT,N<:AbstractExpressionNode,_turbo,_bumper,W} - operators::AbstractOperatorEnum +# Controls level of specialization we compile +operator_specialization(::Type{<:AbstractOperatorEnum}) = AbstractOperatorEnum +operator_specialization(::Type{<:OperatorEnum}) = OperatorEnum +operator_specialization(::Type{<:GenericOperatorEnum}) = GenericOperatorEnum + +struct Options{CT,OP<:AbstractOperatorEnum,N<:AbstractExpressionNode,_turbo,_bumper,W} + operators::OP bin_constraints::Vector{Tuple{Int,Int}} una_constraints::Vector{Int} complexity_mapping::ComplexityMapping{CT} diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 1ec9d0274..0cd307896 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -1093,7 +1093,7 @@ function _format_output(state::SearchState, ropt::RuntimeOptions) end @noinline function _dispatch_s_r_cycle( - in_pop::Population{T,L,P}, + in_pop::Population{T,L,N}, dataset::Dataset, @nospecialize(options::Options); pop::Int, @@ -1102,7 +1102,7 @@ end verbosity, cur_maxsize::Int, running_search_statistics, -) where {T,L,P} +) where {T,L,N} record = RecordType() @recorder record["out$(out)_pop$(pop)"] = RecordType( "iteration$(iteration)" => record_population(in_pop, options)