diff --git a/Project.toml b/Project.toml index f88f6da..af52198 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "SteadyStateDiffEq" uuid = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f" -version = "2.3.0" +version = "2.3.1" [deps] ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" diff --git a/src/SteadyStateDiffEq.jl b/src/SteadyStateDiffEq.jl index 9b7855b..745e158 100644 --- a/src/SteadyStateDiffEq.jl +++ b/src/SteadyStateDiffEq.jl @@ -1,9 +1,19 @@ module SteadyStateDiffEq -using Reexport +using Reexport: @reexport @reexport using DiffEqBase -using DiffEqCallbacks, ConcreteStructs, LinearAlgebra, SciMLBase +using ConcreteStructs: @concrete +using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode, + AbstractSafeNonlinearTerminationMode, + AbstractSafeBestNonlinearTerminationMode, + NonlinearSafeTerminationReturnCode, NormTerminationMode +using DiffEqCallbacks: TerminateSteadyState +using LinearAlgebra: norm +using SciMLBase: SciMLBase, CallbackSet, NonlinearProblem, ODEProblem, + ReturnCode, SteadyStateProblem, get_du, init, isinplace + +const infnorm = Base.Fix2(norm, Inf) include("algorithms.jl") include("solve.jl") diff --git a/src/algorithms.jl b/src/algorithms.jl index 6957e17..d43aedb 100644 --- a/src/algorithms.jl +++ b/src/algorithms.jl @@ -1,4 +1,4 @@ -abstract type SteadyStateDiffEqAlgorithm <: DiffEqBase.AbstractSteadyStateAlgorithm end +abstract type SteadyStateDiffEqAlgorithm <: SciMLBase.AbstractSteadyStateAlgorithm end """ SSRootfind(alg = nothing) diff --git a/src/solve.jl b/src/solve.jl index ca5b6ea..ed906f7 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -1,7 +1,7 @@ -function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::SSRootfind, +function SciMLBase.__solve(prob::SciMLBase.AbstractSteadyStateProblem, alg::SSRootfind, args...; kwargs...) nlprob = NonlinearProblem(prob) - nlsol = DiffEqBase.__solve(nlprob, alg.alg, args...; kwargs...) + nlsol = SciMLBase.__solve(nlprob, alg.alg, args...; kwargs...) return SciMLBase.build_solution(prob, SSRootfind(nlsol.alg), nlsol.u, nlsol.resid; nlsol.retcode, nlsol.stats, nlsol.left, nlsol.right, original = nlsol) end @@ -9,12 +9,11 @@ end __get_tspan(u0, alg::DynamicSS) = __get_tspan(u0, alg.tspan) __get_tspan(u0, tspan::Tuple) = tspan function __get_tspan(u0, tspan::Number) - return convert.(DiffEqBase.value(real(eltype(u0))), - (DiffEqBase.value(zero(tspan)), tspan)) + return convert.( + DiffEqBase.value(real(eltype(u0))), (DiffEqBase.value(zero(tspan)), tspan)) end -infnorm(x) = norm(x,Inf) -function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::DynamicSS, +function SciMLBase.__solve(prob::SciMLBase.AbstractSteadyStateProblem, alg::DynamicSS, args...; abstol = 1e-8, reltol = 1e-6, odesolve_kwargs = (;), save_idxs = nothing, termination_condition = NormTerminationMode(infnorm), kwargs...) @@ -54,7 +53,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::Dy # Construct and solve the ODEProblem odeprob = ODEProblem{isinplace(prob)}(f, prob.u0, tspan, prob.p) - odesol = DiffEqBase.__solve(odeprob, alg.alg, args...; abstol, reltol, kwargs..., + odesol = SciMLBase.__solve(odeprob, alg.alg, args...; abstol, reltol, kwargs..., odesolve_kwargs..., callback, save_end = true) resid, u, retcode = __get_result_from_sol(termination_condition, tc_cache, odesol) @@ -68,8 +67,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::Dy retcode, original = odesol) end -function __get_result_from_sol(::DiffEqBase.AbstractNonlinearTerminationMode, tc_cache, - odesol) +function __get_result_from_sol(::AbstractNonlinearTerminationMode, tc_cache, odesol) u, t = last(odesol.u), last(odesol.t) du = odesol(t, Val{1}) return (du, u, @@ -77,18 +75,15 @@ function __get_result_from_sol(::DiffEqBase.AbstractNonlinearTerminationMode, tc ReturnCode.Failure)) end -function __get_result_from_sol(::DiffEqBase.AbstractSafeNonlinearTerminationMode, tc_cache, - odesol) +function __get_result_from_sol(::AbstractSafeNonlinearTerminationMode, tc_cache, odesol) u, t = last(odesol.u), last(odesol.t) du = odesol(t, Val{1}) - if tc_cache.retcode == DiffEqBase.NonlinearSafeTerminationReturnCode.Success + if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success retcode_tc = ReturnCode.Success - elseif tc_cache.retcode == - DiffEqBase.NonlinearSafeTerminationReturnCode.PatienceTermination + elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination retcode_tc = ReturnCode.ConvergenceFailure - elseif tc_cache.retcode == - DiffEqBase.NonlinearSafeTerminationReturnCode.ProtectiveTermination + elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination retcode_tc = ReturnCode.Unstable else retcode_tc = ReturnCode.Default @@ -105,18 +100,15 @@ function __get_result_from_sol(::DiffEqBase.AbstractSafeNonlinearTerminationMode return du, u, retcode end -function __get_result_from_sol(::DiffEqBase.AbstractSafeBestNonlinearTerminationMode, - tc_cache, odesol) +function __get_result_from_sol(::AbstractSafeBestNonlinearTerminationMode, tc_cache, odesol) u, t = tc_cache.u, only(DiffEqBase.get_saved_values(tc_cache)) du = odesol(t, Val{1}) - if tc_cache.retcode == DiffEqBase.NonlinearSafeTerminationReturnCode.Success + if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success retcode_tc = ReturnCode.Success - elseif tc_cache.retcode == - DiffEqBase.NonlinearSafeTerminationReturnCode.PatienceTermination + elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination retcode_tc = ReturnCode.ConvergenceFailure - elseif tc_cache.retcode == - DiffEqBase.NonlinearSafeTerminationReturnCode.ProtectiveTermination + elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination retcode_tc = ReturnCode.Unstable else retcode_tc = ReturnCode.Default diff --git a/test/core.jl b/test/core.jl index 7725265..92470d1 100644 --- a/test/core.jl +++ b/test/core.jl @@ -23,13 +23,13 @@ end du = zeros(2) p = nothing - sol = solve(prob, DynamicSS(Rodas5()); abstol = 1e-9, reltol = 1e-9) + sol = solve(prob, DynamicSS(Tsit5()); abstol = 1e-9, reltol = 1e-9) @test SciMLBase.successful_retcode(sol.retcode) f(du, sol.u, p, 0) @test du≈[0, 0] atol=1e-7 - sol = solve(prob, DynamicSS(Rodas5(), tspan = 1e-3)) + sol = solve(prob, DynamicSS(Tsit5(), tspan = 1e-3)) @test sol.retcode != ReturnCode.Success sol = solve(prob, DynamicSS(CVODE_BDF()), dt = 1.0) @@ -77,8 +77,10 @@ sol2 = solve(prob, DynamicSS(Tsit5()); abstol = 1e-4) for termination_condition in [ NormTerminationMode(SteadyStateDiffEq.infnorm), RelTerminationMode(), RelNormTerminationMode(SteadyStateDiffEq.infnorm), - AbsTerminationMode(), AbsNormTerminationMode(SteadyStateDiffEq.infnorm), RelSafeTerminationMode(SteadyStateDiffEq.infnorm), - AbsSafeTerminationMode(SteadyStateDiffEq.infnorm), RelSafeBestTerminationMode(SteadyStateDiffEq.infnorm), AbsSafeBestTerminationMode(SteadyStateDiffEq.infnorm) + AbsTerminationMode(), AbsNormTerminationMode(SteadyStateDiffEq.infnorm), + RelSafeTerminationMode(SteadyStateDiffEq.infnorm), + AbsSafeTerminationMode(SteadyStateDiffEq.infnorm), RelSafeBestTerminationMode(SteadyStateDiffEq.infnorm), + AbsSafeBestTerminationMode(SteadyStateDiffEq.infnorm) ] sol_tc = solve(prob, DynamicSS(Tsit5()); termination_condition) @show sol_tc.retcode, termination_condition @@ -104,7 +106,7 @@ prob = SteadyStateProblem(f, u0) saved_values = SavedValues(Float64, Vector{Float64}) cb = SavingCallback((u, t, integrator) -> copy(u), saved_values, save_everystep = true, save_start = true) -sol = solve(prob, DynamicSS(Rodas5()); callback = cb, save_everystep = true, +sol = solve(prob, DynamicSS(Tsit5()); callback = cb, save_everystep = true, save_start = true) @test SciMLBase.successful_retcode(sol.retcode) @test isapprox(saved_values.saveval[end], sol.u)