Skip to content

Commit

Permalink
fix: dispatch forwarddiff on __init and __solve
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 22, 2024
1 parent 89d76b0 commit f5a06cb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 23 deletions.
3 changes: 2 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ include("descent/damped_newton.jl")
include("descent/geodesic_acceleration.jl")

include("internal/jacobian.jl")
include("internal/forward_diff.jl")
include("internal/linear_solve.jl")
include("internal/termination.jl")
include("internal/tracing.jl")
Expand All @@ -93,6 +92,8 @@ include("algorithms/levenberg_marquardt.jl")
include("algorithms/trust_region.jl")
include("algorithms/extension_algs.jl")

include("internal/forward_diff.jl") # we need to define after the algorithms

include("utils.jl")
include("default.jl")

Expand Down
58 changes: 36 additions & 22 deletions src/internal/forward_diff.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
# XXX: dispatch on `__solve` & `__init`
function SciMLBase.solve(
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::Union{Nothing, AbstractNonlinearAlgorithm},
args...;
kwargs...) where {T, V, P, iip}
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
const DualNonlinearProblem = NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem}

for algType in (
Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL
)
@eval function SciMLBase.__solve(
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end
end

@concrete mutable struct NonlinearSolveForwardDiffCache
Expand All @@ -32,17 +42,21 @@ function reinit_cache!(cache::NonlinearSolveForwardDiffCache;
return cache
end

function SciMLBase.init(
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::Union{Nothing, AbstractNonlinearAlgorithm},
args...;
kwargs...) where {T, V, P, iip}
p = __value(prob.p)
newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p))
for algType in (
Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm,
GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL
)
@eval function SciMLBase.__init(
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...)
p = __value(prob.p)
newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p))
end
end

function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
Expand Down

0 comments on commit f5a06cb

Please sign in to comment.