Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FillDist and ArrayDist #19

Merged
merged 25 commits into from
Feb 16, 2020
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ version = "0.3.2"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Expand All @@ -18,10 +20,14 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Combinatorics = "0.7"
DiffRules = "0.1, 1.0"
Distributions = "0.22"
FillArrays = "0.8"
FiniteDifferences = "0.9"
ForwardDiff = "0.10.6"
PDMats = "0.9"
SpecialFunctions = "0.8, 0.9, 0.10"
StatsBase = "0.32"
StatsFuns = "0.8, 0.9"
Tracker = "0.2.5"
Zygote = "0.4.7"
Expand Down
14 changes: 11 additions & 3 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@ using PDMats,
StatsFuns

using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
TrackedVecOrMat, track, data
using ZygoteRules: ZygoteRules, pullback
TrackedVecOrMat, track, @grad, data
using ZygoteRules: ZygoteRules, @adjoint, pullback
using LinearAlgebra: copytri!
using Distributions: AbstractMvLogNormal,
ContinuousMultivariateDistribution
using DiffRules, SpecialFunctions, FillArrays
using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
using Base.Iterators: drop

import StatsFuns: logsumexp,
binomlogpdf,
Expand All @@ -35,11 +38,16 @@ export TuringScalMvNormal,
TuringMvLogNormal,
TuringPoissonBinomial,
TuringWishart,
TuringInverseWishart
TuringInverseWishart,
ArrayDist,
FillDist

include("common.jl")
include("univariate.jl")
include("multivariate.jl")
include("matrixvariate.jl")
include("flatten.jl")
include("array_dist.jl")
include("multi.jl")

end
90 changes: 90 additions & 0 deletions src/array_dist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Univariate

const VectorOfUnivariate{
S <: ValueSupport,
Tdist <: UnivariateDistribution{S},
Tdists <: AbstractVector{Tdist},
} = Distributions.Product{S, Tdist, Tdists}
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved

function ArrayDist(dists::AbstractVector{<:Normal{T}}) where {T}
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
if T <: TrackedReal
init_m = vcat(dists[1].μ)
means = mapreduce(vcat, drop(dists, 1); init = init_m) do d
d.μ
end
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
init_v = vcat(dists[1].σ^2)
vars = mapreduce(vcat, drop(dists, 1); init = init_v) do d
d.σ^2
end
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
else
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
means = [d.μ for d in dists]
vars = [d.σ^2 for d in dists]
end

return MvNormal(means, vars)
end
function ArrayDist(dists::AbstractVector{<:UnivariateDistribution})
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
return Distributions.Product(dists)
end
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real})
return sum(logpdf.(dist.v, x))
end
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
# Any other more efficient implementation breaks Zygote
return [logpdf(dist, x[:,i]) for i in 1:size(x, 2)]
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
end
function Distributions.logpdf(
dist::VectorOfUnivariate,
x::AbstractVector{<:AbstractMatrix{<:Real}},
)
return logpdf.(Ref(dist), x)
end

struct MatrixOfUnivariate{
S <: ValueSupport,
Tdist <: UnivariateDistribution{S},
Tdists <: AbstractMatrix{Tdist},
} <: MatrixDistribution{S}
dists::Tdists
end
Base.size(dist::MatrixOfUnivariate) = size(dist.dists)
function ArrayDist(dists::AbstractMatrix{<:UnivariateDistribution})
return MatrixOfUnivariate(dists)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real})
# Broadcasting here breaks Tracker for some reason
return sum(zip(dist.dists, x)) do (dist, x)
logpdf(dist, x)
end
end
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
return rand.(Ref(rng), dist.dists)
end

# Multivariate

struct VectorOfMultivariate{
S <: ValueSupport,
Tdist <: MultivariateDistribution{S},
Tdists <: AbstractVector{Tdist},
} <: MatrixDistribution{S}
dists::Tdists
end
Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist))
Base.length(dist::VectorOfMultivariate) = length(dist.dists)
function ArrayDist(dists::AbstractVector{<:MultivariateDistribution})
return VectorOfMultivariate(dists)
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist))
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
end
function Distributions.logpdf(
dist::VectorOfMultivariate,
x::AbstractVector{<:AbstractVector{<:Real}},
)
return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist))
end
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
init = reshape(rand(rng, dist.dists[1]), :, 1)
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init)
end
56 changes: 47 additions & 9 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
## Generic ##

Base.one(::Irrational) = 1
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved

function Base.fill(
value::TrackedReal,
dims::Vararg{Union{Integer, AbstractUnitRange}},
)
return track(fill, value, dims...)
end
Tracker.@grad function Base.fill(value::Real, dims...)
@grad function Base.fill(value::Real, dims...)
return fill(data(value), dims...), function(Δ)
size(Δ) ≢ dims && error("Dimension mismatch")
return (sum(Δ), map(_->nothing, dims)...)
Expand All @@ -16,15 +18,15 @@ end
## StatsFuns ##

logsumexp(x::TrackedArray) = track(logsumexp, x)
Tracker.@grad function logsumexp(x::TrackedArray)
@grad function logsumexp(x::TrackedArray)
lse = logsumexp(data(x))
return lse, Δ -> (Δ .* exp.(x .- lse),)
end

## Linear algebra ##

LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A)
Tracker.@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),)
end

Expand All @@ -39,27 +41,27 @@ function turing_chol(A::AbstractMatrix, check)
(chol.factors, chol.info)
end
turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check)
Tracker.@grad function turing_chol(A::AbstractMatrix, check)
@grad function turing_chol(A::AbstractMatrix, check)
C, back = pullback(unsafe_cholesky, data(A), data(check))
return (C.factors, C.info), Δ->back((factors=data(Δ[1]),))
end

unsafe_cholesky(x, check) = cholesky(x, check=check)
ZygoteRules.@adjoint function unsafe_cholesky(Σ::Real, check)
@adjoint function unsafe_cholesky(Σ::Real, check)
C = cholesky(Σ; check=check)
return C, function(Δ::NamedTuple)
issuccess(C) || return (zero(Σ), nothing)
(Δ.factors[1, 1] / (2 * C.U[1, 1]), nothing)
end
end
ZygoteRules.@adjoint function unsafe_cholesky(Σ::Diagonal, check)
@adjoint function unsafe_cholesky(Σ::Diagonal, check)
C = cholesky(Σ; check=check)
return C, function(Δ::NamedTuple)
issuccess(C) || (Diagonal(zero(diag(Δ.factors))), nothing)
(Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)), nothing)
end
end
ZygoteRules.@adjoint function unsafe_cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
@adjoint function unsafe_cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
C = cholesky(Σ; check=check)
return C, function(Δ::NamedTuple)
issuccess(C) || return (zero(Δ.factors), nothing)
Expand All @@ -78,7 +80,7 @@ end
# Specialised logdet for cholesky to target the triangle directly.
logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)])
logdet_chol_tri(U::TrackedMatrix) = track(logdet_chol_tri, U)
Tracker.@grad function logdet_chol_tri(U::AbstractMatrix)
@grad function logdet_chol_tri(U::AbstractMatrix)
U_data = data(U)
return logdet_chol_tri(U_data), Δ->(Matrix(Diagonal(2 .* Δ ./ diag(U_data))),)
end
Expand All @@ -88,6 +90,7 @@ function LinearAlgebra.logdet(C::Cholesky{<:TrackedReal, <:TrackedMatrix})
end

# Tracker's implementation of ldiv isn't good. We'll use Zygote's instead.

zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B
function zygote_ldiv(A::TrackedMatrix, B::TrackedVecOrMat)
return track(zygote_ldiv, A, B)
Expand All @@ -96,11 +99,46 @@ function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat)
return track(zygote_ldiv, A, B)
end
zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = track(zygote_ldiv, A, B)
Tracker.@grad function zygote_ldiv(A, B)
@grad function zygote_ldiv(A, B)
Y, back = pullback(\, data(A), data(B))
return Y, Δ->back(data(Δ))
end

function Base.:\(a::Cholesky{<:TrackedReal, <:TrackedArray}, b::AbstractVecOrMat)
return (a.U \ (a.U' \ b))
end

# SpecialFunctions

function SpecialFunctions.logabsgamma(x::TrackedReal)
v = loggamma(x)
return v, sign(data(v))
end
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved

# Some Tracker fixes

for i = 0:2, c = Tracker.combinations([:AbstractArray, :TrackedArray, :TrackedReal, :Number], i), f = [:hcat, :vcat]
if :TrackedReal in c
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
track($f, $(cnames...), x, xs...)
end
end
@grad function vcat(x::Real)
vcat(data(x)), (Δ) -> (Δ[1],)
end
@grad function vcat(x1::Real, x2::Real)
vcat(data(x1), data(x2)), (Δ) -> (Δ[1], Δ[2])
end
@grad function vcat(x1::AbstractVector, x2::Real)
vcat(data(x1), data(x2)), (Δ) -> (Δ[1:length(x1)], Δ[length(x1)+1])
end
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved

# Zygote fill has issues with non-numbers

@adjoint function fill(x::T, dims...) where {T}
function zfill(x, dims...,)
return reshape([x for i in 1:prod(dims)], dims)
end
pullback(zfill, x, dims...)
end
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
76 changes: 76 additions & 0 deletions src/flatten.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
macro register(dist)
return quote
DistributionsAD.eval(getexpr($(esc(dist))))
DistributionsAD.toflatten(::$(esc(dist))) = true
end
end
function getexpr(Tdist)
x = gensym()
fnames = fieldnames(Tdist)
flattened_args = Expr(:tuple, [:(dist.$f) for f in fnames]...)
func = Expr(:->,
Expr(:tuple, fnames..., x),
Expr(:block,
Expr(:call, :logpdf,
Expr(:call, :($Tdist), fnames...),
x,
)
)
)
return :(flatten(dist::$Tdist) = ($func, $flattened_args))
end
const flattened_dists = [ Bernoulli,
BetaBinomial,
Binomial,
Geometric,
NegativeBinomial,
Poisson,
Skellam,
PoissonBinomial,
Arcsine,
Beta,
BetaPrime,
Biweight,
Cauchy,
Chernoff,
Chi,
Chisq,
Cosine,
Epanechnikov,
Erlang,
Exponential,
FDist,
Frechet,
Gamma,
GeneralizedExtremeValue,
GeneralizedPareto,
Gumbel,
InverseGamma,
InverseGaussian,
Kolmogorov,
Laplace,
Levy,
LocationScale,
Logistic,
LogitNormal,
LogNormal,
Normal,
NormalCanon,
NormalInverseGaussian,
Pareto,
PGeneralizedGaussian,
Rayleigh,
SymTriangularDist,
TDist,
TriangularDist,
Triweight,
Categorical,
Truncated,
]
for T in flattened_dists
@eval toflatten(::$T) = true
end
toflatten(::Distribution) = false
for T in flattened_dists
eval(getexpr(T))
end
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved
4 changes: 2 additions & 2 deletions src/matrixvariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ end

## Adjoints

ZygoteRules.@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real})
@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real})
return pullback(TuringWishart, df, S)
end
ZygoteRules.@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real})
@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real})
return pullback(TuringInverseWishart, df, S)
end

Expand Down
Loading