Skip to content

Commit

Permalink
Fix and test TuringDirichlet constructors (#152)
Browse files Browse the repository at this point in the history
* Fix and test `TuringDirichlet` constructors

* Fix typo

* Import `TuringDirichlet`
  • Loading branch information
devmotion authored Feb 3, 2021
1 parent d6aaa64 commit 9806ec3
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DistributionsAD"
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.18"
version = "0.6.19"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
72 changes: 37 additions & 35 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
@@ -1,52 +1,54 @@
## Dirichlet ##

struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution
struct TuringDirichlet{T<:Real,TV<:AbstractVector,S<:Real} <: ContinuousMultivariateDistribution
alpha::TV
alpha0::T
lmnB::T
end
Base.length(d::TuringDirichlet) = length(d.alpha)
function check(alpha)
all(ai -> ai > 0, alpha) ||
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
end

function Distributions._rand!(rng::Random.AbstractRNG,
d::TuringDirichlet,
x::AbstractVector{<:Real})
s = 0.0
n = length(x)
α = d.alpha
for i in 1:n
@inbounds s += (x[i] = rand(rng, Gamma(α[i])))
end
Distributions.multiply!(x, inv(s)) # this returns x
lmnB::S
end

function TuringDirichlet(alpha::AbstractVector)
check(alpha)
all(ai -> ai > 0, alpha) ||
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))

alpha0 = sum(alpha)
lmnB = sum(loggamma, alpha) - loggamma(alpha0)
T = promote_type(typeof(alpha0), typeof(lmnB))
TV = typeof(alpha)
TuringDirichlet{T, TV}(alpha, alpha0, lmnB)
end

function TuringDirichlet(d::Integer, alpha::Real)
alpha0 = alpha * d
_alpha = fill(alpha, d)
lmnB = loggamma(alpha) * d - loggamma(alpha0)
T = promote_type(typeof(alpha0), typeof(lmnB))
TV = typeof(_alpha)
TuringDirichlet{T, TV}(_alpha, alpha0, lmnB)
end
function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer}
TuringDirichlet(float.(alpha))
return TuringDirichlet(alpha, alpha0, lmnB)
end
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha))
TuringDirichlet(d::Integer, alpha::Real) = TuringDirichlet(Fill(alpha, d))

# TODO: remove?
TuringDirichlet(alpha::AbstractVector{<:Integer}) = TuringDirichlet(float.(alpha))
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, float(alpha))

# TODO: remove and use `Dirichlet` only for `Tracker.TrackedVector`
Distributions.Dirichlet(alpha::AbstractVector) = TuringDirichlet(alpha)

TuringDirichlet(d::Dirichlet) = TuringDirichlet(d.alpha, d.alpha0, d.lmnB)

Base.length(d::TuringDirichlet) = length(d.alpha)

# copied from Distributions
# TODO: remove and use `Dirichlet`?
function Distributions._rand!(
rng::Random.AbstractRNG,
d::TuringDirichlet,
x::AbstractVector{<:Real},
)
@inbounds for (i, αi) in zip(eachindex(x), d.alpha)
x[i] = rand(rng, Gamma(αi))
end
Distributions.multiply!(x, inv(sum(x))) # this returns x
end
function Distributions._rand!(
rng::AbstractRNG,
d::TuringDirichlet{<:Real,<:FillArrays.AbstractFill},
x::AbstractVector{<:Real}
)
rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x)
Distributions.multiply!(x, inv(sum(x))) # this returns x
end

function Distributions._logpdf(d::TuringDirichlet, x::AbstractVector{<:Real})
return simplex_logpdf(d.alpha, d.lmnB, x)
end
Expand Down
6 changes: 3 additions & 3 deletions src/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,13 @@ Dirichlet(alpha::AbstractVector{<:TrackedReal}) = TuringDirichlet(alpha)
Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)

function _logpdf(d::Dirichlet, x::AbstractVector{<:TrackedReal})
return _logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
return _logpdf(TuringDirichlet(d), x)
end
function logpdf(d::Dirichlet, x::AbstractMatrix{<:TrackedReal})
return logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
return logpdf(TuringDirichlet(d), x)
end
function loglikelihood(d::Dirichlet, x::AbstractMatrix{<:TrackedReal})
return loglikelihood(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
return loglikelihood(TuringDirichlet(d), x)
end

# default definition of `loglikelihood` yields gradients of zero?!
Expand Down
7 changes: 3 additions & 4 deletions src/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,13 +371,13 @@ Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)

function Distributions._logpdf(d::Dirichlet, x::TrackedVector{<:Real})
return Distributions._logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
return Distributions._logpdf(TuringDirichlet(d), x)
end
function Distributions.logpdf(d::Dirichlet, x::TrackedMatrix{<:Real})
return logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
return logpdf(TuringDirichlet(d), x)
end
function Distributions.loglikelihood(d::Dirichlet, x::TrackedMatrix{<:Real})
return loglikelihood(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
return loglikelihood(TuringDirichlet(d), x)
end

# Fix ambiguities
Expand Down Expand Up @@ -615,4 +615,3 @@ Distributions.InverseWishart(df::TrackedReal, S::AbstractMatrix{<:Real}) = Turin
Distributions.InverseWishart(df::Real, S::TrackedMatrix) = TuringInverseWishart(df, S)
Distributions.InverseWishart(df::TrackedReal, S::TrackedMatrix) = TuringInverseWishart(df, S)
Distributions.InverseWishart(df::TrackedReal, S::AbstractPDMat{<:TrackedReal}) = TuringInverseWishart(df, S)

38 changes: 38 additions & 0 deletions test/others.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,42 @@
end
end
end

@testset "TuringDirichlet" begin
dim = 3
n = 4
for alpha in (2, rand())
d1 = TuringDirichlet(dim, alpha)
d2 = Dirichlet(dim, alpha)
d3 = TuringDirichlet(d2)
@test d1.alpha == d2.alpha == d3.alpha
@test d1.alpha0 == d2.alpha0 == d3.alpha0
@test d1.lmnB == d2.lmnB == d3.lmnB

s1 = rand(d1)
@test s1 isa Vector{Float64}
@test length(s1) == dim

s2 = rand(d1, n)
@test s2 isa Matrix{Float64}
@test size(s2) == (dim, n)
end

for alpha in (ones(Int, dim), rand(dim))
d1 = TuringDirichlet(alpha)
d2 = Dirichlet(alpha)
d3 = TuringDirichlet(d2)
@test d1.alpha == d2.alpha == d3.alpha
@test d1.alpha0 == d2.alpha0 == d3.alpha0
@test d1.lmnB == d2.lmnB == d3.lmnB

s1 = rand(d1)
@test s1 isa Vector{Float64}
@test length(s1) == dim

s2 = rand(d1, n)
@test s2 isa Matrix{Float64}
@test size(s2) == (dim, n)
end
end
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Random, LinearAlgebra, Test

using Distributions: meanlogdet
using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal,
TuringPoissonBinomial
TuringPoissonBinomial, TuringDirichlet
using StatsBase: entropy
using StatsFuns: binomlogpdf, logsumexp, logistic

Expand Down

2 comments on commit 9806ec3

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/29317

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.19 -m "<description of version>" 9806ec30af8b4aa3731e3ee358c7cae27fa01f20
git push origin v0.6.19

Please sign in to comment.