Skip to content

Commit

Permalink
fix: hessian (#489)
Browse files Browse the repository at this point in the history
* fix: hessian through nonlinear solvers

* feat: extend gradient support for cached nlls
  • Loading branch information
avik-pal authored Nov 1, 2024
1 parent 748fb09 commit 6b0002b
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 100 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ NLSolvers = "0.5"
NLsolve = "4.5"
NaNMath = "1"
NonlinearProblemLibrary = "0.1.2"
NonlinearSolveBase = "1"
NonlinearSolveBase = "1.2"
NonlinearSolveFirstOrder = "1"
NonlinearSolveQuasiNewton = "1"
NonlinearSolveSpectralMethods = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolveBase"
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.1.0"
version = "1.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
100 changes: 9 additions & 91 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using CommonSolve: solve
using DifferentiationInterface: DifferentiationInterface
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using LinearAlgebra: mul!
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, remake

Expand All @@ -20,11 +19,14 @@ function NonlinearSolveBase.additional_incompatible_backend_check(
end

Utils.value(::Type{Dual{T, V, N}}) where {T, V, N} = V
Utils.value(x::Dual) = Utils.value(ForwardDiff.value(x))
Utils.value(x::Dual) = ForwardDiff.value(x)
Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x)

function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem},
prob::Union{
IntervalNonlinearProblem, NonlinearProblem,
ImmutableNonlinearProblem, NonlinearLeastSquaresProblem
},
alg, args...; kwargs...
)
p = Utils.value(prob.p)
Expand All @@ -35,98 +37,14 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
end

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, p)
z = -Jᵤ \ Jₚ
pp = prob.p
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)

if uu isa Number
partials = sum(sumfun, zip(z, pp))
elseif p isa Number
partials = sumfun((z, pp))
else
partials = sum(sumfun, zip(eachcol(z), pp))
end

return sol, partials
end

function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...
)
p = Utils.value(prob.p)
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
sol = solve(newprob, alg, args...; kwargs...)
uu = sol.u

# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if SciMLBase.isinplace(prob)
vjp_fn = @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
return nothing
end
else
vjp_fn = @closure (u, p) -> begin
resid = prob.f(u, p)
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
end
end
elseif SciMLBase.has_jac(prob.f)
if SciMLBase.isinplace(prob)
vjp_fn = @closure (du, u, p) -> begin
J = Utils.safe_similar(du, length(sol.resid), length(u))
prob.f.jac(J, u, p)
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
end
else
vjp_fn = @closure (u, p) -> begin
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
end
end
else
# For small problems, nesting ForwardDiff is actually quite fast
autodiff = length(uu) + length(sol.resid) 50 ?
NonlinearSolveBase.select_reverse_mode_autodiff(prob, nothing) :
AutoForwardDiff()

if SciMLBase.isinplace(prob)
vjp_fn = @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
# Using `Constant` lead to dual ordering issues
ff = @closure (du, u) -> prob.f(du, u, p)
resid2 = copy(resid)
DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,))
@. du *= 2
return nothing
end
else
vjp_fn = @closure (u, p) -> begin
v = prob.f(u, p)
# Using `Constant` lead to dual ordering issues
ff = Base.Fix2(prob.f, p)
res = only(DI.pullback(ff, autodiff, u, (v,)))
ArrayInterface.can_setindex(res) || return 2 .* res
@. res *= 2
return res
end
end
end
fn = prob isa NonlinearLeastSquaresProblem ?
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f

Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p)
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, p)
z = -Jᵤ \ Jₚ
pp = prob.p
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)
Expand Down
4 changes: 2 additions & 2 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using ConcreteStructs: @concrete
using FastClosures: @closure
using Preferences: @load_preference, @set_preferences!

using ADTypes: ADTypes, AbstractADType, AutoSparse, NoSparsityDetector,
using ADTypes: ADTypes, AbstractADType, AutoSparse, AutoForwardDiff, NoSparsityDetector,
KnownJacobianSparsityDetector
using Adapt: WrappedArray
using ArrayInterface: ArrayInterface
Expand All @@ -25,7 +25,7 @@ using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
using SymbolicIndexingInterface: SymbolicIndexingInterface

using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind
using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind, mul!
using Markdown: @doc_str
using Printf: @printf

Expand Down
62 changes: 62 additions & 0 deletions lib/NonlinearSolveBase/src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,65 @@ end
is_finite_differences_backend(ad::AbstractADType) = false
is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true
is_finite_differences_backend(::ADTypes.AutoFiniteDifferences) = true

function nlls_generate_vjp_function(prob::NonlinearLeastSquaresProblem, sol, uu)
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if SciMLBase.isinplace(prob)
return @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f.vjp(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
return nothing
end
else
return @closure (u, p) -> begin
resid = prob.f(u, p)
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
end
end
elseif SciMLBase.has_jac(prob.f)
if SciMLBase.isinplace(prob)
return @closure (du, u, p) -> begin
J = Utils.safe_similar(du, length(sol.resid), length(u))
prob.f.jac(J, u, p)
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
end
else
return @closure (u, p) -> begin
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
end
end
else
# For small problems, nesting ForwardDiff is actually quite fast
autodiff = length(uu) + length(sol.resid) 50 ?
select_reverse_mode_autodiff(prob, nothing) : AutoForwardDiff()

if SciMLBase.isinplace(prob)
return @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
# Using `Constant` lead to dual ordering issues
ff = @closure (du, u) -> prob.f(du, u, p)
resid2 = copy(resid)
DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,))
@. du *= 2
return nothing
end
else
return @closure (u, p) -> begin
v = prob.f(u, p)
# Using `Constant` lead to dual ordering issues
res = only(DI.pullback(Base.Fix2(prob.f, p), autodiff, u, (v,)))
ArrayInterface.can_setindex(res) || return 2 .* res
@. res *= 2
return res
end
end
end
end
1 change: 1 addition & 0 deletions lib/NonlinearSolveBase/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ function nonlinearsolve_forwarddiff_solve end
function nonlinearsolve_dual_solution end
function nonlinearsolve_∂f_∂p end
function nonlinearsolve_∂f_∂u end
function nlls_generate_vjp_function end

# Nonlinear Solve Termination Conditions
abstract type AbstractNonlinearTerminationMode end
Expand Down
12 changes: 7 additions & 5 deletions src/forward_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ function InternalAPI.reinit!(
end

for algType in ALL_SOLVER_TYPES
# XXX: Extend to DualNonlinearLeastSquaresProblem
@eval function SciMLBase.__init(
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
Expand All @@ -64,10 +63,13 @@ end
function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
sol = solve!(cache.cache)
prob = cache.prob

uu = sol.u
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)

fn = prob isa NonlinearLeastSquaresProblem ?
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f

Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p)

z_arr = -Jᵤ \ Jₚ

Expand Down
94 changes: 94 additions & 0 deletions test/forward_ad_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,97 @@ end
end
end
end

@testitem "NLLS Hessian SciML/NonlinearSolve.jl#445" tags=[:core] begin
using ForwardDiff, FiniteDiff

function objfn(F, init, params)
th1, th2 = init
px, py, l1, l2 = params
F[1] = l1 * cos(th1) + l2 * cos(th1 + th2) - px
F[2] = l1 * sin(th1) + l2 * sin(th1 + th2) - py
return F
end

function solve_nlprob(pxpy)
px, py = pxpy
theta1 = pi / 4
theta2 = pi / 4
initial_guess = [theta1; theta2]
l1 = 60
l2 = 60
p = [px; py; l1; l2]
prob = NonlinearLeastSquaresProblem(
NonlinearFunction(objfn, resid_prototype = zeros(2)),
initial_guess, p
)
resu = solve(
prob,
reltol = 1e-12, abstol = 1e-12
)
th1, th2 = resu.u
cable1_base = [-90; 0; 0]
cable2_base = [-150; 0; 0]
cable3_base = [150; 0; 0]
cable1_top = [l1 * cos(th1) / 2; l1 * sin(th1) / 2; 0]
cable23_top = [l1 * cos(th1) + l2 * cos(th1 + th2) / 2;
l1 * sin(th1) + l2 * sin(th1 + th2) / 2; 0]
c1_length = sqrt((cable1_top[1] - cable1_base[1])^2 +
(cable1_top[2] - cable1_base[2])^2)
c2_length = sqrt((cable23_top[1] - cable2_base[1])^2 +
(cable23_top[2] - cable2_base[2])^2)
c3_length = sqrt((cable23_top[1] - cable3_base[1])^2 +
(cable23_top[2] - cable3_base[2])^2)
return c1_length + c2_length + c3_length
end

grad1 = ForwardDiff.gradient(solve_nlprob, [34.0, 87.0])
grad2 = FiniteDiff.finite_difference_gradient(solve_nlprob, [34.0, 87.0])

@test grad1grad2 atol=1e-3

hess1 = ForwardDiff.hessian(solve_nlprob, [34.0, 87.0])
hess2 = FiniteDiff.finite_difference_hessian(solve_nlprob, [34.0, 87.0])

@test hess1hess2 atol=1e-3

function solve_nlprob_with_cache(pxpy)
px, py = pxpy
theta1 = pi / 4
theta2 = pi / 4
initial_guess = [theta1; theta2]
l1 = 60
l2 = 60
p = [px; py; l1; l2]
prob = NonlinearLeastSquaresProblem(
NonlinearFunction(objfn, resid_prototype = zeros(2)),
initial_guess, p
)
cache = init(prob; reltol = 1e-12, abstol = 1e-12)
resu = solve!(cache)
th1, th2 = resu.u
cable1_base = [-90; 0; 0]
cable2_base = [-150; 0; 0]
cable3_base = [150; 0; 0]
cable1_top = [l1 * cos(th1) / 2; l1 * sin(th1) / 2; 0]
cable23_top = [l1 * cos(th1) + l2 * cos(th1 + th2) / 2;
l1 * sin(th1) + l2 * sin(th1 + th2) / 2; 0]
c1_length = sqrt((cable1_top[1] - cable1_base[1])^2 +
(cable1_top[2] - cable1_base[2])^2)
c2_length = sqrt((cable23_top[1] - cable2_base[1])^2 +
(cable23_top[2] - cable2_base[2])^2)
c3_length = sqrt((cable23_top[1] - cable3_base[1])^2 +
(cable23_top[2] - cable3_base[2])^2)
return c1_length + c2_length + c3_length
end

grad1 = ForwardDiff.gradient(solve_nlprob_with_cache, [34.0, 87.0])
grad2 = FiniteDiff.finite_difference_gradient(solve_nlprob_with_cache, [34.0, 87.0])

@test grad1grad2 atol=1e-3

hess1 = ForwardDiff.hessian(solve_nlprob_with_cache, [34.0, 87.0])
hess2 = FiniteDiff.finite_difference_hessian(solve_nlprob_with_cache, [34.0, 87.0])

@test hess1hess2 atol=1e-3
end

2 comments on commit 6b0002b

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/NonlinearSolveBase

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/118516

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a NonlinearSolveBase-v1.2.0 -m "<description of version>" 6b0002b7b8cda2524a6211663d03fa64df42fa17
git push origin NonlinearSolveBase-v1.2.0

Please sign in to comment.