diff --git a/Project.toml b/Project.toml index 1815b08..43205ed 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.35" +version = "0.6.36" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -27,8 +27,8 @@ ChainRules = "1" ChainRulesCore = "1" Compat = "3.6" DiffRules = "0.1, 1.0" -Distributions = "0.25.32" -FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12" +Distributions = "0.25.41" +FillArrays = "0.9, 0.10, 0.11, 0.12" NaNMath = "0.3" PDMats = "0.9, 0.10, 0.11" Requires = "1" diff --git a/src/DistributionsAD.jl b/src/DistributionsAD.jl index 92349bf..9be245a 100644 --- a/src/DistributionsAD.jl +++ b/src/DistributionsAD.jl @@ -1,9 +1,9 @@ module DistributionsAD -using PDMats, - LinearAlgebra, - Distributions, - Random, +using PDMats, + LinearAlgebra, + Distributions, + Random, SpecialFunctions, StatsFuns, Compat, @@ -16,20 +16,20 @@ using PDMats, using SpecialFunctions: logabsgamma, digamma using LinearAlgebra: copytri!, AbstractTriangular -using Distributions: AbstractMvLogNormal, +using Distributions: AbstractMvLogNormal, ContinuousMultivariateDistribution using Base.Iterators: drop import StatsBase -import StatsFuns: logsumexp, - binomlogpdf, - nbinomlogpdf, - poislogpdf, +import StatsFuns: logsumexp, + binomlogpdf, + nbinomlogpdf, + poislogpdf, nbetalogpdf -import Distributions: MvNormal, - MvLogNormal, - logpdf, - quantile, +import Distributions: MvNormal, + MvLogNormal, + logpdf, + quantile, PoissonBinomial, Binomial, BetaBinomial, @@ -53,7 +53,6 @@ include("multivariate.jl") include("matrixvariate.jl") include("flatten.jl") -include("chainrules.jl") include("zygote.jl") @init begin diff --git a/src/chainrules.jl b/src/chainrules.jl deleted file mode 100644 index 45c8ee7..0000000 --- a/src/chainrules.jl +++ /dev/null @@ -1,11 +0,0 @@ -## Uniform ## - -@scalar_rule( - uniformlogpdf(a, b, x), - @setup( - insupport = a <= x <= b, - diff = b - a, - c = insupport ? inv(diff) : inv(one(diff)), - ), - (c, -c, ZeroTangent()), -) diff --git a/src/flatten.jl b/src/flatten.jl index f386a4b..d175049 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -8,8 +8,8 @@ function getexpr(Tdist) x = gensym() fnames = fieldnames(Tdist) flattened_args = Expr(:tuple, [:(dist.$f) for f in fnames]...) - func = Expr(:->, - Expr(:tuple, fnames..., x), + func = Expr(:->, + Expr(:tuple, fnames..., x), Expr(:block, Expr(:call, :logpdf, Expr(:call, :($Tdist), fnames...), @@ -58,7 +58,6 @@ const flattened_dists = [ Bernoulli, TDist, TriangularDist, Triweight, - TuringUniform, ] for T in flattened_dists @eval toflatten(::$T) = true diff --git a/src/tracker.jl b/src/tracker.jl index e0fe1e2..414d442 100644 --- a/src/tracker.jl +++ b/src/tracker.jl @@ -208,36 +208,33 @@ adapt_randn(rng::AbstractRNG, x::TrackedArray, dims...) = adapt_randn(rng, data( ## Uniform ## -Distributions.Uniform(a::TrackedReal, b::Real) = TuringUniform{TrackedReal}(a, b) -Distributions.Uniform(a::Real, b::TrackedReal) = TuringUniform{TrackedReal}(a, b) -Distributions.Uniform(a::TrackedReal, b::TrackedReal) = TuringUniform{TrackedReal}(a, b) -Distributions.logpdf(d::Uniform, x::TrackedReal) = uniformlogpdf(d.a, d.b, x) - -uniformlogpdf(a::Real, b::Real, x::TrackedReal) = track(uniformlogpdf, a, b, x) -uniformlogpdf(a::TrackedReal, b::TrackedReal, x::Real) = track(uniformlogpdf, a, b, x) -uniformlogpdf(a::TrackedReal, b::TrackedReal, x::TrackedReal) = track(uniformlogpdf, a, b, x) -@grad function uniformlogpdf(a, b, x) - # compute log pdf - diff = data(b) - data(a) - insupport = a <= data(x) <= b - lp = insupport ? -log(diff) : log(zero(diff)) - - function pullback(Δ) - z = zero(x) * Δ - if insupport - c = Δ / diff - return c, -c, z - else - c = Δ / one(diff) - cNaN = oftype(c, NaN) - return cNaN, cNaN, oftype(z, NaN) +logpdf(d::Uniform, x::TrackedReal) = track(uniformlogpdf, d.a, d.b, x) +logpdf(d::Uniform{<:TrackedReal}, x::Real) = track(uniformlogpdf, d.a, d.b, x) +logpdf(d::Uniform{<:TrackedReal}, x::TrackedReal) = track(uniformlogpdf, d.a, d.b, x) + +# avoid any possible promotions of the outer constructor +uniformlogpdf(a::T, b::T, x::Real) where {T<:Real} = logpdf(Uniform{T}(a, b), x) +@grad function uniformlogpdf(_a::T, _b::T, _x::Real) where {T<:Real} + # Compute log probability + a = data(_a) + b = data(_b) + x = data(_x) + insupport = a <= x <= b + diff = b - a + Ω = insupport ? -log(diff) : log(zero(diff)) + + # Define pullback + function uniformlogpdf_pullback(Δ) + Δa = Δ / diff + if !insupport + Δa = zero(Δa) end + return Δa, -Δa, zero(x) end - return lp, pullback + return Ω, uniformlogpdf_pullback end - ## Binomial ## binomlogpdf(n::Int, p::TrackedReal, x::Int) = track(binomlogpdf, n, p, x) diff --git a/src/univariate.jl b/src/univariate.jl index baab8c5..82b8d66 100644 --- a/src/univariate.jl +++ b/src/univariate.jl @@ -1,31 +1,3 @@ -## Uniform ## - -struct TuringUniform{T} <: ContinuousUnivariateDistribution - a::T - b::T -end -TuringUniform() = TuringUniform(0.0, 1.0) -function TuringUniform(a::Int, b::Int) - return TuringUniform{Float64}(Float64(a), Float64(b)) -end -function TuringUniform(a::Real, b::Real) - T = promote_type(typeof(a), typeof(b)) - return TuringUniform{T}(T(a), T(b)) -end -Distributions.logpdf(d::TuringUniform, x::Real) = uniformlogpdf(d.a, d.b, x) - -Base.minimum(d::TuringUniform) = d.a -Base.maximum(d::TuringUniform) = d.b - -function uniformlogpdf(a, b, x) - diff = b - a - if a <= x <= b - return -log(diff) - else - return log(zero(diff)) - end -end - ## PoissonBinomial ## struct TuringPoissonBinomial{T<:Real, TV1<:AbstractVector{T}, TV2<:AbstractVector} <: DiscreteUnivariateDistribution diff --git a/src/zygote.jl b/src/zygote.jl index 5500e8e..e7dec8e 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -1,9 +1,3 @@ -## Uniform ## - -ZygoteRules.@adjoint function Distributions.Uniform(args...) - return ZygoteRules.pullback(TuringUniform, args...) -end - ## Product # Tests with `Kolmogorov` seem to fail otherwise?! diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl deleted file mode 100644 index 9496ba2..0000000 --- a/test/ad/chainrules.jl +++ /dev/null @@ -1,7 +0,0 @@ -@testset "chainrules" begin - x = randn() - z = x + exp(randn()) - y = z + exp(randn()) - test_frule(DistributionsAD.uniformlogpdf, x, y, z) - test_rrule(DistributionsAD.uniformlogpdf, x, y, z) -end diff --git a/test/ad/distributions.jl b/test/ad/distributions.jl index 08ddb92..ef5ee18 100644 --- a/test/ad/distributions.jl +++ b/test/ad/distributions.jl @@ -201,9 +201,6 @@ DistSpec(Uniform, (), 0.5), DistSpec(Uniform, (alpha, alpha + beta), alpha + beta * gamma), - DistSpec(TuringUniform, (), 0.5), - DistSpec(TuringUniform, (alpha, alpha + beta), alpha + beta * gamma), - DistSpec(VonMises, (), 1.0), DistSpec(Weibull, (), 1.0), diff --git a/test/ad/others.jl b/test/ad/others.jl index a0b7431..1c70558 100644 --- a/test/ad/others.jl +++ b/test/ad/others.jl @@ -1,9 +1,5 @@ @testset "AD: Others" begin if GROUP == "All" || GROUP == "Tracker" - @testset "TuringUniform" begin - @test logpdf(TuringUniform(), param(0.5)) == 0 - end - @testset "Semicircle" begin @test Tracker.data(logpdf(Semicircle(1.0), param(0.5))) == logpdf(Semicircle(1.0), 0.5) end @@ -17,7 +13,7 @@ @testset "zygote_ldiv" begin A = to_posdef(rand(3, 3)) B = to_posdef(rand(3, 3)) - + test_reverse_mode_ad(randn(3, 3), A, B) do A, B return DistributionsAD.zygote_ldiv(A, B) end @@ -84,10 +80,10 @@ v = rand(rng, T, n) d = rand(Int, n) tp = ReverseDiff.InstructionTape() - x = ReverseDiff.TrackedArray(v, d, tp) + x = ReverseDiff.TrackedArray(v, d, tp) test_adapt_randn(rng, x, T, dims...) end end end end -end \ No newline at end of file +end diff --git a/test/others.jl b/test/others.jl index 10bd450..23549ae 100644 --- a/test/others.jl +++ b/test/others.jl @@ -96,10 +96,6 @@ end end - @testset "TuringUniform" begin - @test logpdf(TuringUniform(), 0.5) == 0 - end - @testset "TuringPoissonBinomial" begin d1 = TuringPoissonBinomial([0.5, 0.5]) d2 = PoissonBinomial([0.5, 0.5]) diff --git a/test/runtests.jl b/test/runtests.jl index d03627e..3829094 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,7 @@ using PDMats using Random, LinearAlgebra, Test using Distributions: meanlogdet -using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal, +using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringPoissonBinomial, TuringDirichlet using StatsBase: entropy using StatsFuns: StatsFuns, logsumexp, logistic @@ -25,6 +25,5 @@ end if GROUP == "All" || GROUP in ("ForwardDiff", "Zygote", "ReverseDiff", "Tracker") include("ad/utils.jl") include("ad/others.jl") - include("ad/chainrules.jl") include("ad/distributions.jl") end