Skip to content

Commit beddedd

Browse files
authored
Don't depend on ChainRules internals (#142)
1 parent ef1ffba commit beddedd

File tree

2 files changed

+11
-27
lines changed

2 files changed

+11
-27
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DistributionsAD"
22
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
3-
version = "0.6.12"
3+
version = "0.6.13"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/common.jl

+10-26
Original file line numberDiff line numberDiff line change
@@ -4,44 +4,28 @@ function turing_chol(A::AbstractMatrix, check)
44
chol = cholesky(A, check=check)
55
(chol.factors, chol.info)
66
end
7-
function ChainRules.rrule(::typeof(turing_chol), A::AbstractMatrix, check)
8-
factors, info = turing_chol(A, check)
9-
function turing_chol_pullback(Ȳ)
10-
= Ȳ[1]
11-
∂A = ChainRules.chol_blocked_rev(f̄, factors, 25, true)
12-
return (ChainRules.NO_FIELDS, ∂A, ChainRules.DoesNotExist())
13-
end
14-
(factors,info), turing_chol_pullback
15-
end
167
function turing_chol_back(A::AbstractMatrix, check)
17-
C, dC_pullback = rrule(turing_chol, A, check)
8+
C, chol_pullback = rrule(cholesky, A, Val(false), check=check)
189
function back(Δ)
19-
_, dC = dC_pullback(Δ)
20-
(dC, nothing)
10+
= Composite{typeof(C)}((U=Δ[1]))
11+
∂C = chol_pullback(Ȳ)[2]
12+
(∂C, nothing)
2113
end
22-
C, back
14+
(C.factors,C.info), back
2315
end
2416

2517
function symm_turing_chol(A::AbstractMatrix, check, uplo)
2618
chol = cholesky(Symmetric(A, uplo), check=check)
2719
(chol.factors, chol.info)
2820
end
29-
function ChainRules.rrule(::typeof(symm_turing_chol), A::AbstractMatrix, check, uplo)
30-
factors, info = symm_turing_chol(A, check, uplo)
31-
function symm_turing_chol_pullback(Ȳ)
32-
= Ȳ[1]
33-
∂A = ChainRules.chol_blocked_rev(f̄, factors, 25, true)
34-
return (ChainRules.NO_FIELDS, ∂A, ChainRules.DoesNotExist(), ChainRules.DoesNotExist())
35-
end
36-
return (factors,info), symm_turing_chol_pullback
37-
end
3821
function symm_turing_chol_back(A::AbstractMatrix, check, uplo)
39-
C, dC_pullback = rrule(symm_turing_chol, A, check, uplo)
22+
C, chol_pullback = rrule(cholesky, Symmetric(A,uplo), Val(false), check=check)
4023
function back(Δ)
41-
_, dC = dC_pullback(Δ)
42-
(dC, nothing, nothing)
24+
= Composite{typeof(C)}((U=Δ[1]))
25+
∂C = chol_pullback(Ȳ)[2]
26+
(∂C, nothing, nothing)
4327
end
44-
C, back
28+
(C.factors, C.info), back
4529
end
4630

4731

0 commit comments

Comments
 (0)