Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite TemplateExpression to enable hierarchical expressions #365

Merged
merged 59 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
effab2c
feat: create `ComposableExpression`
MilesCranmer Oct 30, 2024
42935fc
feat: tweak names of internal types
MilesCranmer Oct 30, 2024
e3dad4f
test: composable expression
MilesCranmer Oct 30, 2024
a1e192c
feat: init hierarchical expression
MilesCranmer Nov 1, 2024
bc48fcc
feat: enable `VectorWrapper` for other operators
MilesCranmer Nov 1, 2024
c3aa38b
fix: HierarchicalExpression instabilities
MilesCranmer Nov 1, 2024
756a2d9
fix: need to freeze operators in HierarchicalExpression and Composabl…
MilesCranmer Nov 2, 2024
eca9b91
feat: validation of inferred constraints
MilesCranmer Nov 2, 2024
05419e4
feat: make hierarchical expressions compatible
MilesCranmer Nov 2, 2024
81b1870
feat: better printing for HierarchicalExpression
MilesCranmer Nov 2, 2024
b517a8d
feat: info dump at end of search
MilesCranmer Nov 2, 2024
d4c84dc
fix: correct return type for `get_tree`
MilesCranmer Nov 2, 2024
91fbee9
fix: JET error
MilesCranmer Nov 2, 2024
b7f8622
test: other validity checks
MilesCranmer Nov 2, 2024
15a6159
feat: print with `=` to not have breaks
MilesCranmer Nov 2, 2024
a05bb16
feat: ensure we save the full expression string
MilesCranmer Nov 2, 2024
609b7da
fix: switch to `pretty` over `raw`
MilesCranmer Nov 2, 2024
cf631f8
refactor!: fully deprecate varMap
MilesCranmer Nov 2, 2024
9a4fddd
refactor: clean up imports with ExplicitImports
MilesCranmer Nov 2, 2024
cdfaaca
fix: fix old use of `pretty`
MilesCranmer Nov 2, 2024
1d22351
fix: validate degree 2 nans
MilesCranmer Nov 2, 2024
8050cd3
fix: map to safe operators within ComposableExpression
MilesCranmer Nov 2, 2024
ee19066
feat: allow custom complexity functions
MilesCranmer Nov 2, 2024
4609e03
test: custom complexity function
MilesCranmer Nov 2, 2024
dea324d
refactor: force specialization for composable expression
MilesCranmer Nov 2, 2024
555a8dd
refactor: fewer closures
MilesCranmer Nov 2, 2024
666babd
feat: expose VectorWrapper
MilesCranmer Nov 2, 2024
8c3596b
refactor: more efficient mutations for hierarchical
MilesCranmer Nov 2, 2024
aa3b435
test: fix pretty print format
MilesCranmer Nov 2, 2024
9445bf4
refactor: name `ValidVector`
MilesCranmer Nov 2, 2024
2f9d17e
docs: document ValidVector and ComposableExpression
MilesCranmer Nov 2, 2024
dc2d509
feat!: move `HierarchicalExpression` into place of `TemplateExpression`
MilesCranmer Nov 2, 2024
fb1b733
style: formatting of template expression
MilesCranmer Nov 2, 2024
31cc3d2
test: improve coverage for TemplateExpression
MilesCranmer Nov 2, 2024
9a78079
test: fix composable expression test
MilesCranmer Nov 2, 2024
4b99b67
test: fix complexity tests
MilesCranmer Nov 2, 2024
403f614
test: errors for TemplateStructure
MilesCranmer Nov 2, 2024
c1e403f
docs: update changelog
MilesCranmer Nov 3, 2024
4a7bf35
fix: validate keys of `num_features`
MilesCranmer Nov 3, 2024
8e438ca
fix: move back NodeSampler to exports
MilesCranmer Nov 3, 2024
ca4c8d9
test: remove old reference to test file
MilesCranmer Nov 3, 2024
436ff23
docs: update docs for TemplateStructure
MilesCranmer Nov 3, 2024
6ef8bcf
feat: return `nothing` for invalid result rather than `NaN`
MilesCranmer Nov 3, 2024
7a0dbc2
test: fix missing `node_type`
MilesCranmer Nov 3, 2024
05f8678
docs: update docs for TemplateStructure
MilesCranmer Nov 3, 2024
d59de9b
fix: move `_info_dump` to end for precompilation
MilesCranmer Nov 3, 2024
e5f5105
docs: improve readability of example
MilesCranmer Nov 3, 2024
e5bfeff
fix: left arg in ComposableExpression
MilesCranmer Nov 2, 2024
fd29e64
refactor: top-level `get_safe_op`
MilesCranmer Nov 4, 2024
1fcd544
docs: tweak order
MilesCranmer Nov 4, 2024
f75f1ee
refactor: remove some unused constants
MilesCranmer Nov 4, 2024
8889de7
test: weaken test condition
MilesCranmer Nov 4, 2024
069c25c
docs: add parametrized function example
MilesCranmer Nov 7, 2024
0df7014
refactor!: rename `classes` to `class`
MilesCranmer Nov 7, 2024
afe6de1
docs: add more deps
MilesCranmer Nov 7, 2024
a59d77a
fix: reference to classes
MilesCranmer Nov 7, 2024
e7e2e0b
test: coverage of complexity mapping
MilesCranmer Nov 7, 2024
9f0261d
fix: copying complexity function to worker
MilesCranmer Nov 7, 2024
113f2c6
test: fix `get_tree`
MilesCranmer Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 50 additions & 54 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,68 +91,61 @@ A `TemplateExpression` is constructed by specifying:
For example, you can create a `TemplateExpression` that enforces
the constraint: `sin(f(x1, x2)) + g(x3)^2` - where we evolve `f` and `g` simultaneously.

Let's see some code for this. First, we define some base expressions for each input feature:
To do this, we first describe the structure using `TemplateStructure`
that takes a single closure function that maps a named tuple of
`ComposableExpression` expressions and a tuple of features:

```julia
using SymbolicRegression

options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos))
operators = options.operators
variable_names = ["x1", "x2", "x3"]

# Base expressions:
x1 = Expression(Node{Float64}(; feature=1); operators, variable_names)
x2 = Expression(Node{Float64}(; feature=2); operators, variable_names)
x3 = Expression(Node{Float64}(; feature=3); operators, variable_names)
structure = TemplateStructure{(:f, :g)}(
((; f, g), (x1, x2, x3)) -> sin(f(x1, x2)) + g(x3)^2
)
```

To build a `TemplateExpression`, we specify the structure using
a `TemplateStructure` object. This class has several fields:
This defines how the `TemplateExpression` should be
evaluated numerically on a given input.

- `combine`: Optional function taking a `NamedTuple` of function keys => expressions,
returning a single expression. Fallback method used by `get_tree`
on a `TemplateExpression` to generate a single `Expression`.
- `combine_vectors`: Optional function taking a `NamedTuple` of function keys => vectors,
returning a single vector. Used for evaluating the expression tree.
You may optionally define a method with a second argument `X` for if you wish
to include the data matrix `X` (of shape `[num_features, num_rows]`) in the
computation.
- `combine_strings`: Optional function taking a `NamedTuple` of function keys => strings,
returning a single string. Used for printing the expression tree.
- `variable_constraints`: Optional `NamedTuple` that defines which variables each sub-expression is allowed to access.
For example, requesting `f(x1, x2)` and `g(x3)` would be equivalent to `(; f=[1, 2], g=[3])`.

Let's see an example:
The number of arguments allowed by each expression object
is inferred using this closure, though it can also
be passed explicitly with the `num_features` kwarg.

```julia

# Combine f and g them into a single scalar expression:
structure = TemplateStructure(;
combine_strings=e -> "sin(" * e.f * ") + (" * e.g * ")^2",
combine_vectors=e -> map((f, g) -> sin(f) + g * g, e.f, e.g),
variable_constraints = (; f=[1, 2], g=[3]), # We constrain it to f(x1, x2) and g(x3)
)
operators = Options(binary_operators=(+, -, *, /)).operators
variable_names = ["x1", "x2", "x3"]
x1 = ComposableExpression(Node{Float64}(; feature=1); operators, variable_names)
x2 = ComposableExpression(Node{Float64}(; feature=2); operators, variable_names)
x3 = ComposableExpression(Node{Float64}(; feature=3); operators, variable_names)
```

This defines how the `TemplateExpression` should be evaluated numerically on a given input,
and also how it should be represented as a string:
Note that using `x1` here refers to the
_relative_ argument to the expression.
So the node with feature equal to 1 will reference
the first argument, regardless of what it is.

```julia
julia> f_example = x1 - x2 * x2; # Normal `Expression` object

julia> g_example = 1.5 * x3;

julia> # Create TemplateExpression from these sub-expressions:
st_expr = TemplateExpression((; f=f_example, g=g_example); structure, operators, variable_names);
st_expr = TemplateExpression(
(; f=x1 - x2 * x2, g=1.5 * x1);
structure,
operators,
variable_names
) # Prints as: f = #1 - (#2 * #2); g = 1.5 * #1

# Evaluation combines evaluation of `f` and `g`, and combines them
# with the structure function:
st_expr([0.0; 1.0; 2.0;;])
```

julia> st_expr # Prints using `my_structure`!
sin(x1 - (x2 * x2)) + 1.5 * x3^2
This also work with hierarchical expressions! For example,

julia> st_expr([0.0; 1.0; 2.0;;]) # Combines evaluation of `f` and `g` via `my_structure`!
1-element Vector{Float64}:
8.158529015192103
```julia
structure = TemplateStructure{(:f, :g)}(
((; f, g), (x1, x2, x3)) -> f(x1, g(x2), x3^2) - g(x3)
)
```

this is a valid structure!

We can also use this `TemplateExpression` in SymbolicRegression.jl searches!

<details>
Expand All @@ -168,11 +161,17 @@ This also has our variable mapping, which says
we are fitting `f(x1, x2)`, `g1(x3)`, and `g2(x3)`:

```julia
structure = TemplateStructure(;
combine_strings=e -> "( " * e.f * " + " * e.g1 * ", " * e.f * " + " * e.g2 * " )",
combine_vectors=e -> map(i -> (e.f[i] + e.g1[i], e.f[i] + e.g2[i]), eachindex(e.f)),
variable_constraints = (; f=[1, 2], g1=[3], g2=[3]),
)
function my_structure((; f, g1, g2), (x1, x2, x3))
_f = f(x1, x2)
_g1 = g1(x3)
_g2 = g2(x3)

# We use `.x` to get the underlying vector
out = map((fi, g1i, g2i) -> (fi + g1i, fi + g2i), _f.x, _g1.x, _g2.x)
# And `.valid` to see whether the evaluations
return ValidVector(out, _f.valid && _g1.valid && _g2.valid)
end
structure = TemplateStructure{(:f, :g1, :g2)}(my_structure)
```

Now, our dataset is a regular 2D array of inputs for `X`.
Expand All @@ -182,10 +181,7 @@ But our `y` is actually a _vector of 2-tuples_!
X = rand(100, 3) .* 10

y = [
(
sin(X[i, 1]) + X[i, 3]^2,
sin(X[i, 1]) + X[i, 3]
)
(sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3])
for i in eachindex(axes(X, 1))
]
```
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Dates = "1"
DifferentiationInterface = "0.5, 0.6"
DispatchDoctor = "^0.4.17"
Distributed = "<0.0.1, 1"
DynamicExpressions = "1.4"
DynamicExpressions = "1.5.0"
DynamicQuantities = "1"
Enzyme = "0.12"
JSON3 = "1"
Expand Down
27 changes: 8 additions & 19 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,32 +63,21 @@ These types allow you to define expressions with parameters that can be tuned to
## Template Expressions

Template expressions allow you to specify predefined structures and constraints for your expressions.
These use the new `TemplateStructure` type to define how expressions should be combined and evaluated.
These use `ComposableExpressions` as their internal expression type, which makes them
flexible for creating a structure out of a single function.

These use the `TemplateStructure` type to define how expressions should be combined and evaluated.

```@docs
TemplateExpression
TemplateStructure
```

Example usage:

```julia
# Define a template structure
structure = TemplateStructure(
combine=e -> e.f + e.g, # Create normal `Expression`
combine_vectors=e -> (e.f .+ e.g), # Output vector
combine_strings=e -> "($e.f) + ($e.g)", # Output string
variable_constraints=(; f=[1, 2], g=[3]) # Constrain dependencies
)

# Use in options
model = SRRegressor(;
expression_type=TemplateExpression,
expression_options=(; structure=structure)
)
```
Composable expressions allow you to combine multiple expressions together.

The `variable_constraints` field allows you to specify which variables can be used in different parts of the expression.
```@docs
ComposableExpression
```

## Population

Expand Down
36 changes: 22 additions & 14 deletions examples/template_expression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,41 @@ using Test: @test
options = Options(; binary_operators=(+, *, /, -), unary_operators=(sin, cos))
operators = options.operators
variable_names = (i -> "x$i").(1:3)
x1, x2, x3 = (i -> Expression(Node(Float64; feature=i); operators, variable_names)).(1:3)

structure = TemplateStructure{(:f, :g1, :g2)}(;
combine_vectors=e -> map((f, g1, g2) -> (f + g1, f + g2), e.f, e.g1, e.g2),
combine_strings=e -> "( $(e.f) + $(e.g1), $(e.f) + $(e.g2) )",
variable_constraints=(; f=[1, 2], g1=[3], g2=[3]),
x1, x2, x3 =
(i -> ComposableExpression(Node(Float64; feature=i); operators, variable_names)).(1:3)

structure = TemplateStructure{(:f, :g1, :g2)}(
((; f, g1, g2), (x1, x2, x3)) -> let
_f = f(x1, x2)
_g1 = g1(x3)
_g2 = g2(x3)
_out1 = _f + _g1
_out2 = _f + _g2
ValidVector(map(tuple, _out1.x, _out2.x), _out1.valid && _out2.valid)
end,
)

st_expr = TemplateExpression((; f=x1, g1=x3, g2=x3); structure, operators, variable_names)

X = rand(100, 3) .* 10
x1 = rand(100)
x2 = rand(100)
x3 = rand(100)

# Our dataset is a vector of 2-tuples
y = [(sin(X[i, 1]) + X[i, 3]^2, sin(X[i, 1]) + X[i, 3]) for i in eachindex(axes(X, 1))]
y = [(sin(x1[i]) + x3[i]^2, sin(x1[i]) + x3[i]) for i in eachindex(x1, x2, x3)]

model = SRRegressor(;
binary_operators=(+, *),
unary_operators=(sin,),
maxsize=15,
maxsize=20,
expression_type=TemplateExpression,
expression_options=(; structure),
# The elementwise needs to operate directly on each row of `y`:
elementwise_loss=((x1, x2), (y1, y2)) -> (y1 - x1)^2 + (y2 - x2)^2,
early_stop_condition=(loss, complexity) -> loss < 1e-5 && complexity <= 7,
early_stop_condition=(loss, complexity) -> loss < 1e-6 && complexity <= 7,
)

mach = machine(model, X, y)
mach = machine(model, [x1 x2 x3], y)
fit!(mach)

# Check the performance of the model
Expand All @@ -48,6 +56,6 @@ best_f = get_contents(best_expr).f
best_g1 = get_contents(best_expr).g1
best_g2 = get_contents(best_expr).g2

@test best_f(X') ≈ (@. sin(X[:, 1]))
@test best_g1(X') ≈ (@. X[:, 3] * X[:, 3])
@test best_g2(X') ≈ (@. X[:, 3])
@test best_f(x1, x2) ≈ @. sin.(x1)
@test best_g1(x3) ≈ (@. x3 * x3)
@test best_g2(x3) ≈ (@. x3)
Loading
Loading