Skip to content

Commit

Permalink
Use capital letters for compile-time constants
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 24, 2023
1 parent 62ce580 commit 0fb6f40
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 44 deletions.
84 changes: 40 additions & 44 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ import .CoreModule:
erf,
erfc,
atanh_clip
import .UtilsModule: is_anonymous_function, recursive_merge, json3_write
import .UtilsModule: is_anonymous_function, recursive_merge, json3_write, get_base_type
import .ComplexityModule: compute_complexity
import .CheckConstraintsModule: check_constraints
import .AdaptiveParsimonyModule:
Expand Down Expand Up @@ -349,16 +349,16 @@ function equation_search(
runtests::Bool=true,
saved_state=nothing,
return_state::Union{Bool,Nothing}=nothing,
loss_type::Type{Linit}=Nothing,
loss_type::Type{L}=Nothing,
verbosity::Union{Integer,Nothing}=nothing,
progress::Union{Bool,Nothing}=nothing,
X_units::Union{AbstractVector,Nothing}=nothing,
y_units=nothing,
v_dim_out::Val{dim_out}=Val(nothing),
v_dim_out::Val{DIM_OUT}=Val(nothing),
# Deprecated:
multithreaded=nothing,
varMap=nothing,
) where {T<:DATA_TYPE,Linit,dim_out}
) where {T<:DATA_TYPE,L,DIM_OUT}
if multithreaded !== nothing
error(
"`multithreaded` is deprecated. Use the `parallelism` argument instead. " *
Expand All @@ -371,10 +371,6 @@ function equation_search(
@assert length(weights) == length(y)
weights = reshape(weights, size(y))
end
if T <: Complex && loss_type == Nothing
get_base_type(::Type{Complex{BT}}) where {BT} = BT
loss_type = get_base_type(T)
end

datasets = construct_datasets(
X,
Expand All @@ -385,7 +381,7 @@ function equation_search(
y_variable_names,
X_units,
y_units,
loss_type,
(T <: Complex && L === Nothing) ? get_base_type(T) : L,
)

return equation_search(
Expand All @@ -402,7 +398,7 @@ function equation_search(
return_state=return_state,
verbosity=verbosity,
progress=progress,
v_dim_out=Val(dim_out),
v_dim_out=Val(DIM_OUT),
)
end

Expand Down Expand Up @@ -439,8 +435,8 @@ function equation_search(
return_state::Union{Bool,Nothing}=nothing,
verbosity::Union{Int,Nothing}=nothing,
progress::Union{Bool,Nothing}=nothing,
v_dim_out::Val{dim_out}=Val(nothing),
) where {dim_out,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}}
v_dim_out::Val{DIM_OUT}=Val(nothing),
) where {DIM_OUT,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}}
v_concurrency, concurrency = if parallelism in (:multithreading, "multithreading")
(Val(:multithreading), :multithreading)
elseif parallelism in (:multiprocessing, "multiprocessing")
Expand Down Expand Up @@ -476,10 +472,10 @@ function equation_search(
options.return_state
end

v_dim_out = if dim_out === nothing
v_dim_out = if DIM_OUT === nothing
length(datasets) > 1 ? Val(2) : Val(1)
else
Val(dim_out)
Val(DIM_OUT)
end
_numprocs::Int = if numprocs === nothing && procs === nothing
4
Expand Down Expand Up @@ -565,8 +561,8 @@ function equation_search(
end

function _equation_search(
::Val{parallelism},
::Val{dim_out},
::Val{PARALLELISM},
::Val{DIM_OUT},
datasets::Vector{D},
niterations::Int,
options::Options,
Expand All @@ -578,8 +574,8 @@ function _equation_search(
saved_state,
verbosity,
progress,
::Val{should_return_state},
) where {T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L},parallelism,should_return_state,dim_out}
::Val{RETURN_STATE},
) where {T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L},PARALLELISM,RETURN_STATE,DIM_OUT}
stdin_reader = watch_stream(stdin)

if options.define_helper_functions
Expand All @@ -588,10 +584,10 @@ function _equation_search(

example_dataset = datasets[1]
nout = size(datasets, 1)
@assert (nout == 1 || dim_out == 2)
@assert (nout == 1 || DIM_OUT == 2)

if runtests
test_option_configuration(parallelism, datasets, saved_state, options)
test_option_configuration(PARALLELISM, datasets, saved_state, options)
test_dataset_configuration(example_dataset, options, verbosity)
end

Expand All @@ -604,9 +600,9 @@ function _equation_search(
end
# Start a population on every process
# Store the population, hall of fame
WorkerOutputType = if parallelism == :serial
WorkerOutputType = if PARALLELISM == :serial
Tuple{Population{T,L},HallOfFame{T,L},RecordType,Float64}
elseif parallelism == :multiprocessing
elseif PARALLELISM == :multiprocessing
Future
else
Task
Expand All @@ -622,7 +618,7 @@ function _equation_search(
# Initialize storage for workers
tasks = [Task[] for j in 1:nout]
# Set up a channel to send finished populations back to head node
channels = if parallelism == :multiprocessing
channels = if PARALLELISM == :multiprocessing
[[RemoteChannel(1) for i in 1:(options.populations)] for j in 1:nout]
else
[[Channel(1) for i in 1:(options.populations)] for j in 1:nout]
Expand All @@ -645,7 +641,7 @@ function _equation_search(
##########################################################################
### Distributed code:
##########################################################################
if parallelism == :multiprocessing
if PARALLELISM == :multiprocessing
(procs, we_created_procs) = configure_workers(;
procs,
numprocs,
Expand Down Expand Up @@ -680,7 +676,7 @@ function _equation_search(

for j in 1:nout, i in 1:(options.populations)
worker_idx = assign_next_worker!(
worker_assignment; out=j, pop=i, parallelism, procs
worker_assignment; out=j, pop=i, parallelism=PARALLELISM, procs
)
saved_pop = load_saved_population(saved_state; out=j, pop=i)

Expand All @@ -698,7 +694,7 @@ function _equation_search(
begin
(copy_pop, HallOfFame(options, T, L), RecordType(), 0.0)
end,
parallelism = parallelism,
parallelism = PARALLELISM,
worker_idx = worker_idx
)
else
Expand All @@ -720,7 +716,7 @@ function _equation_search(
Float64(options.population_size),
)
end,
parallelism = parallelism,
parallelism = PARALLELISM,
worker_idx = worker_idx
)
# This involves population_size evaluations, on the full dataset:
Expand All @@ -740,7 +736,7 @@ function _equation_search(
curmaxsize = curmaxsizes[j]
@recorder record["out$(j)_pop$(i)"] = RecordType()
worker_idx = assign_next_worker!(
worker_assignment; out=j, pop=i, parallelism, procs
worker_assignment; out=j, pop=i, parallelism=PARALLELISM, procs
)

# TODO - why is this needed??
Expand All @@ -749,7 +745,7 @@ function _equation_search(
last_pop = worker_output[j][i]
updated_pop = @sr_spawner(
begin
in_pop = if parallelism in (:multiprocessing, :multithreading)
in_pop = if PARALLELISM in (:multiprocessing, :multithreading)
fetch(last_pop)[1]
else
last_pop[1]
Expand All @@ -766,7 +762,7 @@ function _equation_search(
running_search_statistics=c_rss,
)
end,
parallelism = parallelism,
parallelism = PARALLELISM,
worker_idx = worker_idx
)
worker_output[j][i] = updated_pop
Expand All @@ -789,7 +785,7 @@ function _equation_search(
print_every_n_seconds = 5
equation_speed = Float32[]

if parallelism in (:multiprocessing, :multithreading)
if PARALLELISM in (:multiprocessing, :multithreading)
for j in 1:nout, i in 1:(options.populations)
# Start listening for each population to finish:
t = @async put!(channels[j][i], fetch(worker_output[j][i]))
Expand Down Expand Up @@ -817,14 +813,14 @@ function _equation_search(
j, i = all_idx[kappa]

# Check if error on population:
if parallelism in (:multiprocessing, :multithreading)
if PARALLELISM in (:multiprocessing, :multithreading)
if istaskfailed(tasks[j][i])
fetch(tasks[j][i])
error("Task failed for population")
end
end
# Non-blocking check if a population is ready:
population_ready = if parallelism in (:multiprocessing, :multithreading)
population_ready = if PARALLELISM in (:multiprocessing, :multithreading)
# TODO: Implement type assertions based on parallelism.
isready(channels[j][i])
else
Expand All @@ -837,7 +833,7 @@ function _equation_search(
start_work_monitor!(resource_monitor)
# Take the fetch operation from the channel since its ready
(cur_pop, best_seen, cur_record, cur_num_evals) =
if parallelism in (:multiprocessing, :multithreading)
if PARALLELISM in (:multiprocessing, :multithreading)
take!(channels[j][i])
else
worker_output[j][i]
Expand Down Expand Up @@ -904,7 +900,7 @@ function _equation_search(
break
end
worker_idx = assign_next_worker!(
worker_assignment; out=j, pop=i, parallelism, procs
worker_assignment; out=j, pop=i, parallelis=PARALLELISM, procs
)
iteration = if is_recording(options)
key = "out$(j)_pop$(i)"
Expand All @@ -929,10 +925,10 @@ function _equation_search(
running_search_statistics=c_rss,
)
end,
parallelism = parallelism,
parallelism = PARALLELISM,
worker_idx = worker_idx
)
if parallelism in (:multiprocessing, :multithreading)
if PARALLELISM in (:multiprocessing, :multithreading)
tasks[j][i] = @async put!(channels[j][i], fetch(worker_output[j][i]))
end

Expand All @@ -950,7 +946,7 @@ function _equation_search(
options,
equation_speed,
head_node_occupation,
parallelism,
PARALLELISM,
)
end
end
Expand Down Expand Up @@ -990,7 +986,7 @@ function _equation_search(
total_cycles,
cycles_remaining,
head_node_occupation,
parallelism,
PARALLELISM,
width=options.terminal_width,
)
end
Expand All @@ -1014,9 +1010,9 @@ function _equation_search(
close_reader!(stdin_reader)

# Safely close all processes or threads
if parallelism == :multiprocessing
if PARALLELISM == :multiprocessing
we_created_procs && rmprocs(procs)
elseif parallelism == :multithreading
elseif PARALLELISM == :multithreading
for j in 1:nout, i in 1:(options.populations)
wait(worker_output[j][i])
end
Expand All @@ -1028,10 +1024,10 @@ function _equation_search(

@recorder json3_write(record, options.recorder_file)

if should_return_state
return (returnPops, (dim_out == 1 ? only(hallOfFame) : hallOfFame))
if RETURN_STATE
return (returnPops, (DIM_OUT == 1 ? only(hallOfFame) : hallOfFame))
else
return (dim_out == 1 ? only(hallOfFame) : hallOfFame)
return (DIM_OUT == 1 ? only(hallOfFame) : hallOfFame)
end
end

Expand Down
2 changes: 2 additions & 0 deletions src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ recursive_merge(x::AbstractDict...) = merge(recursive_merge, x...)
recursive_merge(x...) = x[end]
recursive_merge() = error("Unexpected input.")

get_base_type(::Type{Complex{BT}}) where {BT} = BT

const subscripts = ('', '', '', '', '', '', '', '', '', '')
function subscriptify(number::Integer)
return join([subscripts[i + 1] for i in reverse(digits(number))])
Expand Down

0 comments on commit 0fb6f40

Please sign in to comment.