Skip to content

Commit

Permalink
Merge pull request #19 from TuringLang/mt/array_dist_and_multi
Browse files Browse the repository at this point in the history
Implement FillDist and ArrayDist
  • Loading branch information
mohamed82008 authored Feb 16, 2020
2 parents c93d8e7 + 105e5e8 commit b296d39
Show file tree
Hide file tree
Showing 11 changed files with 440 additions and 81 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ version = "0.3.2"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Expand All @@ -18,10 +20,14 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Combinatorics = "0.7"
DiffRules = "0.1, 1.0"
Distributions = "0.22"
FillArrays = "0.8"
FiniteDifferences = "0.9"
ForwardDiff = "0.10.6"
PDMats = "0.9"
SpecialFunctions = "0.8, 0.9, 0.10"
StatsBase = "0.32"
StatsFuns = "0.8, 0.9"
Tracker = "0.2.5"
Zygote = "0.4.7"
Expand Down
15 changes: 12 additions & 3 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ using PDMats,
StatsFuns

using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
TrackedVecOrMat, track, data
using ZygoteRules: ZygoteRules, pullback
TrackedVecOrMat, track, @grad, data
using SpecialFunctions: logabsgamma, digamma
using ZygoteRules: ZygoteRules, @adjoint, pullback
using LinearAlgebra: copytri!
using Distributions: AbstractMvLogNormal,
ContinuousMultivariateDistribution
using DiffRules, SpecialFunctions, FillArrays
using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
using Base.Iterators: drop

import StatsFuns: logsumexp,
binomlogpdf,
Expand All @@ -35,11 +39,16 @@ export TuringScalMvNormal,
TuringMvLogNormal,
TuringPoissonBinomial,
TuringWishart,
TuringInverseWishart
TuringInverseWishart,
arraydist,
filldist

include("common.jl")
include("univariate.jl")
include("multivariate.jl")
include("matrixvariate.jl")
include("flatten.jl")
include("arraydist.jl")
include("filldist.jl")

end
76 changes: 76 additions & 0 deletions src/arraydist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Univariate

const VectorOfUnivariate = Distributions.Product

function arraydist(dists::AbstractVector{<:Normal{T}}) where {T}
means = mean.(dists)
vars = var.(dists)
return MvNormal(means, vars)
end
function arraydist(dists::AbstractVector{<:Normal{<:TrackedReal}})
means = vcatmapreduce(mean, dists)
vars = vcatmapreduce(var, dists)
return MvNormal(means, vars)
end
function arraydist(dists::AbstractVector{<:UnivariateDistribution})
return product_distribution(dists)
end
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
return sum(vcatmapreduce(logpdf, dist.v, x))
end
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
# eachcol breaks Zygote, so we need an adjoint
return vcatmapreduce((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x))
end
@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
# Any other more efficient implementation breaks Zygote
f(dist, x) = [sum(logpdf.(dist.v, view(x, :, i))) for i in 1:size(x, 2)]
return pullback(f, dist, x)
end

struct MatrixOfUnivariate{
S <: ValueSupport,
Tdist <: UnivariateDistribution{S},
Tdists <: AbstractMatrix{Tdist},
} <: MatrixDistribution{S}
dists::Tdists
end
Base.size(dist::MatrixOfUnivariate) = size(dist.dists)
function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
return MatrixOfUnivariate(dists)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
# Broadcasting here breaks Tracker for some reason
# A Zygote adjoint is defined for vcatmapreduce to use broadcasting
return sum(vcatmapreduce(logpdf, dist.dists, x))
end
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
return rand.(Ref(rng), dist.dists)
end

# Multivariate

struct VectorOfMultivariate{
S <: ValueSupport,
Tdist <: MultivariateDistribution{S},
Tdists <: AbstractVector{Tdist},
} <: MatrixDistribution{S}
dists::Tdists
end
Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist))
Base.length(dist::VectorOfMultivariate) = length(dist.dists)
function arraydist(dists::AbstractVector{<:MultivariateDistribution})
return VectorOfMultivariate(dists)
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
# eachcol breaks Zygote, so we define an adjoint
return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x)))
end
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
return pullback(f, dist, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
init = reshape(rand(rng, dist.dists[1]), :, 1)
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init)
end
75 changes: 66 additions & 9 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
## Generic ##

if VERSION < v"1.1"
eachcol(A::AbstractVecOrMat) = (view(A, :, i) for i in axes(A, 2))
end

Base.one(::Irrational) = true

function vcatmapreduce(f, args...)
init = vcat(f(first.(args)...,))
zipped_args = zip(args...,)
return mapreduce(vcat, drop(zipped_args, 1); init = init) do zarg
f(zarg...,)
end
end
@adjoint function vcatmapreduce(f, args...)
g(f, args...) = f.(args...,)
return pullback(g, f, args...)
end

function Base.fill(
value::TrackedReal,
dims::Vararg{Union{Integer, AbstractUnitRange}},
)
return track(fill, value, dims...)
end
Tracker.@grad function Base.fill(value::Real, dims...)
@grad function Base.fill(value::Real, dims...)
return fill(data(value), dims...), function(Δ)
size(Δ) dims && error("Dimension mismatch")
return (sum(Δ), map(_->nothing, dims)...)
Expand All @@ -16,15 +34,15 @@ end
## StatsFuns ##

logsumexp(x::TrackedArray) = track(logsumexp, x)
Tracker.@grad function logsumexp(x::TrackedArray)
@grad function logsumexp(x::TrackedArray)
lse = logsumexp(data(x))
return lse, Δ ->.* exp.(x .- lse),)
end

## Linear algebra ##

LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A)
Tracker.@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),)
end

Expand All @@ -39,27 +57,27 @@ function turing_chol(A::AbstractMatrix, check)
(chol.factors, chol.info)
end
turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check)
Tracker.@grad function turing_chol(A::AbstractMatrix, check)
@grad function turing_chol(A::AbstractMatrix, check)
C, back = pullback(unsafe_cholesky, data(A), data(check))
return (C.factors, C.info), Δ->back((factors=data(Δ[1]),))
end

unsafe_cholesky(x, check) = cholesky(x, check=check)
ZygoteRules.@adjoint function unsafe_cholesky::Real, check)
@adjoint function unsafe_cholesky::Real, check)
C = cholesky(Σ; check=check)
return C, function::NamedTuple)
issuccess(C) || return (zero(Σ), nothing)
.factors[1, 1] / (2 * C.U[1, 1]), nothing)
end
end
ZygoteRules.@adjoint function unsafe_cholesky::Diagonal, check)
@adjoint function unsafe_cholesky::Diagonal, check)
C = cholesky(Σ; check=check)
return C, function::NamedTuple)
issuccess(C) || (Diagonal(zero(diag.factors))), nothing)
(Diagonal(diag.factors) .* inv.(2 .* C.factors.diag)), nothing)
end
end
ZygoteRules.@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
@adjoint function unsafe_cholesky::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
C = cholesky(Σ; check=check)
return C, function::NamedTuple)
issuccess(C) || return (zero.factors), nothing)
Expand All @@ -78,7 +96,7 @@ end
# Specialised logdet for cholesky to target the triangle directly.
logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)])
logdet_chol_tri(U::TrackedMatrix) = track(logdet_chol_tri, U)
Tracker.@grad function logdet_chol_tri(U::AbstractMatrix)
@grad function logdet_chol_tri(U::AbstractMatrix)
U_data = data(U)
return logdet_chol_tri(U_data), Δ->(Matrix(Diagonal(2 .* Δ ./ diag(U_data))),)
end
Expand All @@ -88,6 +106,7 @@ function LinearAlgebra.logdet(C::Cholesky{<:TrackedReal, <:TrackedMatrix})
end

# Tracker's implementation of ldiv isn't good. We'll use Zygote's instead.

zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B
function zygote_ldiv(A::TrackedMatrix, B::TrackedVecOrMat)
return track(zygote_ldiv, A, B)
Expand All @@ -96,11 +115,49 @@ function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat)
return track(zygote_ldiv, A, B)
end
zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = track(zygote_ldiv, A, B)
Tracker.@grad function zygote_ldiv(A, B)
@grad function zygote_ldiv(A, B)
Y, back = pullback(\, data(A), data(B))
return Y, Δ->back(data(Δ))
end

function Base.:\(a::Cholesky{<:TrackedReal, <:TrackedArray}, b::AbstractVecOrMat)
return (a.U \ (a.U' \ b))
end

# SpecialFunctions

SpecialFunctions.logabsgamma(x::TrackedReal) = track(logabsgamma, x)
@grad function SpecialFunctions.logabsgamma(x::Real)
return logabsgamma(data(x)), Δ -> (digamma(data(x)) * Δ[1],)
end
@adjoint function SpecialFunctions.logabsgamma(x::Real)
return logabsgamma(x), Δ -> (digamma(x) * Δ[1],)
end

# Some Tracker fixes

for i = 0:2, c = Tracker.combinations([:AbstractArray, :TrackedArray, :TrackedReal, :Number], i), f = [:hcat, :vcat]
if :TrackedReal in c
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
track($f, $(cnames...), x, xs...)
end
end
@grad function vcat(x::Real)
vcat(data(x)), (Δ) -> (Δ[1],)
end
@grad function vcat(x1::Real, x2::Real)
vcat(data(x1), data(x2)), (Δ) -> (Δ[1], Δ[2])
end
@grad function vcat(x1::AbstractVector, x2::Real)
vcat(data(x1), data(x2)), (Δ) -> (Δ[1:length(x1)], Δ[length(x1)+1])
end

# Zygote fill has issues with non-numbers

@adjoint function fill(x::T, dims...) where {T}
function zfill(x, dims...,)
return reshape([x for i in 1:prod(dims)], dims)
end
pullback(zfill, x, dims...)
end
111 changes: 111 additions & 0 deletions src/filldist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Univariate

const FillVectorOfUnivariate{
S <: ValueSupport,
T <: UnivariateDistribution{S},
Tdists <: Fill{T, 1},
} = VectorOfUnivariate{S, T, Tdists}

function filldist(dist::UnivariateDistribution, N::Int)
return product_distribution(Fill(dist, N))
end
filldist(d::Normal, N::Int) = MvNormal(fill(d.μ, N), d.σ)

function Distributions.logpdf(
dist::FillVectorOfUnivariate,
x::AbstractVector{<:Real},
)
return _logpdf(dist, x)
end
function Distributions.logpdf(
dist::FillVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
return _logpdf(dist, x)
end
@adjoint function Distributions.logpdf(
dist::FillVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
return pullback(_logpdf, dist, x)
end

function _logpdf(
dist::FillVectorOfUnivariate,
x::AbstractVector{<:Real},
)
return _flat_logpdf(dist.v.value, x)
end
function _logpdf(
dist::FillVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
return _flat_logpdf_mat(dist.v.value, x)
end

function _flat_logpdf(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return sum(f.(args..., x))
else
return sum(vcatmapreduce(x -> logpdf(dist, x), x))
end
end
function _flat_logpdf_mat(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return vec(sum(f.(args..., x), dims = 1))
else
temp = vcatmapreduce(x -> logpdf(dist, x), x)
return vec(sum(reshape(temp, size(x)), dims = 1))
end
end

const FillMatrixOfUnivariate{
S <: ValueSupport,
T <: UnivariateDistribution{S},
Tdists <: Fill{T, 2},
} = MatrixOfUnivariate{S, T, Tdists}

function filldist(dist::UnivariateDistribution, N1::Integer, N2::Integer)
return MatrixOfUnivariate(Fill(dist, N1, N2))
end
function Distributions.logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:Real})
return _flat_logpdf(dist.dists.value, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::FillMatrixOfUnivariate)
return rand(rng, dist.dists.value, length.(dist.dists.axes))
end

# Multivariate

const FillVectorOfMultivariate{
S <: ValueSupport,
T <: MultivariateDistribution{S},
Tdists <: Fill{T, 1},
} = VectorOfMultivariate{S, T, Tdists}

function filldist(dist::MultivariateDistribution, N::Int)
return VectorOfMultivariate(Fill(dist, N))
end
function Distributions.logpdf(
dist::FillVectorOfMultivariate,
x::AbstractMatrix{<:Real},
)
return _logpdf(dist, x)
end
@adjoint function Distributions.logpdf(
dist::FillVectorOfMultivariate,
x::AbstractMatrix{<:Real},
)
return pullback(_logpdf, dist, x)
end
function _logpdf(
dist::FillVectorOfMultivariate,
x::AbstractMatrix{<:Real},
)
return sum(logpdf(dist.dists.value, x))
end
function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate)
return rand(rng, dist.dists.value, length.(dist.dists.axes))
end
Loading

0 comments on commit b296d39

Please sign in to comment.