-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cholesky numerical stability: inverse transform #356
base: master
Are you sure you want to change the base?
Conversation
164a33c
to
9d11ba4
Compare
9d11ba4
to
21ec0e0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAICT it's even simpler since LogExpFunctions.logcosh
is supposed to provide a numerically stable and efficient implementation of
👀 I'll give that a spin |
Looking under the hood logcosh is implemented the same way as above, but the single function call is great 👍 |
Co-authored-by: David Widmann <[email protected]>
src/bijectors/corr.jl
Outdated
@@ -495,8 +490,7 @@ function _logabsdetjac_inv_chol(y::AbstractVector) | |||
@inbounds for j in 2:K | |||
tmp = zero(result) | |||
for _ in 1:(j - 1) | |||
z = tanh(y[idx]) | |||
logz = 2 * log(2 / (exp(y[idx]) + exp(-y[idx]))) | |||
logz = -2 * LogExpFunctions.logcosh(y[idx]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name logz
is a bit meaningless now I guess 🙂
Note, I only have a tiny bit of experience with numerical programming (on matrix exponentials) so this is definitely Not My Area of Expertise and I might be doing something horribly wrong
This PR attempts to improve the numerical stability of the inverse transform for Cholesky matrices. For the forward transform, see this PR: #357
Description
This PR replaces
log1p(-z^2) / 2
wherez = tanh(y[idx])
withIrrationalConstants.logtwo + y[idx] - LogExpFunctions.log1pexp(2 * y[idx])
(which is the same mathematical expression) in_inv_link_chol_lkj
:(Note 1: I tried implementing this directly as
log(2 / (exp(y[idx]) + exp(-y[idx])))
, but this was worse in terms of performance.)Note2: This has now been replaced directly with a call to
LogExpFunctions.logcosh
, which does the same calculation (see https://github.com/JuliaStats/LogExpFunctions.jl/blob/289114f535827c612ce10c01b8dec9d3a55e4d15/src/basicfuns.jl#L132-L135) except for an additional call toabs
(which, I guess, makes for better numerical stability for large negative values of y).Accuracy 1
First, to make sure there aren't any regressions, we'll:
Code to generate plot
There isn't really much between the two implementations, sometimes the old one is better, sometimes the new one is better. In any case, the differences are very small so I think the new implementation can be said to almost break even, although I do think the old implementation is very slightly better.
Accuracy 2
However, when sampling in the unconstrained space, there's no guarantee that the resulting sample will resemble anything like the samples obtained via a forward transformation. This leads to issues like #279.
To test out the numerical stability of invlinking random transformed samples, we can:
Code to generate plot
As can be seen, the new method leads to much smaller errors (consistently around the magnitude of
eps() ~ 1e-16
) whereas the old method often has errors that are several orders of magnitude larger.Performance
What next
Note that this issue doesn't actually fully solve #279. That issue arises not because of the inverse transformation, but rather because of the forward transformation (in the call to
logabsdetjac
). This is a result of more numerical instabilities in other functions, specifically the linking one. #357 contains a potential fix for this.