Skip to content

Commit

Permalink
Merge pull request #20 from TuringLang/mt/perf
Browse files Browse the repository at this point in the history
Squeeze some more performance out
  • Loading branch information
mohamed82008 authored Feb 25, 2019
2 parents d72cb6d + c339f74 commit 9ccc4af
Showing 1 changed file with 21 additions and 32 deletions.
53 changes: 21 additions & 32 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ end
#############

const TransformDistribution{T<:ContinuousUnivariateDistribution} = Union{T, Truncated{T}}
function _clamp(x::Real, dist::TransformDistribution)
@inline function _clamp(x::Real, dist::TransformDistribution)
ϵ = eps(x)
bounds = (minimum(dist)+ϵ, maximum(dist)-ϵ)
clamped_x = clamp(x, bounds...)
bounds = (minimum(dist) + ϵ, maximum(dist) - ϵ)
clamped_x = ifelse(x < bounds[1], bounds[1], ifelse(x > bounds[2], bounds[2], x))
DEBUG && @debug "x = $x, bounds = $bounds, clamped_x = $clamped_x"
return clamped_x
end
Expand Down Expand Up @@ -171,13 +171,13 @@ function link(
ϵ = _eps(T)
sum_tmp = zero(T)
@inbounds z = x[1] * (one(T) - 2ϵ) + ϵ # z ∈ [ϵ, 1-ϵ]
@inbounds y[1] = StatsFuns.logit(z) - log(one(T) / (K - 1))
@inbounds y[1] = StatsFuns.logit(z) + log(T(K - 1))
@inbounds @simd for k in 2:(K - 1)
sum_tmp += x[k - 1]
# z ∈ [ϵ, 1-ϵ]
# x[k] = 0 && sum_tmp = 1 -> z ≈ 1
z = (x[k] + ϵ)*(one(T) - 2ϵ)/(one(T) - sum_tmp + ϵ)
y[k] = StatsFuns.logit(z) - log(one(T) / (K - k))
z = (x[k] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp)
y[k] = StatsFuns.logit(z) + log(T(K - k))
end
@inbounds sum_tmp += x[K - 1]
@inbounds if proj
Expand All @@ -201,11 +201,11 @@ function link(
@inbounds @simd for n in 1:size(X, 2)
sum_tmp = zero(T)
z = X[1, n] * (one(T) - 2ϵ) + ϵ
Y[1, n] = StatsFuns.logit(z) - log(one(T) / (K - 1))
Y[1, n] = StatsFuns.logit(z) + log(T(K - 1))
for k in 2:(K - 1)
sum_tmp += X[k - 1, n]
z = (X[k, n] + ϵ)*(one(T) - 2ϵ)/(one(T) - sum_tmp + ϵ)
Y[k, n] = StatsFuns.logit(z) - log(one(T) / (K - k))
z = (X[k, n] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp)
Y[k, n] = StatsFuns.logit(z) + log(T(K - k))
end
sum_tmp += X[K-1, n]
if proj
Expand All @@ -226,11 +226,11 @@ function invlink(
x, K = similar(y), length(y)

ϵ = _eps(T)
@inbounds z = StatsFuns.logistic(y[1] + log(one(T) / (K - 1)))
@inbounds z = StatsFuns.logistic(y[1] - log(T(K - 1)))
@inbounds x[1] = _clamp((z - ϵ) / (one(T) - 2ϵ), d)
sum_tmp = zero(T)
@inbounds @simd for k = 2:(K - 1)
z = StatsFuns.logistic(y[k] + log(one(T) / (K - k)))
z = StatsFuns.logistic(y[k] - log(T(K - k)))
sum_tmp += x[k-1]
x[k] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, d)
end
Expand All @@ -253,12 +253,12 @@ function invlink(

ϵ = _eps(T)
@inbounds @simd for n in 1:size(X, 2)
sum_tmp, z = zero(T), StatsFuns.logistic(Y[1, n] + log(one(T) / (K - 1)))
sum_tmp, z = zero(T), StatsFuns.logistic(Y[1, n] - log(T(K - 1)))
X[1, n] = _clamp((z - ϵ) / (one(T) - 2ϵ), d)
for k in 2:(K - 1)
z = StatsFuns.logistic(Y[k, n] + log(one(T) / (K - k)))
z = StatsFuns.logistic(Y[k, n] - log(T(K - k)))
sum_tmp += X[k - 1]
X[k, n] = _clamp((one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ, d)
X[k, n] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, d)
end
sum_tmp += X[K - 1, n]
if proj
Expand All @@ -284,11 +284,11 @@ function logpdf_with_trans(

sum_tmp = zero(eltype(x))
@inbounds z = x[1]
lp += log(z + ϵ) + log(one(T) - z + ϵ)
lp += log(z + ϵ) + log((one(T) + ϵ) - z)
@inbounds @simd for k in 2:(K - 1)
sum_tmp += x[k-1]
z = x[k] / (one(T) - sum_tmp + ϵ)
lp += log(z + ϵ) + log(one(T) - z + ϵ) + log(one(T) - sum_tmp + ϵ)
z = x[k] / ((one(T) + ϵ) - sum_tmp)
lp += log(z + ϵ) + log((one(T) + ϵ) - z) + log((one(T) + ϵ) - sum_tmp)
end
end
return lp
Expand Down Expand Up @@ -344,13 +344,14 @@ function logpdf_with_trans(
X::AbstractMatrix{<:Real},
transform::Bool
)
T = eltype(X)
lp = logpdf(d, X)
if transform && isfinite(lp)
U = cholesky(X).U
@inbounds @simd for i in 1:dim(d)
lp += (dim(d) - i + 2) * log(U[i, i])
end
lp += dim(d) * log(2.0)
lp += dim(d) * log(T(2))
end
return lp
end
Expand Down Expand Up @@ -382,22 +383,10 @@ end
using Distributions: MultivariateDistribution

link(d::MultivariateDistribution, x::AbstractVector{<:Real}) = copy(x)
function link(d::MultivariateDistribution, X::AbstractMatrix{<:Real})
Y = similar(X)
@inbounds @simd for n in 1:size(X, 2)
Y[:, n] = link(d, view(X, :, n))
end
return Y
end
link(d::MultivariateDistribution, X::AbstractMatrix{<:Real}) = copy(X)

invlink(d::MultivariateDistribution, y::AbstractVector{<:Real}) = copy(y)
function invlink(d::MultivariateDistribution, Y::AbstractMatrix{<:Real})
X = similar(Y)
@inbounds @simd for n in 1:size(Y, 2)
X[:, n] = invlink(d, view(Y, :, n))
end
return X
end
invlink(d::MultivariateDistribution, Y::AbstractMatrix{<:Real}) = copy(Y)

function logpdf_with_trans(d::MultivariateDistribution, x::AbstractVector{<:Real}, ::Bool)
return logpdf(d, x)
Expand Down

0 comments on commit 9ccc4af

Please sign in to comment.