Skip to content

Commit

Permalink
Merge pull request #19 from TuringLang/mt/perf
Browse files Browse the repository at this point in the history
Low hanging performance
  • Loading branch information
mohamed82008 authored Feb 23, 2019
2 parents 3e7b8f8 + c3454de commit d72cb6d
Showing 1 changed file with 36 additions and 34 deletions.
70 changes: 36 additions & 34 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ export TransformDistribution,
invlink,
logpdf_with_trans

const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_BIJECTORS", "0")))

_eps(::Type{T}) where {T} = eps(T)
_eps(::Type{Real}) = eps(Float64)
function __init__()
Expand Down Expand Up @@ -60,11 +62,11 @@ end
#############

const TransformDistribution{T<:ContinuousUnivariateDistribution} = Union{T, Truncated{T}}
@inline function _clamp(x::Real, dist::TransformDistribution)
function _clamp(x::Real, dist::TransformDistribution)
ϵ = eps(x)
bounds = (minimum(dist)+ϵ, maximum(dist)-ϵ)
clamped_x = clamp(x, bounds...)
@debug "x = $x, bounds = $bounds, clamped_x = $clamped_x"
DEBUG && @debug "x = $x, bounds = $bounds, clamped_x = $clamped_x"
return clamped_x
end

Expand Down Expand Up @@ -152,33 +154,33 @@ end
###########

const SimplexDistribution = Union{Dirichlet}
@inline function _clamp(x::T, dist::SimplexDistribution) where T
function _clamp(x::T, dist::SimplexDistribution) where T
bounds = (zero(T), one(T))
clamped_x = clamp(x, bounds...)
@debug "x = $x, bounds = $bounds, clamped_x = $clamped_x"
DEBUG && @debug "x = $x, bounds = $bounds, clamped_x = $clamped_x"
return clamped_x
end

function link(
d::SimplexDistribution,
x::AbstractVector{T},
d::SimplexDistribution,
x::AbstractVector{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
y, K = similar(x), length(x)

ϵ = _eps(T)
sum_tmp = zero(T)
z = x[1] * (one(T) - 2ϵ) + ϵ # z ∈ [ϵ, 1-ϵ]
y[1] = StatsFuns.logit(z) - log(one(T) / (K - 1))
@inbounds for k in 2:(K - 1)
@inbounds z = x[1] * (one(T) - 2ϵ) + ϵ # z ∈ [ϵ, 1-ϵ]
@inbounds y[1] = StatsFuns.logit(z) - log(one(T) / (K - 1))
@inbounds @simd for k in 2:(K - 1)
sum_tmp += x[k - 1]
# z ∈ [ϵ, 1-ϵ]
# x[k] = 0 && sum_tmp = 1 -> z ≈ 1
z = (x[k] + ϵ)*(one(T) - 2ϵ)/(one(T) - sum_tmp + ϵ)
y[k] = StatsFuns.logit(z) - log(one(T) / (K - k))
end
sum_tmp += x[K - 1]
if proj
@inbounds sum_tmp += x[K - 1]
@inbounds if proj
y[K] = zero(T)
else
y[K] = one(T) - sum_tmp - x[K]
Expand All @@ -189,14 +191,14 @@ end

# Vectorised implementation of the above.
function link(
d::SimplexDistribution,
X::AbstractMatrix{T},
d::SimplexDistribution,
X::AbstractMatrix{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
Y, K, N = similar(X), size(X, 1), size(X, 2)

ϵ = _eps(T)
@inbounds for n in 1:size(X, 2)
@inbounds @simd for n in 1:size(X, 2)
sum_tmp = zero(T)
z = X[1, n] * (one(T) - 2ϵ) + ϵ
Y[1, n] = StatsFuns.logit(z) - log(one(T) / (K - 1))
Expand All @@ -217,23 +219,23 @@ function link(
end

function invlink(
d::SimplexDistribution,
y::AbstractVector{T},
d::SimplexDistribution,
y::AbstractVector{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
x, K = similar(y), length(y)

ϵ = _eps(T)
z = StatsFuns.logistic(y[1] + log(one(T) / (K - 1)))
x[1] = _clamp((z - ϵ) / (one(T) - 2ϵ), d)
@inbounds z = StatsFuns.logistic(y[1] + log(one(T) / (K - 1)))
@inbounds x[1] = _clamp((z - ϵ) / (one(T) - 2ϵ), d)
sum_tmp = zero(T)
@inbounds for k = 2:(K - 1)
@inbounds @simd for k = 2:(K - 1)
z = StatsFuns.logistic(y[k] + log(one(T) / (K - k)))
sum_tmp += x[k-1]
x[k] = _clamp((one(T) - sum_tmp + ϵ) / (one(T) - 2ϵ) * z - ϵ, d)
x[k] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, d)
end
sum_tmp += x[K - 1]
if proj
@inbounds sum_tmp += x[K - 1]
@inbounds if proj
x[K] = _clamp(one(T) - sum_tmp, d)
else
x[K] = _clamp(one(T) - sum_tmp - y[K], d)
Expand All @@ -243,14 +245,14 @@ end

# Vectorised implementation of the above.
function invlink(
d::SimplexDistribution,
Y::AbstractMatrix{T},
d::SimplexDistribution,
Y::AbstractMatrix{T},
::Type{Val{proj}} = Val{true}
) where {T<:Real, proj}
X, K, N = similar(Y), size(Y, 1), size(Y, 2)

ϵ = _eps(T)
@inbounds for n in 1:size(X, 2)
@inbounds @simd for n in 1:size(X, 2)
sum_tmp, z = zero(T), StatsFuns.logistic(Y[1, n] + log(one(T) / (K - 1)))
X[1, n] = _clamp((z - ϵ) / (one(T) - 2ϵ), d)
for k in 2:(K - 1)
Expand Down Expand Up @@ -281,9 +283,9 @@ function logpdf_with_trans(
K = length(x)

sum_tmp = zero(eltype(x))
z = x[1]
@inbounds z = x[1]
lp += log(z + ϵ) + log(one(T) - z + ϵ)
@inbounds for k in 2:(K - 1)
@inbounds @simd for k in 2:(K - 1)
sum_tmp += x[k-1]
z = x[k] / (one(T) - sum_tmp + ϵ)
lp += log(z + ϵ) + log(one(T) - z + ϵ) + log(one(T) - sum_tmp + ϵ)
Expand Down Expand Up @@ -322,30 +324,30 @@ const PDMatDistribution = Union{InverseWishart, Wishart}

function link(d::PDMatDistribution, X::AbstractMatrix{T}) where {T<:Real}
Y = cholesky(X).L
for m in 1:size(Y, 1)
@inbounds @simd for m in 1:size(Y, 1)
Y[m, m] = log(Y[m, m])
end
return Matrix(Y)
end

function invlink(d::PDMatDistribution, Y::AbstractMatrix{T}) where {T<:Real}
X, dim = copy(Y), size(Y)
for m in 1:size(X, 1)
@inbounds @simd for m in 1:size(X, 1)
X[m, m] = exp(X[m, m])
end
Z = similar(X)
return mul!(Z, LowerTriangular(X), LowerTriangular(X)')
end

function logpdf_with_trans(
d::PDMatDistribution,
X::AbstractMatrix{<:Real},
d::PDMatDistribution,
X::AbstractMatrix{<:Real},
transform::Bool
)
lp = logpdf(d, X)
if transform && isfinite(lp)
U = cholesky(X).U
for i in 1:dim(d)
@inbounds @simd for i in 1:dim(d)
lp += (dim(d) - i + 2) * log(U[i, i])
end
lp += dim(d) * log(2.0)
Expand Down Expand Up @@ -382,7 +384,7 @@ using Distributions: MultivariateDistribution
link(d::MultivariateDistribution, x::AbstractVector{<:Real}) = copy(x)
function link(d::MultivariateDistribution, X::AbstractMatrix{<:Real})
Y = similar(X)
for n in 1:size(X, 2)
@inbounds @simd for n in 1:size(X, 2)
Y[:, n] = link(d, view(X, :, n))
end
return Y
Expand All @@ -391,7 +393,7 @@ end
invlink(d::MultivariateDistribution, y::AbstractVector{<:Real}) = copy(y)
function invlink(d::MultivariateDistribution, Y::AbstractMatrix{<:Real})
X = similar(Y)
for n in 1:size(Y, 2)
@inbounds @simd for n in 1:size(Y, 2)
X[:, n] = invlink(d, view(Y, :, n))
end
return X
Expand Down

0 comments on commit d72cb6d

Please sign in to comment.