Skip to content

Commit

Permalink
Fix Dirichlet with ReverseDiff (#151)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jan 29, 2021
1 parent 2a6622c commit 3053aae
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 34 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DistributionsAD"
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.16"
version = "0.6.17"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
9 changes: 2 additions & 7 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,9 @@ ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha)
return ZygoteRules.pullback(TuringDirichlet, d, alpha)
end

function simplex_logpdf(alpha, lmnB, x::AbstractVector)
sum((alpha .- 1) .* log.(x)) - lmnB
end
simplex_logpdf(alpha, lmnB, x::AbstractVector) = sum(xlogy.(alpha .- 1, x)) - lmnB
function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
@views init = vcat(sum((alpha .- 1) .* log.(x[:,1])) - lmnB)
mapreduce(vcat, drop(eachcol(x), 1); init = init) do c
sum((alpha .- 1) .* log.(c)) - lmnB
end
return vec(sum(xlogy.(alpha .- 1, x); dims=1)) .- lmnB
end

ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector)
Expand Down
76 changes: 57 additions & 19 deletions src/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ using ..DistributionsAD: DistributionsAD


import SpecialFunctions, NaNMath
import ..DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf, adapt_randn
import ..DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf, adapt_randn,
simplex_logpdf
import Base.Broadcast: materialize
import StatsFuns: logsumexp

Expand Down Expand Up @@ -47,12 +48,25 @@ using ..DistributionsAD: TuringPoissonBinomial,
TuringDirichlet,
TuringScalMvNormal,
TuringDiagMvNormal,
TuringDenseMvNormal
TuringDenseMvNormal,
VectorOfMultivariate,
FillVectorOfMultivariate

include("reversediffx.jl")

adapt_randn(rng::Random.AbstractRNG, x::TrackedArray, dims...) = adapt_randn(rng, value(x), dims...)

# without this definition tests of `VectorOfMultivariate` with `Dirichlet` fail
# upstream bug caused by `view` + `track`: https://github.com/JuliaDiff/ReverseDiff.jl/pull/164
function _logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:TrackedReal})
return sum(i -> _logpdf(dist.dists[i], x[:, i]), axes(x, 2))
end

# fix method ambiguity
function _logpdf(dist::FillVectorOfMultivariate, x::AbstractMatrix{<:TrackedReal})
return loglikelihood(dist.dists.value, x)
end

function PoissonBinomial(p::TrackedArray{<:Real}; check_args=true)
return TuringPoissonBinomial(p; check_args = check_args)
end
Expand Down Expand Up @@ -240,36 +254,60 @@ end
# zero mean,, constant variance
MvLogNormal(d::Int, σ::TrackedReal) = TuringMvLogNormal(TuringMvNormal(d, σ))

Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha)
# Dirichlet

Dirichlet(alpha::AbstractVector{<:TrackedReal}) = TuringDirichlet(alpha)
Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha)

function _logpdf(d::Dirichlet, x::AbstractVector{<:TrackedReal})
return _logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
end
function logpdf(d::Dirichlet, x::AbstractMatrix{<:TrackedReal})
return logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
end
function loglikelihood(d::Dirichlet, x::AbstractMatrix{<:TrackedReal})
return loglikelihood(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x)
end

# default definition of `loglikelihood` yields gradients of zero?!
# upstream bug caused by `view` + `track`: https://github.com/JuliaDiff/ReverseDiff.jl/pull/164
function loglikelihood(d::TuringDirichlet, x::AbstractMatrix{<:TrackedReal})
return sum(i -> logpdf(d, x[:, i]), axes(x, 2))
end

for func_header in [
:(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::AbstractVector)),
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractVector)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractVector)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::TrackedVector)),
:(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::AbstractVector)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::TrackedVector)),
:(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::TrackedVector)),
:(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::TrackedVector)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::AbstractVector{<:TrackedReal})),
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::TrackedReal, x::AbstractVector)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractVector{<:TrackedReal})),
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractVector{<:TrackedReal})),
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::TrackedReal, x::AbstractVector{<:TrackedReal})),

:(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::AbstractMatrix)),
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractMatrix)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractMatrix)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::TrackedMatrix)),
:(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::AbstractMatrix)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::TrackedMatrix)),
:(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::TrackedMatrix)),
:(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::TrackedMatrix)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::AbstractMatrix{<:TrackedReal})),
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::TrackedReal, x::AbstractMatrix)),
:(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractMatrix{<:TrackedReal})),
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractMatrix{<:TrackedReal})),
:(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::TrackedReal, x::AbstractMatrix{<:TrackedReal})),
]
@eval $func_header = track(simplex_logpdf, alpha, lmnB, x)
end
@grad function simplex_logpdf(alpha, lmnB, x::AbstractVector)
simplex_logpdf(value(alpha), value(lmnB), value(x)), Δ -> begin
.* log.(value(x)), -Δ, Δ .* (value(alpha) .- 1))
_alpha = value(alpha)
_lmnB = value(lmnB)
_x = value(x)
simplex_logpdf(_alpha, _lmnB, _x), Δ -> begin
.* log.(_x), -Δ, Δ .* (_alpha .- 1) ./ _x)
end
end
@grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix)
simplex_logpdf(value(alpha), value(lmnB), value(x)), Δ -> begin
(log.(value(x)) * Δ, -sum(Δ), repeat(value(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ))
_alpha = value(alpha)
_lmnB = value(lmnB)
_x = value(x)
simplex_logpdf(_alpha, _lmnB, _x), Δ -> begin
(log.(_x) * Δ, -sum(Δ), ((_alpha .- 1) ./ _x) * Diagonal(Δ))
end
end

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Combinatorics = "1.0.2"
Distributions = "0.24.3"
FiniteDifferences = "0.11.3, 0.12"
ForwardDiff = "0.10.12"
NNlib = "0.7.7"
NNlib = "0.7.10"
PDMats = "0.10.1"
ReverseDiff = "1.4.4"
StatsBase = "0.33.2"
Expand Down
4 changes: 0 additions & 4 deletions test/ad/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
to_positive(x) = exp.(x)
to_positive(x::AbstractArray{<:AbstractArray}) = to_positive.(x)

# Create vectors in probability simplex.
to_simplex(x::AbstractArray; dims=1) = NNlib.softmax(x; dims=dims)
to_simplex(x::AbstractArray{<:AbstractArray}; dims=1) = to_simplex.(x; dims=dims)

# Tests that have a `broken` field can be executed but, according to FiniteDifferences,
# fail to produce the correct result. These tests can be checked with `@test_broken`.
univariate_distributions = DistSpec[
Expand Down
25 changes: 23 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,32 @@ if GROUP == "All" || GROUP == "AD"
to_posdef(A::AbstractMatrix) = A * A' + I
to_posdef_diagonal(a::AbstractVector) = Diagonal(a.^2 .+ 1)

# Create vectors in probability simplex.
to_simplex(x::AbstractArray) = NNlib.softmax(x; dims=1)
to_simplex(x::AbstractArray{<:AbstractArray}) = to_simplex.(x)

if AD == "All" || AD == "ReverseDiff"
@eval begin
# Define adjoint for ReverseDiff
function to_simplex(x::AbstractArray{<:ReverseDiff.TrackedReal})
return ReverseDiff.track(to_simplex, x)
end
ReverseDiff.@grad function to_simplex(x)
_x = ReverseDiff.value(x)
y = to_simplex(_x)
function pullback(∇)
return (NNlib.∇softmax(∇, _x, y; dims=1),)
end
return y, pullback
end
end
end

if AD == "All" || AD == "Tracker"
@eval begin
# Define adjoints for Tracker
to_posdef(A::TrackedMatrix) = Tracker.track(to_posdef, A)
Tracker.@grad function to_posdef(A::TrackedMatrix)
to_posdef(A::Tracker.TrackedMatrix) = Tracker.track(to_posdef, A)
Tracker.@grad function to_posdef(A::Tracker.TrackedMatrix)
data_A = Tracker.data(A)
S = data_A * data_A' + I
function pullback(∇)
Expand Down

2 comments on commit 3053aae

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/28959

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.17 -m "<description of version>" 3053aae1c6b8c49d7bc1b48b3b16163b2d533012
git push origin v0.6.17

Please sign in to comment.