Skip to content

Commit

Permalink
fix: allocate node storage for actual size
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 12, 2024
1 parent 399b86a commit adcdf15
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 7 deletions.
4 changes: 3 additions & 1 deletion src/ComposableExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ end
function CO.count_constants_for_optimization(ex::AbstractComposableExpression)
return CO.count_constants_for_optimization(convert(Expression, ex))
end
function EB.preallocate_expression(prototype::ComposableExpression, n::Integer)
function EB.preallocate_expression(
prototype::ComposableExpression, n::Union{Nothing,Integer}=nothing
)
return (; tree=EB.preallocate_expression(get_contents(prototype), n))
end
function DE.copy_node!(dest::NamedTuple, src::ComposableExpression)
Expand Down
7 changes: 4 additions & 3 deletions src/ExpressionBuilder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,14 @@ end
# We don't require users to overload this, as it's not part of the required interface.
# Also, there's no way to generally do this from the required interface, so for backwards
# compatibility, we just return nothing.
function preallocate_expression(::AbstractExpression, n::Integer)
function preallocate_expression(::AbstractExpression, n::Union{Nothing,Integer}=nothing)
return nothing
end
function preallocate_expression(
prototype::N, n::Integer
prototype::N, n::Union{Nothing,Integer}=nothing
) where {T,N<:AbstractExpressionNode{T}}
return N[DE.with_type_parameters(N, T)() for _ in 1:n]
num_nodes = @something(n, length(prototype))
return N[DE.with_type_parameters(N, T)() for _ in 1:num_nodes]
end
function preallocate_expression(prototype::Expression, n::Integer)
return (; tree=preallocate_expression(DE.get_contents(prototype), n))
Expand Down
2 changes: 1 addition & 1 deletion src/Mutate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ function next_generation(
successful_mutation = false
attempts = 0
max_attempts = 10
node_storage = preallocate_expression(member.tree, length(member.tree))
node_storage = preallocate_expression(member.tree)

#############################################
# Mutations
Expand Down
4 changes: 3 additions & 1 deletion src/ParametricExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ function MF.mutate_constant(
end
end

function EB.preallocate_expression(prototype::ParametricExpression, n::Integer)
function EB.preallocate_expression(
prototype::ParametricExpression, n::Union{Nothing,Integer}=nothing
)
return (;
tree=EB.preallocate_expression(get_contents(prototype), n),
parameters=similar(get_metadata(prototype).parameters),
Expand Down
4 changes: 3 additions & 1 deletion src/TemplateExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,9 @@ function Base.isempty(ex::TemplateExpression)
return all(isempty, values(get_contents(ex)))
end

function EB.preallocate_expression(prototype::TemplateExpression, n::Integer)
function EB.preallocate_expression(
prototype::TemplateExpression, n::Union{Nothing,Integer}=nothing
)
raw_contents = get_contents(prototype)
return (;
trees=NamedTuple{keys(raw_contents)}(
Expand Down

0 comments on commit adcdf15

Please sign in to comment.