diff --git a/src/Bijectors.jl b/src/Bijectors.jl index f5d0245b..ea59e59f 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -18,7 +18,8 @@ export TransformDistribution, const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0"))) -_eps(::Type{T}) where {T} = eps(T) +# Workaround for eps(::ForwardDiff.Dual) +_eps(::Type{T}) where {T} = T(eps(T)) _eps(::Type{Real}) = eps(Float64) function __init__() @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" @eval begin @@ -62,8 +63,8 @@ end ############# const TransformDistribution{T<:ContinuousUnivariateDistribution} = Union{T, Truncated{T}} -@inline function _clamp(x::Real, dist::TransformDistribution) - ϵ = eps(x) +@inline function _clamp(x::T, dist::TransformDistribution) where {T <: Real} + ϵ = _eps(T) bounds = (minimum(dist) + ϵ, maximum(dist) - ϵ) clamped_x = ifelse(x < bounds[1], bounds[1], ifelse(x > bounds[2], bounds[2], x)) DEBUG && @debug "x = $x, bounds = $bounds, clamped_x = $clamped_x"