Skip to content

Commit

Permalink
Help more with type inference in evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Mar 16, 2024
1 parent 6f7d755 commit f3ed53b
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
7 changes: 6 additions & 1 deletion src/InterfaceDynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ function eval_tree_array(
)
return eval_tree_array(
tree, X, options.operators; turbo=options.turbo, bumper=options.bumper, kws...
)
)::Tuple{expected_array_type(tree, X, options.operators),Bool}
end

# Improve type inference by telling Julia the expected array returned
function expected_array_type(::AbstractExpressionNode, X::AbstractArray, ::OperatorEnum)
return typeof(similar(X, axes(X, 2)))
end

"""
Expand Down
3 changes: 2 additions & 1 deletion src/Options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ using ..OperatorsModule:
atanh_clip
using ..MutationWeightsModule: MutationWeights, mutations
import ..OptionsStructModule: Options
using ..OptionsStructModule: ComplexityMapping
using ..OptionsStructModule: ComplexityMapping, operator_specialization
using ..UtilsModule: max_ops, @save_kwargs

"""
Expand Down Expand Up @@ -758,6 +758,7 @@ function Options end

options = Options{
eltype(complexity_mapping),
operator_specialization(typeof(operators)),
node_type,
turbo,
bumper,
Expand Down
12 changes: 9 additions & 3 deletions src/OptionsStruct.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module OptionsStructModule

using Optim: Optim
using DynamicExpressions: AbstractOperatorEnum, AbstractExpressionNode, OperatorEnum
using DynamicExpressions:
AbstractOperatorEnum, AbstractExpressionNode, OperatorEnum, GenericOperatorEnum
using LossFunctions: SupervisedLoss

import ..MutationWeightsModule: MutationWeights
Expand Down Expand Up @@ -38,8 +39,13 @@ function ComplexityMapping(;
)
end

struct Options{CT,N<:AbstractExpressionNode,_turbo,_bumper,W}
operators::AbstractOperatorEnum
# Controls level of specialization we compile
operator_specialization(::Type{<:AbstractOperatorEnum}) = AbstractOperatorEnum
operator_specialization(::Type{<:OperatorEnum}) = OperatorEnum
operator_specialization(::Type{<:GenericOperatorEnum}) = GenericOperatorEnum

struct Options{CT,OP<:AbstractOperatorEnum,N<:AbstractExpressionNode,_turbo,_bumper,W}
operators::OP
bin_constraints::Vector{Tuple{Int,Int}}
una_constraints::Vector{Int}
complexity_mapping::ComplexityMapping{CT}
Expand Down
4 changes: 2 additions & 2 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ function _format_output(state::SearchState, ropt::RuntimeOptions)
end

@noinline function _dispatch_s_r_cycle(
in_pop::Population{T,L,P},
in_pop::Population{T,L,N},
dataset::Dataset,
@nospecialize(options::Options);
pop::Int,
Expand All @@ -1102,7 +1102,7 @@ end
verbosity,
cur_maxsize::Int,
running_search_statistics,
) where {T,L,P}
) where {T,L,N}
record = RecordType()
@recorder record["out$(out)_pop$(pop)"] = RecordType(
"iteration$(iteration)" => record_population(in_pop, options)
Expand Down

0 comments on commit f3ed53b

Please sign in to comment.