diff --git a/Project.toml b/Project.toml index 726f47ce..4717f1af 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.8.8" +version = "0.8.9" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" @@ -9,10 +9,10 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" -Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" @@ -23,8 +23,8 @@ Compat = "3" Distributions = "0.23.3, 0.24" MappedArrays = "0.2.2, 0.3" NNlib = "0.6, 0.7" +NonlinearSolve = "0.3" Reexport = "0.2" Requires = "0.5, 1" -Roots = "0.8.4, 1" StatsFuns = "0.8, 0.9.3" julia = "1.3" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index e435ebf7..a37e3ae2 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -33,9 +33,9 @@ using Reexport, Requires using StatsFuns using LinearAlgebra using MappedArrays -using Roots using Base.Iterators: drop using LinearAlgebra: AbstractTriangular +import NonlinearSolve export TransformDistribution, PositiveDistribution, diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index fbb24f41..8332c08d 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -1,7 +1,6 @@ using LinearAlgebra using Random using NNlib: softplus -using Roots # for inverse ################################################################################ # Planar and Radial Flows # @@ -74,7 +73,7 @@ function (ib::Inverse{<:PlanarLayer})(y::AbstractVector{<:Real}) # Find the scalar ``alpha`` from A.1. wt_y = dot(w, y) wt_u_hat = dot(w, u_hat) - alpha = find_alpha(y, wt_y, wt_u_hat, b) + alpha = find_alpha(wt_y, wt_u_hat, b) return y .- u_hat .* tanh(alpha * norm(w, 2) + b) end @@ -88,14 +87,14 @@ function (ib::Inverse{<:PlanarLayer})(y::AbstractMatrix{<:Real}) # Find the scalar ``alpha`` from A.1 for each column. wt_u_hat = dot(w, u_hat) alphas = mapvcat(eachcol(y)) do c - find_alpha(c, dot(w, c), wt_u_hat, b) + find_alpha(dot(w, c), wt_u_hat, b) end - return y .- u_hat .* tanh.(alphas' .* norm(w, 2) .+ b) + return y .- u_hat .* tanh.(reshape(alphas, 1, :) .* norm(w, 2) .+ b) end """ - find_alpha(y::AbstractVector{<:Real}, wt_y, wt_u_hat, b) + find_alpha(wt_y, wt_u_hat, b) Compute an (approximate) real-valued solution ``α`` to the equation ```math @@ -110,7 +109,7 @@ For details see appendix A.1 of the reference. D. Rezende, S. Mohamed (2015): Variational Inference with Normalizing Flows. arXiv:1505.05770 """ -function find_alpha(y::AbstractVector{<:Real}, wt_y, wt_u_hat, b) +function find_alpha(wt_y, wt_u_hat, b) # Compute the initial bracket ((-Inf, 0) or (0, Inf)) f0 = wt_u_hat * tanh(b) - wt_y zero_f0 = zero(f0) @@ -119,10 +118,10 @@ function find_alpha(y::AbstractVector{<:Real}, wt_y, wt_u_hat, b) else initial_bracket = (oftype(f0, -Inf), zero_f0) end - alpha = find_zero(initial_bracket) do x + prob = NonlinearSolve.NonlinearProblem{false}(initial_bracket) do x, _ x + wt_u_hat * tanh(x + b) - wt_y end - + alpha = NonlinearSolve.solve(prob, NonlinearSolve.Falsi()).left return alpha end diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index b29ac8d2..ad3e3d57 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -1,7 +1,6 @@ using LinearAlgebra using Random using NNlib: softplus -using Roots # for inverse ################################################################################ # Planar and Radial Flows # @@ -77,8 +76,9 @@ function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) # Compute the norm ``r`` from A.2. y_minus_z0 = y .- z0 r = compute_r(y_minus_z0, α, α_plus_β_hat) + γ = (α + r) / (α_plus_β_hat + r) - return z0 .+ ((α + r) / (α_plus_β_hat + r)) .* y_minus_z0 + return z0 .+ γ .* y_minus_z0 end function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real}) @@ -92,8 +92,9 @@ function (ib::Inverse{<:RadialLayer})(y::AbstractMatrix{<:Real}) rs = mapvcat(eachcol(y_minus_z0)) do c return compute_r(c, α, α_plus_β_hat) end + γ = reshape((α .+ rs) ./ (α_plus_β_hat .+ rs), 1, :) - return z0 .+ ((α .+ rs) ./ (α_plus_β_hat .+ rs))' .* y_minus_z0 + return z0 .+ γ .* y_minus_z0 end """ @@ -124,4 +125,3 @@ function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) end logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac -isclosedform(b::Inverse{<:RadialLayer}) = false diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 506d95d7..52c20b1a 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -85,8 +85,8 @@ DistSpec(Poisson, (0.5,), 1), DistSpec(Poisson, (0.5,), [1, 1]), - DistSpec(Skellam, (1.0, 2.0), -2), - DistSpec(Skellam, (1.0, 2.0), [-2, -2]), + DistSpec(Skellam, (1.0, 2.0), -2; broken=(:Zygote,)), + DistSpec(Skellam, (1.0, 2.0), [-2, -2]; broken=(:Zygote,)), DistSpec(PoissonBinomial, ([0.5, 0.5],), 0), DistSpec(PoissonBinomial, ([0.5, 0.5],), [0, 0]), @@ -194,9 +194,8 @@ DistSpec(NormalCanon, (1.0, 2.0), 0.5), - DistSpec(NormalInverseGaussian, (1.0, 2.0, 1.0, 1.0), 0.5), + DistSpec(NormalInverseGaussian, (1.0, 2.0, 1.0, 1.0), 0.5; broken=(:Zygote,)), - DistSpec(Pareto, (), 1.5), DistSpec(Pareto, (1.0,), 1.5), DistSpec(Pareto, (1.0, 1.0), 1.5), @@ -246,6 +245,9 @@ # Stackoverflow caused by SpecialFunctions.besselix DistSpec(VonMises, (1.0,), 1.0), DistSpec(VonMises, (1, 1), 1), + + # Only some Zygote tests are broken and therefore this can not be checked + DistSpec(Pareto, (), 1.5; broken=(:Zygote,)), ] # Tests that have a `broken` field can be executed but, according to FiniteDifferences, @@ -430,7 +432,7 @@ # Skellam only fails in these tests with ReverseDiff # Ref: https://github.com/TuringLang/DistributionsAD.jl/issues/126 - filldist_broken = d.f(d.θ...) isa Skellam ? (:ReverseDiff,) : d.broken + filldist_broken = d.f(d.θ...) isa Skellam ? (d.broken..., :ReverseDiff) : d.broken arraydist_broken = d.broken # Create `filldist` distribution diff --git a/test/interface.jl b/test/interface.jl index cae79ee8..7a382d97 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -176,25 +176,11 @@ end y = @inferred b(x) ys = @inferred b(xs) - - # Computations which do not have closed-form implementations are not necessarily - # differentiable, and so we skip them. - # HACK: In reality, this is just circumventing the fact that Tracker isn't happy with - # the `find_zero` function used by some bijectors. - if isclosedform(b) - @inferred(b(param(xs))) - end + @inferred(b(param(xs))) x_ = @inferred ib(y) xs_ = @inferred ib(ys) - - # Computations which do not have closed-form implementations are not necessarily - # differentiable, and so we skip them. - # HACK: In reality, this is just circumventing the fact that Tracker isn't happy with - # the `find_zero` function used by some bijectors. - if isclosedform(ib) - @inferred(ib(param(ys))) - end + @inferred(ib(param(ys))) result = @inferred forward(b, x) results = @inferred forward(b, xs) @@ -226,8 +212,8 @@ end @test length(logabsdetjac(b, xs)) == length(xs) @test length(logabsdetjac(ib, ys)) == length(xs) - isclosedform(b) && @test @inferred(logabsdetjac(b, param(xs))) isa Union{Array, TrackedArray} - isclosedform(ib) && @test @inferred(logabsdetjac(ib, param(ys))) isa Union{Array, TrackedArray} + @test @inferred(logabsdetjac(b, param(xs))) isa Union{Array, TrackedArray} + @test @inferred(logabsdetjac(ib, param(ys))) isa Union{Array, TrackedArray} @test size(results.logabsdetjac) == size(xs, ) @test size(iresults.logabsdetjac) == size(ys, ) @@ -240,8 +226,8 @@ end @test logabsdetjac.(ib, ys) == @inferred(logabsdetjac(ib, ys)) @test @inferred(logabsdetjac(ib, ys)) ≈ ib_logjac_ad atol=1e-9 - isclosedform(b) && @test logabsdetjac.(b, param(xs)) == @inferred(logabsdetjac(b, param(xs))) - isclosedform(ib) && @test logabsdetjac.(ib, param(ys)) == @inferred(logabsdetjac(ib, param(ys))) + @test logabsdetjac.(b, param(xs)) == @inferred(logabsdetjac(b, param(xs))) + @test logabsdetjac.(ib, param(ys)) == @inferred(logabsdetjac(ib, param(ys))) @test results.logabsdetjac ≈ vec(logabsdetjac.(b, xs)) @test iresults.logabsdetjac ≈ vec(logabsdetjac.(ib, ys)) @@ -253,8 +239,8 @@ end @test size(logabsdetjac(b, xs)) == (size(xs, 2), ) @test size(logabsdetjac(ib, ys)) == (size(xs, 2), ) - isclosedform(b) && @test @inferred(logabsdetjac(b, param(xs))) isa Union{Array, TrackedArray} - isclosedform(ib) && @test @inferred(logabsdetjac(ib, param(ys))) isa Union{Array, TrackedArray} + @test @inferred(logabsdetjac(b, param(xs))) isa Union{Array, TrackedArray} + @test @inferred(logabsdetjac(ib, param(ys))) isa Union{Array, TrackedArray} @test size(results.logabsdetjac) == (size(xs, 2), ) @test size(iresults.logabsdetjac) == (size(ys, 2), ) @@ -266,16 +252,15 @@ end @test results.logabsdetjac ≈ vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) @test iresults.logabsdetjac ≈ vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) - # some have issues with numerically solving the inverse # FIXME: `SimplexBijector` results in ∞ gradient if not in the domain - if isclosedform(b) && !contains(t -> t isa SimplexBijector, b) + if !contains(t -> t isa SimplexBijector, b) b_logjac_ad = [logabsdet(ForwardDiff.jacobian(b, xs[:, i]))[1] for i = 1:size(xs, 2)] - @test logabsdetjac(b, xs) ≈ b_logjac_ad atol=1e-9 - end + tol = isclosedform(b) ? 1e-9 : 1e-1 + @test logabsdetjac(b, xs) ≈ b_logjac_ad rtol=tol atol=tol - if isclosedform(inv(b)) && !contains(t -> t isa SimplexBijector, b) ib_logjac_ad = [logabsdet(ForwardDiff.jacobian(ib, ys[:, i]))[1] for i = 1:size(ys, 2)] - @test logabsdetjac(ib, ys) ≈ ib_logjac_ad atol=1e-9 + tol = isclosedform(ib) ? 1e-9 : 1e-1 + @test logabsdetjac(ib, ys) ≈ ib_logjac_ad rtol=tol atol=tol end else error("tests not implemented yet")