Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cholesky numerical stability: Forward transform #357

Draft
wants to merge 4 commits into
base: py/chol-numerical
Choose a base branch
from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Dec 1, 2024

This is a companion PR to #356. It attempts to solve the following issue, first reported in #279:

using Bijectors
using Distributions

θ_unconstrained = [
	-1.9887091960524537,
	-13.499454444466279,
	-0.39328331954134665,
	-4.426097270849902,
	13.101175413857023,
	7.66647404712346,
	9.249285786544894,
	4.714877413573335,
	6.233118490809442,
	22.28264809311481
]
n = 5
d = LKJCholesky(n, 10)
b = Bijectors.bijector(d)
b_inv = inverse(b)

θ = b_inv(θ_unconstrained)
Bijectors.logabsdetjac(b, θ)

# ERROR: DomainError with 1.0085229361957693:
# atanh(x) is only defined for |x| ≤ 1.

Introduction

The forward transform acts on an upper triangular matrix, W, which is supposed to have unit vectors for each column, i.e. sum(W[:, j] .^ 2) should be 1 for each j:

julia> s = rand(LKJCholesky(5, 1.0, 'U')).U
5×5 UpperTriangular{Float64, Matrix{Float64}}:
 1.0  0.345448  -0.478      0.455158   0.385151
     0.938438  -0.331921  -0.305083  -0.0469749
               0.813231  -0.397178   0.831726
                         0.73621    0.0298828
                                   0.395968

julia> [sum(s[:, i] .^ 2) for i in 1:5]
5-element Vector{Float64}:
 1.0
 1.0
 1.0
 1.0
 1.0000000000000002

In the forward transform code, remainder_sq is initialised at one and then the squares of each element going down column j are successively subtracted, so remainder_sq is really a sum of squares of elements not yet seen.

@inbounds for j in 2:K
y[idx] = atanh(W[1, j])
idx += 1
remainder_sq = 1 - W[1, j]^2
for i in 2:(j - 1)
z = W[i, j] / sqrt(remainder_sq)
y[idx] = atanh(z)
remainder_sq -= W[i, j]^2
idx += 1
end
end

Now, in principle, because z^2 = W[i, j]^2 / (sum of W[i:end, j]^2), there is no way that z^2 can be larger than 1.

However, because of floating point imprecisions, sometimes this isn't true. This is especially likely to happen if the last element W[j-1, j] is very small. This doesn't tend to happen when W is sampled from LKJCholesky, but it can happen when W is obtained through inverse transformation of some random unconstrained vector, as described in e.g. #279.

A proposed fix, instead of subtracting successive squares from 1, could just declare remainder_sq to be that sum:

    @inbounds for j in 2:K
-       remainder_sq = 1 - W[1, j]^2
        for i in 2:(j - 1)
+           remainder_sq = sum(W[i:end, j] .^ 2)
            z = W[i, j] / sqrt(remainder_sq)
            y[idx] = atanh(z)
-           remainder_sq -= W[i, j]^2
            idx += 1
        end
    end

In practice, this is implemented by looping in reverse (over(j-1):-1:2) and incrementing remainder_sq following @devmotion's suggestion below.

Now, because W[i, j] ^ 2 is part of that sum, z can now no longer be larger than 1, and atanh doesn't throw a DomainError.

Setup code for this comment

Setup code
using Bijectors
using LinearAlgebra
using Distributions
using Random
using Plots
using LogExpFunctions

# Using the invlink definition from this PR
_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2
function _inv_link_chol_lkj_new(y::AbstractVector)
    LinearAlgebra.require_one_based_indexing(y)
    K = _triu1_dim_from_length(length(y))
    W = similar(y, K, K)
    T = float(eltype(W))
    logJ = zero(T)
    idx = 1
    @inbounds for j in 1:K
        log_remainder = zero(T)  # log of proportion of unit vector remaining
        for i in 1:(j - 1)
            z = tanh(y[idx])
            W[i, j] = z * exp(log_remainder)
            log_remainder -= LogExpFunctions.logcosh(y[idx])
            logJ += log_remainder
            idx += 1
        end
        logJ += log_remainder
        W[j, j] = exp(log_remainder)
        for i in (j + 1):K
            W[i, j] = 0
        end
    end
    return W, logJ
end

# Existing link implementation
function _link_chol_lkj_from_upper_old(W::AbstractMatrix)
    K = LinearAlgebra.checksquare(W)
    N = ((K - 1) * K) ÷ 2   # {K \choose 2} free parameters
    y = similar(W, N)
    idx = 1
    @inbounds for j in 2:K
        y[idx] = atanh(W[1, j])
        idx += 1
        remainder_sq = 1 - W[1, j]^2
        for i in 2:(j - 1)
            z = W[i, j] / sqrt(remainder_sq)
            y[idx] = atanh(z)
            remainder_sq -= W[i, j]^2
            idx += 1
        end
    end
    return y
end

# New proposal
function _link_chol_lkj_from_upper_new(W::AbstractMatrix)
    K = LinearAlgebra.checksquare(W)
    N = ((K - 1) * K) ÷ 2   # {K \choose 2} free parameters
    y = similar(W, N)
    starting_idx = 1
    @inbounds for j in 2:K
        y[starting_idx] = atanh(W[1, j])
        starting_idx += 1
        remainder_sq = W[j, j]^2
        for i in (j - 1):-1:2
            idx = starting_idx + i - 2
            remainder_sq += W[i, j]^2
            z = W[i, j] / sqrt(remainder_sq)
            y[idx] = atanh(z)
        end
        starting_idx += length(j-1:-1:2)
    end
    return y
end

function plot_maes(samples)
    log_mae_old = log10.([sample[1] for sample in samples])
    log_mae_new = log10.([sample[2] for sample in samples])
    scatter(log_mae_old, log_mae_new, label="")
    lim_min = floor(min(minimum(log_mae_old), minimum(log_mae_new)))
    lim_max = ceil(max(maximum(log_mae_old), maximum(log_mae_new)))
    plot!(lim_min:lim_max, lim_min:lim_max, label="y=x", color=:black)
    xlabel!("log10(maximum abs error old)")
    ylabel!("log10(maximum abs error new)")
end

function test_forward_bijector(f_old, f_new)
    dist = LKJCholesky(5, 1.0, 'U')
    Random.seed!(468)
    samples = map(1:500) do _
        x = rand(dist)
        x_again_old = _inv_link_chol_lkj_new(f_old(x.U))[1]
        x_again_new = _inv_link_chol_lkj_new(f_new(x.U))[1]
        # Return the maximum absolute error between the original sample
        # and sample after roundtrip transformation
        (maximum(abs.(x.U - x_again_old)), maximum(abs.(x.U - x_again_new)))
    end
    return samples
end

Impacts of this change

First, let's check roundtrip transformation on typical samples from Cholesky. The numerical accuracy here is actually marginally better than the existing implementation:

julia> plot_maes(test_forward_bijector(_link_chol_lkj_from_upper_old, _link_chol_lkj_from_upper_new))

bijector_forward_typical

On top of that, it fixes the DomainErrors which occur with random unconstrained inputs:

julia> y = rand(Random.Xoshiro(468), 10) * 16;

julia> x = _inv_link_chol_lkj_new(y)[1];

julia> y_old = _link_chol_lkj_from_upper_old(x)
ERROR: DomainError with 1.000207932997037:
atanh(x) is only defined for |x|  1.
Stacktrace:
 [1] atanh_domain_error(x::Float64)
   @ Base.Math ./special/hyperbolic.jl:240
 [2] atanh
   @ ./special/hyperbolic.jl:256 [inlined]
 [3] _link_chol_lkj_from_upper_old(W::Matrix{Float64})
   @ Main ./REPL[60]:12
 [4] top-level scope
   @ REPL[111]:1

julia> y_new = _link_chol_lkj_from_upper_new(x)
10-element Vector{Float64}:
  1.7139942709891685
  4.050190371709019
 12.606351374271206
  8.239542965781226
  7.8978551586304855
  6.885928358454504
  7.201266901997009
  4.588778566499247
  5.507106236959028
 11.582258189742753

Performance

julia> using Chairmarks

julia> @be (rand(LKJCholesky(5, 1.0, 'U'))) _link_chol_lkj_from_upper_old(_.U)
Benchmark: 4882 samples with 143 evaluations
 min    108.392 ns (2 allocs: 144 bytes)
 median 118.301 ns (2 allocs: 144 bytes)
 mean   125.184 ns (2 allocs: 144 bytes, 0.20% gc time)
 max    12.373 μs (2 allocs: 144 bytes, 97.83% gc time)

julia> @be (rand(LKJCholesky(5, 1.0, 'U'))) _link_chol_lkj_from_upper_new(_.U)
Benchmark: 2924 samples with 241 evaluations
 min    109.959 ns (2 allocs: 144 bytes)
 median 119.295 ns (2 allocs: 144 bytes)
 mean   129.468 ns (2 allocs: 144 bytes, 0.33% gc time)
 max    9.770 μs (2 allocs: 144 bytes, 97.67% gc time)

Remaining concern: accuracy on pathological samples

It's not great, but considering that the existing implementation errors, this is still a net win.

julia> maximum(abs.(y - y_new))
1.2023473718869582e-6

@penelopeysm
Copy link
Member Author

penelopeysm commented Dec 1, 2024

CI is failing because ForwardDiff calculates quite a different Jacobian in this test.

Since the new implementation still works correctly on roundtrip transformation, and this PR doesn't touch anything to do with Jacobian calculations, I wonder if this is a ForwardDiff bug (?)

@testset "LKJCholesky" begin
@testset "uplo: $uplo" for uplo in [:L, :U]
dist = LKJCholesky(3, 1, uplo)
single_sample_tests(dist)
x = rand(dist)
inds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
(uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
]
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
J = J[:, inds]
logpdf_turing = logpdf_with_trans(dist, x, true)
@test logpdf(dist, x) - _logabsdet(J) logpdf_turing
end
end

Repro:

using Bijectors
using ForwardDiff: ForwardDiff
using LinearAlgebra: logabsdet, I, Cholesky
using Random

uplo = :L
dist = LKJCholesky(3, 1, uplo)
x = rand(Xoshiro(468), dist)
inds = [
    LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
    (uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
]
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
J = J[:, inds]
logabsdet(J)[1]

Before this PR:

julia> x = rand(Xoshiro(468), dist)
Cholesky{Float64, Matrix{Float64}}
L factor:
3×3 LinearAlgebra.LowerTriangular{Float64, Matrix{Float64}}:
  1.0                  
 -0.23039    0.973098   
  0.288231  -0.899424  0.328572

julia> logabsdet(J)[1]
2.3239053137427703

With this PR:

julia> logabsdet(J)[1]
0.184638090189601

@@ -297,11 +297,10 @@ function _link_chol_lkj(W::AbstractMatrix)
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero.

@inbounds for j in 1:K
remainder_sq = one(eltype(W))
for i in 1:(j - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could save a few operations by reversing the loop and summing remainder_sq incrementally.

@penelopeysm
Copy link
Member Author

penelopeysm commented Dec 1, 2024

Looping is implemented in reverse now. Had to be careful with indices but it's looking good, the original comment has been updated with benchmarks etc:)

The ForwardDiff test is still wonky, though. Need to dig into that one a bit...

@devmotion
Copy link
Member

devmotion commented Dec 1, 2024

Another point: I wonder if generally it would be better to avoid atanh completely, due to its constrained domain and its derivative 1/(1 - x^2) which might be problematic close to 1 and -1. Note that you could rewrite $$\mathrm{atanh}\left(w_{i,j} / \sqrt{\sum_{k=i}^{j} w_{k,j}^2}\right)$$ as $$\mathrm{asinh}\left(w_{i,j}/\sqrt{\sum_{k={i+1}}^j w_{k,j}^2}\right)$$.

@penelopeysm
Copy link
Member Author

penelopeysm commented Dec 7, 2024

Tried to look into the AD issues. The problem is that the new implementation uses different matrix elements to arrive at the same answer, but ForwardDiff doesn't know that $\sum_i W_{ij}^2 = 1$, so it generates a different Jacobian.

using ForwardDiff
using Bijectors
using LinearAlgebra
using Random
using LogExpFunctions
using Test: @test

dist = LKJCholesky(3, 1, 'U')

# Minimised version of new forward transform
function chol_upper_new(W::AbstractMatrix)
    y1 = atanh(W[1, 2])
    y2 = atanh(W[1, 3])
    z = W[2, 3] / W[3, 3]
    # NOTE: If we replace W[3, 3] here with sqrt(1 - W[1, 3]^2 - W[2, 3]^2) then
    # the autodiff works. But the whole problem / the point of this PR was that
    # this is numerically unstable for small values of W[3, 3].
    y3 = asinh(z)
    return [y1, y2, y3]
end

# Minimised version of old forward transform
function chol_upper_old(W::AbstractMatrix)
    y1 = atanh(W[1, 2])
    y2 = atanh(W[1, 3])
    z = W[2, 3] / sqrt(1 - W[1, 3]^2)
    y3 = atanh(z)
    return [y1, y2, y3]
end

# Check that they both do the same thing, and they both do the same
# thing as Bijectors.jl forward transform – so they are the same function
for _ in 1:1000
    x_new = rand(dist)
    @test chol_upper_new(x_new.UL)  chol_upper_old(x_new.UL) atol=1e-12
    @test chol_upper_new(x_new.UL)  bijector(dist)(x_new) atol=1e-12
end

# ForwardDiff gives different Jacobians though:
x = rand(Random.Xoshiro(468), dist)

J_FD_new = ForwardDiff.jacobian(chol_upper_new, x.UL)
# 3×9 Matrix{Float64}:
#  0.0  0.0  0.0  1.05605  0.0  0.0  0.0     0.0      0.0
#  0.0  0.0  0.0  0.0      0.0  0.0  1.0906  0.0      0.0
#  0.0  0.0  0.0  0.0      0.0  0.0  0.0     1.04432  2.85869

J_FD_old = ForwardDiff.jacobian(chol_upper_old, x.UL)
# 3×9 Matrix{Float64}:
#  0.0  0.0  0.0  1.05605  0.0  0.0   0.0      0.0      0.0
#  0.0  0.0  0.0  0.0      0.0  0.0   1.0906   0.0      0.0
#  0.0  0.0  0.0  0.0      0.0  0.0  -2.50772  8.86963  0.0

So it seems like the issue in the failing test might be the indexing by inds (which for an upper-triangular matrix evaluates to [4, 7, 8]):

inds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
(uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
]
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
J = J[:, inds]

With the old implementation, all the other columns are zero, but with the new one this isn't true. 😬


Edit: Fixed the test by giving ForwardDiff only a vector of the free parameters in the Cholesky factor, instead of giving it the whole matrix. It only works on 3x3 for now

@penelopeysm
Copy link
Member Author

Remaining failing test:

test_rrule(
Bijectors._link_chol_lkj_from_upper,
x.U;
testset_name="_link_chol_lkj_from_upper on $(typeof(x)) [$i]",
)
test_rrule(
Bijectors._link_chol_lkj_from_lower,
x.L;
testset_name="_link_chol_lkj_from_lower on $(typeof(x)) [$i]",
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants