From 213328d160be66f9102d2738a45e3b72fa62b2b4 Mon Sep 17 00:00:00 2001 From: acertain Date: Sat, 17 Aug 2024 16:49:31 -0600 Subject: [PATCH] implement logabsdetjac for Inverse{<:TruncatedBijector} for better numerical stability Example of previous badness: logabsdetjac(inverse(bijector(Uniform(-1,1))), 80) = -Inf (is now -79.30685281944005) --- src/bijectors/truncated.jl | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index d468bbe9..9517807e 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -68,6 +68,25 @@ end with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x) +function truncated_inv_logabsdetjac(y, a, b) + lowerbounded, upperbounded = isfinite(a), isfinite(b) + if lowerbounded && upperbounded + abs_y = abs(y) + return log(b - a) - abs_y + 2 * LogExpFunctions.log1pexp(-abs_y) + elseif lowerbounded || upperbounded + return convert(promote_type(typeof(y), typeof(a), typeof(b)), y) + else + return zero(y) + end +end + +function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y) + a, b = ib.orig.lb, ib.orig.ub + return truncated_inv_logabsdetjac.(y, a, b) +end + +with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y) = transform(ib, y), logabsdetjac(ib, y) + # It's only monotonically decreasing if it's only upper-bounded. # In the multivariate case, we can only say something reasonable if entries are monotonic. function is_monotonically_increasing(b::TruncatedBijector)