Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Move NonlinearSolvePolyAlgorithm to Base #494

Merged
merged 5 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ NLSolvers = "0.5"
NLsolve = "4.5"
NaNMath = "1"
NonlinearProblemLibrary = "0.1.2"
NonlinearSolveBase = "1.2"
NonlinearSolveBase = "1.3"
NonlinearSolveFirstOrder = "1"
NonlinearSolveQuasiNewton = "1"
NonlinearSolveSpectralMethods = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolveBase"
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.0"
version = "1.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 3 additions & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ include("linear_solve.jl")
include("timer_outputs.jl")
include("tracing.jl")
include("wrappers.jl")
include("polyalg.jl")

include("descent/common.jl")
include("descent/newton.jl")
Expand Down Expand Up @@ -81,4 +82,6 @@ export RelTerminationMode, AbsTerminationMode,
export DescentResult, SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogleg,
GeodesicAcceleration

export NonlinearSolvePolyAlgorithm

end
202 changes: 202 additions & 0 deletions lib/NonlinearSolveBase/src/polyalg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""
NonlinearSolvePolyAlgorithm(algs; start_index::Int = 1)

A general way to define PolyAlgorithms for `NonlinearProblem` and
`NonlinearLeastSquaresProblem`. This is a container for a tuple of algorithms that will be
tried in order until one succeeds. If none succeed, then the algorithm with the lowest
residual is returned.

### Arguments

- `algs`: a tuple of algorithms to try in-order! (If this is not a Tuple, then the
returned algorithm is not type-stable).

### Keyword Arguments

- `start_index`: the index to start at. Defaults to `1`.

### Example

```julia
using NonlinearSolve

alg = NonlinearSolvePolyAlgorithm((NewtonRaphson(), Broyden()))
```
"""
@concrete struct NonlinearSolvePolyAlgorithm <: AbstractNonlinearSolveAlgorithm
static_length <: Val
algs <: Tuple
start_index::Int
end

function NonlinearSolvePolyAlgorithm(algs; start_index::Int = 1)
@assert 0 < start_index ≤ length(algs)
algs = Tuple(algs)
return NonlinearSolvePolyAlgorithm(Val(length(algs)), algs, start_index)
end

@concrete mutable struct NonlinearSolvePolyAlgorithmCache <: AbstractNonlinearSolveCache
static_length <: Val
prob <: AbstractNonlinearProblem

caches <: Tuple
alg <: NonlinearSolvePolyAlgorithm

best::Int
current::Int
nsteps::Int

stats::NLStats
total_time::Float64
maxtime

retcode::ReturnCode.T
force_stop::Bool

maxiters::Int
internalnorm

u0
u0_aliased
alias_u0::Bool
end

function SII.symbolic_container(cache::NonlinearSolvePolyAlgorithmCache)
return cache.caches[cache.current]
end
SII.state_values(cache::NonlinearSolvePolyAlgorithmCache) = cache.u0

function Base.show(io::IO, ::MIME"text/plain", cache::NonlinearSolvePolyAlgorithmCache)
println(io, "NonlinearSolvePolyAlgorithmCache with \
$(Utils.unwrap_val(cache.static_length)) algorithms:")
best_alg = ifelse(cache.best == -1, "nothing", cache.best)
println(io, " Best Algorithm: $(best_alg)")
println(
io, " Current Algorithm: [$(cache.current) / $(Utils.unwrap_val(cache.static_length))]"
)
println(io, " nsteps: $(cache.nsteps)")
println(io, " retcode: $(cache.retcode)")
print(io, " Current Cache: ")
NonlinearSolveBase.show_nonlinearsolve_cache(io, cache.caches[cache.current], 4)
end

function InternalAPI.reinit!(
cache::NonlinearSolvePolyAlgorithmCache, args...; p = cache.p, u0 = cache.u0
)
foreach(cache.caches) do cache
InternalAPI.reinit!(cache, args...; p, u0)
end
cache.current = cache.alg.start_index
InternalAPI.reinit!(cache.stats)
cache.nsteps = 0
cache.total_time = 0.0
end

function SciMLBase.__init(
prob::AbstractNonlinearProblem, alg::NonlinearSolvePolyAlgorithm, args...;
stats = NLStats(0, 0, 0, 0, 0), maxtime = nothing, maxiters = 1000,
internalnorm = L2_NORM, alias_u0 = false, verbose = true, kwargs...
)
if alias_u0 && !ArrayInterface.ismutable(prob.u0)
verbose && @warn "`alias_u0` has been set to `true`, but `u0` is \
immutable (checked using `ArrayInterface.ismutable`)."
alias_u0 = false # If immutable don't care about aliasing
end

u0 = prob.u0
u0_aliased = alias_u0 ? copy(u0) : u0
alias_u0 && (prob = SciMLBase.remake(prob; u0 = u0_aliased))

return NonlinearSolvePolyAlgorithmCache(
alg.static_length, prob,
map(alg.algs) do solver
SciMLBase.__init(
prob, solver, args...;
stats, maxtime, internalnorm, alias_u0, verbose, kwargs...
)
end,
alg, -1, alg.start_index, 0, stats, 0.0, maxtime,
ReturnCode.Default, false, maxiters, internalnorm,
u0, u0_aliased, alias_u0
)
end

@generated function InternalAPI.step!(
cache::NonlinearSolvePolyAlgorithmCache{Val{N}}, args...; kwargs...
) where {N}
calls = []
cache_syms = [gensym("cache") for i in 1:N]
for i in 1:N
push!(calls,
quote
$(cache_syms[i]) = cache.caches[$(i)]
if $(i) == cache.current
InternalAPI.step!($(cache_syms[i]), args...; kwargs...)
$(cache_syms[i]).nsteps += 1
if !NonlinearSolveBase.not_terminated($(cache_syms[i]))
if SciMLBase.successful_retcode($(cache_syms[i]).retcode)
cache.best = $(i)
cache.force_stop = true
cache.retcode = $(cache_syms[i]).retcode
else
cache.current = $(i + 1)
end
end
return
end
end)
end

push!(calls, quote
if !(1 ≤ cache.current ≤ length(cache.caches))
minfu, idx = findmin_caches(cache.prob, cache.caches)
cache.best = idx
cache.retcode = cache.caches[idx].retcode
cache.force_stop = true
return
end
end)

return Expr(:block, calls...)
end

# Original is often determined on runtime information especially for PolyAlgorithms so it
# is best to never specialize on that
function build_solution_less_specialize(
prob::AbstractNonlinearProblem, alg, u, resid;
retcode = ReturnCode.Default, original = nothing, left = nothing,
right = nothing, stats = nothing, trace = nothing, kwargs...
)
return SciMLBase.NonlinearSolution{
eltype(eltype(u)), ndims(u), typeof(u), typeof(resid), typeof(prob),
typeof(alg), Any, typeof(left), typeof(stats), typeof(trace)
}(
u, resid, prob, alg, retcode, original, left, right, stats, trace
)
end

function findmin_caches(prob::AbstractNonlinearProblem, caches)
resids = map(caches) do cache
cache === nothing && return nothing
return NonlinearSolveBase.get_fu(cache)
end
return findmin_resids(prob, resids)
end

@views function findmin_resids(prob::AbstractNonlinearProblem, caches)
norm_fn = prob isa NonlinearLeastSquaresProblem ? Base.Fix2(norm, 2) :
Base.Fix2(norm, Inf)
idx = findfirst(Base.Fix2(!==, nothing), caches)
# This is an internal function so we assume that inputs are consistent and there is
# atleast one non-`nothing` value
fx_idx = norm_fn(caches[idx])
idx == length(caches) && return fx_idx, idx
fmin = @closure xᵢ -> begin
xᵢ === nothing && return oftype(fx_idx, Inf)
fx = norm_fn(xᵢ)
return ifelse(isnan(fx), oftype(fx, Inf), fx)
end
x_min, x_min_idx = findmin(fmin, caches[(idx + 1):length(caches)])
x_min < fx_idx && return x_min, x_min_idx + idx
return fx_idx, idx
end
Loading
Loading