Skip to content

Commit

Permalink
refactor: move allocation utility to DynamicExpressions.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 13, 2024
1 parent a397dad commit 19eeb11
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 93 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ Dates = "1"
DifferentiationInterface = "0.5, 0.6"
DispatchDoctor = "^0.4.17"
Distributed = "<0.0.1, 1"
DynamicExpressions = "~1.7"
DynamicExpressions = "~1.8"
DynamicQuantities = "1"
Enzyme = "0.12, 0.13"
JSON3 = "1"
Expand Down
9 changes: 4 additions & 5 deletions src/ComposableExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ using DynamicExpressions.ValueInterfaceModule: is_valid_array

using ..ConstantOptimizationModule: ConstantOptimizationModule as CO
using ..CoreModule: get_safe_op
using ..ExpressionBuilderModule: ExpressionBuilderModule as EB

abstract type AbstractComposableExpression{T,N} <: AbstractExpression{T,N} end

Expand Down Expand Up @@ -113,13 +112,13 @@ end
function CO.count_constants_for_optimization(ex::AbstractComposableExpression)
return CO.count_constants_for_optimization(convert(Expression, ex))
end
function EB.preallocate_expression(
function DE.allocate_container(
prototype::ComposableExpression, n::Union{Nothing,Integer}=nothing
)
return (; tree=EB.preallocate_expression(get_contents(prototype), n))
return (; tree=DE.allocate_container(get_contents(prototype), n))
end
function DE.copy_node!(dest::NamedTuple, src::ComposableExpression)
new_tree = DE.copy_node!(dest.tree, get_contents(src))
function DE.copy_into!(dest::NamedTuple, src::ComposableExpression)
new_tree = DE.copy_into!(dest.tree, get_contents(src))
return DE.with_contents(src, new_tree)
end

Expand Down
36 changes: 1 addition & 35 deletions src/ExpressionBuilder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,13 @@ module ExpressionBuilderModule
using DispatchDoctor: @unstable
using Compat: Fix
using DynamicExpressions:
DynamicExpressions as DE,
AbstractExpressionNode,
AbstractExpression,
Expression,
constructorof,
with_metadata
AbstractExpressionNode, AbstractExpression, constructorof, with_metadata
using StatsBase: StatsBase
using ..CoreModule: AbstractOptions, Dataset
using ..HallOfFameModule: HallOfFame
using ..PopulationModule: Population
using ..PopMemberModule: PopMember

import ..InterfaceDynamicExpressionsModule: preallocate_expression
import DynamicExpressions: get_operators
import ..CoreModule: create_expression

Expand Down Expand Up @@ -192,32 +186,4 @@ end
return get_operators(ex, options.operators)
end

# We don't require users to overload this, as it's not part of the required interface.
# Also, there's no way to generally do this from the required interface, so for backwards
# compatibility, we just return nothing.
function preallocate_expression(::AbstractExpression, n::Union{Nothing,Integer}=nothing)
# return nothing
return error("Should not use this!")
end
function preallocate_expression(
prototype::N, n::Union{Nothing,Integer}=nothing
) where {T,N<:AbstractExpressionNode{T}}
num_nodes = @something(n, length(prototype))
return N[DE.with_type_parameters(N, T)() for _ in 1:num_nodes]
end
function preallocate_expression(prototype::Expression, n::Union{Nothing,Integer}=nothing)
return (; tree=preallocate_expression(DE.get_contents(prototype), n))
end

# The fallback is to just copy:
function DE.copy_node!(::Nothing, src::AbstractExpression)
# TODO: This is piracy
# return copy(src)
return error("Should not use this!") # TODO HACK
end
function DE.copy_node!(dest::NamedTuple, src::Expression)
tree = DE.copy_node!(dest.tree, DE.get_contents(src))
return DE.with_contents(src, tree)
end

end
3 changes: 0 additions & 3 deletions src/InterfaceDynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,4 @@ function DE.EvaluationHelpersModule._grad_evaluator(
)
end

# TODO: Move this to DynamicExpressions.jl
function preallocate_expression end

end
16 changes: 8 additions & 8 deletions src/Mutate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ module MutateModule

using DynamicExpressions:
AbstractExpression,
copy_node!,
copy_into!,
get_tree,
preserve_sharing,
count_scalar_constants,
simplify_tree!,
combine_operators
combine_operators,
allocate_container
using ..CoreModule:
AbstractOptions,
AbstractMutationWeights,
Expand All @@ -16,7 +17,6 @@ using ..CoreModule:
sample_mutation,
max_features
using ..ComplexityModule: compute_complexity
using ..InterfaceDynamicExpressionsModule: preallocate_expression
using ..LossFunctionsModule: score_func, score_func_batched
using ..CheckConstraintsModule: check_constraints
using ..AdaptiveParsimonyModule: RunningSearchStatistics
Expand Down Expand Up @@ -189,14 +189,14 @@ function next_generation(
successful_mutation = false
attempts = 0
max_attempts = 10
node_storage = preallocate_expression(member.tree)
node_storage = allocate_container(member.tree)

#############################################
# Mutations
#############################################
local tree
while (!successful_mutation) && attempts < max_attempts
tree = copy_node!(node_storage, member.tree)
tree = copy_into!(node_storage, member.tree)

mutation_result = _dispatch_mutations!(
tree,
Expand Down Expand Up @@ -241,7 +241,7 @@ function next_generation(
mutation_accepted = false
return (
PopMember(
copy_node!(node_storage, member.tree),
copy_into!(node_storage, member.tree),
beforeScore,
beforeLoss,
options,
Expand Down Expand Up @@ -270,7 +270,7 @@ function next_generation(
mutation_accepted = false
return (
PopMember(
copy_node!(node_storage, member.tree),
copy_into!(node_storage, member.tree),
beforeScore,
beforeLoss,
options,
Expand Down Expand Up @@ -313,7 +313,7 @@ function next_generation(
mutation_accepted = false
return (
PopMember(
copy_node!(node_storage, member.tree),
copy_into!(node_storage, member.tree),
beforeScore,
beforeLoss,
options,
Expand Down
23 changes: 0 additions & 23 deletions src/ParametricExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,27 +181,4 @@ function MF.mutate_constant(
end
end

function EB.preallocate_expression(
prototype::ParametricExpression, n::Union{Nothing,Integer}=nothing
)
return (;
tree=EB.preallocate_expression(get_contents(prototype), n),
parameters=similar(get_metadata(prototype).parameters),
)
end
function DE.copy_node!(dest::NamedTuple, src::ParametricExpression)
new_tree = DE.copy_node!(dest.tree, get_contents(src))
metadata = DE.get_metadata(src)
new_parameters = dest.parameters
new_parameters .= metadata.parameters
new_metadata = DE.Metadata((;
operators=metadata.operators,
variable_names=metadata.variable_names,
parameters=new_parameters,
parameter_names=metadata.parameter_names,
))
# TODO: Better interface for this^
return DE.with_metadata(DE.with_contents(src, new_tree), new_metadata)
end

end
18 changes: 0 additions & 18 deletions src/TemplateExpression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -536,24 +536,6 @@ function Base.isempty(ex::TemplateExpression)
return all(isempty, values(get_contents(ex)))
end

function EB.preallocate_expression(
prototype::TemplateExpression, n::Union{Nothing,Integer}=nothing
)
raw_contents = get_contents(prototype)
return (;
trees=NamedTuple{keys(raw_contents)}(
map(Base.Fix2(EB.preallocate_expression, n), values(raw_contents))
),
)
end
function DE.copy_node!(dest::NamedTuple, src::TemplateExpression)
raw_contents = get_contents(src)
new_trees = NamedTuple{keys(raw_contents)}(
map(DE.copy_node!, values(dest.trees), values(raw_contents))
)
return DE.with_contents(src, new_trees)
end

# TODO: Add custom behavior to adjust what feature nodes can be generated

end

0 comments on commit 19eeb11

Please sign in to comment.