From 4fcda273c3a317677aa1bc96434597a39b5f459a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 14 Jun 2021 09:17:06 +0100 Subject: [PATCH] Default `bijector` implementation for `TransformedDistribution` (#187) * added default bijector impl for TransformedDistribution * patch-version bump * added a couple of transformed dists to test it * also test multivariate transformed * fix interface tests * same as previous commit but for multivariate distributions --- Project.toml | 2 +- src/transformed_distribution.jl | 1 + test/interface.jl | 16 ++++++++++++---- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 3a2dddf5..3fca586d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.9.5" +version = "0.9.6" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 6a30f7eb..1712ba2e 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -35,6 +35,7 @@ transformed(d) = transformed(d, bijector(d)) Returns the constrained-to-unconstrained bijector for distribution `d`. """ +bijector(td::TransformedDistribution) = bijector(td.dist) ∘ inv(td.transform) bijector(d::DiscreteUnivariateDistribution) = Identity{0}() bijector(d::DiscreteMultivariateDistribution) = Identity{1}() bijector(d::ContinuousUnivariateDistribution) = TruncatedBijector(minimum(d), maximum(d)) diff --git a/test/interface.jl b/test/interface.jl index 7aad5176..102fe23b 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -63,6 +63,8 @@ end Rayleigh(1.0), TDist(2), truncated(Normal(0, 1), -Inf, 2), + transformed(Beta(2,2)), + transformed(Exponential()), ] for dist in uni_dists @@ -106,8 +108,10 @@ end b = bijector(d) x = rand(d) y = b(x) - @test log(abs(ForwardDiff.derivative(b, x))) ≈ logabsdetjac(b, x) - @test log(abs(ForwardDiff.derivative(inv(b), y))) ≈ logabsdetjac(inv(b), y) + # `ForwardDiff.derivative` can lead to some numerical inaccuracy, + # so we use a slightly higher `atol` than default. + @test log(abs(ForwardDiff.derivative(b, x))) ≈ logabsdetjac(b, x) atol=1e-6 + @test log(abs(ForwardDiff.derivative(inv(b), y))) ≈ logabsdetjac(inv(b), y) atol=1e-6 end @testset "$dist: ForwardDiff AD" begin @@ -401,6 +405,8 @@ end MvLogNormal(MvNormal(randn(10), exp.(randn(10)))), Dirichlet([1000 * one(Float64), eps(Float64)]), Dirichlet([eps(Float64), 1000 * one(Float64)]), + transformed(MvNormal(randn(10), exp.(randn(10)))), + transformed(MvLogNormal(MvNormal(randn(10), exp.(randn(10))))) ] for dist in vector_dists @@ -446,9 +452,11 @@ end b = bijector(dist) x = rand(dist) y = b(x) + # `ForwardDiff.derivative` can lead to some numerical inaccuracy, + # so we use a slightly higher `atol` than default. @test b(param(x)) isa TrackedArray - @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) - @test log(abs(det(ForwardDiff.jacobian(inv(b), y)))) ≈ logabsdetjac(inv(b), y) + @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) atol=1e-6 + @test log(abs(det(ForwardDiff.jacobian(inv(b), y)))) ≈ logabsdetjac(inv(b), y) atol=1e-6 end end end