From de1677ecf2e5732c2d01f8f6bd80784b112ef3e9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 13 Feb 2024 17:37:42 -0500 Subject: [PATCH] materialize the multi-step scheme --- Project.toml | 4 +++- docs/src/basics/faq.md | 4 ++-- docs/src/basics/sparsity_detection.md | 4 ++-- docs/src/tutorials/large_systems.md | 18 +++++++++--------- src/NonlinearSolve.jl | 10 +++++----- src/algorithms/multistep.jl | 9 +++++---- src/descent/multistep.jl | 26 ++++++++++++++++++++++---- src/utils.jl | 19 +++++++++++++++++++ 8 files changed, 67 insertions(+), 27 deletions(-) diff --git a/Project.toml b/Project.toml index 5c1501bb0..403831c1d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "3.6.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" @@ -55,12 +56,13 @@ NonlinearSolveZygoteExt = "Zygote" [compat] ADTypes = "0.2.6" +Accessors = "0.1" Aqua = "0.8" ArrayInterface = "7.7" BandedMatrices = "1.4" BenchmarkTools = "1.4" -ConcreteStructs = "0.2.3" CUDA = "5.1" +ConcreteStructs = "0.2.3" DiffEqBase = "6.146.0" Enzyme = "0.11.11" FastBroadcast = "0.2.8" diff --git a/docs/src/basics/faq.md b/docs/src/basics/faq.md index e40b57b33..2144fcba4 100644 --- a/docs/src/basics/faq.md +++ b/docs/src/basics/faq.md @@ -72,7 +72,7 @@ differentiate the function based on the input types. However, this function has `xx = [1.0, 2.0, 3.0, 4.0]` followed by a `xx[1] = var[1] - v_true[1]` where `var` might be a Dual number. This causes the error. To fix it: - 1. Specify the `autodiff` to be `AutoFiniteDiff` +1. Specify the `autodiff` to be `AutoFiniteDiff` ```@example dual_error_faq sol = solve(prob_oop, LevenbergMarquardt(; autodiff = AutoFiniteDiff()); maxiters = 10000, @@ -81,7 +81,7 @@ sol = solve(prob_oop, LevenbergMarquardt(; autodiff = AutoFiniteDiff()); maxiter This worked but, Finite Differencing is not the recommended approach in any scenario. - 2. Rewrite the function to use +2. Rewrite the function to use [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl) or write it as ```@example dual_error_faq diff --git a/docs/src/basics/sparsity_detection.md b/docs/src/basics/sparsity_detection.md index 222aebe19..e23d42dab 100644 --- a/docs/src/basics/sparsity_detection.md +++ b/docs/src/basics/sparsity_detection.md @@ -34,7 +34,7 @@ prob = NonlinearProblem( If the `colorvec` is not provided, then it is computed on demand. !!! note - + One thing to be careful about in this case is that `colorvec` is dependent on the autodiff backend used. Forward Mode and Finite Differencing will assume that the colorvec is the column colorvec, while Reverse Mode will assume that the colorvec is the @@ -76,7 +76,7 @@ loaded, we default to using `SymbolicsSparsityDetection()`, else we default to u options if those are provided. !!! warning - + If you provide a non-sparse AD, and provide a `sparsity` or `jac_prototype` then we will use dense AD. This is because, if you provide a specific AD type, we assume that you know what you are doing and want to override the default choice of `nothing`. diff --git a/docs/src/tutorials/large_systems.md b/docs/src/tutorials/large_systems.md index aedd58445..d6f3c96fb 100644 --- a/docs/src/tutorials/large_systems.md +++ b/docs/src/tutorials/large_systems.md @@ -2,15 +2,15 @@ This tutorial is for getting into the extra features of using NonlinearSolve.jl. Solving ill-conditioned nonlinear systems requires specializing the linear solver on properties of -the Jacobian in order to cut down on the ``\mathcal{O}(n^3)`` linear solve and the -``\mathcal{O}(n^2)`` back-solves. This tutorial is designed to explain the advanced usage of +the Jacobian in order to cut down on the `\mathcal{O}(n^3)` linear solve and the +`\mathcal{O}(n^2)` back-solves. This tutorial is designed to explain the advanced usage of NonlinearSolve.jl by solving the steady state stiff Brusselator partial differential equation (BRUSS) using NonlinearSolve.jl. ## Definition of the Brusselator Equation !!! note - + Feel free to skip this section: it simply defines the example problem. The Brusselator PDE is defined as follows: @@ -118,11 +118,11 @@ However, if you know the sparsity of your problem, then you can pass a different type. For example, a `SparseMatrixCSC` will give a sparse matrix. Other sparse matrix types include: - - Bidiagonal - - Tridiagonal - - SymTridiagonal - - BandedMatrix ([BandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BandedMatrices.jl)) - - BlockBandedMatrix ([BlockBandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BlockBandedMatrices.jl)) +- Bidiagonal +- Tridiagonal +- SymTridiagonal +- BandedMatrix ([BandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BandedMatrices.jl)) +- BlockBandedMatrix ([BlockBandedMatrices.jl](https://github.com/JuliaLinearAlgebra/BlockBandedMatrices.jl)) ## Approximate Sparsity Detection & Sparse Jacobians @@ -213,7 +213,7 @@ choices, see the `linsolve` choices are any valid [LinearSolve.jl](https://linearsolve.sciml.ai/dev/) solver. !!! note - + Switching to a Krylov linear solver will automatically change the nonlinear problem solver into Jacobian-free mode, dramatically reducing the memory required. This can be overridden by adding `concrete_jac=true` to the algorithm. diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index bd39e63d0..784b00b76 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -8,9 +8,9 @@ import Reexport: @reexport import PrecompileTools: @recompile_invalidations, @compile_workload, @setup_workload @recompile_invalidations begin - using ADTypes, ConcreteStructs, DiffEqBase, FastBroadcast, FastClosures, LazyArrays, - LineSearches, LinearAlgebra, LinearSolve, MaybeInplace, Preferences, Printf, - SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools + using Accessors, ADTypes, ConcreteStructs, DiffEqBase, FastBroadcast, FastClosures, + LazyArrays, LineSearches, LinearAlgebra, LinearSolve, MaybeInplace, Preferences, + Printf, SciMLBase, SimpleNonlinearSolve, SparseArrays, SparseDiffTools import ArrayInterface: undefmatrix, can_setindex, restructure, fast_scalar_indexing import DiffEqBase: AbstractNonlinearTerminationMode, @@ -142,7 +142,7 @@ end # Core Algorithms export NewtonRaphson, PseudoTransient, Klement, Broyden, LimitedMemoryBroyden, DFSane, - MultiStepNonlinearSolver + MultiStepNonlinearSolver export GaussNewton, LevenbergMarquardt, TrustRegion export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg @@ -156,7 +156,7 @@ export GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, Genera # Descent Algorithms export NewtonDescent, SteepestDescent, Dogleg, DampedNewtonDescent, - GeodesicAcceleration, GenericMultiStepDescent + GeodesicAcceleration, GenericMultiStepDescent ## Multistep Algorithms export MultiStepSchemes diff --git a/src/algorithms/multistep.jl b/src/algorithms/multistep.jl index 35b204094..d1f087fe3 100644 --- a/src/algorithms/multistep.jl +++ b/src/algorithms/multistep.jl @@ -1,7 +1,8 @@ function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing, - scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing) - descent = GenericMultiStepDescent(; scheme, linsolve, precs) - # TODO: Use the scheme as the name - return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = :MultiStepNonlinearSolver, + scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing, + vjp_autodiff = nothing) + scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff)) + descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs) + return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme), descent, jacobian_ad = autodiff) end diff --git a/src/descent/multistep.jl b/src/descent/multistep.jl index 2879a9bef..e92653eb8 100644 --- a/src/descent/multistep.jl +++ b/src/descent/multistep.jl @@ -7,32 +7,47 @@ typically the last names of the authors of the paper that introduced the method. """ module MultiStepSchemes +using ConcreteStructs + abstract type AbstractMultiStepScheme end function Base.show(io::IO, mss::AbstractMultiStepScheme) print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])") end +alg_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = alg_steps(T()) + struct __PotraPtak3 <: AbstractMultiStepScheme end const PotraPtak3 = __PotraPtak3() -alg_steps(::__PotraPtak3) = 1 +alg_steps(::__PotraPtak3) = 2 -struct __SinghSharma4 <: AbstractMultiStepScheme end +@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme + vjp_autodiff = nothing +end const SinghSharma4 = __SinghSharma4() alg_steps(::__SinghSharma4) = 3 -struct __SinghSharma5 <: AbstractMultiStepScheme end +@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme + vjp_autodiff = nothing +end const SinghSharma5 = __SinghSharma5() alg_steps(::__SinghSharma5) = 3 -struct __SinghSharma7 <: AbstractMultiStepScheme end +@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme + vjp_autodiff = nothing +end const SinghSharma7 = __SinghSharma7() alg_steps(::__SinghSharma7) = 4 +@generated function display_name(alg::T) where {T <: AbstractMultiStepScheme} + res = Symbol(first(split(last(split(string(T), ".")), "{"; limit = 2))[3:end]) + return :($(Meta.quot(res))) +end + end const MSS = MultiStepSchemes @@ -43,6 +58,8 @@ const MSS = MultiStepSchemes precs = DEFAULT_PRECS end +Base.show(io::IO, alg::GenericMultiStepDescent) = print(io, "$(alg.scheme)()") + supports_line_search(::GenericMultiStepDescent) = false supports_trust_region(::GenericMultiStepDescent) = false @@ -51,6 +68,7 @@ supports_trust_region(::GenericMultiStepDescent) = false p δu δus + extras scheme::S lincache timer diff --git a/src/utils.jl b/src/utils.jl index 7f4c2c439..e5595ea0d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -158,3 +158,22 @@ Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the i """ @inline pickchunksize(x) = pickchunksize(length(x)) @inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x) + +""" + apply_patch(scheme, patch::NamedTuple{names}) + +Applies the patch to the scheme, returning the new scheme. If some of the `names` are not, +present in the scheme, they are ignored. +""" +@generated function apply_patch(scheme, patch::NamedTuple{names}) where {names} + exprs = [] + for name in names + hasfield(scheme, name) || continue + push!(exprs, quote + lens = PropertyLens{$(Meta.quot(name))}() + return set(scheme, lens, getfield(patch, $(Meta.quot(name)))) + end) + end + push!(exprs, :(return scheme)) + return Expr(:block, exprs...) +end