Skip to content

Commit

Permalink
Update DistributionsAD support (#89)
Browse files Browse the repository at this point in the history
* update DistribtuionsAD and Zygote

* bump Zygote and DAD comapt versions

* minor fix
  • Loading branch information
mohamed82008 authored Mar 15, 2020
1 parent 86c6a67 commit e4065e9
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 60 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ArgCheck = "1, 2.0"
Combinatorics = "0.7"
Compat = "3.0"
Distributions = "0.21.11, 0.22"
DistributionsAD = "0.4.2"
DistributionsAD = "0.4.3"
ForwardDiff = "0.10.3"
MappedArrays = "0.2.2"
NNlib = "0.6"
Expand All @@ -31,7 +31,7 @@ Requires = "0.5, 1"
Roots = "0.8.4, 1.0"
StatsFuns = "0.8, 0.9.3"
Tracker = "0.2.3"
Zygote = "0.4.7"
Zygote = "0.4.10"
julia = "1"

[extras]
Expand Down
3 changes: 2 additions & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using LinearAlgebra
using MappedArrays
using Roots
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular

export TransformDistribution,
PositiveDistribution,
Expand Down Expand Up @@ -366,7 +367,7 @@ function _logpdf_with_trans_pd(
T = eltype(X)
Xcf = cholesky(X, check = false)
if !issuccess(Xcf)
Xcf = cholesky(X + (eps(T) * norm(X)) * I)
Xcf = cholesky(X + max(eps(T), eps(T) * norm(X)) * I)
end
lp = getlogp(d, Xcf, X)
if transform && isfinite(lp)
Expand Down
10 changes: 5 additions & 5 deletions src/bijectors/pd.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
struct PDBijector <: Bijector{2} end

function replace_diag(X, y)
f(i, j) = ifelse(i == j, y[i], X[i, j])
return f.(1:size(X, 1), (1:size(X, 2))')
function replace_diag(f, X)
g(i, j) = ifelse(i == j, f(X[i, i]), X[i, j])
return g.(1:size(X, 1), (1:size(X, 2))')
end
function (b::PDBijector)(X::AbstractMatrix{<:Real})
Y = cholesky(X).L
return replace_diag(Y, log.(diag(Y)))
return replace_diag(log, Y)
end
function (ib::Inverse{<:PDBijector})(Y::AbstractMatrix{<:Real})
X = replace_diag(Y, exp.(diag(Y)))
X = replace_diag(exp, Y)
return LowerTriangular(X) * LowerTriangular(X)'
end
47 changes: 33 additions & 14 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,22 @@ end
vcat(data(x1), data(x2)), (Δ) -> (Δ[1:length(x1)], Δ[length(x1)+1])
end

LinearAlgebra.LowerTriangular(A::TrackedMatrix) = track(LowerTriangular, A)
@grad LinearAlgebra.LowerTriangular(A::TrackedMatrix) = LowerTriangular(data(A)), Δ->(LowerTriangular(Δ),)

function Base.:*(
A::Adjoint{<:Any, <:LinearAlgebra.AbstractTriangular{<:Any, <:AbstractMatrix}},
B::TrackedVector,
)
return track(*, A, B)
function Base.copy(
A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}},
) where {T <: Real}
return track(copy, A)
end
@grad function Base.copy(
A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}},
) where {T <: Real}
return copy(data(A)), ∇ -> (copy(∇),)
end

Base.:*(A::TrackedMatrix, B::AbstractTriangular) = track(*, A, B)
Base.:*(A::AbstractTriangular{T}, B::TrackedVector) where {T} = track(*, A, B)
Base.:*(A::AbstractTriangular{T}, B::TrackedMatrix) where {T} = track(*, A, B)
Base.:*(A::Adjoint{T, <:AbstractTriangular{T}}, B::TrackedMatrix) where {T} = track(*, A, B)
Base.:*(A::Adjoint{T, <:AbstractTriangular{T}}, B::TrackedVector) where {T} = track(*, A, B)

_eps(::Type{<:TrackedReal{T}}) where {T} = eps(T)

Expand Down Expand Up @@ -199,14 +206,26 @@ end
end
end

(b::PDBijector)(X::TrackedMatrix) = track(b, X)
@grad function (b::PDBijector)(X::AbstractMatrix{<:Real})
return pullback(b, data(X))
replace_diag(::typeof(log), X::TrackedMatrix) = track(replace_diag, log, X)
@grad function replace_diag(::typeof(log), X)
Xd = data(X)
f(i, j) = i == j ? log(Xd[i, j]) : Xd[i, j]
out = f.(1:size(Xd, 1), (1:size(Xd, 2))')
out, ∇ -> begin
g(i, j) = i == j ? ∇[i, j]/Xd[i, j] : ∇[i, j]
return (nothing, g.(1:size(Xd, 1), (1:size(Xd, 2))'))
end
end

(ib::Inverse{PDBijector})(X::TrackedMatrix) = track(ib, X)
@grad function (ib::Inverse{PDBijector})(Y::AbstractMatrix{<:Real})
return pullback(ib, data(Y))
replace_diag(::typeof(exp), X::TrackedMatrix) = track(replace_diag, exp, X)
@grad function replace_diag(::typeof(exp), X)
Xd = data(X)
f(i, j) = ifelse(i == j, exp(Xd[i, j]), Xd[i, j])
out = f.(1:size(Xd, 1), (1:size(Xd, 2))')
out, ∇ -> begin
g(i, j) = ifelse(i == j, ∇[i, j]*exp(Xd[i, j]), ∇[i, j])
return (nothing, g.(1:size(Xd, 1), (1:size(Xd, 2))'))
end
end

logabsdetjac(b::SimplexBijector, x::TrackedVecOrMat) = track(logabsdetjac, b, x)
Expand Down
54 changes: 16 additions & 38 deletions src/compat/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,21 @@ end

## Positive definite matrices

@adjoint function replace_diag(X, y)
f(i, j) = ifelse(i == j, y[i], X[i, j])
@adjoint function replace_diag(::typeof(log), X)
f(i, j) = i == j ? log(X[i, j]) : X[i, j]
out = f.(1:size(X, 1), (1:size(X, 2))')
out, ∇ -> (replace_diag(∇, zeros(length(y))), diag(∇))
out, ∇ -> begin
g(i, j) = i == j ? ∇[i, j] / X[i, j] : ∇[i, j]
(nothing, g.(1:size(X, 1), (1:size(X, 2))'))
end
end
@adjoint function replace_diag(::typeof(exp), X)
f(i, j) = ifelse(i == j, exp(X[i, j]), X[i, j])
out = f.(1:size(X, 1), (1:size(X, 2))')
out, ∇ -> begin
g(i, j) = ifelse(i == j, ∇[i, j] * exp(X[i, j]), ∇[i, j])
(nothing, g.(1:size(X, 1), (1:size(X, 2))'))
end
end

@adjoint function _logpdf_with_trans_pd(
Expand All @@ -95,9 +106,9 @@ function _logpdf_with_trans_pd_zygote(
transform::Bool,
)
T = eltype(X)
Xcf = unsafe_cholesky(X, false)
Xcf = cholesky(X, check = false)
if !issuccess(Xcf)
Xcf = unsafe_cholesky(X + (eps(T) * norm(X)) * I, true)
Xcf = cholesky(X + max(eps(T), eps(T) * norm(X)) * I, check = true)
end
lp = getlogp(d, Xcf, X)
if transform && isfinite(lp)
Expand All @@ -110,39 +121,6 @@ function _logpdf_with_trans_pd_zygote(
return lp
end

# Zygote doesn't support kwargs, e.g. cholesky(A, check = false), hence this workaround
# Copied from DistributionsAD
unsafe_cholesky(x, check) = cholesky(x, check=check)
@adjoint function unsafe_cholesky::Real, check)
C = cholesky(Σ; check=check)
return C, function::NamedTuple)
issuccess(C) || return (zero(Σ), nothing)
.factors[1, 1] / (2 * C.U[1, 1]), nothing)
end
end
@adjoint function unsafe_cholesky::Diagonal, check)
C = cholesky(Σ; check=check)
return C, function::NamedTuple)
issuccess(C) || (Diagonal(zero(diag.factors))), nothing)
(Diagonal(diag.factors) .* inv.(2 .* C.factors.diag)), nothing)
end
end
@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
C = cholesky(Σ; check=check)
return C, function::NamedTuple)
issuccess(C) || return (zero.factors), nothing)
U, Ū = C.U, Δ.factors
Σ̄ =* U'
Σ̄ = LinearAlgebra.copytri!(Σ̄, 'U')
Σ̄ = ldiv!(U, Σ̄)
BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄)
@inbounds for n in diagind(Σ̄)
Σ̄[n] /= 2
end
return (UpperTriangular(Σ̄), nothing)
end
end

# Simplex adjoints

@adjoint function _simplex_bijector(X::AbstractVector, b::SimplexBijector)
Expand Down

0 comments on commit e4065e9

Please sign in to comment.