Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FillDist and ArrayDist #19

Merged
merged 25 commits into from
Feb 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ version = "0.3.2"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Expand All @@ -18,10 +20,14 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Combinatorics = "0.7"
DiffRules = "0.1, 1.0"
Distributions = "0.22"
FillArrays = "0.8"
FiniteDifferences = "0.9"
ForwardDiff = "0.10.6"
PDMats = "0.9"
SpecialFunctions = "0.8, 0.9, 0.10"
StatsBase = "0.32"
StatsFuns = "0.8, 0.9"
Tracker = "0.2.5"
Zygote = "0.4.7"
Expand Down
15 changes: 12 additions & 3 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ using PDMats,
StatsFuns

using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
TrackedVecOrMat, track, data
using ZygoteRules: ZygoteRules, pullback
TrackedVecOrMat, track, @grad, data
using SpecialFunctions: logabsgamma, digamma
using ZygoteRules: ZygoteRules, @adjoint, pullback
using LinearAlgebra: copytri!
using Distributions: AbstractMvLogNormal,
ContinuousMultivariateDistribution
using DiffRules, SpecialFunctions, FillArrays
using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
using Base.Iterators: drop

import StatsFuns: logsumexp,
binomlogpdf,
Expand All @@ -35,11 +39,16 @@ export TuringScalMvNormal,
TuringMvLogNormal,
TuringPoissonBinomial,
TuringWishart,
TuringInverseWishart
TuringInverseWishart,
arraydist,
filldist

include("common.jl")
include("univariate.jl")
include("multivariate.jl")
include("matrixvariate.jl")
include("flatten.jl")
include("arraydist.jl")
include("filldist.jl")

end
76 changes: 76 additions & 0 deletions src/arraydist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Univariate

const VectorOfUnivariate = Distributions.Product

function arraydist(dists::AbstractVector{<:Normal{T}}) where {T}
means = mean.(dists)
vars = var.(dists)
return MvNormal(means, vars)
end
function arraydist(dists::AbstractVector{<:Normal{<:TrackedReal}})
means = vcatmapreduce(mean, dists)
vars = vcatmapreduce(var, dists)
return MvNormal(means, vars)
end
function arraydist(dists::AbstractVector{<:UnivariateDistribution})
return product_distribution(dists)
end
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
return sum(vcatmapreduce(logpdf, dist.v, x))
end
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
# eachcol breaks Zygote, so we need an adjoint
return vcatmapreduce((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x))
end
@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
# Any other more efficient implementation breaks Zygote
f(dist, x) = [sum(logpdf.(dist.v, view(x, :, i))) for i in 1:size(x, 2)]
return pullback(f, dist, x)
end

struct MatrixOfUnivariate{
S <: ValueSupport,
Tdist <: UnivariateDistribution{S},
Tdists <: AbstractMatrix{Tdist},
} <: MatrixDistribution{S}
dists::Tdists
end
Base.size(dist::MatrixOfUnivariate) = size(dist.dists)
function arraydist(dists::AbstractMatrix{<:UnivariateDistribution})
return MatrixOfUnivariate(dists)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
# Broadcasting here breaks Tracker for some reason
# A Zygote adjoint is defined for vcatmapreduce to use broadcasting
return sum(vcatmapreduce(logpdf, dist.dists, x))
end
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
return rand.(Ref(rng), dist.dists)
end

# Multivariate

struct VectorOfMultivariate{
S <: ValueSupport,
Tdist <: MultivariateDistribution{S},
Tdists <: AbstractVector{Tdist},
} <: MatrixDistribution{S}
dists::Tdists
end
Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist))
Base.length(dist::VectorOfMultivariate) = length(dist.dists)
function arraydist(dists::AbstractVector{<:MultivariateDistribution})
return VectorOfMultivariate(dists)
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
# eachcol breaks Zygote, so we define an adjoint
return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x)))
end
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
return pullback(f, dist, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
init = reshape(rand(rng, dist.dists[1]), :, 1)
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init)
end
75 changes: 66 additions & 9 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,30 @@
## Generic ##

if VERSION < v"1.1"
eachcol(A::AbstractVecOrMat) = (view(A, :, i) for i in axes(A, 2))
end

Base.one(::Irrational) = true

function vcatmapreduce(f, args...)
init = vcat(f(first.(args)...,))
zipped_args = zip(args...,)
return mapreduce(vcat, drop(zipped_args, 1); init = init) do zarg
f(zarg...,)
end
end
@adjoint function vcatmapreduce(f, args...)
g(f, args...) = f.(args...,)
return pullback(g, f, args...)
end

function Base.fill(
value::TrackedReal,
dims::Vararg{Union{Integer, AbstractUnitRange}},
)
return track(fill, value, dims...)
end
Tracker.@grad function Base.fill(value::Real, dims...)
@grad function Base.fill(value::Real, dims...)
return fill(data(value), dims...), function(Δ)
size(Δ) ≢ dims && error("Dimension mismatch")
return (sum(Δ), map(_->nothing, dims)...)
Expand All @@ -16,15 +34,15 @@ end
## StatsFuns ##

logsumexp(x::TrackedArray) = track(logsumexp, x)
Tracker.@grad function logsumexp(x::TrackedArray)
@grad function logsumexp(x::TrackedArray)
lse = logsumexp(data(x))
return lse, Δ -> (Δ .* exp.(x .- lse),)
end

## Linear algebra ##

LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A)
Tracker.@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),)
end

Expand All @@ -39,27 +57,27 @@ function turing_chol(A::AbstractMatrix, check)
(chol.factors, chol.info)
end
turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check)
Tracker.@grad function turing_chol(A::AbstractMatrix, check)
@grad function turing_chol(A::AbstractMatrix, check)
C, back = pullback(unsafe_cholesky, data(A), data(check))
return (C.factors, C.info), Δ->back((factors=data(Δ[1]),))
end

unsafe_cholesky(x, check) = cholesky(x, check=check)
ZygoteRules.@adjoint function unsafe_cholesky(Σ::Real, check)
@adjoint function unsafe_cholesky(Σ::Real, check)
C = cholesky(Σ; check=check)
return C, function(Δ::NamedTuple)
issuccess(C) || return (zero(Σ), nothing)
(Δ.factors[1, 1] / (2 * C.U[1, 1]), nothing)
end
end
ZygoteRules.@adjoint function unsafe_cholesky(Σ::Diagonal, check)
@adjoint function unsafe_cholesky(Σ::Diagonal, check)
C = cholesky(Σ; check=check)
return C, function(Δ::NamedTuple)
issuccess(C) || (Diagonal(zero(diag(Δ.factors))), nothing)
(Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)), nothing)
end
end
ZygoteRules.@adjoint function unsafe_cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
@adjoint function unsafe_cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
C = cholesky(Σ; check=check)
return C, function(Δ::NamedTuple)
issuccess(C) || return (zero(Δ.factors), nothing)
Expand All @@ -78,7 +96,7 @@ end
# Specialised logdet for cholesky to target the triangle directly.
logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)])
logdet_chol_tri(U::TrackedMatrix) = track(logdet_chol_tri, U)
Tracker.@grad function logdet_chol_tri(U::AbstractMatrix)
@grad function logdet_chol_tri(U::AbstractMatrix)
U_data = data(U)
return logdet_chol_tri(U_data), Δ->(Matrix(Diagonal(2 .* Δ ./ diag(U_data))),)
end
Expand All @@ -88,6 +106,7 @@ function LinearAlgebra.logdet(C::Cholesky{<:TrackedReal, <:TrackedMatrix})
end

# Tracker's implementation of ldiv isn't good. We'll use Zygote's instead.

zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B
function zygote_ldiv(A::TrackedMatrix, B::TrackedVecOrMat)
return track(zygote_ldiv, A, B)
Expand All @@ -96,11 +115,49 @@ function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat)
return track(zygote_ldiv, A, B)
end
zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = track(zygote_ldiv, A, B)
Tracker.@grad function zygote_ldiv(A, B)
@grad function zygote_ldiv(A, B)
Y, back = pullback(\, data(A), data(B))
return Y, Δ->back(data(Δ))
end

function Base.:\(a::Cholesky{<:TrackedReal, <:TrackedArray}, b::AbstractVecOrMat)
return (a.U \ (a.U' \ b))
end

# SpecialFunctions

SpecialFunctions.logabsgamma(x::TrackedReal) = track(logabsgamma, x)
@grad function SpecialFunctions.logabsgamma(x::Real)
return logabsgamma(data(x)), Δ -> (digamma(data(x)) * Δ[1],)
end
@adjoint function SpecialFunctions.logabsgamma(x::Real)
return logabsgamma(x), Δ -> (digamma(x) * Δ[1],)
end
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved

# Some Tracker fixes

for i = 0:2, c = Tracker.combinations([:AbstractArray, :TrackedArray, :TrackedReal, :Number], i), f = [:hcat, :vcat]
if :TrackedReal in c
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
track($f, $(cnames...), x, xs...)
end
end
@grad function vcat(x::Real)
vcat(data(x)), (Δ) -> (Δ[1],)
end
@grad function vcat(x1::Real, x2::Real)
vcat(data(x1), data(x2)), (Δ) -> (Δ[1], Δ[2])
end
@grad function vcat(x1::AbstractVector, x2::Real)
vcat(data(x1), data(x2)), (Δ) -> (Δ[1:length(x1)], Δ[length(x1)+1])
end
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved

# Zygote fill has issues with non-numbers

@adjoint function fill(x::T, dims...) where {T}
function zfill(x, dims...,)
return reshape([x for i in 1:prod(dims)], dims)
end
pullback(zfill, x, dims...)
end
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
111 changes: 111 additions & 0 deletions src/filldist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Univariate

const FillVectorOfUnivariate{
S <: ValueSupport,
T <: UnivariateDistribution{S},
Tdists <: Fill{T, 1},
} = VectorOfUnivariate{S, T, Tdists}

function filldist(dist::UnivariateDistribution, N::Int)
return product_distribution(Fill(dist, N))
end
filldist(d::Normal, N::Int) = MvNormal(fill(d.μ, N), d.σ)

function Distributions.logpdf(
dist::FillVectorOfUnivariate,
x::AbstractVector{<:Real},
)
return _logpdf(dist, x)
end
function Distributions.logpdf(
dist::FillVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
return _logpdf(dist, x)
end
@adjoint function Distributions.logpdf(
dist::FillVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
return pullback(_logpdf, dist, x)
end

function _logpdf(
dist::FillVectorOfUnivariate,
x::AbstractVector{<:Real},
)
return _flat_logpdf(dist.v.value, x)
end
function _logpdf(
dist::FillVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
return _flat_logpdf_mat(dist.v.value, x)
end

function _flat_logpdf(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return sum(f.(args..., x))
else
return sum(vcatmapreduce(x -> logpdf(dist, x), x))
end
end
function _flat_logpdf_mat(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return vec(sum(f.(args..., x), dims = 1))
else
temp = vcatmapreduce(x -> logpdf(dist, x), x)
return vec(sum(reshape(temp, size(x)), dims = 1))
end
end

const FillMatrixOfUnivariate{
S <: ValueSupport,
T <: UnivariateDistribution{S},
Tdists <: Fill{T, 2},
} = MatrixOfUnivariate{S, T, Tdists}

function filldist(dist::UnivariateDistribution, N1::Integer, N2::Integer)
return MatrixOfUnivariate(Fill(dist, N1, N2))
end
function Distributions.logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:Real})
return _flat_logpdf(dist.dists.value, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::FillMatrixOfUnivariate)
return rand(rng, dist.dists.value, length.(dist.dists.axes))
end

# Multivariate

const FillVectorOfMultivariate{
S <: ValueSupport,
T <: MultivariateDistribution{S},
Tdists <: Fill{T, 1},
} = VectorOfMultivariate{S, T, Tdists}

function filldist(dist::MultivariateDistribution, N::Int)
return VectorOfMultivariate(Fill(dist, N))
end
function Distributions.logpdf(
dist::FillVectorOfMultivariate,
x::AbstractMatrix{<:Real},
)
return _logpdf(dist, x)
end
@adjoint function Distributions.logpdf(
dist::FillVectorOfMultivariate,
x::AbstractMatrix{<:Real},
)
return pullback(_logpdf, dist, x)
end
function _logpdf(
dist::FillVectorOfMultivariate,
x::AbstractMatrix{<:Real},
)
return sum(logpdf(dist.dists.value, x))
end
function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate)
return rand(rng, dist.dists.value, length.(dist.dists.axes))
end
Loading