From 21ec0e008d1fb465e12c6642781b7399fb7f69eb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 30 Nov 2024 16:19:17 +0000 Subject: [PATCH 1/6] Improve numerical stability of Cholesky invlink --- Project.toml | 2 +- src/bijectors/corr.jl | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 97a65aff..84b1b77a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.15.2" +version = "0.15.3" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 93a5b089..ad440fde 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -348,13 +348,12 @@ function _inv_link_chol_lkj(Y::AbstractMatrix) T = float(eltype(W)) logJ = zero(T) - idx = 1 @inbounds for j in 1:K log_remainder = zero(T) # log of proportion of unit vector remaining for i in 1:(j - 1) z = tanh(Y[i, j]) W[i, j] = z * exp(log_remainder) - log_remainder += log1p(-z^2) / 2 + log_remainder += log(2 / (exp(Y[i, j]) + exp(-Y[i, j]))) logJ += log_remainder end logJ += log_remainder @@ -380,10 +379,10 @@ function _inv_link_chol_lkj(y::AbstractVector) log_remainder = zero(T) # log of proportion of unit vector remaining for i in 1:(j - 1) z = tanh(y[idx]) - idx += 1 W[i, j] = z * exp(log_remainder) - log_remainder += log1p(-z^2) / 2 + log_remainder += log(2 / (exp(y[idx]) + exp(-y[idx]))) logJ += log_remainder + idx += 1 end logJ += log_remainder W[j, j] = exp(log_remainder) @@ -497,7 +496,7 @@ function _logabsdetjac_inv_chol(y::AbstractVector) tmp = zero(result) for _ in 1:(j - 1) z = tanh(y[idx]) - logz = log(1 - z^2) + logz = 2 * log(2 / (exp(y[idx]) + exp(-y[idx]))) result += logz + (tmp / 2) tmp += logz idx += 1 From 571aefa3db9f090db853b71143536da75a6b2696 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 30 Nov 2024 18:35:08 +0000 Subject: [PATCH 2/6] Simplify further --- src/bijectors/corr.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index ad440fde..af46b1b9 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -353,7 +353,7 @@ function _inv_link_chol_lkj(Y::AbstractMatrix) for i in 1:(j - 1) z = tanh(Y[i, j]) W[i, j] = z * exp(log_remainder) - log_remainder += log(2 / (exp(Y[i, j]) + exp(-Y[i, j]))) + log_remainder += IrrationalConstants.logtwo + Y[i, j] - LogExpFunctions.log1pexp(2 * Y[i, j]) logJ += log_remainder end logJ += log_remainder @@ -380,7 +380,7 @@ function _inv_link_chol_lkj(y::AbstractVector) for i in 1:(j - 1) z = tanh(y[idx]) W[i, j] = z * exp(log_remainder) - log_remainder += log(2 / (exp(y[idx]) + exp(-y[idx]))) + log_remainder += IrrationalConstants.logtwo + y[idx] - LogExpFunctions.log1pexp(2 * y[idx]) logJ += log_remainder idx += 1 end @@ -495,8 +495,7 @@ function _logabsdetjac_inv_chol(y::AbstractVector) @inbounds for j in 2:K tmp = zero(result) for _ in 1:(j - 1) - z = tanh(y[idx]) - logz = 2 * log(2 / (exp(y[idx]) + exp(-y[idx]))) + logz = 2 * (IrrationalConstants.logtwo + y[idx] - LogExpFunctions.log1pexp(2 * y[idx])) result += logz + (tmp / 2) tmp += logz idx += 1 From 81bedf24c5f227e89c19c6f7407d02b4f79eaf19 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 30 Nov 2024 18:40:17 +0000 Subject: [PATCH 3/6] Use logcosh Co-authored-by: David Widmann --- src/bijectors/corr.jl | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index af46b1b9..cc90523a 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -353,7 +353,7 @@ function _inv_link_chol_lkj(Y::AbstractMatrix) for i in 1:(j - 1) z = tanh(Y[i, j]) W[i, j] = z * exp(log_remainder) - log_remainder += IrrationalConstants.logtwo + Y[i, j] - LogExpFunctions.log1pexp(2 * Y[i, j]) + log_remainder -= LogExpFunctions.logcosh(Y[i, j]) logJ += log_remainder end logJ += log_remainder @@ -380,7 +380,7 @@ function _inv_link_chol_lkj(y::AbstractVector) for i in 1:(j - 1) z = tanh(y[idx]) W[i, j] = z * exp(log_remainder) - log_remainder += IrrationalConstants.logtwo + y[idx] - LogExpFunctions.log1pexp(2 * y[idx]) + log_remainder -= LogExpFunctions.logcosh(y[idx]) logJ += log_remainder idx += 1 end @@ -460,13 +460,8 @@ function _logabsdetjac_inv_corr(Y::AbstractMatrix) K = LinearAlgebra.checksquare(Y) result = float(zero(eltype(Y))) - for j in 2:K, i in 1:(j - 1) - @inbounds abs_y_i_j = abs(Y[i, j]) - result += - (K - i + 1) * ( - IrrationalConstants.logtwo - - (abs_y_i_j + LogExpFunctions.log1pexp(-2 * abs_y_i_j)) - ) + @inbounds for j in 2:K, i in 1:(j - 1) + result += (K - i + 1) * (-LogExpFunctions.logcosh(Y[i, j])) end return result end @@ -495,7 +490,7 @@ function _logabsdetjac_inv_chol(y::AbstractVector) @inbounds for j in 2:K tmp = zero(result) for _ in 1:(j - 1) - logz = 2 * (IrrationalConstants.logtwo + y[idx] - LogExpFunctions.log1pexp(2 * y[idx])) + logz = -2 * LogExpFunctions.logcosh(y[idx]) result += logz + (tmp / 2) tmp += logz idx += 1 From f2b58cdd69fe49ee792f373326f009be9bf42fd5 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 30 Nov 2024 18:48:58 +0000 Subject: [PATCH 4/6] Swap over one more occurrence of logcosh --- src/bijectors/corr.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index cc90523a..5cde58b7 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -461,7 +461,7 @@ function _logabsdetjac_inv_corr(Y::AbstractMatrix) result = float(zero(eltype(Y))) @inbounds for j in 2:K, i in 1:(j - 1) - result += (K - i + 1) * (-LogExpFunctions.logcosh(Y[i, j])) + result -= (K - i + 1) * LogExpFunctions.logcosh(Y[i, j]) end return result end @@ -471,13 +471,8 @@ function _logabsdetjac_inv_corr(y::AbstractVector) result = float(zero(eltype(y))) for (i, y_i) in enumerate(y) - abs_y_i = abs(y_i) row_idx = vec_to_triu1_row_index(i) - result += - (K - row_idx + 1) * ( - IrrationalConstants.logtwo - - (abs_y_i + LogExpFunctions.log1pexp(-2 * abs_y_i)) - ) + result -= (K - row_idx + 1) * LogExpFunctions.logcosh(y_i) end return result end From f2987ebc9450c0201769fed9931fc2affdbebd3e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sat, 30 Nov 2024 19:32:06 +0000 Subject: [PATCH 5/6] Update Stan documentation link --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 5cde58b7..58aa98fc 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -2,7 +2,7 @@ CorrBijector <: Bijector A bijector implementation of Stan's parametrization method for Correlation matrix: -https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html +https://mc-stan.org/docs/reference-manual/transforms.html#correlation-matrix-transform.section Basically, a unconstrained strictly upper triangular matrix `y` is transformed to a correlation matrix by following readable but not that efficient form: From 72599f7184914acf137ada8643bd8b8a5684c8e6 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 1 Dec 2024 01:32:41 +0000 Subject: [PATCH 6/6] Simplify loop in _logabsdetjac_inv_chol --- src/bijectors/corr.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 58aa98fc..c511de0b 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -485,9 +485,9 @@ function _logabsdetjac_inv_chol(y::AbstractVector) @inbounds for j in 2:K tmp = zero(result) for _ in 1:(j - 1) - logz = -2 * LogExpFunctions.logcosh(y[idx]) - result += logz + (tmp / 2) - tmp += logz + logcoshy = LogExpFunctions.logcosh(y[idx]) + tmp -= logcoshy + result += tmp - logcoshy idx += 1 end end