From 213328d160be66f9102d2738a45e3b72fa62b2b4 Mon Sep 17 00:00:00 2001 From: acertain Date: Sat, 17 Aug 2024 16:49:31 -0600 Subject: [PATCH 01/11] implement logabsdetjac for Inverse{<:TruncatedBijector} for better numerical stability Example of previous badness: logabsdetjac(inverse(bijector(Uniform(-1,1))), 80) = -Inf (is now -79.30685281944005) --- src/bijectors/truncated.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index d468bbe9..9517807e 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -68,6 +68,25 @@ end with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x) +function truncated_inv_logabsdetjac(y, a, b) + lowerbounded, upperbounded = isfinite(a), isfinite(b) + if lowerbounded && upperbounded + abs_y = abs(y) + return log(b - a) - abs_y + 2 * LogExpFunctions.log1pexp(-abs_y) + elseif lowerbounded || upperbounded + return convert(promote_type(typeof(y), typeof(a), typeof(b)), y) + else + return zero(y) + end +end + +function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y) + a, b = ib.orig.lb, ib.orig.ub + return truncated_inv_logabsdetjac.(y, a, b) +end + +with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y) = transform(ib, y), logabsdetjac(ib, y) + # It's only monotonically decreasing if it's only upper-bounded. # In the multivariate case, we can only say something reasonable if entries are monotonic. function is_monotonically_increasing(b::TruncatedBijector) From 1438fcd820990d7714e3c05f2dbf8aba37a89dff Mon Sep 17 00:00:00 2001 From: acertain Date: Mon, 19 Aug 2024 16:10:47 -0600 Subject: [PATCH 02/11] promote at start, try to fix test --- src/bijectors/truncated.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index 9517807e..cfef407b 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -69,12 +69,13 @@ end with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x) function truncated_inv_logabsdetjac(y, a, b) + y, a, b = promote(y, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded abs_y = abs(y) return log(b - a) - abs_y + 2 * LogExpFunctions.log1pexp(-abs_y) elseif lowerbounded || upperbounded - return convert(promote_type(typeof(y), typeof(a), typeof(b)), y) + return y else return zero(y) end @@ -82,10 +83,12 @@ end function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y) a, b = ib.orig.lb, ib.orig.ub - return truncated_inv_logabsdetjac.(y, a, b) + return sum(truncated_inv_logabsdetjac.(y, a, b)) end -with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y) = transform(ib, y), logabsdetjac(ib, y) +function with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y) + return transform(ib, y), logabsdetjac(ib, y) +end # It's only monotonically decreasing if it's only upper-bounded. # In the multivariate case, we can only say something reasonable if entries are monotonic. From ab97706894063a6144cdd79c423f5e7ff7ec64d1 Mon Sep 17 00:00:00 2001 From: acertain Date: Mon, 19 Aug 2024 17:35:58 -0600 Subject: [PATCH 03/11] fix? --- src/bijectors/truncated.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index cfef407b..b810870b 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -72,8 +72,7 @@ function truncated_inv_logabsdetjac(y, a, b) y, a, b = promote(y, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded - abs_y = abs(y) - return log(b - a) - abs_y + 2 * LogExpFunctions.log1pexp(-abs_y) + return log(b - a) + y - 2 * LogExpFunctions.log1pexp(y) elseif lowerbounded || upperbounded return y else From 88d23b536668b52728531135c8e46e14264efd29 Mon Sep 17 00:00:00 2001 From: acertain Date: Mon, 19 Aug 2024 18:19:59 -0600 Subject: [PATCH 04/11] fix test --- test/bijectors/ordered.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/bijectors/ordered.jl b/test/bijectors/ordered.jl index b2115fe2..bbf3866d 100644 --- a/test/bijectors/ordered.jl +++ b/test/bijectors/ordered.jl @@ -78,7 +78,7 @@ end @testset "correctness" begin num_samples = 10_000 - num_adapts = 1_000 + num_adapts = 5_000 @testset "k = $k" for k in [2, 3, 5] @testset "$(typeof(dist))" for dist in [ # Unconstrained From 994aa5febeed82179ee37472821a745c43be95e6 Mon Sep 17 00:00:00 2001 From: acertain Date: Tue, 20 Aug 2024 11:10:10 -0600 Subject: [PATCH 05/11] back to abs formula --- src/bijectors/truncated.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index b810870b..dc8b6211 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -72,7 +72,8 @@ function truncated_inv_logabsdetjac(y, a, b) y, a, b = promote(y, a, b) lowerbounded, upperbounded = isfinite(a), isfinite(b) if lowerbounded && upperbounded - return log(b - a) + y - 2 * LogExpFunctions.log1pexp(y) + abs_y = abs(y) + return log(b - a) - abs_y - 2 * LogExpFunctions.log1pexp(-abs_y) elseif lowerbounded || upperbounded return y else From cdc6128a763bdf01c016fbecde412de990124c7e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 30 Nov 2024 00:29:28 +0000 Subject: [PATCH 06/11] Bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 18586cd5..2db72fbc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.15" +version = "0.15.1" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From 23705a351ae3536f24f3419228c86ac1558910bd Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 30 Nov 2024 00:52:04 +0000 Subject: [PATCH 07/11] Actually bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2db72fbc..61f00465 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.15.1" +version = "0.16" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From ee08b22e3e6190d5aea5515f65d4d45a7502bbb7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 30 Nov 2024 00:52:32 +0000 Subject: [PATCH 08/11] Add test for Uniform(-1, 1), y=80 --- test/interface.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/interface.jl b/test/interface.jl index 44a73878..1c2da33e 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -96,6 +96,15 @@ contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs)) logabsdetjac(inverse(b), y) atol = 1e-6 end end + + @testset "numerical stability with large numbers: Bijectors.jl#325" begin + d = Uniform(big(-1.0), big(1.0)) + b = bijector(d) + y = big(80) + x = inverse(b)(y) + @test logpdf(d, inverse(b)(y)) + logabsdetjacinv(b, y) ≈ + logpdf_with_trans(d, x, true) + end end @testset "Truncated" begin From 7be7ea9c13ca9a8592abd44829e38296f3037d1b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 30 Nov 2024 00:59:54 +0000 Subject: [PATCH 09/11] Tweak test to be more discerning --- test/interface.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 1c2da33e..e5b6737a 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -97,13 +97,17 @@ contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs)) end end - @testset "numerical stability with large numbers: Bijectors.jl#325" begin - d = Uniform(big(-1.0), big(1.0)) + @testset "logabsdetjac numerical stability: Bijectors.jl#325" begin + d = Uniform(-1, 1) b = bijector(d) - y = big(80) - x = inverse(b)(y) - @test logpdf(d, inverse(b)(y)) + logabsdetjacinv(b, y) ≈ - logpdf_with_trans(d, x, true) + y = 80 + # x needs higher precision to be calculated correctly, otherwise + # logpdf_with_trans returns -Inf + d_big = Uniform(big(-1.0), big(1.0)) + b_big = bijector(d_big) + x_big = inverse(b_big)(big(y)) + @test logpdf(d, x_big) + logabsdetjacinv(b, y) ≈ + logpdf_with_trans(d_big, x_big, true) atol = 1e-14 end end From e82d0d708373e27ef9e1ec14ef20592be334f906 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 30 Nov 2024 01:03:37 +0000 Subject: [PATCH 10/11] Test forward logabsdetjac too --- test/interface.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index e5b6737a..f5ac7457 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -106,7 +106,9 @@ contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs)) d_big = Uniform(big(-1.0), big(1.0)) b_big = bijector(d_big) x_big = inverse(b_big)(big(y)) - @test logpdf(d, x_big) + logabsdetjacinv(b, y) ≈ + @test logpdf(d_big, x_big) + logabsdetjacinv(b, y) ≈ + logpdf_with_trans(d_big, x_big, true) atol = 1e-14 + @test logpdf(d_big, x_big) - logabsdetjac(b, x_big) ≈ logpdf_with_trans(d_big, x_big, true) atol = 1e-14 end end From 67b699ab2b06b9610b840b77b2750c460ad8eab7 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 30 Nov 2024 18:46:07 +0000 Subject: [PATCH 11/11] Update Project.toml Co-authored-by: David Widmann --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 61f00465..84b1b77a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.16" +version = "0.15.3" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"