Skip to content

Commit

Permalink
feat: re-use allocations within mutation loop
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 12, 2024
1 parent 530689f commit bf1c521
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Dates = "1"
DifferentiationInterface = "0.5, 0.6"
DispatchDoctor = "^0.4.17"
Distributed = "<0.0.1, 1"
DynamicExpressions = "1.7.0"
DynamicExpressions = "~1.7"
DynamicQuantities = "1"
Enzyme = "0.12, 0.13"
JSON3 = "1"
Expand Down
8 changes: 8 additions & 0 deletions src/ComposableExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using DynamicExpressions.ValueInterfaceModule: is_valid_array

using ..ConstantOptimizationModule: ConstantOptimizationModule as CO
using ..CoreModule: get_safe_op
using ..ExpressionBuilderModule: ExpressionBuilderModule as EB

abstract type AbstractComposableExpression{T,N} <: AbstractExpression{T,N} end

Expand Down Expand Up @@ -112,6 +113,13 @@ end
function CO.count_constants_for_optimization(ex::AbstractComposableExpression)
return CO.count_constants_for_optimization(convert(Expression, ex))
end
function EB.preallocate_expression(prototype::ComposableExpression, n::Integer)
return (; tree=EB.preallocate_expression(get_contents(prototype), n))
end
function DE.copy_node!(dest::NamedTuple, src::ComposableExpression)
new_tree = DE.copy_node!(dest.tree, get_contents(src))
return DE.with_contents(src, new_tree)
end

@implements(
ExpressionInterface{all_ei_methods_except(())}, ComposableExpression, [Arguments()]
Expand Down
33 changes: 32 additions & 1 deletion src/ExpressionBuilder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@ module ExpressionBuilderModule
using DispatchDoctor: @unstable
using Compat: Fix
using DynamicExpressions:
AbstractExpressionNode, AbstractExpression, constructorof, with_metadata
DynamicExpressions as DE,
AbstractExpressionNode,
AbstractExpression,
Expression,
constructorof,
with_metadata
using StatsBase: StatsBase
using ..CoreModule: AbstractOptions, Dataset
using ..HallOfFameModule: HallOfFame
using ..PopulationModule: Population
using ..PopMemberModule: PopMember

import ..InterfaceDynamicExpressionsModule: preallocate_expression
import DynamicExpressions: get_operators
import ..CoreModule: create_expression

Expand Down Expand Up @@ -186,4 +192,29 @@ end
return get_operators(ex, options.operators)
end

# We don't require users to overload this, as it's not part of the required interface.
# Also, there's no way to generally do this from the required interface, so for backwards
# compatibility, we just return nothing.
function preallocate_expression(::AbstractExpression, n::Integer)
return nothing
end
function preallocate_expression(
prototype::N, n::Integer
) where {T,N<:AbstractExpressionNode{T}}
return N[DE.with_type_parameters(N, T)() for _ in 1:n]
end
function preallocate_expression(prototype::Expression, n::Integer)
return (; tree=preallocate_expression(DE.get_contents(prototype), n))
end

# The fallback is to just copy:
function DE.copy_node!(::Nothing, src::AbstractExpression)
# TODO: This is piracy
return copy(src)
end
function DE.copy_node!(dest::NamedTuple, src::Expression)
tree = DE.copy_node!(dest.tree, DE.get_contents(src))
return DE.with_contents(src, tree)
end

end
3 changes: 3 additions & 0 deletions src/InterfaceDynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,4 +359,7 @@ function DE.EvaluationHelpersModule._grad_evaluator(
)
end

# TODO: Move this to DynamicExpressions.jl
function preallocate_expression end

end
11 changes: 6 additions & 5 deletions src/Mutate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using ..CoreModule:
sample_mutation,
max_features
using ..ComplexityModule: compute_complexity
using ..InterfaceDynamicExpressionsModule: preallocate_expression
using ..LossFunctionsModule: score_func, score_func_batched
using ..CheckConstraintsModule: check_constraints
using ..AdaptiveParsimonyModule: RunningSearchStatistics
Expand Down Expand Up @@ -188,14 +189,14 @@ function next_generation(
successful_mutation = false
attempts = 0
max_attempts = 10
node_buffer = collect(copy(member.tree))
node_storage = preallocate_expression(member.tree, curmaxsize)

#############################################
# Mutations
#############################################
local tree
while (!successful_mutation) && attempts < max_attempts
tree = attempts == 0 ? first(node_buffer) : copy_node!(node_buffer, member.tree)
tree = copy_node!(node_storage, member.tree)

mutation_result = _dispatch_mutations!(
tree,
Expand Down Expand Up @@ -240,7 +241,7 @@ function next_generation(
mutation_accepted = false
return (
PopMember(
copy(member.tree),
copy_node!(node_storage, member.tree),
beforeScore,
beforeLoss,
options,
Expand Down Expand Up @@ -269,7 +270,7 @@ function next_generation(
mutation_accepted = false
return (
PopMember(
copy(member.tree),
copy_node!(node_storage, member.tree),
beforeScore,
beforeLoss,
options,
Expand Down Expand Up @@ -312,7 +313,7 @@ function next_generation(
mutation_accepted = false
return (
PopMember(
copy(member.tree),
copy_node!(node_storage, member.tree),
beforeScore,
beforeLoss,
options,
Expand Down
21 changes: 21 additions & 0 deletions src/ParametricExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,25 @@ function MF.mutate_constant(
end
end

function EB.preallocate_expression(prototype::ParametricExpression, n::Integer)
return (;
tree=EB.preallocate_expression(get_contents(prototype), n),
parameters=similar(get_metadata(prototype).parameters),
)
end
function DE.copy_node!(dest::NamedTuple, src::ParametricExpression)
new_tree = DE.copy_node!(dest.tree, get_contents(src))
metadata = DE.get_metadata(src)
new_parameters = dest.parameters
new_parameters .= metadata.parameters
new_metadata = DE.Metadata((;
operators=metadata.operators,
variable_names=metadata.variable_names,
parameters=new_parameters,
parameter_names=metadata.parameter_names,
))
# TODO: Better interface for this^
return DE.with_metadata(DE.with_contents(src, new_tree), new_metadata)
end

end
16 changes: 16 additions & 0 deletions src/TemplateExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,22 @@ function Base.isempty(ex::TemplateExpression)
return all(isempty, values(get_contents(ex)))
end

function EB.preallocate_expression(prototype::TemplateExpression, n::Integer)
raw_contents = get_contents(prototype)
return (;
trees=NamedTuple{keys(raw_contents)}(
map(Base.Fix2(EB.preallocate_expression, n), values(raw_contents))
),
)
end
function DE.copy_node!(dest::NamedTuple, src::TemplateExpression)
raw_contents = get_contents(src)
new_trees = NamedTuple{keys(raw_contents)}(
map(DE.copy_node!, dest.trees, values(raw_contents))
)
return DE.with_contents(src, new_trees)
end

# TODO: Add custom behavior to adjust what feature nodes can be generated

end

0 comments on commit bf1c521

Please sign in to comment.