Skip to content

Commit 229ea4c

Browse files
committed
Loop in reverse direction to improve performance
1 parent f7bc09e commit 229ea4c

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

src/bijectors/corr.jl

+14-12
Original file line numberDiff line numberDiff line change
@@ -293,20 +293,20 @@ which is the above implementation.
293293
function _link_chol_lkj(W::AbstractMatrix)
294294
K = LinearAlgebra.checksquare(W)
295295

296-
y = similar(W) # z is also UpperTriangular.
296+
y = similar(W) # W is upper triangular.
297297
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero.
298298

299299
@inbounds for j in 1:K
300-
for i in 1:(j - 1)
301-
remainder_norm = norm(W[i:end, j])
302-
z = W[i, j] / remainder_norm
300+
remainder_sq = W[j, j]^2
301+
for i in (j - 1):-1:1
302+
remainder_sq += W[i, j]^2
303+
z = W[i, j] / sqrt(remainder_sq)
303304
y[i, j] = atanh(z)
304305
end
305306
for i in j:K
306307
y[i, j] = 0
307308
end
308309
end
309-
310310
return y
311311
end
312312

@@ -316,16 +316,18 @@ function _link_chol_lkj_from_upper(W::AbstractMatrix)
316316

317317
y = similar(W, N)
318318

319-
idx = 1
319+
starting_idx = 1
320320
@inbounds for j in 2:K
321-
y[idx] = atanh(W[1, j])
322-
idx += 1
323-
for i in 2:(j - 1)
324-
remainder_norm = norm(W[i:end, j])
325-
z = W[i, j] / remainder_norm
321+
y[starting_idx] = atanh(W[1, j])
322+
starting_idx += 1
323+
remainder_sq = W[j, j]^2
324+
for i in (j - 1):-1:2
325+
idx = starting_idx + i - 2
326+
remainder_sq += W[i, j]^2
327+
z = W[i, j] / sqrt(remainder_sq)
326328
y[idx] = atanh(z)
327-
idx += 1
328329
end
330+
starting_idx += length((j - 1):-1:2)
329331
end
330332

331333
return y

0 commit comments

Comments
 (0)