Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Update to Distributions 0.24.12 #150

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DistributionsAD"
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.16"
version = "0.6.17"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -28,8 +28,8 @@ ChainRules = "0.7"
ChainRulesCore = "0.9.9"
Compat = "3.6"
DiffRules = "0.1, 1.0"
Distributions = "0.23.3, 0.24"
FillArrays = "0.8, 0.9, 0.10"
Distributions = "0.24.12"
FillArrays = "0.8, 0.9, 0.10, 0.11"
ForwardDiff = "0.10.6"
NaNMath = "0.3"
PDMats = "0.9, 0.10"
Expand Down
15 changes: 1 addition & 14 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import StatsFuns: logsumexp,
nbetalogpdf
import Distributions: MvNormal,
MvLogNormal,
poissonbinomial_pdf_fft,
logpdf,
quantile,
PoissonBinomial,
Expand All @@ -46,9 +45,6 @@ export TuringScalMvNormal,
arraydist,
filldist

# check if Distributions >= 0.24 by checking if a generic implementation of `pdf` is defined
const DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF = hasmethod(pdf, Tuple{UnivariateDistribution,Real})

include("common.jl")
include("arraydist.jl")
include("filldist.jl")
Expand All @@ -66,7 +62,7 @@ include("zygote.jl")
using .ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here
include("forwarddiff.jl")

# loads adjoint for `poissonbinomial_pdf` and `poissonbinomial_pdf_fft`
# loads adjoint for `poissonbinomial_pdf`
include("zygote_forwarddiff.jl")
end

Expand Down Expand Up @@ -99,15 +95,6 @@ include("zygote.jl")
return sum(copy(logpdf.(dist.v, x)))
end

function Distributions.logpdf(
dist::LazyVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
size(x, 1) == length(dist) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
return vec(sum(copy(logpdf.(dists, x)), dims = 1))
end

const LazyMatrixOfUnivariate{
S<:ValueSupport,
T<:UnivariateDistribution{S},
Expand Down
23 changes: 1 addition & 22 deletions src/arraydist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,7 @@

const VectorOfUnivariate = Distributions.Product

function arraydist(dists::AbstractVector{<:UnivariateDistribution})
return Product(dists)
end

function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real})
size(x, 1) == length(dist) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
# `eachcol` breaks Zygote, so we use `view` directly
return map(i -> sum(map(logpdf, dist.v, view(x, :, i))), axes(x, 2))
end
arraydist(dists::AbstractVector{<:UnivariateDistribution}) = Product(dists)

struct MatrixOfUnivariate{
S <: ValueSupport,
Expand All @@ -29,12 +20,6 @@ function Distributions._logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Rea
# Broadcasting here breaks Tracker for some reason
return sum(map(logpdf, dist.dists, x))
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return map(x -> logpdf(dist, x), x)
end
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractArray{<:Matrix{<:Real}})
return map(x -> logpdf(dist, x), x)
end

function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate)
return rand.(Ref(rng), dist.dists)
Expand All @@ -59,12 +44,6 @@ function Distributions._logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:R
# `eachcol` breaks Zygote, so we use `view` directly
return sum(i -> logpdf(dist.dists[i], view(x, :, i)), axes(x, 2))
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return map(x -> logpdf(dist, x), x)
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}})
return map(x -> logpdf(dist, x), x)
end

function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
init = reshape(rand(rng, dist.dists[1]), :, 1)
Expand Down
16 changes: 0 additions & 16 deletions src/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,3 @@ function nbinomlogpdf(r::ForwardDiff.Dual{T}, p::Real, k::Int) where {T}
Δ_r = ForwardDiff.partials(r) * _nbinomlogpdf_grad_1(val_r, p, k)
return FD(nbinomlogpdf(val_r, p, k), Δ_r)
end

## ForwardDiff broadcasting support ##
# If we use Distributions >= 0.24, then `DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF` is `true`.
# In Distributions 0.24 `logpdf` is defined for inputs of type `Real` which are then
# converted to the support of the distributions (such as integers) in their concrete implementations.
# Thus it is no needed to have a special function for dual numbers that performs the conversion
# (and actually this method leads to method ambiguity errors since even discrete distributions now
# define logpdf(::MyDistribution, ::Real), see, e.g.,
# JuliaStats/Distributions.jl@ae2d6c5/src/univariate/discrete/binomial.jl#L119).
if !DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF
@eval begin
function Distributions.logpdf(d::DiscreteUnivariateDistribution, k::ForwardDiff.Dual)
return logpdf(d, convert(Integer, ForwardDiff.value(k)))
end
end
end
21 changes: 0 additions & 21 deletions src/matrixvariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,24 +214,3 @@ function Distributions._rand!(rng::AbstractRNG, d::TuringInverseWishart, A::Abst
X = Distributions._rand!(rng, TuringWishart(d.df, inv(cholesky(d.S))), A)
A .= inv(cholesky!(X))
end

# Only needed in Distributions < 0.24
if !DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF
for T in (:MatrixBeta, :MatrixNormal, :Wishart, :InverseWishart,
:TuringWishart, :TuringInverseWishart,
:VectorOfMultivariate, :MatrixOfUnivariate)
@eval begin
Distributions.loglikelihood(d::$T, X::AbstractMatrix{<:Real}) = logpdf(d, X)
function Distributions.loglikelihood(d::$T, X::AbstractArray{<:Real,3})
(size(X, 1), size(X, 2)) == size(d) || throw(DimensionMismatch("Inconsistent array dimensions."))
return sum(i -> _logpdf(d, view(X, :, :, i)), axes(X, 3))
end
function Distributions.loglikelihood(
d::$T,
X::AbstractArray{<:AbstractMatrix{<:Real}},
)
return sum(x -> logpdf(d, x), X)
end
end
end
end
85 changes: 0 additions & 85 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
@@ -1,88 +1,3 @@
## Dirichlet ##

struct TuringDirichlet{T, TV <: AbstractVector} <: ContinuousMultivariateDistribution
alpha::TV
alpha0::T
lmnB::T
end
Base.length(d::TuringDirichlet) = length(d.alpha)
function check(alpha)
all(ai -> ai > 0, alpha) ||
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
end

function Distributions._rand!(rng::Random.AbstractRNG,
d::TuringDirichlet,
x::AbstractVector{<:Real})
s = 0.0
n = length(x)
α = d.alpha
for i in 1:n
@inbounds s += (x[i] = rand(rng, Gamma(α[i])))
end
Distributions.multiply!(x, inv(s)) # this returns x
end

function TuringDirichlet(alpha::AbstractVector)
check(alpha)
alpha0 = sum(alpha)
lmnB = sum(loggamma, alpha) - loggamma(alpha0)
T = promote_type(typeof(alpha0), typeof(lmnB))
TV = typeof(alpha)
TuringDirichlet{T, TV}(alpha, alpha0, lmnB)
end

function TuringDirichlet(d::Integer, alpha::Real)
alpha0 = alpha * d
_alpha = fill(alpha, d)
lmnB = loggamma(alpha) * d - loggamma(alpha0)
T = promote_type(typeof(alpha0), typeof(lmnB))
TV = typeof(_alpha)
TuringDirichlet{T, TV}(_alpha, alpha0, lmnB)
end
function TuringDirichlet(alpha::AbstractVector{T}) where {T <: Integer}
TuringDirichlet(float.(alpha))
end
TuringDirichlet(d::Integer, alpha::Integer) = TuringDirichlet(d, Float64(alpha))

Distributions.Dirichlet(alpha::AbstractVector) = TuringDirichlet(alpha)

function Distributions._logpdf(d::TuringDirichlet, x::AbstractVector{<:Real})
return simplex_logpdf(d.alpha, d.lmnB, x)
end
function Distributions.logpdf(d::TuringDirichlet, x::AbstractMatrix{<:Real})
size(x, 1) == length(d) ||
throw(DimensionMismatch("Inconsistent array dimensions."))
return simplex_logpdf(d.alpha, d.lmnB, x)
end

ZygoteRules.@adjoint function Distributions.Dirichlet(alpha)
return ZygoteRules.pullback(TuringDirichlet, alpha)
end
ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha)
return ZygoteRules.pullback(TuringDirichlet, d, alpha)
end

function simplex_logpdf(alpha, lmnB, x::AbstractVector)
sum((alpha .- 1) .* log.(x)) - lmnB
end
function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
@views init = vcat(sum((alpha .- 1) .* log.(x[:,1])) - lmnB)
mapreduce(vcat, drop(eachcol(x), 1); init = init) do c
sum((alpha .- 1) .* log.(c)) - lmnB
end
end

ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector)
simplex_logpdf(alpha, lmnB, x), Δ -> (Δ .* log.(x), -Δ, Δ .* (alpha .- 1) ./ x)
end

ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
simplex_logpdf(alpha, lmnB, x), Δ -> begin
(log.(x) * Δ, -sum(Δ), ((alpha .- 1) ./ x) * Diagonal(Δ))
end
end

## MvNormal ##

"""
Expand Down
35 changes: 0 additions & 35 deletions src/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import Distributions: logpdf,
Gamma,
MvNormal,
MvLogNormal,
Dirichlet,
Wishart,
InverseWishart,
PoissonBinomial,
Expand All @@ -44,7 +43,6 @@ using ..DistributionsAD: TuringPoissonBinomial,
TuringMvLogNormal,
TuringWishart,
TuringInverseWishart,
TuringDirichlet,
TuringScalMvNormal,
TuringDiagMvNormal,
TuringDenseMvNormal
Expand Down Expand Up @@ -240,39 +238,6 @@ end
# zero mean,, constant variance
MvLogNormal(d::Int, σ::TrackedReal) = TuringMvLogNormal(TuringMvNormal(d, σ))

Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)

for func_header in [
:(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::AbstractVector)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractVector)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::TrackedVector)),
:(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::AbstractVector)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::TrackedVector)),
:(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::TrackedVector)),
:(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::TrackedVector)),

:(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::AbstractMatrix)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractMatrix)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::TrackedMatrix)),
:(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::AbstractMatrix)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::TrackedMatrix)),
:(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::TrackedMatrix)),
:(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::TrackedMatrix)),
]
@eval $func_header = track(simplex_logpdf, alpha, lmnB, x)
end
@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector)
simplex_logpdf(value(alpha), value(lmnB), value(x)), Δ -> begin
(Δ .* log.(value(x)), -Δ, Δ .* (value(alpha) .- 1))
end
end
@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
simplex_logpdf(value(alpha), value(lmnB), value(x)), Δ -> begin
(log.(value(x)) * Δ, -sum(Δ), repeat(value(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ))
end
end

Distributions.Wishart(df::TrackedReal, S::Matrix{<:Real}) = TuringWishart(df, S)
Distributions.Wishart(df::TrackedReal, S::AbstractMatrix{<:Real}) = TuringWishart(df, S)
Distributions.Wishart(df::Real, S::AbstractMatrix{<:TrackedReal}) = TuringWishart(df, S)
Expand Down
Loading