Skip to content

Commit

Permalink
Merge pull request #92 from devmotion/add_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jun 29, 2020
2 parents 7b2a0fb + 7a62896 commit 424ff23
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 32 deletions.
20 changes: 12 additions & 8 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,20 +170,20 @@ end
for T in (:TrackedVector, :TrackedMatrix)
@eval begin
function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.ScalMat}, x::$T)
logpdf(TuringScalMvNormal(d.μ, d.Σ.value), x)
logpdf(TuringScalMvNormal(d.μ, sqrt(d.Σ.value)), x)
end
function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.PDiagMat}, x::$T)
logpdf(TuringDiagMvNormal(d.μ, d.Σ.diag), x)
logpdf(TuringDiagMvNormal(d.μ, sqrt.(d.Σ.diag)), x)
end
function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.PDMat}, x::$T)
logpdf(TuringDenseMvNormal(d.μ, d.Σ.chol), x)
end

function Distributions.logpdf(d::MvLogNormal{<:Any, <:PDMats.ScalMat}, x::$T)
logpdf(TuringMvLogNormal(TuringScalMvNormal(d.normal.μ, d.normal.Σ.value)), x)
logpdf(TuringMvLogNormal(TuringScalMvNormal(d.normal.μ, sqrt(d.normal.Σ.value))), x)
end
function Distributions.logpdf(d::MvLogNormal{<:Any, <:PDMats.PDiagMat}, x::$T)
logpdf(TuringMvLogNormal(TuringDiagMvNormal(d.normal.μ, d.normal.Σ.diag)), x)
logpdf(TuringMvLogNormal(TuringDiagMvNormal(d.normal.μ, sqrt.(d.normal.Σ.diag))), x)
end
function Distributions.logpdf(d::MvLogNormal{<:Any, <:PDMats.PDMat}, x::$T)
logpdf(TuringMvLogNormal(TuringDenseMvNormal(d.normal.μ, d.normal.Σ.chol)), x)
Expand Down Expand Up @@ -406,15 +406,15 @@ for T in (:AbstractVector, :AbstractMatrix)
x::$T
)
return ZygoteRules.pullback(d, x) do d, x
logpdf(TuringScalMvNormal(d.μ, d.Σ.value), x)
logpdf(TuringScalMvNormal(d.μ, sqrt(d.Σ.value)), x)
end
end
ZygoteRules.@adjoint function Distributions.logpdf(
d::MvNormal{<:Any, <:PDMats.PDiagMat},
x::$T
)
return ZygoteRules.pullback(d, x) do d, x
logpdf(TuringDiagMvNormal(d.μ, d.Σ.diag), x)
logpdf(TuringDiagMvNormal(d.μ, sqrt.(d.Σ.diag)), x)
end
end
ZygoteRules.@adjoint function Distributions.logpdf(
Expand All @@ -431,7 +431,9 @@ for T in (:AbstractVector, :AbstractMatrix)
x::$T
)
return ZygoteRules.pullback(d, x) do d, x
dist = TuringMvLogNormal(TuringScalMvNormal(d.normal.μ, d.normal.Σ.value))
dist = TuringMvLogNormal(
TuringScalMvNormal(d.normal.μ, sqrt(d.normal.Σ.value)),
)
logpdf(dist, x)
end
end
Expand All @@ -440,7 +442,9 @@ for T in (:AbstractVector, :AbstractMatrix)
x::$T
)
return ZygoteRules.pullback(d, x) do d, x
dist = TuringMvLogNormal(TuringDiagMvNormal(d.normal.μ, d.normal.Σ.diag))
dist = TuringMvLogNormal(
TuringDiagMvNormal(d.normal.μ, sqrt.(d.normal.Σ.diag)),
)
logpdf(dist, x)
end
end
Expand Down
8 changes: 4 additions & 4 deletions src/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,20 @@ end
for T in (:TrackedVector, :TrackedMatrix)
@eval begin
function logpdf(d::MvNormal{<:Any, <:PDMats.ScalMat}, x::$T)
logpdf(TuringScalMvNormal(d.μ, d.Σ.value), x)
logpdf(TuringScalMvNormal(d.μ, sqrt(d.Σ.value)), x)
end
function logpdf(d::MvNormal{<:Any, <:PDMats.PDiagMat}, x::$T)
logpdf(TuringDiagMvNormal(d.μ, d.Σ.diag), x)
logpdf(TuringDiagMvNormal(d.μ, sqrt.(d.Σ.diag)), x)
end
function logpdf(d::MvNormal{<:Any, <:PDMats.PDMat}, x::$T)
logpdf(TuringDenseMvNormal(d.μ, d.Σ.chol), x)
end

function logpdf(d::MvLogNormal{<:Any, <:PDMats.ScalMat}, x::$T)
logpdf(TuringMvLogNormal(TuringScalMvNormal(d.normal.μ, d.normal.Σ.value)), x)
logpdf(TuringMvLogNormal(TuringScalMvNormal(d.normal.μ, sqrt(d.normal.Σ.value))), x)
end
function logpdf(d::MvLogNormal{<:Any, <:PDMats.PDiagMat}, x::$T)
logpdf(TuringMvLogNormal(TuringDiagMvNormal(d.normal.μ, d.normal.Σ.diag)), x)
logpdf(TuringMvLogNormal(TuringDiagMvNormal(d.normal.μ, sqrt.(d.normal.Σ.diag))), x)
end
function logpdf(d::MvLogNormal{<:Any, <:PDMats.PDMat}, x::$T)
logpdf(TuringMvLogNormal(TuringDenseMvNormal(d.normal.μ, d.normal.Σ.chol)), x)
Expand Down
33 changes: 13 additions & 20 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,35 +105,28 @@ function test_ad(dist::DistSpec; kwargs...)
end
end

if isempty(θ)
# In this case we can only test the gradient with respect to `x`
xtest = vectorize(x)
ftest = let xorig=x
x -> f_allargs(unpack(x, (1,), xorig)...)
end
test_ad(ftest, xtest; kwargs...)
else
# For all combinations of distribution parameters `θ`
for inds in combinations(2:(length(θ) + 1))
# Test only distribution parameters
# For all combinations of distribution parameters `θ`
for inds in powerset(2:(length(θ) + 1))
# Test only distribution parameters
if !isempty(inds)
xtest = mapreduce(vcat, inds) do i
vectorize(θ[i - 1])
end
ftest = let xorig=x, θorig=θ, inds=inds
x -> f_allargs(unpack(x, inds, xorig, θorig...)...)
end
test_ad(ftest, xtest; kwargs...)
end

# Test derivative with respect to location `x` as well
# if the distribution is continuous
if Distributions.value_support(typeof(dist)) === Continuous
xtest = vcat(vectorize(x), xtest)
push!(inds, 1)
ftest = let xorig=x, θorig=θ, inds=inds
x -> f_allargs(unpack(x, inds, xorig, θorig...)...)
end
test_ad(ftest, xtest; kwargs...)
# Test derivative with respect to location `x` as well
# if the distribution is continuous
if Distributions.value_support(typeof(dist)) === Continuous
xtest = isempty(inds) ? vectorize(x) : vcat(vectorize(x), xtest)
push!(inds, 1)
ftest = let xorig=x, θorig=θ, inds=inds
x -> f_allargs(unpack(x, inds, xorig, θorig...)...)
end
test_ad(ftest, xtest; kwargs...)
end
end
end
Expand Down

0 comments on commit 424ff23

Please sign in to comment.