From c79bd51f7006021e514749a53ba89ae1732cc22e Mon Sep 17 00:00:00 2001 From: Kipton Barros Date: Wed, 15 Nov 2023 08:31:12 -0700 Subject: [PATCH] Print an error if non-positive dt is detected --- src/Integrators.jl | 10 ++++++++-- test/test_jet.jl | 4 ++-- test/test_samplers.jl | 4 ++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/Integrators.jl b/src/Integrators.jl index f32e9e1f8..92cbf8496 100644 --- a/src/Integrators.jl +++ b/src/Integrators.jl @@ -21,7 +21,10 @@ mutable struct Langevin kT :: Float64 end -Langevin(Δt; λ, kT) = Langevin(Δt, λ, kT) +function Langevin(Δt; λ, kT) + Δt <= 0 && error("Select positive Δt") + return Langevin(Δt, λ, kT) +end Base.copy(dyn::Langevin) = Langevin(dyn.Δt, dyn.λ, dyn.kT) @@ -41,7 +44,10 @@ mutable struct ImplicitMidpoint atol :: Float64 end -ImplicitMidpoint(Δt; atol=1e-12) = ImplicitMidpoint(Δt, atol) +function ImplicitMidpoint(Δt; atol=1e-12) + Δt <= 0 && error("Select positive Δt") + return ImplicitMidpoint(Δt, atol) +end diff --git a/test/test_jet.jl b/test/test_jet.jl index 9854821bf..08c1dd2d5 100644 --- a/test/test_jet.jl +++ b/test/test_jet.jl @@ -16,7 +16,7 @@ sampler = LocalSampler(kT=0.2; propose) @test_opt step!(sys, sampler) - langevin = Langevin(0.01, kT=0.2, λ=0.1) + langevin = Langevin(0.01; kT=0.2, λ=0.1) @test_opt step!(sys, langevin) integrator = ImplicitMidpoint(0.01) @@ -48,7 +48,7 @@ end step!(sys, sampler) @test 0 == @allocated step!(sys, sampler) - langevin = Langevin(0.01, kT=0.2, λ=0.1) + langevin = Langevin(0.01; kT=0.2, λ=0.1) step!(sys, langevin) @test 0 == @allocated step!(sys, langevin) diff --git a/test/test_samplers.jl b/test/test_samplers.jl index 84f66c226..783161edc 100644 --- a/test/test_samplers.jl +++ b/test/test_samplers.jl @@ -72,7 +72,7 @@ collect_dur = 100.0 sys = su3_anisotropy_model(; D, L, seed=0) - langevin = Langevin(Δt; kT=0.0, λ) + langevin = Langevin(Δt; kT=0, λ) for kT in kTs langevin.kT = kT @@ -96,7 +96,7 @@ collect_dur = 200.0 sys = su5_anisotropy_model(; D, L, seed=0) - langevin = Langevin(Δt; kT=0.0, λ) + langevin = Langevin(Δt; kT=0, λ) for kT ∈ kTs langevin.kT = kT