diff --git a/Project.toml b/Project.toml index 97a65aff..84b1b77a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.15.2" +version = "0.15.3" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index d468bbe9..dc8b6211 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -68,6 +68,28 @@ end with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x) +function truncated_inv_logabsdetjac(y, a, b) + y, a, b = promote(y, a, b) + lowerbounded, upperbounded = isfinite(a), isfinite(b) + if lowerbounded && upperbounded + abs_y = abs(y) + return log(b - a) - abs_y - 2 * LogExpFunctions.log1pexp(-abs_y) + elseif lowerbounded || upperbounded + return y + else + return zero(y) + end +end + +function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y) + a, b = ib.orig.lb, ib.orig.ub + return sum(truncated_inv_logabsdetjac.(y, a, b)) +end + +function with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y) + return transform(ib, y), logabsdetjac(ib, y) +end + # It's only monotonically decreasing if it's only upper-bounded. # In the multivariate case, we can only say something reasonable if entries are monotonic. function is_monotonically_increasing(b::TruncatedBijector) diff --git a/test/bijectors/ordered.jl b/test/bijectors/ordered.jl index 60354005..5f03ea41 100644 --- a/test/bijectors/ordered.jl +++ b/test/bijectors/ordered.jl @@ -78,7 +78,7 @@ end @testset "correctness" begin num_samples = 10_000 - num_adapts = 1_000 + num_adapts = 5_000 @testset "k = $k" for k in [2, 3, 5] @testset "$(typeof(dist))" for dist in [ # Unconstrained diff --git a/test/interface.jl b/test/interface.jl index 44a73878..f5ac7457 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -96,6 +96,21 @@ contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs)) logabsdetjac(inverse(b), y) atol = 1e-6 end end + + @testset "logabsdetjac numerical stability: Bijectors.jl#325" begin + d = Uniform(-1, 1) + b = bijector(d) + y = 80 + # x needs higher precision to be calculated correctly, otherwise + # logpdf_with_trans returns -Inf + d_big = Uniform(big(-1.0), big(1.0)) + b_big = bijector(d_big) + x_big = inverse(b_big)(big(y)) + @test logpdf(d_big, x_big) + logabsdetjacinv(b, y) ≈ + logpdf_with_trans(d_big, x_big, true) atol = 1e-14 + @test logpdf(d_big, x_big) - logabsdetjac(b, x_big) ≈ + logpdf_with_trans(d_big, x_big, true) atol = 1e-14 + end end @testset "Truncated" begin