Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Move dual nonlinear solving to NonlinearSolveBase #513

Merged
merged 10 commits into from
Dec 11, 2024
111 changes: 109 additions & 2 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,36 @@ module NonlinearSolveBaseForwardDiffExt

using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using CommonSolve: solve
using CommonSolve: CommonSolve, solve
using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, remake

using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
AbstractNonlinearSolveAlgorithm, Utils, InternalAPI,
AbstractNonlinearSolveCache

const DI = DifferentiationInterface

const ALL_SOLVER_TYPES = [
Nothing, AbstractNonlinearSolveAlgorithm
]

const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function NonlinearSolveBase.additional_incompatible_backend_check(
prob::AbstractNonlinearProblem, ::Union{AutoForwardDiff, AutoPolyesterForwardDiff})
return !ForwardDiff.can_dual(eltype(prob.u0))
Expand Down Expand Up @@ -102,4 +121,92 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution(
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials)))
end

for algType in ALL_SOLVER_TYPES
@eval function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end
end

@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache
cache
prob
alg
p
values_p
partials_p
end

function InternalAPI.reinit!(
cache::NonlinearSolveForwardDiffCache, args...;
p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs...
)
InternalAPI.reinit!(
cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs...
)
cache.p = p
cache.values_p = nodual_value(p)
cache.partials_p = ForwardDiff.partials(p)
return cache
end

for algType in ALL_SOLVER_TYPES
@eval function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end
end

function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
sol = solve!(cache.cache)
prob = cache.prob
uu = sol.u

fn = prob isa NonlinearLeastSquaresProblem ?
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f

Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p)

z_arr = -Jᵤ \ Jₚ

sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
if cache.p isa Number
partials = sumfun((z_arr, cache.p))
else
partials = sum(sumfun, zip(eachcol(z_arr), cache.p))
end

dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, cache.p)
return SciMLBase.build_solution(
prob, cache.alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

"""
pickchunksize(x) = pickchunksize(length(x))
pickchunksize(x::Int)

Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length.
"""
@inline pickchunksize(x) = pickchunksize(length(x))
@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)
ErikQQY marked this conversation as resolved.
Show resolved Hide resolved

end
3 changes: 2 additions & 1 deletion lib/NonlinearSolveFirstOrder/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
Expand All @@ -86,4 +87,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "ForwardDiff", "Enzyme", "ExplicitImports", "Hwloc", "InteractiveUtils", "LineSearch", "LineSearches", "NonlinearProblemLibrary", "Pkg", "Random", "ReTestItems", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test", "Zygote"]
10 changes: 10 additions & 0 deletions lib/NonlinearSolveFirstOrder/test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,13 @@
@test sol.retcode == ReturnCode.Success
@test jac_calls == 0
end

@testitem "Dual of BigFloat: Issue #512" tags=[:core] begin
using NonlinearSolveFirstOrder, ForwardDiff
fn_iip = NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p)
u2 = [ForwardDiff.Dual(BigFloat(1.0), 5.0), ForwardDiff.Dual(BigFloat(1.0), 5.0),
ForwardDiff.Dual(BigFloat(1.0), 5.0)]
prob_iip_bf = NonlinearProblem{true}(fn_iip, u2, ForwardDiff.Dual(BigFloat(2.0), 5.0))
sol = solve(prob_iip_bf, NewtonRaphson())
@test sol.retcode == ReturnCode.Success
end
2 changes: 0 additions & 2 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ const ALL_SOLVER_TYPES = [
NonlinearSolvePolyAlgorithm
]

include("forward_diff.jl")

@setup_workload begin
nonlinear_functions = (
(NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), 0.1),
Expand Down
99 changes: 0 additions & 99 deletions src/forward_diff.jl

This file was deleted.

Loading