Skip to content

Commit

Permalink
Showing 14 changed files with 751 additions and 265 deletions.
27 changes: 27 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: CI

on:
push:
branches:
- master
pull_request:
types: [opened, synchronize, reopened]

jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
matrix:
julia-version: [1.0.5, 1.2.0, 1.3]
julia-arch: [x64, x86]
os: [ubuntu-latest, macOS-latest]
exclude:
- os: macOS-latest
julia-arch: x86

steps:
- uses: actions/checkout@v1.0.0
- uses: julia-actions/setup-julia@latest
with:
version: ${{ matrix.julia-version }}
- uses: julia-actions/julia-runtest@master
23 changes: 0 additions & 23 deletions .travis.yml

This file was deleted.

2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@ version = "0.3.2"

[deps]
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -20,6 +21,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Combinatorics = "0.7"
Compat = "3.6"
DiffRules = "0.1, 1.0"
Distributions = "0.22"
FillArrays = "0.8"
3 changes: 2 additions & 1 deletion src/DistributionsAD.jl
Original file line number Diff line number Diff line change
@@ -8,7 +8,8 @@ using PDMats,
Random,
Combinatorics,
SpecialFunctions,
StatsFuns
StatsFuns,
Compat

using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
TrackedVecOrMat, track, @grad, data
12 changes: 12 additions & 0 deletions src/arraydist.jl
Original file line number Diff line number Diff line change
@@ -44,6 +44,12 @@ function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real
# A Zygote adjoint is defined for vcatmapreduce to use broadcasting
return sum(vcatmapreduce(logpdf, dist.dists, x))
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return vcatmapreduce(x -> logpdf(dist, x), x)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}})
return vcatmapreduce(x -> logpdf(dist, x), x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
return rand.(Ref(rng), dist.dists)
end
@@ -66,6 +72,12 @@ function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Re
# eachcol breaks Zygote, so we define an adjoint
return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x)))
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(dist, x), x), size(x))
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(dist, x), x), size(x))
end
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
return pullback(f, dist, x)
4 changes: 1 addition & 3 deletions src/common.jl
Original file line number Diff line number Diff line change
@@ -4,8 +4,6 @@ if VERSION < v"1.1"
eachcol(A::AbstractVecOrMat) = (view(A, :, i) for i in axes(A, 2))
end

Base.one(::Irrational) = true

function vcatmapreduce(f, args...)
init = vcat(f(first.(args)...,))
zipped_args = zip(args...,)
@@ -14,7 +12,7 @@ function vcatmapreduce(f, args...)
end
end
@adjoint function vcatmapreduce(f, args...)
g(f, args...) = f.(args...,)
g(f, args...) = f.(args...)
return pullback(g, f, args...)
end

26 changes: 18 additions & 8 deletions src/filldist.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Univariate

Tracker.dual(x::Int, p) = x

const FillVectorOfUnivariate{
S <: ValueSupport,
T <: UnivariateDistribution{S},
@@ -46,15 +48,23 @@ end
function _flat_logpdf(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return sum(f.(args..., x))
if any(Tracker.istracked, args)
return sum(f.(args..., x))
else
return sum(logpdf.(dist, x))
end
else
return sum(vcatmapreduce(x -> logpdf(dist, x), x))
end
end
function _flat_logpdf_mat(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return vec(sum(f.(args..., x), dims = 1))
if any(Tracker.istracked, args)
return vec(sum(f.(args..., x), dims = 1))
else
return vec(sum(logpdf.(dist, x), dims = 1))
end
else
temp = vcatmapreduce(x -> logpdf(dist, x), x)
return vec(sum(reshape(temp, size(x)), dims = 1))
@@ -74,7 +84,7 @@ function Distributions.logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:
return _flat_logpdf(dist.dists.value, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::FillMatrixOfUnivariate)
return rand(rng, dist.dists.value, length.(dist.dists.axes))
return rand(rng, dist.dists.value, length.(dist.dists.axes)...,)
end

# Multivariate
@@ -94,18 +104,18 @@ function Distributions.logpdf(
)
return _logpdf(dist, x)
end
@adjoint function Distributions.logpdf(
function _logpdf(
dist::FillVectorOfMultivariate,
x::AbstractMatrix{<:Real},
)
return pullback(_logpdf, dist, x)
return sum(logpdf(dist.dists.value, x))
end
function _logpdf(
@adjoint function Distributions.logpdf(
dist::FillVectorOfMultivariate,
x::AbstractMatrix{<:Real},
)
return sum(logpdf(dist.dists.value, x))
return pullback(_logpdf, dist, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate)
return rand(rng, dist.dists.value, length.(dist.dists.axes))
return rand(rng, dist.dists.value, length.(dist.dists.axes)...,)
end
13 changes: 6 additions & 7 deletions src/flatten.jl
Original file line number Diff line number Diff line change
@@ -26,7 +26,6 @@ const flattened_dists = [ Bernoulli,
NegativeBinomial,
Poisson,
Skellam,
PoissonBinomial,
Arcsine,
Beta,
BetaPrime,
@@ -42,10 +41,10 @@ const flattened_dists = [ Bernoulli,
FDist,
Frechet,
Gamma,
GeneralizedExtremeValue,
#GeneralizedExtremeValue,
GeneralizedPareto,
Gumbel,
InverseGamma,
#InverseGamma,
InverseGaussian,
Kolmogorov,
Laplace,
@@ -55,17 +54,17 @@ const flattened_dists = [ Bernoulli,
LogitNormal,
LogNormal,
Normal,
NormalCanon,
NormalInverseGaussian,
#NormalCanon,
#NormalInverseGaussian,
Pareto,
PGeneralizedGaussian,
Rayleigh,
SymTriangularDist,
TDist,
TriangularDist,
Triweight,
Categorical,
Truncated,
#Truncated,
#VonMises,
]
for T in flattened_dists
@eval toflatten(::$T) = true
51 changes: 49 additions & 2 deletions src/matrixvariate.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
## MatrixBeta

function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:TrackedMatrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
end
@adjoint function Distributions.logpdf(d::MatrixBeta, X::AbstractArray{<:Matrix{<:Real}})
f(d, X) = map(x -> logpdf(d, x), X)
return pullback(f, d, X)
end

# Adapted from Distributions.jl

## Wishart
@@ -12,6 +22,13 @@ end

#### Constructors

function TuringWishart(d::Wishart)
d = TuringWishart(d.df, getchol(d.S), d.c0)
end
getchol(p::PDMat) = p.chol
getchol(p::PDiagMat) = Diagonal(map(sqrt, p.diag))
getchol(p::ScalMat) = Diagonal(fill(sqrt(p.value), p.dim))

function TuringWishart(df::T, S::AbstractMatrix) where {T <: Real}
p = size(S, 1)
df > p - 1 || error("dpf should be greater than dim - 1.")
@@ -66,7 +83,7 @@ end
function Distributions.entropy(d::TuringWishart)
p = Distributions.dim(d)
df = d.df
d.c0 - 0.5 * (df - p - 1) * meanlogdet(d) + 0.5 * df * p
d.c0 - 0.5 * (df - p - 1) * Distributions.meanlogdet(d) + 0.5 * df * p
end

# Gupta/Nagar (1999) Theorem 3.3.15.i
@@ -82,12 +99,24 @@ end

#### Evaluation

function Distributions.logpdf(d::Wishart, X::TrackedMatrix)
return logpdf(TuringWishart(d), X)
end
function Distributions.logpdf(d::Wishart, X::AbstractArray{<:TrackedMatrix})
return logpdf(TuringWishart(d), X)
end
function Distributions.logpdf(d::TuringWishart, X::AbstractMatrix{<:Real})
df = d.df
p = Distributions.dim(d)
Xcf = cholesky(X)
return 0.5 * ((df - (p + 1)) * logdet(Xcf) - tr(d.chol \ X)) - d.c0
end
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
end
function Distributions.logpdf(d::TuringWishart, X::AbstractArray{<:Matrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
end

#### Sampling
function Distributions._rand!(rng::AbstractRNG, d::TuringWishart, A::AbstractMatrix)
@@ -128,6 +157,13 @@ end

#### Constructors

function TuringInverseWishart(d::InverseWishart)
d = TuringInverseWishart(d.df, getmatrix(d.Ψ), d.c0)
end
getmatrix(p::PDMat) = p.mat
getmatrix(p::PDiagMat) = Diagonal(p.diag)
getmatrix(p::ScalMat) = Diagonal(fill(p.value, p.dim))

function TuringInverseWishart(df::T, Ψ::AbstractMatrix) where T<:Real
p = size(Ψ, 1)
df > p - 1 || error("df should be greater than dim - 1.")
@@ -182,6 +218,12 @@ end

#### Evaluation

function Distributions.logpdf(d::InverseWishart, X::TrackedMatrix)
return logpdf(TuringInverseWishart(d), X)
end
function Distributions.logpdf(d::InverseWishart, X::AbstractArray{<:TrackedMatrix})
return logpdf(TuringInverseWishart(d), X)
end
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractMatrix{<:Real})
p = Distributions.dim(d)
df = d.df
@@ -190,7 +232,12 @@ function Distributions.logpdf(d::TuringInverseWishart, X::AbstractMatrix{<:Real}
Ψ = d.S
-0.5 * ((df + p + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) - d.c0
end

function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:AbstractMatrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
end
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractArray{<:Matrix{<:Real}})
return reshape(vcatmapreduce(x -> logpdf(d, x), X), size(X))
end

#### Sampling

16 changes: 14 additions & 2 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
@@ -116,8 +116,6 @@ struct TuringDiagMvNormal{Tm<:AbstractVector, Tσ<:AbstractVector} <: Continuous
σ::Tσ
end

Distributions.params(d::TuringDiagMvNormal) = (d.m, d.σ)
Distributions.dim(d::TuringDiagMvNormal) = length(d.m)
Base.length(d::TuringDiagMvNormal) = length(d.m)
Base.size(d::TuringDiagMvNormal) = (length(d),)
Distributions.rand(d::TuringDiagMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...)
@@ -164,6 +162,20 @@ function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix)
return -((size(x, 1) * log(2π) + logdet(d.C)) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2
end

for T in (:TrackedVector, :TrackedMatrix)
@eval begin
function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.ScalMat}, x::$T)
logpdf(TuringScalMvNormal(d.μ, d.Σ.value), x)
end
function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.PDiagMat}, x::$T)
logpdf(TuringDiagMvNormal(d.μ, d.Σ.diag), x)
end
function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.PDMat}, x::$T)
logpdf(TuringDenseMvNormal(d.μ, d.Σ.chol), x)
end
end
end

import StatsBase: entropy
function entropy(d::TuringDiagMvNormal)
T = eltype(d.σ)
681 changes: 472 additions & 209 deletions test/distributions.jl

Large diffs are not rendered by default.

120 changes: 112 additions & 8 deletions test/others.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,117 @@
using StatsBase: entropy

@testset "unsafe_cholesky" begin
A = rand(3, 3); A = A + A' + 3I
@test Matrix(DistributionsAD.unsafe_cholesky(A, true)) == Matrix(cholesky(A))
@test !issuccess(DistributionsAD.unsafe_cholesky(rand(3,3), false))
@test_throws PosDefException DistributionsAD.unsafe_cholesky(rand(3,3), true)
end

@testset "TuringWishart" begin
dim = 3
A = Matrix{Float64}(I, dim, dim)
dW1 = Wishart(dim + 4, A)
dW2 = TuringWishart(dim + 4, A)

@testset "$F" for F in (size, rank, mean, meanlogdet, entropy, cov, var)
@test F(dW1) == F(dW2)
end
@test Matrix(mode(dW1)) == mode(dW2)
xw = rand(dW2)
@test insupport(dW1, xw)
@test insupport(dW2, xw)
@test logpdf(dW1, xw) == logpdf(dW2, xw)
end

@testset "TuringInverseWishart" begin
dim = 3
A = Matrix{Float64}(I, dim, dim)
dIW1 = InverseWishart(dim + 4, A)
dIW2 = TuringInverseWishart(dim + 4, A)

@testset "$F" for F in (size, rank, mean, cov, var)
@test F(dIW1) == F(dIW2)
end
@test Matrix(mode(dIW1)) == mode(dIW2)
xiw = rand(dIW2)
@test insupport(dIW1, xiw)
@test insupport(dIW2, xiw)
@test logpdf(dIW1, xiw) == logpdf(dIW2, xiw)
end

@testset "TuringMvNormal" begin
@testset "$TD" for TD in [TuringDenseMvNormal, TuringDiagMvNormal, TuringScalMvNormal]
m = rand(3)
if TD <: TuringDenseMvNormal
C = Matrix{Float64}(I, 3, 3)
d1 = TuringMvNormal(m, C)
elseif TD <: TuringDiagMvNormal
C = ones(3)
d1 = TuringMvNormal(m, C)
else
C = 1.0
d1 = TuringMvNormal(m, C)
end
d2 = MvNormal(m, C)

@testset "$F" for F in (length, size)
@test F(d1) == F(d2)
end

x1 = rand(d1)
x2 = rand(d1, 3)
@test isapprox(logpdf(d1, x1), logpdf(d2, x1), rtol = 1e-6)
@test isapprox(logpdf(d1, x2), logpdf(d2, x2), rtol = 1e-6)
end
end

@testset "TuringMvLogNormal" begin
@testset "$TD" for TD in [TuringDenseMvNormal, TuringDiagMvNormal, TuringScalMvNormal]
m = rand(3)
if TD <: TuringDenseMvNormal
C = Matrix{Float64}(I, 3, 3)
d1 = TuringMvLogNormal(TuringMvNormal(m, C))
elseif TD <: TuringDiagMvNormal
C = ones(3)
d1 = TuringMvLogNormal(TuringMvNormal(m, C))
else
C = 1.0
d1 = TuringMvLogNormal(TuringMvNormal(m, C))
end
d2 = MvLogNormal(MvNormal(m, C))

@test length(d1) == length(d2)

x1 = rand(d1)
x2 = rand(d1, 3)
@test isapprox(logpdf(d1, x1), logpdf(d2, x1), rtol = 1e-6)
@test isapprox(logpdf(d1, x2), logpdf(d2, x2), rtol = 1e-6)

x2[:, 1] .= -1
@test isinf(logpdf(d1, x2)[1])
@test isinf(logpdf(d2, x2)[1])
end
end

@testset "TuringUniform" begin
@test logpdf(TuringUniform(), param(0.5)) == 0
end

@testset "Semicircle" begin
@test Tracker.data(logpdf(Semicircle(1.0), param(0.5))) == logpdf(Semicircle(1.0), 0.5)
end

@testset "TuringPoissonBinomial" begin
d1 = TuringPoissonBinomial([0.5, 0.5])
d2 = PoissonBinomial([0.5, 0.5])
@test quantile(d1, 0.5) == quantile(d2, 0.5)
@test minimum(d1) == minimum(d2)
end

@testset "Inverse of pi" begin
@test 1/pi == inv(pi)
end

@testset "Others" begin
@test fill(param(1.0), 3) isa TrackedArray
x = rand(3)
@@ -13,11 +125,3 @@ using StatsBase: entropy
B = copy(A)
@test DistributionsAD.zygote_ldiv(A, B) == A \ B
end

@testset "Extras from StatsBase.jl" begin
sigmas = exp.(randn(10))
d1 = TuringDiagMvNormal(zeros(10), sigmas)
d2 = MvNormal(zeros(10), sigmas)

@test isapprox(entropy(d1), entropy(d2), rtol = 1e-6)
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@ using DistributionsAD, Test, LinearAlgebra, Combinatorics
using ForwardDiff: Dual
using StatsFuns: binomlogpdf, logsumexp
const FDM = FiniteDifferences
using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringUniform, unsafe_cholesky
using Distributions: meanlogdet

include("test_utils.jl")
include("distributions.jl")
36 changes: 34 additions & 2 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
if VERSION < v"1.1"
isnothing(x) = x === nothing
end

struct ADTestFunction
name::String
f::Function
@@ -11,15 +15,20 @@ end

vectorize(v::Number) = [v]
vectorize(v::Diagonal) = v.diag
vectorize(v::Vector{<:Matrix}) = mapreduce(vec, vcat, v)
vectorize(v) = vec(v)
pack(vals...) = reduce(vcat, vectorize.(vals))

@generated function unpack(x, vals...)
unpacked = []
ind = :(1)
for (i, T) in enumerate(vals)
if T <: Number
push!(unpacked, :(x[$ind]))
ind = :($ind + 1)
elseif T <: Vector{<:Matrix}
push!(unpacked, :(unpack_vec_of_mats(x[$ind:$ind+sum(length, vals[$i])-1], vals[$i])))
ind = :($ind + sum(length, vals[$i]))
elseif T <: Vector
push!(unpacked, :(x[$ind:$ind+length(vals[$i])-1]))
ind = :($ind + length(vals[$i]))
@@ -38,6 +47,15 @@ pack(vals...) = reduce(vcat, vectorize.(vals))
return ($(unpacked...),)
end
end
function unpack_vec_of_mats(x, val)
ind = 1
return map(1:length(val)) do i
out = reshape(x[ind : ind + length(val[i]) - 1], size(val[i]))
ind += length(val[i])
out
end
end

function get_function(dist::DistSpec, inds, val)
syms = []
args = []
@@ -55,7 +73,14 @@ function get_function(dist::DistSpec, inds, val)
push!(syms, sym)
expr = quote
($(syms...),) -> begin
temp = logpdf($(dist.name)($(args...)), $(sym))
temp_args = ($(args...),)
temp_dist = $(dist.name)(temp_args...)
temp_x = $(sym)
if temp_dist isa UnivariateDistribution && temp_x isa AbstractArray
temp = logpdf.(temp_dist, temp_x)
else
temp = logpdf(temp_dist, temp_x)
end
if temp isa AbstractVector
return sum(temp)
else
@@ -74,7 +99,14 @@ function get_function(dist::DistSpec, inds, val)
@assert length(inds) > 0
expr = quote
($(syms...),) -> begin
temp = logpdf($(dist.name)($(args...)), $(dist.x))
temp_args = ($(args...),)
temp_dist = $(dist.name)(temp_args...)
temp_x = $(dist.x)
if temp_dist isa UnivariateDistribution && temp_x isa AbstractArray
temp = logpdf.(temp_dist, temp_x)
else
temp = logpdf(temp_dist, temp_x)
end
if temp isa AbstractVector
return sum(temp)
else

0 comments on commit 45ccbab

Please sign in to comment.