Skip to content

Commit

Permalink
materialize the multi-step scheme
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 13, 2024
1 parent 2d8a732 commit de1677e
Show file tree
Hide file tree
Showing 8 changed files with 67 additions and 27 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "3.6.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down Expand Up @@ -55,12 +56,13 @@ NonlinearSolveZygoteExt = "Zygote"

[compat]
ADTypes = "0.2.6"
Accessors = "0.1"
Aqua = "0.8"
ArrayInterface = "7.7"
BandedMatrices = "1.4"
BenchmarkTools = "1.4"
ConcreteStructs = "0.2.3"
CUDA = "5.1"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.146.0"
Enzyme = "0.11.11"
FastBroadcast = "0.2.8"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/basics/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ differentiate the function based on the input types. However, this function has
`xx = [1.0, 2.0, 3.0, 4.0]` followed by a `xx[1] = var[1] - v_true[1]` where `var` might
be a Dual number. This causes the error. To fix it:

1. Specify the `autodiff` to be `AutoFiniteDiff`
1. Specify the `autodiff` to be `AutoFiniteDiff`

```@example dual_error_faq
sol = solve(prob_oop, LevenbergMarquardt(; autodiff = AutoFiniteDiff()); maxiters = 10000,
Expand All @@ -81,7 +81,7 @@ sol = solve(prob_oop, LevenbergMarquardt(; autodiff = AutoFiniteDiff()); maxiter

This worked but, Finite Differencing is not the recommended approach in any scenario.

2. Rewrite the function to use
2. Rewrite the function to use
[PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl) or write it as

```@example dual_error_faq
Expand Down
4 changes: 2 additions & 2 deletions docs/src/basics/sparsity_detection.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ prob = NonlinearProblem(
If the `colorvec` is not provided, then it is computed on demand.

!!! note

One thing to be careful about in this case is that `colorvec` is dependent on the
autodiff backend used. Forward Mode and Finite Differencing will assume that the
colorvec is the column colorvec, while Reverse Mode will assume that the colorvec is the
Expand Down Expand Up @@ -76,7 +76,7 @@ loaded, we default to using `SymbolicsSparsityDetection()`, else we default to u
options if those are provided.

!!! warning

If you provide a non-sparse AD, and provide a `sparsity` or `jac_prototype` then
we will use dense AD. This is because, if you provide a specific AD type, we assume
that you know what you are doing and want to override the default choice of `nothing`.
18 changes: 9 additions & 9 deletions docs/src/tutorials/large_systems.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

This tutorial is for getting into the extra features of using NonlinearSolve.jl. Solving
ill-conditioned nonlinear systems requires specializing the linear solver on properties of
the Jacobian in order to cut down on the ``\mathcal{O}(n^3)`` linear solve and the
``\mathcal{O}(n^2)`` back-solves. This tutorial is designed to explain the advanced usage of
the Jacobian in order to cut down on the `\mathcal{O}(n^3)` linear solve and the
`\mathcal{O}(n^2)` back-solves. This tutorial is designed to explain the advanced usage of
NonlinearSolve.jl by solving the steady state stiff Brusselator partial differential
equation (BRUSS) using NonlinearSolve.jl.

## Definition of the Brusselator Equation

!!! note

Feel free to skip this section: it simply defines the example problem.

The Brusselator PDE is defined as follows:
Expand Down Expand Up @@ -118,11 +118,11 @@ However, if you know the sparsity of your problem, then you can pass a different
type. For example, a `SparseMatrixCSC` will give a sparse matrix. Other sparse matrix types
include:

- Bidiagonal
- Tridiagonal
- SymTridiagonal
- BandedMatrix ([BandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BandedMatrices.jl))
- BlockBandedMatrix ([BlockBandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BlockBandedMatrices.jl))
- Bidiagonal
- Tridiagonal
- SymTridiagonal
- BandedMatrix ([BandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BandedMatrices.jl))
- BlockBandedMatrix ([BlockBandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BlockBandedMatrices.jl))

## Approximate Sparsity Detection & Sparse Jacobians

Expand Down Expand Up @@ -213,7 +213,7 @@ choices, see the
`linsolve` choices are any valid [LinearSolve.jl](https://linearsolve.sciml.ai/dev/) solver.

!!! note

Switching to a Krylov linear solver will automatically change the nonlinear problem
solver into Jacobian-free mode, dramatically reducing the memory required. This can be
overridden by adding `concrete_jac=true` to the algorithm.
Expand Down
10 changes: 5 additions & 5 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import Reexport: @reexport
import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload

@recompile_invalidations begin
using ADTypes, ConcreteStructs, DiffEqBase, FastBroadcast, FastClosures, LazyArrays,
LineSearches, LinearAlgebra, LinearSolve, MaybeInplace, Preferences, Printf,
SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools
using Accessors, ADTypes, ConcreteStructs, DiffEqBase, FastBroadcast, FastClosures,
LazyArrays, LineSearches, LinearAlgebra, LinearSolve, MaybeInplace, Preferences,
Printf, SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools

import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing
import DiffEqBase: AbstractNonlinearTerminationMode,
Expand Down Expand Up @@ -142,7 +142,7 @@ end

# Core Algorithms
export NewtonRaphson, PseudoTransient, Klement, Broyden, LimitedMemoryBroyden, DFSane,
MultiStepNonlinearSolver
MultiStepNonlinearSolver
export GaussNewton, LevenbergMarquardt, TrustRegion
export NonlinearSolvePolyAlgorithm,
RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg
Expand All @@ -156,7 +156,7 @@ export GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, Genera

# Descent Algorithms
export NewtonDescent, SteepestDescent, Dogleg, DampedNewtonDescent,
GeodesicAcceleration, GenericMultiStepDescent
GeodesicAcceleration, GenericMultiStepDescent
## Multistep Algorithms
export MultiStepSchemes

Expand Down
9 changes: 5 additions & 4 deletions src/algorithms/multistep.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing,

Check warning on line 1 in src/algorithms/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/multistep.jl#L1

Added line #L1 was not covered by tests
scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing)
descent = GenericMultiStepDescent(; scheme, linsolve, precs)
# TODO: Use the scheme as the name
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = :MultiStepNonlinearSolver,
scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing,
vjp_autodiff = nothing)
scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff))
descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs)
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme),

Check warning on line 6 in src/algorithms/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/multistep.jl#L4-L6

Added lines #L4 - L6 were not covered by tests
descent, jacobian_ad = autodiff)
end
26 changes: 22 additions & 4 deletions src/descent/multistep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,47 @@ typically the last names of the authors of the paper that introduced the method.
"""
module MultiStepSchemes

using ConcreteStructs

abstract type AbstractMultiStepScheme end

function Base.show(io::IO, mss::AbstractMultiStepScheme)
print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])")

Check warning on line 15 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L14-L15

Added lines #L14 - L15 were not covered by tests
end

alg_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = alg_steps(T())

Check warning on line 18 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L18

Added line #L18 was not covered by tests

struct __PotraPtak3 <: AbstractMultiStepScheme end
const PotraPtak3 = __PotraPtak3()

alg_steps(::__PotraPtak3) = 1
alg_steps(::__PotraPtak3) = 2

Check warning on line 23 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L23

Added line #L23 was not covered by tests

struct __SinghSharma4 <: AbstractMultiStepScheme end
@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme
vjp_autodiff = nothing
end
const SinghSharma4 = __SinghSharma4()

alg_steps(::__SinghSharma4) = 3

Check warning on line 30 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L30

Added line #L30 was not covered by tests

struct __SinghSharma5 <: AbstractMultiStepScheme end
@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme
vjp_autodiff = nothing
end
const SinghSharma5 = __SinghSharma5()

alg_steps(::__SinghSharma5) = 3

Check warning on line 37 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L37

Added line #L37 was not covered by tests

struct __SinghSharma7 <: AbstractMultiStepScheme end
@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme
vjp_autodiff = nothing
end
const SinghSharma7 = __SinghSharma7()

alg_steps(::__SinghSharma7) = 4

Check warning on line 44 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L44

Added line #L44 was not covered by tests

@generated function display_name(alg::T) where {T <: AbstractMultiStepScheme}
res = Symbol(first(split(last(split(string(T), ".")), "{"; limit = 2))[3:end])
return :($(Meta.quot(res)))

Check warning on line 48 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L46-L48

Added lines #L46 - L48 were not covered by tests
end

end

const MSS = MultiStepSchemes
Expand All @@ -43,6 +58,8 @@ const MSS = MultiStepSchemes
precs = DEFAULT_PRECS
end

Base.show(io::IO, alg::GenericMultiStepDescent) = print(io, "$(alg.scheme)()")

Check warning on line 61 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L61

Added line #L61 was not covered by tests

supports_line_search(::GenericMultiStepDescent) = false
supports_trust_region(::GenericMultiStepDescent) = false

Check warning on line 64 in src/descent/multistep.jl

View check run for this annotation

Codecov / codecov/patch

src/descent/multistep.jl#L63-L64

Added lines #L63 - L64 were not covered by tests

Expand All @@ -51,6 +68,7 @@ supports_trust_region(::GenericMultiStepDescent) = false
p
δu
δus
extras
scheme::S
lincache
timer
Expand Down
19 changes: 19 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,22 @@ Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the i
"""
@inline pickchunksize(x) = pickchunksize(length(x))
@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)

"""
apply_patch(scheme, patch::NamedTuple{names})
Applies the patch to the scheme, returning the new scheme. If some of the `names` are not,
present in the scheme, they are ignored.
"""
@generated function apply_patch(scheme, patch::NamedTuple{names}) where {names}
exprs = []
for name in names
hasfield(scheme, name) || continue
push!(exprs, quote
lens = PropertyLens{$(Meta.quot(name))}()
return set(scheme, lens, getfield(patch, $(Meta.quot(name))))

Check warning on line 174 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L168-L174

Added lines #L168 - L174 were not covered by tests
end)
end
push!(exprs, :(return scheme))
return Expr(:block, exprs...)

Check warning on line 178 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L176-L178

Added lines #L176 - L178 were not covered by tests
end

0 comments on commit de1677e

Please sign in to comment.