-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement logabsdetjac for Inverse{<:TruncatedBijector} for better numerical stability #325
base: master
Are you sure you want to change the base?
Conversation
…merical stability Example of previous badness: logabsdetjac(inverse(bijector(Uniform(-1,1))), 80) = -Inf (is now -79.30685281944005)
@devmotion thanks for the suggestions, I've implemented them |
@acertain this seems to break a correctness test for Julia 1.6. Do you want to take a look? |
We dropped 1.6 already, and the remaining tests pass (known Enzyme crashes notwithstanding). I've bumped the minor version since this technically exposes a new function. I'm wondering how we can implement a good test for this. We can't do this for example d = Uniform(-1, 1); b = bijector(d); y = 80; x = inverse(b)(y)
@test logpdf(d, inverse(b)(y)) + logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) Thanks to this PR the LHS evaluates to a finite value, but the comparison fails because because Maybe we just check isfinite() for the LHS terms? Or we could do the equality check iff The existing tests already check for numerical accuracy on a wide range of distributions and non-pathological values, so we should already be fairly confident that this PR does not cause any regressions on these. Bijectors.jl/test/interface.jl Lines 78 to 85 in d342371
|
Probably easiest to compare with julia> using Bijectors
julia> d = Uniform(big(-1.0), big(1.0));
julia> b = bijector(d);
julia> y = big(80.0);
julia> x = inverse(b)(y)
0.9999999999999999999999999999999999639029722430916965537574328529994522280756386
julia> logpdf_with_trans(d, x, true)
-80.000000000000000000000000000000000036097027594005675893754704116470597853965 |
Co-authored-by: David Widmann <[email protected]>
Example of previous badness:
logabsdetjac(inverse(bijector(Uniform(-1,1))), 80) = -Inf (is now -79.30685281944005)
Formula stolen from stan, I didn't check its correctness.