diff --git a/Project.toml b/Project.toml index 01ebc0d0..00cbbeff 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,9 @@ version = "0.3.2" [deps] Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" +DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" @@ -18,10 +20,14 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Combinatorics = "0.7" +DiffRules = "0.1, 1.0" Distributions = "0.22" +FillArrays = "0.8" +FiniteDifferences = "0.9" ForwardDiff = "0.10.6" PDMats = "0.9" SpecialFunctions = "0.8, 0.9, 0.10" +StatsBase = "0.32" StatsFuns = "0.8, 0.9" Tracker = "0.2.5" Zygote = "0.4.7" diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 6520414b..560689e7 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -11,11 +11,15 @@ using PDMats, StatsFuns using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray, - TrackedVecOrMat, track, data -using ZygoteRules: ZygoteRules, pullback + TrackedVecOrMat, track, @grad, data +using SpecialFunctions: logabsgamma, digamma +using ZygoteRules: ZygoteRules, @adjoint, pullback using LinearAlgebra: copytri! using Distributions: AbstractMvLogNormal, ContinuousMultivariateDistribution +using DiffRules, SpecialFunctions, FillArrays +using ForwardDiff: @define_binary_dual_op # Needed for `eval`ing diffrules here +using Base.Iterators: drop import StatsFuns: logsumexp, binomlogpdf, @@ -35,11 +39,16 @@ export TuringScalMvNormal, TuringMvLogNormal, TuringPoissonBinomial, TuringWishart, - TuringInverseWishart + TuringInverseWishart, + arraydist, + filldist include("common.jl") include("univariate.jl") include("multivariate.jl") include("matrixvariate.jl") +include("flatten.jl") +include("arraydist.jl") +include("filldist.jl") end diff --git a/src/arraydist.jl b/src/arraydist.jl new file mode 100644 index 00000000..5c393c58 --- /dev/null +++ b/src/arraydist.jl @@ -0,0 +1,76 @@ +# Univariate + +const VectorOfUnivariate = Distributions.Product + +function arraydist(dists::AbstractVector{<:Normal{T}}) where {T} + means = mean.(dists) + vars = var.(dists) + return MvNormal(means, vars) +end +function arraydist(dists::AbstractVector{<:Normal{<:TrackedReal}}) + means = vcatmapreduce(mean, dists) + vars = vcatmapreduce(var, dists) + return MvNormal(means, vars) +end +function arraydist(dists::AbstractVector{<:UnivariateDistribution}) + return product_distribution(dists) +end +function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real}) + return sum(vcatmapreduce(logpdf, dist.v, x)) +end +function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real}) + # eachcol breaks Zygote, so we need an adjoint + return vcatmapreduce((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x)) +end +@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real}) + # Any other more efficient implementation breaks Zygote + f(dist, x) = [sum(logpdf.(dist.v, view(x, :, i))) for i in 1:size(x, 2)] + return pullback(f, dist, x) +end + +struct MatrixOfUnivariate{ + S <: ValueSupport, + Tdist <: UnivariateDistribution{S}, + Tdists <: AbstractMatrix{Tdist}, +} <: MatrixDistribution{S} + dists::Tdists +end +Base.size(dist::MatrixOfUnivariate) = size(dist.dists) +function arraydist(dists::AbstractMatrix{<:UnivariateDistribution}) + return MatrixOfUnivariate(dists) +end +function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real}) + # Broadcasting here breaks Tracker for some reason + # A Zygote adjoint is defined for vcatmapreduce to use broadcasting + return sum(vcatmapreduce(logpdf, dist.dists, x)) +end +function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate) + return rand.(Ref(rng), dist.dists) +end + +# Multivariate + +struct VectorOfMultivariate{ + S <: ValueSupport, + Tdist <: MultivariateDistribution{S}, + Tdists <: AbstractVector{Tdist}, +} <: MatrixDistribution{S} + dists::Tdists +end +Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist)) +Base.length(dist::VectorOfMultivariate) = length(dist.dists) +function arraydist(dists::AbstractVector{<:MultivariateDistribution}) + return VectorOfMultivariate(dists) +end +function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) + # eachcol breaks Zygote, so we define an adjoint + return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x))) +end +@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) + f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2))) + return pullback(f, dist, x) +end +function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) + init = reshape(rand(rng, dist.dists[1]), :, 1) + return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init) +end diff --git a/src/common.jl b/src/common.jl index dc9788d7..eee8518f 100644 --- a/src/common.jl +++ b/src/common.jl @@ -1,12 +1,30 @@ ## Generic ## +if VERSION < v"1.1" + eachcol(A::AbstractVecOrMat) = (view(A, :, i) for i in axes(A, 2)) +end + +Base.one(::Irrational) = true + +function vcatmapreduce(f, args...) + init = vcat(f(first.(args)...,)) + zipped_args = zip(args...,) + return mapreduce(vcat, drop(zipped_args, 1); init = init) do zarg + f(zarg...,) + end +end +@adjoint function vcatmapreduce(f, args...) + g(f, args...) = f.(args...,) + return pullback(g, f, args...) +end + function Base.fill( value::TrackedReal, dims::Vararg{Union{Integer, AbstractUnitRange}}, ) return track(fill, value, dims...) end -Tracker.@grad function Base.fill(value::Real, dims...) +@grad function Base.fill(value::Real, dims...) return fill(data(value), dims...), function(Δ) size(Δ) ≢ dims && error("Dimension mismatch") return (sum(Δ), map(_->nothing, dims)...) @@ -16,7 +34,7 @@ end ## StatsFuns ## logsumexp(x::TrackedArray) = track(logsumexp, x) -Tracker.@grad function logsumexp(x::TrackedArray) +@grad function logsumexp(x::TrackedArray) lse = logsumexp(data(x)) return lse, Δ -> (Δ .* exp.(x .- lse),) end @@ -24,7 +42,7 @@ end ## Linear algebra ## LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A) -Tracker.@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix) +@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix) return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),) end @@ -39,27 +57,27 @@ function turing_chol(A::AbstractMatrix, check) (chol.factors, chol.info) end turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check) -Tracker.@grad function turing_chol(A::AbstractMatrix, check) +@grad function turing_chol(A::AbstractMatrix, check) C, back = pullback(unsafe_cholesky, data(A), data(check)) return (C.factors, C.info), Δ->back((factors=data(Δ[1]),)) end unsafe_cholesky(x, check) = cholesky(x, check=check) -ZygoteRules.@adjoint function unsafe_cholesky(Σ::Real, check) +@adjoint function unsafe_cholesky(Σ::Real, check) C = cholesky(Σ; check=check) return C, function(Δ::NamedTuple) issuccess(C) || return (zero(Σ), nothing) (Δ.factors[1, 1] / (2 * C.U[1, 1]), nothing) end end -ZygoteRules.@adjoint function unsafe_cholesky(Σ::Diagonal, check) +@adjoint function unsafe_cholesky(Σ::Diagonal, check) C = cholesky(Σ; check=check) return C, function(Δ::NamedTuple) issuccess(C) || (Diagonal(zero(diag(Δ.factors))), nothing) (Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)), nothing) end end -ZygoteRules.@adjoint function unsafe_cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check) +@adjoint function unsafe_cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check) C = cholesky(Σ; check=check) return C, function(Δ::NamedTuple) issuccess(C) || return (zero(Δ.factors), nothing) @@ -78,7 +96,7 @@ end # Specialised logdet for cholesky to target the triangle directly. logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)]) logdet_chol_tri(U::TrackedMatrix) = track(logdet_chol_tri, U) -Tracker.@grad function logdet_chol_tri(U::AbstractMatrix) +@grad function logdet_chol_tri(U::AbstractMatrix) U_data = data(U) return logdet_chol_tri(U_data), Δ->(Matrix(Diagonal(2 .* Δ ./ diag(U_data))),) end @@ -88,6 +106,7 @@ function LinearAlgebra.logdet(C::Cholesky{<:TrackedReal, <:TrackedMatrix}) end # Tracker's implementation of ldiv isn't good. We'll use Zygote's instead. + zygote_ldiv(A::AbstractMatrix, B::AbstractVecOrMat) = A \ B function zygote_ldiv(A::TrackedMatrix, B::TrackedVecOrMat) return track(zygote_ldiv, A, B) @@ -96,7 +115,7 @@ function zygote_ldiv(A::TrackedMatrix, B::AbstractVecOrMat) return track(zygote_ldiv, A, B) end zygote_ldiv(A::AbstractMatrix, B::TrackedVecOrMat) = track(zygote_ldiv, A, B) -Tracker.@grad function zygote_ldiv(A, B) +@grad function zygote_ldiv(A, B) Y, back = pullback(\, data(A), data(B)) return Y, Δ->back(data(Δ)) end @@ -104,3 +123,41 @@ end function Base.:\(a::Cholesky{<:TrackedReal, <:TrackedArray}, b::AbstractVecOrMat) return (a.U \ (a.U' \ b)) end + +# SpecialFunctions + +SpecialFunctions.logabsgamma(x::TrackedReal) = track(logabsgamma, x) +@grad function SpecialFunctions.logabsgamma(x::Real) + return logabsgamma(data(x)), Δ -> (digamma(data(x)) * Δ[1],) +end +@adjoint function SpecialFunctions.logabsgamma(x::Real) + return logabsgamma(x), Δ -> (digamma(x) * Δ[1],) +end + +# Some Tracker fixes + +for i = 0:2, c = Tracker.combinations([:AbstractArray, :TrackedArray, :TrackedReal, :Number], i), f = [:hcat, :vcat] + if :TrackedReal in c + cnames = map(_ -> gensym(), c) + @eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) = + track($f, $(cnames...), x, xs...) + end +end +@grad function vcat(x::Real) + vcat(data(x)), (Δ) -> (Δ[1],) +end +@grad function vcat(x1::Real, x2::Real) + vcat(data(x1), data(x2)), (Δ) -> (Δ[1], Δ[2]) +end +@grad function vcat(x1::AbstractVector, x2::Real) + vcat(data(x1), data(x2)), (Δ) -> (Δ[1:length(x1)], Δ[length(x1)+1]) +end + +# Zygote fill has issues with non-numbers + +@adjoint function fill(x::T, dims...) where {T} + function zfill(x, dims...,) + return reshape([x for i in 1:prod(dims)], dims) + end + pullback(zfill, x, dims...) +end diff --git a/src/filldist.jl b/src/filldist.jl new file mode 100644 index 00000000..d7a0b282 --- /dev/null +++ b/src/filldist.jl @@ -0,0 +1,111 @@ +# Univariate + +const FillVectorOfUnivariate{ + S <: ValueSupport, + T <: UnivariateDistribution{S}, + Tdists <: Fill{T, 1}, +} = VectorOfUnivariate{S, T, Tdists} + +function filldist(dist::UnivariateDistribution, N::Int) + return product_distribution(Fill(dist, N)) +end +filldist(d::Normal, N::Int) = MvNormal(fill(d.μ, N), d.σ) + +function Distributions.logpdf( + dist::FillVectorOfUnivariate, + x::AbstractVector{<:Real}, +) + return _logpdf(dist, x) +end +function Distributions.logpdf( + dist::FillVectorOfUnivariate, + x::AbstractMatrix{<:Real}, +) + return _logpdf(dist, x) +end +@adjoint function Distributions.logpdf( + dist::FillVectorOfUnivariate, + x::AbstractMatrix{<:Real}, +) + return pullback(_logpdf, dist, x) +end + +function _logpdf( + dist::FillVectorOfUnivariate, + x::AbstractVector{<:Real}, +) + return _flat_logpdf(dist.v.value, x) +end +function _logpdf( + dist::FillVectorOfUnivariate, + x::AbstractMatrix{<:Real}, +) + return _flat_logpdf_mat(dist.v.value, x) +end + +function _flat_logpdf(dist, x) + if toflatten(dist) + f, args = flatten(dist) + return sum(f.(args..., x)) + else + return sum(vcatmapreduce(x -> logpdf(dist, x), x)) + end +end +function _flat_logpdf_mat(dist, x) + if toflatten(dist) + f, args = flatten(dist) + return vec(sum(f.(args..., x), dims = 1)) + else + temp = vcatmapreduce(x -> logpdf(dist, x), x) + return vec(sum(reshape(temp, size(x)), dims = 1)) + end +end + +const FillMatrixOfUnivariate{ + S <: ValueSupport, + T <: UnivariateDistribution{S}, + Tdists <: Fill{T, 2}, +} = MatrixOfUnivariate{S, T, Tdists} + +function filldist(dist::UnivariateDistribution, N1::Integer, N2::Integer) + return MatrixOfUnivariate(Fill(dist, N1, N2)) +end +function Distributions.logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:Real}) + return _flat_logpdf(dist.dists.value, x) +end +function Distributions.rand(rng::Random.AbstractRNG, dist::FillMatrixOfUnivariate) + return rand(rng, dist.dists.value, length.(dist.dists.axes)) +end + +# Multivariate + +const FillVectorOfMultivariate{ + S <: ValueSupport, + T <: MultivariateDistribution{S}, + Tdists <: Fill{T, 1}, +} = VectorOfMultivariate{S, T, Tdists} + +function filldist(dist::MultivariateDistribution, N::Int) + return VectorOfMultivariate(Fill(dist, N)) +end +function Distributions.logpdf( + dist::FillVectorOfMultivariate, + x::AbstractMatrix{<:Real}, +) + return _logpdf(dist, x) +end +@adjoint function Distributions.logpdf( + dist::FillVectorOfMultivariate, + x::AbstractMatrix{<:Real}, +) + return pullback(_logpdf, dist, x) +end +function _logpdf( + dist::FillVectorOfMultivariate, + x::AbstractMatrix{<:Real}, +) + return sum(logpdf(dist.dists.value, x)) +end +function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate) + return rand(rng, dist.dists.value, length.(dist.dists.axes)) +end diff --git a/src/flatten.jl b/src/flatten.jl new file mode 100644 index 00000000..a305d1b5 --- /dev/null +++ b/src/flatten.jl @@ -0,0 +1,76 @@ +macro register(dist) + return quote + DistributionsAD.eval(getexpr($(esc(dist)))) + DistributionsAD.toflatten(::$(esc(dist))) = true + end +end +function getexpr(Tdist) + x = gensym() + fnames = fieldnames(Tdist) + flattened_args = Expr(:tuple, [:(dist.$f) for f in fnames]...) + func = Expr(:->, + Expr(:tuple, fnames..., x), + Expr(:block, + Expr(:call, :logpdf, + Expr(:call, :($Tdist), fnames...), + x, + ) + ) + ) + return :(flatten(dist::$Tdist) = ($func, $flattened_args)) +end +const flattened_dists = [ Bernoulli, + BetaBinomial, + Binomial, + Geometric, + NegativeBinomial, + Poisson, + Skellam, + PoissonBinomial, + Arcsine, + Beta, + BetaPrime, + Biweight, + Cauchy, + Chernoff, + Chi, + Chisq, + Cosine, + Epanechnikov, + Erlang, + Exponential, + FDist, + Frechet, + Gamma, + GeneralizedExtremeValue, + GeneralizedPareto, + Gumbel, + InverseGamma, + InverseGaussian, + Kolmogorov, + Laplace, + Levy, + LocationScale, + Logistic, + LogitNormal, + LogNormal, + Normal, + NormalCanon, + NormalInverseGaussian, + Pareto, + PGeneralizedGaussian, + Rayleigh, + SymTriangularDist, + TDist, + TriangularDist, + Triweight, + Categorical, + Truncated, + ] +for T in flattened_dists + @eval toflatten(::$T) = true +end +toflatten(::Distribution) = false +for T in flattened_dists + eval(getexpr(T)) +end diff --git a/src/matrixvariate.jl b/src/matrixvariate.jl index a57c4849..0bbe16f5 100644 --- a/src/matrixvariate.jl +++ b/src/matrixvariate.jl @@ -201,10 +201,10 @@ end ## Adjoints -ZygoteRules.@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real}) +@adjoint function Distributions.Wishart(df::Real, S::AbstractMatrix{<:Real}) return pullback(TuringWishart, df, S) end -ZygoteRules.@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real}) +@adjoint function Distributions.InverseWishart(df::Real, S::AbstractMatrix{<:Real}) return pullback(TuringInverseWishart, df, S) end diff --git a/src/multivariate.jl b/src/multivariate.jl index 157133e5..8859d377 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -13,11 +13,9 @@ function TuringDenseMvNormal(m::AbstractVector, A::AbstractMatrix) return TuringDenseMvNormal(m, cholesky(A)) end Base.length(d::TuringDenseMvNormal) = length(d.m) -function Distributions.rand(rng::Random.AbstractRNG, d::TuringDenseMvNormal) - return d.m .+ d.C.U' * randn(rng, length(d)) -end -function Distributions.rand(rng::Random.AbstractRNG, d::TuringDenseMvNormal, n::Int) - return d.m .+ d.C.U' * randn(rng, length(d), n) +Distributions.rand(d::TuringDenseMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...) +function Distributions.rand(rng::Random.AbstractRNG, d::TuringDenseMvNormal, n::Int...) + return d.m .+ d.C.U' * randn(rng, length(d), n...) end """ @@ -33,9 +31,10 @@ end Distributions.params(d::TuringDiagMvNormal) = (d.m, d.σ) Distributions.dim(d::TuringDiagMvNormal) = length(d.m) Base.length(d::TuringDiagMvNormal) = length(d.m) -Base.size(d::TuringDiagMvNormal) = (length(d), ) -function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal, n::Int) - return d.m .+ d.σ .* randn(rng, length(d), n) +Base.size(d::TuringDiagMvNormal) = (length(d),) +Distributions.rand(d::TuringDiagMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...) +function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal, n::Int...) + return d.m .+ d.σ .* randn(rng, length(d), n...) end struct TuringScalMvNormal{Tm<:AbstractVector, Tσ<:Real} <: ContinuousMultivariateDistribution @@ -44,12 +43,10 @@ struct TuringScalMvNormal{Tm<:AbstractVector, Tσ<:Real} <: ContinuousMultivaria end Base.length(d::TuringScalMvNormal) = length(d.m) -Base.size(d::TuringScalMvNormal) = (length(d), ) -function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal) - return d.m .+ d.σ .* randn(rng, length(d)) -end -function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal, n::Int) - return d.m .+ d.σ .* randn(rng, length(d), n) +Base.size(d::TuringScalMvNormal) = (length(d),) +Distributions.rand(d::TuringScalMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...) +function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal, n::Int...) + return d.m .+ d.σ .* randn(rng, length(d), n...) end for T in (:AbstractVector, :AbstractMatrix) @@ -268,18 +265,18 @@ MvLogNormal(d::Int, σ::TrackedReal{<:Real}) = TuringMvLogNormal(TuringMvNormal( ## Zygote adjoint -ZygoteRules.@adjoint function Distributions.MvNormal( +@adjoint function Distributions.MvNormal( A::Union{AbstractVector{<:Real}, AbstractMatrix{<:Real}}, ) return pullback(TuringMvNormal, A) end -ZygoteRules.@adjoint function Distributions.MvNormal( +@adjoint function Distributions.MvNormal( m::AbstractVector{<:Real}, A::Union{Real, UniformScaling, AbstractVecOrMat{<:Real}}, ) return pullback(TuringMvNormal, m, A) end -ZygoteRules.@adjoint function Distributions.MvNormal( +@adjoint function Distributions.MvNormal( d::Int, A::Real, ) diff --git a/src/univariate.jl b/src/univariate.jl index 3db7d0df..d1c2e3aa 100644 --- a/src/univariate.jl +++ b/src/univariate.jl @@ -30,25 +30,31 @@ end uniformlogpdf(a::Real, b::Real, x::TrackedReal) = track(uniformlogpdf, a, b, x) uniformlogpdf(a::TrackedReal, b::TrackedReal, x::Real) = track(uniformlogpdf, a, b, x) uniformlogpdf(a::TrackedReal, b::TrackedReal, x::TrackedReal) = track(uniformlogpdf, a, b, x) -Tracker.@grad function uniformlogpdf(a, b, x) +@grad function uniformlogpdf(a, b, x) diff = data(b) - data(a) T = typeof(diff) - l = -log(diff) - f = isfinite(l) - da = 1/diff - n = T(NaN) - return l, Δ->(f ? da : n, f ? -da : n, f ? zero(T) : n) + if a <= data(x) <= b && a < b + l = -log(diff) + da = 1/diff^2 + return l, Δ -> (da * Δ, -da * Δ, zero(T) * Δ) + else + n = T(NaN) + return l, Δ -> (n, n, n) + end end -ZygoteRules.@adjoint function uniformlogpdf(a, b, x) +@adjoint function uniformlogpdf(a, b, x) diff = b - a T = typeof(diff) - l = -log(diff) - f = isfinite(l) - da = 1/diff - n = T(NaN) - return l, Δ->(f ? da : n, f ? -da : n, f ? zero(T) : n) + if a <= x <= b && a < b + l = -log(diff) + da = 1/diff^2 + return l, Δ -> (da * Δ, -da * Δ, zero(T) * Δ) + else + n = T(NaN) + return l, Δ -> (n, n, n) + end end -ZygoteRules.@adjoint function Distributions.Uniform(args...) +@adjoint function Distributions.Uniform(args...) return pullback(TuringUniform, args...) end @@ -61,7 +67,7 @@ function _betalogpdfgrad(α, β, x) dx = (α - 1)/x + (1 - β)/(1 - x) return (dα, dβ, dx) end -ZygoteRules.@adjoint function betalogpdf(α::Real, β::Real, x::Number) +@adjoint function betalogpdf(α::Real, β::Real, x::Number) return betalogpdf(α, β, x), Δ -> (Δ .* _betalogpdfgrad(α, β, x)) end @@ -73,7 +79,7 @@ function _gammalogpdfgrad(k, θ, x) dx = (k - 1)/x - 1/θ return (dk, dθ, dx) end -ZygoteRules.@adjoint function gammalogpdf(k::Real, θ::Real, x::Number) +@adjoint function gammalogpdf(k::Real, θ::Real, x::Number) return gammalogpdf(k, θ, x), Δ -> (Δ .* _gammalogpdfgrad(k, θ, x)) end @@ -86,7 +92,7 @@ function _chisqlogpdfgrad(k, x) dx = (hk - 1)/x - one(hk)/2 return (dk, dx) end -ZygoteRules.@adjoint function chisqlogpdf(k::Real, x::Number) +@adjoint function chisqlogpdf(k::Real, x::Number) return chisqlogpdf(k, x), Δ -> (Δ .* _chisqlogpdfgrad(k, x)) end @@ -103,7 +109,7 @@ function _fdistlogpdfgrad(v1, v2, x) dx = v1 / 2 * (1 / x - temp3) - 1 / x return (dv1, dv2, dx) end -ZygoteRules.@adjoint function fdistlogpdf(v1::Real, v2::Real, x::Number) +@adjoint function fdistlogpdf(v1::Real, v2::Real, x::Number) return fdistlogpdf(v1, v2, x), Δ -> (Δ .* _fdistlogpdfgrad(v1, v2, x)) end @@ -114,31 +120,37 @@ function _tdistlogpdfgrad(v, x) dx = -x * (v + 1) / (v + x^2) return (dv, dx) end -ZygoteRules.@adjoint function tdistlogpdf(v::Real, x::Number) +@adjoint function tdistlogpdf(v::Real, x::Number) return tdistlogpdf(v, x), Δ -> (Δ .* _tdistlogpdfgrad(v, x)) end ## Semicircle ## +function semicircle_dldr(r, x) + diffsq = r^2 - x^2 + return -2 / r + r / diffsq +end +function semicircle_dldx(r, x) + diffsq = r^2 - x^2 + return -x / diffsq +end + logpdf(d::Semicircle{<:Real}, x::TrackedReal) = semicirclelogpdf(d.r, x) logpdf(d::Semicircle{<:TrackedReal}, x::Real) = semicirclelogpdf(d.r, x) logpdf(d::Semicircle{<:TrackedReal}, x::TrackedReal) = semicirclelogpdf(d.r, x) -semicirclelogpdf(r::TrackedReal, x::Real) = track(semicirclelogpdf, r, x) -semicirclelogpdf(r::Real, x::TrackedReal) = track(semicirclelogpdf, r, x) -semicirclelogpdf(r::TrackedReal, x::TrackedReal) = track(semicirclelogpdf, r, x) -Tracker.@grad function semicirclelogpdf(r, x) - rd = data(r) - xd = data(x) - xx, rr = promote(xd, float(rd)) - d = Semicircle(rr) - T = typeof(xx) - l = logpdf(d, xx) - f = isfinite(l) - n = T(NaN) - return l, function (Δ) - diffsq = rr^2 - xx^2 - (f ? Δ*(-2/rr + rr/diffsq) : n, f ? Δ*(-xx/diffsq) : n) - end + +semicirclelogpdf(r, x) = logpdf(Semicircle(r), x) +M, f, arity = DiffRules.@define_diffrule DistributionsAD.semicirclelogpdf(r, x) = + :(semicircle_dldr($r, $x)), :(semicircle_dldx($r, $x)) +da, db = DiffRules.diffrule(M, f, :a, :b) +f = :($M.$f) +@eval begin + @grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ -> (Δ * $da, Tracker._zero(b)) + @grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (Tracker._zero(a), Δ * $db) + $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) + $f(a::TrackedReal, b::Real) = track($f, a, b) + $f(a::Real, b::TrackedReal) = track($f, a, b) end if VERSION < v"1.2" Base.inv(::Irrational{:π}) = 1/π @@ -147,11 +159,11 @@ end ## Binomial ## binomlogpdf(n::Int, p::TrackedReal, x::Int) = track(binomlogpdf, n, p, x) -Tracker.@grad function binomlogpdf(n::Int, p::TrackedReal, x::Int) +@grad function binomlogpdf(n::Int, p::TrackedReal, x::Int) return binomlogpdf(n, data(p), x), Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing) end -ZygoteRules.@adjoint function binomlogpdf(n::Int, p::Real, x::Int) +@adjoint function binomlogpdf(n::Int, p::Real, x::Int) return binomlogpdf(n, p, x), Δ->(nothing, Δ * (x / p - (n - x) / (1 - p)), nothing) end @@ -174,15 +186,15 @@ _nbinomlogpdf_grad_2(r, p, k) = -k / (1 - p) + r / p nbinomlogpdf(n::TrackedReal, p::TrackedReal, x::Int) = track(nbinomlogpdf, n, p, x) nbinomlogpdf(n::Real, p::TrackedReal, x::Int) = track(nbinomlogpdf, n, p, x) nbinomlogpdf(n::TrackedReal, p::Real, x::Int) = track(nbinomlogpdf, n, p, x) -Tracker.@grad function nbinomlogpdf(r::TrackedReal, p::TrackedReal, k::Int) +@grad function nbinomlogpdf(r::TrackedReal, p::TrackedReal, k::Int) return nbinomlogpdf(data(r), data(p), k), Δ->(Δ * _nbinomlogpdf_grad_1(r, p, k), Δ * _nbinomlogpdf_grad_2(r, p, k), nothing) end -Tracker.@grad function nbinomlogpdf(r::Real, p::TrackedReal, k::Int) +@grad function nbinomlogpdf(r::Real, p::TrackedReal, k::Int) return nbinomlogpdf(data(r), data(p), k), Δ->(Tracker._zero(r), Δ * _nbinomlogpdf_grad_2(r, p, k), nothing) end -Tracker.@grad function nbinomlogpdf(r::TrackedReal, p::Real, k::Int) +@grad function nbinomlogpdf(r::TrackedReal, p::Real, k::Int) return nbinomlogpdf(data(r), data(p), k), Δ->(Δ * _nbinomlogpdf_grad_1(r, p, k), Tracker._zero(p), nothing) end @@ -191,10 +203,11 @@ function nbinomlogpdf(r::ForwardDiff.Dual{T}, p::ForwardDiff.Dual{T}, k::Int) wh FD = ForwardDiff.Dual{T} val_p = ForwardDiff.value(p) val_r = ForwardDiff.value(r) - - Δ_r = ForwardDiff.partials(r) * _nbinomlogpdf_grad_1(val_r, val_p, k) - Δ_p = ForwardDiff.partials(p) * _nbinomlogpdf_grad_2(val_r, val_p, k) - Δ = Δ_p + Δ_r + Δ_r = ForwardDiff.partials(r) + dr = _nbinomlogpdf_grad_1(val_r, val_p, k) + Δ_p = ForwardDiff.partials(p) + dp = _nbinomlogpdf_grad_2(val_r, val_p, k) + Δ = ForwardDiff._mul_partials(Δ_r, Δ_p, dr, dp) return FD(nbinomlogpdf(val_r, val_p, k), Δ) end function nbinomlogpdf(r::Real, p::ForwardDiff.Dual{T}, k::Int) where {T} @@ -213,11 +226,11 @@ end ## Poisson ## poislogpdf(v::TrackedReal, x::Int) = track(poislogpdf, v, x) -Tracker.@grad function poislogpdf(v::TrackedReal, x::Int) +@grad function poislogpdf(v::TrackedReal, x::Int) return poislogpdf(data(v), x), Δ->(Δ * (x/v - 1), nothing) end -ZygoteRules.@adjoint function poislogpdf(v::Real, x::Int) +@adjoint function poislogpdf(v::Real, x::Int) return poislogpdf(v, x), Δ->(Δ * (x/v - 1), nothing) end @@ -249,7 +262,7 @@ Base.minimum(d::TuringPoissonBinomial) = 0 Base.maximum(d::TuringPoissonBinomial) = length(d.p) poissonbinomial_pdf_fft(x::TrackedArray) = track(poissonbinomial_pdf_fft, x) -Tracker.@grad function poissonbinomial_pdf_fft(x::TrackedArray) +@grad function poissonbinomial_pdf_fft(x::TrackedArray) x_data = data(x) T = eltype(x_data) fft = poissonbinomial_pdf_fft(x_data) @@ -258,16 +271,17 @@ Tracker.@grad function poissonbinomial_pdf_fft(x::TrackedArray) end end # FIXME: This is inefficient, replace with the commented code below once Zygote supports it. -ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray) +@adjoint function poissonbinomial_pdf_fft(x::AbstractArray) T = eltype(x) fft = poissonbinomial_pdf_fft(x) return fft, Δ -> begin ((ForwardDiff.jacobian(x -> poissonbinomial_pdf_fft(x), x)::Matrix{T})' * Δ,) end end + # The code below doesn't work because of bugs in Zygote. The above is inefficient. #= -ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{<:Real}) +@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{<:Real}) return pullback(poissonbinomial_pdf_fft_zygote, x) end function poissonbinomial_pdf_fft_zygote(p::AbstractArray{T}) where {T <: Real} diff --git a/test/distributions.jl b/test/distributions.jl index 450c689f..a74c597e 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -159,6 +159,8 @@ separator() @testset "Multivariate discrete distributions" begin test_head("Testing: Multivariate discrete distributions") mult_disc_dists = [ + DistSpec(:((p) -> filldist(Bernoulli(p), dim)), (0.45,), fill(1, dim)), + DistSpec(:((p) -> arraydist(fill(Bernoulli(p), dim))), (0.45,), fill(1, dim)), DistSpec(:((p) -> Multinomial(2, p / sum(p))), (fill(0.5, 2),), [2, 0]), ] for d in mult_disc_dists @@ -174,6 +176,10 @@ separator() test_head("Testing: Multivariate continuous distributions") mult_cont_dists = [ # Vector case + DistSpec(:(() -> filldist(Beta(), dim)), (), fill(0.5, dim)), + DistSpec(:(() -> arraydist(fill(Beta(), dim))), (), fill(0.5, dim)), + DistSpec(:((m, v) -> filldist(Normal(m, sqrt(v)), dim)), (1.0, 1.0), norm_val_vec), + DistSpec(:((m, v) -> arraydist(fill(Normal(m, sqrt(v)), dim))), (1.0, 1.0), norm_val_vec), DistSpec(:MvNormal, (mean, cov_mat), norm_val_vec), DistSpec(:MvNormal, (mean, cov_vec), norm_val_vec), DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_vec), @@ -192,6 +198,10 @@ separator() DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_vec), DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_vec), # Matrix case + DistSpec(:(() -> filldist(Beta(), dim)), (), fill(0.5, dim, dim)), + DistSpec(:(() -> arraydist(fill(Beta(), dim))), (), fill(0.5, dim, dim)), + DistSpec(:((m, v) -> filldist(Normal(m, sqrt(v)), dim)), (1.0, 1.0), norm_val_mat), + DistSpec(:((m, v) -> arraydist(fill(Normal(m, sqrt(v)), dim))), (1.0, 1.0), norm_val_mat), DistSpec(:MvNormal, (mean, cov_vec), norm_val_mat), DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_mat), DistSpec(:MvNormal, (mean, cov_num), norm_val_mat), @@ -215,7 +225,6 @@ separator() DistSpec(:MvNormalCanon, (cov_mat,), norm_val_vec), DistSpec(:MvNormalCanon, (cov_vec,), norm_val_vec), DistSpec(:(cov_num -> MvNormalCanon(dim, cov_num)), (cov_num,), norm_val_vec), - DistSpec(:Dirichlet, (alpha,), dir_val), DistSpec(:MvNormalCanon, (mean, cov_mat), norm_val_mat), DistSpec(:MvNormalCanon, (mean, cov_vec), norm_val_mat), DistSpec(:MvNormalCanon, (mean, cov_num), norm_val_mat), @@ -244,6 +253,10 @@ separator() @testset "Matrix-variate continuous distributions" begin test_head("Testing: Matrix-variate continuous distributions") matrix_cont_dists = [ + DistSpec(:(() -> filldist(Beta(), dim, dim)), (), fill(0.5, dim, dim)), + DistSpec(:(() -> arraydist(fill(Beta(), dim, dim))), (), fill(0.5, dim, dim)), + DistSpec(:((m, v) -> filldist(Normal(m, sqrt(v)), dim, 2)), (1.0, 1.0), norm_val_mat), + DistSpec(:((m, v) -> arraydist(fill(Normal(m, sqrt(v)), dim, 2))), (1.0, 1.0), norm_val_mat), DistSpec(:((n1, n2)->MatrixBeta(dim, n1, n2)), (dim, dim), beta_mat), DistSpec(:Wishart, (dim, cov_mat), cov_mat), DistSpec(:InverseWishart, (dim, cov_mat), cov_mat), diff --git a/test/others.jl b/test/others.jl index 3546ad2e..8048e2fb 100644 --- a/test/others.jl +++ b/test/others.jl @@ -19,5 +19,5 @@ end d1 = TuringDiagMvNormal(zeros(10), sigmas) d2 = MvNormal(zeros(10), sigmas) - @test entropy(d1) == entropy(d2) + @test isapprox(entropy(d1), entropy(d2), rtol = 1e-6) end