diff --git a/src/algorithms/multistep.jl b/src/algorithms/multistep.jl index abb056402..cd5a31890 100644 --- a/src/algorithms/multistep.jl +++ b/src/algorithms/multistep.jl @@ -1,8 +1,10 @@ function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing, scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing, vjp_autodiff = nothing, linesearch = NoLineSearch()) - scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff)) + forward_ad = ifelse(autodiff isa ADTypes.AbstractForwardMode, autodiff, nothing) + scheme_concrete = apply_patch( + scheme, (; autodiff, vjp_autodiff, jvp_autodiff = forward_ad)) descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs) return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme), - descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff) + descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff, forward_ad) end diff --git a/src/descent/damped_newton.jl b/src/descent/damped_newton.jl index a00b480f8..cee437d7e 100644 --- a/src/descent/damped_newton.jl +++ b/src/descent/damped_newton.jl @@ -58,11 +58,7 @@ function __internal_init( shared::Val{N} = Val(1), kwargs...) where {INV, N} length(fu) != length(u) && @assert !INV "Precomputed Inverse for Non-Square Jacobian doesn't make sense." - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end - + δu, δus = @shared_caches N (@bb δu = similar(u)) normal_form_damping = returns_norm_form_damping(alg.damping_fn) normal_form_linsolve = __needs_square_A(alg.linsolve, u) if u isa Number diff --git a/src/descent/dogleg.jl b/src/descent/dogleg.jl index 772f06295..ca7314760 100644 --- a/src/descent/dogleg.jl +++ b/src/descent/dogleg.jl @@ -56,10 +56,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::Dogleg, J, fu, u; linsolve_kwargs, abstol, reltol, shared, kwargs...) cauchy_cache = __internal_init(prob, alg.steepest_descent, J, fu, u; pre_inverted, linsolve_kwargs, abstol, reltol, shared, kwargs...) - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end + δu, δus = @shared_caches N (@bb δu = similar(u)) @bb δu_cache_1 = similar(u) @bb δu_cache_2 = similar(u) @bb δu_cache_mul = similar(u) diff --git a/src/descent/geodesic_acceleration.jl b/src/descent/geodesic_acceleration.jl index 76033da0f..a989c0376 100644 --- a/src/descent/geodesic_acceleration.jl +++ b/src/descent/geodesic_acceleration.jl @@ -89,10 +89,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::GeodesicAccelerati abstol = nothing, reltol = nothing, internalnorm::F = DEFAULT_NORM, kwargs...) where {INV, N, F} T = promote_type(eltype(u), eltype(fu)) - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end + δu, δus = @shared_caches N (@bb δu = similar(u)) descent_cache = __internal_init(prob, alg.descent, J, fu, u; shared = Val(N * 2), pre_inverted, linsolve_kwargs, abstol, reltol, kwargs...) @bb Jv = similar(fu) diff --git a/src/descent/multistep.jl b/src/descent/multistep.jl index 67c756a2c..eae086493 100644 --- a/src/descent/multistep.jl +++ b/src/descent/multistep.jl @@ -15,12 +15,12 @@ function Base.show(io::IO, mss::AbstractMultiStepScheme) print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])") end -alg_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = alg_steps(T()) +newton_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = newton_steps(T()) struct __PotraPtak3 <: AbstractMultiStepScheme end const PotraPtak3 = __PotraPtak3() -alg_steps(::__PotraPtak3) = 2 +newton_steps(::__PotraPtak3) = 2 nintermediates(::__PotraPtak3) = 1 @kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme @@ -28,21 +28,23 @@ nintermediates(::__PotraPtak3) = 1 end const SinghSharma4 = __SinghSharma4() -alg_steps(::__SinghSharma4) = 3 +newton_steps(::__SinghSharma4) = 4 +nintermediates(::__SinghSharma4) = 2 @kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme jvp_autodiff = nothing end const SinghSharma5 = __SinghSharma5() -alg_steps(::__SinghSharma5) = 3 +newton_steps(::__SinghSharma5) = 4 +nintermediates(::__SinghSharma5) = 2 @kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme jvp_autodiff = nothing end const SinghSharma7 = __SinghSharma7() -alg_steps(::__SinghSharma7) = 4 +newton_steps(::__SinghSharma7) = 6 @generated function display_name(alg::T) where {T <: AbstractMultiStepScheme} res = Symbol(first(split(last(split(string(T), ".")), "{"; limit = 2))[3:end]) @@ -75,6 +77,8 @@ supports_trust_region(::GenericMultiStepDescent) = false fus internal_cache internal_caches + extra + extras scheme::S timer nf::Int @@ -91,49 +95,37 @@ function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = ca end function __internal_multistep_caches( - scheme::MSS.__PotraPtak3, alg::GenericMultiStepDescent, - prob, args...; shared::Val{N} = Val(1), kwargs...) where {N} + scheme::Union{MSS.__PotraPtak3, MSS.__SinghSharma4, MSS.__SinghSharma5}, + alg::GenericMultiStepDescent, prob, args...; + shared::Val{N} = Val(1), kwargs...) where {N} internal_descent = NewtonDescent(; alg.linsolve, alg.precs) - internal_cache = __internal_init( + return @shared_caches N __internal_init( prob, internal_descent, args...; kwargs..., shared = Val(2)) - internal_caches = N ≤ 1 ? nothing : - map(2:N) do i - __internal_init(prob, internal_descent, args...; kwargs..., shared = Val(2)) - end - return internal_cache, internal_caches end +__extras_cache(::MSS.AbstractMultiStepScheme, args...; kwargs...) = nothing, nothing + function __internal_init(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem}, alg::GenericMultiStepDescent, J, fu, u; shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;), abstol = nothing, reltol = nothing, timer = get_timer_output(), kwargs...) where {INV, N} - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end - fu_cache = ntuple(MSS.nintermediates(alg.scheme)) do i + δu, δus = @shared_caches N (@bb δu = similar(u)) + fu_cache, fus_cache = @shared_caches N (ntuple(MSS.nintermediates(alg.scheme)) do i @bb xx = similar(fu) - end - fus_cache = N ≤ 1 ? nothing : map(2:N) do i - ntuple(MSS.nintermediates(alg.scheme)) do j - @bb xx = similar(fu) - end - end - u_cache = ntuple(MSS.nintermediates(alg.scheme)) do i + end) + u_cache, us_cache = @shared_caches N (ntuple(MSS.nintermediates(alg.scheme)) do i @bb xx = similar(u) - end - us_cache = N ≤ 1 ? nothing : map(2:N) do i - ntuple(MSS.nintermediates(alg.scheme)) do j - @bb xx = similar(u) - end - end + end) internal_cache, internal_caches = __internal_multistep_caches( alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs, abstol, reltol, timer, kwargs...) + extra, extras = __extras_cache( + alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs, + abstol, reltol, timer, kwargs...) return GenericMultiStepDescentCache( prob.f, prob.p, δu, δus, u_cache, us_cache, fu_cache, fus_cache, - internal_cache, internal_caches, alg.scheme, timer, 0) + internal_cache, internal_caches, extra, extras, alg.scheme, timer, 0) end function __internal_solve!(cache::GenericMultiStepDescentCache{MSS.__PotraPtak3, INV}, J, diff --git a/src/descent/newton.jl b/src/descent/newton.jl index 26bea6350..52f8e9743 100644 --- a/src/descent/newton.jl +++ b/src/descent/newton.jl @@ -36,10 +36,7 @@ function __internal_init(prob::NonlinearProblem, alg::NewtonDescent, J, fu, u; shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;), abstol = nothing, reltol = nothing, timer = get_timer_output(), kwargs...) where {INV, N} - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end + δu, δus = @shared_caches N (@bb δu = similar(u)) INV && return NewtonDescentCache{true, false}(δu, δus, nothing, nothing, nothing, timer) lincache = LinearSolverCache(alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol, linsolve_kwargs...) @@ -64,10 +61,7 @@ function __internal_init(prob::NonlinearLeastSquaresProblem, alg::NewtonDescent, end lincache = LinearSolverCache(alg, alg.linsolve, A, b, _vec(u); abstol, reltol, linsolve_kwargs...) - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end + δu, δus = @shared_caches N (@bb δu = similar(u)) return NewtonDescentCache{false, normal_form}(δu, δus, lincache, JᵀJ, Jᵀfu, timer) end diff --git a/src/descent/steepest.jl b/src/descent/steepest.jl index da7812fa0..9fd7cc9a9 100644 --- a/src/descent/steepest.jl +++ b/src/descent/steepest.jl @@ -34,10 +34,7 @@ end linsolve_kwargs = (;), abstol = nothing, reltol = nothing, timer = get_timer_output(), kwargs...) where {INV, N} INV && @assert length(fu)==length(u) "Non-Square Jacobian Inverse doesn't make sense." - @bb δu = similar(u) - δus = N ≤ 1 ? nothing : map(2:N) do i - @bb δu_ = similar(u) - end + δu, δus = @shared_caches N (@bb δu = similar(u)) if INV lincache = LinearSolverCache(alg, alg.linsolve, transpose(J), _vec(fu), _vec(u); abstol, reltol, linsolve_kwargs...) diff --git a/src/utils.jl b/src/utils.jl index e5595ea0d..5c609fe97 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -177,3 +177,16 @@ present in the scheme, they are ignored. push!(exprs, :(return scheme)) return Expr(:block, exprs...) end + +macro shared_caches(N, expr) + @gensym cache caches + return esc(quote + begin + $(cache) = $(expr) + $(caches) = $(N) ≤ 1 ? nothing : map(2:($(N))) do i + $(expr) + end + ($cache, $caches) + end + end) +end