Skip to content

Commit

Permalink
Merge pull request #271 from MilesCranmer/graph-nodes
Browse files Browse the repository at this point in the history
Program synthesis/graph expressions
  • Loading branch information
MilesCranmer authored Mar 20, 2024
2 parents b26f7f6 + 89daeff commit ae848a3
Show file tree
Hide file tree
Showing 49 changed files with 1,279 additions and 971 deletions.
18 changes: 11 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymbolicRegression"
uuid = "8254be44-1295-4e6a-a16d-46603ac705cb"
authors = ["MilesCranmer <[email protected]>"]
version = "0.23.3"
version = "0.24.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down Expand Up @@ -36,19 +36,21 @@ SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils"

[compat]
Aqua = "0.7"
Bumper = "0.6"
Compat = "^4.2"
DynamicExpressions = "0.13"
DynamicQuantities = "0.10"
DynamicExpressions = "0.16"
DynamicQuantities = "0.10, 0.11, 0.12, 0.13"
JSON3 = "1"
LineSearches = "7"
LoopVectorization = "0.12"
LossFunctions = "0.10, 0.11"
MLJModelInterface = "1.5, 1.6, 1.7, 1.8"
MLJModelInterface = "~1.5, ~1.6, ~1.7, ~1.8"
MacroTools = "0.4, 0.5"
Optim = "0.19, 1.1 - 1.7.6"
Optim = "~1.8, ~1.9"
PackageExtensionCompat = "1"
Pkg = "1"
PrecompileTools = "1"
ProgressBars = "1.4"
ProgressBars = "~1.4"
Reexport = "1"
SpecialFunctions = "0.10.1, 1, 2"
StatsBase = "0.33, 0.34"
Expand All @@ -58,9 +60,11 @@ julia = "1.6"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand All @@ -70,4 +74,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "SafeTestsets", "Aqua", "ForwardDiff", "LinearAlgebra", "JSON3", "MLJBase", "MLJTestInterface", "Suppressor", "SymbolicUtils", "Zygote"]
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "JSON3", "MLJBase", "MLJTestInterface", "Suppressor", "SymbolicUtils", "Zygote"]
2 changes: 2 additions & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
11 changes: 11 additions & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using SymbolicRegression.AdaptiveParsimonyModule: RunningSearchStatistics
using SymbolicRegression.PopulationModule: best_of_sample
using SymbolicRegression.ConstantOptimizationModule: optimize_constants
using SymbolicRegression.CheckConstraintsModule: check_constraints
using Bumper, LoopVectorization

function create_search_benchmark()
suite = BenchmarkGroup()
Expand All @@ -27,9 +28,19 @@ function create_search_benchmark()
maxsize=30,
verbosity=0,
progress=false,
mutation_weights=MutationWeights(),
loss=(pred, target) -> (pred - target)^2,
extra_kws...,
)
if hasfield(MutationWeights, :swap_operands)
option_kws.mutation_weights.swap_operands = 0.0
end
if hasfield(MutationWeights, :form_connection)
option_kws.mutation_weights.form_connection = 0.0
end
if hasfield(MutationWeights, :break_connection)
option_kws.mutation_weights.break_connection = 0.0
end
seeds = 1:3
niterations = 30
# We create an equation that cannot be found exactly, so the search
Expand Down
3 changes: 1 addition & 2 deletions example.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SymbolicRegression, SymbolicUtils
using SymbolicRegression

X = randn(Float32, 5, 100)
y = 2 * cos.(X[4, :]) + X[1, :] .^ 2 .- 2
Expand All @@ -18,7 +18,6 @@ trees = [member.tree for member in dominating]
tree = trees[end]
output, did_succeed = eval_tree_array(tree, X, options)

eqn = node_to_symbolic(dominating[end].tree, options)
println("Complexity\tMSE\tEquation")

for member in dominating
Expand Down
28 changes: 17 additions & 11 deletions ext/SymbolicRegressionSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
module SymbolicRegressionSymbolicUtilsExt

using SymbolicUtils: Symbolic
using SymbolicRegression: Node, Options
using SymbolicRegression: AbstractExpressionNode, Node, Options
using SymbolicRegression.MLJInterfaceModule: AbstractSRRegressor, get_options

import SymbolicRegression: node_to_symbolic, symbolic_to_node

"""
node_to_symbolic(tree::Node, options::Options; kws...)
node_to_symbolic(tree::AbstractExpressionNode, options::Options; kws...)
Convert an expression to SymbolicUtils.jl form.
"""
function node_to_symbolic(tree::Node, options::Options; kws...)
function node_to_symbolic(tree::AbstractExpressionNode, options::Options; kws...)
return node_to_symbolic(tree, options.operators; kws...)
end
function node_to_symbolic(tree::Node, m::AbstractSRRegressor; kws...)
function node_to_symbolic(tree::AbstractExpressionNode, m::AbstractSRRegressor; kws...)
return node_to_symbolic(tree, get_options(m); kws...)
end

Expand All @@ -30,20 +30,26 @@ function symbolic_to_node(eqn::Symbolic, m::AbstractSRRegressor; kws...)
return symbolic_to_node(eqn, get_options(m); kws...)
end

function Base.convert(::Type{Symbolic}, tree::Node, options::Options; kws...)
function Base.convert(
::Type{Symbolic}, tree::AbstractExpressionNode, options::Options; kws...
)
return convert(Symbolic, tree, options.operators; kws...)
end
function Base.convert(::Type{Symbolic}, tree::Node, m::AbstractSRRegressor; kws...)
function Base.convert(
::Type{Symbolic}, tree::AbstractExpressionNode, m::AbstractSRRegressor; kws...
)
return convert(Symbolic, tree, get_options(m); kws...)
end

function Base.convert(::Type{Node}, x::Union{Number,Symbolic}, options::Options; kws...)
return convert(Node, x, options.operators; kws...)
function Base.convert(
::Type{N}, x::Union{Number,Symbolic}, options::Options; kws...
) where {N<:AbstractExpressionNode}
return convert(N, x, options.operators; kws...)
end
function Base.convert(
::Type{Node}, x::Union{Number,Symbolic}, m::AbstractSRRegressor; kws...
)
return convert(Node, x, get_options(m); kws...)
::Type{N}, x::Union{Number,Symbolic}, m::AbstractSRRegressor; kws...
) where {N<:AbstractExpressionNode}
return convert(N, x, get_options(m); kws...)
end

end
2 changes: 1 addition & 1 deletion src/AdaptiveParsimony.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ equations, for use in adaptive losses and parsimony.
- `normalized_frequencies::Vector{Float64}`: This is the same as `frequencies`, but
normalized to sum to 1.0. This is updated once in a while.
"""
mutable struct RunningSearchStatistics
struct RunningSearchStatistics
window_size::Int
frequencies::Vector{Float64}
normalized_frequencies::Vector{Float64} # Stores `frequencies`, but normalized (updated once in a while)
Expand Down
23 changes: 16 additions & 7 deletions src/CheckConstraints.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
module CheckConstraintsModule

using DynamicExpressions: Node, count_depth, tree_mapreduce
using DynamicExpressions: AbstractExpressionNode, count_depth, tree_mapreduce
using ..UtilsModule: vals
using ..CoreModule: Options
using ..ComplexityModule: compute_complexity, past_complexity_limit

# Check if any binary operator are overly complex
function flag_bin_operator_complexity(tree::Node, op, cons, options::Options)::Bool
function flag_bin_operator_complexity(
tree::AbstractExpressionNode, op, cons, options::Options
)::Bool
any(tree) do subtree
if subtree.degree == 2 && subtree.op == op
cons[1] > -1 &&
Expand All @@ -24,7 +26,9 @@ end
Check if any unary operators are overly complex.
This assumes you have already checked whether the constraint is > -1.
"""
function flag_una_operator_complexity(tree::Node, op, cons, options::Options)::Bool
function flag_una_operator_complexity(
tree::AbstractExpressionNode, op, cons, options::Options
)::Bool
any(tree) do subtree
if subtree.degree == 1 && tree.op == op
past_complexity_limit(subtree.l, options, cons) && return true
Expand All @@ -34,19 +38,21 @@ function flag_una_operator_complexity(tree::Node, op, cons, options::Options)::B
end

function count_max_nestedness(tree, degree, op)
# TODO: Update this to correctly share nodes
nestedness = tree_mapreduce(
t -> 0, # Leafs
t -> (t.degree == degree && t.op == op) ? 1 : 0, # Branches
(p, c...) -> p + max(c...), # Reduce
tree,
tree;
break_sharing=Val(true),
)
# Remove count of self:
is_self = tree.degree == degree && tree.op == op
return nestedness - (is_self ? 1 : 0)
end

"""Check if there are any illegal combinations of operators"""
function flag_illegal_nests(tree::Node, options::Options)::Bool
function flag_illegal_nests(tree::AbstractExpressionNode, options::Options)::Bool
# We search from the top first, then from child nodes at end.
(nested_constraints = options.nested_constraints) === nothing && return false
for (degree, op_idx, op_constraint) in nested_constraints
Expand All @@ -65,7 +71,10 @@ end

"""Check if user-passed constraints are violated or not"""
function check_constraints(
tree::Node, options::Options, maxsize::Int, cursize::Union{Int,Nothing}=nothing
tree::AbstractExpressionNode,
options::Options,
maxsize::Int,
cursize::Union{Int,Nothing}=nothing,
)::Bool
((cursize === nothing) ? compute_complexity(tree, options) : cursize) > maxsize &&
return false
Expand All @@ -84,7 +93,7 @@ function check_constraints(
return true
end

check_constraints(tree::Node, options::Options)::Bool =
check_constraints(tree::AbstractExpressionNode, options::Options)::Bool =
check_constraints(tree, options, options.maxsize)

end
22 changes: 15 additions & 7 deletions src/Complexity.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
module ComplexityModule

using DynamicExpressions: Node, count_nodes, tree_mapreduce
using DynamicExpressions: AbstractExpressionNode, count_nodes, tree_mapreduce
using ..CoreModule: Options, ComplexityMapping

function past_complexity_limit(tree::Node, options::Options{CT}, limit)::Bool where {CT}
function past_complexity_limit(
tree::AbstractExpressionNode, options::Options{CT}, limit
)::Bool where {CT}
return compute_complexity(tree, options) > limit
end

Expand All @@ -14,16 +16,20 @@ By default, this is the number of nodes in a tree.
However, it could use the custom settings in options.complexity_mapping
if these are defined.
"""
function compute_complexity(tree::Node, options::Options{CT})::Int where {CT}
function compute_complexity(
tree::AbstractExpressionNode, options::Options{CT}; break_sharing=Val(false)
)::Int where {CT}
if options.complexity_mapping.use
raw_complexity = _compute_complexity(tree, options)
raw_complexity = _compute_complexity(tree, options; break_sharing)
return round(Int, raw_complexity)
else
return count_nodes(tree)
return count_nodes(tree; break_sharing)
end
end

function _compute_complexity(tree::Node, options::Options{CT})::CT where {CT}
function _compute_complexity(
tree::AbstractExpressionNode, options::Options{CT}; break_sharing=Val(false)
)::CT where {CT}
cmap = options.complexity_mapping
constant_complexity = cmap.constant_complexity
variable_complexity = cmap.variable_complexity
Expand All @@ -34,7 +40,9 @@ function _compute_complexity(tree::Node, options::Options{CT})::CT where {CT}
t -> t.degree == 1 ? unaop_complexities[t.op] : binop_complexities[t.op],
+,
tree,
CT,
CT;
break_sharing=break_sharing,
f_on_shared=(result, is_shared) -> is_shared ? result : zero(CT),
)
end

Expand Down
Loading

0 comments on commit ae848a3

Please sign in to comment.