From 66eae40b121e5492749ca02fdee0d608e9efc8d5 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Fri, 10 Jan 2020 08:26:49 +1100 Subject: [PATCH 01/24] multiple distributions as one --- src/DistributionsAD.jl | 6 +- src/array_dist.jl | 152 +++++++++++++++++++++++++++++++++++++++++ src/multi.jl | 133 ++++++++++++++++++++++++++++++++++++ src/multivariate.jl | 55 +++++++++++++-- 4 files changed, 339 insertions(+), 7 deletions(-) create mode 100644 src/array_dist.jl create mode 100644 src/multi.jl diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 6520414b..ce741198 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -35,11 +35,15 @@ export TuringScalMvNormal, TuringMvLogNormal, TuringPoissonBinomial, TuringWishart, - TuringInverseWishart + TuringInverseWishart, + Multi, + ArrayDist include("common.jl") include("univariate.jl") include("multivariate.jl") include("matrixvariate.jl") +include("multi.jl") +include("array_dist.jl") end diff --git a/src/array_dist.jl b/src/array_dist.jl new file mode 100644 index 00000000..bfc957bb --- /dev/null +++ b/src/array_dist.jl @@ -0,0 +1,152 @@ +# Multivariate continuous + +struct ProductVectorContinuousMultivariate{ + Tdists <: AbstractVector{<:ContinuousMultivariateDistribution}, +} <: ContinuousMatrixDistribution + dists::Tdists +end +Base.size(dist::ProductVectorContinuousMultivariate) = (length(dist.dists[1]), length(dist)) +Base.length(dist::ProductVectorContinuousMultivariate) = length(dist.dists) +function ArrayDist(dists::AbstractVector{<:ContinuousMultivariateDistribution}) + return ProductVectorContinuousMultivariate(dists) +end +function Distributions.logpdf( + dist::ProductVectorContinuousMultivariate, + x::AbstractMatrix{<:Real}, +) + return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist)) +end +function Distributions.logpdf( + dist::ProductVectorContinuousMultivariate, + x::AbstractVector{<:AbstractVector{<:Real}}, +) + return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist)) +end +function Distributions.rand( + rng::Random.AbstractRNG, + dist::ProductVectorContinuousMultivariate, +) + return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 1:length(dist)) +end + +# Multivariate discrete + +struct ProductVectorDiscreteMultivariate{ + Tdists <: AbstractVector{<:DiscreteMultivariateDistribution}, +} <: DiscreteMatrixDistribution + dists::Tdists +end +Base.size(dist::ProductVectorDiscreteMultivariate) = (length(dist.dists[1]), length(dist)) +Base.length(dist::ProductVectorDiscreteMultivariate) = length(dist.dists) +function ArrayDist(dists::AbstractVector{<:DiscreteMultivariateDistribution}) + return ProductVectorDiscreteMultivariate(dists) +end +function Distributions.logpdf( + dist::ProductVectorDiscreteMultivariate, + x::AbstractMatrix{<:Integer}, +) + return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist)) +end +function Distributions.logpdf( + dist::ProductVectorDiscreteMultivariate, + x::AbstractVector{<:AbstractVector{<:Integer}}, +) + return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist)) +end +function Distributions.rand( + rng::Random.AbstractRNG, + dist::ProductVectorDiscreteMultivariate, +) + return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 1:length(dist)) +end + +# Univariate continuous + +struct ProductVectorContinuousUnivariate{ + Tdists <: AbstractVector{<:ContinuousUnivariateDistribution}, +} <: ContinuousMultivariateDistribution + dists::Tdists +end +Base.length(dist::ProductVectorContinuousUnivariate) = length(dist.dists) +Base.size(dist::ProductVectorContinuousUnivariate) = (length(dist),) +function ArrayDist(dists::AbstractVector{<:ContinuousUnivariateDistribution}) + return ProductVectorContinuousUnivariate(dists) +end +function Distributions.logpdf( + dist::ProductVectorContinuousUnivariate, + x::AbstractVector{<:Real}, +) + return sum(logpdf.(dist.dists, x)) +end +function Distributions.rand( + rng::Random.AbstractRNG, + dist::ProductVectorContinuousUnivariate, +) + return rand.(Ref(rng), dist.dists) +end + +struct ProductMatrixContinuousUnivariate{ + Tdists <: AbstractMatrix{<:ContinuousUnivariateDistribution}, +} <: ContinuousMatrixDistribution + dists::Tdists +end +Base.size(dist::ProductMatrixContinuousUnivariate) = size(dist.dists) +function ArrayDist(dists::AbstractMatrix{<:ContinuousUnivariateDistribution}) + return ProductMatrixContinuousUnivariate(dists) +end +function Distributions.logpdf( + dist::ProductMatrixContinuousUnivariate, + x::AbstractMatrix{<:Real}, +) + return sum(logpdf.(dist.dists, x)) +end +function Distributions.rand( + rng::Random.AbstractRNG, + dist::ProductMatrixContinuousUnivariate, +) + return rand.(Ref(rng), dist.dists) +end + +# Univariate discrete + +struct ProductVectorDiscreteUnivariate{ + Tdists <: AbstractVector{<:DiscreteUnivariateDistribution}, +} <: ContinuousMultivariateDistribution + dists::Tdists +end +Base.length(dist::ProductVectorDiscreteUnivariate) = length(dist.dists) +Base.size(dist::ProductVectorDiscreteUnivariate) = (length(dist.dists[1]), length(dist)) +function ArrayDist(dists::AbstractVector{<:DiscreteUnivariateDistribution}) + ProductVectorDiscreteUnivariate(dists) +end +function Distributions.logpdf( + dist::ProductVectorDiscreteUnivariate, + x::AbstractVector{<:Integer}, +) + return sum(logpdf.(dist.dists, x)) +end +function Distributions.rand( + rng::Random.AbstractRNG, + dist::ProductVectorDiscreteUnivariate, +) + return rand.(Ref(rng), dist.dists) +end + +struct ProductMatrixDiscreteUnivariate{ + Tdists <: AbstractMatrix{<:DiscreteUnivariateDistribution}, +} <: DiscreteMatrixDistribution + dists::Tdists +end +Base.size(dists::ProductMatrixDiscreteUnivariate) = size(dist.dists) +function ArrayDist(dists::AbstractMatrix{<:DiscreteUnivariateDistribution}) + return ProductMatrixDiscreteUnivariate(dists) +end +function Distributions.logpdf( + dist::ProductMatrixDiscreteUnivariate, + x::AbstractMatrix{<:Real}, +) + return sum(logpdf.(dist.dists, x)) +end +function Distributions.rand(rng::Random.AbstractRNG, dist::ProductMatrixDiscreteUnivariate) + return rand.(Ref(rng), dist.dists) +end diff --git a/src/multi.jl b/src/multi.jl new file mode 100644 index 00000000..6e77eb07 --- /dev/null +++ b/src/multi.jl @@ -0,0 +1,133 @@ +# Multivariate continuous + +struct MultipleContinuousMultivariate{ + Tdist <: ContinuousMultivariateDistribution +} <: ContinuousMatrixDistribution + dist::Tdist + N::Int +end +Base.size(dist::MultipleContinuousMultivariate) = (length(dist.dist), dist.N) +function Multi(dist::ContinuousMultivariateDistribution, N::Int) + return MultipleContinuousMultivariate(dist, N) +end +function Distributions.logpdf( + dist::MultipleContinuousMultivariate, + x::AbstractMatrix{<:Real} +) + return sum(logpdf(dist.dist, x)) +end +function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleContinuousMultivariate) + return rand(rng, dist.dist, dist.N) +end +Distributions.MvNormal(m, s, N::Int) = MultipleContinuousMultivariate(MvNormal(m, s), N) + + +# Multivariate discrete + +struct MultipleDiscreteMultivariate{ + Tdist <: DiscreteMultivariateDistribution +} <: DiscreteMatrixDistribution + dist::Tdist + N::Int +end +Base.size(dist::MultipleDiscreteMultivariate) = (length(dist.dist), dist.N) +function Multi(dist::DiscreteMultivariateDistribution, N::Int) + return MultipleDiscreteMultivariate(dist, N) +end +function Distributions.logpdf( + dist::MultipleDiscreteMultivariate, + x::AbstractMatrix{<:Integer} +) + return sum(logpdf(dist.dist, x)) +end +function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleDiscreteMultivariate) + return rand(rng, dist.dist, dist.N) +end + +# Univariate continuous + +struct MultipleContinuousUnivariate{ + Tdist <: ContinuousUnivariateDistribution, +} <: ContinuousMultivariateDistribution + dist::Tdist + N::Int +end +Base.length(dist::MultipleContinuousUnivariate) = dist.N +Base.size(dist::MultipleContinuousUnivariate) = (dist.N,) +function Multi(dist::ContinuousUnivariateDistribution, N::Int) + return MultipleContinuousUnivariate(dist, N) +end +function Distributions.logpdf( + dist::MultipleContinuousUnivariate, + x::AbstractVector{<:Real} +) + return sum(logpdf.(dist.dist, x)) +end +function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleContinuousUnivariate) + return rand(rng, dist.dist, dist.N) +end + +struct MatrixContinuousUnivariate{ + Tdist <: ContinuousUnivariateDistribution, + Tsize <: NTuple{2, Integer}, +} <: ContinuousMatrixDistribution + dist::Tdist + S::Tsize +end +Base.size(dist::MatrixContinuousUnivariate) = dist.S +function Multi(dist::ContinuousUnivariateDistribution, N1::Integer, N2::Integer) + return MatrixContinuousUnivariate(dist, (N1, N2)) +end +function Distributions.logpdf( + dist::MatrixContinuousUnivariate, + x::AbstractMatrix{<:Real} +) + return sum(logpdf.(dist.dist, x)) +end +function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixContinuousUnivariate) + return rand(rng, dist.dist, dist.S) +end + +# Univariate discrete + +struct MultipleDiscreteUnivariate{ + Tdist <: DiscreteUnivariateDistribution, +} <: ContinuousMultivariateDistribution + dist::Tdist + N::Int +end +Base.length(dist::MultipleDiscreteUnivariate) = dist.N +Base.size(dist::MultipleDiscreteUnivariate) = (dist.N,) +function Multi(dist::DiscreteUnivariateDistribution, N::Int) + MultipleDiscreteUnivariate(dist, N) +end +function Distributions.logpdf( + dist::MultipleDiscreteUnivariate, + x::AbstractVector{<:Integer} +) + return sum(logpdf.(dist.dist, x)) +end +function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleDiscreteUnivariate) + return rand(rng, dist.dist, dist.N) +end + +struct MatrixDiscreteUnivariate{ + Tdist <: DiscreteUnivariateDistribution, + Tsize <: NTuple{2, Integer}, +} <: DiscreteMatrixDistribution + dist::Tdist + S::Tsize +end +Base.size(dist::MatrixDiscreteUnivariate) = dist.S +function Multi(dist::DiscreteUnivariateDistribution, N1::Integer, N2::Integer) + return MatrixDiscreteUnivariate(dist, (N1, N2)) +end +function Distributions.logpdf( + dist::MatrixDiscreteUnivariate, + x::AbstractMatrix{<:Real} +) + return sum(logpdf.(dist.dist, x)) +end +function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixDiscreteUnivariate) + return rand(rng, dist.dist, dist.S) +end diff --git a/src/multivariate.jl b/src/multivariate.jl index 157133e5..36761e5d 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -50,6 +50,24 @@ function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal) end function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal, n::Int) return d.m .+ d.σ .* randn(rng, length(d), n) + +Base.length(d::TuringDiagMvNormal) = length(d.m) +Base.size(d::TuringDiagMvNormal) = (length(d), length(d)) +function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal) + return d.m .+ d.σ .* randn(rng, length(d)) +end + + +struct TuringScalMvNormal{Tm<:AbstractVector, Tσ<:Real} <: ContinuousMultivariateDistribution + m::Tm + σ::Tσ +end + +Base.length(d::TuringScalMvNormal) = length(d.m) +Base.size(d::TuringScalMvNormal) = (length(d), length(d)) +function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal) + return d.m .+ d.σ .* randn(rng, length(d)) +>>>>>>> multiple distributions as one end for T in (:AbstractVector, :AbstractMatrix) @@ -76,6 +94,31 @@ function _logpdf(d::TuringDenseMvNormal, x::AbstractVector) end function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix) return -((size(x, 1) * log(2π) + logdet(d.C)) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2 +======= +for T in (:TrackedVector, :TrackedMatrix) + @eval function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.ScalMat}, x::$T) + logpdf(TuringScalMvNormal(d.μ, d.Σ.value), x) + end +end + +function _logpdf(d::TuringScalMvNormal, x::AbstractVector) + return -(length(x) * log(2π) + 2 * sum(log(d.σ)) + sum(abs2, (x .- d.m) ./ d.σ)) / 2 +end +function _logpdf(d::TuringScalMvNormal, x::AbstractMatrix) + return -(size(x, 2) * log(2π) .+ 2 * sum(log(d.σ)) .+ sum(abs2, (x .- d.m) ./ d.σ, dims=1)') ./ 2 +end + +function _logpdf(d::TuringDiagMvNormal, x::AbstractVector) + return -(length(x) * log(2π) + 2 * sum(log.(d.σ)) + sum(abs2, (x .- d.m) ./ d.σ)) / 2 +end +function _logpdf(d::TuringDiagMvNormal, x::AbstractMatrix) + return -(size(x, 2) * log(2π) .+ 2 * sum(log.(d.σ)) .+ sum(abs2, (x .- d.m) ./ d.σ, dims=1)') ./ 2 +end +function _logpdf(d::TuringDenseMvNormal, x::AbstractVector) + return -(length(x) * log(2π) + logdet(d.C) + sum(abs2, zygote_ldiv(d.C.U', x .- d.m))) / 2 +end +function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix) + return -(size(x, 2) * log(2π) .+ logdet(d.C) .+ sum(abs2, zygote_ldiv(d.C.U', x .- d.m), dims=1)') ./ 2 end import StatsBase: entropy @@ -91,9 +134,9 @@ MvNormal(A::TrackedMatrix) = TuringMvNormal(A) MvNormal(σ::TrackedVector) = TuringMvNormal(σ) # dense mean, dense covariance -MvNormal(m::TrackedVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvNormal(m, A) -MvNormal(m::TrackedVector{<:Real}, A::Matrix{<:Real}) = TuringMvNormal(m, A) -MvNormal(m::AbstractVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvNormal(m, A) +MvNormal(m::TrackedVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringDenseMvNormal(m, A) +MvNormal(m::TrackedVector{<:Real}, A::Matrix{<:Real}) = TuringDenseMvNormal(m, A) +MvNormal(m::AbstractVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringDenseMvNormal(m, A) # dense mean, diagonal covariance function MvNormal( @@ -205,9 +248,9 @@ MvLogNormal(A::TrackedMatrix) = TuringMvLogNormal(TuringMvNormal(A)) MvLogNormal(σ::TrackedVector) = TuringMvLogNormal(TuringMvNormal(σ)) # dense mean, dense covariance -MvLogNormal(m::TrackedVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, A)) -MvLogNormal(m::TrackedVector{<:Real}, A::Matrix{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, A)) -MvLogNormal(m::AbstractVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, A)) +MvLogNormal(m::TrackedVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvLogNormal(TuringDenseMvNormal(m, A)) +MvLogNormal(m::TrackedVector{<:Real}, A::Matrix{<:Real}) = TuringMvLogNormal(TuringDenseMvNormal(m, A)) +MvLogNormal(m::AbstractVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvLogNormal(TuringDenseMvNormal(m, A)) # dense mean, diagonal covariance function MvLogNormal( From 151a55e5ecc77a75a5d83d7c3fe35979ab092b28 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Fri, 10 Jan 2020 21:38:10 +1100 Subject: [PATCH 02/24] numerous bug and perf fixes --- Project.toml | 3 ++ src/DistributionsAD.jl | 3 ++ src/common.jl | 8 +++++ src/flatten.jl | 72 ++++++++++++++++++++++++++++++++++++++++++ src/multi.jl | 5 +-- src/univariate.jl | 69 ++++++++++++++++++++++++++++------------ 6 files changed, 137 insertions(+), 23 deletions(-) create mode 100644 src/flatten.jl diff --git a/Project.toml b/Project.toml index 01ebc0d0..8bcfcb96 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ version = "0.3.2" [deps] Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -19,10 +20,12 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Combinatorics = "0.7" Distributions = "0.22" +DiffRules = "0.1, 1.0" ForwardDiff = "0.10.6" PDMats = "0.9" SpecialFunctions = "0.8, 0.9, 0.10" StatsFuns = "0.8, 0.9" +SpecialFunctions = "0.8, 0.9" Tracker = "0.2.5" Zygote = "0.4.7" ZygoteRules = "0.2" diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index ce741198..1e23d1f0 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -16,6 +16,8 @@ using ZygoteRules: ZygoteRules, pullback using LinearAlgebra: copytri! using Distributions: AbstractMvLogNormal, ContinuousMultivariateDistribution +using DiffRules, SpecialFunctions +using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here import StatsFuns: logsumexp, binomlogpdf, @@ -44,6 +46,7 @@ include("univariate.jl") include("multivariate.jl") include("matrixvariate.jl") include("multi.jl") +include("flatten.jl") include("array_dist.jl") end diff --git a/src/common.jl b/src/common.jl index dc9788d7..47d2d91e 100644 --- a/src/common.jl +++ b/src/common.jl @@ -88,6 +88,7 @@ function LinearAlgebra.logdet(C::Cholesky{<:TrackedReal, <:TrackedMatrix}) end # Tracker's implementation of ldiv isn't good. We'll use Zygote's instead. + zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B function zygote_ldiv(A::TrackedMatrix, B::TrackedVecOrMat) return track(zygote_ldiv, A, B) @@ -104,3 +105,10 @@ end function Base.:\(a::Cholesky{<:TrackedReal, <:TrackedArray}, b::AbstractVecOrMat) return (a.U \ (a.U' \ b)) end + +# SpecialFunctions + +function SpecialFunctions.logabsgamma(x::TrackedReal) + v = loggamma(x) + return v, sign(data(v)) +end diff --git a/src/flatten.jl b/src/flatten.jl new file mode 100644 index 00000000..b2ab2432 --- /dev/null +++ b/src/flatten.jl @@ -0,0 +1,72 @@ +function getexpr(Tdist) + x = gensym() + fnames = fieldnames(Tdist) + flattened_args = Expr(:tuple, [:(dist.$f) for f in fnames]...) + func = Expr(:->, + Expr(:tuple, fnames..., x), + Expr(:block, + Expr(:call, :logpdf, + Expr(:call, :($Tdist), fnames...), + x, + ) + ) + ) + return :(flatten(dist::$Tdist) = ($func, $flattened_args)) +end +for T in ( Bernoulli, + BetaBinomial, + Binomial, + Geometric, + NegativeBinomial, + Poisson, + Skellam, + PoissonBinomial, + Arcsine, + BetaBinomial, + Binomial, + Geometric, + NegativeBinomial, + Poisson, + Skellam, + PoissonBinomial, + Beta, + BetaPrime, + Biweight, + Cauchy, + Chernoff, + Chi, + Chisq, + Cosine, + Epanechnikov, + Erlang, + Exponential, + FDist, + Frechet, + Gamma, + GeneralizedExtremeValue, + GeneralizedPareto, + Gumbel, + InverseGamma, + InverseGaussian, + Kolmogorov, + Laplace, + Levy, + LocationScale, + Logistic, + LogitNormal, + LogNormal, + Normal, + NormalCanon, + NormalInverseGaussian, + Pareto, + PGeneralizedGaussian, + Rayleigh, + SymTriangularDist, + TDist, + TriangularDist, + Triweight, + Categorical, + Truncated, + ) + eval(getexpr(T)) +end diff --git a/src/multi.jl b/src/multi.jl index 6e77eb07..35aacbcf 100644 --- a/src/multi.jl +++ b/src/multi.jl @@ -59,9 +59,10 @@ function Multi(dist::ContinuousUnivariateDistribution, N::Int) end function Distributions.logpdf( dist::MultipleContinuousUnivariate, - x::AbstractVector{<:Real} + x::AbstractVector{<:Real}, ) - return sum(logpdf.(dist.dist, x)) + f, args = flatten(dist.dist) + return sum(f.(args..., x)) end function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleContinuousUnivariate) return rand(rng, dist.dist, dist.N) diff --git a/src/univariate.jl b/src/univariate.jl index 3db7d0df..5e3b4b22 100644 --- a/src/univariate.jl +++ b/src/univariate.jl @@ -46,7 +46,27 @@ ZygoteRules.@adjoint function uniformlogpdf(a, b, x) f = isfinite(l) da = 1/diff n = T(NaN) - return l, Δ->(f ? da : n, f ? -da : n, f ? zero(T) : n) + z = zero(T) + return l, Δ -> (f ? (z, z, z) : (n, n, n)) +end +for T in (:TrackedReal, :Real) + @eval @grad function uniformlogpdf( + a::TrackedReal, + b::TrackedReal, + x::$T, + ) + ad = data(a) + bd = data(b) + T = typeof(a) + l = logpdf(Uniform(ad, bd), x) + f = isfinite(l) + temp = 1/(bd - ad)^2 + dlda = temp + dldb = -temp + n = T(NaN) + z = zero(T) + return l, Δ -> (f ? (dlda * Δ, dldb * Δ, z) : (n, n, n)) + end end ZygoteRules.@adjoint function Distributions.Uniform(args...) return pullback(TuringUniform, args...) @@ -120,25 +140,31 @@ end ## Semicircle ## +function semicircle_dldr(r, x) + diffsq = r^2 - x^2 + return -2 / r + r / diffsq +end +function semicircle_dldx(r, x) + diffsq = r^2 - x^2 + return -x / diffsq +end + logpdf(d::Semicircle{<:Real}, x::TrackedReal) = semicirclelogpdf(d.r, x) logpdf(d::Semicircle{<:TrackedReal}, x::Real) = semicirclelogpdf(d.r, x) logpdf(d::Semicircle{<:TrackedReal}, x::TrackedReal) = semicirclelogpdf(d.r, x) -semicirclelogpdf(r::TrackedReal, x::Real) = track(semicirclelogpdf, r, x) -semicirclelogpdf(r::Real, x::TrackedReal) = track(semicirclelogpdf, r, x) -semicirclelogpdf(r::TrackedReal, x::TrackedReal) = track(semicirclelogpdf, r, x) -Tracker.@grad function semicirclelogpdf(r, x) - rd = data(r) - xd = data(x) - xx, rr = promote(xd, float(rd)) - d = Semicircle(rr) - T = typeof(xx) - l = logpdf(d, xx) - f = isfinite(l) - n = T(NaN) - return l, function (Δ) - diffsq = rr^2 - xx^2 - (f ? Δ*(-2/rr + rr/diffsq) : n, f ? Δ*(-xx/diffsq) : n) - end + +semicirclelogpdf(r, x) = logpdf(Semicircle(r), x) +M, f, arity = DiffRules.@define_diffrule DistributionsAD.semicirclelogpdf(r, x) = + :(semicircle_dldr($r, $x)), :(semicircle_dldx($r, $x)) +da, db = DiffRules.diffrule(M, f, :a, :b) +f = :($M.$f) +@eval begin + @grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, Tracker._zero(b)) + @grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (Tracker._zero(a), Δ * $db) + $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) + $f(a::TrackedReal, b::Real) = track($f, a, b) + $f(a::Real, b::TrackedReal) = track($f, a, b) end if VERSION < v"1.2" Base.inv(::Irrational{:π}) = 1/π @@ -191,10 +217,11 @@ function nbinomlogpdf(r::ForwardDiff.Dual{T}, p::ForwardDiff.Dual{T}, k::Int) wh FD = ForwardDiff.Dual{T} val_p = ForwardDiff.value(p) val_r = ForwardDiff.value(r) - - Δ_r = ForwardDiff.partials(r) * _nbinomlogpdf_grad_1(val_r, val_p, k) - Δ_p = ForwardDiff.partials(p) * _nbinomlogpdf_grad_2(val_r, val_p, k) - Δ = Δ_p + Δ_r + Δ_r = ForwardDiff.partials(r) + dr = _nbinomlogpdf_grad_1(val_r, val_p, k) + Δ_p = ForwardDiff.partials(p) + dp = _nbinomlogpdf_grad_2(val_r, val_p, k) + Δ = ForwardDiff._mul_partials(Δ_r, Δ_p, dr, dp) return FD(nbinomlogpdf(val_r, val_p, k), Δ) end function nbinomlogpdf(r::Real, p::ForwardDiff.Dual{T}, k::Int) where {T} From 47a4af5fa88943906228bd6bb32799b86fdfe998 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Fri, 10 Jan 2020 11:29:08 +0000 Subject: [PATCH 03/24] remove duplicate items in list --- src/flatten.jl | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/flatten.jl b/src/flatten.jl index b2ab2432..16372a89 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -22,13 +22,6 @@ for T in ( Bernoulli, Skellam, PoissonBinomial, Arcsine, - BetaBinomial, - Binomial, - Geometric, - NegativeBinomial, - Poisson, - Skellam, - PoissonBinomial, Beta, BetaPrime, Biweight, From 1b336f70c7ecfaa3119596eef830b8328f19ba5f Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Thu, 23 Jan 2020 10:29:53 +1100 Subject: [PATCH 04/24] fix rebase and consistency issues --- src/multivariate.jl | 76 +++++++++------------------------------------ src/univariate.jl | 49 +++++++++++------------------ 2 files changed, 33 insertions(+), 92 deletions(-) diff --git a/src/multivariate.jl b/src/multivariate.jl index 36761e5d..3d751164 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -13,11 +13,9 @@ function TuringDenseMvNormal(m::AbstractVector, A::AbstractMatrix) return TuringDenseMvNormal(m, cholesky(A)) end Base.length(d::TuringDenseMvNormal) = length(d.m) -function Distributions.rand(rng::Random.AbstractRNG, d::TuringDenseMvNormal) - return d.m .+ d.C.U' * randn(rng, length(d)) -end -function Distributions.rand(rng::Random.AbstractRNG, d::TuringDenseMvNormal, n::Int) - return d.m .+ d.C.U' * randn(rng, length(d), n) +Distributions.rand(d::TuringDenseMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...) +function Distributions.rand(rng::Random.AbstractRNG, d::TuringDenseMvNormal, n::Int...) + return d.m .+ d.C.U' * randn(rng, length(d), n...) end """ @@ -32,32 +30,13 @@ end Distributions.params(d::TuringDiagMvNormal) = (d.m, d.σ) Distributions.dim(d::TuringDiagMvNormal) = length(d.m) -Base.length(d::TuringDiagMvNormal) = length(d.m) -Base.size(d::TuringDiagMvNormal) = (length(d), ) -function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal, n::Int) - return d.m .+ d.σ .* randn(rng, length(d), n) -end - -struct TuringScalMvNormal{Tm<:AbstractVector, Tσ<:Real} <: ContinuousMultivariateDistribution - m::Tm - σ::Tσ -end - -Base.length(d::TuringScalMvNormal) = length(d.m) -Base.size(d::TuringScalMvNormal) = (length(d), ) -function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal) - return d.m .+ d.σ .* randn(rng, length(d)) -end -function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal, n::Int) - return d.m .+ d.σ .* randn(rng, length(d), n) - Base.length(d::TuringDiagMvNormal) = length(d.m) Base.size(d::TuringDiagMvNormal) = (length(d), length(d)) -function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal) - return d.m .+ d.σ .* randn(rng, length(d)) +Distributions.rand(d::TuringDiagMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...) +function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal, n::Int...) + return d.m .+ d.σ .* randn(rng, length(d), n...) end - struct TuringScalMvNormal{Tm<:AbstractVector, Tσ<:Real} <: ContinuousMultivariateDistribution m::Tm σ::Tσ @@ -65,9 +44,9 @@ end Base.length(d::TuringScalMvNormal) = length(d.m) Base.size(d::TuringScalMvNormal) = (length(d), length(d)) -function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal) - return d.m .+ d.σ .* randn(rng, length(d)) ->>>>>>> multiple distributions as one +Distributions.rand(d::TuringScalMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...) +function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal, n::Int...) + return d.m .+ d.σ .* randn(rng, length(d), n...) end for T in (:AbstractVector, :AbstractMatrix) @@ -94,31 +73,6 @@ function _logpdf(d::TuringDenseMvNormal, x::AbstractVector) end function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix) return -((size(x, 1) * log(2π) + logdet(d.C)) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2 -======= -for T in (:TrackedVector, :TrackedMatrix) - @eval function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.ScalMat}, x::$T) - logpdf(TuringScalMvNormal(d.μ, d.Σ.value), x) - end -end - -function _logpdf(d::TuringScalMvNormal, x::AbstractVector) - return -(length(x) * log(2π) + 2 * sum(log(d.σ)) + sum(abs2, (x .- d.m) ./ d.σ)) / 2 -end -function _logpdf(d::TuringScalMvNormal, x::AbstractMatrix) - return -(size(x, 2) * log(2π) .+ 2 * sum(log(d.σ)) .+ sum(abs2, (x .- d.m) ./ d.σ, dims=1)') ./ 2 -end - -function _logpdf(d::TuringDiagMvNormal, x::AbstractVector) - return -(length(x) * log(2π) + 2 * sum(log.(d.σ)) + sum(abs2, (x .- d.m) ./ d.σ)) / 2 -end -function _logpdf(d::TuringDiagMvNormal, x::AbstractMatrix) - return -(size(x, 2) * log(2π) .+ 2 * sum(log.(d.σ)) .+ sum(abs2, (x .- d.m) ./ d.σ, dims=1)') ./ 2 -end -function _logpdf(d::TuringDenseMvNormal, x::AbstractVector) - return -(length(x) * log(2π) + logdet(d.C) + sum(abs2, zygote_ldiv(d.C.U', x .- d.m))) / 2 -end -function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix) - return -(size(x, 2) * log(2π) .+ logdet(d.C) .+ sum(abs2, zygote_ldiv(d.C.U', x .- d.m), dims=1)') ./ 2 end import StatsBase: entropy @@ -134,9 +88,9 @@ MvNormal(A::TrackedMatrix) = TuringMvNormal(A) MvNormal(σ::TrackedVector) = TuringMvNormal(σ) # dense mean, dense covariance -MvNormal(m::TrackedVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringDenseMvNormal(m, A) -MvNormal(m::TrackedVector{<:Real}, A::Matrix{<:Real}) = TuringDenseMvNormal(m, A) -MvNormal(m::AbstractVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringDenseMvNormal(m, A) +MvNormal(m::TrackedVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvNormal(m, A) +MvNormal(m::TrackedVector{<:Real}, A::Matrix{<:Real}) = TuringMvNormal(m, A) +MvNormal(m::AbstractVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvNormal(m, A) # dense mean, diagonal covariance function MvNormal( @@ -248,9 +202,9 @@ MvLogNormal(A::TrackedMatrix) = TuringMvLogNormal(TuringMvNormal(A)) MvLogNormal(σ::TrackedVector) = TuringMvLogNormal(TuringMvNormal(σ)) # dense mean, dense covariance -MvLogNormal(m::TrackedVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvLogNormal(TuringDenseMvNormal(m, A)) -MvLogNormal(m::TrackedVector{<:Real}, A::Matrix{<:Real}) = TuringMvLogNormal(TuringDenseMvNormal(m, A)) -MvLogNormal(m::AbstractVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvLogNormal(TuringDenseMvNormal(m, A)) +MvLogNormal(m::TrackedVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, A)) +MvLogNormal(m::TrackedVector{<:Real}, A::Matrix{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, A)) +MvLogNormal(m::AbstractVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, A)) # dense mean, diagonal covariance function MvLogNormal( diff --git a/src/univariate.jl b/src/univariate.jl index 5e3b4b22..aa624fd7 100644 --- a/src/univariate.jl +++ b/src/univariate.jl @@ -33,39 +33,25 @@ uniformlogpdf(a::TrackedReal, b::TrackedReal, x::TrackedReal) = track(uniformlog Tracker.@grad function uniformlogpdf(a, b, x) diff = data(b) - data(a) T = typeof(diff) - l = -log(diff) - f = isfinite(l) - da = 1/diff - n = T(NaN) - return l, Δ->(f ? da : n, f ? -da : n, f ? zero(T) : n) + if a <= data(x) <= b && a < b + l = -log(diff) + da = 1/diff^2 + return l, Δ -> (da * Δ, -da * Δ, zero(T) * Δ) + else + n = T(NaN) + return l, Δ -> (n, n, n) + end end ZygoteRules.@adjoint function uniformlogpdf(a, b, x) diff = b - a T = typeof(diff) - l = -log(diff) - f = isfinite(l) - da = 1/diff - n = T(NaN) - z = zero(T) - return l, Δ -> (f ? (z, z, z) : (n, n, n)) -end -for T in (:TrackedReal, :Real) - @eval @grad function uniformlogpdf( - a::TrackedReal, - b::TrackedReal, - x::$T, - ) - ad = data(a) - bd = data(b) - T = typeof(a) - l = logpdf(Uniform(ad, bd), x) - f = isfinite(l) - temp = 1/(bd - ad)^2 - dlda = temp - dldb = -temp + if a <= x <= b && a < b + l = -log(diff) + da = 1/diff^2 + return l, Δ -> (da * Δ, -da * Δ, zero(T) * Δ) + else n = T(NaN) - z = zero(T) - return l, Δ -> (f ? (dlda * Δ, dldb * Δ, z) : (n, n, n)) + return l, Δ -> (n, n, n) end end ZygoteRules.@adjoint function Distributions.Uniform(args...) @@ -159,9 +145,9 @@ M, f, arity = DiffRules.@define_diffrule DistributionsAD.semicirclelogpdf(r, x) da, db = DiffRules.diffrule(M, f, :a, :b) f = :($M.$f) @eval begin - @grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) - @grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, Tracker._zero(b)) - @grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (Tracker._zero(a), Δ * $db) + Tracker.@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + Tracker.@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, Tracker._zero(b)) + Tracker.@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (Tracker._zero(a), Δ * $db) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b) $f(a::Real, b::TrackedReal) = track($f, a, b) @@ -292,6 +278,7 @@ ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray) ((ForwardDiff.jacobian(x -> poissonbinomial_pdf_fft(x), x)::Matrix{T})' * Δ,) end end + # The code below doesn't work because of bugs in Zygote. The above is inefficient. #= ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{<:Real}) From 2fda51776747964f726bc52932268d5078aa6d88 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Thu, 23 Jan 2020 16:38:46 +1100 Subject: [PATCH 05/24] flatten all multi of univariate by default --- src/flatten.jl | 107 +++++++++++++++++++++++++++---------------------- src/multi.jl | 17 +++++--- 2 files changed, 71 insertions(+), 53 deletions(-) diff --git a/src/flatten.jl b/src/flatten.jl index 16372a89..acd862ad 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -1,3 +1,9 @@ +macro register(dist) + return quote + DistributionsAD.eval(getexpr($(esc(dist)))) + DistributionsAD.toflatten(::$(esc(dist))) = true + end +end function getexpr(Tdist) x = gensym() fnames = fieldnames(Tdist) @@ -13,53 +19,58 @@ function getexpr(Tdist) ) return :(flatten(dist::$Tdist) = ($func, $flattened_args)) end -for T in ( Bernoulli, - BetaBinomial, - Binomial, - Geometric, - NegativeBinomial, - Poisson, - Skellam, - PoissonBinomial, - Arcsine, - Beta, - BetaPrime, - Biweight, - Cauchy, - Chernoff, - Chi, - Chisq, - Cosine, - Epanechnikov, - Erlang, - Exponential, - FDist, - Frechet, - Gamma, - GeneralizedExtremeValue, - GeneralizedPareto, - Gumbel, - InverseGamma, - InverseGaussian, - Kolmogorov, - Laplace, - Levy, - LocationScale, - Logistic, - LogitNormal, - LogNormal, - Normal, - NormalCanon, - NormalInverseGaussian, - Pareto, - PGeneralizedGaussian, - Rayleigh, - SymTriangularDist, - TDist, - TriangularDist, - Triweight, - Categorical, - Truncated, - ) +const flattened_dists = [ Bernoulli, + BetaBinomial, + Binomial, + Geometric, + NegativeBinomial, + Poisson, + Skellam, + PoissonBinomial, + Arcsine, + Beta, + BetaPrime, + Biweight, + Cauchy, + Chernoff, + Chi, + Chisq, + Cosine, + Epanechnikov, + Erlang, + Exponential, + FDist, + Frechet, + Gamma, + GeneralizedExtremeValue, + GeneralizedPareto, + Gumbel, + InverseGamma, + InverseGaussian, + Kolmogorov, + Laplace, + Levy, + LocationScale, + Logistic, + LogitNormal, + LogNormal, + Normal, + NormalCanon, + NormalInverseGaussian, + Pareto, + PGeneralizedGaussian, + Rayleigh, + SymTriangularDist, + TDist, + TriangularDist, + Triweight, + Categorical, + Truncated, + ] +for T in flattened_dists + @eval toflatten(::T) = true +end +toflatten(::Distribution) = false +for T in flattened_dists eval(getexpr(T)) end diff --git a/src/multi.jl b/src/multi.jl index 35aacbcf..03a04ee7 100644 --- a/src/multi.jl +++ b/src/multi.jl @@ -61,12 +61,19 @@ function Distributions.logpdf( dist::MultipleContinuousUnivariate, x::AbstractVector{<:Real}, ) - f, args = flatten(dist.dist) - return sum(f.(args..., x)) + return _flat_logpdf(dist.dist, x) end function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleContinuousUnivariate) return rand(rng, dist.dist, dist.N) end +function _flat_logpdf(dist, x) + if toflatten(dist) + f, args = flatten(dist) + return sum(f.(args..., x)) + else + return sum(logpdf.(dist, x)) + end +end struct MatrixContinuousUnivariate{ Tdist <: ContinuousUnivariateDistribution, @@ -83,7 +90,7 @@ function Distributions.logpdf( dist::MatrixContinuousUnivariate, x::AbstractMatrix{<:Real} ) - return sum(logpdf.(dist.dist, x)) + return _flat_logpdf(dist.dist, x) end function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixContinuousUnivariate) return rand(rng, dist.dist, dist.S) @@ -106,7 +113,7 @@ function Distributions.logpdf( dist::MultipleDiscreteUnivariate, x::AbstractVector{<:Integer} ) - return sum(logpdf.(dist.dist, x)) + return _flat_logpdf(dist.dist, x) end function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleDiscreteUnivariate) return rand(rng, dist.dist, dist.N) @@ -127,7 +134,7 @@ function Distributions.logpdf( dist::MatrixDiscreteUnivariate, x::AbstractMatrix{<:Real} ) - return sum(logpdf.(dist.dist, x)) + return _flat_logpdf(dist.dist, x) end function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixDiscreteUnivariate) return rand(rng, dist.dist, dist.S) From ec4a7784bf75668517068fbe86ef0af81869dd9f Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Thu, 23 Jan 2020 17:49:27 +1100 Subject: [PATCH 06/24] bug fix --- src/flatten.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flatten.jl b/src/flatten.jl index acd862ad..a305d1b5 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -68,7 +68,7 @@ const flattened_dists = [ Bernoulli, Truncated, ] for T in flattened_dists - @eval toflatten(::T) = true + @eval toflatten(::$T) = true end toflatten(::Distribution) = false for T in flattened_dists From f8cf678f4d6a6492a6a8c287eb83feca60416b89 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 1 Feb 2020 07:45:14 +1100 Subject: [PATCH 07/24] minor rebase issue --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8bcfcb96..8d1bf60d 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,6 @@ ForwardDiff = "0.10.6" PDMats = "0.9" SpecialFunctions = "0.8, 0.9, 0.10" StatsFuns = "0.8, 0.9" -SpecialFunctions = "0.8, 0.9" Tracker = "0.2.5" Zygote = "0.4.7" ZygoteRules = "0.2" From 7f7651ba9555b90c60718fca1dd9368a63eaa968 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Mon, 3 Feb 2020 21:24:52 +1100 Subject: [PATCH 08/24] fix Dirichlet --- src/DistributionsAD.jl | 1 + src/multivariate.jl | 88 ++++++++++++++++++++++++++++++++++++++++++ test/distributions.jl | 4 +- 3 files changed, 91 insertions(+), 2 deletions(-) diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 1e23d1f0..37665114 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -18,6 +18,7 @@ using Distributions: AbstractMvLogNormal, ContinuousMultivariateDistribution using DiffRules, SpecialFunctions using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here +using Base.Iterators: drop import StatsFuns: logsumexp, binomlogpdf, diff --git a/src/multivariate.jl b/src/multivariate.jl index 3d751164..3dc2d03c 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -1,3 +1,91 @@ +## Dirichlet ## + +struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution + alpha::TV + alpha0::T + lmnB::T +end +function check(alpha) + all(ai -> ai > 0, alpha) || + throw(ArgumentError("Dirichlet: alpha must be a positive vector.")) +end +Zygote.@nograd DistributionsAD.check + +function TuringDirichlet(alpha::AbstractVector) + check(alpha) + alpha0 = sum(alpha) + lmnB = sum(loggamma, alpha) - loggamma(alpha0) + T = promote_type(typeof(alpha0), typeof(lmnB)) + TV = typeof(alpha) + TuringDirichlet{T, TV}(alpha, alpha0, lmnB) +end + +function TuringDirichlet(d::Integer, alpha::Real) + alpha0 = alpha * d + _alpha = fill(alpha, d) + lmnB = loggamma(alpha) * d - loggamma(alpha0) + T = promote_type(typeof(alpha0), typeof(lmnB)) + TV = typeof(_alpha) + TuringDirichlet{T, TV}(_alpha, alpha0, lmnB) +end +function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer} + Tf = float(T) + TuringDirichlet(convert(AbstractVector{Tf}, alpha)) +end +TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha)) + +Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha) +Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha) + +function Distributions.logpdf(d::TuringDirichlet, x::AbstractVector) + simplex_logpdf(d.alpha, d.lmnB, x) +end +function Distributions.logpdf(d::TuringDirichlet, x::AbstractMatrix) + simplex_logpdf(d.alpha, d.lmnB, x) +end +function Distributions.logpdf(d::Dirichlet{T}, x::TrackedVecOrMat) where {T} + TV = typeof(d.alpha) + logpdf(TuringDirichlet{T, TV}(d.alpha, d.alpha0, d.lmnB), x) +end + +ZygoteRules.@adjoint function Distributions.Dirichlet(alpha) + return pullback(TuringDirichlet, alpha) +end +ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha) + return pullback(TuringDirichlet, d, alpha) +end + +function simplex_logpdf(alpha, lmnB, x::AbstractVector) + sum((alpha .- 1) .* log.(x)) - lmnB +end +function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) + @views init = vcat(sum((alpha .- 1) .* log.(x[:,1]))) + mapreduce(vcat, drop(eachcol(x), 1); init = init) do c + sum((alpha .- 1) .* log.(c)) - lmnB + end +end + +Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector) + simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin + (Δ .* log.(data(x)), -Δ, Δ .* (data(alpha) .- 1)) + end +end +Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) + simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin + (log.(data(x)) * Δ, -sum(Δ), repeat(data(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ)) + end +end + +ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector) + simplex_logpdf(alpha, lmnB, x), Δ -> (Δ .* log.(x), -Δ, Δ .* (alpha .- 1)) +end + +ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) + simplex_logpdf(alpha, lmnB, x), Δ -> begin + (log.(x) * Δ, -sum(Δ), repeat(alpha .- 1, 1, size(x, 2)) * Diagonal(Δ)) + end +end + ## MvNormal ## """ diff --git a/test/distributions.jl b/test/distributions.jl index 450c689f..22635d50 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -205,6 +205,8 @@ separator() DistSpec(:MvLogNormal, (cov_vec,), norm_val_mat), DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_mat), DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_mat), + DistSpec(:Dirichlet, (alpha,), dir_val), + DistSpec(:Dirichlet, (alpha,), dir_val), ] broken_mult_cont_dists = [ @@ -215,14 +217,12 @@ separator() DistSpec(:MvNormalCanon, (cov_mat,), norm_val_vec), DistSpec(:MvNormalCanon, (cov_vec,), norm_val_vec), DistSpec(:(cov_num -> MvNormalCanon(dim, cov_num)), (cov_num,), norm_val_vec), - DistSpec(:Dirichlet, (alpha,), dir_val), DistSpec(:MvNormalCanon, (mean, cov_mat), norm_val_mat), DistSpec(:MvNormalCanon, (mean, cov_vec), norm_val_mat), DistSpec(:MvNormalCanon, (mean, cov_num), norm_val_mat), DistSpec(:MvNormalCanon, (cov_mat,), norm_val_mat), DistSpec(:MvNormalCanon, (cov_vec,), norm_val_mat), DistSpec(:(cov_num -> MvNormalCanon(dim, cov_num)), (cov_num,), norm_val_mat), - DistSpec(:Dirichlet, (alpha,), dir_val), # Test failure DistSpec(:MvNormal, (mean, cov_mat), norm_val_mat), DistSpec(:MvNormal, (cov_mat,), norm_val_mat), From 03295647d3c4492b344287004e1502f5268d9b1f Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 15 Feb 2020 17:05:14 +1100 Subject: [PATCH 09/24] apply David's comments and add tests --- Project.toml | 2 + src/DistributionsAD.jl | 12 +-- src/array_dist.jl | 188 ++++++++++++++--------------------------- src/common.jl | 46 ++++++++-- src/matrixvariate.jl | 4 +- src/multi.jl | 160 +++++++++++------------------------ src/multivariate.jl | 22 ++--- src/univariate.jl | 42 ++++----- test/distributions.jl | 12 +++ 9 files changed, 202 insertions(+), 286 deletions(-) diff --git a/Project.toml b/Project.toml index 8d1bf60d..1d23ad62 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.3.2" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" @@ -21,6 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" Combinatorics = "0.7" Distributions = "0.22" DiffRules = "0.1, 1.0" +FillArrays = "0.8" ForwardDiff = "0.10.6" PDMats = "0.9" SpecialFunctions = "0.8, 0.9, 0.10" diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 37665114..67d5cedc 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -11,12 +11,12 @@ using PDMats, StatsFuns using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray, - TrackedVecOrMat, track, data -using ZygoteRules: ZygoteRules, pullback + TrackedVecOrMat, track, @grad, data +using ZygoteRules: ZygoteRules, @adjoint, pullback using LinearAlgebra: copytri! using Distributions: AbstractMvLogNormal, ContinuousMultivariateDistribution -using DiffRules, SpecialFunctions +using DiffRules, SpecialFunctions, FillArrays using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here using Base.Iterators: drop @@ -39,15 +39,15 @@ export TuringScalMvNormal, TuringPoissonBinomial, TuringWishart, TuringInverseWishart, - Multi, - ArrayDist + ArrayDist, + FillDist include("common.jl") include("univariate.jl") include("multivariate.jl") include("matrixvariate.jl") -include("multi.jl") include("flatten.jl") include("array_dist.jl") +include("multi.jl") end diff --git a/src/array_dist.jl b/src/array_dist.jl index bfc957bb..c3c39698 100644 --- a/src/array_dist.jl +++ b/src/array_dist.jl @@ -1,152 +1,90 @@ -# Multivariate continuous +# Univariate -struct ProductVectorContinuousMultivariate{ - Tdists <: AbstractVector{<:ContinuousMultivariateDistribution}, -} <: ContinuousMatrixDistribution - dists::Tdists -end -Base.size(dist::ProductVectorContinuousMultivariate) = (length(dist.dists[1]), length(dist)) -Base.length(dist::ProductVectorContinuousMultivariate) = length(dist.dists) -function ArrayDist(dists::AbstractVector{<:ContinuousMultivariateDistribution}) - return ProductVectorContinuousMultivariate(dists) -end -function Distributions.logpdf( - dist::ProductVectorContinuousMultivariate, - x::AbstractMatrix{<:Real}, -) - return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist)) -end -function Distributions.logpdf( - dist::ProductVectorContinuousMultivariate, - x::AbstractVector{<:AbstractVector{<:Real}}, -) - return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist)) -end -function Distributions.rand( - rng::Random.AbstractRNG, - dist::ProductVectorContinuousMultivariate, -) - return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 1:length(dist)) -end +const VectorOfUnivariate{ + S <: ValueSupport, + Tdist <: UnivariateDistribution{S}, + Tdists <: AbstractVector{Tdist}, +} = Distributions.Product{S, Tdist, Tdists} -# Multivariate discrete +function ArrayDist(dists::AbstractVector{<:Normal{T}}) where {T} + if T <: TrackedReal + init_m = dists[1].μ + means = mapreduce(vcat, drop(dists, 1); init = init_m) do d + d.μ + end + init_v = dists[1].σ^2 + vars = mapreduce(vcat, drop(dists, 1); init = init_v) do d + d.σ^2 + end + else + means = [d.μ for d in dists] + vars = [d.σ^2 for d in dists] + end -struct ProductVectorDiscreteMultivariate{ - Tdists <: AbstractVector{<:DiscreteMultivariateDistribution}, -} <: DiscreteMatrixDistribution - dists::Tdists + return MvNormal(means, vars) end -Base.size(dist::ProductVectorDiscreteMultivariate) = (length(dist.dists[1]), length(dist)) -Base.length(dist::ProductVectorDiscreteMultivariate) = length(dist.dists) -function ArrayDist(dists::AbstractVector{<:DiscreteMultivariateDistribution}) - return ProductVectorDiscreteMultivariate(dists) -end -function Distributions.logpdf( - dist::ProductVectorDiscreteMultivariate, - x::AbstractMatrix{<:Integer}, -) - return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist)) +function ArrayDist(dists::AbstractVector{<:UnivariateDistribution}) + return Distributions.Product(dists) end -function Distributions.logpdf( - dist::ProductVectorDiscreteMultivariate, - x::AbstractVector{<:AbstractVector{<:Integer}}, -) - return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist)) +function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real}) + return sum(logpdf.(dist.v, x)) end -function Distributions.rand( - rng::Random.AbstractRNG, - dist::ProductVectorDiscreteMultivariate, -) - return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 1:length(dist)) -end - -# Univariate continuous - -struct ProductVectorContinuousUnivariate{ - Tdists <: AbstractVector{<:ContinuousUnivariateDistribution}, -} <: ContinuousMultivariateDistribution - dists::Tdists -end -Base.length(dist::ProductVectorContinuousUnivariate) = length(dist.dists) -Base.size(dist::ProductVectorContinuousUnivariate) = (length(dist),) -function ArrayDist(dists::AbstractVector{<:ContinuousUnivariateDistribution}) - return ProductVectorContinuousUnivariate(dists) +function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real}) + # Any other more efficient implementation breaks Zygote + return [logpdf(dist, x[:,i]) for i in 1:size(x, 2)] end function Distributions.logpdf( - dist::ProductVectorContinuousUnivariate, - x::AbstractVector{<:Real}, -) - return sum(logpdf.(dist.dists, x)) -end -function Distributions.rand( - rng::Random.AbstractRNG, - dist::ProductVectorContinuousUnivariate, + dist::VectorOfUnivariate, + x::AbstractVector{<:AbstractMatrix{<:Real}}, ) - return rand.(Ref(rng), dist.dists) + return logpdf.(Ref(dist), x) end -struct ProductMatrixContinuousUnivariate{ - Tdists <: AbstractMatrix{<:ContinuousUnivariateDistribution}, -} <: ContinuousMatrixDistribution +struct MatrixOfUnivariate{ + S <: ValueSupport, + Tdist <: UnivariateDistribution{S}, + Tdists <: AbstractMatrix{Tdist}, +} <: MatrixDistribution{S} dists::Tdists end -Base.size(dist::ProductMatrixContinuousUnivariate) = size(dist.dists) -function ArrayDist(dists::AbstractMatrix{<:ContinuousUnivariateDistribution}) - return ProductMatrixContinuousUnivariate(dists) +Base.size(dist::MatrixOfUnivariate) = size(dist.dists) +function ArrayDist(dists::AbstractMatrix{<:UnivariateDistribution}) + return MatrixOfUnivariate(dists) end -function Distributions.logpdf( - dist::ProductMatrixContinuousUnivariate, - x::AbstractMatrix{<:Real}, -) - return sum(logpdf.(dist.dists, x)) +function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real}) + # Broadcasting here breaks Tracker for some reason + return sum(zip(dist.dists, x)) do (dist, x) + logpdf(dist, x) + end end -function Distributions.rand( - rng::Random.AbstractRNG, - dist::ProductMatrixContinuousUnivariate, -) +function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate) return rand.(Ref(rng), dist.dists) end -# Univariate discrete +# Multivariate continuous -struct ProductVectorDiscreteUnivariate{ - Tdists <: AbstractVector{<:DiscreteUnivariateDistribution}, -} <: ContinuousMultivariateDistribution +struct VectorOfMultivariate{ + S <: ValueSupport, + Tdist <: MultivariateDistribution{S}, + Tdists <: AbstractVector{Tdist}, +} <: MatrixDistribution{S} dists::Tdists end -Base.length(dist::ProductVectorDiscreteUnivariate) = length(dist.dists) -Base.size(dist::ProductVectorDiscreteUnivariate) = (length(dist.dists[1]), length(dist)) -function ArrayDist(dists::AbstractVector{<:DiscreteUnivariateDistribution}) - ProductVectorDiscreteUnivariate(dists) -end -function Distributions.logpdf( - dist::ProductVectorDiscreteUnivariate, - x::AbstractVector{<:Integer}, -) - return sum(logpdf.(dist.dists, x)) -end -function Distributions.rand( - rng::Random.AbstractRNG, - dist::ProductVectorDiscreteUnivariate, -) - return rand.(Ref(rng), dist.dists) -end - -struct ProductMatrixDiscreteUnivariate{ - Tdists <: AbstractMatrix{<:DiscreteUnivariateDistribution}, -} <: DiscreteMatrixDistribution - dists::Tdists +Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist)) +Base.length(dist::VectorOfMultivariate) = length(dist.dists) +function ArrayDist(dists::AbstractVector{<:MultivariateDistribution}) + return VectorOfMultivariate(dists) end -Base.size(dists::ProductMatrixDiscreteUnivariate) = size(dist.dists) -function ArrayDist(dists::AbstractMatrix{<:DiscreteUnivariateDistribution}) - return ProductMatrixDiscreteUnivariate(dists) +function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) + return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist)) end function Distributions.logpdf( - dist::ProductMatrixDiscreteUnivariate, - x::AbstractMatrix{<:Real}, + dist::VectorOfMultivariate, + x::AbstractVector{<:AbstractVector{<:Real}}, ) - return sum(logpdf.(dist.dists, x)) + return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist)) end -function Distributions.rand(rng::Random.AbstractRNG, dist::ProductMatrixDiscreteUnivariate) - return rand.(Ref(rng), dist.dists) +function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) + init = reshape(rand(rng, dist.dists[1]), :, 1) + return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init) end diff --git a/src/common.jl b/src/common.jl index 47d2d91e..49e6cde4 100644 --- a/src/common.jl +++ b/src/common.jl @@ -6,7 +6,7 @@ function Base.fill( ) return track(fill, value, dims...) end -Tracker.@grad function Base.fill(value::Real, dims...) +@grad function Base.fill(value::Real, dims...) return fill(data(value), dims...), function(Δ) size(Δ) ≢ dims && error("Dimension mismatch") return (sum(Δ), map(_->nothing, dims)...) @@ -16,7 +16,7 @@ end ## StatsFuns ## logsumexp(x::TrackedArray) = track(logsumexp, x) -Tracker.@grad function logsumexp(x::TrackedArray) +@grad function logsumexp(x::TrackedArray) lse = logsumexp(data(x)) return lse, Δ -> (Δ .* exp.(x .- lse),) end @@ -24,7 +24,7 @@ end ## Linear algebra ## LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A) -Tracker.@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix) +@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix) return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),) end @@ -39,27 +39,27 @@ function turing_chol(A::AbstractMatrix, check) (chol.factors, chol.info) end turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check) -Tracker.@grad function turing_chol(A::AbstractMatrix, check) +@grad function turing_chol(A::AbstractMatrix, check) C, back = pullback(unsafe_cholesky, data(A), data(check)) return (C.factors, C.info), Δ->back((factors=data(Δ[1]),)) end unsafe_cholesky(x, check) = cholesky(x, check=check) -ZygoteRules.@adjoint function unsafe_cholesky(Σ::Real, check) +@adjoint function unsafe_cholesky(Σ::Real, check) C = cholesky(Σ; check=check) return C, function(Δ::NamedTuple) issuccess(C) || return (zero(Σ), nothing) (Δ.factors[1, 1] / (2 * C.U[1, 1]), nothing) end end -ZygoteRules.@adjoint function unsafe_cholesky(Σ::Diagonal, check) +@adjoint function unsafe_cholesky(Σ::Diagonal, check) C = cholesky(Σ; check=check) return C, function(Δ::NamedTuple) issuccess(C) || (Diagonal(zero(diag(Δ.factors))), nothing) (Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)), nothing) end end -ZygoteRules.@adjoint function unsafe_cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check) +@adjoint function unsafe_cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check) C = cholesky(Σ; check=check) return C, function(Δ::NamedTuple) issuccess(C) || return (zero(Δ.factors), nothing) @@ -78,7 +78,7 @@ end # Specialised logdet for cholesky to target the triangle directly. logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)]) logdet_chol_tri(U::TrackedMatrix) = track(logdet_chol_tri, U) -Tracker.@grad function logdet_chol_tri(U::AbstractMatrix) +@grad function logdet_chol_tri(U::AbstractMatrix) U_data = data(U) return logdet_chol_tri(U_data), Δ->(Matrix(Diagonal(2 .* Δ ./ diag(U_data))),) end @@ -97,7 +97,7 @@ function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat) return track(zygote_ldiv, A, B) end zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = track(zygote_ldiv, A, B) -Tracker.@grad function zygote_ldiv(A, B) +@grad function zygote_ldiv(A, B) Y, back = pullback(\, data(A), data(B)) return Y, Δ->back(data(Δ)) end @@ -112,3 +112,31 @@ function SpecialFunctions.logabsgamma(x::TrackedReal) v = loggamma(x) return v, sign(data(v)) end + +# Some Tracker fixes + +for i = 0:2, c = Tracker.combinations([:AbstractArray, :TrackedArray, :TrackedReal, :Number], i), f = [:hcat, :vcat] + if :TrackedReal in c + cnames = map(_ -> gensym(), c) + @eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) = + track($f, $(cnames...), x, xs...) + end +end +@grad function vcat(x::Real) + vcat(data(x)), (Δ) -> (Δ[1],) +end +@grad function vcat(x1::Real, x2::Real) + vcat(data(x1), data(x2)), (Δ) -> (Δ[1], Δ[2]) +end +@grad function vcat(x1::AbstractVector, x2::Real) + vcat(data(x1), data(x2)), (Δ) -> (Δ[1:length(x1)], Δ[length(x1)+1]) +end + +# Zygote fill has issues with non-numbers + +@adjoint function fill(x::T, dims...) where {T} + function zfill(x, dims...,) + return reshape([x for i in 1:prod(dims)], dims) + end + pullback(zfill, x, dims...) +end diff --git a/src/matrixvariate.jl b/src/matrixvariate.jl index a57c4849..0bbe16f5 100644 --- a/src/matrixvariate.jl +++ b/src/matrixvariate.jl @@ -201,10 +201,10 @@ end ## Adjoints -ZygoteRules.@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real}) +@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real}) return pullback(TuringWishart, df, S) end -ZygoteRules.@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real}) +@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real}) return pullback(TuringInverseWishart, df, S) end diff --git a/src/multi.jl b/src/multi.jl index 03a04ee7..07b5c6f1 100644 --- a/src/multi.jl +++ b/src/multi.jl @@ -1,70 +1,26 @@ -# Multivariate continuous - -struct MultipleContinuousMultivariate{ - Tdist <: ContinuousMultivariateDistribution -} <: ContinuousMatrixDistribution - dist::Tdist - N::Int -end -Base.size(dist::MultipleContinuousMultivariate) = (length(dist.dist), dist.N) -function Multi(dist::ContinuousMultivariateDistribution, N::Int) - return MultipleContinuousMultivariate(dist, N) -end -function Distributions.logpdf( - dist::MultipleContinuousMultivariate, - x::AbstractMatrix{<:Real} -) - return sum(logpdf(dist.dist, x)) -end -function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleContinuousMultivariate) - return rand(rng, dist.dist, dist.N) -end -Distributions.MvNormal(m, s, N::Int) = MultipleContinuousMultivariate(MvNormal(m, s), N) +# Univariate +const FillVectorOfUnivariate{ + S <: ValueSupport, + T <: UnivariateDistribution{S}, + Tdists <: Fill{T, 1}, +} = VectorOfUnivariate{S, T, Tdists} -# Multivariate discrete - -struct MultipleDiscreteMultivariate{ - Tdist <: DiscreteMultivariateDistribution -} <: DiscreteMatrixDistribution - dist::Tdist - N::Int -end -Base.size(dist::MultipleDiscreteMultivariate) = (length(dist.dist), dist.N) -function Multi(dist::DiscreteMultivariateDistribution, N::Int) - return MultipleDiscreteMultivariate(dist, N) +function FillDist(dist::UnivariateDistribution, N::Int) + return Product(Fill(dist, N)) end +FillDist(d::Normal, N::Int) = MvNormal(fill(d.μ, N), d.σ) function Distributions.logpdf( - dist::MultipleDiscreteMultivariate, - x::AbstractMatrix{<:Integer} + dist::FillVectorOfUnivariate, + x::AbstractVector{<:Real}, ) - return sum(logpdf(dist.dist, x)) -end -function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleDiscreteMultivariate) - return rand(rng, dist.dist, dist.N) -end - -# Univariate continuous - -struct MultipleContinuousUnivariate{ - Tdist <: ContinuousUnivariateDistribution, -} <: ContinuousMultivariateDistribution - dist::Tdist - N::Int -end -Base.length(dist::MultipleContinuousUnivariate) = dist.N -Base.size(dist::MultipleContinuousUnivariate) = (dist.N,) -function Multi(dist::ContinuousUnivariateDistribution, N::Int) - return MultipleContinuousUnivariate(dist, N) + return _flat_logpdf(dist.v.value, x) end function Distributions.logpdf( - dist::MultipleContinuousUnivariate, - x::AbstractVector{<:Real}, + dist::FillVectorOfUnivariate, + x::AbstractMatrix{<:Real}, ) - return _flat_logpdf(dist.dist, x) -end -function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleContinuousUnivariate) - return rand(rng, dist.dist, dist.N) + return _flat_logpdf_mat(dist.v.value, x) end function _flat_logpdf(dist, x) if toflatten(dist) @@ -74,68 +30,48 @@ function _flat_logpdf(dist, x) return sum(logpdf.(dist, x)) end end - -struct MatrixContinuousUnivariate{ - Tdist <: ContinuousUnivariateDistribution, - Tsize <: NTuple{2, Integer}, -} <: ContinuousMatrixDistribution - dist::Tdist - S::Tsize -end -Base.size(dist::MatrixContinuousUnivariate) = dist.S -function Multi(dist::ContinuousUnivariateDistribution, N1::Integer, N2::Integer) - return MatrixContinuousUnivariate(dist, (N1, N2)) -end -function Distributions.logpdf( - dist::MatrixContinuousUnivariate, - x::AbstractMatrix{<:Real} -) - return _flat_logpdf(dist.dist, x) -end -function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixContinuousUnivariate) - return rand(rng, dist.dist, dist.S) +function _flat_logpdf_mat(dist, x) + if toflatten(dist) + f, args = flatten(dist) + return vec(sum(f.(args..., x), dims = 1)) + else + return vec(sum(logpdf.(dist, x), dims = 1)) + end end -# Univariate discrete +const FillMatrixOfUnivariate{ + S <: ValueSupport, + T <: UnivariateDistribution{S}, + Tdists <: Fill{T, 2}, +} = MatrixOfUnivariate{S, T, Tdists} -struct MultipleDiscreteUnivariate{ - Tdist <: DiscreteUnivariateDistribution, -} <: ContinuousMultivariateDistribution - dist::Tdist - N::Int +function FillDist(dist::UnivariateDistribution, N1::Integer, N2::Integer) + return MatrixOfUnivariate(Fill(dist, N1, N2)) end -Base.length(dist::MultipleDiscreteUnivariate) = dist.N -Base.size(dist::MultipleDiscreteUnivariate) = (dist.N,) -function Multi(dist::DiscreteUnivariateDistribution, N::Int) - MultipleDiscreteUnivariate(dist, N) -end -function Distributions.logpdf( - dist::MultipleDiscreteUnivariate, - x::AbstractVector{<:Integer} -) - return _flat_logpdf(dist.dist, x) +function Distributions.logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:Real}) + return _flat_logpdf(dist.dists.value, x) end -function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleDiscreteUnivariate) - return rand(rng, dist.dist, dist.N) +function Distributions.rand(rng::Random.AbstractRNG, dist::FillMatrixOfUnivariate) + return rand(rng, dist.dists.value, length.(dist.dists.axes)) end -struct MatrixDiscreteUnivariate{ - Tdist <: DiscreteUnivariateDistribution, - Tsize <: NTuple{2, Integer}, -} <: DiscreteMatrixDistribution - dist::Tdist - S::Tsize -end -Base.size(dist::MatrixDiscreteUnivariate) = dist.S -function Multi(dist::DiscreteUnivariateDistribution, N1::Integer, N2::Integer) - return MatrixDiscreteUnivariate(dist, (N1, N2)) +# Multivariate + +const FillVectorOfMultivariate{ + S <: ValueSupport, + T <: MultivariateDistribution{S}, + Tdists <: Fill{T, 1}, +} = VectorOfMultivariate{S, T, Tdists} + +function FillDist(dist::MultivariateDistribution, N::Int) + return VectorOfMultivariate(Fill(dist, N)) end function Distributions.logpdf( - dist::MatrixDiscreteUnivariate, - x::AbstractMatrix{<:Real} + dist::FillVectorOfMultivariate, + x::AbstractMatrix{<:Real}, ) - return _flat_logpdf(dist.dist, x) + return sum(logpdf(dist.dists.value, x)) end -function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixDiscreteUnivariate) - return rand(rng, dist.dist, dist.S) +function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate) + return rand(rng, dist.dists.value, length.(dist.dists.axes)) end diff --git a/src/multivariate.jl b/src/multivariate.jl index 3dc2d03c..3cf05c54 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -48,10 +48,10 @@ function Distributions.logpdf(d::Dirichlet{T}, x::TrackedVecOrMat) where {T} logpdf(TuringDirichlet{T, TV}(d.alpha, d.alpha0, d.lmnB), x) end -ZygoteRules.@adjoint function Distributions.Dirichlet(alpha) +@adjoint function Distributions.Dirichlet(alpha) return pullback(TuringDirichlet, alpha) end -ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha) +@adjoint function Distributions.Dirichlet(d, alpha) return pullback(TuringDirichlet, d, alpha) end @@ -65,23 +65,23 @@ function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) end end -Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector) +@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector) simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin (Δ .* log.(data(x)), -Δ, Δ .* (data(alpha) .- 1)) end end -Tracker.@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) +@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin (log.(data(x)) * Δ, -sum(Δ), repeat(data(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ)) end end -ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector) - simplex_logpdf(alpha, lmnB, x), Δ -> (Δ .* log.(x), -Δ, Δ .* (alpha .- 1)) +@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector) + return simplex_logpdf(alpha, lmnB, x), Δ -> (Δ .* log.(x), -Δ, Δ .* (alpha .- 1)) end -ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) - simplex_logpdf(alpha, lmnB, x), Δ -> begin +@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) + return simplex_logpdf(alpha, lmnB, x), Δ -> begin (log.(x) * Δ, -sum(Δ), repeat(alpha .- 1, 1, size(x, 2)) * Diagonal(Δ)) end end @@ -353,18 +353,18 @@ MvLogNormal(d::Int, σ::TrackedReal{<:Real}) = TuringMvLogNormal(TuringMvNormal( ## Zygote adjoint -ZygoteRules.@adjoint function Distributions.MvNormal( +@adjoint function Distributions.MvNormal( A::Union{AbstractVector{<:Real}, AbstractMatrix{<:Real}}, ) return pullback(TuringMvNormal, A) end -ZygoteRules.@adjoint function Distributions.MvNormal( +@adjoint function Distributions.MvNormal( m::AbstractVector{<:Real}, A::Union{Real, UniformScaling, AbstractVecOrMat{<:Real}}, ) return pullback(TuringMvNormal, m, A) end -ZygoteRules.@adjoint function Distributions.MvNormal( +@adjoint function Distributions.MvNormal( d::Int, A::Real, ) diff --git a/src/univariate.jl b/src/univariate.jl index aa624fd7..d1c2e3aa 100644 --- a/src/univariate.jl +++ b/src/univariate.jl @@ -30,7 +30,7 @@ end uniformlogpdf(a::Real, b::Real, x::TrackedReal) = track(uniformlogpdf, a, b, x) uniformlogpdf(a::TrackedReal, b::TrackedReal, x::Real) = track(uniformlogpdf, a, b, x) uniformlogpdf(a::TrackedReal, b::TrackedReal, x::TrackedReal) = track(uniformlogpdf, a, b, x) -Tracker.@grad function uniformlogpdf(a, b, x) +@grad function uniformlogpdf(a, b, x) diff = data(b) - data(a) T = typeof(diff) if a <= data(x) <= b && a < b @@ -42,7 +42,7 @@ Tracker.@grad function uniformlogpdf(a, b, x) return l, Δ -> (n, n, n) end end -ZygoteRules.@adjoint function uniformlogpdf(a, b, x) +@adjoint function uniformlogpdf(a, b, x) diff = b - a T = typeof(diff) if a <= x <= b && a < b @@ -54,7 +54,7 @@ ZygoteRules.@adjoint function uniformlogpdf(a, b, x) return l, Δ -> (n, n, n) end end -ZygoteRules.@adjoint function Distributions.Uniform(args...) +@adjoint function Distributions.Uniform(args...) return pullback(TuringUniform, args...) end @@ -67,7 +67,7 @@ function _betalogpdfgrad(α, β, x) dx = (α - 1)/x + (1 - β)/(1 - x) return (dα, dβ, dx) end -ZygoteRules.@adjoint function betalogpdf(α::Real, β::Real, x::Number) +@adjoint function betalogpdf(α::Real, β::Real, x::Number) return betalogpdf(α, β, x), Δ -> (Δ .* _betalogpdfgrad(α, β, x)) end @@ -79,7 +79,7 @@ function _gammalogpdfgrad(k, θ, x) dx = (k - 1)/x - 1/θ return (dk, dθ, dx) end -ZygoteRules.@adjoint function gammalogpdf(k::Real, θ::Real, x::Number) +@adjoint function gammalogpdf(k::Real, θ::Real, x::Number) return gammalogpdf(k, θ, x), Δ -> (Δ .* _gammalogpdfgrad(k, θ, x)) end @@ -92,7 +92,7 @@ function _chisqlogpdfgrad(k, x) dx = (hk - 1)/x - one(hk)/2 return (dk, dx) end -ZygoteRules.@adjoint function chisqlogpdf(k::Real, x::Number) +@adjoint function chisqlogpdf(k::Real, x::Number) return chisqlogpdf(k, x), Δ -> (Δ .* _chisqlogpdfgrad(k, x)) end @@ -109,7 +109,7 @@ function _fdistlogpdfgrad(v1, v2, x) dx = v1 / 2 * (1 / x - temp3) - 1 / x return (dv1, dv2, dx) end -ZygoteRules.@adjoint function fdistlogpdf(v1::Real, v2::Real, x::Number) +@adjoint function fdistlogpdf(v1::Real, v2::Real, x::Number) return fdistlogpdf(v1, v2, x), Δ -> (Δ .* _fdistlogpdfgrad(v1, v2, x)) end @@ -120,7 +120,7 @@ function _tdistlogpdfgrad(v, x) dx = -x * (v + 1) / (v + x^2) return (dv, dx) end -ZygoteRules.@adjoint function tdistlogpdf(v::Real, x::Number) +@adjoint function tdistlogpdf(v::Real, x::Number) return tdistlogpdf(v, x), Δ -> (Δ .* _tdistlogpdfgrad(v, x)) end @@ -145,9 +145,9 @@ M, f, arity = DiffRules.@define_diffrule DistributionsAD.semicirclelogpdf(r, x) da, db = DiffRules.diffrule(M, f, :a, :b) f = :($M.$f) @eval begin - Tracker.@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) - Tracker.@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, Tracker._zero(b)) - Tracker.@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (Tracker._zero(a), Δ * $db) + @grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, Tracker._zero(b)) + @grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (Tracker._zero(a), Δ * $db) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b) $f(a::Real, b::TrackedReal) = track($f, a, b) @@ -159,11 +159,11 @@ end ## Binomial ## binomlogpdf(n::Int, p::TrackedReal, x::Int) = track(binomlogpdf, n, p, x) -Tracker.@grad function binomlogpdf(n::Int, p::TrackedReal, x::Int) +@grad function binomlogpdf(n::Int, p::TrackedReal, x::Int) return binomlogpdf(n, data(p), x), Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing) end -ZygoteRules.@adjoint function binomlogpdf(n::Int, p::Real, x::Int) +@adjoint function binomlogpdf(n::Int, p::Real, x::Int) return binomlogpdf(n, p, x), Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing) end @@ -186,15 +186,15 @@ _nbinomlogpdf_grad_2(r, p, k) = -k / (1 - p) + r / p nbinomlogpdf(n::TrackedReal, p::TrackedReal, x::Int) = track(nbinomlogpdf, n, p, x) nbinomlogpdf(n::Real, p::TrackedReal, x::Int) = track(nbinomlogpdf, n, p, x) nbinomlogpdf(n::TrackedReal, p::Real, x::Int) = track(nbinomlogpdf, n, p, x) -Tracker.@grad function nbinomlogpdf(r::TrackedReal, p::TrackedReal, k::Int) +@grad function nbinomlogpdf(r::TrackedReal, p::TrackedReal, k::Int) return nbinomlogpdf(data(r), data(p), k), Δ->(Δ * _nbinomlogpdf_grad_1(r, p, k), Δ * _nbinomlogpdf_grad_2(r, p, k), nothing) end -Tracker.@grad function nbinomlogpdf(r::Real, p::TrackedReal, k::Int) +@grad function nbinomlogpdf(r::Real, p::TrackedReal, k::Int) return nbinomlogpdf(data(r), data(p), k), Δ->(Tracker._zero(r), Δ * _nbinomlogpdf_grad_2(r, p, k), nothing) end -Tracker.@grad function nbinomlogpdf(r::TrackedReal, p::Real, k::Int) +@grad function nbinomlogpdf(r::TrackedReal, p::Real, k::Int) return nbinomlogpdf(data(r), data(p), k), Δ->(Δ * _nbinomlogpdf_grad_1(r, p, k), Tracker._zero(p), nothing) end @@ -226,11 +226,11 @@ end ## Poisson ## poislogpdf(v::TrackedReal, x::Int) = track(poislogpdf, v, x) -Tracker.@grad function poislogpdf(v::TrackedReal, x::Int) +@grad function poislogpdf(v::TrackedReal, x::Int) return poislogpdf(data(v), x), Δ->(Δ * (x/v - 1), nothing) end -ZygoteRules.@adjoint function poislogpdf(v::Real, x::Int) +@adjoint function poislogpdf(v::Real, x::Int) return poislogpdf(v, x), Δ->(Δ * (x/v - 1), nothing) end @@ -262,7 +262,7 @@ Base.minimum(d::TuringPoissonBinomial) = 0 Base.maximum(d::TuringPoissonBinomial) = length(d.p) poissonbinomial_pdf_fft(x::TrackedArray) = track(poissonbinomial_pdf_fft, x) -Tracker.@grad function poissonbinomial_pdf_fft(x::TrackedArray) +@grad function poissonbinomial_pdf_fft(x::TrackedArray) x_data = data(x) T = eltype(x_data) fft = poissonbinomial_pdf_fft(x_data) @@ -271,7 +271,7 @@ Tracker.@grad function poissonbinomial_pdf_fft(x::TrackedArray) end end # FIXME: This is inefficient, replace with the commented code below once Zygote supports it. -ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray) +@adjoint function poissonbinomial_pdf_fft(x::AbstractArray) T = eltype(x) fft = poissonbinomial_pdf_fft(x) return fft, Δ -> begin @@ -281,7 +281,7 @@ end # The code below doesn't work because of bugs in Zygote. The above is inefficient. #= -ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{<:Real}) +@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{<:Real}) return pullback(poissonbinomial_pdf_fft_zygote, x) end function poissonbinomial_pdf_fft_zygote(p::AbstractArray{T}) where {T <: Real} diff --git a/test/distributions.jl b/test/distributions.jl index 22635d50..a64649d7 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -174,6 +174,10 @@ separator() test_head("Testing: Multivariate continuous distributions") mult_cont_dists = [ # Vector case + DistSpec(:(() -> FillDist(Beta(), dim)), (), fill(0.5, dim)), + DistSpec(:(() -> ArrayDist(fill(Beta(), dim))), (), fill(0.5, dim)), + DistSpec(:((m, v) -> FillDist(Normal(m, sqrt(v)), dim)), (1.0, 1.0), norm_val_vec), + DistSpec(:((m, v) -> ArrayDist(fill(Normal(m, sqrt(v)), dim))), (1.0, 1.0), norm_val_vec), DistSpec(:MvNormal, (mean, cov_mat), norm_val_vec), DistSpec(:MvNormal, (mean, cov_vec), norm_val_vec), DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_vec), @@ -192,6 +196,10 @@ separator() DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_vec), DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_vec), # Matrix case + DistSpec(:(() -> FillDist(Beta(), dim)), (), fill(0.5, dim, dim)), + DistSpec(:(() -> ArrayDist(fill(Beta(), dim))), (), fill(0.5, dim, dim)), + DistSpec(:((m, v) -> FillDist(Normal(m, sqrt(v)), dim)), (1.0, 1.0), norm_val_mat), + DistSpec(:((m, v) -> ArrayDist(fill(Normal(m, sqrt(v)), dim))), (1.0, 1.0), norm_val_mat), DistSpec(:MvNormal, (mean, cov_vec), norm_val_mat), DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_mat), DistSpec(:MvNormal, (mean, cov_num), norm_val_mat), @@ -244,6 +252,10 @@ separator() @testset "Matrix-variate continuous distributions" begin test_head("Testing: Matrix-variate continuous distributions") matrix_cont_dists = [ + DistSpec(:(() -> FillDist(Beta(), dim, dim)), (), fill(0.5, dim, dim)), + DistSpec(:(() -> ArrayDist(fill(Beta(), dim, dim))), (), fill(0.5, dim, dim)), + DistSpec(:((m, v) -> FillDist(Normal(m, sqrt(v)), dim, 2)), (1.0, 1.0), norm_val_mat), + DistSpec(:((m, v) -> ArrayDist(fill(Normal(m, sqrt(v)), dim, 2))), (1.0, 1.0), norm_val_mat), DistSpec(:((n1, n2)->MatrixBeta(dim, n1, n2)), (dim, dim), beta_mat), DistSpec(:Wishart, (dim, cov_mat), cov_mat), DistSpec(:InverseWishart, (dim, cov_mat), cov_mat), From 2f97eaa3bac8c64211e9f98e0a3b31ebadb59dc4 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 15 Feb 2020 17:30:41 +1100 Subject: [PATCH 10/24] fix Project.toml --- Project.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1d23ad62..00cbbeff 100644 --- a/Project.toml +++ b/Project.toml @@ -20,12 +20,14 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Combinatorics = "0.7" -Distributions = "0.22" DiffRules = "0.1, 1.0" +Distributions = "0.22" FillArrays = "0.8" +FiniteDifferences = "0.9" ForwardDiff = "0.10.6" PDMats = "0.9" SpecialFunctions = "0.8, 0.9, 0.10" +StatsBase = "0.32" StatsFuns = "0.8, 0.9" Tracker = "0.2.5" Zygote = "0.4.7" From 52d609edb5d9c10197614052fbbca1b24a0ba17d Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 15 Feb 2020 17:34:04 +1100 Subject: [PATCH 11/24] minor type stability fix --- src/array_dist.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_dist.jl b/src/array_dist.jl index c3c39698..5303a33f 100644 --- a/src/array_dist.jl +++ b/src/array_dist.jl @@ -8,11 +8,11 @@ const VectorOfUnivariate{ function ArrayDist(dists::AbstractVector{<:Normal{T}}) where {T} if T <: TrackedReal - init_m = dists[1].μ + init_m = vcat(dists[1].μ) means = mapreduce(vcat, drop(dists, 1); init = init_m) do d d.μ end - init_v = dists[1].σ^2 + init_v = vcat(dists[1].σ^2) vars = mapreduce(vcat, drop(dists, 1); init = init_v) do d d.σ^2 end From 11c536cf8d0726dda26e75d3f233cff7d9ace63b Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 15 Feb 2020 17:42:58 +1100 Subject: [PATCH 12/24] typo --- src/array_dist.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_dist.jl b/src/array_dist.jl index 5303a33f..fd904b7f 100644 --- a/src/array_dist.jl +++ b/src/array_dist.jl @@ -61,7 +61,7 @@ function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate) return rand.(Ref(rng), dist.dists) end -# Multivariate continuous +# Multivariate struct VectorOfMultivariate{ S <: ValueSupport, From 6e8bbe118601bb0f22f83f9e92f2089b04ffcdc2 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 15 Feb 2020 17:43:12 +1100 Subject: [PATCH 13/24] add discrete dist tests --- test/distributions.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/distributions.jl b/test/distributions.jl index a64649d7..ab6b62c7 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -155,10 +155,12 @@ separator() end end separator() - +=# @testset "Multivariate discrete distributions" begin test_head("Testing: Multivariate discrete distributions") mult_disc_dists = [ + DistSpec(:((p) -> FillDist(Bernoulli(p), dim)), (0.45,), fill(1, dim)), + DistSpec(:((p) -> ArrayDist(fill(Bernoulli(p), dim))), (0.45,), fill(1, dim)), DistSpec(:((p) -> Multinomial(2, p / sum(p))), (fill(0.5, 2),), [2, 0]), ] for d in mult_disc_dists From 0c0302cda2e1aec209f112b9b7dce52c4aed9f9e Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 15 Feb 2020 20:39:54 +1100 Subject: [PATCH 14/24] fix Zygote Irrational errors --- src/common.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/common.jl b/src/common.jl index 49e6cde4..f7f907fd 100644 --- a/src/common.jl +++ b/src/common.jl @@ -1,5 +1,7 @@ ## Generic ## +Base.one(::Irrational) = 1 + function Base.fill( value::TrackedReal, dims::Vararg{Union{Integer, AbstractUnitRange}}, From 0cf7b8599d5165c8774ae36aa6dc75e3d194403d Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 15 Feb 2020 20:49:00 +1100 Subject: [PATCH 15/24] typo --- test/distributions.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/distributions.jl b/test/distributions.jl index ab6b62c7..1a4a0487 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -155,7 +155,7 @@ separator() end end separator() -=# + @testset "Multivariate discrete distributions" begin test_head("Testing: Multivariate discrete distributions") mult_disc_dists = [ From 79e7e0374593771221bf667b431d47c9f14392e8 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 15 Feb 2020 22:01:45 +1100 Subject: [PATCH 16/24] == -> isapprox test --- test/others.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/others.jl b/test/others.jl index 3546ad2e..8048e2fb 100644 --- a/test/others.jl +++ b/test/others.jl @@ -19,5 +19,5 @@ end d1 = TuringDiagMvNormal(zeros(10), sigmas) d2 = MvNormal(zeros(10), sigmas) - @test entropy(d1) == entropy(d2) + @test isapprox(entropy(d1), entropy(d2), rtol = 1e-6) end From c5751d410af29be38a519d448f1f0b60752e12a0 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sat, 15 Feb 2020 22:36:33 +1100 Subject: [PATCH 17/24] Update src/multivariate.jl Co-Authored-By: David Widmann --- src/multivariate.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate.jl b/src/multivariate.jl index 3cf05c54..4b9b970e 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -32,7 +32,7 @@ function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer} Tf = float(T) TuringDirichlet(convert(AbstractVector{Tf}, alpha)) end -TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha)) +TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, float(alpha)) Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha) Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha) From dd93b4a0f27a58b05d2313851b6f88061629b1aa Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sat, 15 Feb 2020 22:38:34 +1100 Subject: [PATCH 18/24] Update src/multi.jl Co-Authored-By: David Widmann --- src/multi.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multi.jl b/src/multi.jl index 07b5c6f1..8834d88d 100644 --- a/src/multi.jl +++ b/src/multi.jl @@ -7,7 +7,7 @@ const FillVectorOfUnivariate{ } = VectorOfUnivariate{S, T, Tdists} function FillDist(dist::UnivariateDistribution, N::Int) - return Product(Fill(dist, N)) + return product_distribution(Fill(dist, N)) end FillDist(d::Normal, N::Int) = MvNormal(fill(d.μ, N), d.σ) function Distributions.logpdf( From d6a9e4ee5181ce6f606a22e8b09b1c883e5aa0f5 Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sat, 15 Feb 2020 22:42:13 +1100 Subject: [PATCH 19/24] Update src/multivariate.jl Co-Authored-By: David Widmann --- src/multivariate.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate.jl b/src/multivariate.jl index 4b9b970e..ec45d671 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -131,7 +131,7 @@ struct TuringScalMvNormal{Tm<:AbstractVector, Tσ<:Real} <: ContinuousMultivaria end Base.length(d::TuringScalMvNormal) = length(d.m) -Base.size(d::TuringScalMvNormal) = (length(d), length(d)) +Base.size(d::TuringScalMvNormal) = (length(d),) Distributions.rand(d::TuringScalMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...) function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal, n::Int...) return d.m .+ d.σ .* randn(rng, length(d), n...) From 2b88f7ba646f0c3f0c701efe39f055d4047bc80c Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sat, 15 Feb 2020 22:42:26 +1100 Subject: [PATCH 20/24] Update src/multivariate.jl Co-Authored-By: David Widmann --- src/multivariate.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate.jl b/src/multivariate.jl index ec45d671..b103f1b0 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -119,7 +119,7 @@ end Distributions.params(d::TuringDiagMvNormal) = (d.m, d.σ) Distributions.dim(d::TuringDiagMvNormal) = length(d.m) Base.length(d::TuringDiagMvNormal) = length(d.m) -Base.size(d::TuringDiagMvNormal) = (length(d), length(d)) +Base.size(d::TuringDiagMvNormal) = (length(d),) Distributions.rand(d::TuringDiagMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...) function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal, n::Int...) return d.m .+ d.σ .* randn(rng, length(d), n...) From 751e13899852c0a6dd88c5731780c97db505a5e0 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 16 Feb 2020 01:54:27 +1100 Subject: [PATCH 21/24] fix David's comments --- src/DistributionsAD.jl | 9 ++-- src/{array_dist.jl => arraydist.jl} | 68 ++++++++++++----------------- src/common.jl | 21 +++++++-- src/{multi.jl => filldist.jl} | 50 +++++++++++++++++---- src/multivariate.jl | 2 +- test/distributions.jl | 28 ++++++------ 6 files changed, 107 insertions(+), 71 deletions(-) rename src/{array_dist.jl => arraydist.jl} (52%) rename src/{multi.jl => filldist.jl} (61%) diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 67d5cedc..560689e7 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -12,6 +12,7 @@ using PDMats, using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray, TrackedVecOrMat, track, @grad, data +using SpecialFunctions: logabsgamma, digamma using ZygoteRules: ZygoteRules, @adjoint, pullback using LinearAlgebra: copytri! using Distributions: AbstractMvLogNormal, @@ -39,15 +40,15 @@ export TuringScalMvNormal, TuringPoissonBinomial, TuringWishart, TuringInverseWishart, - ArrayDist, - FillDist + arraydist, + filldist include("common.jl") include("univariate.jl") include("multivariate.jl") include("matrixvariate.jl") include("flatten.jl") -include("array_dist.jl") -include("multi.jl") +include("arraydist.jl") +include("filldist.jl") end diff --git a/src/array_dist.jl b/src/arraydist.jl similarity index 52% rename from src/array_dist.jl rename to src/arraydist.jl index fd904b7f..5c393c58 100644 --- a/src/array_dist.jl +++ b/src/arraydist.jl @@ -1,43 +1,31 @@ # Univariate -const VectorOfUnivariate{ - S <: ValueSupport, - Tdist <: UnivariateDistribution{S}, - Tdists <: AbstractVector{Tdist}, -} = Distributions.Product{S, Tdist, Tdists} - -function ArrayDist(dists::AbstractVector{<:Normal{T}}) where {T} - if T <: TrackedReal - init_m = vcat(dists[1].μ) - means = mapreduce(vcat, drop(dists, 1); init = init_m) do d - d.μ - end - init_v = vcat(dists[1].σ^2) - vars = mapreduce(vcat, drop(dists, 1); init = init_v) do d - d.σ^2 - end - else - means = [d.μ for d in dists] - vars = [d.σ^2 for d in dists] - end +const VectorOfUnivariate = Distributions.Product +function arraydist(dists::AbstractVector{<:Normal{T}}) where {T} + means = mean.(dists) + vars = var.(dists) return MvNormal(means, vars) end -function ArrayDist(dists::AbstractVector{<:UnivariateDistribution}) - return Distributions.Product(dists) +function arraydist(dists::AbstractVector{<:Normal{<:TrackedReal}}) + means = vcatmapreduce(mean, dists) + vars = vcatmapreduce(var, dists) + return MvNormal(means, vars) +end +function arraydist(dists::AbstractVector{<:UnivariateDistribution}) + return product_distribution(dists) end function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real}) - return sum(logpdf.(dist.v, x)) + return sum(vcatmapreduce(logpdf, dist.v, x)) end function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real}) - # Any other more efficient implementation breaks Zygote - return [logpdf(dist, x[:,i]) for i in 1:size(x, 2)] + # eachcol breaks Zygote, so we need an adjoint + return vcatmapreduce((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x)) end -function Distributions.logpdf( - dist::VectorOfUnivariate, - x::AbstractVector{<:AbstractMatrix{<:Real}}, -) - return logpdf.(Ref(dist), x) +@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real}) + # Any other more efficient implementation breaks Zygote + f(dist, x) = [sum(logpdf.(dist.v, view(x, :, i))) for i in 1:size(x, 2)] + return pullback(f, dist, x) end struct MatrixOfUnivariate{ @@ -48,14 +36,13 @@ struct MatrixOfUnivariate{ dists::Tdists end Base.size(dist::MatrixOfUnivariate) = size(dist.dists) -function ArrayDist(dists::AbstractMatrix{<:UnivariateDistribution}) +function arraydist(dists::AbstractMatrix{<:UnivariateDistribution}) return MatrixOfUnivariate(dists) end function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real}) # Broadcasting here breaks Tracker for some reason - return sum(zip(dist.dists, x)) do (dist, x) - logpdf(dist, x) - end + # A Zygote adjoint is defined for vcatmapreduce to use broadcasting + return sum(vcatmapreduce(logpdf, dist.dists, x)) end function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate) return rand.(Ref(rng), dist.dists) @@ -72,17 +59,16 @@ struct VectorOfMultivariate{ end Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist)) Base.length(dist::VectorOfMultivariate) = length(dist.dists) -function ArrayDist(dists::AbstractVector{<:MultivariateDistribution}) +function arraydist(dists::AbstractVector{<:MultivariateDistribution}) return VectorOfMultivariate(dists) end function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) - return sum(logpdf(dist.dists[i], x[:,i]) for i in 1:length(dist)) + # eachcol breaks Zygote, so we define an adjoint + return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x))) end -function Distributions.logpdf( - dist::VectorOfMultivariate, - x::AbstractVector{<:AbstractVector{<:Real}}, -) - return sum(logpdf(dist.dists[i], x[i]) for i in 1:length(dist)) +@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) + f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2))) + return pullback(f, dist, x) end function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) init = reshape(rand(rng, dist.dists[1]), :, 1) diff --git a/src/common.jl b/src/common.jl index f7f907fd..afb102cc 100644 --- a/src/common.jl +++ b/src/common.jl @@ -2,6 +2,18 @@ Base.one(::Irrational) = 1 +function vcatmapreduce(f, args...) + init = vcat(f(first.(args)...,)) + zipped_args = zip(args...,) + return mapreduce(vcat, drop(zipped_args, 1); init = init) do zarg + f(zarg...,) + end +end +@adjoint function vcatmapreduce(f, args...) + g(f, args...) = f.(args...,) + return pullback(g, f, args...) +end + function Base.fill( value::TrackedReal, dims::Vararg{Union{Integer, AbstractUnitRange}}, @@ -110,9 +122,12 @@ end # SpecialFunctions -function SpecialFunctions.logabsgamma(x::TrackedReal) - v = loggamma(x) - return v, sign(data(v)) +SpecialFunctions.logabsgamma(x::TrackedReal) = track(logabsgamma, x) +@grad function SpecialFunctions.logabsgamma(x::Real) + return logabsgamma(data(x)), Δ -> (digamma(data(x)) * Δ[1],) +end +@adjoint function SpecialFunctions.logabsgamma(x::Real) + return logabsgamma(x), Δ -> (digamma(x) * Δ[1],) end # Some Tracker fixes diff --git a/src/multi.jl b/src/filldist.jl similarity index 61% rename from src/multi.jl rename to src/filldist.jl index 07b5c6f1..d7a0b282 100644 --- a/src/multi.jl +++ b/src/filldist.jl @@ -6,28 +6,49 @@ const FillVectorOfUnivariate{ Tdists <: Fill{T, 1}, } = VectorOfUnivariate{S, T, Tdists} -function FillDist(dist::UnivariateDistribution, N::Int) - return Product(Fill(dist, N)) +function filldist(dist::UnivariateDistribution, N::Int) + return product_distribution(Fill(dist, N)) end -FillDist(d::Normal, N::Int) = MvNormal(fill(d.μ, N), d.σ) +filldist(d::Normal, N::Int) = MvNormal(fill(d.μ, N), d.σ) + function Distributions.logpdf( dist::FillVectorOfUnivariate, x::AbstractVector{<:Real}, ) - return _flat_logpdf(dist.v.value, x) + return _logpdf(dist, x) end function Distributions.logpdf( dist::FillVectorOfUnivariate, x::AbstractMatrix{<:Real}, +) + return _logpdf(dist, x) +end +@adjoint function Distributions.logpdf( + dist::FillVectorOfUnivariate, + x::AbstractMatrix{<:Real}, +) + return pullback(_logpdf, dist, x) +end + +function _logpdf( + dist::FillVectorOfUnivariate, + x::AbstractVector{<:Real}, +) + return _flat_logpdf(dist.v.value, x) +end +function _logpdf( + dist::FillVectorOfUnivariate, + x::AbstractMatrix{<:Real}, ) return _flat_logpdf_mat(dist.v.value, x) end + function _flat_logpdf(dist, x) if toflatten(dist) f, args = flatten(dist) return sum(f.(args..., x)) else - return sum(logpdf.(dist, x)) + return sum(vcatmapreduce(x -> logpdf(dist, x), x)) end end function _flat_logpdf_mat(dist, x) @@ -35,7 +56,8 @@ function _flat_logpdf_mat(dist, x) f, args = flatten(dist) return vec(sum(f.(args..., x), dims = 1)) else - return vec(sum(logpdf.(dist, x), dims = 1)) + temp = vcatmapreduce(x -> logpdf(dist, x), x) + return vec(sum(reshape(temp, size(x)), dims = 1)) end end @@ -45,7 +67,7 @@ const FillMatrixOfUnivariate{ Tdists <: Fill{T, 2}, } = MatrixOfUnivariate{S, T, Tdists} -function FillDist(dist::UnivariateDistribution, N1::Integer, N2::Integer) +function filldist(dist::UnivariateDistribution, N1::Integer, N2::Integer) return MatrixOfUnivariate(Fill(dist, N1, N2)) end function Distributions.logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:Real}) @@ -63,12 +85,24 @@ const FillVectorOfMultivariate{ Tdists <: Fill{T, 1}, } = VectorOfMultivariate{S, T, Tdists} -function FillDist(dist::MultivariateDistribution, N::Int) +function filldist(dist::MultivariateDistribution, N::Int) return VectorOfMultivariate(Fill(dist, N)) end function Distributions.logpdf( dist::FillVectorOfMultivariate, x::AbstractMatrix{<:Real}, +) + return _logpdf(dist, x) +end +@adjoint function Distributions.logpdf( + dist::FillVectorOfMultivariate, + x::AbstractMatrix{<:Real}, +) + return pullback(_logpdf, dist, x) +end +function _logpdf( + dist::FillVectorOfMultivariate, + x::AbstractMatrix{<:Real}, ) return sum(logpdf(dist.dists.value, x)) end diff --git a/src/multivariate.jl b/src/multivariate.jl index 3cf05c54..5308be7e 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -59,7 +59,7 @@ function simplex_logpdf(alpha, lmnB, x::AbstractVector) sum((alpha .- 1) .* log.(x)) - lmnB end function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) - @views init = vcat(sum((alpha .- 1) .* log.(x[:,1]))) + init = vcat(sum((alpha .- 1) .* log.(view(x, :, 1)))) mapreduce(vcat, drop(eachcol(x), 1); init = init) do c sum((alpha .- 1) .* log.(c)) - lmnB end diff --git a/test/distributions.jl b/test/distributions.jl index 1a4a0487..5bcb5556 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -159,8 +159,8 @@ separator() @testset "Multivariate discrete distributions" begin test_head("Testing: Multivariate discrete distributions") mult_disc_dists = [ - DistSpec(:((p) -> FillDist(Bernoulli(p), dim)), (0.45,), fill(1, dim)), - DistSpec(:((p) -> ArrayDist(fill(Bernoulli(p), dim))), (0.45,), fill(1, dim)), + DistSpec(:((p) -> filldist(Bernoulli(p), dim)), (0.45,), fill(1, dim)), + DistSpec(:((p) -> arraydist(fill(Bernoulli(p), dim))), (0.45,), fill(1, dim)), DistSpec(:((p) -> Multinomial(2, p / sum(p))), (fill(0.5, 2),), [2, 0]), ] for d in mult_disc_dists @@ -176,10 +176,10 @@ separator() test_head("Testing: Multivariate continuous distributions") mult_cont_dists = [ # Vector case - DistSpec(:(() -> FillDist(Beta(), dim)), (), fill(0.5, dim)), - DistSpec(:(() -> ArrayDist(fill(Beta(), dim))), (), fill(0.5, dim)), - DistSpec(:((m, v) -> FillDist(Normal(m, sqrt(v)), dim)), (1.0, 1.0), norm_val_vec), - DistSpec(:((m, v) -> ArrayDist(fill(Normal(m, sqrt(v)), dim))), (1.0, 1.0), norm_val_vec), + DistSpec(:(() -> filldist(Beta(), dim)), (), fill(0.5, dim)), + DistSpec(:(() -> arraydist(fill(Beta(), dim))), (), fill(0.5, dim)), + DistSpec(:((m, v) -> filldist(Normal(m, sqrt(v)), dim)), (1.0, 1.0), norm_val_vec), + DistSpec(:((m, v) -> arraydist(fill(Normal(m, sqrt(v)), dim))), (1.0, 1.0), norm_val_vec), DistSpec(:MvNormal, (mean, cov_mat), norm_val_vec), DistSpec(:MvNormal, (mean, cov_vec), norm_val_vec), DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_vec), @@ -198,10 +198,10 @@ separator() DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_vec), DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_vec), # Matrix case - DistSpec(:(() -> FillDist(Beta(), dim)), (), fill(0.5, dim, dim)), - DistSpec(:(() -> ArrayDist(fill(Beta(), dim))), (), fill(0.5, dim, dim)), - DistSpec(:((m, v) -> FillDist(Normal(m, sqrt(v)), dim)), (1.0, 1.0), norm_val_mat), - DistSpec(:((m, v) -> ArrayDist(fill(Normal(m, sqrt(v)), dim))), (1.0, 1.0), norm_val_mat), + DistSpec(:(() -> filldist(Beta(), dim)), (), fill(0.5, dim, dim)), + DistSpec(:(() -> arraydist(fill(Beta(), dim))), (), fill(0.5, dim, dim)), + DistSpec(:((m, v) -> filldist(Normal(m, sqrt(v)), dim)), (1.0, 1.0), norm_val_mat), + DistSpec(:((m, v) -> arraydist(fill(Normal(m, sqrt(v)), dim))), (1.0, 1.0), norm_val_mat), DistSpec(:MvNormal, (mean, cov_vec), norm_val_mat), DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_mat), DistSpec(:MvNormal, (mean, cov_num), norm_val_mat), @@ -254,10 +254,10 @@ separator() @testset "Matrix-variate continuous distributions" begin test_head("Testing: Matrix-variate continuous distributions") matrix_cont_dists = [ - DistSpec(:(() -> FillDist(Beta(), dim, dim)), (), fill(0.5, dim, dim)), - DistSpec(:(() -> ArrayDist(fill(Beta(), dim, dim))), (), fill(0.5, dim, dim)), - DistSpec(:((m, v) -> FillDist(Normal(m, sqrt(v)), dim, 2)), (1.0, 1.0), norm_val_mat), - DistSpec(:((m, v) -> ArrayDist(fill(Normal(m, sqrt(v)), dim, 2))), (1.0, 1.0), norm_val_mat), + DistSpec(:(() -> filldist(Beta(), dim, dim)), (), fill(0.5, dim, dim)), + DistSpec(:(() -> arraydist(fill(Beta(), dim, dim))), (), fill(0.5, dim, dim)), + DistSpec(:((m, v) -> filldist(Normal(m, sqrt(v)), dim, 2)), (1.0, 1.0), norm_val_mat), + DistSpec(:((m, v) -> arraydist(fill(Normal(m, sqrt(v)), dim, 2))), (1.0, 1.0), norm_val_mat), DistSpec(:((n1, n2)->MatrixBeta(dim, n1, n2)), (dim, dim), beta_mat), DistSpec(:Wishart, (dim, cov_mat), cov_mat), DistSpec(:InverseWishart, (dim, cov_mat), cov_mat), From 735f65a7f896b1e7bce13d78e874e88566bbce40 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 16 Feb 2020 02:15:39 +1100 Subject: [PATCH 22/24] remove Dirichlet fix --- src/multivariate.jl | 88 ------------------------------------------- test/distributions.jl | 3 +- 2 files changed, 1 insertion(+), 90 deletions(-) diff --git a/src/multivariate.jl b/src/multivariate.jl index 2927d9b1..8859d377 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -1,91 +1,3 @@ -## Dirichlet ## - -struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution - alpha::TV - alpha0::T - lmnB::T -end -function check(alpha) - all(ai -> ai > 0, alpha) || - throw(ArgumentError("Dirichlet: alpha must be a positive vector.")) -end -Zygote.@nograd DistributionsAD.check - -function TuringDirichlet(alpha::AbstractVector) - check(alpha) - alpha0 = sum(alpha) - lmnB = sum(loggamma, alpha) - loggamma(alpha0) - T = promote_type(typeof(alpha0), typeof(lmnB)) - TV = typeof(alpha) - TuringDirichlet{T, TV}(alpha, alpha0, lmnB) -end - -function TuringDirichlet(d::Integer, alpha::Real) - alpha0 = alpha * d - _alpha = fill(alpha, d) - lmnB = loggamma(alpha) * d - loggamma(alpha0) - T = promote_type(typeof(alpha0), typeof(lmnB)) - TV = typeof(_alpha) - TuringDirichlet{T, TV}(_alpha, alpha0, lmnB) -end -function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer} - Tf = float(T) - TuringDirichlet(convert(AbstractVector{Tf}, alpha)) -end -TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, float(alpha)) - -Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha) -Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha) - -function Distributions.logpdf(d::TuringDirichlet, x::AbstractVector) - simplex_logpdf(d.alpha, d.lmnB, x) -end -function Distributions.logpdf(d::TuringDirichlet, x::AbstractMatrix) - simplex_logpdf(d.alpha, d.lmnB, x) -end -function Distributions.logpdf(d::Dirichlet{T}, x::TrackedVecOrMat) where {T} - TV = typeof(d.alpha) - logpdf(TuringDirichlet{T, TV}(d.alpha, d.alpha0, d.lmnB), x) -end - -@adjoint function Distributions.Dirichlet(alpha) - return pullback(TuringDirichlet, alpha) -end -@adjoint function Distributions.Dirichlet(d, alpha) - return pullback(TuringDirichlet, d, alpha) -end - -function simplex_logpdf(alpha, lmnB, x::AbstractVector) - sum((alpha .- 1) .* log.(x)) - lmnB -end -function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) - init = vcat(sum((alpha .- 1) .* log.(view(x, :, 1)))) - mapreduce(vcat, drop(eachcol(x), 1); init = init) do c - sum((alpha .- 1) .* log.(c)) - lmnB - end -end - -@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector) - simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin - (Δ .* log.(data(x)), -Δ, Δ .* (data(alpha) .- 1)) - end -end -@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) - simplex_logpdf(data(alpha), data(lmnB), data(x)), Δ -> begin - (log.(data(x)) * Δ, -sum(Δ), repeat(data(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ)) - end -end - -@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector) - return simplex_logpdf(alpha, lmnB, x), Δ -> (Δ .* log.(x), -Δ, Δ .* (alpha .- 1)) -end - -@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) - return simplex_logpdf(alpha, lmnB, x), Δ -> begin - (log.(x) * Δ, -sum(Δ), repeat(alpha .- 1, 1, size(x, 2)) * Diagonal(Δ)) - end -end - ## MvNormal ## """ diff --git a/test/distributions.jl b/test/distributions.jl index 5bcb5556..a74c597e 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -215,8 +215,6 @@ separator() DistSpec(:MvLogNormal, (cov_vec,), norm_val_mat), DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_mat), DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_mat), - DistSpec(:Dirichlet, (alpha,), dir_val), - DistSpec(:Dirichlet, (alpha,), dir_val), ] broken_mult_cont_dists = [ @@ -233,6 +231,7 @@ separator() DistSpec(:MvNormalCanon, (cov_mat,), norm_val_mat), DistSpec(:MvNormalCanon, (cov_vec,), norm_val_mat), DistSpec(:(cov_num -> MvNormalCanon(dim, cov_num)), (cov_num,), norm_val_mat), + DistSpec(:Dirichlet, (alpha,), dir_val), # Test failure DistSpec(:MvNormal, (mean, cov_mat), norm_val_mat), DistSpec(:MvNormal, (cov_mat,), norm_val_mat), From 988878a66eb9102faf2fd43e60a35000bf392a55 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 16 Feb 2020 03:51:27 +1100 Subject: [PATCH 23/24] one of irrationals = true not 1 --- src/common.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common.jl b/src/common.jl index afb102cc..14f7fc5d 100644 --- a/src/common.jl +++ b/src/common.jl @@ -1,6 +1,6 @@ ## Generic ## -Base.one(::Irrational) = 1 +Base.one(::Irrational) = true function vcatmapreduce(f, args...) init = vcat(f(first.(args)...,)) From 105e5e8bea99d1589da0fdcbc65d1425aed45b43 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sun, 16 Feb 2020 04:10:45 +1100 Subject: [PATCH 24/24] define eachcol for Julia 1.0 --- src/common.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/common.jl b/src/common.jl index 14f7fc5d..eee8518f 100644 --- a/src/common.jl +++ b/src/common.jl @@ -1,5 +1,9 @@ ## Generic ## +if VERSION < v"1.1" + eachcol(A::AbstractVecOrMat) = (view(A, :, i) for i in axes(A, 2)) +end + Base.one(::Irrational) = true function vcatmapreduce(f, args...)