Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement logabsdetjac for Inverse{<:TruncatedBijector} for better numerical stability #325

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
19 changes: 19 additions & 0 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,25 @@ end

with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x)

function truncated_inv_logabsdetjac(y, a, b)
lowerbounded, upperbounded = isfinite(a), isfinite(b)
if lowerbounded && upperbounded
abs_y = abs(y)
return log(b - a) - abs_y + 2 * LogExpFunctions.log1pexp(-abs_y)
elseif lowerbounded || upperbounded
return convert(promote_type(typeof(y), typeof(a), typeof(b)), y)
else
return zero(y)
acertain marked this conversation as resolved.
Show resolved Hide resolved
end
end

function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y)
a, b = ib.orig.lb, ib.orig.ub
return truncated_inv_logabsdetjac.(y, a, b)
end

with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y) = transform(ib, y), logabsdetjac(ib, y)

# It's only monotonically decreasing if it's only upper-bounded.
# In the multivariate case, we can only say something reasonable if entries are monotonic.
function is_monotonically_increasing(b::TruncatedBijector)
Expand Down
Loading