Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create overloadable utilities: AbstractOptions, AbstractRuntimeOptions, AbstractMutationWeights, AbstractSearchState, and mutate! #353

Merged
merged 16 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ jobs:
- "part2"
- "part3"
julia-version:
- "1.6"
- "1.8"
- "1.10"
- "1"
os:
- ubuntu-latest
Expand All @@ -54,15 +53,6 @@ jobs:
- os: macOS-latest
julia-version: "1"
test: "part3"
- os: ubuntu-latest
julia-version: "~1.11.0-0"
test: "part1"
- os: ubuntu-latest
julia-version: "~1.11.0-0"
test: "part2"
- os: ubuntu-latest
julia-version: "~1.11.0-0"
test: "part3"

steps:
- uses: actions/checkout@v4
Expand Down
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down Expand Up @@ -58,7 +57,6 @@ LossFunctions = "0.10, 0.11"
MLJModelInterface = "~1.5, ~1.6, ~1.7, ~1.8, ~1.9, ~1.10, ~1.11"
MacroTools = "0.4, 0.5"
Optim = "~1.8, ~1.9"
PackageExtensionCompat = "1"
Pkg = "<0.0.1, 1"
PrecompileTools = "1"
Printf = "<0.0.1, 1"
Expand All @@ -69,7 +67,7 @@ SpecialFunctions = "0.10.1, 1, 2"
StatsBase = "0.33, 0.34"
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
TOML = "<0.0.1, 1"
julia = "1.6"
julia = "1.10"

[extras]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand Down
53 changes: 53 additions & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using BenchmarkTools
using SymbolicRegression, BenchmarkTools, Random
using SymbolicRegression.AdaptiveParsimonyModule: RunningSearchStatistics
using SymbolicRegression.MutateModule: next_generation
using SymbolicRegression.RecorderModule: RecordType
using SymbolicRegression.PopulationModule: best_of_sample
using SymbolicRegression.ConstantOptimizationModule: optimize_constants
using SymbolicRegression.CheckConstraintsModule: check_constraints
Expand Down Expand Up @@ -93,6 +95,57 @@ function create_utils_benchmark()
)
)

suite["next_generation_x100"] = @benchmarkable(
let
for member in members
next_generation(
dataset,
member,
temperature,
curmaxsize,
rss,
options;
tmp_recorder=recorder,
)
end
end,
setup = (
nfeatures = 1;
dataset = Dataset(randn(nfeatures, 32), randn(32));
mutation_weights = MutationWeights(;
mutate_constant=1.0,
mutate_operator=1.0,
swap_operands=1.0,
rotate_tree=1.0,
add_node=1.0,
insert_node=1.0,
simplify=0.0,
randomize=0.0,
do_nothing=0.0,
form_connection=0.0,
break_connection=0.0,
);
options = Options(;
unary_operators=[sin, cos], binary_operators=[+, -, *, /], mutation_weights
);
recorder = RecordType();
temperature = 1.0;
curmaxsize = 20;
rss = RunningSearchStatistics(; options);
trees = [
gen_random_tree_fixed_size(15, options, nfeatures, Float64) for _ in 1:100
];
expressions = [
Expression(tree; operators=options.operators, variable_names=["x1"]) for
tree in trees
];
members = [
PopMember(dataset, expression, options; deterministic=false) for
expression in expressions
]
)
)

ntrees = 10
suite["optimize_constants_x10"] = @benchmarkable(
foreach(members) do member
Expand Down
27 changes: 18 additions & 9 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
using Documenter
using SymbolicUtils
using SymbolicRegression
using SymbolicRegression: Dataset, update_baseline_loss!
using SymbolicRegression:
AbstractExpression,
ExpressionInterface,
Dataset,
update_baseline_loss!,
AbstractMutationWeights,
AbstractOptions,
mutate!,
condition_mutation_weights!,
sample_mutation,
MutationResult,
AbstractRuntimeOptions,
AbstractSearchState,
@extend_operators
using DynamicExpressions

DocMeta.setdocmeta!(
SymbolicRegression, :DocTestSetup, :(using LossFunctions); recursive=true
Expand Down Expand Up @@ -40,14 +54,8 @@ readme = replace(

# We prepend the `<table>` with a ```@raw html
# and append the `</table>` with a ```:
readme = replace(
readme,
r"<table>" => s"```@raw html\n<table>",
)
readme = replace(
readme,
r"</table>" => s"</table>\n```",
)
readme = replace(readme, r"<table>" => s"```@raw html\n<table>")
readme = replace(readme, r"</table>" => s"</table>\n```")

# Then, we surround ```mermaid\n...\n``` snippets
# with ```@raw html\n<div class="mermaid">\n...\n</div>```:
Expand Down Expand Up @@ -96,6 +104,7 @@ makedocs(;
"API" => "api.md",
"Losses" => "losses.md",
"Types" => "types.md",
"Customization" => "customization.md",
],
)

Expand Down
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ SRRegressor
MultitargetSRRegressor
```

## equation_search
## Low-Level API

```@docs
equation_search
Expand Down
61 changes: 61 additions & 0 deletions docs/src/customization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Customization

Many parts of SymbolicRegression.jl are designed to be customizable.

The normal way to do this in Julia is to define a new type that subtypes
an abstract type from a package, and then define new methods for the type,
extending internal methods on that type.

## Custom Options

For example, you can define a custom options type:

```@docs
AbstractOptions
```

Any function in SymbolicRegression.jl you can generally define a new method
on your custom options type, to define custom behavior.

## Custom Mutations

You can define custom mutation operators by defining a new method on
`mutate!`, as well as subtyping `AbstractMutationWeights`:

```@docs
mutate!
AbstractMutationWeights
condition_mutation_weights!
sample_mutation
MutationResult
```

## Custom Expressions

You can create your own expression types by defining a new type that extends `AbstractExpression`.

```@docs
AbstractExpression
ExpressionInterface
```

The interface is fairly flexible, and permits you define specific functional forms,
extra parameters, etc. See the documentation of DynamicExpressions.jl for more details on what
methods you need to implement. Then, for SymbolicRegression.jl, you would
pass `expression_type` to the `Options` constructor, as well as any
`expression_options` you need (as a `NamedTuple`).

If needed, you may need to overload `SymbolicRegression.ExpressionBuilder.extra_init_params` in
case your expression needs additional parameters. See the method for `ParametricExpression`
as an example.

## Other Customizations

Other internal abstract types include the following:

```@docs
AbstractRuntimeOptions
AbstractSearchState
```

These let you include custom state variables and runtime options.
18 changes: 0 additions & 18 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,24 +60,6 @@ ParametricNode

These types allow you to define expressions with parameters that can be tuned to fit the data better. You can specify the maximum number of parameters using the `expression_options` argument in `SRRegressor`.

## Custom Expressions

You can create your own expression types by defining a new type that extends `AbstractExpression`.

```@docs
AbstractExpression
```

The interface is fairly flexible, and permits you define specific functional forms,
extra parameters, etc. See the documentation of DynamicExpressions.jl for more details on what
methods you need to implement. Then, for SymbolicRegression.jl, you would
pass `expression_type` to the `Options` constructor, as well as any
`expression_options` you need (as a `NamedTuple`).

If needed, you may need to overload `SymbolicRegression.ExpressionBuilder.extra_init_params` in
case your expression needs additional parameters. See the method for `ParametricExpression`
as an example.

## Population

Groups of equations are given as a population, which is
Expand Down
4 changes: 2 additions & 2 deletions src/AdaptiveParsimony.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module AdaptiveParsimonyModule

using ..CoreModule: Options, MAX_DEGREE
using ..CoreModule: AbstractOptions, MAX_DEGREE

"""
RunningSearchStatistics
Expand All @@ -23,7 +23,7 @@ struct RunningSearchStatistics
normalized_frequencies::Vector{Float64} # Stores `frequencies`, but normalized (updated once in a while)
end

function RunningSearchStatistics(; options::Options, window_size::Int=100000)
function RunningSearchStatistics(; options::AbstractOptions, window_size::Int=100000)
maxsize = options.maxsize
actualMaxsize = maxsize + MAX_DEGREE
init_frequencies = ones(Float64, actualMaxsize)
Expand Down
14 changes: 7 additions & 7 deletions src/CheckConstraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ module CheckConstraintsModule

using DynamicExpressions:
AbstractExpressionNode, AbstractExpression, get_tree, count_depth, tree_mapreduce
using ..CoreModule: Options
using ..CoreModule: AbstractOptions
using ..ComplexityModule: compute_complexity, past_complexity_limit

# Check if any binary operator are overly complex
function flag_bin_operator_complexity(
tree::AbstractExpressionNode, op, cons, options::Options
tree::AbstractExpressionNode, op, cons, options::AbstractOptions
)::Bool
any(tree) do subtree
if subtree.degree == 2 && subtree.op == op
Expand All @@ -27,7 +27,7 @@ Check if any unary operators are overly complex.
This assumes you have already checked whether the constraint is > -1.
"""
function flag_una_operator_complexity(
tree::AbstractExpressionNode, op, cons, options::Options
tree::AbstractExpressionNode, op, cons, options::AbstractOptions
)::Bool
any(tree) do subtree
if subtree.degree == 1 && tree.op == op
Expand All @@ -52,7 +52,7 @@ function count_max_nestedness(tree, degree, op)
end

"""Check if there are any illegal combinations of operators"""
function flag_illegal_nests(tree::AbstractExpressionNode, options::Options)::Bool
function flag_illegal_nests(tree::AbstractExpressionNode, options::AbstractOptions)::Bool
# We search from the top first, then from child nodes at end.
(nested_constraints = options.nested_constraints) === nothing && return false
for (degree, op_idx, op_constraint) in nested_constraints
Expand All @@ -72,7 +72,7 @@ end
"""Check if user-passed constraints are violated or not"""
function check_constraints(
ex::AbstractExpression,
options::Options,
options::AbstractOptions,
maxsize::Int,
cursize::Union{Int,Nothing}=nothing,
)::Bool
Expand All @@ -81,7 +81,7 @@ function check_constraints(
end
function check_constraints(
tree::AbstractExpressionNode,
options::Options,
options::AbstractOptions,
maxsize::Int,
cursize::Union{Int,Nothing}=nothing,
)::Bool
Expand All @@ -103,7 +103,7 @@ function check_constraints(
end

check_constraints(
ex::Union{AbstractExpression,AbstractExpressionNode}, options::Options
ex::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions
)::Bool = check_constraints(ex, options, options.maxsize)

end
8 changes: 4 additions & 4 deletions src/Complexity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ module ComplexityModule

using DynamicExpressions:
AbstractExpression, AbstractExpressionNode, get_tree, count_nodes, tree_mapreduce
using ..CoreModule: Options, ComplexityMapping
using ..CoreModule: AbstractOptions, ComplexityMapping

function past_complexity_limit(
tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options, limit
tree::Union{AbstractExpression,AbstractExpressionNode}, options::AbstractOptions, limit
)::Bool
return compute_complexity(tree, options) > limit
end
Expand All @@ -18,12 +18,12 @@ However, it could use the custom settings in options.complexity_mapping
if these are defined.
"""
function compute_complexity(
tree::AbstractExpression, options::Options; break_sharing=Val(false)
tree::AbstractExpression, options::AbstractOptions; break_sharing=Val(false)
)
return compute_complexity(get_tree(tree), options; break_sharing)
end
function compute_complexity(
tree::AbstractExpressionNode, options::Options; break_sharing=Val(false)
tree::AbstractExpressionNode, options::AbstractOptions; break_sharing=Val(false)
)::Int
if options.complexity_mapping.use
raw_complexity = _compute_complexity(
Expand Down
Loading
Loading