diff --git a/Project.toml b/Project.toml index 0fba3bc1..56a638c5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.18" +version = "0.6.19" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/multivariate.jl b/src/multivariate.jl index 0f388ebb..634c023d 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -1,52 +1,54 @@ ## Dirichlet ## -struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution +struct TuringDirichlet{T<:Real,TV<:AbstractVector,S<:Real} <: ContinuousMultivariateDistribution alpha::TV alpha0::T - lmnB::T -end -Base.length(d::TuringDirichlet) = length(d.alpha) -function check(alpha) - all(ai -> ai > 0, alpha) || - throw(ArgumentError("Dirichlet: alpha must be a positive vector.")) -end - -function Distributions._rand!(rng::Random.AbstractRNG, - d::TuringDirichlet, - x::AbstractVector{<:Real}) - s = 0.0 - n = length(x) - α = d.alpha - for i in 1:n - @inbounds s += (x[i] = rand(rng, Gamma(α[i]))) - end - Distributions.multiply!(x, inv(s)) # this returns x + lmnB::S end function TuringDirichlet(alpha::AbstractVector) - check(alpha) + all(ai -> ai > 0, alpha) || + throw(ArgumentError("Dirichlet: alpha must be a positive vector.")) + alpha0 = sum(alpha) lmnB = sum(loggamma, alpha) - loggamma(alpha0) - T = promote_type(typeof(alpha0), typeof(lmnB)) - TV = typeof(alpha) - TuringDirichlet{T, TV}(alpha, alpha0, lmnB) -end -function TuringDirichlet(d::Integer, alpha::Real) - alpha0 = alpha * d - _alpha = fill(alpha, d) - lmnB = loggamma(alpha) * d - loggamma(alpha0) - T = promote_type(typeof(alpha0), typeof(lmnB)) - TV = typeof(_alpha) - TuringDirichlet{T, TV}(_alpha, alpha0, lmnB) -end -function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer} - TuringDirichlet(float.(alpha)) + return TuringDirichlet(alpha, alpha0, lmnB) end -TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha)) +TuringDirichlet(d::Integer, alpha::Real) = TuringDirichlet(Fill(alpha, d)) +# TODO: remove? +TuringDirichlet(alpha::AbstractVector{<:Integer}) = TuringDirichlet(float.(alpha)) +TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, float(alpha)) + +# TODO: remove and use `Dirichlet` only for `Tracker.TrackedVector` Distributions.Dirichlet(alpha::AbstractVector) = TuringDirichlet(alpha) +TuringDirichlet(d::Dirichlet) = TuringDirichlet(d.alpha, d.alpha0, d.lmnB) + +Base.length(d::TuringDirichlet) = length(d.alpha) + +# copied from Distributions +# TODO: remove and use `Dirichlet`? +function Distributions._rand!( + rng::Random.AbstractRNG, + d::TuringDirichlet, + x::AbstractVector{<:Real}, +) + @inbounds for (i, αi) in zip(eachindex(x), d.alpha) + x[i] = rand(rng, Gamma(αi)) + end + Distributions.multiply!(x, inv(sum(x))) # this returns x +end +function Distributions._rand!( + rng::AbstractRNG, + d::TuringDirichlet{<:Real,<:FillArrays.AbstractFill}, + x::AbstractVector{<:Real} +) + rand!(rng, Gamma(FillArrays.getindex_value(d.alpha)), x) + Distributions.multiply!(x, inv(sum(x))) # this returns x +end + function Distributions._logpdf(d::TuringDirichlet, x::AbstractVector{<:Real}) return simplex_logpdf(d.alpha, d.lmnB, x) end diff --git a/src/reversediff.jl b/src/reversediff.jl index a4c9f42b..6e853d3b 100644 --- a/src/reversediff.jl +++ b/src/reversediff.jl @@ -260,13 +260,13 @@ Dirichlet(alpha::AbstractVector{<:TrackedReal}) = TuringDirichlet(alpha) Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha) function _logpdf(d::Dirichlet, x::AbstractVector{<:TrackedReal}) - return _logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x) + return _logpdf(TuringDirichlet(d), x) end function logpdf(d::Dirichlet, x::AbstractMatrix{<:TrackedReal}) - return logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x) + return logpdf(TuringDirichlet(d), x) end function loglikelihood(d::Dirichlet, x::AbstractMatrix{<:TrackedReal}) - return loglikelihood(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x) + return loglikelihood(TuringDirichlet(d), x) end # default definition of `loglikelihood` yields gradients of zero?! diff --git a/src/tracker.jl b/src/tracker.jl index f08ec1ad..f4a79067 100644 --- a/src/tracker.jl +++ b/src/tracker.jl @@ -371,13 +371,13 @@ Distributions.Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha) Distributions.Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha) function Distributions._logpdf(d::Dirichlet, x::TrackedVector{<:Real}) - return Distributions._logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x) + return Distributions._logpdf(TuringDirichlet(d), x) end function Distributions.logpdf(d::Dirichlet, x::TrackedMatrix{<:Real}) - return logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x) + return logpdf(TuringDirichlet(d), x) end function Distributions.loglikelihood(d::Dirichlet, x::TrackedMatrix{<:Real}) - return loglikelihood(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x) + return loglikelihood(TuringDirichlet(d), x) end # Fix ambiguities @@ -615,4 +615,3 @@ Distributions.InverseWishart(df::TrackedReal, S::AbstractMatrix{<:Real}) = Turin Distributions.InverseWishart(df::Real, S::TrackedMatrix) = TuringInverseWishart(df, S) Distributions.InverseWishart(df::TrackedReal, S::TrackedMatrix) = TuringInverseWishart(df, S) Distributions.InverseWishart(df::TrackedReal, S::AbstractPDMat{<:TrackedReal}) = TuringInverseWishart(df, S) - diff --git a/test/others.jl b/test/others.jl index fc6308a4..fc545bfd 100644 --- a/test/others.jl +++ b/test/others.jl @@ -298,4 +298,42 @@ end end end + + @testset "TuringDirichlet" begin + dim = 3 + n = 4 + for alpha in (2, rand()) + d1 = TuringDirichlet(dim, alpha) + d2 = Dirichlet(dim, alpha) + d3 = TuringDirichlet(d2) + @test d1.alpha == d2.alpha == d3.alpha + @test d1.alpha0 == d2.alpha0 == d3.alpha0 + @test d1.lmnB == d2.lmnB == d3.lmnB + + s1 = rand(d1) + @test s1 isa Vector{Float64} + @test length(s1) == dim + + s2 = rand(d1, n) + @test s2 isa Matrix{Float64} + @test size(s2) == (dim, n) + end + + for alpha in (ones(Int, dim), rand(dim)) + d1 = TuringDirichlet(alpha) + d2 = Dirichlet(alpha) + d3 = TuringDirichlet(d2) + @test d1.alpha == d2.alpha == d3.alpha + @test d1.alpha0 == d2.alpha0 == d3.alpha0 + @test d1.lmnB == d2.lmnB == d3.lmnB + + s1 = rand(d1) + @test s1 isa Vector{Float64} + @test length(s1) == dim + + s2 = rand(d1, n) + @test s2 isa Matrix{Float64} + @test size(s2) == (dim, n) + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 90e91560..779e5ec1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,7 @@ using Random, LinearAlgebra, Test using Distributions: meanlogdet using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal, - TuringPoissonBinomial + TuringPoissonBinomial, TuringDirichlet using StatsBase: entropy using StatsFuns: binomlogpdf, logsumexp, logistic