Skip to content

Commit

Permalink
Remove some custom adjoints (#104)
Browse files Browse the repository at this point in the history
* Remove some custom adjoints

* Fix logpdf for VectorOfMultivariate
  • Loading branch information
devmotion authored Aug 23, 2020
1 parent 01ad761 commit a96b159
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 34 deletions.
25 changes: 6 additions & 19 deletions src/arraydist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ function summaporbroadcast(f, dists::AbstractArray, x::AbstractArray)
return sum(map(f, dists, x))
end
function summaporbroadcast(f, dists::AbstractVector, x::AbstractMatrix)
return map(x -> summaporbroadcast(f, dists, x), eachcol(x))
# `eachcol` breaks Zygote, so we use `view` directly
return map(i -> summaporbroadcast(f, dists, view(x, :, i)), axes(x, 2))
end
@init @require LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" begin
function summaporbroadcast(f, dists::LazyArrays.BroadcastArray, x::AbstractArray)
Expand Down Expand Up @@ -33,14 +34,6 @@ function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real
# eachcol breaks Zygote, so we need an adjoint
return summaporbroadcast(logpdf, dist.v, x)
end
ZygoteRules.@adjoint function Distributions.logpdf(
dist::VectorOfUnivariate,
x::AbstractMatrix{<:Real}
)
# Any other more efficient implementation breaks Zygote
f(dist, x) = [sum(logpdf.(dist.v, view(x, :, i))) for i in 1:size(x, 2)]
return ZygoteRules.pullback(f, dist, x)
end

struct MatrixOfUnivariate{
S <: ValueSupport,
Expand Down Expand Up @@ -80,24 +73,18 @@ Base.length(dist::VectorOfMultivariate) = length(dist.dists)
function arraydist(dists::AbstractVector{<:MultivariateDistribution})
return VectorOfMultivariate(dists)
end

function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real})
# eachcol breaks Zygote, so we define an adjoint
return sum(map(logpdf, dist.dists, eachcol(x)))
# `eachcol` breaks Zygote, so we use `view` directly
return sum(map((d, i) -> logpdf(d, view(x, :, i)), dist.dists, axes(x, 2)))
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:AbstractMatrix{<:Real}})
return map(x -> logpdf(dist, x), x)
end
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractArray{<:Matrix{<:Real}})
return map(x -> logpdf(dist, x), x)
end
ZygoteRules.@adjoint function Distributions.logpdf(
dist::VectorOfMultivariate,
x::AbstractMatrix{<:Real}
)
return ZygoteRules.pullback(dist, x) do dist, x
sum(map(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2)))
end
end

function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate)
init = reshape(rand(rng, dist.dists[1]), :, 1)
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init)
Expand Down
14 changes: 2 additions & 12 deletions src/filldist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,6 @@ function Distributions.logpdf(
)
return _logpdf(dist, x)
end
ZygoteRules.@adjoint function Distributions.logpdf(
dist::FillVectorOfUnivariate,
x::AbstractMatrix{<:Real},
)
return ZygoteRules.pullback(_logpdf, dist, x)
end

function _logpdf(
dist::FillVectorOfUnivariate,
Expand Down Expand Up @@ -103,18 +97,14 @@ function Distributions.logpdf(
)
return _logpdf(dist, x)
end

function _logpdf(
dist::FillVectorOfMultivariate,
x::AbstractMatrix{<:Real},
)
return sum(logpdf(dist.dists.value, x))
end
ZygoteRules.@adjoint function Distributions.logpdf(
dist::FillVectorOfMultivariate,
x::AbstractMatrix{<:Real},
)
return ZygoteRules.pullback(_logpdf, dist, x)
end

function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate)
return rand(rng, dist.dists.value, length.(dist.dists.axes)...,)
end
4 changes: 1 addition & 3 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ function check(alpha)
all(ai -> ai > 0, alpha) ||
throw(ArgumentError("Dirichlet: alpha must be a positive vector."))
end
ZygoteRules.@adjoint function check(alpha)
return check(alpha), _ -> (nothing,)
end

function Distributions._rand!(rng::Random.AbstractRNG,
d::TuringDirichlet,
x::AbstractVector{<:Real})
Expand Down

2 comments on commit a96b159

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/20032

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.4 -m "<description of version>" a96b159ab25aab67d1a2076726e8b9c392eb6fc7
git push origin v0.6.4

Please sign in to comment.