diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl index e50be6c47..c4f1dc901 100644 --- a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl @@ -37,8 +37,8 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( sol = solve(newprob, alg, args...; kwargs...) uu = sol.u - Jₚ = nonlinearsolve_∂f_∂p(prob, prob.f, uu, p) - Jᵤ = nonlinearsolve_∂f_∂u(prob, prob.f, uu, p) + Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, p) + Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, p) z = -Jᵤ \ Jₚ pp = prob.p sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z) @@ -123,8 +123,8 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( end end - Jₚ = nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p) - Jᵤ = nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p) + Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p) + Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p) z = -Jᵤ \ Jₚ pp = prob.p sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z) @@ -140,7 +140,7 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve( return sol, partials end -function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F} +function NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F} if SciMLBase.isinplace(prob) f2 = @closure p -> begin du = Utils.safe_similar(u, promote_type(eltype(u), eltype(p))) @@ -159,7 +159,7 @@ function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F} end end -function nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F} +function NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F} if SciMLBase.isinplace(prob) return ForwardDiff.jacobian( @closure((du, u)->f(du, u, p)), Utils.safe_similar(u), u) diff --git a/lib/NonlinearSolveBase/src/public.jl b/lib/NonlinearSolveBase/src/public.jl index d9014d71e..eceea6d75 100644 --- a/lib/NonlinearSolveBase/src/public.jl +++ b/lib/NonlinearSolveBase/src/public.jl @@ -8,6 +8,8 @@ function get_tolerance end # Forward declarations of functions for forward mode AD function nonlinearsolve_forwarddiff_solve end function nonlinearsolve_dual_solution end +function nonlinearsolve_∂f_∂p end +function nonlinearsolve_∂f_∂u end # Nonlinear Solve Termination Conditions abstract type AbstractNonlinearTerminationMode end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 903a1f33d..853cbdf19 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -22,6 +22,9 @@ using LineSearch: LineSearch, AbstractLineSearchCache, LineSearchesJL, NoLineSea using LinearSolve: LinearSolve, QRFactorization, needs_concrete_A, AbstractFactorization, DefaultAlgorithmChoice, DefaultLinearSolver using MaybeInplace: @bb +using NonlinearSolveBase: NonlinearSolveBase, nonlinearsolve_forwarddiff_solve, + nonlinearsolve_dual_solution, nonlinearsolve_∂f_∂p, + nonlinearsolve_∂f_∂u using Printf: @printf using Preferences: Preferences, @load_preference, @set_preferences! using RecursiveArrayTools: recursivecopy! diff --git a/src/internal/forward_diff.jl b/src/internal/forward_diff.jl index 190c80645..a4238674e 100644 --- a/src/internal/forward_diff.jl +++ b/src/internal/forward_diff.jl @@ -1,15 +1,12 @@ -# Not part of public API but helps reduce code duplication -import SimpleNonlinearSolve: __nlsolve_ad, __nlsolve_dual_soln, __nlsolve_∂f_∂p, - __nlsolve_∂f_∂u - +# XXX: dispatch on `__solve` & `__init` function SciMLBase.solve( prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, <:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}}, alg::Union{Nothing, AbstractNonlinearAlgorithm}, args...; kwargs...) where {T, V, P, iip} - sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...) - dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p) + sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...) + dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p) return SciMLBase.build_solution( prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) end @@ -53,10 +50,10 @@ function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache) prob = cache.prob uu = sol.u - f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, cache.values_p) - f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, cache.values_p) + Jₚ = nonlinearsolve_∂f_∂p(prob, prob.f, uu, cache.values_p) + Jᵤ = nonlinearsolve_∂f_∂u(prob, prob.f, uu, cache.values_p) - z_arr = -f_x \ f_p + z_arr = -Jᵤ \ Jₚ sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z) if cache.p isa Number @@ -65,7 +62,7 @@ function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache) partials = sum(sumfun, zip(eachcol(z_arr), cache.p)) end - dual_soln = __nlsolve_dual_soln(sol.u, partials, cache.p) + dual_soln = nonlinearsolve_dual_solution(sol.u, partials, cache.p) return SciMLBase.build_solution( prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) end