Skip to content

Commit

Permalink
Merge pull request #277 from MilesCranmer/tb-logging
Browse files Browse the repository at this point in the history
Integration with TensorBoard and other logging utilities
  • Loading branch information
MilesCranmer authored Nov 9, 2024
2 parents 6c5c74e + 8c37c26 commit 92f2b33
Show file tree
Hide file tree
Showing 15 changed files with 529 additions and 48 deletions.
33 changes: 32 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Summary of major recent changes, described in more detail below:
- `AbstractSearchState`, for holding custom metadata during searches.
- `AbstractOptions` and `AbstractRuntimeOptions`, for customizing pretty much everything else in the library via multiple dispatch. Please make an issue/PR if you would like any particular internal functions be declared `public` to enable stability across versions for your tool.
- Many of these were motivated to modularize the implementation of [LaSR](https://github.com/trishullab/LibraryAugmentedSymbolicRegression.jl), an LLM-guided version of SymbolicRegression.jl, so it can sit as a modular layer on top of SymbolicRegression.jl.
- [Added TensorBoardLogger.jl and other logging integrations via `SRLogger`](#added-tensorboardloggerjl-and-other-logging-integrations-via-srlogger)
- Fundamental improvements to the underlying evolutionary algorithm
- New mutation operators introduced, `swap_operands` and `rotate_tree` – both of which seem to help kick the evolution out of local optima.
- New hyperparameter defaults created, based on a Pareto front volume calculation, rather than simply accuracy of the best expression.
Expand Down Expand Up @@ -366,7 +367,36 @@ Base.propertynames(options::MyOptions) = (NEW_OPTIONS_KEYS..., fieldnames(Symbol
These new abstractions provide users with greater flexibility in defining the structure and behavior of expressions, nodes, and the search process itself.
These are also of course used as the basis for alternate behavior such as `ParametricExpression` and `TemplateExpression`.

### Fundamental improvements to the underlying evolutionary algorithm
### Added TensorBoardLogger.jl and other logging integrations via `SRLogger`

You can now track the progress of symbolic regression searches using `TensorBoardLogger.jl`, `Wandb.jl`, or other logging backends.

This is done by wrapping any `AbstractLogger` with the new `SRLogger` type, and passing it to the `logger` option in `SRRegressor`
or `equation_search`:

```julia
using SymbolicRegression
using TensorBoardLogger

logger = SRLogger(
TBLogger("logs/run"),
log_interval=2, # Log every 2 steps
)

model = SRRegressor(;
binary_operators=[+, -, *],
logger=logger,
)
```

The logger will track:

- Loss curves over time at each complexity level
- Population statistics (distribution of complexities)
- Pareto frontier volume (can be used as an overall metric of search performance)
- Full equations at each complexity level

This works with any logger that implements the Julia logging interface.

### Support for Zygote.jl and Enzyme.jl within the constant optimizer, specified using the `autodiff_backend` option

Expand Down Expand Up @@ -405,6 +435,7 @@ A custom run ID can be specified via the new `run_id` parameter passed to `equat
- Option to force dimensionless constants when fitting with dimensional constraints, via the `dimensionless_constants_only` option.
- Default `maxsize` increased from 20 to 30.
- Default `niterations` increased from 10 to 50, as many users seem to be unaware that this is small (and meant for testing), even in publications. I think this 50 is still low, but it should be a more accurate default for those who don't tune.
- `MLJ.fit!(mach)` now records the number of iterations used, and, should `mach.model.niterations` be changed after the fit, the number of iterations passed to `equation_search` will be reduced accordingly.

### Update Guide

Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand Down Expand Up @@ -52,6 +53,7 @@ DynamicQuantities = "1"
Enzyme = "0.12"
JSON3 = "1"
LineSearches = "7"
Logging = "1"
LossFunctions = "0.10, 0.11"
MLJModelInterface = "~1.5, ~1.6, ~1.7, ~1.8, ~1.9, ~1.10, ~1.11"
MacroTools = "0.4, 0.5"
Expand Down
21 changes: 21 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,24 @@ Note that use of this function requires `SymbolicUtils.jl` to be installed and l
```@docs
calculate_pareto_frontier
```

## Logging

```@docs
SRLogger
```

The `SRLogger` allows you to track the progress of symbolic regression searches.
It can wrap any `AbstractLogger` that implements the Julia logging interface,
such as from TensorBoardLogger.jl or Wandb.jl.

```julia
using TensorBoardLogger

logger = SRLogger(TBLogger("logs/run"), log_interval=2)

model = SRRegressor(;
logger=logger,
kws...
)
```
39 changes: 38 additions & 1 deletion docs/src/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,44 @@ You can even output custom structs - see the more detailed [Template Expression

Be sure to also check out the [Parametric Expression example](examples/parametric_expression.md).

## 9. Additional features
## 9. Logging with TensorBoard

You can track the progress of symbolic regression searches using TensorBoard or other logging backends. Here's an example using `TensorBoardLogger` and wrapping it with [`SRLogger`](@ref):

```julia
using SymbolicRegression
using TensorBoardLogger
using MLJ

logger = SRLogger(TBLogger("logs/sr_run"))

# Create and fit model with logger
model = SRRegressor(
binary_operators=[+, -, *],
maxsize=40,
niterations=100,
logger=logger
)

X = (a=rand(500), b=rand(500))
y = @. 2 * cos(X.a * 23.5) - X.b^2

mach = machine(model, X, y)
fit!(mach)
```

You can then view the logs with:

```bash
tensorboard --logdir logs
```

The TensorBoard interface will show
the loss curves over time (at each complexity), as well
as the Pareto frontier volume which can be used as an overall metric
of the search performance.

## 10. Additional features

For the many other features available in SymbolicRegression.jl,
check out the API page for `Options`. You might also find it useful
Expand Down
9 changes: 5 additions & 4 deletions examples/parameterized_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,15 @@ functional form, but with varying parameters across different conditions or clas
fit!(mach)
idx1 = lastindex(report(mach).equations)
ypred1 = predict(mach, (data=X, idx=idx1))
loss1 = sum(i -> abs2(ypred1[i] - y[i]), eachindex(y))
loss1 = sum(i -> abs2(ypred1[i] - y[i]), eachindex(y)) / length(y)

# Should keep all parameters
stop_at[] = 1e-5
stop_at[] = loss1 * 0.999
mach.model.niterations *= 2
fit!(mach)
idx2 = lastindex(report(mach).equations)
ypred2 = predict(mach, (data=X, idx=idx2))
loss2 = sum(i -> abs2(ypred2[i] - y[i]), eachindex(y))
loss2 = sum(i -> abs2(ypred2[i] - y[i]), eachindex(y)) / length(y)

# Should get better:
@test loss1 >= loss2
@test loss1 > loss2
17 changes: 10 additions & 7 deletions src/HallOfFame.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,7 @@ function string_dominating_pareto_curve(
y_sym_units=dataset.y_sym_units,
pretty,
)
y_prefix = dataset.y_variable_name
unit_str = format_dimensions(dataset.y_sym_units)
y_prefix *= unit_str
if dataset.y_sym_units === nothing && dataset.X_sym_units !== nothing
y_prefix *= WILDCARD_UNIT_STRING
end
prefix = y_prefix * " = "
prefix = make_prefix(tree, options, dataset)
eqn_string = prefix * eqn_string
stats_columns_string = @sprintf("%-10d %-8.3e %-8.3e ", complexity, loss, score)
left_cols_width = length(stats_columns_string)
Expand All @@ -172,6 +166,15 @@ function string_dominating_pareto_curve(
print(buffer, ''^(terminal_width - 1))
return dump_buffer(buffer)
end
function make_prefix(::AbstractExpression, ::AbstractOptions, dataset::Dataset)
y_prefix = dataset.y_variable_name
unit_str = format_dimensions(dataset.y_sym_units)
y_prefix *= unit_str
if dataset.y_sym_units === nothing && dataset.X_sym_units !== nothing
y_prefix *= WILDCARD_UNIT_STRING
end
return y_prefix * " = "
end

function wrap_equation_string(eqn_string, left_cols_width, terminal_width)
dots = "..."
Expand Down
207 changes: 207 additions & 0 deletions src/Logging.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
module LoggingModule

using Base: AbstractLogger
using Logging: Logging as LG
using DynamicExpressions: string_tree

using ..UtilsModule: @ignore
using ..CoreModule: AbstractOptions, Dataset
using ..PopulationModule: Population
using ..HallOfFameModule: HallOfFame
using ..ComplexityModule: compute_complexity
using ..HallOfFameModule: calculate_pareto_frontier
using ..SearchUtilsModule: AbstractSearchState, AbstractRuntimeOptions

import ..SearchUtilsModule: logging_callback!

"""
AbstractSRLogger <: AbstractLogger
Abstract type for symbolic regression loggers. Subtypes must implement:
- `get_logger(logger)`: Return the underlying logger
- `logging_callback!(logger; kws...)`: Callback function for logging.
Called with the current state, datasets, runtime options, and options. If you wish to
reduce the logging frequency, you can increment and monitor a counter within this
function.
"""
abstract type AbstractSRLogger <: AbstractLogger end

"""
SRLogger(logger::AbstractLogger; log_every_n::Integer=1)
A logger for symbolic regression that wraps another logger.
# Arguments
- `logger`: The base logger to wrap
- `log_interval`: Number of steps between logging events. Default is 1 (log every step).
"""
Base.@kwdef struct SRLogger{L<:AbstractLogger} <: AbstractSRLogger
logger::L
log_interval::Int = 1
_log_step::Base.RefValue{Int} = Base.RefValue(0)
end
SRLogger(logger::AbstractLogger; kws...) = SRLogger(; logger, kws...)

function get_logger(logger::SRLogger)
return logger.logger
end
function should_log(logger::SRLogger)
return logger.log_interval > 0 && logger._log_step[] % logger.log_interval == 0
end
function increment_log_step!(logger::SRLogger)
logger._log_step[] += 1
return nothing
end

function LG.with_logger(f::Function, logger::AbstractSRLogger)
return LG.with_logger(f, get_logger(logger))
end

"""
logging_callback!(logger::AbstractSRLogger; kws...)
Default logging callback for SymbolicRegression.
To override the default logging behavior, create a new type `MyLogger <: AbstractSRLogger`
and define a method for `SymbolicRegression.logging_callback`.
"""
function logging_callback!(
logger::AbstractSRLogger;
@nospecialize(state::AbstractSearchState),
datasets::AbstractVector{<:Dataset{T,L}},
@nospecialize(ropt::AbstractRuntimeOptions),
@nospecialize(options::AbstractOptions),
) where {T,L}
if should_log(logger)
data = log_payload(logger, state, datasets, options)
LG.with_logger(logger) do
@info("search", data = data)
end
end
increment_log_step!(logger)
return nothing
end

function log_payload(
logger::AbstractSRLogger,
@nospecialize(state::AbstractSearchState),
datasets::AbstractVector{<:Dataset{T,L}},
@nospecialize(options::AbstractOptions),
) where {T,L}
d = Ref(Dict{String,Any}())
for i in eachindex(datasets, state.halls_of_fame)
out = _log_scalars(;
pops=state.last_pops[i],
hall_of_fame=state.halls_of_fame[i],
dataset=datasets[i],
options,
)
if length(datasets) == 1
d[] = out
else
d[]["output$(i)"] = out
end
end
d[]["num_evals"] = sum(sum, state.num_evals)
return d[]
end

function _log_scalars(;
@nospecialize(pops::AbstractVector{<:Population}),
@nospecialize(hall_of_fame::HallOfFame{T,L}),
dataset::Dataset{T,L},
@nospecialize(options::AbstractOptions),
) where {T,L}
out = Dict{String,Any}()

#### Population diagnostics
out["population"] = Dict([
"complexities" => let
complexities = Int[]
for pop in pops, member in pop.members
push!(complexities, compute_complexity(member, options))
end
complexities
end
])

#### Summaries
dominating = calculate_pareto_frontier(hall_of_fame)
trees = [member.tree for member in dominating]
losses = L[member.loss for member in dominating]
complexities = Int[compute_complexity(member, options) for member in dominating]

out["min_loss"] = length(dominating) > 0 ? dominating[end].loss : L(Inf)
out["pareto_volume"] = if length(dominating) > 1
log_losses = @. log10(losses + eps(L))
log_complexities = @. log10(complexities)

# Add a point equal to the best loss and largest possible complexity, + 1
push!(log_losses, minimum(log_losses))
push!(log_complexities, log10(options.maxsize + 1))

# Add a point to connect things:
push!(log_losses, maximum(log_losses))
push!(log_complexities, maximum(log_complexities))

xy = cat(log_complexities, log_losses; dims=2)
hull = convex_hull(xy)
convex_hull_area(hull)
else
0.0
end

#### Full Pareto front
out["equations"] = let
equations = String[
string_tree(member.tree, options; variable_names=dataset.variable_names) for
member in dominating
]
Dict([
"complexity=" * string(complexities[i_eqn]) =>
Dict("loss" => losses[i_eqn], "equation" => equations[i_eqn]) for
i_eqn in eachindex(complexities, losses, equations)
])
end
return out
end

"""Uses gift wrapping algorithm to create a convex hull."""
function convex_hull(xy)
@assert size(xy, 2) == 2
cur_point = xy[sortperm(xy[:, 1])[1], :]
hull = typeof(cur_point)[]
while true
push!(hull, cur_point)
end_point = xy[1, :]
for candidate_point in eachrow(xy)
if end_point == cur_point || isleftof(candidate_point, (cur_point, end_point))
end_point = candidate_point
end
end
cur_point = end_point
if end_point == hull[1]
break
end
end
return hull
end

function isleftof(point, line)
(start_point, end_point) = line
return (end_point[1] - start_point[1]) * (point[2] - start_point[2]) -
(end_point[2] - start_point[2]) * (point[1] - start_point[1]) > 0
end

"""Computes the area within a convex hull."""
function convex_hull_area(hull)
area = 0.0
for i in eachindex(hull)
j = i == lastindex(hull) ? firstindex(hull) : nextind(hull, i)
area += (hull[i][1] * hull[j][2] - hull[j][1] * hull[i][2])
end
return abs(area) / 2
end

end
Loading

0 comments on commit 92f2b33

Please sign in to comment.