From f5a06cbcd829a4d66ccb765c40669d69b59f3e88 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 21 Oct 2024 20:34:34 -0400 Subject: [PATCH] fix: dispatch forwarddiff on `__init` and `__solve` --- src/NonlinearSolve.jl | 3 +- src/internal/forward_diff.jl | 58 ++++++++++++++++++++++-------------- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 5f21936a1..625477a88 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -67,7 +67,6 @@ include("descent/damped_newton.jl") include("descent/geodesic_acceleration.jl") include("internal/jacobian.jl") -include("internal/forward_diff.jl") include("internal/linear_solve.jl") include("internal/termination.jl") include("internal/tracing.jl") @@ -93,6 +92,8 @@ include("algorithms/levenberg_marquardt.jl") include("algorithms/trust_region.jl") include("algorithms/extension_algs.jl") +include("internal/forward_diff.jl") # we need to define after the algorithms + include("utils.jl") include("default.jl") diff --git a/src/internal/forward_diff.jl b/src/internal/forward_diff.jl index a4238674e..c2adc70e2 100644 --- a/src/internal/forward_diff.jl +++ b/src/internal/forward_diff.jl @@ -1,14 +1,24 @@ -# XXX: dispatch on `__solve` & `__init` -function SciMLBase.solve( - prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, - alg::Union{Nothing, AbstractNonlinearAlgorithm}, - args...; - kwargs...) where {T, V, P, iip} - sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...) - dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p) - return SciMLBase.build_solution( - prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) +const DualNonlinearProblem = NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P} +const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{ + <:Union{Number, <:AbstractArray}, iip, + <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}} where {iip, T, V, P} +const DualAbstractNonlinearProblem = Union{ + DualNonlinearProblem, DualNonlinearLeastSquaresProblem} + +for algType in ( + Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane, + GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, + LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL, + SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL +) + @eval function SciMLBase.__solve( + prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...) + sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...) + dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) + end end @concrete mutable struct NonlinearSolveForwardDiffCache @@ -32,17 +42,21 @@ function reinit_cache!(cache::NonlinearSolveForwardDiffCache; return cache end -function SciMLBase.init( - prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, - <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, - alg::Union{Nothing, AbstractNonlinearAlgorithm}, - args...; - kwargs...) where {T, V, P, iip} - p = __value(prob.p) - newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...) - cache = init(newprob, alg, args...; kwargs...) - return NonlinearSolveForwardDiffCache( - cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)) +for algType in ( + Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane, + SimpleNonlinearSolve.AbstractSimpleNonlinearSolveAlgorithm, + GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, + LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL, + SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL +) + @eval function SciMLBase.__init( + prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...) + p = __value(prob.p) + newprob = NonlinearProblem(prob.f, __value(prob.u0), p; prob.kwargs...) + cache = init(newprob, alg, args...; kwargs...) + return NonlinearSolveForwardDiffCache( + cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)) + end end function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)