From ecb7e94245f0f29c027368e109236c986e3c1263 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 13 Jan 2019 03:31:45 +1100 Subject: [PATCH 1/4] fix #12 --- src/Bijectors.jl | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 28f06f39..04a33fc6 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -210,6 +210,8 @@ function link( return Y end +clamp0to1(x::T) where T = clamp(x, zero(T), one(T)) + function invlink( d::SimplexDistribution, y::AbstractVector{T}, @@ -219,18 +221,18 @@ function invlink( ϵ = _eps(T) z = StatsFuns.logistic(y[1] + log(one(T) / (K - 1))) - x[1] = (z - ϵ) / (one(T) - 2ϵ) + x[1] = (z - ϵ) / (one(T) - 2ϵ) |> clamp0to1 sum_tmp = zero(T) @inbounds for k = 2:(K - 1) z = StatsFuns.logistic(y[k] + log(one(T) / (K - k))) sum_tmp += x[k-1] - x[k] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ + x[k] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ |> clamp0to1 end sum_tmp += x[K - 1] if proj - x[K] = one(T) - sum_tmp + x[K] = one(T) - sum_tmp |> clamp0to1 else - x[K] = one(T) - sum_tmp - y[K] + x[K] = one(T) - sum_tmp - y[K] |> clamp0to1 end return x end @@ -246,17 +248,17 @@ function invlink( ϵ = _eps(T) @inbounds for n in 1:size(X, 2) sum_tmp, z = zero(T), StatsFuns.logistic(Y[1, n] + log(one(T) / (K - 1))) - X[1, n] = (z - ϵ) / (one(T) - 2ϵ) + X[1, n] = (z - ϵ) / (one(T) - 2ϵ) |> clamp0to1 for k in 2:(K - 1) z = StatsFuns.logistic(Y[k, n] + log(one(T) / (K - k))) sum_tmp += X[k - 1] - X[k, n] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ + X[k, n] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ |> clamp0to1 end sum_tmp += X[K - 1, n] if proj - X[K, n] = one(T) - sum_tmp + X[K, n] = one(T) - sum_tmp |> clamp0to1 else - X[K, n] = one(T) - sum_tmp - Y[K, n] + X[K, n] = one(T) - sum_tmp - Y[K, n] |> clamp0to1 end end @@ -279,7 +281,7 @@ function logpdf_with_trans( lp += log(z + ϵ) + log(one(T) - z + ϵ) @inbounds for k in 2:(K - 1) sum_tmp += x[k-1] - z = x[k] / (one(T) - sum_tmp) + z = x[k] / (one(T) - sum_tmp + ϵ) lp += log(z + ϵ) + log(one(T) - z + ϵ) + log(one(T) - sum_tmp + ϵ) end end From 17eb9e31f5128890b1663a229e9f33bfde2a5670 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 13 Jan 2019 04:02:26 +1100 Subject: [PATCH 2/4] add numerical test and increase tolerance for inverse test --- test/runtests.jl | 4 +++- test/transform.jl | 19 +++++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 3302f893..66760196 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,5 @@ -using Bijectors +using Bijectors, Random + +Random.seed!(123456) include("transform.jl") diff --git a/test/transform.jl b/test/transform.jl index 826f3bce..b7f8cfdb 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -42,8 +42,17 @@ function single_sample_tests(dist) # Check that link is inverse of invlink. Hopefully this just holds given the above... y = link(dist, x) - @test link(dist, invlink(dist, copy(y))) ≈ y atol=1e-9 - + if dist isa Dirichlet + # `logit` and `logistic` are not perfect inverses. This leads to a diversion. + # Example: + # julia> logit(logistic(0.9999999999999998)) + # 1.0 + # julia> logistic(logit(0.9999999999999998)) + # 0.9999999999999998 + @test link(dist, invlink(dist, copy(y))) ≈ y atol=0.5 + else + @test link(dist, invlink(dist, copy(y))) ≈ y atol=1e-9 + end if dist isa SimplexDistribution # This should probably be exact. @test logpdf(dist, x .+ ϵ) == logpdf_with_trans(dist, x, false) @@ -140,6 +149,12 @@ let ϵ = eps(Float64) logpdf_turing = logpdf_with_trans(dist, x, true) J = jacobian(x->link(dist, x, Val{false}), x) @test logpdf(dist, x .+ ϵ) - _logabsdet(J) ≈ logpdf_turing + + # Issue #12 + stepsize = 1e10 + dim = length(dist) + x = [logpdf_with_trans(dist, invlink(dist, link(dist, rand(dist)) .+ randn(dim) .* stepsize), true) for _ in 1:1_000] + @test !any(isinf, x) && !any(isnan, x) else single_sample_tests(dist, jacobian) end From 2ace38abf55a1be101f9e5ca1abc2ced80524cde Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Tue, 15 Jan 2019 12:28:06 +1100 Subject: [PATCH 3/4] make other bounded distributions stable --- src/Bijectors.jl | 59 +++++++++++++++++++++++------------------------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 04a33fc6..9bb02fc8 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -60,8 +60,12 @@ end ############# const TransformDistribution{T<:ContinuousUnivariateDistribution} = Union{T, Truncated{T}} +@inline function _clamp(x::Real, dist::TransformDistribution) + return clamp(x, minimum(dist), maximum(dist)) +end -function link(d::TransformDistribution, x::Real) +link(d::TransformDistribution, x::Real) = _link(d, _clamp(x, d)) +function _link(d::TransformDistribution, x::Real) a, b = minimum(d), maximum(d) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded @@ -75,7 +79,8 @@ function link(d::TransformDistribution, x::Real) end end -function invlink(d::TransformDistribution, y::Real) +invlink(d::TransformDistribution, y::Real) = _clamp(_invlink(d, y), d) +function _invlink(d::TransformDistribution, y::Real) a, b = minimum(d), maximum(d) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded @@ -90,6 +95,10 @@ function invlink(d::TransformDistribution, y::Real) end function logpdf_with_trans(d::TransformDistribution, x::Real, transform::Bool) + x = transform ? _clamp(x, d) : x + return _logpdf_with_trans(d, x, transform) +end +function _logpdf_with_trans(d::TransformDistribution, x::Real, transform::Bool) lp = logpdf(d, x) if transform a, b = minimum(d), maximum(d) @@ -105,19 +114,6 @@ function logpdf_with_trans(d::TransformDistribution, x::Real, transform::Bool) return lp end - -############### -# -∞ < x < -∞ # -############### - -const RealDistribution = Union{ - Cauchy, Gumbel, Laplace, Logistic, NoncentralT, Normal, NormalCanon, TDist, -} - -link(d::RealDistribution, x::Real) = x -invlink(d::RealDistribution, y::Real) = y -logpdf_with_trans(d::RealDistribution, y::Real, transform::Bool) = logpdf(d, y) - ######### # 0 < x # ######### @@ -127,9 +123,9 @@ const PositiveDistribution = Union{ InverseGaussian, Kolmogorov, LogNormal, NoncentralChisq, NoncentralF, Rayleigh, Weibull, } -link(d::PositiveDistribution, x::Real) = log(x) -invlink(d::PositiveDistribution, y::Real) = exp(y) -function logpdf_with_trans(d::PositiveDistribution, x::Real, transform::Bool) +_link(d::PositiveDistribution, x::Real) = log(x) +_invlink(d::PositiveDistribution, y::Real) = exp(y) +function _logpdf_with_trans(d::PositiveDistribution, x::Real, transform::Bool) return logpdf(d, x) + transform * log(x) end @@ -140,9 +136,9 @@ end const UnitDistribution = Union{Beta, KSOneSided, NoncentralBeta} -link(d::UnitDistribution, x::Real) = StatsFuns.logit(x) -invlink(d::UnitDistribution, y::Real) = StatsFuns.logistic(y) -function logpdf_with_trans(d::UnitDistribution, x::Real, transform::Bool) +_link(d::UnitDistribution, x::Real) = StatsFuns.logit(x) +_invlink(d::UnitDistribution, y::Real) = StatsFuns.logistic(y) +function _logpdf_with_trans(d::UnitDistribution, x::Real, transform::Bool) return logpdf(d, x) + transform * log(x * (one(x) - x)) end @@ -152,6 +148,9 @@ end ########### const SimplexDistribution = Union{Dirichlet} +@inline function _clamp(x::T, dist::SimplexDistribution) where T + return clamp(x, zero(T), one(T)) +end function link( d::SimplexDistribution, @@ -210,8 +209,6 @@ function link( return Y end -clamp0to1(x::T) where T = clamp(x, zero(T), one(T)) - function invlink( d::SimplexDistribution, y::AbstractVector{T}, @@ -221,18 +218,18 @@ function invlink( ϵ = _eps(T) z = StatsFuns.logistic(y[1] + log(one(T) / (K - 1))) - x[1] = (z - ϵ) / (one(T) - 2ϵ) |> clamp0to1 + x[1] = _clamp((z - ϵ) / (one(T) - 2ϵ), d) sum_tmp = zero(T) @inbounds for k = 2:(K - 1) z = StatsFuns.logistic(y[k] + log(one(T) / (K - k))) sum_tmp += x[k-1] - x[k] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ |> clamp0to1 + x[k] = _clamp((one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ, d) end sum_tmp += x[K - 1] if proj - x[K] = one(T) - sum_tmp |> clamp0to1 + x[K] = _clamp(one(T) - sum_tmp, d) else - x[K] = one(T) - sum_tmp - y[K] |> clamp0to1 + x[K] = _clamp(one(T) - sum_tmp - y[K], d) end return x end @@ -248,17 +245,17 @@ function invlink( ϵ = _eps(T) @inbounds for n in 1:size(X, 2) sum_tmp, z = zero(T), StatsFuns.logistic(Y[1, n] + log(one(T) / (K - 1))) - X[1, n] = (z - ϵ) / (one(T) - 2ϵ) |> clamp0to1 + X[1, n] = _clamp((z - ϵ) / (one(T) - 2ϵ), d) for k in 2:(K - 1) z = StatsFuns.logistic(Y[k, n] + log(one(T) / (K - k))) sum_tmp += X[k - 1] - X[k, n] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ |> clamp0to1 + X[k, n] = _clamp((one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ, d) end sum_tmp += X[K - 1, n] if proj - X[K, n] = one(T) - sum_tmp |> clamp0to1 + X[K, n] = _clamp(one(T) - sum_tmp, d) else - X[K, n] = one(T) - sum_tmp - Y[K, n] |> clamp0to1 + X[K, n] = _clamp(one(T) - sum_tmp - Y[K, n], d) end end From 247aace459d74f74537cc7a13b31e987e1683002 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Wed, 16 Jan 2019 08:25:35 +1100 Subject: [PATCH 4/4] add debugging info to clamp --- src/Bijectors.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 9bb02fc8..d9e54401 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -61,7 +61,10 @@ end const TransformDistribution{T<:ContinuousUnivariateDistribution} = Union{T, Truncated{T}} @inline function _clamp(x::Real, dist::TransformDistribution) - return clamp(x, minimum(dist), maximum(dist)) + bounds = (minimum(dist), maximum(dist)) + clamped_x = clamp(x, bounds...) + @debug "x = $x, bounds = $bounds, clamped_x = $clamped_x" + return clamped_x end link(d::TransformDistribution, x::Real) = _link(d, _clamp(x, d)) @@ -149,7 +152,10 @@ end const SimplexDistribution = Union{Dirichlet} @inline function _clamp(x::T, dist::SimplexDistribution) where T - return clamp(x, zero(T), one(T)) + bounds = (zero(T), one(T)) + clamped_x = clamp(x, bounds...) + @debug "x = $x, bounds = $bounds, clamped_x = $clamped_x" + return clamped_x end function link(