diff --git a/Project.toml b/Project.toml index b5357221..36e54d04 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.16" +version = "0.6.17" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/multivariate.jl b/src/multivariate.jl index c41734f1..0f388ebb 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -63,14 +63,9 @@ ZygoteRules.@adjoint function Distributions.Dirichlet(d, alpha) return ZygoteRules.pullback(TuringDirichlet, d, alpha) end -function simplex_logpdf(alpha, lmnB, x::AbstractVector) - sum((alpha .- 1) .* log.(x)) - lmnB -end +simplex_logpdf(alpha, lmnB, x::AbstractVector) = sum(xlogy.(alpha .- 1, x)) - lmnB function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) - @views init = vcat(sum((alpha .- 1) .* log.(x[:,1])) - lmnB) - mapreduce(vcat, drop(eachcol(x), 1); init = init) do c - sum((alpha .- 1) .* log.(c)) - lmnB - end + return vec(sum(xlogy.(alpha .- 1, x); dims=1)) .- lmnB end ZygoteRules.@adjoint function simplex_logpdf(alpha, lmnB, x::AbstractVector) diff --git a/src/reversediff.jl b/src/reversediff.jl index 6ced110b..a4c9f42b 100644 --- a/src/reversediff.jl +++ b/src/reversediff.jl @@ -18,7 +18,8 @@ using ..DistributionsAD: DistributionsAD import SpecialFunctions, NaNMath -import ..DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf, adapt_randn +import ..DistributionsAD: turing_chol, symm_turing_chol, _mv_categorical_logpdf, adapt_randn, + simplex_logpdf import Base.Broadcast: materialize import StatsFuns: logsumexp @@ -47,12 +48,25 @@ using ..DistributionsAD: TuringPoissonBinomial, TuringDirichlet, TuringScalMvNormal, TuringDiagMvNormal, - TuringDenseMvNormal + TuringDenseMvNormal, + VectorOfMultivariate, + FillVectorOfMultivariate include("reversediffx.jl") adapt_randn(rng::Random.AbstractRNG, x::TrackedArray, dims...) = adapt_randn(rng, value(x), dims...) +# without this definition tests of `VectorOfMultivariate` with `Dirichlet` fail +# upstream bug caused by `view` + `track`: https://github.com/JuliaDiff/ReverseDiff.jl/pull/164 +function _logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:TrackedReal}) + return sum(i -> _logpdf(dist.dists[i], x[:, i]), axes(x, 2)) +end + +# fix method ambiguity +function _logpdf(dist::FillVectorOfMultivariate, x::AbstractMatrix{<:TrackedReal}) + return loglikelihood(dist.dists.value, x) +end + function PoissonBinomial(p::TrackedArray{<:Real}; check_args=true) return TuringPoissonBinomial(p; check_args = check_args) end @@ -240,36 +254,60 @@ end # zero mean,, constant variance MvLogNormal(d::Int, σ::TrackedReal) = TuringMvLogNormal(TuringMvNormal(d, σ)) -Dirichlet(alpha::TrackedVector) = TuringDirichlet(alpha) +# Dirichlet + +Dirichlet(alpha::AbstractVector{<:TrackedReal}) = TuringDirichlet(alpha) Dirichlet(d::Integer, alpha::TrackedReal) = TuringDirichlet(d, alpha) +function _logpdf(d::Dirichlet, x::AbstractVector{<:TrackedReal}) + return _logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x) +end +function logpdf(d::Dirichlet, x::AbstractMatrix{<:TrackedReal}) + return logpdf(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x) +end +function loglikelihood(d::Dirichlet, x::AbstractMatrix{<:TrackedReal}) + return loglikelihood(TuringDirichlet(d.alpha, d.alpha0, d.lmnB), x) +end + +# default definition of `loglikelihood` yields gradients of zero?! +# upstream bug caused by `view` + `track`: https://github.com/JuliaDiff/ReverseDiff.jl/pull/164 +function loglikelihood(d::TuringDirichlet, x::AbstractMatrix{<:TrackedReal}) + return sum(i -> logpdf(d, x[:, i]), axes(x, 2)) +end + for func_header in [ - :(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::AbstractVector)), + :(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractVector)), :(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractVector)), - :(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::TrackedVector)), - :(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::AbstractVector)), - :(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::TrackedVector)), - :(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::TrackedVector)), - :(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::TrackedVector)), + :(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::AbstractVector{<:TrackedReal})), + :(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::TrackedReal, x::AbstractVector)), + :(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractVector{<:TrackedReal})), + :(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractVector{<:TrackedReal})), + :(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::TrackedReal, x::AbstractVector{<:TrackedReal})), - :(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::AbstractMatrix)), + :(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractMatrix)), :(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractMatrix)), - :(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::TrackedMatrix)), - :(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::AbstractMatrix)), - :(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::TrackedMatrix)), - :(simplex_logpdf(alpha::TrackedVector, lmnB::Real, x::TrackedMatrix)), - :(simplex_logpdf(alpha::TrackedVector, lmnB::TrackedReal, x::TrackedMatrix)), + :(simplex_logpdf(alpha::AbstractVector, lmnB::Real, x::AbstractMatrix{<:TrackedReal})), + :(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::TrackedReal, x::AbstractMatrix)), + :(simplex_logpdf(alpha::AbstractVector, lmnB::TrackedReal, x::AbstractMatrix{<:TrackedReal})), + :(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::Real, x::AbstractMatrix{<:TrackedReal})), + :(simplex_logpdf(alpha::AbstractVector{<:TrackedReal}, lmnB::TrackedReal, x::AbstractMatrix{<:TrackedReal})), ] @eval $func_header = track(simplex_logpdf, alpha, lmnB, x) end @grad function simplex_logpdf(alpha, lmnB, x::AbstractVector) - simplex_logpdf(value(alpha), value(lmnB), value(x)), Δ -> begin - (Δ .* log.(value(x)), -Δ, Δ .* (value(alpha) .- 1)) + _alpha = value(alpha) + _lmnB = value(lmnB) + _x = value(x) + simplex_logpdf(_alpha, _lmnB, _x), Δ -> begin + (Δ .* log.(_x), -Δ, Δ .* (_alpha .- 1) ./ _x) end end @grad function simplex_logpdf(alpha, lmnB, x::AbstractMatrix) - simplex_logpdf(value(alpha), value(lmnB), value(x)), Δ -> begin - (log.(value(x)) * Δ, -sum(Δ), repeat(value(alpha) .- 1, 1, size(x, 2)) * Diagonal(Δ)) + _alpha = value(alpha) + _lmnB = value(lmnB) + _x = value(x) + simplex_logpdf(_alpha, _lmnB, _x), Δ -> begin + (log.(_x) * Δ, -sum(Δ), ((_alpha .- 1) ./ _x) * Diagonal(Δ)) end end diff --git a/test/Project.toml b/test/Project.toml index 5271e9cf..d1bbbba4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -21,7 +21,7 @@ Combinatorics = "1.0.2" Distributions = "0.24.3" FiniteDifferences = "0.11.3, 0.12" ForwardDiff = "0.10.12" -NNlib = "0.7.7" +NNlib = "0.7.10" PDMats = "0.10.1" ReverseDiff = "1.4.4" StatsBase = "0.33.2" diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index a543a5dd..90b541c6 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -26,10 +26,6 @@ to_positive(x) = exp.(x) to_positive(x::AbstractArray{<:AbstractArray}) = to_positive.(x) - # Create vectors in probability simplex. - to_simplex(x::AbstractArray; dims=1) = NNlib.softmax(x; dims=dims) - to_simplex(x::AbstractArray{<:AbstractArray}; dims=1) = to_simplex.(x; dims=dims) - # Tests that have a `broken` field can be executed but, according to FiniteDifferences, # fail to produce the correct result. These tests can be checked with `@test_broken`. univariate_distributions = DistSpec[ diff --git a/test/runtests.jl b/test/runtests.jl index 9f2d7773..90e91560 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,11 +45,32 @@ if GROUP == "All" || GROUP == "AD" to_posdef(A::AbstractMatrix) = A * A' + I to_posdef_diagonal(a::AbstractVector) = Diagonal(a.^2 .+ 1) + # Create vectors in probability simplex. + to_simplex(x::AbstractArray) = NNlib.softmax(x; dims=1) + to_simplex(x::AbstractArray{<:AbstractArray}) = to_simplex.(x) + + if AD == "All" || AD == "ReverseDiff" + @eval begin + # Define adjoint for ReverseDiff + function to_simplex(x::AbstractArray{<:ReverseDiff.TrackedReal}) + return ReverseDiff.track(to_simplex, x) + end + ReverseDiff.@grad function to_simplex(x) + _x = ReverseDiff.value(x) + y = to_simplex(_x) + function pullback(∇) + return (NNlib.∇softmax(∇, _x, y; dims=1),) + end + return y, pullback + end + end + end + if AD == "All" || AD == "Tracker" @eval begin # Define adjoints for Tracker - to_posdef(A::TrackedMatrix) = Tracker.track(to_posdef, A) - Tracker.@grad function to_posdef(A::TrackedMatrix) + to_posdef(A::Tracker.TrackedMatrix) = Tracker.track(to_posdef, A) + Tracker.@grad function to_posdef(A::Tracker.TrackedMatrix) data_A = Tracker.data(A) S = data_A * data_A' + I function pullback(∇)