From effab2c58415c131fbf13ab82d4d1cfcfd8ef822 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 30 Oct 2024 19:38:45 +0000 Subject: [PATCH 01/59] feat: create `ComposableExpression` --- src/ComposableExpression.jl | 80 +++++++++++++++++++++++++++++++++++++ src/SymbolicRegression.jl | 3 ++ 2 files changed, 83 insertions(+) create mode 100644 src/ComposableExpression.jl diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl new file mode 100644 index 000000000..c2538dd55 --- /dev/null +++ b/src/ComposableExpression.jl @@ -0,0 +1,80 @@ +module ComposableExpressionModule + +using DynamicExpressions: + AbstractExpression, + AbstractExpressionNode, + AbstractOperatorEnum, + Metadata, + eval_tree_array, + DynamicExpressions as DE +using DynamicExpressions.InterfacesModule: + ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments + +abstract type AbstractComposableExpression{T,N} <: AbstractExpression{T,N} end + +struct ComposableExpression{T,N<:AbstractExpressionNode{T},D<:NamedTuple} <: + AbstractComposableExpression{T,N} + tree::N + metadata::Metadata{D} +end + +@inline function ComposableExpression( + tree::AbstractExpressionNode{T}; metadata... +) where {T} + d = (; metadata...) + return ComposableExpression(tree, Metadata(d)) +end + +DE.get_metadata(ex::AbstractComposableExpression) = ex.metadata +DE.get_contents(ex::AbstractComposableExpression) = ex.tree +DE.get_tree(ex::AbstractComposableExpression) = ex.tree + +function DE.get_operators( + ex::AbstractComposableExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) + return something(operators, DE.get_metadata(ex).operators) +end +function DE.get_variable_names( + ex::AbstractComposableExpression, + variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, +) + return something(variable_names, DE.get_metadata(ex).variable_names) +end + +@implements( + ExpressionInterface{all_ei_methods_except(())}, ComposableExpression, [Arguments()] +) + +struct ResultOk{A<:AbstractVector} + value::A + ok::Bool +end +ResultOk(x::Tuple{Vararg{Any,2}}) = ResultOk(x...) + +function (ex::AbstractComposableExpression)(x) + return error("ComposableExpression does not support input of type $(typeof(x))") +end +function (ex::AbstractComposableExpression)(x::AbstractVector, _xs::AbstractVector...) + xs = (x, _xs...) + # Wrap it up for the recursive call + xs = ntuple(i -> ResultOk(xs[i], true), Val(length(xs))) + result_ok = ex(xs...) + # Unwrap it + if result_ok.ok + return result_ok.value + else + nan = convert(eltype(result_ok.value), NaN) + return result_ok.value .* nan + end +end +function (ex::AbstractComposableExpression)(x::ResultOk, _xs::ResultOk...) + xs = (x, _xs...) + ok = all(xi -> xi.ok, xs) + if !ok + return ResultOk(first(xs).value, false) + end + X = Matrix(stack(map(xi -> xi.value, xs)...)') + return ResultOk(eval_tree_array(ex, X)) +end + +end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 75bdbfa19..df46d107f 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -14,6 +14,7 @@ export Population, ParametricExpression, TemplateExpression, TemplateStructure, + ComposableExpression, NodeSampler, AbstractExpression, AbstractExpressionNode, @@ -222,6 +223,7 @@ using DispatchDoctor: @stable include("SearchUtils.jl") include("ExpressionBuilder.jl") include("TemplateExpression.jl") + include("ComposableExpression.jl") include("ParametricExpression.jl") end @@ -317,6 +319,7 @@ using .SearchUtilsModule: get_cur_maxsize, update_hall_of_fame! using .TemplateExpressionModule: TemplateExpression, TemplateStructure +using .ComposableExpressionModule: ComposableExpression using .ExpressionBuilderModule: embed_metadata, strip_metadata @stable default_mode = "disable" begin From 42935fc7ff02549754604bb0224d052271996a12 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 30 Oct 2024 19:40:35 +0000 Subject: [PATCH 02/59] feat: tweak names of internal types --- src/ComposableExpression.jl | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index c2538dd55..56ba1abdc 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -45,11 +45,11 @@ end ExpressionInterface{all_ei_methods_except(())}, ComposableExpression, [Arguments()] ) -struct ResultOk{A<:AbstractVector} +struct VectorWrapper{A<:AbstractVector} value::A - ok::Bool + valid::Bool end -ResultOk(x::Tuple{Vararg{Any,2}}) = ResultOk(x...) +VectorWrapper(x::Tuple{Vararg{Any,2}}) = VectorWrapper(x...) function (ex::AbstractComposableExpression)(x) return error("ComposableExpression does not support input of type $(typeof(x))") @@ -57,24 +57,24 @@ end function (ex::AbstractComposableExpression)(x::AbstractVector, _xs::AbstractVector...) xs = (x, _xs...) # Wrap it up for the recursive call - xs = ntuple(i -> ResultOk(xs[i], true), Val(length(xs))) - result_ok = ex(xs...) + xs = ntuple(i -> VectorWrapper(xs[i], true), Val(length(xs))) + result = ex(xs...) # Unwrap it - if result_ok.ok - return result_ok.value + if result.valid + return result.value else - nan = convert(eltype(result_ok.value), NaN) - return result_ok.value .* nan + nan = convert(eltype(result.value), NaN) + return result.value .* nan end end -function (ex::AbstractComposableExpression)(x::ResultOk, _xs::ResultOk...) +function (ex::AbstractComposableExpression)(x::VectorWrapper, _xs::VectorWrapper...) xs = (x, _xs...) - ok = all(xi -> xi.ok, xs) - if !ok - return ResultOk(first(xs).value, false) + valid = all(xi -> xi.valid, xs) + if !valid + return VectorWrapper(first(xs).value, false) end X = Matrix(stack(map(xi -> xi.value, xs)...)') - return ResultOk(eval_tree_array(ex, X)) + return VectorWrapper(eval_tree_array(ex, X)) end end From e3dad4f25b3dd7c8d619bf93c9e020e2dba4ea29 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 30 Oct 2024 20:50:42 +0000 Subject: [PATCH 03/59] test: composable expression --- src/ComposableExpression.jl | 13 ++++++++++++- test/runtests.jl | 1 + test/test_composable_expression.jl | 25 +++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 test/test_composable_expression.jl diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index 56ba1abdc..3f1f83ef9 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -41,6 +41,17 @@ function DE.get_variable_names( return something(variable_names, DE.get_metadata(ex).variable_names) end +function DE.get_scalar_constants(ex::AbstractComposableExpression) + return DE.get_scalar_constants(DE.get_contents(ex)) +end +function DE.set_scalar_constants!(ex::AbstractComposableExpression, constants, refs) + return DE.set_scalar_constants!(DE.get_contents(ex), constants, refs) +end + +function Base.copy(ex::AbstractComposableExpression) + return ComposableExpression(copy(ex.tree), copy(ex.metadata)) +end + @implements( ExpressionInterface{all_ei_methods_except(())}, ComposableExpression, [Arguments()] ) @@ -73,7 +84,7 @@ function (ex::AbstractComposableExpression)(x::VectorWrapper, _xs::VectorWrapper if !valid return VectorWrapper(first(xs).value, false) end - X = Matrix(stack(map(xi -> xi.value, xs)...)') + X = Matrix(stack(map(xi -> xi.value, xs))') return VectorWrapper(eval_tree_array(ex, X)) end diff --git a/test/runtests.jl b/test/runtests.jl index fcc2c5b08..dacb374a0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -165,6 +165,7 @@ end include("test_pretty_printing.jl") include("test_expression_builder.jl") +include("test_composable_expression.jl") @testitem "Aqua tests" tags = [:part2, :aqua] begin include("test_aqua.jl") diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl new file mode 100644 index 000000000..e32d49fed --- /dev/null +++ b/test/test_composable_expression.jl @@ -0,0 +1,25 @@ + +@testitem "Test ComposableExpression" tags = [:part2] begin + using SymbolicRegression: ComposableExpression, Node + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + ex = ComposableExpression(Node(Float64; feature=1); operators, variable_names) + x = randn(32) + y = randn(32) + + @test ex(x, y) == x +end + +@testitem "Test interface" tags = [:part2] begin + using SymbolicRegression: ComposableExpression + using DynamicExpressions.InterfacesModule: Interfaces, ExpressionInterface + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + f = ComposableExpression(Node(Float64; feature=1); operators, variable_names) + + @test Interfaces.test(ExpressionInterface, ComposableExpression, [f]) +end From a1e192caddfb9f1a77921d7dfd1d33cb950d9d54 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 1 Nov 2024 22:07:58 +0000 Subject: [PATCH 04/59] feat: init hierarchical expression --- src/ComposableExpression.jl | 56 ++++- src/HierarchicalExpression.jl | 392 ++++++++++++++++++++++++++++++++++ src/SymbolicRegression.jl | 3 +- 3 files changed, 448 insertions(+), 3 deletions(-) create mode 100644 src/HierarchicalExpression.jl diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index 3f1f83ef9..bd6c843dd 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -6,6 +6,7 @@ using DynamicExpressions: AbstractOperatorEnum, Metadata, eval_tree_array, + set_node!, DynamicExpressions as DE using DynamicExpressions.InterfacesModule: ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments @@ -83,9 +84,60 @@ function (ex::AbstractComposableExpression)(x::VectorWrapper, _xs::VectorWrapper valid = all(xi -> xi.valid, xs) if !valid return VectorWrapper(first(xs).value, false) + else + X = Matrix(stack(map(xi -> xi.value, xs))') + return VectorWrapper(eval_tree_array(ex, X)) + end +end +function (ex::AbstractComposableExpression)( + x::AbstractComposableExpression, _xs::AbstractComposableExpression... +) + xs = (x, _xs...) + # To do this, we basically want to put the tree of x + # into the position of variable 1, and so on! + tree = copy(get_contents(ex)) + xs_trees = map(get_contents, xs) + # TODO: This is a bit dangerous, no? We are assuming + # that `foreach` won't try to go down the copied trees + foreach(tree) do node + if node.degree == 0 && !node.constant + set_node!(node, copy(xs_trees[node.feature])) + end + end + return with_contents(ex, tree) +end + +# Basically we want to vectorize every single operation on VectorWrapper, +# so that the user can use it easily. + +#! format: off +# First, binary operators: +for op in ( + :*, :/, :+, :-, :^, :÷, :mod, :log, + :atan, :atand, :copysign, :flipsign, + :&, :|, :⊻, ://, :\, +) + @eval function Base.$(op)(x::VectorWrapper, y::VectorWrapper) + return VectorWrapper(@. Base.$(op)(x.value, y.value)) + end +end + +for op in ( + :sin, :cos, :tan, :sinh, :cosh, :tanh, :asin, :acos, + :asinh, :acosh, :atanh, :sec, :csc, :cot, :asec, :acsc, :acot, :sech, :csch, + :coth, :asech, :acsch, :acoth, :sinc, :cosc, :cosd, :cotd, :cscd, :secd, + :sinpi, :cospi, :sind, :tand, :acosd, :acotd, :acscd, :asecd, :asind, + :log, :log2, :log10, :log1p, :exp, :exp2, :exp10, :expm1, :frexp, :exponent, + :float, :abs, :real, :imag, :conj, :unsigned, + :nextfloat, :prevfloat, :transpose, :significand, + :modf, :rem, :floor, :ceil, :round, :trunc, + :inv, :sqrt, :cbrt, :abs2, :angle, :factorial, + :(!), :-, :+, :sign, :identity, +) + @eval function Base.$(op)(x::VectorWrapper) + return VectorWrapper(@. Base.$(op)(x.value)) end - X = Matrix(stack(map(xi -> xi.value, xs))') - return VectorWrapper(eval_tree_array(ex, X)) end +#! format: on end diff --git a/src/HierarchicalExpression.jl b/src/HierarchicalExpression.jl new file mode 100644 index 000000000..3d399238b --- /dev/null +++ b/src/HierarchicalExpression.jl @@ -0,0 +1,392 @@ +module HierarchicalExpressionModule + +using Random: AbstractRNG +using Compat: Fix +using DispatchDoctor: @unstable +using StyledStrings: @styled_str, annotatedstring +using DynamicExpressions: + DynamicExpressions as DE, + AbstractStructuredExpression, + AbstractExpressionNode, + AbstractExpression, + AbstractOperatorEnum, + OperatorEnum, + Expression, + Metadata, + get_contents, + with_contents, + get_metadata, + get_operators, + get_variable_names, + get_tree, + node_type, + eval_tree_array, + count_nodes +using DynamicExpressions.InterfacesModule: + ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments + +using ..CoreModule: + AbstractOptions, Dataset, CoreModule as CM, AbstractMutationWeights, has_units +using ..ConstantOptimizationModule: ConstantOptimizationModule as CO +using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE +using ..MutationFunctionsModule: MutationFunctionsModule as MF +using ..ExpressionBuilderModule: ExpressionBuilderModule as EB +using ..DimensionalAnalysisModule: DimensionalAnalysisModule as DA +using ..CheckConstraintsModule: CheckConstraintsModule as CC +using ..ComplexityModule: ComplexityModule +using ..LossFunctionsModule: LossFunctionsModule as LF +using ..MutateModule: MutateModule as MM +using ..PopMemberModule: PopMember +using ..ComposableExpressionModule: AbstractComposableExpression, VectorWrapper + +""" + HierarchicalStructure{K,S,N,E,C} <: Function + +A struct that defines a prescribed structure for a `HierarchicalExpression`, +including functions that define the result in different contexts. + +The `K` parameter is used to specify the symbols representing the inner expressions. +If not declared using the constructor `HierarchicalStructure{K}(...)`, the keys of the +`variable_constraints` `NamedTuple` will be used to infer this. + +# Fields +- `combine`: Required function taking a `NamedTuple` of function keys => expressions, + returning a single expression. Fallback method used by `get_tree` + on a `HierarchicalExpression` to generate a single `Expression`. +""" +struct HierarchicalStructure{K,E<:Function} <: Function + combine::E +end + +function HierarchicalStructure{K}(combine::E; combine_strings=nothing) where {K,E<:Function} + return HierarchicalStructure{K}(combine; combine_strings=nothing) +end + +function combine(template::HierarchicalStructure, args...) + return template.combine(args...) +end + +get_function_keys(::HierarchicalStructure{K}) where {K} = K + +""" + HierarchicalExpression{T,F,N,E,TS,D} <: AbstractStructuredExpression{T,F,N,E,D} + +A symbolic expression that allows the combination of multiple sub-expressions +in a structured way, with constraints on variable usage. + +`HierarchicalExpression` is designed for symbolic regression tasks where +domain-specific knowledge or constraints must be imposed on the model's structure. + +# Constructor + +- `HierarchicalExpression(trees; structure, operators, variable_names)` + - `trees`: A `NamedTuple` holding the sub-expressions (e.g., `f = Expression(...)`, `g = Expression(...)`). + - `structure`: A `HierarchicalStructure` which holds functions that define how the sub-expressions are combined + in different contexts. + - `operators`: An `OperatorEnum` that defines the allowed operators for the sub-expressions. + - `variable_names`: An optional `Vector` of `String` that defines the names of the variables in the dataset. + +# Example + +Let's create an example `HierarchicalExpression` that combines two sub-expressions `f(x1, x2)` and `g(x3)`: + +```julia +# Define operators and variable names +options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) +operators = options.operators +variable_names = ["x1", "x2", "x3"] + +# Create sub-expressions +x1 = Expression(Node{Float64}(; feature=1); operators, variable_names) +x2 = Expression(Node{Float64}(; feature=2); operators, variable_names) +x3 = Expression(Node{Float64}(; feature=3); operators, variable_names) + +# Create HierarchicalExpression +example_expr = (; f=x1, g=x3) +st_expr = HierarchicalExpression( + example_expr; + structure=HierarchicalStructure{(:f, :g)}( + ((; f, g), (x1, x2, x3)) -> sin(f(x1, x2)) + g(x3)^2 + ), + operators, + variable_names, +) +``` + +When fitting a model in SymbolicRegression.jl, you would provide the `HierarchicalExpression` +as the `expression_type` argument, and then pass `expression_options=(; structure=HierarchicalStructure(...))` +as additional options. The `variable_constraints` will constraint `f` to only have access to `x1` and `x2`, +and `g` to only have access to `x3`. +""" +struct HierarchicalExpression{ + T, + F<:HierarchicalStructure, + N<:AbstractExpressionNode{T}, + E<:AbstractComposableExpression{T,N}, + TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}}, + D<:@NamedTuple{structure::F, operators::O, variable_names::V} where {O,V}, +} <: AbstractStructuredExpression{T,F,N,E,D} + trees::TS + metadata::Metadata{D} + + function HierarchicalExpression( + trees::TS, metadata::Metadata{D} + ) where { + TS, + F<:HierarchicalStructure, + D<:@NamedTuple{structure::F, operators::O, variable_names::V} where {O,V}, + } + @assert keys(trees) == get_function_keys(metadata.structure) + E = typeof(first(values(trees))) + N = node_type(E) + return new{eltype(N),F,N,E,TS,D}(trees, metadata) + end +end + +function HierarchicalExpression( + trees::NamedTuple{<:Any,<:NTuple{<:Any,<:AbstractExpression}}; + structure::F, + operators::Union{AbstractOperatorEnum,Nothing}=nothing, + variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, +) where {F<:HierarchicalStructure} + example_tree = first(values(trees))::AbstractExpression + operators = get_operators(example_tree, operators) + variable_names = get_variable_names(example_tree, variable_names) + metadata = (; structure, operators, variable_names) + return HierarchicalExpression(trees, Metadata(metadata)) +end + +@unstable DE.constructorof(::Type{<:HierarchicalExpression}) = HierarchicalExpression + +@implements( + ExpressionInterface{all_ei_methods_except(())}, HierarchicalExpression, [Arguments()] +) + +function combine(ex::HierarchicalExpression, args...) + return combine(get_metadata(ex).structure, args...) +end +function get_function_keys(ex::HierarchicalExpression) + return get_function_keys(get_metadata(ex).structure) +end + +function EB.create_expression( + t::AbstractExpressionNode{T}, + options::AbstractOptions, + dataset::Dataset{T,L}, + ::Type{<:AbstractExpressionNode}, + ::Type{E}, + ::Val{embed}=Val(false), +) where {T,L,embed,E<:HierarchicalExpression} + function_keys = get_function_keys(options.expression_options.structure) + + # NOTE: We need to copy over the operators so we can call the structure function + operators = options.operators + variable_names = embed ? dataset.variable_names : nothing + inner_expressions = ntuple( + _ -> ComposableExpression(copy(t); operators, variable_names), length(function_keys) + ) + # TODO: Generalize to other inner expression types + return DE.constructorof(E)( + NamedTuple{function_keys}(inner_expressions); + EB.init_params(options, dataset, nothing, Val(embed))..., + ) +end +function EB.extra_init_params( + ::Type{E}, + prototype::Union{Nothing,AbstractExpression}, + options::AbstractOptions, + dataset::Dataset{T}, + ::Val{embed}, +) where {T,embed,E<:HierarchicalExpression} + # We also need to include the operators here to be consistent with `create_expression`. + return (; options.operators, options.expression_options...) +end +function EB.sort_params(params::NamedTuple, ::Type{<:HierarchicalExpression}) + return (; params.structure, params.operators, params.variable_names) +end + +function ComplexityModule.compute_complexity( + tree::HierarchicalExpression, options::AbstractOptions; break_sharing=Val(false) +) + # Rather than including the complexity of the combined tree, + # we only sum the complexity of each inner expression, which will be smaller. + return sum( + ex -> ComplexityModule.compute_complexity(ex, options; break_sharing), + values(get_contents(tree)), + ) +end + +_color_string(s::AbstractString, c::Symbol) = styled"{$c:$s}" +function DE.string_tree( + tree::HierarchicalExpression, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + kws..., +) + raw_contents = get_contents(tree) + function_keys = keys(raw_contents) + colors = Base.Iterators.cycle((:magenta, :green, :red, :blue, :yellow, :cyan)) + inner_strings = NamedTuple{function_keys}( + map(ex -> DE.string_tree(ex, operators; kws...), values(raw_contents)) + ) + colored_strings = NamedTuple{function_keys}(map(_color_string, inner_strings, colors)) + return join((annotatedstring(k, " = ", v) for (k, v) in pairs(colored_strings)), "\n") +end +function DE.eval_tree_array( + tree::HierarchicalExpression{T}, + cX::AbstractMatrix{T}, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + kws..., +) where {T} + raw_contents = get_contents(tree) + result = combine(tree, raw_contents, map(x -> VectorWrapper(x, true), eachrow(cX))) + return result.value, result.valid +end +function (ex::HierarchicalExpression)( + X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... +) + result, valid = DE.eval_tree_array(ex, X, operators; kws...) + if valid + return result + else + nan = convert(eltype(result), NaN) + return result .* nan + end +end +@unstable IDE.expected_array_type(::AbstractMatrix, ::Type{<:HierarchicalExpression}) = Any + +function DA.violates_dimensional_constraints( + @nospecialize(tree::HierarchicalExpression), + dataset::Dataset, + @nospecialize(options::AbstractOptions) +) + @assert !has_units(dataset) + return false +end +function MM.condition_mutation_weights!( + @nospecialize(weights::AbstractMutationWeights), + @nospecialize(member::P), + @nospecialize(options::AbstractOptions), + curmaxsize::Int, +) where {T,L,N<:HierarchicalExpression,P<:PopMember{T,L,N}} + # HACK TODO + return nothing +end + +""" +We need full specialization for constrained expressions, as they rely on subexpressions being combined. +""" +function CM.operator_specialization( + ::Type{O}, ::Type{<:HierarchicalExpression} +) where {O<:OperatorEnum} + return O +end + +""" +We pick a random subexpression to mutate, +and also return the symbol we mutated on so that we can put it back together later. +""" +function MF.get_contents_for_mutation(ex::HierarchicalExpression, rng::AbstractRNG) + raw_contents = get_contents(ex) + function_keys = keys(raw_contents) + + # Sample weighted by number of nodes in each subexpression + num_nodes = map(count_nodes, values(raw_contents)) + weights = map(Base.Fix2(/, sum(num_nodes)), num_nodes) + cumsum_weights = cumsum(weights) + rand_val = rand(rng) + idx = findfirst(Base.Fix2(>=, rand_val), cumsum_weights)::Int + + key_to_mutate = function_keys[idx] + return raw_contents[key_to_mutate], key_to_mutate +end + +"""See `get_contents_for_mutation(::HierarchicalExpression, ::AbstractRNG)`.""" +function MF.with_contents_for_mutation( + ex::HierarchicalExpression, new_inner_contents, context::Symbol +) + raw_contents = get_contents(ex) + raw_contents_keys = keys(raw_contents) + new_contents = NamedTuple{raw_contents_keys}( + ntuple(length(raw_contents_keys)) do i + if raw_contents_keys[i] == context + new_inner_contents + else + raw_contents[raw_contents_keys[i]] + end + end, + ) + return with_contents(ex, new_contents) +end + +"""We combine the operators of each inner expression.""" +function DE.combine_operators( + ex::HierarchicalExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) where {T,N} + raw_contents = get_contents(ex) + function_keys = keys(raw_contents) + new_contents = NamedTuple{function_keys}( + map(Base.Fix2(DE.combine_operators, operators), values(raw_contents)) + ) + return with_contents(ex, new_contents) +end + +"""We simplify each inner expression.""" +function DE.simplify_tree!( + ex::HierarchicalExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing +) where {T,N} + raw_contents = get_contents(ex) + function_keys = keys(raw_contents) + new_contents = NamedTuple{function_keys}( + map(Base.Fix2(DE.simplify_tree!, operators), values(raw_contents)) + ) + return with_contents(ex, new_contents) +end + +function CO.count_constants_for_optimization(ex::HierarchicalExpression) + return sum(CO.count_constants_for_optimization, values(get_contents(ex))) +end + +# function CC.check_constraints( +# ex::HierarchicalExpression, +# options::AbstractOptions, +# maxsize::Int, +# cursize::Union{Int,Nothing}=nothing, +# )::Bool +# raw_contents = get_contents(ex) +# variable_constraints = get_metadata(ex).structure.variable_constraints + +# # First, we check the variable constraints at the top level: +# has_invalid_variables = any(keys(raw_contents)) do key +# tree = raw_contents[key] +# allowed_variables = variable_constraints[key] +# contains_other_features_than(tree, allowed_variables) +# end +# if has_invalid_variables +# return false +# end + +# # We also check the combined complexity: +# ((cursize === nothing) ? ComplexityModule.compute_complexity(ex, options) : cursize) > +# maxsize && return false + +# # Then, we check other constraints for inner expressions: +# for t in values(raw_contents) +# if !CC.check_constraints(t, options, maxsize, nothing) +# return false +# end +# end +# return true +# # TODO: The concept of `cursize` doesn't really make sense here. +# end +# function contains_other_features_than(tree::AbstractExpression, features) +# return contains_other_features_than(get_tree(tree), features) +# end +# function contains_other_features_than(tree::AbstractExpressionNode, features) +# any(tree) do node +# node.degree == 0 && !node.constant && node.feature ∉ features +# end +# end + +# TODO: Add custom behavior to adjust what feature nodes can be generated + +end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index df46d107f..4428a73a7 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -222,8 +222,9 @@ using DispatchDoctor: @stable include("Migration.jl") include("SearchUtils.jl") include("ExpressionBuilder.jl") - include("TemplateExpression.jl") include("ComposableExpression.jl") + include("TemplateExpression.jl") + include("HierarchicalExpression.jl") include("ParametricExpression.jl") end From bc48fcc493bf395d4a1da550f24ab11b86472f62 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 1 Nov 2024 23:14:08 +0000 Subject: [PATCH 05/59] feat: enable `VectorWrapper` for other operators --- src/ComposableExpression.jl | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index bd6c843dd..07fd3b5d0 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -7,6 +7,8 @@ using DynamicExpressions: Metadata, eval_tree_array, set_node!, + get_contents, + with_contents, DynamicExpressions as DE using DynamicExpressions.InterfacesModule: ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments @@ -110,6 +112,19 @@ end # Basically we want to vectorize every single operation on VectorWrapper, # so that the user can use it easily. +function apply_operator(op::F, x...) where {F<:Function} + if all(_is_valid, x) + vx = map(_get_value, x) + return VectorWrapper(op.(vx...), true) + else + return VectorWrapper(_get_value(first(x)), false) + end +end +_is_valid(x::VectorWrapper) = x.valid +_is_valid(x) = true +_get_value(x::VectorWrapper) = x.value +_get_value(x) = x + #! format: off # First, binary operators: for op in ( @@ -117,8 +132,10 @@ for op in ( :atan, :atand, :copysign, :flipsign, :&, :|, :⊻, ://, :\, ) - @eval function Base.$(op)(x::VectorWrapper, y::VectorWrapper) - return VectorWrapper(@. Base.$(op)(x.value, y.value)) + @eval begin + Base.$(op)(x::VectorWrapper, y::VectorWrapper) = apply_operator(Base.$(op), x, y) + Base.$(op)(x::VectorWrapper, y::Number) = apply_operator(Base.$(op), x, y) + Base.$(op)(x::Number, y::VectorWrapper) = apply_operator(Base.$(op), x, y) end end @@ -134,9 +151,7 @@ for op in ( :inv, :sqrt, :cbrt, :abs2, :angle, :factorial, :(!), :-, :+, :sign, :identity, ) - @eval function Base.$(op)(x::VectorWrapper) - return VectorWrapper(@. Base.$(op)(x.value)) - end + @eval Base.$(op)(x::VectorWrapper) = apply_operator(Base.$(op), x) end #! format: on From c3aa38be7e542655312db78936c955c373182c0e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Fri, 1 Nov 2024 23:15:51 +0000 Subject: [PATCH 06/59] fix: HierarchicalExpression instabilities --- src/HierarchicalExpression.jl | 12 +++++--- src/SymbolicRegression.jl | 3 ++ test/test_composable_expression.jl | 45 +++++++++++++++++++++++++++++- 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/src/HierarchicalExpression.jl b/src/HierarchicalExpression.jl index 3d399238b..abf8c5d17 100644 --- a/src/HierarchicalExpression.jl +++ b/src/HierarchicalExpression.jl @@ -58,8 +58,8 @@ struct HierarchicalStructure{K,E<:Function} <: Function combine::E end -function HierarchicalStructure{K}(combine::E; combine_strings=nothing) where {K,E<:Function} - return HierarchicalStructure{K}(combine; combine_strings=nothing) +function HierarchicalStructure{K}(combine::E) where {K,E<:Function} + return HierarchicalStructure{K,E}(combine) end function combine(template::HierarchicalStructure, args...) @@ -229,7 +229,9 @@ function DE.string_tree( map(ex -> DE.string_tree(ex, operators; kws...), values(raw_contents)) ) colored_strings = NamedTuple{function_keys}(map(_color_string, inner_strings, colors)) - return join((annotatedstring(k, " = ", v) for (k, v) in pairs(colored_strings)), "\n") + return join( + (annotatedstring(k, " = ", v) for (k, v) in pairs(colored_strings)), styled"\n" + ) end function DE.eval_tree_array( tree::HierarchicalExpression{T}, @@ -238,7 +240,9 @@ function DE.eval_tree_array( kws..., ) where {T} raw_contents = get_contents(tree) - result = combine(tree, raw_contents, map(x -> VectorWrapper(x, true), eachrow(cX))) + result = combine( + tree, raw_contents, map(x -> VectorWrapper(copy(x), true), eachrow(cX)) + ) return result.value, result.valid end function (ex::HierarchicalExpression)( diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 4428a73a7..ca0fb1ed6 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -14,6 +14,8 @@ export Population, ParametricExpression, TemplateExpression, TemplateStructure, + HierarchicalExpression, + HierarchicalStructure, ComposableExpression, NodeSampler, AbstractExpression, @@ -320,6 +322,7 @@ using .SearchUtilsModule: get_cur_maxsize, update_hall_of_fame! using .TemplateExpressionModule: TemplateExpression, TemplateStructure +using .HierarchicalExpressionModule: HierarchicalExpression, HierarchicalStructure using .ComposableExpressionModule: ComposableExpression using .ExpressionBuilderModule: embed_metadata, strip_metadata diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index e32d49fed..27b809acf 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -1,4 +1,3 @@ - @testitem "Test ComposableExpression" tags = [:part2] begin using SymbolicRegression: ComposableExpression, Node using DynamicExpressions: OperatorEnum @@ -23,3 +22,47 @@ end @test Interfaces.test(ExpressionInterface, ComposableExpression, [f]) end + +@testitem "Printing and evaluation of HierarchicalExpression" begin + using SymbolicRegression + + structure = HierarchicalStructure{(:f, :g)}( + ((; f, g), (x1, x2, x3)) -> let + sin(f(x1, x2)) + g(x3)^2 + end + ) + operators = Options().operators + variable_names = ["x1", "x2", "x3"] + + x1, x2, x3 = [ + ComposableExpression(Node{Float64}(; feature=i); operators, variable_names) for + i in 1:3 + ] + f = x1 * x2 + g = x1 + expr = HierarchicalExpression((; f, g); structure, operators, variable_names) + + # Default printing strategy: + @test String(string_tree(expr)) == "f = x1 * x2\ng = x1" + + x1_val = randn(5) + x2_val = randn(5) + + # The feature indicates the index passed as argument: + @test x1(x1_val) ≈ x1_val + @test x2(x1_val, x2_val) ≈ x2_val + @test x1(x2_val) ≈ x2_val + + # Composing expressions and then calling: + @test String(string_tree((x1 * x2)(x3, x3))) == "x3 * x3" + + # Can evaluate with `sin` even though it's not in the allowed operators! + X = randn(3, 5) + x1_val = X[1, :] + x2_val = X[2, :] + x3_val = X[3, :] + @test expr(X) ≈ @. sin(x1_val * x2_val) + x3_val^2 + + # This is even though `g` is defined on `x1` only: + @test g(x3_val) ≈ x3_val +end From 756a2d95b724ca76269f2066bf65c0e7d06e9da0 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 11:39:15 +0000 Subject: [PATCH 07/59] fix: need to freeze operators in HierarchicalExpression and ComposableExpression --- src/ComposableExpression.jl | 11 +++++++---- src/HierarchicalExpression.jl | 11 +++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index 07fd3b5d0..6dbcdfbec 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -15,8 +15,11 @@ using DynamicExpressions.InterfacesModule: abstract type AbstractComposableExpression{T,N} <: AbstractExpression{T,N} end -struct ComposableExpression{T,N<:AbstractExpressionNode{T},D<:NamedTuple} <: - AbstractComposableExpression{T,N} +struct ComposableExpression{ + T, + N<:AbstractExpressionNode{T}, + D<:@NamedTuple{operators::O, variable_names::V} where {O<:AbstractOperatorEnum,V}, +} <: AbstractComposableExpression{T,N} tree::N metadata::Metadata{D} end @@ -35,13 +38,13 @@ DE.get_tree(ex::AbstractComposableExpression) = ex.tree function DE.get_operators( ex::AbstractComposableExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing ) - return something(operators, DE.get_metadata(ex).operators) + return @something(operators, DE.get_metadata(ex).operators) end function DE.get_variable_names( ex::AbstractComposableExpression, variable_names::Union{Nothing,AbstractVector{<:AbstractString}}=nothing, ) - return something(variable_names, DE.get_metadata(ex).variable_names) + return @something(variable_names, DE.get_metadata(ex).variable_names, Some(nothing)) end function DE.get_scalar_constants(ex::AbstractComposableExpression) diff --git a/src/HierarchicalExpression.jl b/src/HierarchicalExpression.jl index abf8c5d17..4c529d3c7 100644 --- a/src/HierarchicalExpression.jl +++ b/src/HierarchicalExpression.jl @@ -37,7 +37,7 @@ using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF using ..MutateModule: MutateModule as MM using ..PopMemberModule: PopMember -using ..ComposableExpressionModule: AbstractComposableExpression, VectorWrapper +using ..ComposableExpressionModule: ComposableExpression, VectorWrapper """ HierarchicalStructure{K,S,N,E,C} <: Function @@ -122,9 +122,11 @@ struct HierarchicalExpression{ T, F<:HierarchicalStructure, N<:AbstractExpressionNode{T}, - E<:AbstractComposableExpression{T,N}, + E<:ComposableExpression{T,N}, TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}}, - D<:@NamedTuple{structure::F, operators::O, variable_names::V} where {O,V}, + D<:@NamedTuple{ + structure::F, operators::O, variable_names::V + } where {O<:AbstractOperatorEnum,V}, } <: AbstractStructuredExpression{T,F,N,E,D} trees::TS metadata::Metadata{D} @@ -183,7 +185,8 @@ function EB.create_expression( operators = options.operators variable_names = embed ? dataset.variable_names : nothing inner_expressions = ntuple( - _ -> ComposableExpression(copy(t); operators, variable_names), length(function_keys) + _ -> ComposableExpression(copy(t); operators, variable_names), + Val(length(function_keys)), ) # TODO: Generalize to other inner expression types return DE.constructorof(E)( From eca9b91db9ec637963b86646c951fe6aeeb48d82 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 13:52:32 +0000 Subject: [PATCH 08/59] feat: validation of inferred constraints --- src/HierarchicalExpression.jl | 61 +++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 10 deletions(-) diff --git a/src/HierarchicalExpression.jl b/src/HierarchicalExpression.jl index 4c529d3c7..b5de19559 100644 --- a/src/HierarchicalExpression.jl +++ b/src/HierarchicalExpression.jl @@ -50,24 +50,63 @@ If not declared using the constructor `HierarchicalStructure{K}(...)`, the keys `variable_constraints` `NamedTuple` will be used to infer this. # Fields -- `combine`: Required function taking a `NamedTuple` of function keys => expressions, - returning a single expression. Fallback method used by `get_tree` - on a `HierarchicalExpression` to generate a single `Expression`. +- `combine`: Required function taking a `NamedTuple` of callable expressions (with keys `K`), + and a tuple representing the data. For example, `((; f, g), (x1, x2, x3)) -> f(x1, x2) + g(x3)` + would be a valid `combine` function. You may also re-use the callable expressions and + use different inputs, such as `((; f, g), (x1, x2)) -> f(x1 + g(x2)) - g(x1)` is + another valid choice. +- `num_features`: Optional `NamedTuple` of function keys => integers representing the number of + features used by each expression. If not provided, it will be inferred using the `combine` + function. For example, if `f` takes two arguments, and `g` takes one, then + `num_features = (; f=2, g=1)`. """ -struct HierarchicalStructure{K,E<:Function} <: Function +struct HierarchicalStructure{K,E<:Function,NF<:NamedTuple} <: Function combine::E + num_features::NF end -function HierarchicalStructure{K}(combine::E) where {K,E<:Function} - return HierarchicalStructure{K,E}(combine) +function HierarchicalStructure{K}(combine::E, num_features=nothing) where {K,E<:Function} + num_features = @something(num_features, infer_variable_constraints(Val(K), combine)) + return HierarchicalStructure{K,E,typeof(num_features)}(combine, num_features) end -function combine(template::HierarchicalStructure, args...) +@unstable function combine(template::HierarchicalStructure, args...) return template.combine(args...) end get_function_keys(::HierarchicalStructure{K}) where {K} = K +function _record_composable_expression!(variable_constraints, ::Val{k}, args...) where {k} + vc = variable_constraints[k][] + if vc == -1 + variable_constraints[k][] = length(args) + elseif vc != length(args) + throw(ArgumentError("Inconsistent number of arguments passed to $k")) + end + return first(args) +end + +"""Infers number of features used by each subexpression, by passing in test data.""" +function infer_variable_constraints(::Val{K}, combiner::F) where {K,F} + variable_constraints = NamedTuple{K}(map(_ -> Ref(-1), K)) + # Now, we need to evaluate the `combine` function to see how many + # features are used for each function call. If unset, we record it. + # If set, we validate. + inner = Fix{1}(_record_composable_expression!, variable_constraints) + _recorders_of_composable_expressions = NamedTuple{K}(map(k -> Fix{1}(inner, Val(k)), K)) + # We use an evaluation to get the variable constraints + combiner( + _recorders_of_composable_expressions, + Base.Iterators.repeated(VectorWrapper(ones(Float64, 1), true)), + ) + inferred = NamedTuple{K}(map(x -> x[], values(variable_constraints))) + if any(==(-1), values(inferred)) + failed_keys = filter(k -> inferred[k] == -1, K) + throw(ArgumentError("Failed to infer number of features used by $failed_keys")) + end + return inferred +end + """ HierarchicalExpression{T,F,N,E,TS,D} <: AbstractStructuredExpression{T,F,N,E,D} @@ -164,7 +203,7 @@ end ExpressionInterface{all_ei_methods_except(())}, HierarchicalExpression, [Arguments()] ) -function combine(ex::HierarchicalExpression, args...) +@unstable function combine(ex::HierarchicalExpression, args...) return combine(get_metadata(ex).structure, args...) end function get_function_keys(ex::HierarchicalExpression) @@ -232,8 +271,10 @@ function DE.string_tree( map(ex -> DE.string_tree(ex, operators; kws...), values(raw_contents)) ) colored_strings = NamedTuple{function_keys}(map(_color_string, inner_strings, colors)) - return join( - (annotatedstring(k, " = ", v) for (k, v) in pairs(colored_strings)), styled"\n" + return annotatedstring( + join( + (annotatedstring(k, " = ", v) for (k, v) in pairs(colored_strings)), styled"\n" + ), ) end function DE.eval_tree_array( From 05419e46681a780c8e92716b1797fef2c50f2003 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 15:11:24 +0000 Subject: [PATCH 09/59] feat: make hierarchical expressions compatible --- src/ComposableExpression.jl | 26 +++++++++++ src/HierarchicalExpression.jl | 85 ++++++++++++++++++----------------- src/TemplateExpression.jl | 4 +- 3 files changed, 73 insertions(+), 42 deletions(-) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index 6dbcdfbec..0b8fb4067 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -1,10 +1,14 @@ module ComposableExpressionModule +using DispatchDoctor: @unstable using DynamicExpressions: AbstractExpression, + Expression, AbstractExpressionNode, AbstractOperatorEnum, Metadata, + constructorof, + get_metadata, eval_tree_array, set_node!, get_contents, @@ -13,6 +17,8 @@ using DynamicExpressions: using DynamicExpressions.InterfacesModule: ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments +using ..ConstantOptimizationModule: ConstantOptimizationModule as CO + abstract type AbstractComposableExpression{T,N} <: AbstractExpression{T,N} end struct ComposableExpression{ @@ -31,6 +37,8 @@ end return ComposableExpression(tree, Metadata(d)) end +@unstable DE.constructorof(::Type{<:ComposableExpression}) = ComposableExpression + DE.get_metadata(ex::AbstractComposableExpression) = ex.metadata DE.get_contents(ex::AbstractComposableExpression) = ex.tree DE.get_tree(ex::AbstractComposableExpression) = ex.tree @@ -58,6 +66,24 @@ function Base.copy(ex::AbstractComposableExpression) return ComposableExpression(copy(ex.tree), copy(ex.metadata)) end +function Base.convert(::Type{E}, ex::AbstractComposableExpression) where {E<:Expression} + return constructorof(E)(get_contents(ex), get_metadata(ex)) +end + +for name in (:combine_operators, :simplify_tree!) + @eval function DE.$name( + ex::AbstractComposableExpression{T,N}, + operators::Union{AbstractOperatorEnum,Nothing}=nothing, + ) where {T,N} + inner_ex = DE.$name(convert(Expression, ex), operators) + return with_contents(ex, inner_ex) + end +end + +function CO.count_constants_for_optimization(ex::AbstractComposableExpression) + return CO.count_constants_for_optimization(convert(Expression, ex)) +end + @implements( ExpressionInterface{all_ei_methods_except(())}, ComposableExpression, [Arguments()] ) diff --git a/src/HierarchicalExpression.jl b/src/HierarchicalExpression.jl index b5de19559..eba7cd787 100644 --- a/src/HierarchicalExpression.jl +++ b/src/HierarchicalExpression.jl @@ -284,6 +284,9 @@ function DE.eval_tree_array( kws..., ) where {T} raw_contents = get_contents(tree) + if has_invalid_variables(tree) + return (cX[1, :], false) + end result = combine( tree, raw_contents, map(x -> VectorWrapper(copy(x), true), eachrow(cX)) ) @@ -394,46 +397,48 @@ function CO.count_constants_for_optimization(ex::HierarchicalExpression) return sum(CO.count_constants_for_optimization, values(get_contents(ex))) end -# function CC.check_constraints( -# ex::HierarchicalExpression, -# options::AbstractOptions, -# maxsize::Int, -# cursize::Union{Int,Nothing}=nothing, -# )::Bool -# raw_contents = get_contents(ex) -# variable_constraints = get_metadata(ex).structure.variable_constraints - -# # First, we check the variable constraints at the top level: -# has_invalid_variables = any(keys(raw_contents)) do key -# tree = raw_contents[key] -# allowed_variables = variable_constraints[key] -# contains_other_features_than(tree, allowed_variables) -# end -# if has_invalid_variables -# return false -# end - -# # We also check the combined complexity: -# ((cursize === nothing) ? ComplexityModule.compute_complexity(ex, options) : cursize) > -# maxsize && return false - -# # Then, we check other constraints for inner expressions: -# for t in values(raw_contents) -# if !CC.check_constraints(t, options, maxsize, nothing) -# return false -# end -# end -# return true -# # TODO: The concept of `cursize` doesn't really make sense here. -# end -# function contains_other_features_than(tree::AbstractExpression, features) -# return contains_other_features_than(get_tree(tree), features) -# end -# function contains_other_features_than(tree::AbstractExpressionNode, features) -# any(tree) do node -# node.degree == 0 && !node.constant && node.feature ∉ features -# end -# end +function CC.check_constraints( + ex::HierarchicalExpression, + options::AbstractOptions, + maxsize::Int, + cursize::Union{Int,Nothing}=nothing, +)::Bool + # First, we check the variable constraints at the top level: + if has_invalid_variables(ex) + return false + end + + # We also check the combined complexity: + @something(cursize, ComplexityModule.compute_complexity(ex, options)) > maxsize && + return false + + # Then, we check other constraints for inner expressions: + raw_contents = get_contents(ex) + for t in values(raw_contents) + if !CC.check_constraints(t, options, maxsize, nothing) + return false + end + end + return true + # TODO: The concept of `cursize` doesn't really make sense here. +end +function has_invalid_variables(ex::HierarchicalExpression) + raw_contents = get_contents(ex) + num_features = get_metadata(ex).structure.num_features + any(keys(raw_contents)) do key + tree = raw_contents[key] + max_feature = num_features[key] + contains_features_greater_than(tree, max_feature) + end +end +function contains_features_greater_than(tree::AbstractExpression, max_feature) + return contains_features_greater_than(get_tree(tree), max_feature) +end +function contains_features_greater_than(tree::AbstractExpressionNode, max_feature) + any(tree) do node + node.degree == 0 && !node.constant && node.feature > max_feature + end +end # TODO: Add custom behavior to adjust what feature nodes can be generated diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 586589fab..42a045c3f 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -507,8 +507,8 @@ function CC.check_constraints( end # We also check the combined complexity: - ((cursize === nothing) ? ComplexityModule.compute_complexity(ex, options) : cursize) > - maxsize && return false + @something(cursize, ComplexityModule.compute_complexity(ex, options)) > maxsize && + return false # Then, we check other constraints for inner expressions: for t in values(raw_contents) From 81b18702f1b478f9ba917f5979c42a5a21058de5 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 17:31:40 +0000 Subject: [PATCH 10/59] feat: better printing for HierarchicalExpression --- src/HierarchicalExpression.jl | 46 ++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/src/HierarchicalExpression.jl b/src/HierarchicalExpression.jl index eba7cd787..9d91fe628 100644 --- a/src/HierarchicalExpression.jl +++ b/src/HierarchicalExpression.jl @@ -210,6 +210,23 @@ function get_function_keys(ex::HierarchicalExpression) return get_function_keys(get_metadata(ex).structure) end +function DE.get_tree(ex::HierarchicalExpression{<:Any,<:Any,<:Any,E}) where {E} + raw_contents = get_contents(ex) + total_num_features = max(values(get_metadata(ex).structure.num_features)...) + example_inner_ex = first(values(raw_contents)) + example_tree = get_contents(example_inner_ex)::AbstractExpressionNode + + variable_trees = [ + DE.constructorof(typeof(example_tree))(; feature=i) for i in 1:total_num_features + ] + variable_expressions = [ + with_contents(inner_ex, variable_tree) for + (inner_ex, variable_tree) in zip(values(raw_contents), variable_trees) + ] + + return combine(get_metadata(ex).structure, raw_contents, variable_expressions) +end + function EB.create_expression( t::AbstractExpressionNode{T}, options::AbstractOptions, @@ -258,24 +275,41 @@ function ComplexityModule.compute_complexity( ) end +# Rather than using iterator with repeat, just make a tuple: +function _colors(::Val{n}) where {n} + return ntuple( + (i -> (:magenta, :green, :red, :blue, :yellow, :cyan)[mod1(i, n)]), Val(n) + ) +end + _color_string(s::AbstractString, c::Symbol) = styled"{$c:$s}" function DE.string_tree( tree::HierarchicalExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing; + variable_names=nothing, kws..., ) raw_contents = get_contents(tree) function_keys = keys(raw_contents) - colors = Base.Iterators.cycle((:magenta, :green, :red, :blue, :yellow, :cyan)) + num_features = get_metadata(tree).structure.num_features + total_num_features = max(values(num_features)...) + colors = _colors(Val(length(function_keys))) + variable_names = ["#" * string(i) for i in 1:total_num_features] inner_strings = NamedTuple{function_keys}( - map(ex -> DE.string_tree(ex, operators; kws...), values(raw_contents)) + map( + ex -> DE.string_tree(ex, operators; variable_names, kws...), + values(raw_contents), + ), ) - colored_strings = NamedTuple{function_keys}(map(_color_string, inner_strings, colors)) - return annotatedstring( - join( - (annotatedstring(k, " = ", v) for (k, v) in pairs(colored_strings)), styled"\n" + strings = NamedTuple{function_keys}( + map( + (k, s, c) -> annotatedstring(string(k) * " = ", _color_string(s, c)), + function_keys, + values(inner_strings), + colors, ), ) + return annotatedstring(join(strings, styled"\n")) end function DE.eval_tree_array( tree::HierarchicalExpression{T}, From b517a8da417c72dee41705642d77e6d78159c020 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 17:47:36 +0000 Subject: [PATCH 11/59] feat: info dump at end of search --- src/ProgressBars.jl | 7 ++++++- src/SymbolicRegression.jl | 44 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/ProgressBars.jl b/src/ProgressBars.jl index 1b6bc402d..7214399f2 100644 --- a/src/ProgressBars.jl +++ b/src/ProgressBars.jl @@ -1,7 +1,7 @@ module ProgressBarsModule using Compat: Fix -using ProgressMeter: Progress, next! +using ProgressMeter: ProgressMeter, Progress, next!, finish! using StyledStrings: @styled_str, annotatedstring using ..UtilsModule: AnnotatedString @@ -26,6 +26,11 @@ function barlen(pbar::WrappedProgressBar)::Int return @something(pbar.bar.barlen, displaysize(stdout)[2]) end +function ProgressMeter.finish!(pbar::WrappedProgressBar) + ProgressMeter.finish!(pbar.bar) + return nothing +end + """Iterate a progress bar.""" function manually_iterate!(pbar::WrappedProgressBar) width = barlen(pbar) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index ca0fb1ed6..d53c4d098 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -288,7 +288,7 @@ using .HallOfFameModule: HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve using .MutateModule: mutate!, condition_mutation_weights!, MutationResult using .SingleIterationModule: s_r_cycle, optimize_and_simplify_population -using .ProgressBarsModule: WrappedProgressBar +using .ProgressBarsModule: WrappedProgressBar, finish! using .RecorderModule: @recorder, find_iteration_from_record using .MigrationModule: migrate! using .SearchUtilsModule: @@ -539,6 +539,7 @@ end _warmup_search!(state, datasets, ropt, options) _main_search_loop!(state, datasets, ropt, options) _tear_down!(state, ropt, options) + _info_dump(state, datasets, ropt, options) return _format_output(state, datasets, ropt, options) end @@ -1016,6 +1017,9 @@ function _main_search_loop!( end ################################################################ end + if ropt.progress + finish!(progress_bar) + end return nothing end function _tear_down!( @@ -1112,4 +1116,42 @@ redirect_stdout(devnull) do end end +function _info_dump( + state::AbstractSearchState, + datasets::Vector{D}, + ropt::AbstractRuntimeOptions, + options::AbstractOptions, +) where {D<:Dataset} + ropt.verbosity <= 0 && return nothing + + nout = length(state.halls_of_fame) + if nout > 1 + @info "Final populations:" + else + @info "Final population:" + end + for (j, (hall_of_fame, dataset)) in enumerate(zip(state.halls_of_fame, datasets)) + if nout > 1 + @info "Output $j:" + end + equation_strings = string_dominating_pareto_curve( + hall_of_fame, dataset, options; width=options.terminal_width + ) + println(equation_strings) + end + + if options.save_to_file + output_directory = joinpath( + something(options.output_directory, "outputs"), ropt.run_id + ) + @info "Results saved to:" + for j in 1:nout + filename = nout > 1 ? "hall_of_fame_output$(j).csv" : "hall_of_fame.csv" + output_file = joinpath(output_directory, filename) + println(" - ", output_file) + end + end + return nothing +end + end #module SR From d4c84dc4f2c4e4dd622f3104c870e0a10651bb62 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 18:21:27 +0000 Subject: [PATCH 12/59] fix: correct return type for `get_tree` --- Project.toml | 2 +- src/HierarchicalExpression.jl | 4 +++- src/SymbolicRegression.jl | 9 +++++++- test/test_composable_expression.jl | 36 +++++++++++++++++++++++++++--- 4 files changed, 45 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index ba95ee1f4..820dc1048 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ Dates = "1" DifferentiationInterface = "0.5, 0.6" DispatchDoctor = "^0.4.17" Distributed = "<0.0.1, 1" -DynamicExpressions = "1.4" +DynamicExpressions = "1.4.1" DynamicQuantities = "1" Enzyme = "0.12" JSON3 = "1" diff --git a/src/HierarchicalExpression.jl b/src/HierarchicalExpression.jl index 9d91fe628..5375d8bbb 100644 --- a/src/HierarchicalExpression.jl +++ b/src/HierarchicalExpression.jl @@ -224,7 +224,9 @@ function DE.get_tree(ex::HierarchicalExpression{<:Any,<:Any,<:Any,E}) where {E} (inner_ex, variable_tree) in zip(values(raw_contents), variable_trees) ] - return combine(get_metadata(ex).structure, raw_contents, variable_expressions) + return DE.get_tree( + combine(get_metadata(ex).structure, raw_contents, variable_expressions) + ) end function EB.create_expression( diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index d53c4d098..f07542a95 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -1135,7 +1135,14 @@ function _info_dump( @info "Output $j:" end equation_strings = string_dominating_pareto_curve( - hall_of_fame, dataset, options; width=options.terminal_width + hall_of_fame, + dataset, + options; + width=@something( + options.terminal_width, + ropt.progress ? displaysize(stdout)[2] : nothing, + Some(nothing) + ) ) println(equation_strings) end diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index 27b809acf..a0bac36c0 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -11,16 +11,46 @@ @test ex(x, y) == x end -@testitem "Test interface" tags = [:part2] begin +@testitem "Test interface for ComposableExpression" tags = [:part2] begin using SymbolicRegression: ComposableExpression using DynamicExpressions.InterfacesModule: Interfaces, ExpressionInterface using DynamicExpressions: OperatorEnum operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) variable_names = (i -> "x$i").(1:3) - f = ComposableExpression(Node(Float64; feature=1); operators, variable_names) + x1 = ComposableExpression(Node(Float64; feature=1); operators, variable_names) + x2 = ComposableExpression(Node(Float64; feature=2); operators, variable_names) - @test Interfaces.test(ExpressionInterface, ComposableExpression, [f]) + f = x1 * sin(x2) + g = f(f, f) + + @test string_tree(f) == "x1 * sin(x2)" + @test string_tree(g) == "(x1 * sin(x2)) * sin(x1 * sin(x2))" + + @test Interfaces.test(ExpressionInterface, ComposableExpression, [f, g]) +end + +@testitem "Test interface for HierarchicalExpression" tags = [:part2] begin + using SymbolicRegression + using SymbolicRegression: HierarchicalExpression + using DynamicExpressions.InterfacesModule: Interfaces, ExpressionInterface + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + x1 = ComposableExpression(Node(Float64; feature=1); operators, variable_names) + x2 = ComposableExpression(Node(Float64; feature=2); operators, variable_names) + + structure = HierarchicalStructure{(:f, :g)}( + ((; f, g), (x1, x2)) -> f(f(f(x1))) - f(g(x2, x1)) + ) + @test structure.num_features == (; f=1, g=2) + + expr = HierarchicalExpression((; f=x1, g=x2 * x2); structure, operators, variable_names) + + @test String(string_tree(expr)) == "f = #1\ng = #2 * #2" + @test string_tree(get_tree(expr)) == "x1 - (x1 * x1)" + @test Interfaces.test(ExpressionInterface, HierarchicalExpression, [expr]) end @testitem "Printing and evaluation of HierarchicalExpression" begin From 15a6159fb3da55ea6f958281a5aa7de129adf0c3 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 19:58:13 +0000 Subject: [PATCH 13/59] feat: print with `=` to not have breaks --- src/SearchUtils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index ed433df65..a09893e62 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -470,7 +470,7 @@ function print_search_state( 100.0 * cycles_elapsed / total_cycles / nout ) - print("="^twidth * "\n") + print("═"^twidth * "\n") for (j, (hall_of_fame, dataset)) in enumerate(zip(hall_of_fames, datasets)) if nout > 1 @printf("Best equations for output %d\n", j) @@ -479,7 +479,7 @@ function print_search_state( hall_of_fame, dataset, options; width=width ) print(equation_strings * "\n") - print("="^twidth * "\n") + print("═"^twidth * "\n") end return print("Press 'q' and then to stop execution early.\n") end From a05bb16c91288577ecfa55a76216e93b0bdaed73 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 19:58:47 +0000 Subject: [PATCH 14/59] feat: ensure we save the full expression string --- src/SearchUtils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index a09893e62..745d7581f 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -580,7 +580,7 @@ function save_to_file( complexities[i] = compute_complexity(member, options) losses[i] = member.loss strings[i] = string_tree( - member.tree, options; variable_names=dataset.variable_names + member.tree, options; variable_names=dataset.variable_names, pretty=false ) end From 609b7da0bb2966997334ebdf26e621a798ad8cb2 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 20:00:19 +0000 Subject: [PATCH 15/59] fix: switch to `pretty` over `raw` --- Project.toml | 2 +- src/Dataset.jl | 2 -- src/HallOfFame.jl | 4 ++-- src/HierarchicalExpression.jl | 21 +++++++++++++++------ src/InterfaceDynamicExpressions.jl | 10 +++++----- src/MLJInterface.jl | 7 +++++-- src/deprecates.jl | 4 +--- test/test_units.jl | 12 ++++++------ 8 files changed, 35 insertions(+), 27 deletions(-) diff --git a/Project.toml b/Project.toml index 820dc1048..ce0f2f397 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ Dates = "1" DifferentiationInterface = "0.5, 0.6" DispatchDoctor = "^0.4.17" Distributed = "<0.0.1, 1" -DynamicExpressions = "1.4.1" +DynamicExpressions = "1.5.0" DynamicQuantities = "1" Enzyme = "0.12" JSON3 = "1" diff --git a/src/Dataset.jl b/src/Dataset.jl index 49a452938..4a44180ae 100644 --- a/src/Dataset.jl +++ b/src/Dataset.jl @@ -102,13 +102,11 @@ function Dataset( X_units::Union{AbstractVector,Nothing}=nothing, y_units=nothing, # Deprecated: - varMap=nothing, kws..., ) where {T<:DATA_TYPE,L} Base.require_one_based_indexing(X) y !== nothing && Base.require_one_based_indexing(y) # Deprecation warning: - variable_names = deprecate_varmap(variable_names, varMap, :Dataset) if haskey(kws, :loss_type) Base.depwarn( "The `loss_type` keyword argument is deprecated. Pass as an argument instead.", diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index d09990ad7..d65fc73a4 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -133,7 +133,7 @@ const HEADER = let end function string_dominating_pareto_curve( - hallOfFame, dataset, options; width::Union{Integer,Nothing}=nothing + hallOfFame, dataset, options; width::Union{Integer,Nothing}=nothing, pretty::Bool=true ) terminal_width = (width === nothing) ? 100 : max(100, width::Integer) _buffer = IOBuffer() @@ -150,7 +150,7 @@ function string_dominating_pareto_curve( display_variable_names=dataset.display_variable_names, X_sym_units=dataset.X_sym_units, y_sym_units=dataset.y_sym_units, - raw=false, + pretty, ) y_prefix = dataset.y_variable_name unit_str = format_dimensions(dataset.y_sym_units) diff --git a/src/HierarchicalExpression.jl b/src/HierarchicalExpression.jl index 5375d8bbb..4d0673ef5 100644 --- a/src/HierarchicalExpression.jl +++ b/src/HierarchicalExpression.jl @@ -206,9 +206,6 @@ end @unstable function combine(ex::HierarchicalExpression, args...) return combine(get_metadata(ex).structure, args...) end -function get_function_keys(ex::HierarchicalExpression) - return get_function_keys(get_metadata(ex).structure) -end function DE.get_tree(ex::HierarchicalExpression{<:Any,<:Any,<:Any,E}) where {E} raw_contents = get_contents(ex) @@ -288,6 +285,7 @@ _color_string(s::AbstractString, c::Symbol) = styled"{$c:$s}" function DE.string_tree( tree::HierarchicalExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing; + pretty::Bool, variable_names=nothing, kws..., ) @@ -299,19 +297,30 @@ function DE.string_tree( variable_names = ["#" * string(i) for i in 1:total_num_features] inner_strings = NamedTuple{function_keys}( map( - ex -> DE.string_tree(ex, operators; variable_names, kws...), + ex -> DE.string_tree(ex, operators; pretty, variable_names, kws...), values(raw_contents), ), ) strings = NamedTuple{function_keys}( map( - (k, s, c) -> annotatedstring(string(k) * " = ", _color_string(s, c)), + (k, s, c) -> let + prefix = if !pretty || length(function_keys) == 1 + "" + elseif k == first(function_keys) + "╭ " + elseif k == last(function_keys) + "╰ " + else + "├ " + end + annotatedstring(prefix * string(k) * " = ", _color_string(s, c)) + end, function_keys, values(inner_strings), colors, ), ) - return annotatedstring(join(strings, styled"\n")) + return annotatedstring(join(strings, pretty ? styled"\n" : "; ")) end function DE.eval_tree_array( tree::HierarchicalExpression{T}, diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index 86f14d3be..7ccb1ab84 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -180,23 +180,21 @@ Convert an equation to a string. @inline function DE.string_tree( tree::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions; - raw::Bool=true, + pretty::Bool=false, X_sym_units=nothing, y_sym_units=nothing, variable_names=nothing, display_variable_names=variable_names, - varMap=nothing, kws..., ) - variable_names = deprecate_varmap(variable_names, varMap, :string_tree) - - if raw + if !pretty tree = tree isa GraphNode ? convert(Node, tree) : tree return DE.string_tree( tree, DE.get_operators(tree, options); f_variable=string_variable_raw, variable_names, + pretty, ) end @@ -213,6 +211,7 @@ Convert an equation to a string. ) end, variable_names=display_variable_names, + pretty, kws..., ) else @@ -222,6 +221,7 @@ Convert an equation to a string. f_variable=string_variable, f_constant=Fix{2}(Fix{3}(string_constant, ""), options.v_print_precision), variable_names=display_variable_names, + pretty, kws..., ) end diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index 395837ef2..ff66819ef 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -459,11 +459,14 @@ function _predict(m::M, fitresult, Xnew, idx, classes) where {M<:AbstractSRRegre end function get_equation_strings_for(::SRRegressor, trees, options, variable_names) - return (t -> string_tree(t, options; variable_names=variable_names)).(trees) + return ( + t -> string_tree(t, options; variable_names=variable_names, pretty=false) + ).(trees) end function get_equation_strings_for(::MultitargetSRRegressor, trees, options, variable_names) return [ - (t -> string_tree(t, options; variable_names=variable_names)).(ts) for ts in trees + (t -> string_tree(t, options; variable_names=variable_names, pretty=false)).(ts) for + ts in trees ] end diff --git a/src/deprecates.jl b/src/deprecates.jl index c8e0b4d57..6b6fb29ac 100644 --- a/src/deprecates.jl +++ b/src/deprecates.jl @@ -4,7 +4,7 @@ import .HallOfFameModule: calculate_pareto_frontier import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size @deprecate( - calculate_pareto_frontier(X, y, hallOfFame, options; weights=nothing, varMap=nothing), + calculate_pareto_frontier(X, y, hallOfFame, options; weights=nothing), calculate_pareto_frontier(hallOfFame) ) @deprecate( @@ -41,7 +41,6 @@ import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size loss_type::Type=Nothing, # Deprecated: multithreaded=nothing, - varMap=nothing, ) where {T<:DATA_TYPE}, equation_search( X, @@ -58,7 +57,6 @@ import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size saved_state, loss_type, multithreaded, - varMap, ) ) diff --git a/test/test_units.jl b/test/test_units.jl index a586f5e3c..da7f45fa3 100644 --- a/test/test_units.jl +++ b/test/test_units.jl @@ -337,15 +337,15 @@ end @test string_tree(tree, options) == "(1.0 * (x1 + ((x2 * x3) * 5.32))) - cos(1.5 * (x1 - 0.5))" - @test string_tree(tree, options; raw=false) == + @test string_tree(tree, options; pretty=true) == "(1 * (x₁ + ((x₂ * x₃) * 5.32))) - cos(1.5 * (x₁ - 0.5))" @test string_tree( - tree, options; raw=false, display_variable_names=dataset.display_variable_names + tree, options; pretty=true, display_variable_names=dataset.display_variable_names ) == "(1 * (x₁ + ((x₂ * x₃) * 5.32))) - cos(1.5 * (x₁ - 0.5))" @test string_tree( tree, options; - raw=false, + pretty=true, display_variable_names=dataset.display_variable_names, X_sym_units=dataset.X_sym_units, y_sym_units=dataset.y_sym_units, @@ -355,7 +355,7 @@ end @test string_tree( x5 * 3.2, options; - raw=false, + pretty=true, display_variable_names=dataset.display_variable_names, X_sym_units=dataset.X_sym_units, y_sym_units=dataset.y_sym_units, @@ -366,7 +366,7 @@ end @test string_tree( x5 * 3.2, options; - raw=false, + pretty=true, display_variable_names=dataset2.display_variable_names, X_sym_units=dataset2.X_sym_units, y_sym_units=dataset2.y_sym_units, @@ -381,7 +381,7 @@ end @test string_tree( x5 * 3.2, options; - raw=false, + pretty=true, display_variable_names=dataset2.display_variable_names, X_sym_units=dataset2.X_sym_units, y_sym_units=dataset2.y_sym_units, From cf631f8b7a5f336a7aa6e946a2824f5654cfcc4a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 20:03:54 +0000 Subject: [PATCH 16/59] refactor!: fully deprecate varMap --- src/SymbolicRegression.jl | 11 ----------- test/test_composable_expression.jl | 6 ++---- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index f07542a95..ce0da290d 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -190,15 +190,6 @@ catch VersionNumber(0, 0, 0) end -function deprecate_varmap(variable_names, varMap, func_name) - if varMap !== nothing - Base.depwarn("`varMap` is deprecated; use `variable_names` instead", func_name) - @assert variable_names === nothing "Cannot pass both `varMap` and `variable_names`" - variable_names = varMap - end - return variable_names -end - using DispatchDoctor: @stable @stable default_mode = "disable" begin @@ -451,7 +442,6 @@ function equation_search( v_dim_out::Val{DIM_OUT}=Val(nothing), # Deprecated: multithreaded=nothing, - varMap=nothing, ) where {T<:DATA_TYPE,L,DIM_OUT} if multithreaded !== nothing error( @@ -459,7 +449,6 @@ function equation_search( "Choose one of :multithreaded, :multiprocessing, or :serial.", ) end - variable_names = deprecate_varmap(variable_names, varMap, :equation_search) if weights !== nothing @assert length(weights) == length(y) diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index a0bac36c0..281be8aea 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -49,7 +49,7 @@ end expr = HierarchicalExpression((; f=x1, g=x2 * x2); structure, operators, variable_names) @test String(string_tree(expr)) == "f = #1\ng = #2 * #2" - @test string_tree(get_tree(expr)) == "x1 - (x1 * x1)" + @test string_tree(get_tree(expr), operators) == "x1 - (x1 * x1)" @test Interfaces.test(ExpressionInterface, HierarchicalExpression, [expr]) end @@ -57,9 +57,7 @@ end using SymbolicRegression structure = HierarchicalStructure{(:f, :g)}( - ((; f, g), (x1, x2, x3)) -> let - sin(f(x1, x2)) + g(x3)^2 - end + ((; f, g), (x1, x2, x3)) -> sin(f(x1, x2)) + g(x3)^2 ) operators = Options().operators variable_names = ["x1", "x2", "x3"] From cdfaaca2a8a31574e0fd5a130010caeb821c8f21 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 20:09:42 +0000 Subject: [PATCH 17/59] fix: fix old use of `pretty` --- src/HierarchicalExpression.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/HierarchicalExpression.jl b/src/HierarchicalExpression.jl index 4d0673ef5..ed72cba3d 100644 --- a/src/HierarchicalExpression.jl +++ b/src/HierarchicalExpression.jl @@ -285,7 +285,7 @@ _color_string(s::AbstractString, c::Symbol) = styled"{$c:$s}" function DE.string_tree( tree::HierarchicalExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing; - pretty::Bool, + pretty::Bool=false, variable_names=nothing, kws..., ) From ee190666a3d2a63491f1713c3865e5e7d4bcfe01 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 20:30:42 +0000 Subject: [PATCH 18/59] feat: allow custom complexity functions --- src/Complexity.jl | 9 +++++++-- src/Configure.jl | 19 ++++++++++++++++--- src/Options.jl | 26 ++++++++++++++++++++------ src/OptionsStruct.jl | 2 +- src/Population.jl | 2 +- src/SymbolicRegression.jl | 1 + 6 files changed, 46 insertions(+), 13 deletions(-) diff --git a/src/Complexity.jl b/src/Complexity.jl index dec8fb63e..54b101898 100644 --- a/src/Complexity.jl +++ b/src/Complexity.jl @@ -20,12 +20,17 @@ if these are defined. function compute_complexity( tree::AbstractExpression, options::AbstractOptions; break_sharing=Val(false) ) - return compute_complexity(get_tree(tree), options; break_sharing) + if options.complexity_mapping isa Function + return options.complexity_mapping(tree)::Int + else + return compute_complexity(get_tree(tree), options; break_sharing) + end end function compute_complexity( tree::AbstractExpressionNode, options::AbstractOptions; break_sharing=Val(false) )::Int - if options.complexity_mapping.use + complexity_mapping = options.complexity_mapping + if complexity_mapping isa ComplexityMapping && complexity_mapping.use raw_complexity = _compute_complexity( tree, options.complexity_mapping; break_sharing ) diff --git a/src/Configure.jl b/src/Configure.jl index d8f029bfa..6f72ced4f 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -120,7 +120,12 @@ function move_functions_to_workers( ) where {T} # All the types of functions we need to move to workers: function_sets = ( - :unaops, :binops, :elementwise_loss, :early_stop_condition, :loss_function + :unaops, + :binops, + :elementwise_loss, + :early_stop_condition, + :loss_function, + :complexity_mapping, ) for function_set in function_sets @@ -152,6 +157,12 @@ function move_functions_to_workers( end ops = (options.loss_function,) example_inputs = (Node(T; val=zero(T)), dataset, options) + elseif function_set == :complexity_mapping + if options.complexity_mapping isa Union{ComplexityMapping,Function} + continue + end + ops = (options.complexity_mapping,) + example_inputs = (create_expression(zero(T), options, dataset),) else error("Invalid function set: $function_set") end @@ -171,7 +182,9 @@ function move_functions_to_workers( end end -function copy_definition_to_workers(op, procs, options::AbstractOptions, verbosity) +function copy_definition_to_workers( + @nospecialize(op), procs, @nospecialize(options::AbstractOptions), verbosity +) name = nameof(op) verbosity > 0 && @info "Copying definition of $op to workers..." src_ms = methods(op).ms @@ -195,7 +208,7 @@ function test_function_on_workers(example_inputs, op, procs) end function activate_env_on_workers( - procs, project_path::String, options::AbstractOptions, verbosity + procs, project_path::String, @nospecialize(options::AbstractOptions), verbosity ) verbosity > 0 && @info "Activating environment on workers." @everywhere procs begin diff --git a/src/Options.jl b/src/Options.jl index aa247709e..48df70bd9 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -326,6 +326,10 @@ const OPTION_DESCRIPTIONS = """- `defaults`: What set of defaults to use for `Op - `complexity_of_variables`: What complexity should be assigned to use of a variable, which can also be a vector indicating different per-variable complexity. By default, this is 1. +- `complexity_mapping`: Alternatively, you can pass a function that takes + the expression as input and returns the complexity. Make sure that + this operates on `AbstractExpression` (and unpacks to `AbstractExpressionNode`), + and returns an integer. - `alpha`: The probability of accepting an equation mutation during regularized evolution is given by exp(-delta_loss/(alpha * T)), where T goes from 1 to 0. Thus, alpha=infinite is the same as no annealing. @@ -474,6 +478,7 @@ $(OPTION_DESCRIPTIONS) @nospecialize(complexity_of_operators = nothing), @nospecialize(complexity_of_constants::Union{Nothing,Real} = nothing), @nospecialize(complexity_of_variables::Union{Nothing,Real,AbstractVector} = nothing), + ### complexity_mapping @nospecialize(warmup_maxsize_by::Union{Real,Nothing} = nothing), ### use_frequency ### use_frequency_in_tournament @@ -541,6 +546,7 @@ $(OPTION_DESCRIPTIONS) ## 3. The Objective: dimensionless_constants_only::Bool=false, ## 4. Working with Complexities: + complexity_mapping::Union{Function,ComplexityMapping,Nothing}=nothing, use_frequency::Bool=true, use_frequency_in_tournament::Bool=true, should_simplify::Union{Nothing,Bool}=nothing, @@ -689,6 +695,11 @@ $(OPTION_DESCRIPTIONS) error("You cannot specify both `elementwise_loss` and `loss_function`.") end end + if complexity_mapping !== nothing + @assert complexity_of_operators === nothing && + complexity_of_constants === nothing && + complexity_of_variables === nothing + end ################################# #### Supply defaults ############ @@ -761,12 +772,15 @@ $(OPTION_DESCRIPTIONS) una_constraints, bin_constraints, unary_operators, binary_operators ) - complexity_mapping = ComplexityMapping( - complexity_of_operators, - complexity_of_variables, - complexity_of_constants, - binary_operators, - unary_operators, + complexity_mapping = @something( + complexity_mapping, + ComplexityMapping( + complexity_of_operators, + complexity_of_variables, + complexity_of_constants, + binary_operators, + unary_operators, + ) ) if maxdepth === nothing diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index b39dbf0b5..22871606d 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -178,7 +178,7 @@ all properties of `Options` available for internal methods in SymbolicRegression abstract type AbstractOptions end struct Options{ - CM<:ComplexityMapping, + CM<:Union{ComplexityMapping,Function}, OP<:AbstractOperatorEnum, N<:AbstractExpressionNode, E<:AbstractExpression, diff --git a/src/Population.jl b/src/Population.jl index 6b9173c5c..ce76ee326 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -205,7 +205,7 @@ function record_population(pop::Population, options::AbstractOptions)::RecordTyp return RecordType( "population" => [ RecordType( - "tree" => string_tree(member.tree, options), + "tree" => string_tree(member.tree, options; pretty=false), "loss" => member.loss, "score" => member.score, "complexity" => compute_complexity(member, options), diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index ce0da290d..e8a13c6ce 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -231,6 +231,7 @@ using .CoreModule: Dataset, AbstractOptions, Options, + ComplexityMapping, AbstractMutationWeights, MutationWeights, is_weighted, From 4609e031d93c689d63eacf1886863f32f5dda46e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 20:38:41 +0000 Subject: [PATCH 19/59] test: custom complexity function --- src/Dataset.jl | 2 - src/InterfaceDynamicExpressions.jl | 2 - src/MLJInterface.jl | 3 +- test/runtests.jl | 4 +- test/test_complexity.jl | 128 +++++++++++++++++------------ 5 files changed, 80 insertions(+), 59 deletions(-) diff --git a/src/Dataset.jl b/src/Dataset.jl index 4a44180ae..2635fb995 100644 --- a/src/Dataset.jl +++ b/src/Dataset.jl @@ -6,8 +6,6 @@ using ..UtilsModule: subscriptify, get_base_type using ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE using ...InterfaceDynamicQuantitiesModule: get_si_units, get_sym_units -import ...deprecate_varmap - """ Dataset{T<:DATA_TYPE,L<:LOSS_TYPE} diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index 7ccb1ab84..6bcf2d82a 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -15,8 +15,6 @@ using ..CoreModule: AbstractOptions, Dataset using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap using ..UtilsModule: subscriptify -import ..deprecate_varmap - """ eval_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::AbstractOptions; kws...) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index ff66819ef..b98a98f5b 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -24,7 +24,8 @@ using DynamicQuantities: dimension using LossFunctions: SupervisedLoss using ..InterfaceDynamicQuantitiesModule: get_dimensions_type -using ..CoreModule: Options, Dataset, AbstractMutationWeights, MutationWeights, LOSS_TYPE +using ..CoreModule: + Options, Dataset, AbstractMutationWeights, MutationWeights, LOSS_TYPE, ComplexityMapping using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame diff --git a/test/runtests.jl b/test/runtests.jl index dacb374a0..9c5ade786 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -77,9 +77,7 @@ end include("test_nested_constraints.jl") end -@testitem "Test complexity evaluation" tags = [:part3] begin - include("test_complexity.jl") -end +include("test_complexity.jl") @testitem "Test options" tags = [:part1] begin include("test_options.jl") diff --git a/test/test_complexity.jl b/test/test_complexity.jl index deaad6813..130271c2a 100644 --- a/test/test_complexity.jl +++ b/test/test_complexity.jl @@ -1,54 +1,80 @@ -println("Testing custom complexities.") -using SymbolicRegression, Test +@testitem "Test complexity evaluation" tags = [:part3] begin + using SymbolicRegression -x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") + x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") -# First, test regular complexities: -function make_options(; kw...) - return Options(; binary_operators=(+, -, *, /, ^), unary_operators=(cos, sin), kw...) + # First, test regular complexities: + function make_options(; kw...) + return Options(; + binary_operators=(+, -, *, /, ^), unary_operators=(cos, sin), kw... + ) + end + options = make_options() + @extend_operators options + tree = sin((x1 + x2 + x3)^2.3) + @test compute_complexity(tree, options) == 8 + + options = make_options(; complexity_of_operators=[sin => 3]) + @test compute_complexity(tree, options) == 10 + options = make_options(; complexity_of_operators=[sin => 3, (+) => 2]) + @test compute_complexity(tree, options) == 12 + + # Real numbers: + options = make_options(; complexity_of_operators=[sin => 3, (+) => 2, (^) => 3.2]) + @test compute_complexity(tree, options) == round(Int, 12 + (3.2 - 1)) +end + +@testitem "Test other things about complexity" tags = [:part3] begin + using SymbolicRegression + + x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") + + function make_options(; kw...) + return Options(; + binary_operators=(+, -, *, /, ^), unary_operators=(cos, sin), kw... + ) + end + + options = make_options(; + complexity_of_operators=[sin => 3, (+) => 2], complexity_of_variables=2 + ) + @test compute_complexity(tree, options) == 12 + 3 * 1 + options = make_options(; + complexity_of_operators=[sin => 3, (+) => 2], + complexity_of_variables=2, + complexity_of_constants=2, + ) + @test compute_complexity(tree, options) == 12 + 3 * 1 + 1 + options = make_options(; + complexity_of_operators=[sin => 3, (+) => 2], + complexity_of_variables=2, + complexity_of_constants=2.6, + ) + @test compute_complexity(tree, options) == 12 + 3 * 1 + 1 + 1 + + # Custom variables + options = make_options(; + complexity_of_variables=[1, 2, 3], complexity_of_operators=[(+) => 5, (*) => 2] + ) + x1, x2, x3 = [Node{Float64}(; feature=i) for i in 1:3] + tree = x1 + x2 * x3 + @test compute_complexity(tree, options) == 1 + 5 + 2 + 2 + 3 + options = make_options(; + complexity_of_variables=2, complexity_of_operators=[(+) => 5, (*) => 2] + ) + @test compute_complexity(tree, options) == 2 + 5 + 2 + 2 + 2 +end + +@testitem "Custom complexity mapping" tags = [:part3] begin + using SymbolicRegression + + function custom_complexity(tree) + @test tree isa AbstractExpression + return 10 + end + + options = Options(; complexity_mapping=custom_complexity) + variable_names = ["x1"] + x1 = Expression(Node{Float64}(; feature=1); options.operators, variable_names) + @test compute_complexity(x1, options) == 10 end -options = make_options() -@extend_operators options -tree = sin((x1 + x2 + x3)^2.3) -@test compute_complexity(tree, options) == 8 - -options = make_options(; complexity_of_operators=[sin => 3]) -@test compute_complexity(tree, options) == 10 -options = make_options(; complexity_of_operators=[sin => 3, (+) => 2]) -@test compute_complexity(tree, options) == 12 - -# Real numbers: -options = make_options(; complexity_of_operators=[sin => 3, (+) => 2, (^) => 3.2]) -@test compute_complexity(tree, options) == round(Int, 12 + (3.2 - 1)) - -# Now, test other things, like variables and constants: -options = make_options(; - complexity_of_operators=[sin => 3, (+) => 2], complexity_of_variables=2 -) -@test compute_complexity(tree, options) == 12 + 3 * 1 -options = make_options(; - complexity_of_operators=[sin => 3, (+) => 2], - complexity_of_variables=2, - complexity_of_constants=2, -) -@test compute_complexity(tree, options) == 12 + 3 * 1 + 1 -options = make_options(; - complexity_of_operators=[sin => 3, (+) => 2], - complexity_of_variables=2, - complexity_of_constants=2.6, -) -@test compute_complexity(tree, options) == 12 + 3 * 1 + 1 + 1 - -# Custom variables -options = make_options(; - complexity_of_variables=[1, 2, 3], complexity_of_operators=[(+) => 5, (*) => 2] -) -x1, x2, x3 = [Node{Float64}(; feature=i) for i in 1:3] -tree = x1 + x2 * x3 -@test compute_complexity(tree, options) == 1 + 5 + 2 + 2 + 3 -options = make_options(; - complexity_of_variables=2, complexity_of_operators=[(+) => 5, (*) => 2] -) -@test compute_complexity(tree, options) == 2 + 5 + 2 + 2 + 2 - -println("Passed.") From 666babd1da702735a736de336413e9fa5099fefe Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 20:58:48 +0000 Subject: [PATCH 20/59] feat: expose VectorWrapper --- src/SymbolicRegression.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index e8a13c6ce..607185d91 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -314,7 +314,8 @@ using .SearchUtilsModule: get_cur_maxsize, update_hall_of_fame! using .TemplateExpressionModule: TemplateExpression, TemplateStructure -using .HierarchicalExpressionModule: HierarchicalExpression, HierarchicalStructure +using .HierarchicalExpressionModule: + HierarchicalExpression, HierarchicalStructure, VectorWrapper using .ComposableExpressionModule: ComposableExpression using .ExpressionBuilderModule: embed_metadata, strip_metadata From 8c3596b4bf7164688e0f550c4c15598258c48b96 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 21:19:01 +0000 Subject: [PATCH 21/59] refactor: more efficient mutations for hierarchical --- src/Configure.jl | 2 +- src/Core.jl | 2 +- src/Dataset.jl | 4 ++++ src/HierarchicalExpression.jl | 9 ++++++++- src/Mutate.jl | 9 +++++++-- src/OptionsStruct.jl | 2 ++ src/SearchUtils.jl | 6 +++--- src/SymbolicRegression.jl | 3 ++- 8 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/Configure.jl b/src/Configure.jl index 6f72ced4f..61c66d3e5 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -299,7 +299,7 @@ function test_entire_pipeline( population_size=20, nlength=3, options=options, - nfeatures=dataset.nfeatures, + nfeatures=max_features(dataset, options), ) tmp_pop = s_r_cycle( dataset, diff --git a/src/Core.jl b/src/Core.jl index 6000412ce..8e56c0334 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -12,7 +12,7 @@ include("Options.jl") using .ProgramConstantsModule: MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE -using .DatasetModule: Dataset, is_weighted, has_units +using .DatasetModule: Dataset, is_weighted, has_units, max_features using .MutationWeightsModule: AbstractMutationWeights, MutationWeights, sample_mutation using .OptionsStructModule: AbstractOptions, diff --git a/src/Dataset.jl b/src/Dataset.jl index 2635fb995..65cc909cb 100644 --- a/src/Dataset.jl +++ b/src/Dataset.jl @@ -235,4 +235,8 @@ _fill!(x::NamedTuple, val) = foreach(v -> _fill!(v, val), values(x)) _fill!(::Nothing, val) = nothing _fill!(x, val) = x +function max_features(dataset::Dataset, _) + return dataset.nfeatures +end + end diff --git a/src/HierarchicalExpression.jl b/src/HierarchicalExpression.jl index ed72cba3d..81807fe15 100644 --- a/src/HierarchicalExpression.jl +++ b/src/HierarchicalExpression.jl @@ -26,7 +26,7 @@ using DynamicExpressions.InterfacesModule: ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments using ..CoreModule: - AbstractOptions, Dataset, CoreModule as CM, AbstractMutationWeights, has_units + AbstractOptions, Options, Dataset, CoreModule as CM, AbstractMutationWeights, has_units using ..ConstantOptimizationModule: ConstantOptimizationModule as CO using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..MutationFunctionsModule: MutationFunctionsModule as MF @@ -377,6 +377,13 @@ function CM.operator_specialization( return O end +function CM.max_features( + dataset::Dataset, options::Options{<:Any,<:Any,<:Any,<:HierarchicalExpression} +) + num_features = options.expression_options.structure.num_features + return max(values(num_features)...) +end + """ We pick a random subexpression to mutate, and also return the symbol we mutated on so that we can put it back together later. diff --git a/src/Mutate.jl b/src/Mutate.jl index 7b828f6f3..8c24de72d 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -9,7 +9,12 @@ using DynamicExpressions: simplify_tree!, combine_operators using ..CoreModule: - AbstractOptions, AbstractMutationWeights, Dataset, RecordType, sample_mutation + AbstractOptions, + AbstractMutationWeights, + Dataset, + RecordType, + sample_mutation, + max_features using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: score_func, score_func_batched using ..CheckConstraintsModule: check_constraints @@ -173,7 +178,7 @@ function next_generation( member.score, member.loss end - nfeatures = dataset.nfeatures + nfeatures = max_features(dataset, options) weights = copy(options.mutation_weights) diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index 22871606d..4f78a53a2 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -6,6 +6,8 @@ using DynamicExpressions: AbstractOperatorEnum, AbstractExpressionNode, AbstractExpression, OperatorEnum using LossFunctions: SupervisedLoss +using ..DatasetModule: Dataset +import ..DatasetModule: max_features import ..MutationWeightsModule: AbstractMutationWeights """ diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 745d7581f..6f7d3a991 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -13,7 +13,7 @@ using Compat: Fix using DynamicExpressions: AbstractExpression, string_tree using ..UtilsModule: subscriptify -using ..CoreModule: Dataset, AbstractOptions, Options, MAX_DEGREE, RecordType +using ..CoreModule: Dataset, AbstractOptions, Options, MAX_DEGREE, RecordType, max_features using ..ComplexityModule: compute_complexity using ..PopulationModule: Population using ..PopMemberModule: PopMember @@ -269,7 +269,7 @@ function init_dummy_pops( first(datasets); population_size=1, options=options, - nfeatures=first(datasets).nfeatures, + nfeatures=max_features(first(datasets), options), ) # ^ Due to occasional inference issue, we manually specify the return type return [ @@ -281,7 +281,7 @@ function init_dummy_pops( datasets[j]; population_size=1, options=options, - nfeatures=datasets[j].nfeatures, + nfeatures=max_features(datasets[j], options), ) end for i in 1:npops ] for j in 1:length(datasets) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 607185d91..3fbdec635 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -234,6 +234,7 @@ using .CoreModule: ComplexityMapping, AbstractMutationWeights, MutationWeights, + max_features, is_weighted, sample_mutation, plus, @@ -714,7 +715,7 @@ function _initialize_search!( population_size=options.population_size, nlength=3, options=options, - nfeatures=datasets[j].nfeatures, + nfeatures=max_features(datasets[j], options), ), HallOfFame(options, datasets[j]), RecordType(), From aa3b4359d563e53eab7f487c5c4d1c23386caa8f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 21:20:35 +0000 Subject: [PATCH 22/59] test: fix pretty print format --- test/test_composable_expression.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index 281be8aea..32083b578 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -48,7 +48,8 @@ end expr = HierarchicalExpression((; f=x1, g=x2 * x2); structure, operators, variable_names) - @test String(string_tree(expr)) == "f = #1\ng = #2 * #2" + @test String(string_tree(expr)) == "f = #1; g = #2 * #2" + @test String(string_tree(expr; pretty=true)) == "f = #1\ng = #2 * #2" @test string_tree(get_tree(expr), operators) == "x1 - (x1 * x1)" @test Interfaces.test(ExpressionInterface, HierarchicalExpression, [expr]) end From 9445bf4acdaf90253cbefde791f81468bce1d6eb Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 21:41:30 +0000 Subject: [PATCH 23/59] refactor: name `ValidVector` --- src/ComposableExpression.jl | 30 +++++++++++++++--------------- src/HierarchicalExpression.jl | 8 +++----- src/SymbolicRegression.jl | 2 +- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index 0b8fb4067..d993b205d 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -88,11 +88,11 @@ end ExpressionInterface{all_ei_methods_except(())}, ComposableExpression, [Arguments()] ) -struct VectorWrapper{A<:AbstractVector} +struct ValidVector{A<:AbstractVector} value::A valid::Bool end -VectorWrapper(x::Tuple{Vararg{Any,2}}) = VectorWrapper(x...) +ValidVector(x::Tuple{Vararg{Any,2}}) = ValidVector(x...) function (ex::AbstractComposableExpression)(x) return error("ComposableExpression does not support input of type $(typeof(x))") @@ -100,7 +100,7 @@ end function (ex::AbstractComposableExpression)(x::AbstractVector, _xs::AbstractVector...) xs = (x, _xs...) # Wrap it up for the recursive call - xs = ntuple(i -> VectorWrapper(xs[i], true), Val(length(xs))) + xs = ntuple(i -> ValidVector(xs[i], true), Val(length(xs))) result = ex(xs...) # Unwrap it if result.valid @@ -110,14 +110,14 @@ function (ex::AbstractComposableExpression)(x::AbstractVector, _xs::AbstractVect return result.value .* nan end end -function (ex::AbstractComposableExpression)(x::VectorWrapper, _xs::VectorWrapper...) +function (ex::AbstractComposableExpression)(x::ValidVector, _xs::ValidVector...) xs = (x, _xs...) valid = all(xi -> xi.valid, xs) if !valid - return VectorWrapper(first(xs).value, false) + return ValidVector(first(xs).value, false) else X = Matrix(stack(map(xi -> xi.value, xs))') - return VectorWrapper(eval_tree_array(ex, X)) + return ValidVector(eval_tree_array(ex, X)) end end function (ex::AbstractComposableExpression)( @@ -138,20 +138,20 @@ function (ex::AbstractComposableExpression)( return with_contents(ex, tree) end -# Basically we want to vectorize every single operation on VectorWrapper, +# Basically we want to vectorize every single operation on ValidVector, # so that the user can use it easily. function apply_operator(op::F, x...) where {F<:Function} if all(_is_valid, x) vx = map(_get_value, x) - return VectorWrapper(op.(vx...), true) + return ValidVector(op.(vx...), true) else - return VectorWrapper(_get_value(first(x)), false) + return ValidVector(_get_value(first(x)), false) end end -_is_valid(x::VectorWrapper) = x.valid +_is_valid(x::ValidVector) = x.valid _is_valid(x) = true -_get_value(x::VectorWrapper) = x.value +_get_value(x::ValidVector) = x.value _get_value(x) = x #! format: off @@ -162,9 +162,9 @@ for op in ( :&, :|, :⊻, ://, :\, ) @eval begin - Base.$(op)(x::VectorWrapper, y::VectorWrapper) = apply_operator(Base.$(op), x, y) - Base.$(op)(x::VectorWrapper, y::Number) = apply_operator(Base.$(op), x, y) - Base.$(op)(x::Number, y::VectorWrapper) = apply_operator(Base.$(op), x, y) + Base.$(op)(x::ValidVector, y::ValidVector) = apply_operator(Base.$(op), x, y) + Base.$(op)(x::ValidVector, y::Number) = apply_operator(Base.$(op), x, y) + Base.$(op)(x::Number, y::ValidVector) = apply_operator(Base.$(op), x, y) end end @@ -180,7 +180,7 @@ for op in ( :inv, :sqrt, :cbrt, :abs2, :angle, :factorial, :(!), :-, :+, :sign, :identity, ) - @eval Base.$(op)(x::VectorWrapper) = apply_operator(Base.$(op), x) + @eval Base.$(op)(x::ValidVector) = apply_operator(Base.$(op), x) end #! format: on diff --git a/src/HierarchicalExpression.jl b/src/HierarchicalExpression.jl index 81807fe15..bcff05fbe 100644 --- a/src/HierarchicalExpression.jl +++ b/src/HierarchicalExpression.jl @@ -37,7 +37,7 @@ using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF using ..MutateModule: MutateModule as MM using ..PopMemberModule: PopMember -using ..ComposableExpressionModule: ComposableExpression, VectorWrapper +using ..ComposableExpressionModule: ComposableExpression, ValidVector """ HierarchicalStructure{K,S,N,E,C} <: Function @@ -97,7 +97,7 @@ function infer_variable_constraints(::Val{K}, combiner::F) where {K,F} # We use an evaluation to get the variable constraints combiner( _recorders_of_composable_expressions, - Base.Iterators.repeated(VectorWrapper(ones(Float64, 1), true)), + Base.Iterators.repeated(ValidVector(ones(Float64, 1), true)), ) inferred = NamedTuple{K}(map(x -> x[], values(variable_constraints))) if any(==(-1), values(inferred)) @@ -332,9 +332,7 @@ function DE.eval_tree_array( if has_invalid_variables(tree) return (cX[1, :], false) end - result = combine( - tree, raw_contents, map(x -> VectorWrapper(copy(x), true), eachrow(cX)) - ) + result = combine(tree, raw_contents, map(x -> ValidVector(copy(x), true), eachrow(cX))) return result.value, result.valid end function (ex::HierarchicalExpression)( diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 3fbdec635..89176002e 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -316,7 +316,7 @@ using .SearchUtilsModule: update_hall_of_fame! using .TemplateExpressionModule: TemplateExpression, TemplateStructure using .HierarchicalExpressionModule: - HierarchicalExpression, HierarchicalStructure, VectorWrapper + HierarchicalExpression, HierarchicalStructure, ValidVector using .ComposableExpressionModule: ComposableExpression using .ExpressionBuilderModule: embed_metadata, strip_metadata From 2f9d17e3ec5834544d9f2a7a7ac17204e5d4124a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 22:01:30 +0000 Subject: [PATCH 24/59] docs: document ValidVector and ComposableExpression --- src/ComposableExpression.jl | 43 +++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index d993b205d..3d3e8dc95 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -21,6 +21,30 @@ using ..ConstantOptimizationModule: ConstantOptimizationModule as CO abstract type AbstractComposableExpression{T,N} <: AbstractExpression{T,N} end +""" + ComposableExpression{T,N,D} <: AbstractComposableExpression{T,N} <: AbstractExpression{T,N} + +A symbolic expression representing a mathematical formula as an expression tree (`tree::N`) with associated metadata (`metadata::Metadata{D}`). Used to construct and manipulate expressions in symbolic regression tasks. + +Example: + +Create variables `x1` and `x2`, and build an expression `f = x1 * sin(x2)`: + +```julia +operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) +variable_names = ["x1", "x2"] +x1 = ComposableExpression(Node(Float64; feature=1); operators, variable_names) +x2 = ComposableExpression(Node(Float64; feature=2); operators, variable_names) +f = x1 * sin(x2) +# ^This now references the first and second arguments of things passed to it: + +f(x1, x1) # == x1 * sin(x1) +f(randn(5), randn(5)) # == randn(5) .* sin.(randn(5)) + +# You can even pass it to itself: +f(f, f) # == (x1 * sin(x2)) * sin((x1 * sin(x2))) +``` +""" struct ComposableExpression{ T, N<:AbstractExpressionNode{T}, @@ -88,6 +112,25 @@ end ExpressionInterface{all_ei_methods_except(())}, ComposableExpression, [Arguments()] ) +""" + ValidVector{A<:AbstractVector} + +A wrapper for an AbstractVector paired with a validity flag (valid::Bool). +It represents a vector along with a boolean indicating whether the data is valid. +This is useful in computations where certain operations might produce invalid data +(e.g., division by zero), allowing the validity to propagate through calculations. +Operations on `ValidVector` instances automatically handle the valid flag: if all +operands are valid, the result is valid; if any operand is invalid, the result is +marked invalid. + +You will need to work with this to do highly custom operations with +`ComposableExpression` and `HierarchicalExpression`. + +# Fields: + +- `value::A`: The vector data. +- `valid::Bool`: Indicates if the data is valid. +""" struct ValidVector{A<:AbstractVector} value::A valid::Bool From dc2d50975077a4089b45de1ec290ad96dbd1e784 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 22:31:54 +0000 Subject: [PATCH 25/59] feat!: move `HierarchicalExpression` into place of `TemplateExpression` --- examples/template_expression.jl | 31 +- examples/template_expression_complex.jl | 96 ++--- src/ComposableExpression.jl | 18 +- src/HierarchicalExpression.jl | 495 ------------------------ src/SymbolicRegression.jl | 32 +- src/TemplateExpression.jl | 340 ++++++++-------- test/test_composable_expression.jl | 19 +- test/test_template_expression.jl | 227 ----------- 8 files changed, 253 insertions(+), 1005 deletions(-) delete mode 100644 src/HierarchicalExpression.jl delete mode 100644 test/test_template_expression.jl diff --git a/examples/template_expression.jl b/examples/template_expression.jl index 8c2465b1a..4f02fd754 100644 --- a/examples/template_expression.jl +++ b/examples/template_expression.jl @@ -6,20 +6,27 @@ using Test: @test options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) operators = options.operators variable_names = (i -> "x$i").(1:3) -x1, x2, x3 = (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - -structure = TemplateStructure{(:f, :g1, :g2)}(; - combine_vectors=e -> map((f, g1, g2) -> (f + g1, f + g2), e.f, e.g1, e.g2), - combine_strings=e -> "( $(e.f) + $(e.g1), $(e.f) + $(e.g2) )", - variable_constraints=(; f=[1, 2], g1=[3], g2=[3]), +x1, x2, x3 = (i -> ComposableExpression(Node(Float64; feature=i); operators, variable_names)).(1:3) + +structure = TemplateStructure{(:f, :g1, :g2)}( + ((; f, g1, g2), (x1, x2, x3)) -> let + _f = f(x1, x2) + _g1 = g1(x3) + _g2 = g2(x3) + _out1 = _f + _g1 + _out2 = _f + _g2 + ValidVector(map(tuple, _out1.x, _out2.x), _out1.valid && _out2.valid) + end, ) st_expr = TemplateExpression((; f=x1, g1=x3, g2=x3); structure, operators, variable_names) -X = rand(100, 3) .* 10 +x1 = rand(100) +x2 = rand(100) +x3 = rand(100) # Our dataset is a vector of 2-tuples -y = [(sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3]) for i in eachindex(axes(X, 1))] +y = [(sin(x1[i]) + x3[i]^2, sin(x1[i]) + x3[i]) for i in eachindex(x1, x2, x3)] model = SRRegressor(; binary_operators=(+, *), @@ -32,7 +39,7 @@ model = SRRegressor(; early_stop_condition=(loss, complexity) -> loss < 1e-5 && complexity <= 7, ) -mach = machine(model, X, y) +mach = machine(model, [x1 x2 x3], y) fit!(mach) # Check the performance of the model @@ -48,6 +55,6 @@ best_f = get_contents(best_expr).f best_g1 = get_contents(best_expr).g1 best_g2 = get_contents(best_expr).g2 -@test best_f(X') ≈ (@. sin(X[:, 1])) -@test best_g1(X') ≈ (@. X[:, 3] * X[:, 3]) -@test best_g2(X') ≈ (@. X[:, 3]) +@test best_f(x1, x2) ≈ @. sin.(x1) +@test best_g1(x3) ≈ (@. x3 * x3) +@test best_g2(x3) ≈ (@. x3) diff --git a/examples/template_expression_complex.jl b/examples/template_expression_complex.jl index b3794e823..e4e8921ce 100644 --- a/examples/template_expression_complex.jl +++ b/examples/template_expression_complex.jl @@ -29,7 +29,9 @@ the components of a particle's motion under magnetic and drag forces. We'll see Let's get started! =# -using SymbolicRegression, Random +using SymbolicRegression +using SymbolicRegression: ValidVector +using Random using MLJBase: machine, fit!, predict, report #= @@ -168,61 +170,63 @@ variable_names = ["t", "v_x", "v_y", "v_z", "T"] Template expressions require you to define a _structure_ function, which describes how to combine the sub-expressions into a single expression, numerically evaluate them, and print them. +These are evaluated using `ComposableExpression` for the individual +subexpressions (which allow them to be composed into new expressions), +and `ValidVector` for carrying through evaluation results. -First, let's just make a function that prints the expression: +Let's define our structure function. Note that this takes two arguments, +one being a named tuple of our expressions (`::ComposableExpression`), and the other being a tuple +of the input variables (`::ValidVector`). =# -function combine_strings(e) - ## e is a named tuple of strings representing each formula - return " ╭ 𝐁 = [ " * e.B_x * " , " * e.B_y * " , " * e.B_z * " ]\n ╰ 𝐅 = (" * e.F_d_scale * ") * 𝐯" - ## (Note that string interpolation will erase the colors, so use `*` instead) -end - -#= -So, this will just print the separate B and F_d expressions we've learned. - -Then, let's define an expression that takes the numerical values -evaluated in the TemplateExpression, and combines them into the resultant -force vector. Inside this function, we can do whatever we want. -=# -function combine_vectors(e, X) - ## This time, e is a named tuple of *vectors*, representing the batched - ## evaluation of each formula. - - ## First, extract the 3D velocity vectors from the input matrix: - v = [(X[2, i], X[3, i], X[4, i]) for i in eachindex(axes(X, 2))] - - ## Use this to compute the full drag force: - F_d = [F_d_scale_i .* vi for (F_d_scale_i, vi) in zip(e.F_d_scale, v)] - - ## Collect the magnetic field components that we've learned into the vector: - B = [(bx, by, bz) for (bx, by, bz) in zip(e.B_x, e.B_y, e.B_z)] - - ## Using this, we compute the magnetic force with a cross product: +function compute_force((; B_x, B_y, B_z, F_d_scale), (t, v_x, v_y, v_z, T)) + ## First, we evaluate each subexpression on the variables we wish + ## to have each depend on: + _B_x = B_x(t) + _B_y = B_y(t) + _B_z = B_z(t) + _F_d_scale = F_d_scale(T) + ## Note that we can also evaluate an expression multiple times, + ## including in a hierarchy! + + ## Now, let's do the same computation we did above to + ## get the total force vectors. Note that the evaluation + ## output is wrapped in `ValidVector`, so we need + ## to extract the `.x` to get raw vectors: + B = [(bx, by, bz) for (bx, by, bz) in zip(_B_x.x, _B_y.x, _B_z.x)] + v = [(vx, vy, vz) for (vx, vy, vz) in zip(v_x.x, v_y.x, v_z.x)] + + + ## Now, let's compute the drag force using our model: + F_d = [_F_d_scale.x .* vi for (vi, _F_d_scale) in zip(v, _F_d_scale)] + + ## Now, the magnetic force: F_mag = [cross(vi, Bi) for (vi, Bi) in zip(v, B)] ## Finally, we combine the drag and magnetic forces into the total force: - return [Force((fd .+ fm)...) for (fd, fm) in zip(F_d, F_mag)] + F = map((fd, fm) -> Force((fd .+ fm)...), F_d, F_mag) + + ## The output of this function needs to be another `ValidVector`, + ## which carries through the validity of the evaluation. We compute + ## this below. + ValidVector(F, _B_x.valid && _B_y.valid && _B_z.valid && _F_d_scale.valid) + ## (Note that if you were doing operations that could not handle NaNs, + ## you may need to return early - just be sure to also return the `ValidVector`!) end #= -For the functions we wish to learn, we can constraint what variables -each of them depends on, explicitly. Let's say B only depends on time, -and the drag force scale only depends on temperature (we explicitly -multiply the velocity in). -=# -variable_constraints = (; B_x=[1], B_y=[1], B_z=[1], F_d_scale=[5]) +Note above that we have constrained what variables each subexpression depends on. -#= -Now, we can create our template expression: +We have constrained the magnetic field to only depend on time, +and the drag force scale to only depend on temperature. +The other variables we simply pass through and use in the evaluation. + +Now, we can create our template expression, with the +subexpression symbols we wish to learn: =# -structure = TemplateStructure{(:B_x, :B_y, :B_z, :F_d_scale)}(; - combine_strings=combine_strings, - combine_vectors=combine_vectors, - variable_constraints=variable_constraints, -) +structure = TemplateStructure{(:B_x, :B_y, :B_z, :F_d_scale)}(compute_force) #= -Let's look at an example of how this would be used +First, let's look at an example of how this would be used in a TemplateExpression, for some guess at the form of the solution: =# @@ -243,7 +247,7 @@ ex = TemplateExpression( So we can see that it prints the expression as we've defined it. Now, we can create a regressor that builds template expressions -which follow this structure: +which follow this structure! =# model = SRRegressor(; binary_operators=(+, -, *, /), @@ -252,7 +256,7 @@ model = SRRegressor(; maxsize=35, expression_type=TemplateExpression, expression_options=(; structure=structure), - ## The elementwise needs to operate directly on each row of `y`: + ## Note that the elementwise loss needs to operate directly on each row of `y`: elementwise_loss=(F1, F2) -> (F1.x - F2.x)^2 + (F1.y - F2.y)^2 + (F1.z - F2.z)^2, batching=true, batch_size=30, diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index 3d3e8dc95..7e5178c66 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -124,15 +124,15 @@ operands are valid, the result is valid; if any operand is invalid, the result i marked invalid. You will need to work with this to do highly custom operations with -`ComposableExpression` and `HierarchicalExpression`. +`ComposableExpression` and `TemplateExpression`. # Fields: -- `value::A`: The vector data. +- `x::A`: The vector data. - `valid::Bool`: Indicates if the data is valid. """ struct ValidVector{A<:AbstractVector} - value::A + x::A valid::Bool end ValidVector(x::Tuple{Vararg{Any,2}}) = ValidVector(x...) @@ -147,19 +147,19 @@ function (ex::AbstractComposableExpression)(x::AbstractVector, _xs::AbstractVect result = ex(xs...) # Unwrap it if result.valid - return result.value + return result.x else - nan = convert(eltype(result.value), NaN) - return result.value .* nan + nan = convert(eltype(result.x), NaN) + return result.x .* nan end end function (ex::AbstractComposableExpression)(x::ValidVector, _xs::ValidVector...) xs = (x, _xs...) valid = all(xi -> xi.valid, xs) if !valid - return ValidVector(first(xs).value, false) + return ValidVector(first(xs).x, false) else - X = Matrix(stack(map(xi -> xi.value, xs))') + X = Matrix(stack(map(xi -> xi.x, xs))') return ValidVector(eval_tree_array(ex, X)) end end @@ -194,7 +194,7 @@ function apply_operator(op::F, x...) where {F<:Function} end _is_valid(x::ValidVector) = x.valid _is_valid(x) = true -_get_value(x::ValidVector) = x.value +_get_value(x::ValidVector) = x.x _get_value(x) = x #! format: off diff --git a/src/HierarchicalExpression.jl b/src/HierarchicalExpression.jl deleted file mode 100644 index bcff05fbe..000000000 --- a/src/HierarchicalExpression.jl +++ /dev/null @@ -1,495 +0,0 @@ -module HierarchicalExpressionModule - -using Random: AbstractRNG -using Compat: Fix -using DispatchDoctor: @unstable -using StyledStrings: @styled_str, annotatedstring -using DynamicExpressions: - DynamicExpressions as DE, - AbstractStructuredExpression, - AbstractExpressionNode, - AbstractExpression, - AbstractOperatorEnum, - OperatorEnum, - Expression, - Metadata, - get_contents, - with_contents, - get_metadata, - get_operators, - get_variable_names, - get_tree, - node_type, - eval_tree_array, - count_nodes -using DynamicExpressions.InterfacesModule: - ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments - -using ..CoreModule: - AbstractOptions, Options, Dataset, CoreModule as CM, AbstractMutationWeights, has_units -using ..ConstantOptimizationModule: ConstantOptimizationModule as CO -using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE -using ..MutationFunctionsModule: MutationFunctionsModule as MF -using ..ExpressionBuilderModule: ExpressionBuilderModule as EB -using ..DimensionalAnalysisModule: DimensionalAnalysisModule as DA -using ..CheckConstraintsModule: CheckConstraintsModule as CC -using ..ComplexityModule: ComplexityModule -using ..LossFunctionsModule: LossFunctionsModule as LF -using ..MutateModule: MutateModule as MM -using ..PopMemberModule: PopMember -using ..ComposableExpressionModule: ComposableExpression, ValidVector - -""" - HierarchicalStructure{K,S,N,E,C} <: Function - -A struct that defines a prescribed structure for a `HierarchicalExpression`, -including functions that define the result in different contexts. - -The `K` parameter is used to specify the symbols representing the inner expressions. -If not declared using the constructor `HierarchicalStructure{K}(...)`, the keys of the -`variable_constraints` `NamedTuple` will be used to infer this. - -# Fields -- `combine`: Required function taking a `NamedTuple` of callable expressions (with keys `K`), - and a tuple representing the data. For example, `((; f, g), (x1, x2, x3)) -> f(x1, x2) + g(x3)` - would be a valid `combine` function. You may also re-use the callable expressions and - use different inputs, such as `((; f, g), (x1, x2)) -> f(x1 + g(x2)) - g(x1)` is - another valid choice. -- `num_features`: Optional `NamedTuple` of function keys => integers representing the number of - features used by each expression. If not provided, it will be inferred using the `combine` - function. For example, if `f` takes two arguments, and `g` takes one, then - `num_features = (; f=2, g=1)`. -""" -struct HierarchicalStructure{K,E<:Function,NF<:NamedTuple} <: Function - combine::E - num_features::NF -end - -function HierarchicalStructure{K}(combine::E, num_features=nothing) where {K,E<:Function} - num_features = @something(num_features, infer_variable_constraints(Val(K), combine)) - return HierarchicalStructure{K,E,typeof(num_features)}(combine, num_features) -end - -@unstable function combine(template::HierarchicalStructure, args...) - return template.combine(args...) -end - -get_function_keys(::HierarchicalStructure{K}) where {K} = K - -function _record_composable_expression!(variable_constraints, ::Val{k}, args...) where {k} - vc = variable_constraints[k][] - if vc == -1 - variable_constraints[k][] = length(args) - elseif vc != length(args) - throw(ArgumentError("Inconsistent number of arguments passed to $k")) - end - return first(args) -end - -"""Infers number of features used by each subexpression, by passing in test data.""" -function infer_variable_constraints(::Val{K}, combiner::F) where {K,F} - variable_constraints = NamedTuple{K}(map(_ -> Ref(-1), K)) - # Now, we need to evaluate the `combine` function to see how many - # features are used for each function call. If unset, we record it. - # If set, we validate. - inner = Fix{1}(_record_composable_expression!, variable_constraints) - _recorders_of_composable_expressions = NamedTuple{K}(map(k -> Fix{1}(inner, Val(k)), K)) - # We use an evaluation to get the variable constraints - combiner( - _recorders_of_composable_expressions, - Base.Iterators.repeated(ValidVector(ones(Float64, 1), true)), - ) - inferred = NamedTuple{K}(map(x -> x[], values(variable_constraints))) - if any(==(-1), values(inferred)) - failed_keys = filter(k -> inferred[k] == -1, K) - throw(ArgumentError("Failed to infer number of features used by $failed_keys")) - end - return inferred -end - -""" - HierarchicalExpression{T,F,N,E,TS,D} <: AbstractStructuredExpression{T,F,N,E,D} - -A symbolic expression that allows the combination of multiple sub-expressions -in a structured way, with constraints on variable usage. - -`HierarchicalExpression` is designed for symbolic regression tasks where -domain-specific knowledge or constraints must be imposed on the model's structure. - -# Constructor - -- `HierarchicalExpression(trees; structure, operators, variable_names)` - - `trees`: A `NamedTuple` holding the sub-expressions (e.g., `f = Expression(...)`, `g = Expression(...)`). - - `structure`: A `HierarchicalStructure` which holds functions that define how the sub-expressions are combined - in different contexts. - - `operators`: An `OperatorEnum` that defines the allowed operators for the sub-expressions. - - `variable_names`: An optional `Vector` of `String` that defines the names of the variables in the dataset. - -# Example - -Let's create an example `HierarchicalExpression` that combines two sub-expressions `f(x1, x2)` and `g(x3)`: - -```julia -# Define operators and variable names -options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) -operators = options.operators -variable_names = ["x1", "x2", "x3"] - -# Create sub-expressions -x1 = Expression(Node{Float64}(; feature=1); operators, variable_names) -x2 = Expression(Node{Float64}(; feature=2); operators, variable_names) -x3 = Expression(Node{Float64}(; feature=3); operators, variable_names) - -# Create HierarchicalExpression -example_expr = (; f=x1, g=x3) -st_expr = HierarchicalExpression( - example_expr; - structure=HierarchicalStructure{(:f, :g)}( - ((; f, g), (x1, x2, x3)) -> sin(f(x1, x2)) + g(x3)^2 - ), - operators, - variable_names, -) -``` - -When fitting a model in SymbolicRegression.jl, you would provide the `HierarchicalExpression` -as the `expression_type` argument, and then pass `expression_options=(; structure=HierarchicalStructure(...))` -as additional options. The `variable_constraints` will constraint `f` to only have access to `x1` and `x2`, -and `g` to only have access to `x3`. -""" -struct HierarchicalExpression{ - T, - F<:HierarchicalStructure, - N<:AbstractExpressionNode{T}, - E<:ComposableExpression{T,N}, - TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}}, - D<:@NamedTuple{ - structure::F, operators::O, variable_names::V - } where {O<:AbstractOperatorEnum,V}, -} <: AbstractStructuredExpression{T,F,N,E,D} - trees::TS - metadata::Metadata{D} - - function HierarchicalExpression( - trees::TS, metadata::Metadata{D} - ) where { - TS, - F<:HierarchicalStructure, - D<:@NamedTuple{structure::F, operators::O, variable_names::V} where {O,V}, - } - @assert keys(trees) == get_function_keys(metadata.structure) - E = typeof(first(values(trees))) - N = node_type(E) - return new{eltype(N),F,N,E,TS,D}(trees, metadata) - end -end - -function HierarchicalExpression( - trees::NamedTuple{<:Any,<:NTuple{<:Any,<:AbstractExpression}}; - structure::F, - operators::Union{AbstractOperatorEnum,Nothing}=nothing, - variable_names::Union{AbstractVector{<:AbstractString},Nothing}=nothing, -) where {F<:HierarchicalStructure} - example_tree = first(values(trees))::AbstractExpression - operators = get_operators(example_tree, operators) - variable_names = get_variable_names(example_tree, variable_names) - metadata = (; structure, operators, variable_names) - return HierarchicalExpression(trees, Metadata(metadata)) -end - -@unstable DE.constructorof(::Type{<:HierarchicalExpression}) = HierarchicalExpression - -@implements( - ExpressionInterface{all_ei_methods_except(())}, HierarchicalExpression, [Arguments()] -) - -@unstable function combine(ex::HierarchicalExpression, args...) - return combine(get_metadata(ex).structure, args...) -end - -function DE.get_tree(ex::HierarchicalExpression{<:Any,<:Any,<:Any,E}) where {E} - raw_contents = get_contents(ex) - total_num_features = max(values(get_metadata(ex).structure.num_features)...) - example_inner_ex = first(values(raw_contents)) - example_tree = get_contents(example_inner_ex)::AbstractExpressionNode - - variable_trees = [ - DE.constructorof(typeof(example_tree))(; feature=i) for i in 1:total_num_features - ] - variable_expressions = [ - with_contents(inner_ex, variable_tree) for - (inner_ex, variable_tree) in zip(values(raw_contents), variable_trees) - ] - - return DE.get_tree( - combine(get_metadata(ex).structure, raw_contents, variable_expressions) - ) -end - -function EB.create_expression( - t::AbstractExpressionNode{T}, - options::AbstractOptions, - dataset::Dataset{T,L}, - ::Type{<:AbstractExpressionNode}, - ::Type{E}, - ::Val{embed}=Val(false), -) where {T,L,embed,E<:HierarchicalExpression} - function_keys = get_function_keys(options.expression_options.structure) - - # NOTE: We need to copy over the operators so we can call the structure function - operators = options.operators - variable_names = embed ? dataset.variable_names : nothing - inner_expressions = ntuple( - _ -> ComposableExpression(copy(t); operators, variable_names), - Val(length(function_keys)), - ) - # TODO: Generalize to other inner expression types - return DE.constructorof(E)( - NamedTuple{function_keys}(inner_expressions); - EB.init_params(options, dataset, nothing, Val(embed))..., - ) -end -function EB.extra_init_params( - ::Type{E}, - prototype::Union{Nothing,AbstractExpression}, - options::AbstractOptions, - dataset::Dataset{T}, - ::Val{embed}, -) where {T,embed,E<:HierarchicalExpression} - # We also need to include the operators here to be consistent with `create_expression`. - return (; options.operators, options.expression_options...) -end -function EB.sort_params(params::NamedTuple, ::Type{<:HierarchicalExpression}) - return (; params.structure, params.operators, params.variable_names) -end - -function ComplexityModule.compute_complexity( - tree::HierarchicalExpression, options::AbstractOptions; break_sharing=Val(false) -) - # Rather than including the complexity of the combined tree, - # we only sum the complexity of each inner expression, which will be smaller. - return sum( - ex -> ComplexityModule.compute_complexity(ex, options; break_sharing), - values(get_contents(tree)), - ) -end - -# Rather than using iterator with repeat, just make a tuple: -function _colors(::Val{n}) where {n} - return ntuple( - (i -> (:magenta, :green, :red, :blue, :yellow, :cyan)[mod1(i, n)]), Val(n) - ) -end - -_color_string(s::AbstractString, c::Symbol) = styled"{$c:$s}" -function DE.string_tree( - tree::HierarchicalExpression, - operators::Union{AbstractOperatorEnum,Nothing}=nothing; - pretty::Bool=false, - variable_names=nothing, - kws..., -) - raw_contents = get_contents(tree) - function_keys = keys(raw_contents) - num_features = get_metadata(tree).structure.num_features - total_num_features = max(values(num_features)...) - colors = _colors(Val(length(function_keys))) - variable_names = ["#" * string(i) for i in 1:total_num_features] - inner_strings = NamedTuple{function_keys}( - map( - ex -> DE.string_tree(ex, operators; pretty, variable_names, kws...), - values(raw_contents), - ), - ) - strings = NamedTuple{function_keys}( - map( - (k, s, c) -> let - prefix = if !pretty || length(function_keys) == 1 - "" - elseif k == first(function_keys) - "╭ " - elseif k == last(function_keys) - "╰ " - else - "├ " - end - annotatedstring(prefix * string(k) * " = ", _color_string(s, c)) - end, - function_keys, - values(inner_strings), - colors, - ), - ) - return annotatedstring(join(strings, pretty ? styled"\n" : "; ")) -end -function DE.eval_tree_array( - tree::HierarchicalExpression{T}, - cX::AbstractMatrix{T}, - operators::Union{AbstractOperatorEnum,Nothing}=nothing; - kws..., -) where {T} - raw_contents = get_contents(tree) - if has_invalid_variables(tree) - return (cX[1, :], false) - end - result = combine(tree, raw_contents, map(x -> ValidVector(copy(x), true), eachrow(cX))) - return result.value, result.valid -end -function (ex::HierarchicalExpression)( - X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... -) - result, valid = DE.eval_tree_array(ex, X, operators; kws...) - if valid - return result - else - nan = convert(eltype(result), NaN) - return result .* nan - end -end -@unstable IDE.expected_array_type(::AbstractMatrix, ::Type{<:HierarchicalExpression}) = Any - -function DA.violates_dimensional_constraints( - @nospecialize(tree::HierarchicalExpression), - dataset::Dataset, - @nospecialize(options::AbstractOptions) -) - @assert !has_units(dataset) - return false -end -function MM.condition_mutation_weights!( - @nospecialize(weights::AbstractMutationWeights), - @nospecialize(member::P), - @nospecialize(options::AbstractOptions), - curmaxsize::Int, -) where {T,L,N<:HierarchicalExpression,P<:PopMember{T,L,N}} - # HACK TODO - return nothing -end - -""" -We need full specialization for constrained expressions, as they rely on subexpressions being combined. -""" -function CM.operator_specialization( - ::Type{O}, ::Type{<:HierarchicalExpression} -) where {O<:OperatorEnum} - return O -end - -function CM.max_features( - dataset::Dataset, options::Options{<:Any,<:Any,<:Any,<:HierarchicalExpression} -) - num_features = options.expression_options.structure.num_features - return max(values(num_features)...) -end - -""" -We pick a random subexpression to mutate, -and also return the symbol we mutated on so that we can put it back together later. -""" -function MF.get_contents_for_mutation(ex::HierarchicalExpression, rng::AbstractRNG) - raw_contents = get_contents(ex) - function_keys = keys(raw_contents) - - # Sample weighted by number of nodes in each subexpression - num_nodes = map(count_nodes, values(raw_contents)) - weights = map(Base.Fix2(/, sum(num_nodes)), num_nodes) - cumsum_weights = cumsum(weights) - rand_val = rand(rng) - idx = findfirst(Base.Fix2(>=, rand_val), cumsum_weights)::Int - - key_to_mutate = function_keys[idx] - return raw_contents[key_to_mutate], key_to_mutate -end - -"""See `get_contents_for_mutation(::HierarchicalExpression, ::AbstractRNG)`.""" -function MF.with_contents_for_mutation( - ex::HierarchicalExpression, new_inner_contents, context::Symbol -) - raw_contents = get_contents(ex) - raw_contents_keys = keys(raw_contents) - new_contents = NamedTuple{raw_contents_keys}( - ntuple(length(raw_contents_keys)) do i - if raw_contents_keys[i] == context - new_inner_contents - else - raw_contents[raw_contents_keys[i]] - end - end, - ) - return with_contents(ex, new_contents) -end - -"""We combine the operators of each inner expression.""" -function DE.combine_operators( - ex::HierarchicalExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing -) where {T,N} - raw_contents = get_contents(ex) - function_keys = keys(raw_contents) - new_contents = NamedTuple{function_keys}( - map(Base.Fix2(DE.combine_operators, operators), values(raw_contents)) - ) - return with_contents(ex, new_contents) -end - -"""We simplify each inner expression.""" -function DE.simplify_tree!( - ex::HierarchicalExpression{T,N}, operators::Union{AbstractOperatorEnum,Nothing}=nothing -) where {T,N} - raw_contents = get_contents(ex) - function_keys = keys(raw_contents) - new_contents = NamedTuple{function_keys}( - map(Base.Fix2(DE.simplify_tree!, operators), values(raw_contents)) - ) - return with_contents(ex, new_contents) -end - -function CO.count_constants_for_optimization(ex::HierarchicalExpression) - return sum(CO.count_constants_for_optimization, values(get_contents(ex))) -end - -function CC.check_constraints( - ex::HierarchicalExpression, - options::AbstractOptions, - maxsize::Int, - cursize::Union{Int,Nothing}=nothing, -)::Bool - # First, we check the variable constraints at the top level: - if has_invalid_variables(ex) - return false - end - - # We also check the combined complexity: - @something(cursize, ComplexityModule.compute_complexity(ex, options)) > maxsize && - return false - - # Then, we check other constraints for inner expressions: - raw_contents = get_contents(ex) - for t in values(raw_contents) - if !CC.check_constraints(t, options, maxsize, nothing) - return false - end - end - return true - # TODO: The concept of `cursize` doesn't really make sense here. -end -function has_invalid_variables(ex::HierarchicalExpression) - raw_contents = get_contents(ex) - num_features = get_metadata(ex).structure.num_features - any(keys(raw_contents)) do key - tree = raw_contents[key] - max_feature = num_features[key] - contains_features_greater_than(tree, max_feature) - end -end -function contains_features_greater_than(tree::AbstractExpression, max_feature) - return contains_features_greater_than(get_tree(tree), max_feature) -end -function contains_features_greater_than(tree::AbstractExpressionNode, max_feature) - any(tree) do node - node.degree == 0 && !node.constant && node.feature > max_feature - end -end - -# TODO: Add custom behavior to adjust what feature nodes can be generated - -end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 89176002e..91dd9e698 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -14,17 +14,13 @@ export Population, ParametricExpression, TemplateExpression, TemplateStructure, - HierarchicalExpression, - HierarchicalStructure, + ValidVector, ComposableExpression, - NodeSampler, AbstractExpression, AbstractExpressionNode, EvalOptions, SRRegressor, MultitargetSRRegressor, - LOSS_TYPE, - DATA_TYPE, #Functions: equation_search, @@ -44,7 +40,6 @@ export Population, set_node!, copy_node, node_to_symbolic, - node_type, symbolic_to_node, simplify_tree!, tree_mapreduce, @@ -162,16 +157,17 @@ using DynamicExpressions: with_type_parameters LogCoshLoss using Compat: @compat, Fix -@compat public AbstractOptions, -AbstractRuntimeOptions, -RuntimeOptions, -AbstractMutationWeights, -mutate!, -condition_mutation_weights!, -sample_mutation, -MutationResult, -AbstractSearchState, -SearchState +#! format: off +@compat( + public, + ( + AbstractOptions, AbstractRuntimeOptions, RuntimeOptions, + AbstractMutationWeights, mutate!, condition_mutation_weights!, + sample_mutation, MutationResult, AbstractSearchState, SearchState, + NodeSampler, LOSS_TYPE, DATA_TYPE, node_type, + ) +) +#! format: on # ^ We can add new functions here based on requests from users. # However, I don't want to add many functions without knowing what # users will actually want to overload. @@ -217,7 +213,6 @@ using DispatchDoctor: @stable include("ExpressionBuilder.jl") include("ComposableExpression.jl") include("TemplateExpression.jl") - include("HierarchicalExpression.jl") include("ParametricExpression.jl") end @@ -315,8 +310,7 @@ using .SearchUtilsModule: get_cur_maxsize, update_hall_of_fame! using .TemplateExpressionModule: TemplateExpression, TemplateStructure -using .HierarchicalExpressionModule: - HierarchicalExpression, HierarchicalStructure, ValidVector +using .TemplateExpressionModule: TemplateExpression, TemplateStructure, ValidVector using .ComposableExpressionModule: ComposableExpression using .ExpressionBuilderModule: embed_metadata, strip_metadata diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 42a045c3f..6a4eebc79 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -3,7 +3,7 @@ module TemplateExpressionModule using Random: AbstractRNG using Compat: Fix using DispatchDoctor: @unstable -using StyledStrings: @styled_str +using StyledStrings: @styled_str, annotatedstring using DynamicExpressions: DynamicExpressions as DE, AbstractStructuredExpression, @@ -26,7 +26,7 @@ using DynamicExpressions.InterfacesModule: ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments using ..CoreModule: - AbstractOptions, Dataset, CoreModule as CM, AbstractMutationWeights, has_units + AbstractOptions, Options, Dataset, CoreModule as CM, AbstractMutationWeights, has_units using ..ConstantOptimizationModule: ConstantOptimizationModule as CO using ..InterfaceDynamicExpressionsModule: InterfaceDynamicExpressionsModule as IDE using ..MutationFunctionsModule: MutationFunctionsModule as MF @@ -37,131 +37,76 @@ using ..ComplexityModule: ComplexityModule using ..LossFunctionsModule: LossFunctionsModule as LF using ..MutateModule: MutateModule as MM using ..PopMemberModule: PopMember +using ..ComposableExpressionModule: ComposableExpression, ValidVector """ TemplateStructure{K,S,N,E,C} <: Function A struct that defines a prescribed structure for a `TemplateExpression`, -including functions that define the result of combining sub-expressions in different contexts. +including functions that define the result in different contexts. The `K` parameter is used to specify the symbols representing the inner expressions. If not declared using the constructor `TemplateStructure{K}(...)`, the keys of the `variable_constraints` `NamedTuple` will be used to infer this. # Fields -- `combine`: Optional function taking a `NamedTuple` of function keys => expressions, - returning a single expression. Fallback method used by `get_tree` - on a `TemplateExpression` to generate a single `Expression`. -- `combine_vectors`: Optional function taking a `NamedTuple` of function keys => vectors, - returning a single vector. Used for evaluating the expression tree. - You may optionally define a method with a second argument `X` for if you wish - to include the data matrix `X` (of shape `[num_features, num_rows]`) in the - computation. -- `combine_strings`: Optional function taking a `NamedTuple` of function keys => strings, - returning a single string. Used for printing the expression tree. -- `variable_constraints`: Optional `NamedTuple` that defines which variables each sub-expression is allowed to access. - For example, requesting `f(x1, x2)` and `g(x3)` would be equivalent to `(; f=[1, 2], g=[3])`. +- `combine`: Required function taking a `NamedTuple` of callable expressions (with keys `K`), + and a tuple representing the data. For example, `((; f, g), (x1, x2, x3)) -> f(x1, x2) + g(x3)` + would be a valid `combine` function. You may also re-use the callable expressions and + use different inputs, such as `((; f, g), (x1, x2)) -> f(x1 + g(x2)) - g(x1)` is + another valid choice. +- `num_features`: Optional `NamedTuple` of function keys => integers representing the number of + features used by each expression. If not provided, it will be inferred using the `combine` + function. For example, if `f` takes two arguments, and `g` takes one, then + `num_features = (; f=2, g=1)`. """ -struct TemplateStructure{ - K, - E<:Union{Nothing,Function}, - N<:Union{Nothing,Function}, - S<:Union{Nothing,Function}, - C<:Union{Nothing,NamedTuple{<:Any,<:Tuple{Vararg{Vector{Int}}}}}, -} <: Function +struct TemplateStructure{K,E<:Function,NF<:NamedTuple} <: Function combine::E - combine_vectors::N - combine_strings::S - variable_constraints::C + num_features::NF end -function TemplateStructure{K}(combine::E; kws...) where {K,E<:Function} - return TemplateStructure{K}(; combine, kws...) +function TemplateStructure{K}(combine::E, num_features=nothing) where {K,E<:Function} + num_features = @something(num_features, infer_variable_constraints(Val(K), combine)) + return TemplateStructure{K,E,typeof(num_features)}(combine, num_features) end -function TemplateStructure{K}(; kws...) where {K} - return TemplateStructure(; _function_keys=Val(K), kws...) -end -function TemplateStructure(combine::E; kws...) where {E<:Function} - return TemplateStructure(; combine, kws...) + +@unstable function combine(template::TemplateStructure, args...) + return template.combine(args...) end -function TemplateStructure(; - combine::E=nothing, - combine_vectors::N=nothing, - combine_strings::S=nothing, - variable_constraints::C=nothing, - _function_keys::Val{K}=Val(nothing), -) where { - K, - E<:Union{Nothing,Function}, - N<:Union{Nothing,Function}, - S<:Union{Nothing,Function}, - C<:Union{Nothing,NamedTuple{<:Any,<:Tuple{Vararg{Vector{Int}}}}}, -} - Kout = if K !== nothing && variable_constraints !== nothing - K != keys(variable_constraints) && - throw(ArgumentError("`K` must match the keys of `variable_constraints`.")) - K - elseif K !== nothing - K - elseif variable_constraints !== nothing - keys(variable_constraints) - else - throw( - ArgumentError( - "If `variable_constraints` is not provided, " * - "you must initialize `TemplateStructure` with " * - "`TemplateStructure{K}(...)`, for tuple of symbols `K`.", - ), - ) + +get_function_keys(::TemplateStructure{K}) where {K} = K + +function _record_composable_expression!(variable_constraints, ::Val{k}, args...) where {k} + vc = variable_constraints[k][] + if vc == -1 + variable_constraints[k][] = length(args) + elseif vc != length(args) + throw(ArgumentError("Inconsistent number of arguments passed to $k")) end - return TemplateStructure{Kout,E,N,S,C}( - combine, combine_vectors, combine_strings, variable_constraints + return first(args) +end + +"""Infers number of features used by each subexpression, by passing in test data.""" +function infer_variable_constraints(::Val{K}, combiner::F) where {K,F} + variable_constraints = NamedTuple{K}(map(_ -> Ref(-1), K)) + # Now, we need to evaluate the `combine` function to see how many + # features are used for each function call. If unset, we record it. + # If set, we validate. + inner = Fix{1}(_record_composable_expression!, variable_constraints) + _recorders_of_composable_expressions = NamedTuple{K}(map(k -> Fix{1}(inner, Val(k)), K)) + # We use an evaluation to get the variable constraints + combiner( + _recorders_of_composable_expressions, + Base.Iterators.repeated(ValidVector(ones(Float64, 1), true)), ) -end -# TODO: This interface is ugly. Part of this is due to AbstractStructuredExpression, -# which was not written with this `TemplateStructure` in mind, but just with a -# single callable function. - -function combine(template::TemplateStructure, nt::NamedTuple) - return (template.combine::Function)(nt)::AbstractExpression -end -function combine_vectors( - template::TemplateStructure, nt::NamedTuple, X::Union{AbstractMatrix,Nothing}=nothing -) - combiner = template.combine_vectors::Function - if X !== nothing && hasmethod(combiner, typeof((nt, X))) - # TODO: Refactor this - return combiner(nt, X)::AbstractVector - else - return combiner(nt)::AbstractVector + inferred = NamedTuple{K}(map(x -> x[], values(variable_constraints))) + if any(==(-1), values(inferred)) + failed_keys = filter(k -> inferred[k] == -1, K) + throw(ArgumentError("Failed to infer number of features used by $failed_keys")) end -end -function combine_strings(template::TemplateStructure, nt::NamedTuple) - return (template.combine_strings::Function)(nt)::AbstractString + return inferred end -function (template::TemplateStructure)( - nt::NamedTuple{<:Any,<:Tuple{AbstractExpression,Vararg{AbstractExpression}}} -) - return combine(template, nt) -end -function (template::TemplateStructure)( - nt::NamedTuple{<:Any,<:Tuple{AbstractVector,Vararg{AbstractVector}}}, - X::Union{AbstractMatrix,Nothing}=nothing, -) - return combine_vectors(template, nt, X) -end -function (template::TemplateStructure)( - nt::NamedTuple{<:Any,<:Tuple{AbstractString,Vararg{AbstractString}}} -) - return combine_strings(template, nt) -end - -can_combine(template::TemplateStructure) = template.combine !== nothing -can_combine_vectors(template::TemplateStructure) = template.combine_vectors !== nothing -can_combine_strings(template::TemplateStructure) = template.combine_strings !== nothing -get_function_keys(::TemplateStructure{K}) where {K} = K - """ TemplateExpression{T,F,N,E,TS,D} <: AbstractStructuredExpression{T,F,N,E,D} @@ -199,20 +144,8 @@ x3 = Expression(Node{Float64}(; feature=3); operators, variable_names) example_expr = (; f=x1, g=x3) st_expr = TemplateExpression( example_expr; - structure=TemplateStructure{(:f, :g)}(nt -> sin(nt.f) + nt.g * nt.g), - operators, - variable_names, -) -``` - -We can also define constraints on which variables each sub-expression is allowed to access: - -```julia -variable_constraints = (; f=[1, 2], g=[3]) -st_expr = TemplateExpression( - example_expr; - structure=TemplateStructure( - nt -> sin(nt.f) + nt.g * nt.g; variable_constraints + structure=TemplateStructure{(:f, :g)}( + ((; f, g), (x1, x2, x3)) -> sin(f(x1, x2)) + g(x3)^2 ), operators, variable_names, @@ -228,9 +161,11 @@ struct TemplateExpression{ T, F<:TemplateStructure, N<:AbstractExpressionNode{T}, - E<:Expression{T,N}, # TODO: Generalize this + E<:ComposableExpression{T,N}, TS<:NamedTuple{<:Any,<:NTuple{<:Any,E}}, - D<:@NamedTuple{structure::F, operators::O, variable_names::V} where {O,V}, + D<:@NamedTuple{ + structure::F, operators::O, variable_names::V + } where {O<:AbstractOperatorEnum,V}, } <: AbstractStructuredExpression{T,F,N,E,D} trees::TS metadata::Metadata{D} @@ -268,28 +203,28 @@ end ExpressionInterface{all_ei_methods_except(())}, TemplateExpression, [Arguments()] ) -function combine(ex::TemplateExpression, nt::NamedTuple) - return combine(get_metadata(ex).structure, nt) -end -function combine_vectors( - ex::TemplateExpression, nt::NamedTuple, X::Union{AbstractMatrix,Nothing}=nothing -) - return combine_vectors(get_metadata(ex).structure, nt, X) -end -function combine_strings(ex::TemplateExpression, nt::NamedTuple) - return combine_strings(get_metadata(ex).structure, nt) +@unstable function combine(ex::TemplateExpression, args...) + return combine(get_metadata(ex).structure, args...) end -function can_combine(ex::TemplateExpression) - return can_combine(get_metadata(ex).structure) -end -function can_combine_vectors(ex::TemplateExpression) - return can_combine_vectors(get_metadata(ex).structure) -end -function can_combine_strings(ex::TemplateExpression) - return can_combine_strings(get_metadata(ex).structure) +function DE.get_tree(ex::TemplateExpression{<:Any,<:Any,<:Any,E}) where {E} + raw_contents = get_contents(ex) + total_num_features = max(values(get_metadata(ex).structure.num_features)...) + example_inner_ex = first(values(raw_contents)) + example_tree = get_contents(example_inner_ex)::AbstractExpressionNode + + variable_trees = [ + DE.constructorof(typeof(example_tree))(; feature=i) for i in 1:total_num_features + ] + variable_expressions = [ + with_contents(inner_ex, variable_tree) for + (inner_ex, variable_tree) in zip(values(raw_contents), variable_trees) + ] + + return DE.get_tree( + combine(get_metadata(ex).structure, raw_contents, variable_expressions) + ) end -get_function_keys(ex::TemplateExpression) = get_function_keys(get_metadata(ex).structure) function EB.create_expression( t::AbstractExpressionNode{T}, @@ -305,7 +240,8 @@ function EB.create_expression( operators = options.operators variable_names = embed ? dataset.variable_names : nothing inner_expressions = ntuple( - _ -> Expression(copy(t); operators, variable_names), length(function_keys) + _ -> ComposableExpression(copy(t); operators, variable_names), + Val(length(function_keys)), ) # TODO: Generalize to other inner expression types return DE.constructorof(E)( @@ -338,25 +274,53 @@ function ComplexityModule.compute_complexity( ) end +# Rather than using iterator with repeat, just make a tuple: +function _colors(::Val{n}) where {n} + return ntuple( + (i -> (:magenta, :green, :red, :blue, :yellow, :cyan)[mod1(i, n)]), Val(n) + ) +end + _color_string(s::AbstractString, c::Symbol) = styled"{$c:$s}" function DE.string_tree( - tree::TemplateExpression, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... + tree::TemplateExpression, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + pretty::Bool=false, + variable_names=nothing, + kws..., ) raw_contents = get_contents(tree) - if can_combine_strings(tree) - function_keys = keys(raw_contents) - colors = Base.Iterators.cycle((:magenta, :green, :red, :blue, :yellow, :cyan)) - inner_strings = NamedTuple{function_keys}( - map(ex -> DE.string_tree(ex, operators; kws...), values(raw_contents)) - ) - colored_strings = NamedTuple{function_keys}( - map(_color_string, inner_strings, colors) - ) - return combine_strings(tree, colored_strings) - else - @assert can_combine(tree) - return DE.string_tree(combine(tree, raw_contents), operators; kws...) - end + function_keys = keys(raw_contents) + num_features = get_metadata(tree).structure.num_features + total_num_features = max(values(num_features)...) + colors = _colors(Val(length(function_keys))) + variable_names = ["#" * string(i) for i in 1:total_num_features] + inner_strings = NamedTuple{function_keys}( + map( + ex -> DE.string_tree(ex, operators; pretty, variable_names, kws...), + values(raw_contents), + ), + ) + strings = NamedTuple{function_keys}( + map( + (k, s, c) -> let + prefix = if !pretty || length(function_keys) == 1 + "" + elseif k == first(function_keys) + "╭ " + elseif k == last(function_keys) + "╰ " + else + "├ " + end + annotatedstring(prefix * string(k) * " = ", _color_string(s, c)) + end, + function_keys, + values(inner_strings), + colors, + ), + ) + return annotatedstring(join(strings, pretty ? styled"\n" : "; ")) end function DE.eval_tree_array( tree::TemplateExpression{T}, @@ -365,32 +329,21 @@ function DE.eval_tree_array( kws..., ) where {T} raw_contents = get_contents(tree) - if can_combine_vectors(tree) - # Raw numerical results of each inner expression: - outs = map( - ex -> DE.eval_tree_array(ex, cX, operators; kws...), values(raw_contents) - ) - # Combine them using the structure function: - results = NamedTuple{keys(raw_contents)}(map(first, outs)) - return combine_vectors(tree, results, cX), all(last, outs) - else - @assert can_combine(tree) - return DE.eval_tree_array(combine(tree, raw_contents), cX, operators; kws...) + if has_invalid_variables(tree) + return (cX[1, :], false) end + result = combine(tree, raw_contents, map(x -> ValidVector(copy(x), true), eachrow(cX))) + return result.x, result.valid end function (ex::TemplateExpression)( X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... ) - raw_contents = get_contents(ex) - if can_combine_vectors(ex) - results = NamedTuple{keys(raw_contents)}( - map(ex -> ex(X, operators; kws...), values(raw_contents)) - ) - return combine_vectors(ex, results, X) + result, valid = DE.eval_tree_array(ex, X, operators; kws...) + if valid + return result else - @assert can_combine(ex) - callable = combine(ex, raw_contents) - return callable(X, operators; kws...) + nan = convert(eltype(result), NaN) + return result .* nan end end @unstable IDE.expected_array_type(::AbstractMatrix, ::Type{<:TemplateExpression}) = Any @@ -422,6 +375,13 @@ function CM.operator_specialization( return O end +function CM.max_features( + dataset::Dataset, options::Options{<:Any,<:Any,<:Any,<:TemplateExpression} +) + num_features = options.expression_options.structure.num_features + return max(values(num_features)...) +end + """ We pick a random subexpression to mutate, and also return the symbol we mutated on so that we can put it back together later. @@ -493,16 +453,8 @@ function CC.check_constraints( maxsize::Int, cursize::Union{Int,Nothing}=nothing, )::Bool - raw_contents = get_contents(ex) - variable_constraints = get_metadata(ex).structure.variable_constraints - # First, we check the variable constraints at the top level: - has_invalid_variables = any(keys(raw_contents)) do key - tree = raw_contents[key] - allowed_variables = variable_constraints[key] - contains_other_features_than(tree, allowed_variables) - end - if has_invalid_variables + if has_invalid_variables(ex) return false end @@ -511,6 +463,7 @@ function CC.check_constraints( return false # Then, we check other constraints for inner expressions: + raw_contents = get_contents(ex) for t in values(raw_contents) if !CC.check_constraints(t, options, maxsize, nothing) return false @@ -519,12 +472,21 @@ function CC.check_constraints( return true # TODO: The concept of `cursize` doesn't really make sense here. end -function contains_other_features_than(tree::AbstractExpression, features) - return contains_other_features_than(get_tree(tree), features) +function has_invalid_variables(ex::TemplateExpression) + raw_contents = get_contents(ex) + num_features = get_metadata(ex).structure.num_features + any(keys(raw_contents)) do key + tree = raw_contents[key] + max_feature = num_features[key] + contains_features_greater_than(tree, max_feature) + end +end +function contains_features_greater_than(tree::AbstractExpression, max_feature) + return contains_features_greater_than(get_tree(tree), max_feature) end -function contains_other_features_than(tree::AbstractExpressionNode, features) +function contains_features_greater_than(tree::AbstractExpressionNode, max_feature) any(tree) do node - node.degree == 0 && !node.constant && node.feature ∉ features + node.degree == 0 && !node.constant && node.feature > max_feature end end diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index 32083b578..69a31c951 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -1,3 +1,6 @@ +@testitem "Integration Test with fit! and Performance Check" tags = [:part3] begin + include("../examples/template_expression.jl") +end @testitem "Test ComposableExpression" tags = [:part2] begin using SymbolicRegression: ComposableExpression, Node using DynamicExpressions: OperatorEnum @@ -30,9 +33,9 @@ end @test Interfaces.test(ExpressionInterface, ComposableExpression, [f, g]) end -@testitem "Test interface for HierarchicalExpression" tags = [:part2] begin +@testitem "Test interface for TemplateExpression" tags = [:part2] begin using SymbolicRegression - using SymbolicRegression: HierarchicalExpression + using SymbolicRegression: TemplateExpression using DynamicExpressions.InterfacesModule: Interfaces, ExpressionInterface using DynamicExpressions: OperatorEnum @@ -41,23 +44,23 @@ end x1 = ComposableExpression(Node(Float64; feature=1); operators, variable_names) x2 = ComposableExpression(Node(Float64; feature=2); operators, variable_names) - structure = HierarchicalStructure{(:f, :g)}( + structure = TemplateStructure{(:f, :g)}( ((; f, g), (x1, x2)) -> f(f(f(x1))) - f(g(x2, x1)) ) @test structure.num_features == (; f=1, g=2) - expr = HierarchicalExpression((; f=x1, g=x2 * x2); structure, operators, variable_names) + expr = TemplateExpression((; f=x1, g=x2 * x2); structure, operators, variable_names) @test String(string_tree(expr)) == "f = #1; g = #2 * #2" @test String(string_tree(expr; pretty=true)) == "f = #1\ng = #2 * #2" @test string_tree(get_tree(expr), operators) == "x1 - (x1 * x1)" - @test Interfaces.test(ExpressionInterface, HierarchicalExpression, [expr]) + @test Interfaces.test(ExpressionInterface, TemplateExpression, [expr]) end -@testitem "Printing and evaluation of HierarchicalExpression" begin +@testitem "Printing and evaluation of TemplateExpression" begin using SymbolicRegression - structure = HierarchicalStructure{(:f, :g)}( + structure = TemplateStructure{(:f, :g)}( ((; f, g), (x1, x2, x3)) -> sin(f(x1, x2)) + g(x3)^2 ) operators = Options().operators @@ -69,7 +72,7 @@ end ] f = x1 * x2 g = x1 - expr = HierarchicalExpression((; f, g); structure, operators, variable_names) + expr = TemplateExpression((; f, g); structure, operators, variable_names) # Default printing strategy: @test String(string_tree(expr)) == "f = x1 * x2\ng = x1" diff --git a/test/test_template_expression.jl b/test/test_template_expression.jl deleted file mode 100644 index 04836cf15..000000000 --- a/test/test_template_expression.jl +++ /dev/null @@ -1,227 +0,0 @@ -@testitem "Basic utility of the TemplateExpression" tags = [:part3] begin - using SymbolicRegression - using SymbolicRegression: SymbolicRegression as SR - using SymbolicRegression.CheckConstraintsModule: check_constraints - using DynamicExpressions: OperatorEnum - - options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) - operators = options.operators - variable_names = ["x1", "x2", "x3"] - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - - # For combining expressions to a single expression: - structure = TemplateStructure(; - combine=e -> sin(e.f) + e.g * e.g, - combine_vectors=e -> (@. sin(e.f) + e.g^2), - combine_strings=e -> "sin($(e.f)) + $(e.g)^2", - variable_constraints=(; f=[1, 2], g=[3]), - ) - - @test structure isa TemplateStructure{(:f, :g)} - - st_expr = TemplateExpression((; f=x1, g=cos(x3)); structure, operators, variable_names) - @test string_tree(st_expr) == "sin(x1) + cos(x3)^2" - operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(cos, sin)) - - # Changing the operators will change how the expression is interpreted for - # parts that are already evaluated: - @test string_tree(st_expr, operators) == "sin(x1) + sin(x3)^2" - - # We can evaluate with this too: - cX = [1.0 2.0; 3.0 4.0; 5.0 6.0] - out = st_expr(cX) - @test out ≈ [sin(1.0) + cos(5.0)^2, sin(2.0) + cos(6.0)^2] - - # And also check the contents: - @test check_constraints(st_expr, options, 100) - - # We can see that violating the constraints will cause a violation: - new_expr = with_contents(st_expr, (; f=x3, g=cos(x3))) - @test !check_constraints(new_expr, options, 100) - new_expr = with_contents(st_expr, (; f=x2, g=cos(x3))) - @test check_constraints(new_expr, options, 100) - new_expr = with_contents(st_expr, (; f=x2, g=cos(x1))) - @test !check_constraints(new_expr, options, 100) - - # Checks the size of each individual expression: - new_expr = with_contents(st_expr, (; f=x2, g=cos(x3))) - - @test compute_complexity(new_expr, options) == 3 - @test check_constraints(new_expr, options, 3) - @test !check_constraints(new_expr, options, 2) -end -@testitem "Expression interface" tags = [:part3] begin - using SymbolicRegression - using DynamicExpressions: OperatorEnum - using DynamicExpressions.InterfacesModule: Interfaces, ExpressionInterface - - operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) - variable_names = (i -> "x$i").(1:3) - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - - # For combining expressions to a single expression: - structure = TemplateStructure{(:f, :g)}(; - combine=e -> sin(e.f) + e.g * e.g, - combine_strings=e -> "sin($(e.f)) + $(e.g)^2", - combine_vectors=e -> (@. sin(e.f) + e.g^2), - variable_constraints=(; f=[1, 2], g=[3]), - ) - st_expr = TemplateExpression((; f=x1, g=x3); structure, operators, variable_names) - @test Interfaces.test(ExpressionInterface, TemplateExpression, [st_expr]) -end -@testitem "Utilising TemplateExpression to build vector expressions" tags = [:part3] begin - using SymbolicRegression - using Random: rand - - # Define the structure function, which returns a tuple: - structure = TemplateStructure{(:f, :g1, :g2, :g3)}(; - combine_strings=e -> "( $(e.f) + $(e.g1), $(e.f) + $(e.g2), $(e.f) + $(e.g3) )", - combine_vectors=e -> - map((f, g1, g2, g3) -> (f + g1, f + g2, f + g3), e.f, e.g1, e.g2, e.g3), - ) - - # Set up operators and variable names - options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) - variable_names = (i -> "x$i").(1:3) - - # Create expressions - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); options.operators, variable_names)).(1:3) - - # Test with vector inputs: - nt_vector = NamedTuple{(:f, :g1, :g2, :g3)}((1:3, 4:6, 7:9, 10:12)) - @test structure(nt_vector) == [(5, 8, 11), (7, 10, 13), (9, 12, 15)] - - # And string inputs: - nt_string = NamedTuple{(:f, :g1, :g2, :g3)}(("x1", "x2", "x3", "x2")) - @test structure(nt_string) == "( x1 + x2, x1 + x3, x1 + x2 )" - - # Now, using TemplateExpression: - st_expr = TemplateExpression( - (; f=x1, g1=x2, g2=x3, g3=x2); structure, options.operators, variable_names - ) - @test string_tree(st_expr) == "( x1 + x2, x1 + x3, x1 + x2 )" - - # We can directly call it: - cX = [1.0 2.0; 3.0 4.0; 5.0 6.0] - out = st_expr(cX) - @test out == [(1 + 3, 1 + 5, 1 + 3), (2 + 4, 2 + 6, 2 + 4)] -end -@testitem "TemplateExpression getters" tags = [:part3] begin - using SymbolicRegression - using DynamicExpressions: get_operators, get_variable_names - - operators = - Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)).operators - variable_names = (i -> "x$i").(1:3) - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - - structure = TemplateStructure(; - combine=e -> e.f, variable_constraints=(; f=[1, 2], g1=[3], g2=[3], g3=[3]) - ) - - st_expr = TemplateExpression( - (; f=x1, g1=x3, g2=x3, g3=x3); structure, operators, variable_names - ) - - @test st_expr isa TemplateExpression - @test get_operators(st_expr) == operators - @test get_variable_names(st_expr) == variable_names - @test get_metadata(st_expr).structure == structure -end -@testitem "Integration Test with fit! and Performance Check" tags = [:part3] begin - include("../examples/template_expression.jl") -end -@testitem "TemplateExpression with only combine function" tags = [:part3] begin - using SymbolicRegression - using SymbolicRegression.TemplateExpressionModule: - can_combine_vectors, can_combine, get_function_keys - using SymbolicRegression.InterfaceDynamicExpressionsModule: expected_array_type - using DynamicExpressions: constructorof - - # Set up basic operators and variables - options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) - operators = options.operators - variable_names = ["x1", "x2", "x3"] - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - - # Create a TemplateStructure with only combine (no combine_vectors) - structure = TemplateStructure(; - combine=e -> sin(e.f) + e.g * e.g, # Only define combine - variable_constraints=(; f=[1, 2], g=[3]), - ) - - # Create the TemplateExpression - st_expr = TemplateExpression((; f=x1, g=cos(x3)); structure, operators, variable_names) - - @test constructorof(typeof(st_expr)) === TemplateExpression - @test get_function_keys(st_expr) == (:f, :g) - - # Test evaluation - cX = [1.0 2.0; 3.0 4.0; 5.0 6.0] - out = st_expr(cX) - out_2, complete = eval_tree_array(st_expr, cX) - - # The expression should evaluate by first combining to a single expression, - # then evaluating that expression - expected = sin.(cX[1, :]) .+ cos.(cX[3, :]) .^ 2 - @test out ≈ expected - - @test complete - @test out_2 ≈ expected - - # Verify that can_combine_vectors is false but can_combine is true - @test !can_combine_vectors(st_expr) - @test can_combine(st_expr) - - @test expected_array_type(cX, typeof(st_expr)) === Any - - @test string_tree(st_expr) == "sin(x1) + (cos(x3) * cos(x3))" -end -@testitem "TemplateExpression with data in combine_vectors" tags = [:part3] begin - using SymbolicRegression - - options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos, exp)) - operators = options.operators - variable_names = ["x1", "x2", "x3"] - x1, x2, x3 = - (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3) - f = exp(2.5 * x3) - g = x1 - structure = TemplateStructure(; - combine_vectors=(e, X) -> e.f .+ X[2, :], variable_constraints=(; f=[3], g=[1]) - ) - st_expr = TemplateExpression((; f, g); structure, operators, variable_names) - X = randn(3, 100) - @test st_expr(X) ≈ @. exp(2.5 * X[3, :]) + X[2, :] -end -@testitem "TemplateStructure constructors" tags = [:part3] begin - using SymbolicRegression - - operators = Options(; binary_operators=(+, *, /, -)).operators - variable_names = ["x1", "x2"] - - # Create simple expressions with constant values - f = Expression(Node(Float64; val=1.0); operators, variable_names) - g = Expression(Node(Float64; val=2.0); operators, variable_names) - - # Test TemplateStructure{K}(combine; kws...) - st1 = TemplateStructure{(:f, :g)}(e -> e.f + e.g) - @test st1.combine((; f, g)) == f + g - - # Test TemplateStructure(combine; kws...) - st2 = TemplateStructure(e -> e.f + e.g; variable_constraints=(; f=[1], g=[2])) - @test st2.combine((; f, g)) == f + g - - # Test error when no K or variable_constraints provided - @test_throws ArgumentError TemplateStructure(e -> e.f + e.g) - @test_throws ArgumentError( - "If `variable_constraints` is not provided, " * - "you must initialize `TemplateStructure` with " * - "`TemplateStructure{K}(...)`, for tuple of symbols `K`.", - ) TemplateStructure(e -> e.f + e.g) -end From fb1b7334d87e0595c7dca15986bb0cedb8fc306a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 22:32:24 +0000 Subject: [PATCH 26/59] style: formatting of template expression --- examples/template_expression.jl | 3 ++- examples/template_expression_complex.jl | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/template_expression.jl b/examples/template_expression.jl index 4f02fd754..e4140a6a7 100644 --- a/examples/template_expression.jl +++ b/examples/template_expression.jl @@ -6,7 +6,8 @@ using Test: @test options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) operators = options.operators variable_names = (i -> "x$i").(1:3) -x1, x2, x3 = (i -> ComposableExpression(Node(Float64; feature=i); operators, variable_names)).(1:3) +x1, x2, x3 = + (i -> ComposableExpression(Node(Float64; feature=i); operators, variable_names)).(1:3) structure = TemplateStructure{(:f, :g1, :g2)}( ((; f, g1, g2), (x1, x2, x3)) -> let diff --git a/examples/template_expression_complex.jl b/examples/template_expression_complex.jl index e4e8921ce..5fdadbfd6 100644 --- a/examples/template_expression_complex.jl +++ b/examples/template_expression_complex.jl @@ -187,7 +187,7 @@ function compute_force((; B_x, B_y, B_z, F_d_scale), (t, v_x, v_y, v_z, T)) _F_d_scale = F_d_scale(T) ## Note that we can also evaluate an expression multiple times, ## including in a hierarchy! - + ## Now, let's do the same computation we did above to ## get the total force vectors. Note that the evaluation ## output is wrapped in `ValidVector`, so we need From 436ff2390c00a89132fae180ba98406f6291b4a5 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 3 Nov 2024 14:18:38 +0000 Subject: [PATCH 27/59] docs: update docs for TemplateStructure --- src/TemplateExpression.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 6a4eebc79..4c66f3a79 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -40,7 +40,7 @@ using ..PopMemberModule: PopMember using ..ComposableExpressionModule: ComposableExpression, ValidVector """ - TemplateStructure{K,S,N,E,C} <: Function + TemplateStructure{K,E,NF} <: Function A struct that defines a prescribed structure for a `TemplateExpression`, including functions that define the result in different contexts. @@ -50,11 +50,11 @@ If not declared using the constructor `TemplateStructure{K}(...)`, the keys of t `variable_constraints` `NamedTuple` will be used to infer this. # Fields -- `combine`: Required function taking a `NamedTuple` of callable expressions (with keys `K`), - and a tuple representing the data. For example, `((; f, g), (x1, x2, x3)) -> f(x1, x2) + g(x3)` - would be a valid `combine` function. You may also re-use the callable expressions and - use different inputs, such as `((; f, g), (x1, x2)) -> f(x1 + g(x2)) - g(x1)` is - another valid choice. +- `combine`: Required function taking a `NamedTuple` of `ComposableExpression`s (sharing the keys `K`), + and then tuple representing the data of `ValidVector`s. For example, + `((; f, g), (x1, x2, x3)) -> f(x1, x2) + g(x3)` would be a valid `combine` function. You may also + re-use the callable expressions and use different inputs, such as + `((; f, g), (x1, x2)) -> f(x1 + g(x2)) - g(x1)` is another valid choice. - `num_features`: Optional `NamedTuple` of function keys => integers representing the number of features used by each expression. If not provided, it will be inferred using the `combine` function. For example, if `f` takes two arguments, and `g` takes one, then From 6ef8bcfa88470223acbb97c9ec13472fba7e7611 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 3 Nov 2024 14:40:37 +0000 Subject: [PATCH 28/59] feat: return `nothing` for invalid result rather than `NaN` --- examples/template_expression_complex.jl | 6 +-- src/InterfaceDynamicExpressions.jl | 39 +++++++++++------- src/LossFunctions.jl | 47 +++++++++++++++------- src/TemplateExpression.jl | 53 ++++++++++++++----------- 4 files changed, 90 insertions(+), 55 deletions(-) diff --git a/examples/template_expression_complex.jl b/examples/template_expression_complex.jl index 5fdadbfd6..47b3145e7 100644 --- a/examples/template_expression_complex.jl +++ b/examples/template_expression_complex.jl @@ -197,7 +197,7 @@ function compute_force((; B_x, B_y, B_z, F_d_scale), (t, v_x, v_y, v_z, T)) ## Now, let's compute the drag force using our model: - F_d = [_F_d_scale.x .* vi for (vi, _F_d_scale) in zip(v, _F_d_scale)] + F_d = [_F_d_scale .* vi for (vi, _F_d_scale) in zip(v, _F_d_scale.x)] ## Now, the magnetic force: F_mag = [cross(vi, Bi) for (vi, Bi) in zip(v, B)] @@ -233,8 +233,8 @@ the solution: options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos, sqrt, exp)) ## The inner operators are an `DynamicExpressions.OperatorEnum` which is used by `Expression`: operators = options.operators -t = Expression(Node{Float64}(; feature=1); operators, variable_names) -T = Expression(Node{Float64}(; feature=5); operators, variable_names) +t = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) +T = ComposableExpression(Node{Float64}(; feature=5); operators, variable_names) B_x = B_y = B_z = 2.1 * cos(t) F_d_scale = 1.0 * sqrt(T) diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index 6bcf2d82a..28a284f72 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -1,6 +1,7 @@ module InterfaceDynamicExpressionsModule using Printf: @sprintf +using DispatchDoctor: @stable using Compat: Fix using DynamicExpressions: DynamicExpressions as DE, @@ -48,23 +49,31 @@ which speed up evaluation significantly. or nan was encountered, and a large loss should be assigned to the equation. """ -function DE.eval_tree_array( - tree::Union{AbstractExpressionNode,AbstractExpression}, - X::AbstractMatrix, - options::AbstractOptions; - kws..., -) - A = expected_array_type(X, typeof(tree)) - out, complete = DE.eval_tree_array( - tree, - X, - DE.get_operators(tree, options); - turbo=options.turbo, - bumper=options.bumper, +@stable( + default_mode = "disable", + default_union_limit = 2, + function DE.eval_tree_array( + tree::Union{AbstractExpressionNode,AbstractExpression}, + X::AbstractMatrix, + options::AbstractOptions; kws..., ) - return out::A, complete::Bool -end + A = expected_array_type(X, typeof(tree)) + out, complete = DE.eval_tree_array( + tree, + X, + DE.get_operators(tree, options); + turbo=options.turbo, + bumper=options.bumper, + kws..., + ) + if isnothing(out) + return nothing, false + else + return out::A, complete::Bool + end + end +) """Improve type inference by telling Julia the expected array returned.""" function expected_array_type(X::AbstractArray, ::Type) diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index 637bb0fa4..ee9fbd496 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -1,5 +1,6 @@ module LossFunctionsModule +using DispatchDoctor: @stable using StatsBase: StatsBase using DynamicExpressions: AbstractExpression, AbstractExpressionNode, get_tree, eval_tree_array @@ -42,20 +43,38 @@ end end end -function eval_tree_dispatch( - tree::AbstractExpression, dataset::Dataset, options::AbstractOptions, idx -) - A = expected_array_type(dataset.X, typeof(tree)) - out, complete = eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options) - return out::A, complete::Bool -end -function eval_tree_dispatch( - tree::AbstractExpressionNode, dataset::Dataset, options::AbstractOptions, idx +@stable( + default_mode = "disable", + default_union_limit = 2, + begin + function eval_tree_dispatch( + tree::AbstractExpression, dataset::Dataset, options::AbstractOptions, idx + ) + A = expected_array_type(dataset.X, typeof(tree)) + out, complete = eval_tree_array( + tree, maybe_getindex(dataset.X, :, idx), options + ) + if isnothing(out) + return out, false + else + return out::A, complete::Bool + end + end + function eval_tree_dispatch( + tree::AbstractExpressionNode, dataset::Dataset, options::AbstractOptions, idx + ) + A = expected_array_type(dataset.X, typeof(tree)) + out, complete = eval_tree_array( + tree, maybe_getindex(dataset.X, :, idx), options + ) + if isnothing(out) + return out, false + else + return out::A, complete::Bool + end + end + end ) - A = expected_array_type(dataset.X, typeof(tree)) - out, complete = eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options) - return out::A, complete::Bool -end # Evaluate the loss of a particular expression on the input dataset. function _eval_loss( @@ -66,7 +85,7 @@ function _eval_loss( idx, )::L where {T<:DATA_TYPE,L<:LOSS_TYPE} (prediction, completion) = eval_tree_dispatch(tree, dataset, options, idx) - if !completion + if !completion || isnothing(prediction) return L(Inf) end diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 4c66f3a79..b39116529 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -2,7 +2,7 @@ module TemplateExpressionModule using Random: AbstractRNG using Compat: Fix -using DispatchDoctor: @unstable +using DispatchDoctor: @unstable, @stable using StyledStrings: @styled_str, annotatedstring using DynamicExpressions: DynamicExpressions as DE, @@ -322,30 +322,37 @@ function DE.string_tree( ) return annotatedstring(join(strings, pretty ? styled"\n" : "; ")) end -function DE.eval_tree_array( - tree::TemplateExpression{T}, - cX::AbstractMatrix{T}, - operators::Union{AbstractOperatorEnum,Nothing}=nothing; - kws..., -) where {T} - raw_contents = get_contents(tree) - if has_invalid_variables(tree) - return (cX[1, :], false) +@stable( + default_mode = "disable", + default_union_limit = 2, + begin + function DE.eval_tree_array( + tree::TemplateExpression{T}, + cX::AbstractMatrix{T}, + operators::Union{AbstractOperatorEnum,Nothing}=nothing; + kws..., + ) where {T} + raw_contents = get_contents(tree) + if has_invalid_variables(tree) + return (nothing, false) + end + result = combine( + tree, raw_contents, map(x -> ValidVector(copy(x), true), eachrow(cX)) + ) + return result.x, result.valid + end + function (ex::TemplateExpression)( + X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... + ) + result, valid = DE.eval_tree_array(ex, X, operators; kws...) + if valid + return result + else + return nothing + end + end end - result = combine(tree, raw_contents, map(x -> ValidVector(copy(x), true), eachrow(cX))) - return result.x, result.valid -end -function (ex::TemplateExpression)( - X, operators::Union{AbstractOperatorEnum,Nothing}=nothing; kws... ) - result, valid = DE.eval_tree_array(ex, X, operators; kws...) - if valid - return result - else - nan = convert(eltype(result), NaN) - return result .* nan - end -end @unstable IDE.expected_array_type(::AbstractMatrix, ::Type{<:TemplateExpression}) = Any function DA.violates_dimensional_constraints( From 7a0dbc290b459a60ccab654673243a2382bd5ade Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 3 Nov 2024 14:42:31 +0000 Subject: [PATCH 29/59] test: fix missing `node_type` --- test/test_mixed_utils.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_mixed_utils.jl b/test/test_mixed_utils.jl index 2ad9e7636..25af052c7 100644 --- a/test/test_mixed_utils.jl +++ b/test/test_mixed_utils.jl @@ -1,6 +1,5 @@ -using SymbolicRegression -using SymbolicRegression: string_tree -using Random, Bumper, LoopVectorization +using SymbolicRegression, Random, Bumper, LoopVectorization +using SymbolicRegression: string_tree, node_type include("test_params.jl") From 05f8678b6e04c8d0187a5ab0c94edd2aed68fa0b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 3 Nov 2024 14:48:54 +0000 Subject: [PATCH 30/59] docs: update docs for TemplateStructure --- docs/src/types.md | 33 ++++++++++++--------------------- 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/docs/src/types.md b/docs/src/types.md index bf954dfac..3e0d01c9b 100644 --- a/docs/src/types.md +++ b/docs/src/types.md @@ -60,36 +60,27 @@ ParametricNode These types allow you to define expressions with parameters that can be tuned to fit the data better. You can specify the maximum number of parameters using the `expression_options` argument in `SRRegressor`. +## Composable Expressions + +Composable expressions allow you to combine multiple expressions together. + +```@docs +ComposableExpression +``` + ## Template Expressions Template expressions allow you to specify predefined structures and constraints for your expressions. -These use the new `TemplateStructure` type to define how expressions should be combined and evaluated. +These use `ComposableExpressions` as their internal expression type, which makes them +flexible for creating a structure out of a single function. + +These use the `TemplateStructure` type to define how expressions should be combined and evaluated. ```@docs TemplateExpression TemplateStructure ``` -Example usage: - -```julia -# Define a template structure -structure = TemplateStructure( - combine=e -> e.f + e.g, # Create normal `Expression` - combine_vectors=e -> (e.f .+ e.g), # Output vector - combine_strings=e -> "($e.f) + ($e.g)", # Output string - variable_constraints=(; f=[1, 2], g=[3]) # Constrain dependencies -) - -# Use in options -model = SRRegressor(; - expression_type=TemplateExpression, - expression_options=(; structure=structure) -) -``` - -The `variable_constraints` field allows you to specify which variables can be used in different parts of the expression. - ## Population Groups of equations are given as a population, which is From d59de9b202c548a5cffaad23af8e65f6e1c2b14e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 3 Nov 2024 15:17:58 +0000 Subject: [PATCH 31/59] fix: move `_info_dump` to end for precompilation --- src/SymbolicRegression.jl | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 91dd9e698..d2a8bb9c1 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -1085,23 +1085,6 @@ end end return (out_pop, best_seen, record, num_evals) end - -include("MLJInterface.jl") -using .MLJInterfaceModule: SRRegressor, MultitargetSRRegressor - -# Hack to get static analysis to work from within tests: -@ignore include("../test/runtests.jl") - -# TODO: Hack to force ConstructionBase version -using ConstructionBase: ConstructionBase as _ - -include("precompile.jl") -redirect_stdout(devnull) do - redirect_stderr(devnull) do - do_precompilation(Val(:precompile)) - end -end - function _info_dump( state::AbstractSearchState, datasets::Vector{D}, @@ -1147,4 +1130,20 @@ function _info_dump( return nothing end +include("MLJInterface.jl") +using .MLJInterfaceModule: SRRegressor, MultitargetSRRegressor + +# Hack to get static analysis to work from within tests: +@ignore include("../test/runtests.jl") + +# TODO: Hack to force ConstructionBase version +using ConstructionBase: ConstructionBase as _ + +include("precompile.jl") +redirect_stdout(devnull) do + redirect_stderr(devnull) do + do_precompilation(Val(:precompile)) + end +end + end #module SR From e5f51058521440e8e9cf93cd8405399e33541197 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 3 Nov 2024 15:25:40 +0000 Subject: [PATCH 32/59] docs: improve readability of example --- examples/template_expression_complex.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/template_expression_complex.jl b/examples/template_expression_complex.jl index 47b3145e7..c5b255c9e 100644 --- a/examples/template_expression_complex.jl +++ b/examples/template_expression_complex.jl @@ -203,7 +203,7 @@ function compute_force((; B_x, B_y, B_z, F_d_scale), (t, v_x, v_y, v_z, T)) F_mag = [cross(vi, Bi) for (vi, Bi) in zip(v, B)] ## Finally, we combine the drag and magnetic forces into the total force: - F = map((fd, fm) -> Force((fd .+ fm)...), F_d, F_mag) + F = [Force((fd .+ fm)...) for (fd, fm) in zip(F_d, F_mag)] ## The output of this function needs to be another `ValidVector`, ## which carries through the validity of the evaluation. We compute From e5bfeffba254ec4d3d0d7aa1d9fefc6b17c78a31 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 17:12:46 +0000 Subject: [PATCH 33/59] fix: left arg in ComposableExpression --- src/ComposableExpression.jl | 4 ++- test/test_composable_expression.jl | 42 +++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index 7e5178c66..2f04ae2f2 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -189,7 +189,9 @@ function apply_operator(op::F, x...) where {F<:Function} vx = map(_get_value, x) return ValidVector(op.(vx...), true) else - return ValidVector(_get_value(first(x)), false) + example_vector = + something(map(xi -> xi isa ValidVector ? xi : nothing, x)...)::ValidVector + return ValidVector(_get_value(example_vector), false) end end _is_valid(x::ValidVector) = x.valid diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index 69a31c951..5c11265b0 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -57,7 +57,7 @@ end @test Interfaces.test(ExpressionInterface, TemplateExpression, [expr]) end -@testitem "Printing and evaluation of TemplateExpression" begin +@testitem "Printing and evaluation of TemplateExpression" tags = [:part2] begin using SymbolicRegression structure = TemplateStructure{(:f, :g)}( @@ -98,3 +98,43 @@ end # This is even though `g` is defined on `x1` only: @test g(x3_val) ≈ x3_val end + +@testitem "Test error handling" tags = [:part2] begin + using SymbolicRegression + using SymbolicRegression: ComposableExpression, Node, ValidVector + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + ex = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + + # Test error for unsupported input type with specific message + @test_throws "ComposableExpression does not support input of type String" ex( + "invalid input" + ) + + # Test ValidVector operations with numbers + x = ValidVector([1.0, 2.0, 3.0], true) + + # Test binary operations between ValidVector and Number + @test (x + 2.0).x ≈ [3.0, 4.0, 5.0] + @test (2.0 + x).x ≈ [3.0, 4.0, 5.0] + @test (x * 2.0).x ≈ [2.0, 4.0, 6.0] + @test (2.0 * x).x ≈ [2.0, 4.0, 6.0] + + # Test unary operations on ValidVector + @test sin(x).x ≈ sin.([1.0, 2.0, 3.0]) + @test cos(x).x ≈ cos.([1.0, 2.0, 3.0]) + @test abs(x).x ≈ [1.0, 2.0, 3.0] + @test (-x).x ≈ [-1.0, -2.0, -3.0] + + # Test propagation of invalid flag + invalid_x = ValidVector([1.0, 2.0, 3.0], false) + @test !((invalid_x + 2.0).valid) + @test !((2.0 + invalid_x).valid) + @test !(sin(invalid_x).valid) + + # Test that regular numbers are considered valid + @test (x + 2).valid + @test sin(x).valid +end From 91fbee96127bbe94d0fb4e8efe8b6111d941a272 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 19:01:58 +0000 Subject: [PATCH 34/59] fix: JET error --- src/SymbolicRegression.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index d2a8bb9c1..84341970a 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -933,7 +933,7 @@ function _main_search_loop!( options, total_cycles, cycles_remaining=state.cycles_remaining[j] ) move_window!(state.all_running_search_statistics[j]) - if progress_bar !== nothing + if !isnothing(progress_bar) head_node_occupation = estimate_work_fraction(resource_monitor) update_progress_bar!( progress_bar, @@ -1003,7 +1003,7 @@ function _main_search_loop!( end ################################################################ end - if ropt.progress + if !isnothing(progress_bar) finish!(progress_bar) end return nothing From b7f862213632f745911853fe03040a161172b159 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 19:29:23 +0000 Subject: [PATCH 35/59] test: other validity checks --- test/test_composable_expression.jl | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index 5c11265b0..b61b907d7 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -138,3 +138,24 @@ end @test (x + 2).valid @test sin(x).valid end +@testitem "Test validity propagation with NaN" tags = [:part2] begin + using SymbolicRegression: ComposableExpression, Node, ValidVector + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) + variable_names = (i -> "x$i").(1:3) + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) + x3 = ComposableExpression(Node{Float64}(; feature=3); operators, variable_names) + + ex = 1.0 + x2 / x1 + + @test ex([1.0], [2.0]) ≈ [3.0] + + @test ex([1.0, 1.0], [2.0, 2.0]) |> Base.Fix1(count, isnan) == 0 + @test ex([1.0, 0.0], [2.0, 2.0]) |> Base.Fix1(count, isnan) == 2 + + x1_val = ValidVector([1.0, 2.0], false) + x2_val = ValidVector([1.0, 2.0], false) + @test ex(x1_val, x2_val).valid == false +end From 9a4fddd4a642a18f2c63920a40d8739c9c2f697b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 20:05:06 +0000 Subject: [PATCH 36/59] refactor: clean up imports with ExplicitImports --- src/AdaptiveParsimony.jl | 2 +- src/ExpressionBuilder.jl | 14 ++------------ src/HallOfFame.jl | 3 +-- src/Mutate.jl | 2 -- src/OptionsStruct.jl | 2 -- src/ParametricExpression.jl | 5 +---- src/ProgressBars.jl | 2 +- src/SearchUtils.jl | 3 +-- src/SymbolicRegression.jl | 3 ++- src/TemplateExpression.jl | 2 -- 10 files changed, 9 insertions(+), 29 deletions(-) diff --git a/src/AdaptiveParsimony.jl b/src/AdaptiveParsimony.jl index e3fded95c..f45891faa 100644 --- a/src/AdaptiveParsimony.jl +++ b/src/AdaptiveParsimony.jl @@ -1,6 +1,6 @@ module AdaptiveParsimonyModule -using ..CoreModule: AbstractOptions, MAX_DEGREE +using ..CoreModule: AbstractOptions """ RunningSearchStatistics diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl index d7bc5f5d6..00264be6a 100644 --- a/src/ExpressionBuilder.jl +++ b/src/ExpressionBuilder.jl @@ -7,19 +7,9 @@ module ExpressionBuilderModule using DispatchDoctor: @unstable using Compat: Fix using DynamicExpressions: - AbstractExpressionNode, - AbstractExpression, - Expression, - constructorof, - get_tree, - get_contents, - get_metadata, - with_contents, - with_metadata, - count_scalar_constants, - eval_tree_array + AbstractExpressionNode, AbstractExpression, constructorof, with_metadata using StatsBase: StatsBase -using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE +using ..CoreModule: AbstractOptions, Dataset using ..HallOfFameModule: HallOfFame using ..PopulationModule: Population using ..PopMemberModule: PopMember diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index d65fc73a4..44acf2ea7 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -3,8 +3,7 @@ module HallOfFameModule using StyledStrings: @styled_str using DynamicExpressions: AbstractExpression, string_tree using ..UtilsModule: split_string, AnnotatedIOBuffer, dump_buffer -using ..CoreModule: - MAX_DEGREE, AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression +using ..CoreModule: AbstractOptions, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression using ..ComplexityModule: compute_complexity using ..PopMemberModule: PopMember using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING diff --git a/src/Mutate.jl b/src/Mutate.jl index 8c24de72d..e3ab8993f 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -2,7 +2,6 @@ module MutateModule using DynamicExpressions: AbstractExpression, - with_contents, get_tree, preserve_sharing, count_scalar_constants, @@ -21,7 +20,6 @@ using ..CheckConstraintsModule: check_constraints using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..PopMemberModule: PopMember using ..MutationFunctionsModule: - gen_random_tree_fixed_size, mutate_constant, mutate_operator, swap_operands, diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index 4f78a53a2..22871606d 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -6,8 +6,6 @@ using DynamicExpressions: AbstractOperatorEnum, AbstractExpressionNode, AbstractExpression, OperatorEnum using LossFunctions: SupervisedLoss -using ..DatasetModule: Dataset -import ..DatasetModule: max_features import ..MutationWeightsModule: AbstractMutationWeights """ diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index f98a1de08..0f9c92c94 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -6,15 +6,12 @@ module ParametricExpressionModule using DynamicExpressions: DynamicExpressions as DE, - AbstractExpression, ParametricExpression, ParametricNode, get_metadata, - with_metadata, get_contents, with_contents, - get_tree, - eval_tree_array + get_tree using StatsBase: StatsBase using Random: default_rng, AbstractRNG diff --git a/src/ProgressBars.jl b/src/ProgressBars.jl index 7214399f2..c32b0c82f 100644 --- a/src/ProgressBars.jl +++ b/src/ProgressBars.jl @@ -1,7 +1,7 @@ module ProgressBarsModule using Compat: Fix -using ProgressMeter: ProgressMeter, Progress, next!, finish! +using ProgressMeter: ProgressMeter, Progress, next! using StyledStrings: @styled_str, annotatedstring using ..UtilsModule: AnnotatedString diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 6f7d3a991..835191250 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -9,11 +9,10 @@ using Distributed: Distributed, @spawnat, Future, procs, addprocs using StatsBase: mean using StyledStrings: @styled_str using DispatchDoctor: @unstable -using Compat: Fix using DynamicExpressions: AbstractExpression, string_tree using ..UtilsModule: subscriptify -using ..CoreModule: Dataset, AbstractOptions, Options, MAX_DEGREE, RecordType, max_features +using ..CoreModule: Dataset, AbstractOptions, Options, RecordType, max_features using ..ComplexityModule: compute_complexity using ..PopulationModule: Population using ..PopMemberModule: PopMember diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 84341970a..f06e81f5c 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -86,6 +86,7 @@ using Pkg: Pkg using TOML: parsefile using Random: seed!, shuffle! using Reexport +using ProgressMeter: finish! using DynamicExpressions: Node, GraphNode, @@ -276,7 +277,7 @@ using .HallOfFameModule: HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve using .MutateModule: mutate!, condition_mutation_weights!, MutationResult using .SingleIterationModule: s_r_cycle, optimize_and_simplify_population -using .ProgressBarsModule: WrappedProgressBar, finish! +using .ProgressBarsModule: WrappedProgressBar using .RecorderModule: @recorder, find_iteration_from_record using .MigrationModule: migrate! using .SearchUtilsModule: diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index b39116529..05d91a132 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -11,7 +11,6 @@ using DynamicExpressions: AbstractExpression, AbstractOperatorEnum, OperatorEnum, - Expression, Metadata, get_contents, with_contents, @@ -20,7 +19,6 @@ using DynamicExpressions: get_variable_names, get_tree, node_type, - eval_tree_array, count_nodes using DynamicExpressions.InterfacesModule: ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments From 1d22351c67b7eee78b7ec07fb6148d046586b1d5 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 20:12:03 +0000 Subject: [PATCH 37/59] fix: validate degree 2 nans --- src/ComposableExpression.jl | 4 +++- test/test_composable_expression.jl | 30 ++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index 2f04ae2f2..d80484972 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -16,6 +16,7 @@ using DynamicExpressions: DynamicExpressions as DE using DynamicExpressions.InterfacesModule: ExpressionInterface, Interfaces, @implements, all_ei_methods_except, Arguments +using DynamicExpressions.ValueInterfaceModule: is_valid_array using ..ConstantOptimizationModule: ConstantOptimizationModule as CO @@ -187,7 +188,8 @@ end function apply_operator(op::F, x...) where {F<:Function} if all(_is_valid, x) vx = map(_get_value, x) - return ValidVector(op.(vx...), true) + result = op.(vx...) + return ValidVector(result, is_valid_array(result)) else example_vector = something(map(xi -> xi isa ValidVector ? xi : nothing, x)...)::ValidVector diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index b61b907d7..f63c1efe0 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -159,3 +159,33 @@ end x2_val = ValidVector([1.0, 2.0], false) @test ex(x1_val, x2_val).valid == false end + +@testitem "Test nothing return and type inference for TemplateExpression" tags = [:part2] begin + using SymbolicRegression + using Test: @inferred + + # Create a template expression that divides by x1 + structure = TemplateStructure{(:f,)}(((; f), (x1, x2)) -> 1.0 + f(x1) / x1) + operators = Options(; binary_operators=(+, -, *, /)).operators + variable_names = ["x1", "x2"] + + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) + expr = TemplateExpression((; f=x1); structure, operators, variable_names) + + # Test division by zero returns nothing + X = [0.0 1.0]' + @test expr(X) === nothing + + # Test type inference + X_good = [1.0 2.0]' + @test @inferred(Union{Nothing,Vector{Float64}}, expr(X_good)) ≈ [2.0] + + # Test type inference with ValidVector input + x1_val = ValidVector([1.0], true) + x2_val = ValidVector([2.0], true) + @test @inferred(ValidVector{Vector{Float64}}, x1(x1_val, x2_val)).x ≈ [1.0] + + x2_val_false = ValidVector([2.0], false) + @test @inferred(x1(x1_val, x2_val_false)).valid == false +end From 8050cd326d7231d47913fe8aa372d5790ed13053 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 20:28:25 +0000 Subject: [PATCH 38/59] fix: map to safe operators within ComposableExpression --- src/ComposableExpression.jl | 6 ++++-- src/Core.jl | 1 + src/Operators.jl | 9 +++++++++ test/test_composable_expression.jl | 18 ++++++++++++++++++ 4 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index d80484972..e1212aac4 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -19,6 +19,7 @@ using DynamicExpressions.InterfacesModule: using DynamicExpressions.ValueInterfaceModule: is_valid_array using ..ConstantOptimizationModule: ConstantOptimizationModule as CO +using ..CoreModule: get_safe_op abstract type AbstractComposableExpression{T,N} <: AbstractExpression{T,N} end @@ -185,10 +186,11 @@ end # Basically we want to vectorize every single operation on ValidVector, # so that the user can use it easily. -function apply_operator(op::F, x...) where {F<:Function} +function apply_operator(op::F, x::Vararg{Any,N}) where {F<:Function,N} if all(_is_valid, x) vx = map(_get_value, x) - result = op.(vx...) + safe_op = get_safe_op(op) + result = safe_op.(vx...) return ValidVector(result, is_valid_array(result)) else example_vector = diff --git a/src/Core.jl b/src/Core.jl index 8e56c0334..994ab23f2 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -21,6 +21,7 @@ using .OptionsStructModule: specialized_options, operator_specialization using .OperatorsModule: + get_safe_op, plus, sub, mult, diff --git a/src/Operators.jl b/src/Operators.jl index f99cc3bed..b38ccf97f 100644 --- a/src/Operators.jl +++ b/src/Operators.jl @@ -123,4 +123,13 @@ DE.declare_operator_alias(::typeof(safe_sqrt), ::Val{1}) = sqrt @ignore pow(x, y) = safe_pow(x, y) @ignore pow_abs(x, y) = safe_pow(x, y) +get_safe_op(op::F) where {F<:Function} = op +get_safe_op(::typeof(^)) = safe_pow +get_safe_op(::typeof(log)) = safe_log +get_safe_op(::typeof(log2)) = safe_log2 +get_safe_op(::typeof(log10)) = safe_log10 +get_safe_op(::typeof(log1p)) = safe_log1p +get_safe_op(::typeof(sqrt)) = safe_sqrt +get_safe_op(::typeof(acosh)) = safe_acosh + end diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index f63c1efe0..8d4a504f8 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -189,3 +189,21 @@ end x2_val_false = ValidVector([2.0], false) @test @inferred(x1(x1_val, x2_val_false)).valid == false end +@testitem "Test compatibility with power laws" tags = [:part3] begin + using SymbolicRegression + using DynamicExpressions: OperatorEnum + + operators = OperatorEnum(; binary_operators=(+, -, *, /, ^)) + variable_names = ["x1", "x2"] + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) + + structure = TemplateStructure{(:f,)}(((; f), (x1, x2)) -> f(x1)^f(x2)) + expr = TemplateExpression((; f=x1); structure, operators, variable_names) + + # There shouldn't be an error when we evaluate with invalid + # expressions, even though the source of the NaN comes from the structure + # function itself: + X = -rand(2, 32) + @test expr(X) === nothing +end From dea324dd8aa0ca7d20a098f01d72efbc337e630d Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 20:42:12 +0000 Subject: [PATCH 39/59] refactor: force specialization for composable expression --- src/ComposableExpression.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index e1212aac4..9aa6af25c 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -142,20 +142,25 @@ ValidVector(x::Tuple{Vararg{Any,2}}) = ValidVector(x...) function (ex::AbstractComposableExpression)(x) return error("ComposableExpression does not support input of type $(typeof(x))") end -function (ex::AbstractComposableExpression)(x::AbstractVector, _xs::AbstractVector...) - xs = (x, _xs...) +function (ex::AbstractComposableExpression)( + x::AbstractVector, _xs::Vararg{AbstractVector,N} +) where {N} + __xs = (x, _xs...) # Wrap it up for the recursive call - xs = ntuple(i -> ValidVector(xs[i], true), Val(length(xs))) + xs = map(Base.Fix2(ValidVector, true), __xs) result = ex(xs...) # Unwrap it if result.valid return result.x else + # TODO: Make this more general. Like checking if the eltype is numeric. nan = convert(eltype(result.x), NaN) return result.x .* nan end end -function (ex::AbstractComposableExpression)(x::ValidVector, _xs::ValidVector...) +function (ex::AbstractComposableExpression)( + x::ValidVector, _xs::Vararg{ValidVector,N} +) where {N} xs = (x, _xs...) valid = all(xi -> xi.valid, xs) if !valid @@ -166,8 +171,8 @@ function (ex::AbstractComposableExpression)(x::ValidVector, _xs::ValidVector...) end end function (ex::AbstractComposableExpression)( - x::AbstractComposableExpression, _xs::AbstractComposableExpression... -) + x::AbstractComposableExpression, _xs::Vararg{AbstractComposableExpression,N} +) where {N} xs = (x, _xs...) # To do this, we basically want to put the tree of x # into the position of variable 1, and so on! From 555a8dd3bfbd7607c46fedaf1e4af64cbc03e6da Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 20:49:22 +0000 Subject: [PATCH 40/59] refactor: fewer closures --- src/ComposableExpression.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/ComposableExpression.jl b/src/ComposableExpression.jl index 9aa6af25c..866b56e5d 100644 --- a/src/ComposableExpression.jl +++ b/src/ComposableExpression.jl @@ -150,23 +150,24 @@ function (ex::AbstractComposableExpression)( xs = map(Base.Fix2(ValidVector, true), __xs) result = ex(xs...) # Unwrap it - if result.valid - return result.x + if _is_valid(result) + return _get_value(result) else # TODO: Make this more general. Like checking if the eltype is numeric. - nan = convert(eltype(result.x), NaN) - return result.x .* nan + x = _get_value(result) + nan = convert(eltype(x), NaN) + return x .* nan end end function (ex::AbstractComposableExpression)( x::ValidVector, _xs::Vararg{ValidVector,N} ) where {N} xs = (x, _xs...) - valid = all(xi -> xi.valid, xs) + valid = all(_is_valid, xs) if !valid - return ValidVector(first(xs).x, false) + return ValidVector(_get_value(first(xs)), false) else - X = Matrix(stack(map(xi -> xi.x, xs))') + X = Matrix(stack(map(_get_value, xs))') return ValidVector(eval_tree_array(ex, X)) end end From 31cc3d28b6c29b5f4343c55963e0803d8c871be6 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 22:57:52 +0000 Subject: [PATCH 41/59] test: improve coverage for TemplateExpression --- test/test_composable_expression.jl | 63 ++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index 8d4a504f8..9a9a129ac 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -207,3 +207,66 @@ end X = -rand(2, 32) @test expr(X) === nothing end + +@testitem "Test constraints checking in TemplateExpression" tags = [:part2] begin + using SymbolicRegression + using SymbolicRegression: CheckConstraintsModule as CC + + # Create a template expression with nested exponentials + options = Options(; + binary_operators=(+, -, *, /), + unary_operators=(exp,), + nested_constraints=[exp => [exp => 1]], # Only allow one nested exp + ) + operators = options.operators + variable_names = ["x1", "x2"] + + # Create a valid inner expression + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + valid_expr = exp(x1) # One exp is ok + + # Create an invalid inner expression with too many nested exp + invalid_expr = exp(exp(exp(x1))) + # Three nested exp's violates constraint + + @test CC.check_constraints(valid_expr, options, 20) + @test !CC.check_constraints(invalid_expr, options, 20) +end + +@testitem "Test feature constraints in TemplateExpression" tags = [:part1] begin + using SymbolicRegression + using DynamicExpressions: Node + + operators = Options(; binary_operators=(+, -, *, /)).operators + variable_names = ["x1", "x2", "x3"] + + # Create a structure where f only gets access to x1, x2 + # and g only gets access to x3 + structure = TemplateStructure{(:f, :g)}(((; f, g), (x1, x2, x3)) -> f(x1, x2) + g(x3)) + + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) + x3 = ComposableExpression(Node{Float64}(; feature=3); operators, variable_names) + + # Test valid case - each function only uses allowed features + valid_f = x1 + x2 + valid_g = x1 + valid_template = TemplateExpression( + (; f=valid_f, g=valid_g); structure, operators, variable_names + ) + @test valid_template([1.0 2.0 3.0]') ≈ [6.0] # (1 + 2) + 3 + + # Test invalid case - f tries to use x3 which it shouldn't have access to + invalid_f = x1 + x3 + invalid_template = TemplateExpression( + (; f=invalid_f, g=valid_g); structure, operators, variable_names + ) + @test invalid_template([1.0 2.0 3.0]') === nothing + + # Test invalid case - g tries to use x2 which it shouldn't have access to + invalid_g = x2 + invalid_template2 = TemplateExpression( + (; f=valid_f, g=invalid_g); structure, operators, variable_names + ) + @test invalid_template2([1.0 2.0 3.0]') === nothing +end From 9a78079946d3b23421a0c3b8e4d42dcf98896ef1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 22:59:50 +0000 Subject: [PATCH 42/59] test: fix composable expression test --- test/test_composable_expression.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index 9a9a129ac..910a2a0a2 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -52,7 +52,7 @@ end expr = TemplateExpression((; f=x1, g=x2 * x2); structure, operators, variable_names) @test String(string_tree(expr)) == "f = #1; g = #2 * #2" - @test String(string_tree(expr; pretty=true)) == "f = #1\ng = #2 * #2" + @test String(string_tree(expr; pretty=true)) == "╭ f = #1\n╰ g = #2 * #2" @test string_tree(get_tree(expr), operators) == "x1 - (x1 * x1)" @test Interfaces.test(ExpressionInterface, TemplateExpression, [expr]) end From 4b99b67ecd2aed884cec0323e83d93902ebd9650 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 23:17:03 +0000 Subject: [PATCH 43/59] test: fix complexity tests --- test/test_complexity.jl | 1 + test/test_composable_expression.jl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_complexity.jl b/test/test_complexity.jl index 130271c2a..e3a3efd87 100644 --- a/test/test_complexity.jl +++ b/test/test_complexity.jl @@ -38,6 +38,7 @@ end options = make_options(; complexity_of_operators=[sin => 3, (+) => 2], complexity_of_variables=2 ) + tree = sin((x1 + x2 + x3)^2.3) @test compute_complexity(tree, options) == 12 + 3 * 1 options = make_options(; complexity_of_operators=[sin => 3, (+) => 2], diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index 910a2a0a2..56a78abb6 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -75,7 +75,7 @@ end expr = TemplateExpression((; f, g); structure, operators, variable_names) # Default printing strategy: - @test String(string_tree(expr)) == "f = x1 * x2\ng = x1" + @test String(string_tree(expr)) == "f = #1 * #2; g = #1" x1_val = randn(5) x2_val = randn(5) From 403f61476f29322d749904829a0a16d96dded479 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sat, 2 Nov 2024 23:44:38 +0000 Subject: [PATCH 44/59] test: errors for TemplateStructure --- test/test_composable_expression.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/test_composable_expression.jl b/test/test_composable_expression.jl index 56a78abb6..fc17a6b5b 100644 --- a/test/test_composable_expression.jl +++ b/test/test_composable_expression.jl @@ -270,3 +270,27 @@ end ) @test invalid_template2([1.0 2.0 3.0]') === nothing end +@testitem "Test invalid structure" tags = [:part3] begin + using SymbolicRegression + + operators = Options(; binary_operators=(+, -, *, /)).operators + variable_names = ["x1", "x2", "x3"] + + x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) + x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) + x3 = ComposableExpression(Node{Float64}(; feature=3); operators, variable_names) + + @test_throws ArgumentError TemplateStructure{(:f,)}( + ((; f), (x1, x2)) -> f(x1) + f(x1, x2) + ) + @test_throws "Inconsistent number of arguments passed to f" TemplateStructure{(:f,)}( + ((; f), (x1, x2)) -> f(x1) + f(x1, x2) + ) + + @test_throws ArgumentError TemplateStructure{(:f, :g)}(((; f, g), (x1, x2)) -> f(x1)) + @test_throws "Failed to infer number of features used by (:g,)" TemplateStructure{( + :f, :g + )}( + ((; f, g), (x1, x2)) -> f(x1) + ) +end From c1e403f64c9e5186e38f45e02ab5d23860bd2a34 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 3 Nov 2024 00:22:53 +0000 Subject: [PATCH 45/59] docs: update changelog --- CHANGELOG.md | 104 +++++++++++++++++++++++++-------------------------- 1 file changed, 50 insertions(+), 54 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 170c6dfc5..4bcd9da91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,68 +91,61 @@ A `TemplateExpression` is constructed by specifying: For example, you can create a `TemplateExpression` that enforces the constraint: `sin(f(x1, x2)) + g(x3)^2` - where we evolve `f` and `g` simultaneously. -Let's see some code for this. First, we define some base expressions for each input feature: +To do this, we first describe the structure using `TemplateStructure` +that takes a single closure function that maps a named tuple of +`ComposableExpression` expressions and a tuple of features: ```julia using SymbolicRegression -options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos)) -operators = options.operators -variable_names = ["x1", "x2", "x3"] - -# Base expressions: -x1 = Expression(Node{Float64}(; feature=1); operators, variable_names) -x2 = Expression(Node{Float64}(; feature=2); operators, variable_names) -x3 = Expression(Node{Float64}(; feature=3); operators, variable_names) +structure = TemplateStructure{(:f, :g)}( + ((; f, g), (x1, x2, x3)) -> sin(f(x1, x2)) + g(x3)^2 +) ``` -To build a `TemplateExpression`, we specify the structure using -a `TemplateStructure` object. This class has several fields: +This defines how the `TemplateExpression` should be +evaluated numerically on a given input. -- `combine`: Optional function taking a `NamedTuple` of function keys => expressions, - returning a single expression. Fallback method used by `get_tree` - on a `TemplateExpression` to generate a single `Expression`. -- `combine_vectors`: Optional function taking a `NamedTuple` of function keys => vectors, - returning a single vector. Used for evaluating the expression tree. - You may optionally define a method with a second argument `X` for if you wish - to include the data matrix `X` (of shape `[num_features, num_rows]`) in the - computation. -- `combine_strings`: Optional function taking a `NamedTuple` of function keys => strings, - returning a single string. Used for printing the expression tree. -- `variable_constraints`: Optional `NamedTuple` that defines which variables each sub-expression is allowed to access. - For example, requesting `f(x1, x2)` and `g(x3)` would be equivalent to `(; f=[1, 2], g=[3])`. - -Let's see an example: +The number of arguments allowed by each expression object +is inferred using this closure, though it can also +be passed explicitly with the `num_features` kwarg. ```julia - -# Combine f and g them into a single scalar expression: -structure = TemplateStructure(; - combine_strings=e -> "sin(" * e.f * ") + (" * e.g * ")^2", - combine_vectors=e -> map((f, g) -> sin(f) + g * g, e.f, e.g), - variable_constraints = (; f=[1, 2], g=[3]), # We constrain it to f(x1, x2) and g(x3) -) +operators = Options(binary_operators=(+, -, *, /)).operators +variable_names = ["x1", "x2", "x3"] +x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names) +x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names) +x3 = ComposableExpression(Node{Float64}(; feature=3); operators, variable_names) ``` -This defines how the `TemplateExpression` should be evaluated numerically on a given input, -and also how it should be represented as a string: +Note that using `x1` here refers to the +_relative_ argument to the expression. +So the node with feature equal to 1 will reference +the first argument, regardless of what it is. ```julia -julia> f_example = x1 - x2 * x2; # Normal `Expression` object - -julia> g_example = 1.5 * x3; - -julia> # Create TemplateExpression from these sub-expressions: - st_expr = TemplateExpression((; f=f_example, g=g_example); structure, operators, variable_names); +st_expr = TemplateExpression( + (; f=x1 - x2 * x2, g=1.5 * x1); + structure, + operators, + variable_names +) # Prints as: f = #1 - (#2 * #2); g = 1.5 * #1 + +# Evaluation combines evaluation of `f` and `g`, and combines them +# with the structure function: +st_expr([0.0; 1.0; 2.0;;]) +``` -julia> st_expr # Prints using `my_structure`! -sin(x1 - (x2 * x2)) + 1.5 * x3^2 +This also work with hierarchical expressions! For example, -julia> st_expr([0.0; 1.0; 2.0;;]) # Combines evaluation of `f` and `g` via `my_structure`! -1-element Vector{Float64}: - 8.158529015192103 +```julia +structure = TemplateStructure{(:f, :g)}( + ((; f, g), (x1, x2, x3)) -> f(x1, g(x2), x3^2) - g(x3) +) ``` +this is a valid structure! + We can also use this `TemplateExpression` in SymbolicRegression.jl searches!
@@ -168,11 +161,17 @@ This also has our variable mapping, which says we are fitting `f(x1, x2)`, `g1(x3)`, and `g2(x3)`: ```julia -structure = TemplateStructure(; - combine_strings=e -> "( " * e.f * " + " * e.g1 * ", " * e.f * " + " * e.g2 * " )", - combine_vectors=e -> map(i -> (e.f[i] + e.g1[i], e.f[i] + e.g2[i]), eachindex(e.f)), - variable_constraints = (; f=[1, 2], g1=[3], g2=[3]), -) +function my_structure((; f, g1, g2), (x1, x2, x3)) + _f = f(x1, x2) + _g1 = g1(x3) + _g2 = g2(x3) + + # We use `.x` to get the underlying vector + out = map((fi, g1i, g2i) -> (fi + g1i, fi + g2i), _f.x, _g1.x, _g2.x) + # And `.valid` to see whether the evaluations + return ValidVector(out, _f.valid && _g1.valid && _g2.valid) +end +structure = TemplateStructure{(:f, :g1, :g2)}(my_structure) ``` Now, our dataset is a regular 2D array of inputs for `X`. @@ -182,10 +181,7 @@ But our `y` is actually a _vector of 2-tuples_! X = rand(100, 3) .* 10 y = [ - ( - sin(X[i, 1]) + X[i, 3]^2, - sin(X[i, 1]) + X[i, 3] - ) + (sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3]) for i in eachindex(axes(X, 1)) ] ``` From 4a7bf35fef9da12187c614abed86e4074118587f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 3 Nov 2024 00:23:16 +0000 Subject: [PATCH 46/59] fix: validate keys of `num_features` --- src/TemplateExpression.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TemplateExpression.jl b/src/TemplateExpression.jl index 05d91a132..39ceab3d6 100644 --- a/src/TemplateExpression.jl +++ b/src/TemplateExpression.jl @@ -58,7 +58,7 @@ If not declared using the constructor `TemplateStructure{K}(...)`, the keys of t function. For example, if `f` takes two arguments, and `g` takes one, then `num_features = (; f=2, g=1)`. """ -struct TemplateStructure{K,E<:Function,NF<:NamedTuple} <: Function +struct TemplateStructure{K,E<:Function,NF<:NamedTuple{K}} <: Function combine::E num_features::NF end From 8e438cae7dba4c72fd61e96256d1bc653ff8501f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 3 Nov 2024 00:48:41 +0000 Subject: [PATCH 47/59] fix: move back NodeSampler to exports --- src/SymbolicRegression.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index f06e81f5c..b675b7e2e 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -16,6 +16,7 @@ export Population, TemplateStructure, ValidVector, ComposableExpression, + NodeSampler, AbstractExpression, AbstractExpressionNode, EvalOptions, @@ -165,7 +166,7 @@ using Compat: @compat, Fix AbstractOptions, AbstractRuntimeOptions, RuntimeOptions, AbstractMutationWeights, mutate!, condition_mutation_weights!, sample_mutation, MutationResult, AbstractSearchState, SearchState, - NodeSampler, LOSS_TYPE, DATA_TYPE, node_type, + LOSS_TYPE, DATA_TYPE, node_type, ) ) #! format: on From ca4c8d9ccad901502df6e2fdbc295e48d40d542e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 3 Nov 2024 04:35:34 +0000 Subject: [PATCH 48/59] test: remove old reference to test file --- test/runtests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 9c5ade786..db9676b16 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -129,7 +129,6 @@ end ENV["SYMBOLIC_REGRESSION_IS_TESTING"] = "true" include("../examples/parameterized_function.jl") end -include("test_template_expression.jl") @testitem "Testing whether the recorder works." tags = [:part3] begin include("test_recorder.jl") From fd29e64a5d84ac0d6dd431e12f0ae49480f8a63f Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 4 Nov 2024 13:57:28 -0500 Subject: [PATCH 49/59] refactor: top-level `get_safe_op` --- src/SymbolicRegression.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index b675b7e2e..bef4fcb98 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -231,6 +231,7 @@ using .CoreModule: ComplexityMapping, AbstractMutationWeights, MutationWeights, + get_safe_op, max_features, is_weighted, sample_mutation, From 1fcd5442f53c0ff1580b22575a08d83296bd701a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 4 Nov 2024 14:35:58 -0500 Subject: [PATCH 50/59] docs: tweak order --- docs/src/types.md | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/docs/src/types.md b/docs/src/types.md index 3e0d01c9b..baa9e3800 100644 --- a/docs/src/types.md +++ b/docs/src/types.md @@ -60,14 +60,6 @@ ParametricNode These types allow you to define expressions with parameters that can be tuned to fit the data better. You can specify the maximum number of parameters using the `expression_options` argument in `SRRegressor`. -## Composable Expressions - -Composable expressions allow you to combine multiple expressions together. - -```@docs -ComposableExpression -``` - ## Template Expressions Template expressions allow you to specify predefined structures and constraints for your expressions. @@ -81,6 +73,12 @@ TemplateExpression TemplateStructure ``` +Composable expressions allow you to combine multiple expressions together. + +```@docs +ComposableExpression +``` + ## Population Groups of equations are given as a population, which is From f75f1eeb921a2bd6725b97392a5aa586df0deb85 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 4 Nov 2024 14:36:27 -0500 Subject: [PATCH 51/59] refactor: remove some unused constants --- src/Core.jl | 3 +-- src/Dataset.jl | 6 +++--- src/ProgramConstants.jl | 3 --- src/SymbolicRegression.jl | 3 --- 4 files changed, 4 insertions(+), 11 deletions(-) diff --git a/src/Core.jl b/src/Core.jl index 994ab23f2..c442efc73 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -10,8 +10,7 @@ include("OptionsStruct.jl") include("Operators.jl") include("Options.jl") -using .ProgramConstantsModule: - MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE +using .ProgramConstantsModule: RecordType, DATA_TYPE, LOSS_TYPE using .DatasetModule: Dataset, is_weighted, has_units, max_features using .MutationWeightsModule: AbstractMutationWeights, MutationWeights, sample_mutation using .OptionsStructModule: diff --git a/src/Dataset.jl b/src/Dataset.jl index 65cc909cb..2818fd1db 100644 --- a/src/Dataset.jl +++ b/src/Dataset.jl @@ -3,7 +3,7 @@ module DatasetModule using DynamicQuantities: Quantity using ..UtilsModule: subscriptify, get_base_type -using ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE +using ..ProgramConstantsModule: DATA_TYPE, LOSS_TYPE using ...InterfaceDynamicQuantitiesModule: get_si_units, get_sym_units """ @@ -125,8 +125,8 @@ function Dataset( ) end - n = size(X, BATCH_DIM) - nfeatures = size(X, FEATURE_DIM) + n = size(X, 2) + nfeatures = size(X, 1) variable_names = @something(variable_names, ["x$(i)" for i in 1:nfeatures]) display_variable_names = @something( display_variable_names, ["x$(subscriptify(i))" for i in 1:nfeatures] diff --git a/src/ProgramConstants.jl b/src/ProgramConstants.jl index 607ce08b2..7ae2ccd7b 100644 --- a/src/ProgramConstants.jl +++ b/src/ProgramConstants.jl @@ -1,8 +1,5 @@ module ProgramConstantsModule -const MAX_DEGREE = 2 -const BATCH_DIM = 2 -const FEATURE_DIM = 1 const RecordType = Dict{String,Any} const DATA_TYPE = Number diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index bef4fcb98..62cf8626e 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -219,9 +219,6 @@ using DispatchDoctor: @stable end using .CoreModule: - MAX_DEGREE, - BATCH_DIM, - FEATURE_DIM, DATA_TYPE, LOSS_TYPE, RecordType, From 8889de74cff4e17e3216baa7c4704d5439be66e4 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 4 Nov 2024 17:29:30 -0500 Subject: [PATCH 52/59] test: weaken test condition --- examples/template_expression.jl | 4 ++-- test/test_mlj.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/template_expression.jl b/examples/template_expression.jl index e4140a6a7..5b1729229 100644 --- a/examples/template_expression.jl +++ b/examples/template_expression.jl @@ -32,12 +32,12 @@ y = [(sin(x1[i]) + x3[i]^2, sin(x1[i]) + x3[i]) for i in eachindex(x1, x2, x3)] model = SRRegressor(; binary_operators=(+, *), unary_operators=(sin,), - maxsize=15, + maxsize=20, expression_type=TemplateExpression, expression_options=(; structure), # The elementwise needs to operate directly on each row of `y`: elementwise_loss=((x1, x2), (y1, y2)) -> (y1 - x1)^2 + (y2 - x2)^2, - early_stop_condition=(loss, complexity) -> loss < 1e-5 && complexity <= 7, + early_stop_condition=(loss, complexity) -> loss < 1e-6 && complexity <= 7, ) mach = machine(model, [x1 x2 x3], y) diff --git a/test/test_mlj.jl b/test/test_mlj.jl index a4348fd28..ecc9f5331 100644 --- a/test/test_mlj.jl +++ b/test/test_mlj.jl @@ -144,7 +144,7 @@ end fit!(mach) # Check predictions - @test sum(abs2, predict(mach, X) .- Y) / length(X) < 1e-6 + @test sum(abs2, predict(mach, X) .- Y) / length(X) < 1e-5 # Load the output CSV file for i in 1:3 From 069c25c111b622b7e0910ad4b1b588191f5bc627 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 6 Nov 2024 20:48:31 -0500 Subject: [PATCH 53/59] docs: add parametrized function example --- docs/make.jl | 1 + examples/parameterized_function.jl | 93 +++++++++++++++++++++---- examples/template_expression_complex.jl | 4 +- 3 files changed, 83 insertions(+), 15 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 336f2fcb0..678d08f00 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -106,6 +106,7 @@ makedocs(; "Examples" => [ "Short Examples" => "examples.md", "Template Expressions" => "examples/template_expression.md", + "Parameterized Expressions" => "examples/parameterized_function.md", ], "API" => "api.md", "Losses" => "losses.md", diff --git a/examples/parameterized_function.jl b/examples/parameterized_function.jl index 9faefc97c..13d5fa370 100644 --- a/examples/parameterized_function.jl +++ b/examples/parameterized_function.jl @@ -1,21 +1,62 @@ +#literate_begin file="src/examples/parameterized_function.md" +#= +# Learning Parameterized Expressions + +_Note: Parametric expressions are currently considered experimental and may change in the future._ + +Parameterized expressions in SymbolicRegression.jl allow you to discover symbolic expressions that contain +optimizable parameters. This is particularly useful when you have data that follows different patterns +based on some categorical variable, or when you want to learn an expression with constants that should +be optimized during the search. + +In this tutorial, we'll generate synthetic data with class-dependent parameters and use symbolic regression to discover the parameterized expressions. + +## The Problem + +Let's create a synthetic dataset where the underlying function changes based on a class label: + +```math +y = 2\cos(x_2 + 0.1) + x_1^2 - 3.2 \ \ \ \ \text{[class 1]} \\ +\text{OR} \\ +y = 2\cos(x_2 + 1.5) + x_1^2 - 0.5 \ \ \ \ \text{[class 2]} +``` + +We will need to simultaneously learn the symbolic expression and per-class parameters! +=# using SymbolicRegression using Random: MersenneTwister using Zygote using MLJBase: machine, fit!, predict, report using Test -rng = MersenneTwister(0) -X = NamedTuple{(:x1, :x2, :x3, :x4, :x5)}(ntuple(_ -> randn(rng, Float32, 30), Val(5))) -X = (; X..., classes=rand(rng, 1:2, 30)) -p1 = [0.0f0, 3.2f0] -p2 = [1.5f0, 0.5f0] +#= +Now, we generate synthetic data, with these 2 different classes. + +Note that the `class` feature is given special treatment for the [`SRRegressor`](@ref) +as a categorical variable: +=# + +X = let rng = MersenneTwister(0), n = 30 + (; x1=randn(rng, n), x2=randn(rng, n), class=rand(rng, 1:2, n)) +end + +#= +Now, we generate target values using the true model that +has class-dependent parameters: +=# +y = let P1 = [0.1, 1.5], P2 = [3.2, 0.5] + [2 * cos(x2 + P1[class]) + x1^2 - P2[class] for (x1, x2, class) in zip(X.x1, X.x2, X.class)] +end -y = [ - 2 * cos(X.x4[i] + p1[X.classes[i]]) + X.x1[i]^2 - p2[X.classes[i]] for - i in eachindex(X.classes) -] +#= +## Setting up the Search -stop_at = Ref(1e-4) +We'll configure the symbolic regression search to: +- Use parameterized expressions with up to 2 parameters +- Use Zygote.jl for automatic differentiation during parameter optimization (important when using parametric expressions, as it is higher dimensional) +=# + +stop_at = Ref(1e-4) #src model = SRRegressor(; niterations=100, @@ -25,12 +66,38 @@ model = SRRegressor(; expression_type=ParametricExpression, expression_options=(; max_parameters=2), autodiff_backend=:Zygote, - parallelism=:multithreading, - early_stop_condition=(loss, _) -> loss < stop_at[], -) + early_stop_condition=(loss, _) -> loss < stop_at[], #src +); + +#= +Now, let's set up the machine and fit it: +=# mach = machine(model, X, y) +#= +At this point, you would run: + +```julia +fit!(mach) +``` + +You can extract the best expression and parameters with: + +```julia +report(mach).equations[end] +``` + +## Key Takeaways + +1. [`ParametricExpression`](@ref)s allows us to discover symbolic expressions with optimizable parameters +2. The parameters can capture class-dependent variations in the underlying model + +This approach is particularly useful when you suspect your data follows a common +functional form, but with varying parameters across different conditions or class! +=# +#literate_end + fit!(mach) idx1 = lastindex(report(mach).equations) ypred1 = predict(mach, (data=X, idx=idx1)) diff --git a/examples/template_expression_complex.jl b/examples/template_expression_complex.jl index c5b255c9e..8fad30577 100644 --- a/examples/template_expression_complex.jl +++ b/examples/template_expression_complex.jl @@ -7,9 +7,9 @@ Template expressions are a powerful feature in SymbolicRegression.jl that allow on the symbolic regression search. Rather than searching for a completely free-form expression, you can specify a template that combines multiple sub-expressions in a prescribed way. -This is particularly useful when: +This is particularly useful when any of the following are true: - You have domain knowledge about the functional form of your solution -- You want to learn vector-valued expressions (e.g., force fields, velocity fields) +- You want to learn expressions for a vector-valued output - You need to enforce constraints on which variables can appear in different parts of the expression - You want to share sub-expressions between multiple components From 0df70145c4734f33ab7f131e63c3e59c6bbdeed7 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 6 Nov 2024 20:50:06 -0500 Subject: [PATCH 54/59] refactor!: rename `classes` to `class` --- src/MLJInterface.jl | 47 ++++++++++++++--------------- src/ParametricExpression.jl | 6 ++-- test/test_expression_builder.jl | 4 +-- test/test_expression_derivatives.jl | 8 ++--- 4 files changed, 32 insertions(+), 33 deletions(-) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index b98a98f5b..84ae4d563 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -142,21 +142,21 @@ function MMI.update( options = old_fitresult === nothing ? get_options(m) : old_fitresult.options return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, nothing) end -function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, classes) - if isnothing(classes) && MMI.istable(X) && haskey(X, :classes) +function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, class) + if isnothing(class) && MMI.istable(X) && haskey(X, :class) if !(X isa NamedTuple) error("Classes can only be specified with named tuples.") end - new_X = Base.structdiff(X, (; X.classes)) - new_classes = X.classes + new_X = Base.structdiff(X, (; X.class)) + new_class = X.class return _update( - m, verbosity, old_fitresult, old_cache, new_X, y, w, options, new_classes + m, verbosity, old_fitresult, old_cache, new_X, y, w, options, new_class ) end if !isnothing(old_fitresult) @assert( - old_fitresult.has_classes == !isnothing(classes), - "If the first fit used classes, the second fit must also use classes." + old_fitresult.has_class == !isnothing(class), + "If the first fit used class, the second fit must also use class." ) end # To speed up iterative fits, we cache the types: @@ -210,7 +210,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, class X_units=X_units_clean, y_units=y_units_clean, verbosity=verbosity, - extra=isnothing(classes) ? (;) : (; classes), + extra=isnothing(class) ? (;) : (; class), # Help out with inference: v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2), ) @@ -221,7 +221,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, class variable_names=variable_names, y_variable_names=y_variable_names, y_is_table=MMI.istable(y), - has_classes=!isnothing(classes), + has_class=!isnothing(class), X_units=X_units_clean, y_units=y_units_clean, types=( @@ -377,17 +377,17 @@ end function eval_tree_mlj( tree::AbstractExpression, X_t, - classes, + class, m::AbstractSRRegressor, ::Type{T}, fitresult, i, prototype, ) where {T} - out, completed = if isnothing(classes) + out, completed = if isnothing(class) eval_tree_array(tree, X_t, fitresult.options) else - eval_tree_array(tree, X_t, classes, fitresult.options) + eval_tree_array(tree, X_t, class, fitresult.options) end if completed return wrap_units(out, fitresult.y_units, i) @@ -397,30 +397,29 @@ function eval_tree_mlj( end function MMI.predict( - m::M, fitresult, Xnew; idx=nothing, classes=nothing + m::M, fitresult, Xnew; idx=nothing, class=nothing ) where {M<:AbstractSRRegressor} - return _predict(m, fitresult, Xnew, idx, classes) + return _predict(m, fitresult, Xnew, idx, class) end -function _predict(m::M, fitresult, Xnew, idx, classes) where {M<:AbstractSRRegressor} +function _predict(m::M, fitresult, Xnew, idx, class) where {M<:AbstractSRRegressor} if Xnew isa NamedTuple && (haskey(Xnew, :idx) || haskey(Xnew, :data)) @assert( haskey(Xnew, :idx) && haskey(Xnew, :data) && length(keys(Xnew)) == 2, "If specifying an equation index during prediction, you must use a named tuple with keys `idx` and `data`." ) - return _predict(m, fitresult, Xnew.data, Xnew.idx, classes) + return _predict(m, fitresult, Xnew.data, Xnew.idx, class) end - if isnothing(classes) && MMI.istable(Xnew) && haskey(Xnew, :classes) + if isnothing(class) && MMI.istable(Xnew) && haskey(Xnew, :class) if !(Xnew isa NamedTuple) error("Classes can only be specified with named tuples.") end - Xnew2 = Base.structdiff(Xnew, (; Xnew.classes)) - return _predict(m, fitresult, Xnew2, idx, Xnew.classes) + Xnew2 = Base.structdiff(Xnew, (; Xnew.class)) + return _predict(m, fitresult, Xnew2, idx, Xnew.class) end - if fitresult.has_classes + if fitresult.has_class @assert( - !isnothing(classes), - "Classes must be specified if the model was fit with classes." + !isnothing(class), "Classes must be specified if the model was fit with class." ) end @@ -442,12 +441,12 @@ function _predict(m::M, fitresult, Xnew, idx, classes) where {M<:AbstractSRRegre if M <: SRRegressor return eval_tree_mlj( - params.equations[_idx], Xnew_t, classes, m, T, fitresult, nothing, prototype + params.equations[_idx], Xnew_t, class, m, T, fitresult, nothing, prototype ) elseif M <: MultitargetSRRegressor outs = [ eval_tree_mlj( - params.equations[i][_idx[i]], Xnew_t, classes, m, T, fitresult, i, prototype + params.equations[i][_idx[i]], Xnew_t, class, m, T, fitresult, i, prototype ) for i in eachindex(_idx, params.equations) ] out_matrix = reduce(hcat, outs) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 0f9c92c94..58d4d82c4 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -61,7 +61,7 @@ end function DE.eval_tree_array( tree::ParametricExpression, X::AbstractMatrix, - classes::AbstractVector{<:Integer}, + class::AbstractVector{<:Integer}, options::AbstractOptions; kws..., ) @@ -69,7 +69,7 @@ function DE.eval_tree_array( out, complete = DE.eval_tree_array( tree, X, - classes, + class, DE.get_operators(tree, options); turbo=options.turbo, bumper=options.bumper, @@ -84,7 +84,7 @@ function LF.eval_tree_dispatch( out, complete = DE.eval_tree_array( tree, LF.maybe_getindex(dataset.X, :, idx), - LF.maybe_getindex(dataset.extra.classes, idx), + LF.maybe_getindex(dataset.extra.class, idx), options.operators, ) return out::A, complete::Bool diff --git a/test/test_expression_builder.jl b/test/test_expression_builder.jl index 37b9291f3..50028ff4b 100644 --- a/test/test_expression_builder.jl +++ b/test/test_expression_builder.jl @@ -15,10 +15,10 @@ ) X = ones(1, 1) * 2 y = ones(1) - dataset = Dataset(X, y; extra=(; classes=[1])) + dataset = Dataset(X, y; extra=(; class=[1])) @test ex isa ParametricExpression - @test ex(dataset.X, dataset.extra.classes) ≈ ones(1, 1) * 6 + @test ex(dataset.X, dataset.extra.class) ≈ ones(1, 1) * 6 # Mistake in that we gave the wrong options! @test_throws( diff --git a/test/test_expression_derivatives.jl b/test/test_expression_derivatives.jl index c8cba75ae..359b405bf 100644 --- a/test/test_expression_derivatives.jl +++ b/test/test_expression_derivatives.jl @@ -84,18 +84,18 @@ end true_params = [0.5 2.0] init_params = [0.1 0.2] init_constants = [2.5, -0.5] - classes = rand(rng, 1:2, 32) + class = rand(rng, 1:2, 32) y = [ - X[1, i] * X[1, i] - cos(2.6 * X[2, i] - 0.2) + true_params[1, classes[i]] for + X[1, i] * X[1, i] - cos(2.6 * X[2, i] - 0.2) + true_params[1, class[i]] for i in 1:32 ] - dataset = Dataset(X, y; extra=(; classes)) + dataset = Dataset(X, y; extra=(; class)) (true_val, (true_d_params, true_d_constants)) = value_and_gradient(AutoZygote(), (init_params, init_constants)) do (params, c) pred = [ - X[1, i] * X[1, i] - cos(c[1] * X[2, i] + c[2]) + params[1, classes[i]] for + X[1, i] * X[1, i] - cos(c[1] * X[2, i] + c[2]) + params[1, class[i]] for i in 1:32 ] sum(abs2, pred .- y) / length(y) From afe6de15354cdcc30aef220e922b91721ed4b05b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 6 Nov 2024 20:55:32 -0500 Subject: [PATCH 55/59] docs: add more deps --- docs/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index 6399bf082..f66ed72c1 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,7 +3,9 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" Gumbo = "708ec375-b3d6-5a57-a7ce-8257bf98657a" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Documenter = "0.27" From a59d77a17b1e5536171a3aa940dc4fa751cbeaac Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 6 Nov 2024 21:20:54 -0500 Subject: [PATCH 56/59] fix: reference to classes --- src/ParametricExpression.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ParametricExpression.jl b/src/ParametricExpression.jl index 58d4d82c4..a5664fd45 100644 --- a/src/ParametricExpression.jl +++ b/src/ParametricExpression.jl @@ -32,7 +32,7 @@ function EB.extra_init_params( ::Val{embed}, ) where {T,embed,E<:ParametricExpression} num_params = options.expression_options.max_parameters - num_classes = length(unique(dataset.extra.classes)) + num_classes = length(unique(dataset.extra.class)) parameter_names = embed ? ["p$i" for i in 1:num_params] : nothing _parameters = if prototype === nothing randn(T, (num_params, num_classes)) From e7e2e0b15d3ac7b159720586ed1304e9ccf92f8a Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 6 Nov 2024 21:35:30 -0500 Subject: [PATCH 57/59] test: coverage of complexity mapping --- test/test_custom_operators_multiprocessing.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/test_custom_operators_multiprocessing.jl b/test/test_custom_operators_multiprocessing.jl index 2fca2298e..cae405c38 100644 --- a/test/test_custom_operators_multiprocessing.jl +++ b/test/test_custom_operators_multiprocessing.jl @@ -1,5 +1,7 @@ using SymbolicRegression +const used_complexity = Ref(false) + defs = quote _plus(x, y) = x + y _mult(x, y) = x * y @@ -9,12 +11,16 @@ defs = quote _exp(x) = exp(x) early_stop(loss, c) = ((loss <= 1e-10) && (c <= 10)) my_loss(x, y, w) = abs(x - y)^2 * w + my_complexity(ex) = (used_complexity[] = true; length(get_tree(ex))) end # This is needed as workers are initialized in `Core.Main`! if (@__MODULE__) != Core.Main Core.eval(Core.Main, defs) - eval(:(using Main: _plus, _mult, _div, _min, _cos, _exp, early_stop, my_loss)) + eval( + :(using Main: + _plus, _mult, _div, _min, _cos, _exp, early_stop, my_loss, my_complexity), + ) else eval(defs) end @@ -28,6 +34,7 @@ options = SymbolicRegression.Options(; populations=20, early_stop_condition=early_stop, elementwise_loss=my_loss, + complexity_mapping=my_complexity, ) hof = equation_search( From 9f0261de8d5fbc9d8dc4979887464a653e6ea7fb Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Wed, 6 Nov 2024 23:40:53 -0500 Subject: [PATCH 58/59] fix: copying complexity function to worker --- src/Configure.jl | 2 +- test/test_custom_operators_multiprocessing.jl | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/Configure.jl b/src/Configure.jl index 61c66d3e5..1151c4cb2 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -158,7 +158,7 @@ function move_functions_to_workers( ops = (options.loss_function,) example_inputs = (Node(T; val=zero(T)), dataset, options) elseif function_set == :complexity_mapping - if options.complexity_mapping isa Union{ComplexityMapping,Function} + if !(options.complexity_mapping isa Function) continue end ops = (options.complexity_mapping,) diff --git a/test/test_custom_operators_multiprocessing.jl b/test/test_custom_operators_multiprocessing.jl index cae405c38..7e18a39c8 100644 --- a/test/test_custom_operators_multiprocessing.jl +++ b/test/test_custom_operators_multiprocessing.jl @@ -1,7 +1,5 @@ using SymbolicRegression -const used_complexity = Ref(false) - defs = quote _plus(x, y) = x + y _mult(x, y) = x * y @@ -11,7 +9,7 @@ defs = quote _exp(x) = exp(x) early_stop(loss, c) = ((loss <= 1e-10) && (c <= 10)) my_loss(x, y, w) = abs(x - y)^2 * w - my_complexity(ex) = (used_complexity[] = true; length(get_tree(ex))) + my_complexity(ex) = length($(get_tree)(ex)) end # This is needed as workers are initialized in `Core.Main`! @@ -19,7 +17,16 @@ if (@__MODULE__) != Core.Main Core.eval(Core.Main, defs) eval( :(using Main: - _plus, _mult, _div, _min, _cos, _exp, early_stop, my_loss, my_complexity), + get_tree, + _plus, + _mult, + _div, + _min, + _cos, + _exp, + early_stop, + my_loss, + my_complexity), ) else eval(defs) From 113f2c6426cdb422b980e7d58fbefbd63a87e839 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Thu, 7 Nov 2024 01:30:37 -0500 Subject: [PATCH 59/59] test: fix `get_tree` --- test/test_custom_operators_multiprocessing.jl | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/test/test_custom_operators_multiprocessing.jl b/test/test_custom_operators_multiprocessing.jl index 7e18a39c8..22c978771 100644 --- a/test/test_custom_operators_multiprocessing.jl +++ b/test/test_custom_operators_multiprocessing.jl @@ -1,4 +1,5 @@ using SymbolicRegression +using Test defs = quote _plus(x, y) = x + y @@ -7,9 +8,9 @@ defs = quote _min(x, y) = x - y _cos(x) = cos(x) _exp(x) = exp(x) - early_stop(loss, c) = ((loss <= 1e-10) && (c <= 10)) + early_stop(loss, c) = ((loss <= 1e-10) && (c <= 6)) my_loss(x, y, w) = abs(x - y)^2 * w - my_complexity(ex) = length($(get_tree)(ex)) + my_complexity(ex) = ceil(Int, length($(get_tree)(ex)) / 2) end # This is needed as workers are initialized in `Core.Main`! @@ -17,16 +18,7 @@ if (@__MODULE__) != Core.Main Core.eval(Core.Main, defs) eval( :(using Main: - get_tree, - _plus, - _mult, - _div, - _min, - _cos, - _exp, - early_stop, - my_loss, - my_complexity), + _plus, _mult, _div, _min, _cos, _exp, early_stop, my_loss, my_complexity), ) else eval(defs) @@ -39,6 +31,7 @@ options = SymbolicRegression.Options(; binary_operators=(_plus, _mult, _div, _min), unary_operators=(_cos, _exp), populations=20, + maxsize=15, early_stop_condition=early_stop, elementwise_loss=my_loss, complexity_mapping=my_complexity, @@ -55,5 +48,6 @@ hof = equation_search( ) @test any( - early_stop(member.loss, count_nodes(member.tree)) for member in hof.members[hof.exists] + early_stop(member.loss, my_complexity(member.tree)) for + member in hof.members[hof.exists] )