Skip to content

Commit

Permalink
chore: minor updates to SteadyStateDiff
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 30, 2024
1 parent ea70924 commit f9b37c7
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 32 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "SteadyStateDiffEq"
uuid = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
version = "2.3.0"
version = "2.3.1"

[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Expand Down
14 changes: 12 additions & 2 deletions src/SteadyStateDiffEq.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
module SteadyStateDiffEq

using Reexport
using Reexport: @reexport
@reexport using DiffEqBase

using DiffEqCallbacks, ConcreteStructs, LinearAlgebra, SciMLBase
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase, AbstractNonlinearTerminationMode,
AbstractSafeNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode,
NonlinearSafeTerminationReturnCode, NormTerminationMode
using DiffEqCallbacks: TerminateSteadyState
using LinearAlgebra: norm
using SciMLBase: SciMLBase, CallbackSet, NonlinearProblem, ODEProblem,
ReturnCode, SteadyStateProblem, get_du, init, isinplace

const infnorm = Base.Fix2(norm, Inf)

include("algorithms.jl")
include("solve.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
abstract type SteadyStateDiffEqAlgorithm <: DiffEqBase.AbstractSteadyStateAlgorithm end
abstract type SteadyStateDiffEqAlgorithm <: SciMLBase.AbstractSteadyStateAlgorithm end

"""
SSRootfind(alg = nothing)
Expand Down
38 changes: 15 additions & 23 deletions src/solve.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::SSRootfind,
function SciMLBase.__solve(prob::SciMLBase.AbstractSteadyStateProblem, alg::SSRootfind,
args...; kwargs...)
nlprob = NonlinearProblem(prob)
nlsol = DiffEqBase.__solve(nlprob, alg.alg, args...; kwargs...)
nlsol = SciMLBase.__solve(nlprob, alg.alg, args...; kwargs...)
return SciMLBase.build_solution(prob, SSRootfind(nlsol.alg), nlsol.u, nlsol.resid;
nlsol.retcode, nlsol.stats, nlsol.left, nlsol.right, original = nlsol)
end

__get_tspan(u0, alg::DynamicSS) = __get_tspan(u0, alg.tspan)
__get_tspan(u0, tspan::Tuple) = tspan
function __get_tspan(u0, tspan::Number)
return convert.(DiffEqBase.value(real(eltype(u0))),
(DiffEqBase.value(zero(tspan)), tspan))
return convert.(
DiffEqBase.value(real(eltype(u0))), (DiffEqBase.value(zero(tspan)), tspan))
end

infnorm(x) = norm(x,Inf)
function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::DynamicSS,
function SciMLBase.__solve(prob::SciMLBase.AbstractSteadyStateProblem, alg::DynamicSS,
args...; abstol = 1e-8, reltol = 1e-6, odesolve_kwargs = (;),
save_idxs = nothing, termination_condition = NormTerminationMode(infnorm),
kwargs...)
Expand Down Expand Up @@ -54,7 +53,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::Dy

# Construct and solve the ODEProblem
odeprob = ODEProblem{isinplace(prob)}(f, prob.u0, tspan, prob.p)
odesol = DiffEqBase.__solve(odeprob, alg.alg, args...; abstol, reltol, kwargs...,
odesol = SciMLBase.__solve(odeprob, alg.alg, args...; abstol, reltol, kwargs...,
odesolve_kwargs..., callback, save_end = true)

resid, u, retcode = __get_result_from_sol(termination_condition, tc_cache, odesol)
Expand All @@ -68,27 +67,23 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractSteadyStateProblem, alg::Dy
retcode, original = odesol)
end

function __get_result_from_sol(::DiffEqBase.AbstractNonlinearTerminationMode, tc_cache,
odesol)
function __get_result_from_sol(::AbstractNonlinearTerminationMode, tc_cache, odesol)
u, t = last(odesol.u), last(odesol.t)
du = odesol(t, Val{1})
return (du, u,
ifelse(odesol.retcode == ReturnCode.Terminated, ReturnCode.Success,
ReturnCode.Failure))
end

function __get_result_from_sol(::DiffEqBase.AbstractSafeNonlinearTerminationMode, tc_cache,
odesol)
function __get_result_from_sol(::AbstractSafeNonlinearTerminationMode, tc_cache, odesol)
u, t = last(odesol.u), last(odesol.t)
du = odesol(t, Val{1})

if tc_cache.retcode == DiffEqBase.NonlinearSafeTerminationReturnCode.Success
if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success
retcode_tc = ReturnCode.Success
elseif tc_cache.retcode ==
DiffEqBase.NonlinearSafeTerminationReturnCode.PatienceTermination
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination
retcode_tc = ReturnCode.ConvergenceFailure
elseif tc_cache.retcode ==
DiffEqBase.NonlinearSafeTerminationReturnCode.ProtectiveTermination
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination
retcode_tc = ReturnCode.Unstable
else
retcode_tc = ReturnCode.Default
Expand All @@ -105,18 +100,15 @@ function __get_result_from_sol(::DiffEqBase.AbstractSafeNonlinearTerminationMode
return du, u, retcode
end

function __get_result_from_sol(::DiffEqBase.AbstractSafeBestNonlinearTerminationMode,
tc_cache, odesol)
function __get_result_from_sol(::AbstractSafeBestNonlinearTerminationMode, tc_cache, odesol)
u, t = tc_cache.u, only(DiffEqBase.get_saved_values(tc_cache))
du = odesol(t, Val{1})

if tc_cache.retcode == DiffEqBase.NonlinearSafeTerminationReturnCode.Success
if tc_cache.retcode == NonlinearSafeTerminationReturnCode.Success
retcode_tc = ReturnCode.Success
elseif tc_cache.retcode ==
DiffEqBase.NonlinearSafeTerminationReturnCode.PatienceTermination
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.PatienceTermination
retcode_tc = ReturnCode.ConvergenceFailure
elseif tc_cache.retcode ==
DiffEqBase.NonlinearSafeTerminationReturnCode.ProtectiveTermination
elseif tc_cache.retcode == NonlinearSafeTerminationReturnCode.ProtectiveTermination
retcode_tc = ReturnCode.Unstable
else
retcode_tc = ReturnCode.Default
Expand Down
12 changes: 7 additions & 5 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ end
du = zeros(2)
p = nothing

sol = solve(prob, DynamicSS(Rodas5()); abstol = 1e-9, reltol = 1e-9)
sol = solve(prob, DynamicSS(Tsit5()); abstol = 1e-9, reltol = 1e-9)
@test SciMLBase.successful_retcode(sol.retcode)

f(du, sol.u, p, 0)
@test du[0, 0] atol=1e-7

sol = solve(prob, DynamicSS(Rodas5(), tspan = 1e-3))
sol = solve(prob, DynamicSS(Tsit5(), tspan = 1e-3))
@test sol.retcode != ReturnCode.Success

sol = solve(prob, DynamicSS(CVODE_BDF()), dt = 1.0)
Expand Down Expand Up @@ -77,8 +77,10 @@ sol2 = solve(prob, DynamicSS(Tsit5()); abstol = 1e-4)

for termination_condition in [
NormTerminationMode(SteadyStateDiffEq.infnorm), RelTerminationMode(), RelNormTerminationMode(SteadyStateDiffEq.infnorm),
AbsTerminationMode(), AbsNormTerminationMode(SteadyStateDiffEq.infnorm), RelSafeTerminationMode(SteadyStateDiffEq.infnorm),
AbsSafeTerminationMode(SteadyStateDiffEq.infnorm), RelSafeBestTerminationMode(SteadyStateDiffEq.infnorm), AbsSafeBestTerminationMode(SteadyStateDiffEq.infnorm)
AbsTerminationMode(), AbsNormTerminationMode(SteadyStateDiffEq.infnorm),
RelSafeTerminationMode(SteadyStateDiffEq.infnorm),
AbsSafeTerminationMode(SteadyStateDiffEq.infnorm), RelSafeBestTerminationMode(SteadyStateDiffEq.infnorm),
AbsSafeBestTerminationMode(SteadyStateDiffEq.infnorm)
]
sol_tc = solve(prob, DynamicSS(Tsit5()); termination_condition)
@show sol_tc.retcode, termination_condition
Expand All @@ -104,7 +106,7 @@ prob = SteadyStateProblem(f, u0)
saved_values = SavedValues(Float64, Vector{Float64})
cb = SavingCallback((u, t, integrator) -> copy(u), saved_values, save_everystep = true,
save_start = true)
sol = solve(prob, DynamicSS(Rodas5()); callback = cb, save_everystep = true,
sol = solve(prob, DynamicSS(Tsit5()); callback = cb, save_everystep = true,
save_start = true)
@test SciMLBase.successful_retcode(sol.retcode)
@test isapprox(saved_values.saveval[end], sol.u)

0 comments on commit f9b37c7

Please sign in to comment.