diff --git a/Project.toml b/Project.toml index efb82c9e..94293355 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ SpecialFunctions = "0.8, 0.9, 0.10" StatsBase = "0.32" StatsFuns = "0.8, 0.9" Tracker = "0.2.5" -Zygote = "0.4.7" +Zygote = "0.4.10" ZygoteRules = "0.2" julia = "1" diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 8b67d406..062751e1 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -15,7 +15,7 @@ using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray, TrackedVecOrMat, track, @grad, data using SpecialFunctions: logabsgamma, digamma using ZygoteRules: ZygoteRules, @adjoint, pullback -using LinearAlgebra: copytri! +using LinearAlgebra: copytri!, AbstractTriangular using Distributions: AbstractMvLogNormal, ContinuousMultivariateDistribution using DiffRules, SpecialFunctions, FillArrays diff --git a/src/common.jl b/src/common.jl index de9556e9..d3160c3f 100644 --- a/src/common.jl +++ b/src/common.jl @@ -40,9 +40,42 @@ end ## Linear algebra ## -LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A) -@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix) - return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),) +# Work around https://github.com/FluxML/Tracker.jl/pull/9#issuecomment-480051767 + +upper(A::AbstractMatrix) = UpperTriangular(A) +lower(A::AbstractMatrix) = LowerTriangular(A) +function upper(C::Cholesky) + if C.uplo == 'U' + return upper(C.factors) + else + return copy(lower(C.factors)') + end +end +function lower(C::Cholesky) + if C.uplo == 'U' + return copy(upper(C.factors)') + else + return lower(C.factors) + end +end + +LinearAlgebra.LowerTriangular(A::TrackedMatrix) = lower(A) +lower(A::TrackedMatrix) = track(lower, A) +@grad lower(A) = lower(Tracker.data(A)), ∇ -> (lower(∇),) + +LinearAlgebra.UpperTriangular(A::TrackedMatrix) = upper(A) +upper(A::TrackedMatrix) = track(upper, A) +@grad upper(A) = upper(Tracker.data(A)), ∇ -> (upper(∇),) + +function Base.copy( + A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}}, +) where {T <: Real} + return track(copy, A) +end +@grad function Base.copy( + A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}}, +) where {T <: Real} + return copy(data(A)), ∇ -> (copy(∇),) end function LinearAlgebra.cholesky(A::TrackedMatrix; check=true) @@ -57,40 +90,10 @@ function turing_chol(A::AbstractMatrix, check) end turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check) @grad function turing_chol(A::AbstractMatrix, check) - C, back = pullback(unsafe_cholesky, data(A), data(check)) + C, back = pullback(_turing_chol, data(A), data(check)) return (C.factors, C.info), Δ->back((factors=data(Δ[1]),)) end - -unsafe_cholesky(x, check) = cholesky(x, check=check) -@adjoint function unsafe_cholesky(Σ::Real, check) - C = cholesky(Σ; check=check) - return C, function(Δ::NamedTuple) - issuccess(C) || return (zero(Σ), nothing) - (Δ.factors[1, 1] / (2 * C.U[1, 1]), nothing) - end -end -@adjoint function unsafe_cholesky(Σ::Diagonal, check) - C = cholesky(Σ; check=check) - return C, function(Δ::NamedTuple) - issuccess(C) || (Diagonal(zero(diag(Δ.factors))), nothing) - (Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)), nothing) - end -end -@adjoint function unsafe_cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check) - C = cholesky(Σ; check=check) - return C, function(Δ::NamedTuple) - issuccess(C) || return (zero(Δ.factors), nothing) - U, Ū = C.U, Δ.factors - Σ̄ = Ū * U' - Σ̄ = copytri!(Σ̄, 'U') - Σ̄ = ldiv!(U, Σ̄) - BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄) - @inbounds for n in diagind(Σ̄) - Σ̄[n] /= 2 - end - return (UpperTriangular(Σ̄), nothing) - end -end +_turing_chol(x, check) = cholesky(x, check=check) # Specialised logdet for cholesky to target the triangle directly. logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)]) diff --git a/test/others.jl b/test/others.jl index 47354694..362bb698 100644 --- a/test/others.jl +++ b/test/others.jl @@ -1,13 +1,6 @@ using StatsBase: entropy if get_stage() in ("Others", "all") - @testset "unsafe_cholesky" begin - A = rand(3, 3); A = A + A' + 3I - @test Matrix(DistributionsAD.unsafe_cholesky(A, true)) == Matrix(cholesky(A)) - @test !issuccess(DistributionsAD.unsafe_cholesky(rand(3,3), false)) - @test_throws PosDefException DistributionsAD.unsafe_cholesky(rand(3,3), true) - end - @testset "TuringWishart" begin dim = 3 A = Matrix{Float64}(I, dim, dim) diff --git a/test/runtests.jl b/test/runtests.jl index c54684e4..dbdd775b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using DistributionsAD, Test, LinearAlgebra, Combinatorics using ForwardDiff: Dual using StatsFuns: binomlogpdf, logsumexp const FDM = FiniteDifferences -using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringUniform, unsafe_cholesky +using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringUniform using Distributions: meanlogdet include("test_utils.jl")