Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement logabsdetjac for Inverse{<:TruncatedBijector} for better numerical stability #325

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.15.2"
version = "0.15.3"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
22 changes: 22 additions & 0 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,28 @@ end

with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x)

function truncated_inv_logabsdetjac(y, a, b)
y, a, b = promote(y, a, b)
lowerbounded, upperbounded = isfinite(a), isfinite(b)
if lowerbounded && upperbounded
abs_y = abs(y)
return log(b - a) - abs_y - 2 * LogExpFunctions.log1pexp(-abs_y)
elseif lowerbounded || upperbounded
return y
else
return zero(y)
acertain marked this conversation as resolved.
Show resolved Hide resolved
end
end

function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y)
a, b = ib.orig.lb, ib.orig.ub
return sum(truncated_inv_logabsdetjac.(y, a, b))
end

function with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y)
return transform(ib, y), logabsdetjac(ib, y)
end

# It's only monotonically decreasing if it's only upper-bounded.
# In the multivariate case, we can only say something reasonable if entries are monotonic.
function is_monotonically_increasing(b::TruncatedBijector)
Expand Down
2 changes: 1 addition & 1 deletion test/bijectors/ordered.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end

@testset "correctness" begin
num_samples = 10_000
num_adapts = 1_000
num_adapts = 5_000
@testset "k = $k" for k in [2, 3, 5]
@testset "$(typeof(dist))" for dist in [
# Unconstrained
Expand Down
15 changes: 15 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs))
logabsdetjac(inverse(b), y) atol = 1e-6
end
end

@testset "logabsdetjac numerical stability: Bijectors.jl#325" begin
d = Uniform(-1, 1)
b = bijector(d)
y = 80
# x needs higher precision to be calculated correctly, otherwise
# logpdf_with_trans returns -Inf
d_big = Uniform(big(-1.0), big(1.0))
b_big = bijector(d_big)
x_big = inverse(b_big)(big(y))
@test logpdf(d_big, x_big) + logabsdetjacinv(b, y) ≈
logpdf_with_trans(d_big, x_big, true) atol = 1e-14
@test logpdf(d_big, x_big) - logabsdetjac(b, x_big) ≈
logpdf_with_trans(d_big, x_big, true) atol = 1e-14
end
end

@testset "Truncated" begin
Expand Down
Loading