From f7bc09efb0b7b5e9f51d599ea92ce6f1cbc45340 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 1 Dec 2024 02:49:28 +0000 Subject: [PATCH 1/4] Recalculate sum-of-squares in Cholesky forward link --- src/bijectors/corr.jl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index c511de0b..ffb8e3e3 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -297,11 +297,10 @@ function _link_chol_lkj(W::AbstractMatrix) # Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero. @inbounds for j in 1:K - remainder_sq = one(eltype(W)) for i in 1:(j - 1) - z = W[i, j] / sqrt(remainder_sq) + remainder_norm = norm(W[i:end, j]) + z = W[i, j] / remainder_norm y[i, j] = atanh(z) - remainder_sq -= W[i, j]^2 end for i in j:K y[i, j] = 0 @@ -321,11 +320,10 @@ function _link_chol_lkj_from_upper(W::AbstractMatrix) @inbounds for j in 2:K y[idx] = atanh(W[1, j]) idx += 1 - remainder_sq = 1 - W[1, j]^2 for i in 2:(j - 1) - z = W[i, j] / sqrt(remainder_sq) + remainder_norm = norm(W[i:end, j]) + z = W[i, j] / remainder_norm y[idx] = atanh(z) - remainder_sq -= W[i, j]^2 idx += 1 end end From 229ea4c2fe644914d7e9a335742a84d66c7fa365 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 1 Dec 2024 19:23:46 +0000 Subject: [PATCH 2/4] Loop in reverse direction to improve performance --- src/bijectors/corr.jl | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index ffb8e3e3..efac3949 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -293,20 +293,20 @@ which is the above implementation. function _link_chol_lkj(W::AbstractMatrix) K = LinearAlgebra.checksquare(W) - y = similar(W) # z is also UpperTriangular. + y = similar(W) # W is upper triangular. # Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero. @inbounds for j in 1:K - for i in 1:(j - 1) - remainder_norm = norm(W[i:end, j]) - z = W[i, j] / remainder_norm + remainder_sq = W[j, j]^2 + for i in (j - 1):-1:1 + remainder_sq += W[i, j]^2 + z = W[i, j] / sqrt(remainder_sq) y[i, j] = atanh(z) end for i in j:K y[i, j] = 0 end end - return y end @@ -316,16 +316,18 @@ function _link_chol_lkj_from_upper(W::AbstractMatrix) y = similar(W, N) - idx = 1 + starting_idx = 1 @inbounds for j in 2:K - y[idx] = atanh(W[1, j]) - idx += 1 - for i in 2:(j - 1) - remainder_norm = norm(W[i:end, j]) - z = W[i, j] / remainder_norm + y[starting_idx] = atanh(W[1, j]) + starting_idx += 1 + remainder_sq = W[j, j]^2 + for i in (j - 1):-1:2 + idx = starting_idx + i - 2 + remainder_sq += W[i, j]^2 + z = W[i, j] / sqrt(remainder_sq) y[idx] = atanh(z) - idx += 1 end + starting_idx += length((j - 1):-1:2) end return y From c43d2e8c0f4ce39802657b54ab484c139f8df590 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 7 Dec 2024 17:43:46 +0000 Subject: [PATCH 3/4] Use asinh --- src/bijectors/corr.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index efac3949..af368df0 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -299,9 +299,9 @@ function _link_chol_lkj(W::AbstractMatrix) @inbounds for j in 1:K remainder_sq = W[j, j]^2 for i in (j - 1):-1:1 - remainder_sq += W[i, j]^2 z = W[i, j] / sqrt(remainder_sq) - y[i, j] = atanh(z) + y[i, j] = asinh(z) + remainder_sq += W[i, j]^2 end for i in j:K y[i, j] = 0 @@ -323,9 +323,9 @@ function _link_chol_lkj_from_upper(W::AbstractMatrix) remainder_sq = W[j, j]^2 for i in (j - 1):-1:2 idx = starting_idx + i - 2 - remainder_sq += W[i, j]^2 z = W[i, j] / sqrt(remainder_sq) - y[idx] = atanh(z) + y[idx] = asinh(z) + remainder_sq += W[i, j]^2 end starting_idx += length((j - 1):-1:2) end From 4ff6ddc9ae2c5aa9c04371ab295fb2ab1753d91f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 8 Dec 2024 12:08:35 +0000 Subject: [PATCH 4/4] Fix ForwardDiff test --- test/transform.jl | 42 +++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/test/transform.jl b/test/transform.jl index 85477535..6d7b5e77 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -237,18 +237,46 @@ end end @testset "LKJCholesky" begin + # Convert Cholesky factor to its free parameters, i.e. its off-diagonal elements + function chol_3by3_to_free_params(x::Cholesky) + if x.uplo == :U + return [x.U[1, 2], x.U[1, 3], x.U[2, 3]] + else + return [x.L[2, 1], x.L[3, 1], x.L[3, 2]] + end + # TODO: Generalise to arbitrary dimension using this code: + # inds = [ + # LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if + # (uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1]) + # ] + end + + # Reconstruct Cholesky factor from its free parameters + # Note that x[i, i] is always positive so we don't need to worry about the sign + function free_params_to_chol_3by3(free_params::AbstractVector, uplo::Symbol) + x = UpperTriangular(zeros(eltype(free_params), 3, 3)) + x[1, 1] = 1 + x[1, 2] = free_params[1] + x[1, 3] = free_params[2] + x[2, 2] = sqrt(1 - free_params[1]^2) + x[2, 3] = free_params[3] + x[3, 3] = sqrt(1 - free_params[2]^2 - free_params[3]^2) + if uplo == :U + return Cholesky(x) + else + return Cholesky(transpose(x)) + end + end + @testset "uplo: $uplo" for uplo in [:L, :U] dist = LKJCholesky(3, 1, uplo) single_sample_tests(dist) x = rand(dist) - - inds = [ - LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if - (uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1]) - ] - J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL) - J = J[:, inds] + # Here, we need to pass ForwardDiff only the free parameters of the + # Cholesky factor so that we get a square Jacobian matrix + free_params = chol_3by3_to_free_params(x) + J = ForwardDiff.jacobian(z -> link(dist, free_params_to_chol_3by3(z, uplo)), free_params) logpdf_turing = logpdf_with_trans(dist, x, true) @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing end