Skip to content

Commit 4ee4798

Browse files
committed
fix rebase and consistency issues
1 parent 925d5f9 commit 4ee4798

File tree

2 files changed

+33
-95
lines changed

2 files changed

+33
-95
lines changed

src/multivariate.jl

+15-64
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@ function TuringDenseMvNormal(m::AbstractVector, A::AbstractMatrix)
1313
return TuringDenseMvNormal(m, cholesky(A))
1414
end
1515
Base.length(d::TuringDenseMvNormal) = length(d.m)
16-
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDenseMvNormal)
17-
return d.m .+ d.C.U' * randn(rng, length(d))
18-
end
19-
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDenseMvNormal, n::Int)
20-
return d.m .+ d.C.U' * randn(rng, length(d), n)
16+
Distributions.rand(d::TuringDenseMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...)
17+
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDenseMvNormal, n::Int...)
18+
return d.m .+ d.C.U' * randn(rng, length(d), n...)
2119
end
2220

2321
"""
@@ -32,43 +30,21 @@ end
3230

3331
Base.length(d::TuringDiagMvNormal) = length(d.m)
3432
Base.size(d::TuringDiagMvNormal) = (length(d), length(d))
35-
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal)
36-
return d.m .+ d.σ .* randn(rng, length(d))
37-
end
38-
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal, n::Int)
39-
return d.m .+ d.σ .* randn(rng, length(d), n)
40-
end
41-
42-
struct TuringScalMvNormal{Tm<:AbstractVector, Tσ<:Real} <: ContinuousMultivariateDistribution
43-
m::Tm
44-
σ::Tσ
45-
end
46-
47-
Base.length(d::TuringScalMvNormal) = length(d.m)
48-
Base.size(d::TuringScalMvNormal) = (length(d), length(d))
49-
function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal)
50-
return d.m .+ d.σ .* randn(rng, length(d))
51-
end
52-
function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal, n::Int)
53-
return d.m .+ d.σ .* randn(rng, length(d), n)
54-
55-
Base.length(d::TuringDiagMvNormal) = length(d.m)
56-
Base.size(d::TuringDiagMvNormal) = (length(d), length(d))
57-
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal)
58-
return d.m .+ d.σ .* randn(rng, length(d))
33+
Distributions.rand(d::TuringDiagMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...)
34+
function Distributions.rand(rng::Random.AbstractRNG, d::TuringDiagMvNormal, n::Int...)
35+
return d.m .+ d.σ .* randn(rng, length(d), n...)
5936
end
6037

61-
6238
struct TuringScalMvNormal{Tm<:AbstractVector, Tσ<:Real} <: ContinuousMultivariateDistribution
6339
m::Tm
6440
σ::Tσ
6541
end
6642

6743
Base.length(d::TuringScalMvNormal) = length(d.m)
6844
Base.size(d::TuringScalMvNormal) = (length(d), length(d))
69-
function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal)
70-
return d.m .+ d.σ .* randn(rng, length(d))
71-
>>>>>>> multiple distributions as one
45+
Distributions.rand(d::TuringScalMvNormal, n::Int...) = rand(Random.GLOBAL_RNG, d, n...)
46+
function Distributions.rand(rng::Random.AbstractRNG, d::TuringScalMvNormal, n::Int...)
47+
return d.m .+ d.σ .* randn(rng, length(d), n...)
7248
end
7349

7450
for T in (:AbstractVector, :AbstractMatrix)
@@ -95,31 +71,6 @@ function _logpdf(d::TuringDenseMvNormal, x::AbstractVector)
9571
end
9672
function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix)
9773
return -((size(x, 1) * log(2π) + logdet(d.C)) .+ vec(sum(abs2.(zygote_ldiv(d.C.U', x .- d.m)), dims=1))) ./ 2
98-
=======
99-
for T in (:TrackedVector, :TrackedMatrix)
100-
@eval function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.ScalMat}, x::$T)
101-
logpdf(TuringScalMvNormal(d.μ, d.Σ.value), x)
102-
end
103-
end
104-
105-
function _logpdf(d::TuringScalMvNormal, x::AbstractVector)
106-
return -(length(x) * log(2π) + 2 * sum(log(d.σ)) + sum(abs2, (x .- d.m) ./ d.σ)) / 2
107-
end
108-
function _logpdf(d::TuringScalMvNormal, x::AbstractMatrix)
109-
return -(size(x, 2) * log(2π) .+ 2 * sum(log(d.σ)) .+ sum(abs2, (x .- d.m) ./ d.σ, dims=1)') ./ 2
110-
end
111-
112-
function _logpdf(d::TuringDiagMvNormal, x::AbstractVector)
113-
return -(length(x) * log(2π) + 2 * sum(log.(d.σ)) + sum(abs2, (x .- d.m) ./ d.σ)) / 2
114-
end
115-
function _logpdf(d::TuringDiagMvNormal, x::AbstractMatrix)
116-
return -(size(x, 2) * log(2π) .+ 2 * sum(log.(d.σ)) .+ sum(abs2, (x .- d.m) ./ d.σ, dims=1)') ./ 2
117-
end
118-
function _logpdf(d::TuringDenseMvNormal, x::AbstractVector)
119-
return -(length(x) * log(2π) + logdet(d.C) + sum(abs2, zygote_ldiv(d.C.U', x .- d.m))) / 2
120-
end
121-
function _logpdf(d::TuringDenseMvNormal, x::AbstractMatrix)
122-
return -(size(x, 2) * log(2π) .+ logdet(d.C) .+ sum(abs2, zygote_ldiv(d.C.U', x .- d.m), dims=1)') ./ 2
12374
end
12475

12576
# zero mean, dense covariance
@@ -129,9 +80,9 @@ MvNormal(A::TrackedMatrix) = TuringMvNormal(A)
12980
MvNormal::TrackedVector) = TuringMvNormal(σ)
13081

13182
# dense mean, dense covariance
132-
MvNormal(m::TrackedVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringDenseMvNormal(m, A)
133-
MvNormal(m::TrackedVector{<:Real}, A::Matrix{<:Real}) = TuringDenseMvNormal(m, A)
134-
MvNormal(m::AbstractVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringDenseMvNormal(m, A)
83+
MvNormal(m::TrackedVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvNormal(m, A)
84+
MvNormal(m::TrackedVector{<:Real}, A::Matrix{<:Real}) = TuringMvNormal(m, A)
85+
MvNormal(m::AbstractVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvNormal(m, A)
13586

13687
# dense mean, diagonal covariance
13788
function MvNormal(
@@ -237,9 +188,9 @@ MvLogNormal(A::TrackedMatrix) = TuringMvLogNormal(TuringMvNormal(A))
237188
MvLogNormal::TrackedVector) = TuringMvLogNormal(TuringMvNormal(σ))
238189

239190
# dense mean, dense covariance
240-
MvLogNormal(m::TrackedVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvLogNormal(TuringDenseMvNormal(m, A))
241-
MvLogNormal(m::TrackedVector{<:Real}, A::Matrix{<:Real}) = TuringMvLogNormal(TuringDenseMvNormal(m, A))
242-
MvLogNormal(m::AbstractVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvLogNormal(TuringDenseMvNormal(m, A))
191+
MvLogNormal(m::TrackedVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, A))
192+
MvLogNormal(m::TrackedVector{<:Real}, A::Matrix{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, A))
193+
MvLogNormal(m::AbstractVector{<:Real}, A::TrackedMatrix{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, A))
243194

244195
# dense mean, diagonal covariance
245196
function MvLogNormal(

src/univariate.jl

+18-31
Original file line numberDiff line numberDiff line change
@@ -33,39 +33,25 @@ uniformlogpdf(a::TrackedReal, b::TrackedReal, x::TrackedReal) = track(uniformlog
3333
Tracker.@grad function uniformlogpdf(a, b, x)
3434
diff = data(b) - data(a)
3535
T = typeof(diff)
36-
l = -log(diff)
37-
f = isfinite(l)
38-
da = 1/diff
39-
n = T(NaN)
40-
return l, Δ->(f ? da : n, f ? -da : n, f ? zero(T) : n)
36+
if a <= data(x) <= b && a < b
37+
l = -log(diff)
38+
da = 1/diff^2
39+
return l, Δ -> (da * Δ, -da * Δ, zero(T) * Δ)
40+
else
41+
n = T(NaN)
42+
return l, Δ -> (n, n, n)
43+
end
4144
end
4245
ZygoteRules.@adjoint function uniformlogpdf(a, b, x)
4346
diff = b - a
4447
T = typeof(diff)
45-
l = -log(diff)
46-
f = isfinite(l)
47-
da = 1/diff
48-
n = T(NaN)
49-
z = zero(T)
50-
return l, Δ -> (f ? (z, z, z) : (n, n, n))
51-
end
52-
for T in (:TrackedReal, :Real)
53-
@eval @grad function uniformlogpdf(
54-
a::TrackedReal,
55-
b::TrackedReal,
56-
x::$T,
57-
)
58-
ad = data(a)
59-
bd = data(b)
60-
T = typeof(a)
61-
l = logpdf(Uniform(ad, bd), x)
62-
f = isfinite(l)
63-
temp = 1/(bd - ad)^2
64-
dlda = temp
65-
dldb = -temp
48+
if a <= x <= b && a < b
49+
l = -log(diff)
50+
da = 1/diff^2
51+
return l, Δ -> (da * Δ, -da * Δ, zero(T) * Δ)
52+
else
6653
n = T(NaN)
67-
z = zero(T)
68-
return l, Δ -> (f ? (dlda * Δ, dldb * Δ, z) : (n, n, n))
54+
return l, Δ -> (n, n, n)
6955
end
7056
end
7157
ZygoteRules.@adjoint function Distributions.Uniform(args...)
@@ -159,9 +145,9 @@ M, f, arity = DiffRules.@define_diffrule DistributionsAD.semicirclelogpdf(r, x)
159145
da, db = DiffRules.diffrule(M, f, :a, :b)
160146
f = :($M.$f)
161147
@eval begin
162-
@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ ->* $da, Δ * $db)
163-
@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ ->* $da, Tracker._zero(b))
164-
@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (Tracker._zero(a), Δ * $db)
148+
Tracker.@grad $f(a::TrackedReal, b::TrackedReal) = $f(data(a), data(b)), Δ ->* $da, Δ * $db)
149+
Tracker.@grad $f(a::TrackedReal, b::Real) = $f(data(a), b), Δ ->* $da, Tracker._zero(b))
150+
Tracker.@grad $f(a::Real, b::TrackedReal) = $f(a, data(b)), Δ -> (Tracker._zero(a), Δ * $db)
165151
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
166152
$f(a::TrackedReal, b::Real) = track($f, a, b)
167153
$f(a::Real, b::TrackedReal) = track($f, a, b)
@@ -292,6 +278,7 @@ ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray)
292278
((ForwardDiff.jacobian(x -> poissonbinomial_pdf_fft(x), x)::Matrix{T})' * Δ,)
293279
end
294280
end
281+
295282
# The code below doesn't work because of bugs in Zygote. The above is inefficient.
296283
#=
297284
ZygoteRules.@adjoint function poissonbinomial_pdf_fft(x::AbstractArray{<:Real})

0 commit comments

Comments
 (0)