Skip to content

Commit

Permalink
Merge pull request #14 from TuringLang/mt/fix_12
Browse files Browse the repository at this point in the history
Fix numerical error in #12
  • Loading branch information
yebai authored Jan 15, 2019
2 parents 9bc1e3c + 247aace commit dbdac12
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 33 deletions.
65 changes: 35 additions & 30 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,15 @@ end
#############

const TransformDistribution{T<:ContinuousUnivariateDistribution} = Union{T, Truncated{T}}
@inline function _clamp(x::Real, dist::TransformDistribution)
bounds = (minimum(dist), maximum(dist))
clamped_x = clamp(x, bounds...)
@debug "x = $x, bounds = $bounds, clamped_x = $clamped_x"
return clamped_x
end

function link(d::TransformDistribution, x::Real)
link(d::TransformDistribution, x::Real) = _link(d, _clamp(x, d))
function _link(d::TransformDistribution, x::Real)
a, b = minimum(d), maximum(d)
lowerbounded, upperbounded = isfinite(a), isfinite(b)
if lowerbounded && upperbounded
Expand All @@ -75,7 +82,8 @@ function link(d::TransformDistribution, x::Real)
end
end

function invlink(d::TransformDistribution, y::Real)
invlink(d::TransformDistribution, y::Real) = _clamp(_invlink(d, y), d)
function _invlink(d::TransformDistribution, y::Real)
a, b = minimum(d), maximum(d)
lowerbounded, upperbounded = isfinite(a), isfinite(b)
if lowerbounded && upperbounded
Expand All @@ -90,6 +98,10 @@ function invlink(d::TransformDistribution, y::Real)
end

function logpdf_with_trans(d::TransformDistribution, x::Real, transform::Bool)
x = transform ? _clamp(x, d) : x
return _logpdf_with_trans(d, x, transform)
end
function _logpdf_with_trans(d::TransformDistribution, x::Real, transform::Bool)
lp = logpdf(d, x)
if transform
a, b = minimum(d), maximum(d)
Expand All @@ -105,19 +117,6 @@ function logpdf_with_trans(d::TransformDistribution, x::Real, transform::Bool)
return lp
end


###############
# -∞ < x < -∞ #
###############

const RealDistribution = Union{
Cauchy, Gumbel, Laplace, Logistic, NoncentralT, Normal, NormalCanon, TDist,
}

link(d::RealDistribution, x::Real) = x
invlink(d::RealDistribution, y::Real) = y
logpdf_with_trans(d::RealDistribution, y::Real, transform::Bool) = logpdf(d, y)

#########
# 0 < x #
#########
Expand All @@ -127,9 +126,9 @@ const PositiveDistribution = Union{
InverseGaussian, Kolmogorov, LogNormal, NoncentralChisq, NoncentralF, Rayleigh, Weibull,
}

link(d::PositiveDistribution, x::Real) = log(x)
invlink(d::PositiveDistribution, y::Real) = exp(y)
function logpdf_with_trans(d::PositiveDistribution, x::Real, transform::Bool)
_link(d::PositiveDistribution, x::Real) = log(x)
_invlink(d::PositiveDistribution, y::Real) = exp(y)
function _logpdf_with_trans(d::PositiveDistribution, x::Real, transform::Bool)
return logpdf(d, x) + transform * log(x)
end

Expand All @@ -140,9 +139,9 @@ end

const UnitDistribution = Union{Beta, KSOneSided, NoncentralBeta}

link(d::UnitDistribution, x::Real) = StatsFuns.logit(x)
invlink(d::UnitDistribution, y::Real) = StatsFuns.logistic(y)
function logpdf_with_trans(d::UnitDistribution, x::Real, transform::Bool)
_link(d::UnitDistribution, x::Real) = StatsFuns.logit(x)
_invlink(d::UnitDistribution, y::Real) = StatsFuns.logistic(y)
function _logpdf_with_trans(d::UnitDistribution, x::Real, transform::Bool)
return logpdf(d, x) + transform * log(x * (one(x) - x))
end

Expand All @@ -152,6 +151,12 @@ end
###########

const SimplexDistribution = Union{Dirichlet}
@inline function _clamp(x::T, dist::SimplexDistribution) where T
bounds = (zero(T), one(T))
clamped_x = clamp(x, bounds...)
@debug "x = $x, bounds = $bounds, clamped_x = $clamped_x"
return clamped_x
end

function link(
d::SimplexDistribution,
Expand Down Expand Up @@ -219,18 +224,18 @@ function invlink(

ϵ = _eps(T)
z = StatsFuns.logistic(y[1] + log(one(T) / (K - 1)))
x[1] = (z - ϵ) / (one(T) - 2ϵ)
x[1] = _clamp((z - ϵ) / (one(T) - 2ϵ), d)
sum_tmp = zero(T)
@inbounds for k = 2:(K - 1)
z = StatsFuns.logistic(y[k] + log(one(T) / (K - k)))
sum_tmp += x[k-1]
x[k] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ
x[k] = _clamp((one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ, d)
end
sum_tmp += x[K - 1]
if proj
x[K] = one(T) - sum_tmp
x[K] = _clamp(one(T) - sum_tmp, d)
else
x[K] = one(T) - sum_tmp - y[K]
x[K] = _clamp(one(T) - sum_tmp - y[K], d)
end
return x
end
Expand All @@ -246,17 +251,17 @@ function invlink(
ϵ = _eps(T)
@inbounds for n in 1:size(X, 2)
sum_tmp, z = zero(T), StatsFuns.logistic(Y[1, n] + log(one(T) / (K - 1)))
X[1, n] = (z - ϵ) / (one(T) - 2ϵ)
X[1, n] = _clamp((z - ϵ) / (one(T) - 2ϵ), d)
for k in 2:(K - 1)
z = StatsFuns.logistic(Y[k, n] + log(one(T) / (K - k)))
sum_tmp += X[k - 1]
X[k, n] = (one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ
X[k, n] = _clamp((one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ, d)
end
sum_tmp += X[K - 1, n]
if proj
X[K, n] = one(T) - sum_tmp
X[K, n] = _clamp(one(T) - sum_tmp, d)
else
X[K, n] = one(T) - sum_tmp - Y[K, n]
X[K, n] = _clamp(one(T) - sum_tmp - Y[K, n], d)
end
end

Expand All @@ -279,7 +284,7 @@ function logpdf_with_trans(
lp += log(z + ϵ) + log(one(T) - z + ϵ)
@inbounds for k in 2:(K - 1)
sum_tmp += x[k-1]
z = x[k] / (one(T) - sum_tmp)
z = x[k] / (one(T) - sum_tmp + ϵ)
lp += log(z + ϵ) + log(one(T) - z + ϵ) + log(one(T) - sum_tmp + ϵ)
end
end
Expand Down
4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Bijectors
using Bijectors, Random

Random.seed!(123456)

include("transform.jl")
19 changes: 17 additions & 2 deletions test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,17 @@ function single_sample_tests(dist)

# Check that link is inverse of invlink. Hopefully this just holds given the above...
y = link(dist, x)
@test link(dist, invlink(dist, copy(y))) y atol=1e-9

if dist isa Dirichlet
# `logit` and `logistic` are not perfect inverses. This leads to a diversion.
# Example:
# julia> logit(logistic(0.9999999999999998))
# 1.0
# julia> logistic(logit(0.9999999999999998))
# 0.9999999999999998
@test link(dist, invlink(dist, copy(y))) y atol=0.5
else
@test link(dist, invlink(dist, copy(y))) y atol=1e-9
end
if dist isa SimplexDistribution
# This should probably be exact.
@test logpdf(dist, x .+ ϵ) == logpdf_with_trans(dist, x, false)
Expand Down Expand Up @@ -140,6 +149,12 @@ let ϵ = eps(Float64)
logpdf_turing = logpdf_with_trans(dist, x, true)
J = jacobian(x->link(dist, x, Val{false}), x)
@test logpdf(dist, x .+ ϵ) - _logabsdet(J) logpdf_turing

# Issue #12
stepsize = 1e10
dim = length(dist)
x = [logpdf_with_trans(dist, invlink(dist, link(dist, rand(dist)) .+ randn(dim) .* stepsize), true) for _ in 1:1_000]
@test !any(isinf, x) && !any(isnan, x)
else
single_sample_tests(dist, jacobian)
end
Expand Down

0 comments on commit dbdac12

Please sign in to comment.