Skip to content

Commit

Permalink
fix: forwarddiff support
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 21, 2024
1 parent 26c77a6 commit c73bcf4
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 16 deletions.
12 changes: 6 additions & 6 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
Jₚ = nonlinearsolve_∂f_∂p(prob, prob.f, uu, p)
Jᵤ = nonlinearsolve_∂f_∂u(prob, prob.f, uu, p)
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, p)
z = -Jᵤ \ Jₚ
pp = prob.p
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)
Expand Down Expand Up @@ -123,8 +123,8 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
end
end

Jₚ = nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p)
Jᵤ = nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p)
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p)
z = -Jᵤ \ Jₚ
pp = prob.p
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)
Expand All @@ -140,7 +140,7 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
return sol, partials
end

function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
function NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
if SciMLBase.isinplace(prob)
f2 = @closure p -> begin
du = Utils.safe_similar(u, promote_type(eltype(u), eltype(p)))
Expand All @@ -159,7 +159,7 @@ function nonlinearsolve_∂f_∂p(prob, f::F, u, p) where {F}
end
end

function nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F}
function NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F}
if SciMLBase.isinplace(prob)
return ForwardDiff.jacobian(
@closure((du, u)->f(du, u, p)), Utils.safe_similar(u), u)
Expand Down
2 changes: 2 additions & 0 deletions lib/NonlinearSolveBase/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ function get_tolerance end
# Forward declarations of functions for forward mode AD
function nonlinearsolve_forwarddiff_solve end
function nonlinearsolve_dual_solution end
function nonlinearsolve_∂f_∂p end
function nonlinearsolve_∂f_∂u end

# Nonlinear Solve Termination Conditions
abstract type AbstractNonlinearTerminationMode end
Expand Down
3 changes: 3 additions & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ using LineSearch: LineSearch, AbstractLineSearchCache, LineSearchesJL, NoLineSea
using LinearSolve: LinearSolve, QRFactorization, needs_concrete_A, AbstractFactorization,
DefaultAlgorithmChoice, DefaultLinearSolver
using MaybeInplace: @bb
using NonlinearSolveBase: NonlinearSolveBase, nonlinearsolve_forwarddiff_solve,
nonlinearsolve_dual_solution, nonlinearsolve_∂f_∂p,
nonlinearsolve_∂f_∂u
using Printf: @printf
using Preferences: Preferences, @load_preference, @set_preferences!
using RecursiveArrayTools: recursivecopy!
Expand Down
17 changes: 7 additions & 10 deletions src/internal/forward_diff.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
# Not part of public API but helps reduce code duplication
import SimpleNonlinearSolve: __nlsolve_ad, __nlsolve_dual_soln, __nlsolve_∂f_∂p,
__nlsolve_∂f_∂u

# XXX: dispatch on `__solve` & `__init`
function SciMLBase.solve(
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
alg::Union{Nothing, AbstractNonlinearAlgorithm},
args...;
kwargs...) where {T, V, P, iip}
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end
Expand Down Expand Up @@ -53,10 +50,10 @@ function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
prob = cache.prob

uu = sol.u
f_p = __nlsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
f_x = __nlsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)
Jₚ = nonlinearsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
Jᵤ = nonlinearsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)

z_arr = -f_x \ f_p
z_arr = -Jᵤ \ Jₚ

sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if cache.p isa Number
Expand All @@ -65,7 +62,7 @@ function SciMLBase.solve!(cache::NonlinearSolveForwardDiffCache)
partials = sum(sumfun, zip(eachcol(z_arr), cache.p))
end

dual_soln = __nlsolve_dual_soln(sol.u, partials, cache.p)
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, cache.p)
return SciMLBase.build_solution(
prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
end
Expand Down

0 comments on commit c73bcf4

Please sign in to comment.