-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #277 from MilesCranmer/tb-logging
Integration with TensorBoard and other logging utilities
- Loading branch information
Showing
15 changed files
with
529 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
module LoggingModule | ||
|
||
using Base: AbstractLogger | ||
using Logging: Logging as LG | ||
using DynamicExpressions: string_tree | ||
|
||
using ..UtilsModule: @ignore | ||
using ..CoreModule: AbstractOptions, Dataset | ||
using ..PopulationModule: Population | ||
using ..HallOfFameModule: HallOfFame | ||
using ..ComplexityModule: compute_complexity | ||
using ..HallOfFameModule: calculate_pareto_frontier | ||
using ..SearchUtilsModule: AbstractSearchState, AbstractRuntimeOptions | ||
|
||
import ..SearchUtilsModule: logging_callback! | ||
|
||
""" | ||
AbstractSRLogger <: AbstractLogger | ||
Abstract type for symbolic regression loggers. Subtypes must implement: | ||
- `get_logger(logger)`: Return the underlying logger | ||
- `logging_callback!(logger; kws...)`: Callback function for logging. | ||
Called with the current state, datasets, runtime options, and options. If you wish to | ||
reduce the logging frequency, you can increment and monitor a counter within this | ||
function. | ||
""" | ||
abstract type AbstractSRLogger <: AbstractLogger end | ||
|
||
""" | ||
SRLogger(logger::AbstractLogger; log_every_n::Integer=1) | ||
A logger for symbolic regression that wraps another logger. | ||
# Arguments | ||
- `logger`: The base logger to wrap | ||
- `log_interval`: Number of steps between logging events. Default is 1 (log every step). | ||
""" | ||
Base.@kwdef struct SRLogger{L<:AbstractLogger} <: AbstractSRLogger | ||
logger::L | ||
log_interval::Int = 1 | ||
_log_step::Base.RefValue{Int} = Base.RefValue(0) | ||
end | ||
SRLogger(logger::AbstractLogger; kws...) = SRLogger(; logger, kws...) | ||
|
||
function get_logger(logger::SRLogger) | ||
return logger.logger | ||
end | ||
function should_log(logger::SRLogger) | ||
return logger.log_interval > 0 && logger._log_step[] % logger.log_interval == 0 | ||
end | ||
function increment_log_step!(logger::SRLogger) | ||
logger._log_step[] += 1 | ||
return nothing | ||
end | ||
|
||
function LG.with_logger(f::Function, logger::AbstractSRLogger) | ||
return LG.with_logger(f, get_logger(logger)) | ||
end | ||
|
||
""" | ||
logging_callback!(logger::AbstractSRLogger; kws...) | ||
Default logging callback for SymbolicRegression. | ||
To override the default logging behavior, create a new type `MyLogger <: AbstractSRLogger` | ||
and define a method for `SymbolicRegression.logging_callback`. | ||
""" | ||
function logging_callback!( | ||
logger::AbstractSRLogger; | ||
@nospecialize(state::AbstractSearchState), | ||
datasets::AbstractVector{<:Dataset{T,L}}, | ||
@nospecialize(ropt::AbstractRuntimeOptions), | ||
@nospecialize(options::AbstractOptions), | ||
) where {T,L} | ||
if should_log(logger) | ||
data = log_payload(logger, state, datasets, options) | ||
LG.with_logger(logger) do | ||
@info("search", data = data) | ||
end | ||
end | ||
increment_log_step!(logger) | ||
return nothing | ||
end | ||
|
||
function log_payload( | ||
logger::AbstractSRLogger, | ||
@nospecialize(state::AbstractSearchState), | ||
datasets::AbstractVector{<:Dataset{T,L}}, | ||
@nospecialize(options::AbstractOptions), | ||
) where {T,L} | ||
d = Ref(Dict{String,Any}()) | ||
for i in eachindex(datasets, state.halls_of_fame) | ||
out = _log_scalars(; | ||
pops=state.last_pops[i], | ||
hall_of_fame=state.halls_of_fame[i], | ||
dataset=datasets[i], | ||
options, | ||
) | ||
if length(datasets) == 1 | ||
d[] = out | ||
else | ||
d[]["output$(i)"] = out | ||
end | ||
end | ||
d[]["num_evals"] = sum(sum, state.num_evals) | ||
return d[] | ||
end | ||
|
||
function _log_scalars(; | ||
@nospecialize(pops::AbstractVector{<:Population}), | ||
@nospecialize(hall_of_fame::HallOfFame{T,L}), | ||
dataset::Dataset{T,L}, | ||
@nospecialize(options::AbstractOptions), | ||
) where {T,L} | ||
out = Dict{String,Any}() | ||
|
||
#### Population diagnostics | ||
out["population"] = Dict([ | ||
"complexities" => let | ||
complexities = Int[] | ||
for pop in pops, member in pop.members | ||
push!(complexities, compute_complexity(member, options)) | ||
end | ||
complexities | ||
end | ||
]) | ||
|
||
#### Summaries | ||
dominating = calculate_pareto_frontier(hall_of_fame) | ||
trees = [member.tree for member in dominating] | ||
losses = L[member.loss for member in dominating] | ||
complexities = Int[compute_complexity(member, options) for member in dominating] | ||
|
||
out["min_loss"] = length(dominating) > 0 ? dominating[end].loss : L(Inf) | ||
out["pareto_volume"] = if length(dominating) > 1 | ||
log_losses = @. log10(losses + eps(L)) | ||
log_complexities = @. log10(complexities) | ||
|
||
# Add a point equal to the best loss and largest possible complexity, + 1 | ||
push!(log_losses, minimum(log_losses)) | ||
push!(log_complexities, log10(options.maxsize + 1)) | ||
|
||
# Add a point to connect things: | ||
push!(log_losses, maximum(log_losses)) | ||
push!(log_complexities, maximum(log_complexities)) | ||
|
||
xy = cat(log_complexities, log_losses; dims=2) | ||
hull = convex_hull(xy) | ||
convex_hull_area(hull) | ||
else | ||
0.0 | ||
end | ||
|
||
#### Full Pareto front | ||
out["equations"] = let | ||
equations = String[ | ||
string_tree(member.tree, options; variable_names=dataset.variable_names) for | ||
member in dominating | ||
] | ||
Dict([ | ||
"complexity=" * string(complexities[i_eqn]) => | ||
Dict("loss" => losses[i_eqn], "equation" => equations[i_eqn]) for | ||
i_eqn in eachindex(complexities, losses, equations) | ||
]) | ||
end | ||
return out | ||
end | ||
|
||
"""Uses gift wrapping algorithm to create a convex hull.""" | ||
function convex_hull(xy) | ||
@assert size(xy, 2) == 2 | ||
cur_point = xy[sortperm(xy[:, 1])[1], :] | ||
hull = typeof(cur_point)[] | ||
while true | ||
push!(hull, cur_point) | ||
end_point = xy[1, :] | ||
for candidate_point in eachrow(xy) | ||
if end_point == cur_point || isleftof(candidate_point, (cur_point, end_point)) | ||
end_point = candidate_point | ||
end | ||
end | ||
cur_point = end_point | ||
if end_point == hull[1] | ||
break | ||
end | ||
end | ||
return hull | ||
end | ||
|
||
function isleftof(point, line) | ||
(start_point, end_point) = line | ||
return (end_point[1] - start_point[1]) * (point[2] - start_point[2]) - | ||
(end_point[2] - start_point[2]) * (point[1] - start_point[1]) > 0 | ||
end | ||
|
||
"""Computes the area within a convex hull.""" | ||
function convex_hull_area(hull) | ||
area = 0.0 | ||
for i in eachindex(hull) | ||
j = i == lastindex(hull) ? firstindex(hull) : nextind(hull, i) | ||
area += (hull[i][1] * hull[j][2] - hull[j][1] * hull[i][2]) | ||
end | ||
return abs(area) / 2 | ||
end | ||
|
||
end |
Oops, something went wrong.