-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cholesky numerical stability: Forward transform #357
base: py/chol-numerical
Are you sure you want to change the base?
Conversation
CI is failing because ForwardDiff calculates quite a different Jacobian in this test. Since the new implementation still works correctly on roundtrip transformation, and this PR doesn't touch anything to do with Jacobian calculations, I wonder if this is a ForwardDiff bug (?) Bijectors.jl/test/transform.jl Lines 239 to 255 in f52a9c5
Repro: using Bijectors
using ForwardDiff: ForwardDiff
using LinearAlgebra: logabsdet, I, Cholesky
using Random
uplo = :L
dist = LKJCholesky(3, 1, uplo)
x = rand(Xoshiro(468), dist)
inds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
(uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
]
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
J = J[:, inds]
logabsdet(J)[1] Before this PR: julia> x = rand(Xoshiro(468), dist)
Cholesky{Float64, Matrix{Float64}}
L factor:
3×3 LinearAlgebra.LowerTriangular{Float64, Matrix{Float64}}:
1.0 ⋅ ⋅
-0.23039 0.973098 ⋅
0.288231 -0.899424 0.328572
julia> logabsdet(J)[1]
2.3239053137427703 With this PR: julia> logabsdet(J)[1]
0.184638090189601 |
src/bijectors/corr.jl
Outdated
@@ -297,11 +297,10 @@ function _link_chol_lkj(W::AbstractMatrix) | |||
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero. | |||
|
|||
@inbounds for j in 1:K | |||
remainder_sq = one(eltype(W)) | |||
for i in 1:(j - 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could save a few operations by reversing the loop and summing remainder_sq
incrementally.
Looping is implemented in reverse now. Had to be careful with indices but it's looking good, the original comment has been updated with benchmarks etc:) The ForwardDiff test is still wonky, though. Need to dig into that one a bit... |
322e010
to
229ea4c
Compare
Another point: I wonder if generally it would be better to avoid |
Tried to look into the AD issues. The problem is that the new implementation uses different matrix elements to arrive at the same answer, but ForwardDiff doesn't know that using ForwardDiff
using Bijectors
using LinearAlgebra
using Random
using LogExpFunctions
using Test: @test
dist = LKJCholesky(3, 1, 'U')
# Minimised version of new forward transform
function chol_upper_new(W::AbstractMatrix)
y1 = atanh(W[1, 2])
y2 = atanh(W[1, 3])
z = W[2, 3] / W[3, 3]
# NOTE: If we replace W[3, 3] here with sqrt(1 - W[1, 3]^2 - W[2, 3]^2) then
# the autodiff works. But the whole problem / the point of this PR was that
# this is numerically unstable for small values of W[3, 3].
y3 = asinh(z)
return [y1, y2, y3]
end
# Minimised version of old forward transform
function chol_upper_old(W::AbstractMatrix)
y1 = atanh(W[1, 2])
y2 = atanh(W[1, 3])
z = W[2, 3] / sqrt(1 - W[1, 3]^2)
y3 = atanh(z)
return [y1, y2, y3]
end
# Check that they both do the same thing, and they both do the same
# thing as Bijectors.jl forward transform – so they are the same function
for _ in 1:1000
x_new = rand(dist)
@test chol_upper_new(x_new.UL) ≈ chol_upper_old(x_new.UL) atol=1e-12
@test chol_upper_new(x_new.UL) ≈ bijector(dist)(x_new) atol=1e-12
end
# ForwardDiff gives different Jacobians though:
x = rand(Random.Xoshiro(468), dist)
J_FD_new = ForwardDiff.jacobian(chol_upper_new, x.UL)
# 3×9 Matrix{Float64}:
# 0.0 0.0 0.0 1.05605 0.0 0.0 0.0 0.0 0.0
# 0.0 0.0 0.0 0.0 0.0 0.0 1.0906 0.0 0.0
# 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.04432 2.85869
J_FD_old = ForwardDiff.jacobian(chol_upper_old, x.UL)
# 3×9 Matrix{Float64}:
# 0.0 0.0 0.0 1.05605 0.0 0.0 0.0 0.0 0.0
# 0.0 0.0 0.0 0.0 0.0 0.0 1.0906 0.0 0.0
# 0.0 0.0 0.0 0.0 0.0 0.0 -2.50772 8.86963 0.0 So it seems like the issue in the failing test might be the indexing by Bijectors.jl/test/transform.jl Lines 246 to 251 in f52a9c5
With the old implementation, all the other columns are zero, but with the new one this isn't true. 😬 Edit: Fixed the test by giving ForwardDiff only a vector of the free parameters in the Cholesky factor, instead of giving it the whole matrix. It only works on 3x3 for now |
Remaining failing test: Bijectors.jl/test/ad/chainrules.jl Lines 86 to 95 in f52a9c5
|
This is a companion PR to #356. It attempts to solve the following issue, first reported in #279:
Introduction
The forward transform acts on an upper triangular matrix,
W
, which is supposed to have unit vectors for each column, i.e.sum(W[:, j] .^ 2)
should be 1 for eachj
:In the forward transform code,
remainder_sq
is initialised at one and then the squares of each element going down columnj
are successively subtracted, soremainder_sq
is really a sum of squares of elements not yet seen.Bijectors.jl/src/bijectors/corr.jl
Lines 321 to 331 in f52a9c5
Now, in principle, because
z^2 = W[i, j]^2 / (sum of W[i:end, j]^2)
, there is no way thatz^2
can be larger than 1.However, because of floating point imprecisions, sometimes this isn't true. This is especially likely to happen if the last element
W[j-1, j]
is very small. This doesn't tend to happen whenW
is sampled fromLKJCholesky
, but it can happen whenW
is obtained through inverse transformation of some random unconstrained vector, as described in e.g. #279.A proposed fix, instead of subtracting successive squares from 1, could just declare
remainder_sq
to be that sum:In practice, this is implemented by looping in reverse (over
(j-1):-1:2
) and incrementingremainder_sq
following @devmotion's suggestion below.Now, because
W[i, j] ^ 2
is part of that sum,z
can now no longer be larger than 1, and atanh doesn't throw a DomainError.Setup code for this comment
Setup code
Impacts of this change
First, let's check roundtrip transformation on typical samples from Cholesky. The numerical accuracy here is actually marginally better than the existing implementation:
On top of that, it fixes the DomainErrors which occur with random unconstrained inputs:
Performance
Remaining concern: accuracy on pathological samples
It's not great, but considering that the existing implementation errors, this is still a net win.