diff --git a/Project.toml b/Project.toml index 60488b66..e429698e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.9" +version = "0.6.10" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -28,7 +28,7 @@ ChainRules = "0.7" ChainRulesCore = "0.9.9" Compat = "3.6" DiffRules = "0.1, 1.0" -Distributions = "0.23.3" +Distributions = "0.23.3, 0.24" FillArrays = "0.8, 0.9" ForwardDiff = "0.10.6" NaNMath = "0.3" diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 6e708dc9..c9306485 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -46,6 +46,9 @@ export TuringScalMvNormal, arraydist, filldist +# check if Distributions >= 0.24 by checking if a generic implementation of `pdf` is defined +const DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF = hasmethod(pdf, Tuple{UnivariateDistribution,Real}) + include("common.jl") include("arraydist.jl") include("filldist.jl") @@ -63,7 +66,7 @@ include("zygote.jl") using .ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here include("forwarddiff.jl") - # loads adjoint for `poissonbinomial_pdf_fft` + # loads adjoint for `poissonbinomial_pdf` and `poissonbinomial_pdf_fft` include("zygote_forwarddiff.jl") end diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 3efbbf71..4e591be2 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -60,7 +60,17 @@ function nbinomlogpdf(r::ForwardDiff.Dual{T}, p::Real, k::Int) where {T} end ## ForwardDiff broadcasting support ## - -function Distributions.logpdf(d::DiscreteUnivariateDistribution, k::ForwardDiff.Dual) - return logpdf(d, convert(Integer, ForwardDiff.value(k))) +# If we use Distributions >= 0.24, then `DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF` is `true`. +# In Distributions 0.24 `logpdf` is defined for inputs of type `Real` which are then +# converted to the support of the distributions (such as integers) in their concrete implementations. +# Thus it is no needed to have a special function for dual numbers that performs the conversion +# (and actually this method leads to method ambiguity errors since even discrete distributions now +# define logpdf(::MyDistribution, ::Real), see, e.g., +# JuliaStats/Distributions.jl@ae2d6c5/src/univariate/discrete/binomial.jl#L119). +if !DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF + @eval begin + function Distributions.logpdf(d::DiscreteUnivariateDistribution, k::ForwardDiff.Dual) + return logpdf(d, convert(Integer, ForwardDiff.value(k))) + end + end end diff --git a/src/matrixvariate.jl b/src/matrixvariate.jl index ac57d9dd..42c588d7 100644 --- a/src/matrixvariate.jl +++ b/src/matrixvariate.jl @@ -215,21 +215,23 @@ function Distributions._rand!(rng::AbstractRNG, d::TuringInverseWishart, A::Abst A .= inv(cholesky!(X)) end -# TODO: Remove when available in Distributions -for T in (:MatrixBeta, :MatrixNormal, :Wishart, :InverseWishart, - :TuringWishart, :TuringInverseWishart, - :VectorOfMultivariate, :MatrixOfUnivariate) - @eval begin - Distributions.loglikelihood(d::$T, X::AbstractMatrix{<:Real}) = logpdf(d, X) - function Distributions.loglikelihood(d::$T, X::AbstractArray{<:Real,3}) - (size(X, 1), size(X, 2)) == size(d) || throw(DimensionMismatch("Inconsistent array dimensions.")) - return sum(i -> _logpdf(d, view(X, :, :, i)), axes(X, 3)) - end - function Distributions.loglikelihood( - d::$T, - X::AbstractArray{<:AbstractMatrix{<:Real}}, - ) - return sum(x -> logpdf(d, x), X) +# Only needed in Distributions < 0.24 +if !DISTRIBUTIONS_HAS_GENERIC_UNIVARIATE_PDF + for T in (:MatrixBeta, :MatrixNormal, :Wishart, :InverseWishart, + :TuringWishart, :TuringInverseWishart, + :VectorOfMultivariate, :MatrixOfUnivariate) + @eval begin + Distributions.loglikelihood(d::$T, X::AbstractMatrix{<:Real}) = logpdf(d, X) + function Distributions.loglikelihood(d::$T, X::AbstractArray{<:Real,3}) + (size(X, 1), size(X, 2)) == size(d) || throw(DimensionMismatch("Inconsistent array dimensions.")) + return sum(i -> _logpdf(d, view(X, :, :, i)), axes(X, 3)) + end + function Distributions.loglikelihood( + d::$T, + X::AbstractArray{<:AbstractMatrix{<:Real}}, + ) + return sum(x -> logpdf(d, x), X) + end end end end diff --git a/src/tracker.jl b/src/tracker.jl index 93f7fcc4..f08ec1ad 100644 --- a/src/tracker.jl +++ b/src/tracker.jl @@ -260,16 +260,30 @@ end PoissonBinomial(p::TrackedArray{<:Real}; check_args=true) = TuringPoissonBinomial(p; check_args = check_args) + +# TODO: add adjoints without ForwardDiff poissonbinomial_pdf_fft(x::TrackedArray) = track(poissonbinomial_pdf_fft, x) @grad function poissonbinomial_pdf_fft(x::TrackedArray) x_data = data(x) T = eltype(x_data) fft = poissonbinomial_pdf_fft(x_data) return fft, Δ -> begin - ((ForwardDiff.jacobian(x -> poissonbinomial_pdf_fft(x), x_data)::Matrix{T})' * Δ,) + ((ForwardDiff.jacobian(poissonbinomial_pdf_fft, x_data)::Matrix{T})' * Δ,) end end +if isdefined(Distributions, :poissonbinomial_pdf) + Distributions.poissonbinomial_pdf(x::TrackedArray) = track(Distributions.poissonbinomial_pdf, x) + @grad function Distributions.poissonbinomial_pdf(x::TrackedArray) + x_data = data(x) + T = eltype(x_data) + value = Distributions.poissonbinomial_pdf(x_data) + function poissonbinomial_pdf_pullback(Δ) + return ((ForwardDiff.jacobian(Distributions.poissonbinomial_pdf, x_data)::Matrix{T})' * Δ,) + end + return value, poissonbinomial_pdf_pullback + end +end ## Semicircle ## diff --git a/src/univariate.jl b/src/univariate.jl index a7469ab9..d86b0577 100644 --- a/src/univariate.jl +++ b/src/univariate.jl @@ -38,11 +38,14 @@ struct TuringPoissonBinomial{T<:Real, TV1<:AbstractVector{T}, TV2<:AbstractVecto pmf::TV2 end -function TuringPoissonBinomial(p::AbstractArray{<:Real}; check_args = true) - pb = Distributions.poissonbinomial_pdf_fft(p) - ϵ = eps(eltype(pb)) - check_args && @assert all(x -> x >= -ϵ, pb) && isapprox(sum(pb), 1; atol=ϵ) - return TuringPoissonBinomial(p, pb) +# if available use the faster `poissonbinomial_pdf` +@eval begin + function TuringPoissonBinomial(p::AbstractArray{<:Real}; check_args = true) + pb = $(isdefined(Distributions, :poissonbinomial_pdf) ? Distributions.poissonbinomial_pdf : Distributions.poissonbinomial_pdf_fft)(p) + ϵ = eps(eltype(pb)) + check_args && @assert all(x -> x >= -ϵ, pb) && isapprox(sum(pb), 1; atol=ϵ) + return TuringPoissonBinomial(p, pb) + end end function logpdf(d::TuringPoissonBinomial{T}, k::Int) where T<:Real diff --git a/src/zygote_forwarddiff.jl b/src/zygote_forwarddiff.jl index a9ea088f..7e157379 100644 --- a/src/zygote_forwarddiff.jl +++ b/src/zygote_forwarddiff.jl @@ -1,68 +1,20 @@ # Zygote loads ForwardDiff, so this adjoint will autmatically be loaded together # with `using Zygote`. -# FIXME: This is inefficient, replace with the commented code below once Zygote supports it. +# TODO: add adjoints without ForwardDiff @adjoint function poissonbinomial_pdf_fft(x::AbstractArray{T}) where T<:Real fft = poissonbinomial_pdf_fft(x) return fft, Δ -> begin - ((ForwardDiff.jacobian(x -> poissonbinomial_pdf_fft(x), x)::Matrix{T})' * Δ,) + ((ForwardDiff.jacobian(poissonbinomial_pdf_fft, x)::Matrix{T})' * Δ,) end end -# The code below doesn't work because of bugs in Zygote. The above is inefficient. -#= -ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{<:Real}) - return ZygoteRules.pullback(poissonbinomial_pdf_fft_zygote, x) -end -function poissonbinomial_pdf_fft_zygote(p::AbstractArray{T}) where {T <: Real} - n = length(p) - ω = 2 * one(T) / (n + 1) - lmax = ceil(Int, n/2) - x1 = [one(T)/(n + 1)] - x_lmaxp1 = map(1:lmax) do l - logz = zero(T) - argz = zero(T) - for j=1:n - zjl = 1 - p[j] + p[j] * cospi(ω*l) + im * p[j] * sinpi(ω * l) - logz += log(abs(zjl)) - argz += atan(imag(zjl), real(zjl)) - end - dl = exp(logz) - return dl * cos(argz) / (n + 1) + dl * sin(argz) * im / (n + 1) - end - x_lmaxp2_end = [conj(x[l + 1]) for l in lmax:-1:1 if n + 1 - l > l] - x = vcat(x1; x_lmaxp1, x_lmaxp2_end) - y = [sum(x[j] * cis(-π * float(T)(2 * mod(j * k, n)) / n) for j in 1:n) for k in 1:n] - return max.(0, real.(y)) -end -function poissonbinomial_pdf_fft_zygote2(p::AbstractArray{T}) where {T <: Real} - n = length(p) - ω = 2 * one(T) / (n + 1) - x = Vector{Complex{T}}(undef, n+1) - lmax = ceil(Int, n/2) - x[1] = one(T)/(n + 1) - for l=1:lmax - logz = zero(T) - argz = zero(T) - for j=1:n - zjl = 1 - p[j] + p[j] * cospi(ω*l) + im * p[j] * sinpi(ω * l) - logz += log(abs(zjl)) - argz += atan(imag(zjl), real(zjl)) +if isdefined(Distributions, :poissonbinomial_pdf) + @adjoint function Distributions.poissonbinomial_pdf(x::AbstractArray{T}) where T<:Real + value = Distributions.poissonbinomial_pdf(x) + function poissonbinomial_pdf_pullback(Δ) + return ((ForwardDiff.jacobian(Distributions.poissonbinomial_pdf, x)::Matrix{T})' * Δ,) end - dl = exp(logz) - x[l + 1] = dl * cos(argz) / (n + 1) + dl * sin(argz) * im / (n + 1) - if n + 1 - l > l - x[n + 1 - l + 1] = conj(x[l + 1]) - end - end - max.(0, real.(_dft_zygote(copy(x)))) -end -function _dft_zygote(x::Vector{T}) where T - n = length(x) - y = Zygote.Buffer(zeros(complex(float(T)), n)) - @inbounds for j = 0:n-1, k = 0:n-1 - y[k+1] += x[j+1] * cis(-π * float(T)(2 * mod(j * k, n)) / n) + return value, poissonbinomial_pdf_pullback end - return copy(y) end -=# diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 8e472516..b1e7b552 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -395,6 +395,11 @@ # Broken distributions d.f(d.θ...) isa Union{VonMises,TriangularDist} && continue + # Skellam only fails in these tests with ReverseDiff + # Ref: https://github.com/TuringLang/DistributionsAD.jl/pull/119#issuecomment-705769224 + filldist_broken = d.f(d.θ...) isa Skellam ? (:ReverseDiff,) : d.broken + arraydist_broken = d.broken + # Create `filldist` distribution f_filldist = (θ...,) -> filldist(d.f(θ...), n) d_filldist = f_filldist(d.θ...) @@ -420,7 +425,7 @@ f_filldist, d.θ, x; - broken=d.broken, + broken=filldist_broken, ) ) test_ad( @@ -429,7 +434,7 @@ f_arraydist, d.θ, x; - broken=d.broken, + broken=arraydist_broken, ) ) end diff --git a/test/runtests.jl b/test/runtests.jl index fd5f0da0..baf1463c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,22 +5,6 @@ using Combinatorics using Distributions using FiniteDifferences using PDMats -using Requires - -# Figure out which AD backend to test -const AD = get(ENV, "AD", "All") -if AD == "All" || AD == "ForwardDiff" - @eval using ForwardDiff -end -if AD == "All" || AD == "Zygote" - @eval using Zygote -end -if AD == "All" || AD == "ReverseDiff" - @eval using ReverseDiff -end -if AD == "All" || AD == "Tracker" - @eval using Tracker -end using Random, LinearAlgebra, Test @@ -35,21 +19,19 @@ Random.seed!(1) # Set seed that all testsets should reset to. const FDM = FiniteDifferences const GROUP = get(ENV, "GROUP", "All") -# Create positive definite matrix -to_posdef(A::AbstractMatrix) = A * A' + I -to_posdef_diagonal(a::AbstractVector) = Diagonal(a.^2 .+ 1) - -@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin - # Define adjoints for Tracker - to_posdef(A::TrackedMatrix) = Tracker.track(to_posdef, A) - Tracker.@grad function to_posdef(A::TrackedMatrix) - data_A = Tracker.data(A) - S = data_A * data_A' + I - function pullback(∇) - return ((∇ + ∇') * data_A,) - end - return S, pullback - end +# Figure out which AD backend to test +const AD = get(ENV, "AD", "All") +if AD == "All" || AD == "ForwardDiff" + @eval using ForwardDiff +end +if AD == "All" || AD == "Zygote" + @eval using Zygote +end +if AD == "All" || AD == "ReverseDiff" + @eval using ReverseDiff +end +if AD == "All" || AD == "Tracker" + @eval using Tracker end if GROUP == "All" || GROUP == "Others" @@ -57,6 +39,25 @@ if GROUP == "All" || GROUP == "Others" end if GROUP == "All" || GROUP == "AD" + # Create positive definite matrix + to_posdef(A::AbstractMatrix) = A * A' + I + to_posdef_diagonal(a::AbstractVector) = Diagonal(a.^2 .+ 1) + + if AD == "All" || AD == "Tracker" + @eval begin + # Define adjoints for Tracker + to_posdef(A::TrackedMatrix) = Tracker.track(to_posdef, A) + Tracker.@grad function to_posdef(A::TrackedMatrix) + data_A = Tracker.data(A) + S = data_A * data_A' + I + function pullback(∇) + return ((∇ + ∇') * data_A,) + end + return S, pullback + end + end + end + include("ad/utils.jl") include("ad/chainrules.jl") include("ad/distributions.jl")