@@ -4,44 +4,28 @@ function turing_chol(A::AbstractMatrix, check)
4
4
chol = cholesky (A, check= check)
5
5
(chol. factors, chol. info)
6
6
end
7
- function ChainRules. rrule (:: typeof (turing_chol), A:: AbstractMatrix , check)
8
- factors, info = turing_chol (A, check)
9
- function turing_chol_pullback (Ȳ)
10
- f̄ = Ȳ[1 ]
11
- ∂A = ChainRules. chol_blocked_rev (f̄, factors, 25 , true )
12
- return (ChainRules. NO_FIELDS, ∂A, ChainRules. DoesNotExist ())
13
- end
14
- (factors,info), turing_chol_pullback
15
- end
16
7
function turing_chol_back (A:: AbstractMatrix , check)
17
- C, dC_pullback = rrule (turing_chol , A, check)
8
+ C, chol_pullback = rrule (cholesky , A, Val ( false ), check = check)
18
9
function back (Δ)
19
- _, dC = dC_pullback (Δ)
20
- (dC, nothing )
10
+ Ȳ = Composite {typeof(C)} ((U= Δ[1 ]))
11
+ ∂C = chol_pullback (Ȳ)[2 ]
12
+ (∂C, nothing )
21
13
end
22
- C , back
14
+ (C . factors,C . info) , back
23
15
end
24
16
25
17
function symm_turing_chol (A:: AbstractMatrix , check, uplo)
26
18
chol = cholesky (Symmetric (A, uplo), check= check)
27
19
(chol. factors, chol. info)
28
20
end
29
- function ChainRules. rrule (:: typeof (symm_turing_chol), A:: AbstractMatrix , check, uplo)
30
- factors, info = symm_turing_chol (A, check, uplo)
31
- function symm_turing_chol_pullback (Ȳ)
32
- f̄ = Ȳ[1 ]
33
- ∂A = ChainRules. chol_blocked_rev (f̄, factors, 25 , true )
34
- return (ChainRules. NO_FIELDS, ∂A, ChainRules. DoesNotExist (), ChainRules. DoesNotExist ())
35
- end
36
- return (factors,info), symm_turing_chol_pullback
37
- end
38
21
function symm_turing_chol_back (A:: AbstractMatrix , check, uplo)
39
- C, dC_pullback = rrule (symm_turing_chol, A, check, uplo )
22
+ C, chol_pullback = rrule (cholesky, Symmetric (A,uplo), Val ( false ), check = check )
40
23
function back (Δ)
41
- _, dC = dC_pullback (Δ)
42
- (dC, nothing , nothing )
24
+ Ȳ = Composite {typeof(C)} ((U= Δ[1 ]))
25
+ ∂C = chol_pullback (Ȳ)[2 ]
26
+ (∂C, nothing , nothing )
43
27
end
44
- C , back
28
+ (C . factors, C . info) , back
45
29
end
46
30
47
31
0 commit comments