Skip to content

Commit

Permalink
Use macro for shared caches
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 14, 2024
1 parent dccc1dd commit 18934af
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 59 deletions.
6 changes: 4 additions & 2 deletions src/algorithms/multistep.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
function MultiStepNonlinearSolver(; concrete_jac = nothing, linsolve = nothing,
scheme = MSS.PotraPtak3, precs = DEFAULT_PRECS, autodiff = nothing,
vjp_autodiff = nothing, linesearch = NoLineSearch())
scheme_concrete = apply_patch(scheme, (; autodiff, vjp_autodiff))
forward_ad = ifelse(autodiff isa ADTypes.AbstractForwardMode, autodiff, nothing)
scheme_concrete = apply_patch(
scheme, (; autodiff, vjp_autodiff, jvp_autodiff = forward_ad))
descent = GenericMultiStepDescent(; scheme = scheme_concrete, linsolve, precs)
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = MSS.display_name(scheme),
descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff)
descent, jacobian_ad = autodiff, linesearch, reverse_ad = vjp_autodiff, forward_ad)
end
6 changes: 1 addition & 5 deletions src/descent/damped_newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,7 @@ function __internal_init(
shared::Val{N} = Val(1), kwargs...) where {INV, N}
length(fu) != length(u) &&
@assert !INV "Precomputed Inverse for Non-Square Jacobian doesn't make sense."
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end

δu, δus = @shared_caches N (@bb δu = similar(u))
normal_form_damping = returns_norm_form_damping(alg.damping_fn)
normal_form_linsolve = __needs_square_A(alg.linsolve, u)
if u isa Number
Expand Down
5 changes: 1 addition & 4 deletions src/descent/dogleg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::Dogleg, J, fu, u;
linsolve_kwargs, abstol, reltol, shared, kwargs...)
cauchy_cache = __internal_init(prob, alg.steepest_descent, J, fu, u; pre_inverted,
linsolve_kwargs, abstol, reltol, shared, kwargs...)
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
δu, δus = @shared_caches N (@bb δu = similar(u))
@bb δu_cache_1 = similar(u)
@bb δu_cache_2 = similar(u)
@bb δu_cache_mul = similar(u)
Expand Down
5 changes: 1 addition & 4 deletions src/descent/geodesic_acceleration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,7 @@ function __internal_init(prob::AbstractNonlinearProblem, alg::GeodesicAccelerati
abstol = nothing, reltol = nothing, internalnorm::F = DEFAULT_NORM,
kwargs...) where {INV, N, F}
T = promote_type(eltype(u), eltype(fu))
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
δu, δus = @shared_caches N (@bb δu = similar(u))
descent_cache = __internal_init(prob, alg.descent, J, fu, u; shared = Val(N * 2),
pre_inverted, linsolve_kwargs, abstol, reltol, kwargs...)
@bb Jv = similar(fu)
Expand Down
56 changes: 24 additions & 32 deletions src/descent/multistep.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,36 @@ function Base.show(io::IO, mss::AbstractMultiStepScheme)
print(io, "MultiStepSchemes.$(string(nameof(typeof(mss)))[3:end])")
end

alg_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = alg_steps(T())
newton_steps(::Type{T}) where {T <: AbstractMultiStepScheme} = newton_steps(T())

struct __PotraPtak3 <: AbstractMultiStepScheme end
const PotraPtak3 = __PotraPtak3()

alg_steps(::__PotraPtak3) = 2
newton_steps(::__PotraPtak3) = 2
nintermediates(::__PotraPtak3) = 1

@kwdef @concrete struct __SinghSharma4 <: AbstractMultiStepScheme
jvp_autodiff = nothing
end
const SinghSharma4 = __SinghSharma4()

alg_steps(::__SinghSharma4) = 3
newton_steps(::__SinghSharma4) = 4
nintermediates(::__SinghSharma4) = 2

@kwdef @concrete struct __SinghSharma5 <: AbstractMultiStepScheme
jvp_autodiff = nothing
end
const SinghSharma5 = __SinghSharma5()

alg_steps(::__SinghSharma5) = 3
newton_steps(::__SinghSharma5) = 4
nintermediates(::__SinghSharma5) = 2

@kwdef @concrete struct __SinghSharma7 <: AbstractMultiStepScheme
jvp_autodiff = nothing
end
const SinghSharma7 = __SinghSharma7()

alg_steps(::__SinghSharma7) = 4
newton_steps(::__SinghSharma7) = 6

@generated function display_name(alg::T) where {T <: AbstractMultiStepScheme}
res = Symbol(first(split(last(split(string(T), ".")), "{"; limit = 2))[3:end])
Expand Down Expand Up @@ -75,6 +77,8 @@ supports_trust_region(::GenericMultiStepDescent) = false
fus
internal_cache
internal_caches
extra
extras
scheme::S
timer
nf::Int
Expand All @@ -91,49 +95,37 @@ function __reinit_internal!(cache::GenericMultiStepDescentCache, args...; p = ca
end

function __internal_multistep_caches(
scheme::MSS.__PotraPtak3, alg::GenericMultiStepDescent,
prob, args...; shared::Val{N} = Val(1), kwargs...) where {N}
scheme::Union{MSS.__PotraPtak3, MSS.__SinghSharma4, MSS.__SinghSharma5},
alg::GenericMultiStepDescent, prob, args...;
shared::Val{N} = Val(1), kwargs...) where {N}
internal_descent = NewtonDescent(; alg.linsolve, alg.precs)
internal_cache = __internal_init(
return @shared_caches N __internal_init(
prob, internal_descent, args...; kwargs..., shared = Val(2))
internal_caches = N 1 ? nothing :
map(2:N) do i
__internal_init(prob, internal_descent, args...; kwargs..., shared = Val(2))
end
return internal_cache, internal_caches
end

__extras_cache(::MSS.AbstractMultiStepScheme, args...; kwargs...) = nothing, nothing

function __internal_init(prob::Union{NonlinearProblem, NonlinearLeastSquaresProblem},
alg::GenericMultiStepDescent, J, fu, u; shared::Val{N} = Val(1),
pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
abstol = nothing, reltol = nothing, timer = get_timer_output(),
kwargs...) where {INV, N}
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
fu_cache = ntuple(MSS.nintermediates(alg.scheme)) do i
δu, δus = @shared_caches N (@bb δu = similar(u))
fu_cache, fus_cache = @shared_caches N (ntuple(MSS.nintermediates(alg.scheme)) do i
@bb xx = similar(fu)
end
fus_cache = N 1 ? nothing : map(2:N) do i
ntuple(MSS.nintermediates(alg.scheme)) do j
@bb xx = similar(fu)
end
end
u_cache = ntuple(MSS.nintermediates(alg.scheme)) do i
end)
u_cache, us_cache = @shared_caches N (ntuple(MSS.nintermediates(alg.scheme)) do i
@bb xx = similar(u)
end
us_cache = N 1 ? nothing : map(2:N) do i
ntuple(MSS.nintermediates(alg.scheme)) do j
@bb xx = similar(u)
end
end
end)
internal_cache, internal_caches = __internal_multistep_caches(
alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
abstol, reltol, timer, kwargs...)
extra, extras = __extras_cache(
alg.scheme, alg, prob, J, fu, u; shared, pre_inverted, linsolve_kwargs,
abstol, reltol, timer, kwargs...)
return GenericMultiStepDescentCache(
prob.f, prob.p, δu, δus, u_cache, us_cache, fu_cache, fus_cache,
internal_cache, internal_caches, alg.scheme, timer, 0)
internal_cache, internal_caches, extra, extras, alg.scheme, timer, 0)
end

function __internal_solve!(cache::GenericMultiStepDescentCache{MSS.__PotraPtak3, INV}, J,
Expand Down
10 changes: 2 additions & 8 deletions src/descent/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ function __internal_init(prob::NonlinearProblem, alg::NewtonDescent, J, fu, u;
shared::Val{N} = Val(1), pre_inverted::Val{INV} = False, linsolve_kwargs = (;),
abstol = nothing, reltol = nothing, timer = get_timer_output(),
kwargs...) where {INV, N}
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
δu, δus = @shared_caches N (@bb δu = similar(u))
INV && return NewtonDescentCache{true, false}(δu, δus, nothing, nothing, nothing, timer)
lincache = LinearSolverCache(alg, alg.linsolve, J, _vec(fu), _vec(u); abstol, reltol,
linsolve_kwargs...)
Expand All @@ -64,10 +61,7 @@ function __internal_init(prob::NonlinearLeastSquaresProblem, alg::NewtonDescent,
end
lincache = LinearSolverCache(alg, alg.linsolve, A, b, _vec(u); abstol, reltol,
linsolve_kwargs...)
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
δu, δus = @shared_caches N (@bb δu = similar(u))
return NewtonDescentCache{false, normal_form}(δu, δus, lincache, JᵀJ, Jᵀfu, timer)
end

Expand Down
5 changes: 1 addition & 4 deletions src/descent/steepest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ end
linsolve_kwargs = (;), abstol = nothing, reltol = nothing,
timer = get_timer_output(), kwargs...) where {INV, N}
INV && @assert length(fu)==length(u) "Non-Square Jacobian Inverse doesn't make sense."
@bb δu = similar(u)
δus = N 1 ? nothing : map(2:N) do i
@bb δu_ = similar(u)
end
δu, δus = @shared_caches N (@bb δu = similar(u))
if INV
lincache = LinearSolverCache(alg, alg.linsolve, transpose(J), _vec(fu), _vec(u);
abstol, reltol, linsolve_kwargs...)
Expand Down
13 changes: 13 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,16 @@ present in the scheme, they are ignored.
push!(exprs, :(return scheme))
return Expr(:block, exprs...)
end

macro shared_caches(N, expr)
@gensym cache caches
return esc(quote
begin
$(cache) = $(expr)
$(caches) = $(N) 1 ? nothing : map(2:($(N))) do i
$(expr)
end
($cache, $caches)
end
end)
end

0 comments on commit 18934af

Please sign in to comment.