Skip to content

Commit

Permalink
Remove TuringUniform (#211)
Browse files Browse the repository at this point in the history
* Remove `TuringUniform`

* Update tests
  • Loading branch information
devmotion authored Jan 23, 2022
1 parent 5570ac5 commit 9461c6b
Show file tree
Hide file tree
Showing 12 changed files with 44 additions and 113 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DistributionsAD"
uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
version = "0.6.35"
version = "0.6.36"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -27,8 +27,8 @@ ChainRules = "1"
ChainRulesCore = "1"
Compat = "3.6"
DiffRules = "0.1, 1.0"
Distributions = "0.25.32"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
Distributions = "0.25.41"
FillArrays = "0.9, 0.10, 0.11, 0.12"
NaNMath = "0.3"
PDMats = "0.9, 0.10, 0.11"
Requires = "1"
Expand Down
27 changes: 13 additions & 14 deletions src/DistributionsAD.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module DistributionsAD

using PDMats,
LinearAlgebra,
Distributions,
Random,
using PDMats,
LinearAlgebra,
Distributions,
Random,
SpecialFunctions,
StatsFuns,
Compat,
Expand All @@ -16,20 +16,20 @@ using PDMats,

using SpecialFunctions: logabsgamma, digamma
using LinearAlgebra: copytri!, AbstractTriangular
using Distributions: AbstractMvLogNormal,
using Distributions: AbstractMvLogNormal,
ContinuousMultivariateDistribution
using Base.Iterators: drop

import StatsBase
import StatsFuns: logsumexp,
binomlogpdf,
nbinomlogpdf,
poislogpdf,
import StatsFuns: logsumexp,
binomlogpdf,
nbinomlogpdf,
poislogpdf,
nbetalogpdf
import Distributions: MvNormal,
MvLogNormal,
logpdf,
quantile,
import Distributions: MvNormal,
MvLogNormal,
logpdf,
quantile,
PoissonBinomial,
Binomial,
BetaBinomial,
Expand All @@ -53,7 +53,6 @@ include("multivariate.jl")
include("matrixvariate.jl")
include("flatten.jl")

include("chainrules.jl")
include("zygote.jl")

@init begin
Expand Down
11 changes: 0 additions & 11 deletions src/chainrules.jl

This file was deleted.

5 changes: 2 additions & 3 deletions src/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ function getexpr(Tdist)
x = gensym()
fnames = fieldnames(Tdist)
flattened_args = Expr(:tuple, [:(dist.$f) for f in fnames]...)
func = Expr(:->,
Expr(:tuple, fnames..., x),
func = Expr(:->,
Expr(:tuple, fnames..., x),
Expr(:block,
Expr(:call, :logpdf,
Expr(:call, :($Tdist), fnames...),
Expand Down Expand Up @@ -58,7 +58,6 @@ const flattened_dists = [ Bernoulli,
TDist,
TriangularDist,
Triweight,
TuringUniform,
]
for T in flattened_dists
@eval toflatten(::$T) = true
Expand Down
47 changes: 22 additions & 25 deletions src/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,36 +208,33 @@ adapt_randn(rng::AbstractRNG, x::TrackedArray, dims...) = adapt_randn(rng, data(

## Uniform ##

Distributions.Uniform(a::TrackedReal, b::Real) = TuringUniform{TrackedReal}(a, b)
Distributions.Uniform(a::Real, b::TrackedReal) = TuringUniform{TrackedReal}(a, b)
Distributions.Uniform(a::TrackedReal, b::TrackedReal) = TuringUniform{TrackedReal}(a, b)
Distributions.logpdf(d::Uniform, x::TrackedReal) = uniformlogpdf(d.a, d.b, x)

uniformlogpdf(a::Real, b::Real, x::TrackedReal) = track(uniformlogpdf, a, b, x)
uniformlogpdf(a::TrackedReal, b::TrackedReal, x::Real) = track(uniformlogpdf, a, b, x)
uniformlogpdf(a::TrackedReal, b::TrackedReal, x::TrackedReal) = track(uniformlogpdf, a, b, x)
@grad function uniformlogpdf(a, b, x)
# compute log pdf
diff = data(b) - data(a)
insupport = a <= data(x) <= b
lp = insupport ? -log(diff) : log(zero(diff))

function pullback(Δ)
z = zero(x) * Δ
if insupport
c = Δ / diff
return c, -c, z
else
c = Δ / one(diff)
cNaN = oftype(c, NaN)
return cNaN, cNaN, oftype(z, NaN)
logpdf(d::Uniform, x::TrackedReal) = track(uniformlogpdf, d.a, d.b, x)
logpdf(d::Uniform{<:TrackedReal}, x::Real) = track(uniformlogpdf, d.a, d.b, x)
logpdf(d::Uniform{<:TrackedReal}, x::TrackedReal) = track(uniformlogpdf, d.a, d.b, x)

# avoid any possible promotions of the outer constructor
uniformlogpdf(a::T, b::T, x::Real) where {T<:Real} = logpdf(Uniform{T}(a, b), x)
@grad function uniformlogpdf(_a::T, _b::T, _x::Real) where {T<:Real}
# Compute log probability
a = data(_a)
b = data(_b)
x = data(_x)
insupport = a <= x <= b
diff = b - a
Ω = insupport ? -log(diff) : log(zero(diff))

# Define pullback
function uniformlogpdf_pullback(Δ)
Δa = Δ / diff
if !insupport
Δa = zero(Δa)
end
return Δa, -Δa, zero(x)
end

return lp, pullback
return Ω, uniformlogpdf_pullback
end


## Binomial ##

binomlogpdf(n::Int, p::TrackedReal, x::Int) = track(binomlogpdf, n, p, x)
Expand Down
28 changes: 0 additions & 28 deletions src/univariate.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,3 @@
## Uniform ##

struct TuringUniform{T} <: ContinuousUnivariateDistribution
a::T
b::T
end
TuringUniform() = TuringUniform(0.0, 1.0)
function TuringUniform(a::Int, b::Int)
return TuringUniform{Float64}(Float64(a), Float64(b))
end
function TuringUniform(a::Real, b::Real)
T = promote_type(typeof(a), typeof(b))
return TuringUniform{T}(T(a), T(b))
end
Distributions.logpdf(d::TuringUniform, x::Real) = uniformlogpdf(d.a, d.b, x)

Base.minimum(d::TuringUniform) = d.a
Base.maximum(d::TuringUniform) = d.b

function uniformlogpdf(a, b, x)
diff = b - a
if a <= x <= b
return -log(diff)
else
return log(zero(diff))
end
end

## PoissonBinomial ##

struct TuringPoissonBinomial{T<:Real, TV1<:AbstractVector{T}, TV2<:AbstractVector} <: DiscreteUnivariateDistribution
Expand Down
6 changes: 0 additions & 6 deletions src/zygote.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
## Uniform ##

ZygoteRules.@adjoint function Distributions.Uniform(args...)
return ZygoteRules.pullback(TuringUniform, args...)
end

## Product

# Tests with `Kolmogorov` seem to fail otherwise?!
Expand Down
7 changes: 0 additions & 7 deletions test/ad/chainrules.jl

This file was deleted.

3 changes: 0 additions & 3 deletions test/ad/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,6 @@
DistSpec(Uniform, (), 0.5),
DistSpec(Uniform, (alpha, alpha + beta), alpha + beta * gamma),

DistSpec(TuringUniform, (), 0.5),
DistSpec(TuringUniform, (alpha, alpha + beta), alpha + beta * gamma),

DistSpec(VonMises, (), 1.0),

DistSpec(Weibull, (), 1.0),
Expand Down
10 changes: 3 additions & 7 deletions test/ad/others.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
@testset "AD: Others" begin
if GROUP == "All" || GROUP == "Tracker"
@testset "TuringUniform" begin
@test logpdf(TuringUniform(), param(0.5)) == 0
end

@testset "Semicircle" begin
@test Tracker.data(logpdf(Semicircle(1.0), param(0.5))) == logpdf(Semicircle(1.0), 0.5)
end
Expand All @@ -17,7 +13,7 @@
@testset "zygote_ldiv" begin
A = to_posdef(rand(3, 3))
B = to_posdef(rand(3, 3))

test_reverse_mode_ad(randn(3, 3), A, B) do A, B
return DistributionsAD.zygote_ldiv(A, B)
end
Expand Down Expand Up @@ -84,10 +80,10 @@
v = rand(rng, T, n)
d = rand(Int, n)
tp = ReverseDiff.InstructionTape()
x = ReverseDiff.TrackedArray(v, d, tp)
x = ReverseDiff.TrackedArray(v, d, tp)
test_adapt_randn(rng, x, T, dims...)
end
end
end
end
end
end
4 changes: 0 additions & 4 deletions test/others.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,6 @@
end
end

@testset "TuringUniform" begin
@test logpdf(TuringUniform(), 0.5) == 0
end

@testset "TuringPoissonBinomial" begin
d1 = TuringPoissonBinomial([0.5, 0.5])
d2 = PoissonBinomial([0.5, 0.5])
Expand Down
3 changes: 1 addition & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using PDMats
using Random, LinearAlgebra, Test

using Distributions: meanlogdet
using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal,
using DistributionsAD: TuringMvNormal, TuringMvLogNormal,
TuringPoissonBinomial, TuringDirichlet
using StatsBase: entropy
using StatsFuns: StatsFuns, logsumexp, logistic
Expand All @@ -25,6 +25,5 @@ end
if GROUP == "All" || GROUP in ("ForwardDiff", "Zygote", "ReverseDiff", "Tracker")
include("ad/utils.jl")
include("ad/others.jl")
include("ad/chainrules.jl")
include("ad/distributions.jl")
end

2 comments on commit 9461c6b

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/53019

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.36 -m "<description of version>" 9461c6b9eb3472efce160d37fbe1a1b4ec8c9bb6
git push origin v0.6.36

Please sign in to comment.