Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement logabsdetjac for Inverse{<:TruncatedBijector} for better numerical stability #325

Open
wants to merge 15 commits into
base: master
Choose a base branch
from

Conversation

acertain
Copy link

Example of previous badness: logabsdetjac(inverse(bijector(Uniform(-1,1))), 80) = -Inf (is now -79.30685281944005)

Formula stolen from stan, I didn't check its correctness.

src/bijectors/truncated.jl Outdated Show resolved Hide resolved
src/bijectors/truncated.jl Outdated Show resolved Hide resolved
src/bijectors/truncated.jl Outdated Show resolved Hide resolved
src/bijectors/truncated.jl Outdated Show resolved Hide resolved
src/bijectors/truncated.jl Outdated Show resolved Hide resolved
src/bijectors/truncated.jl Outdated Show resolved Hide resolved
src/bijectors/truncated.jl Outdated Show resolved Hide resolved
…merical stability

Example of previous badness:  logabsdetjac(inverse(bijector(Uniform(-1,1))), 80) = -Inf (is now -79.30685281944005)
@acertain
Copy link
Author

@devmotion thanks for the suggestions, I've implemented them

@yebai
Copy link
Member

yebai commented Aug 29, 2024

@acertain this seems to break a correctness test for Julia 1.6. Do you want to take a look?

@penelopeysm
Copy link
Member

penelopeysm commented Nov 30, 2024

We dropped 1.6 already, and the remaining tests pass (known Enzyme crashes notwithstanding).

I've bumped the minor version since this technically exposes a new function.


I'm wondering how we can implement a good test for this. We can't do this for example

d = Uniform(-1, 1); b = bijector(d); y = 80; x = inverse(b)(y)

@test logpdf(d, inverse(b)(y)) + logabsdetjacinv(b, y)  logpdf_with_trans(d, x, true)

Thanks to this PR the LHS evaluates to a finite value, but the comparison fails because because x evaluates to 1.0 and logpdf_with_trans returns infinite. (In theory, x should be 0.9999999....... but floats aren't precise enough to capture that.)

Maybe we just check isfinite() for the LHS terms? Or we could do the equality check iff (b ∘ inverse(b))(y) == y?

The existing tests already check for numerical accuracy on a wide range of distributions and non-pathological values, so we should already be fairly confident that this PR does not cause any regressions on these.

# logpdf corresponds to logpdf_with_trans
d = dist
b = @inferred bijector(d)
x = rand(d)
y = @inferred b(x)
@test logpdf(d, inverse(b)(y)) + logabsdetjacinv(b, y)
logpdf_with_trans(d, x, true)
@test logpdf(d, x) - logabsdetjac(b, x) logpdf_with_trans(d, x, true)

@devmotion
Copy link
Member

Probably easiest to compare with BigFloats. E.g., on the master branch the example above gives

julia> using Bijectors

julia> d = Uniform(big(-1.0), big(1.0));

julia> b = bijector(d);

julia> y = big(80.0);

julia> x = inverse(b)(y)
0.9999999999999999999999999999999999639029722430916965537574328529994522280756386

julia> logpdf_with_trans(d, x, true)
-80.000000000000000000000000000000000036097027594005675893754704116470597853965

Project.toml Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants