diff --git a/src/arraydist.jl b/src/arraydist.jl index 8624568e..7c5c29c5 100644 --- a/src/arraydist.jl +++ b/src/arraydist.jl @@ -5,7 +5,8 @@ function summaporbroadcast(f, dists::AbstractArray, x::AbstractArray) return sum(map(f, dists, x)) end function summaporbroadcast(f, dists::AbstractVector, x::AbstractMatrix) - return map(x -> summaporbroadcast(f, dists, x), eachcol(x)) + # `eachcol` breaks Zygote, so we use `view` directly + return map(i -> summaporbroadcast(f, dists, view(x, :, i)), axes(x, 2)) end @init @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin function summaporbroadcast(f, dists::LazyArrays.BroadcastArray, x::AbstractArray) @@ -33,14 +34,6 @@ function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real # eachcol breaks Zygote, so we need an adjoint return summaporbroadcast(logpdf, dist.v, x) end -ZygoteRules.@adjoint function Distributions.logpdf( - dist::VectorOfUnivariate, - x::AbstractMatrix{<:Real} -) - # Any other more efficient implementation breaks Zygote - f(dist, x) = [sum(logpdf.(dist.v, view(x, :, i))) for i in 1:size(x, 2)] - return ZygoteRules.pullback(f, dist, x) -end struct MatrixOfUnivariate{ S <: ValueSupport, @@ -80,9 +73,10 @@ Base.length(dist::VectorOfMultivariate) = length(dist.dists) function arraydist(dists::AbstractVector{<:MultivariateDistribution}) return VectorOfMultivariate(dists) end + function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) - # eachcol breaks Zygote, so we define an adjoint - return sum(map(logpdf, dist.dists, eachcol(x))) + # `eachcol` breaks Zygote, so we use `view` directly + return sum(map((d, i) -> logpdf(d, view(x, :, i)), dist.dists, axes(x, 2))) end function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}}) return map(x -> logpdf(dist, x), x) @@ -90,14 +84,7 @@ end function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}}) return map(x -> logpdf(dist, x), x) end -ZygoteRules.@adjoint function Distributions.logpdf( - dist::VectorOfMultivariate, - x::AbstractMatrix{<:Real} -) - return ZygoteRules.pullback(dist, x) do dist, x - sum(map(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2))) - end -end + function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) init = reshape(rand(rng, dist.dists[1]), :, 1) return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init) diff --git a/src/filldist.jl b/src/filldist.jl index 7dcafb96..4dbd9672 100644 --- a/src/filldist.jl +++ b/src/filldist.jl @@ -23,12 +23,6 @@ function Distributions.logpdf( ) return _logpdf(dist, x) end -ZygoteRules.@adjoint function Distributions.logpdf( - dist::FillVectorOfUnivariate, - x::AbstractMatrix{<:Real}, -) - return ZygoteRules.pullback(_logpdf, dist, x) -end function _logpdf( dist::FillVectorOfUnivariate, @@ -103,18 +97,14 @@ function Distributions.logpdf( ) return _logpdf(dist, x) end + function _logpdf( dist::FillVectorOfMultivariate, x::AbstractMatrix{<:Real}, ) return sum(logpdf(dist.dists.value, x)) end -ZygoteRules.@adjoint function Distributions.logpdf( - dist::FillVectorOfMultivariate, - x::AbstractMatrix{<:Real}, -) - return ZygoteRules.pullback(_logpdf, dist, x) -end + function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate) return rand(rng, dist.dists.value, length.(dist.dists.axes)...,) end diff --git a/src/multivariate.jl b/src/multivariate.jl index 7000cb65..77ba5c23 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -10,9 +10,7 @@ function check(alpha) all(ai -> ai > 0, alpha) || throw(ArgumentError("Dirichlet: alpha must be a positive vector.")) end -ZygoteRules.@adjoint function check(alpha) - return check(alpha), _ -> (nothing,) -end + function Distributions._rand!(rng::Random.AbstractRNG, d::TuringDirichlet, x::AbstractVector{<:Real})