@@ -293,20 +293,20 @@ which is the above implementation.
293
293
function _link_chol_lkj (W:: AbstractMatrix )
294
294
K = LinearAlgebra. checksquare (W)
295
295
296
- y = similar (W) # z is also UpperTriangular.
296
+ y = similar (W) # W is upper triangular.
297
297
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero.
298
298
299
299
@inbounds for j in 1 : K
300
- for i in 1 : (j - 1 )
301
- remainder_norm = norm (W[i: end , j])
302
- z = W[i, j] / remainder_norm
300
+ remainder_sq = W[j, j]^ 2
301
+ for i in (j - 1 ): - 1 : 1
302
+ remainder_sq += W[i, j]^ 2
303
+ z = W[i, j] / sqrt (remainder_sq)
303
304
y[i, j] = atanh (z)
304
305
end
305
306
for i in j: K
306
307
y[i, j] = 0
307
308
end
308
309
end
309
-
310
310
return y
311
311
end
312
312
@@ -316,16 +316,18 @@ function _link_chol_lkj_from_upper(W::AbstractMatrix)
316
316
317
317
y = similar (W, N)
318
318
319
- idx = 1
319
+ starting_idx = 1
320
320
@inbounds for j in 2 : K
321
- y[idx] = atanh (W[1 , j])
322
- idx += 1
323
- for i in 2 : (j - 1 )
324
- remainder_norm = norm (W[i: end , j])
325
- z = W[i, j] / remainder_norm
321
+ y[starting_idx] = atanh (W[1 , j])
322
+ starting_idx += 1
323
+ remainder_sq = W[j, j]^ 2
324
+ for i in (j - 1 ): - 1 : 2
325
+ idx = starting_idx + i - 2
326
+ remainder_sq += W[i, j]^ 2
327
+ z = W[i, j] / sqrt (remainder_sq)
326
328
y[idx] = atanh (z)
327
- idx += 1
328
329
end
330
+ starting_idx += length ((j - 1 ): - 1 : 2 )
329
331
end
330
332
331
333
return y
0 commit comments