Skip to content

Commit

Permalink
Improve numerical stability of Cholesky invlink
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Nov 30, 2024
1 parent f52a9c5 commit 21ec0e0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.15.2"
version = "0.15.3"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
9 changes: 4 additions & 5 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,12 @@ function _inv_link_chol_lkj(Y::AbstractMatrix)
T = float(eltype(W))
logJ = zero(T)

idx = 1
@inbounds for j in 1:K
log_remainder = zero(T) # log of proportion of unit vector remaining
for i in 1:(j - 1)
z = tanh(Y[i, j])
W[i, j] = z * exp(log_remainder)
log_remainder += log1p(-z^2) / 2
log_remainder += log(2 / (exp(Y[i, j]) + exp(-Y[i, j])))
logJ += log_remainder
end
logJ += log_remainder
Expand All @@ -380,10 +379,10 @@ function _inv_link_chol_lkj(y::AbstractVector)
log_remainder = zero(T) # log of proportion of unit vector remaining
for i in 1:(j - 1)
z = tanh(y[idx])
idx += 1
W[i, j] = z * exp(log_remainder)
log_remainder += log1p(-z^2) / 2
log_remainder += log(2 / (exp(y[idx]) + exp(-y[idx])))
logJ += log_remainder
idx += 1
end
logJ += log_remainder
W[j, j] = exp(log_remainder)
Expand Down Expand Up @@ -497,7 +496,7 @@ function _logabsdetjac_inv_chol(y::AbstractVector)
tmp = zero(result)
for _ in 1:(j - 1)
z = tanh(y[idx])
logz = log(1 - z^2)
logz = 2 * log(2 / (exp(y[idx]) + exp(-y[idx])))
result += logz + (tmp / 2)
tmp += logz
idx += 1
Expand Down

0 comments on commit 21ec0e0

Please sign in to comment.