Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiating mvnormal #1554

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
789ad0a
test and code for frule
matbesancon Apr 18, 2022
24861ca
added tests
matbesancon Apr 24, 2022
6e6c029
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Apr 25, 2022
2bcc217
diff on MvNormal only for now
matbesancon Apr 25, 2022
1e9571b
unthunk when only one element
matbesancon May 23, 2022
ac0995e
Update src/multivariate/mvnormal.jl
matbesancon May 24, 2022
522c13e
Update src/multivariate/mvnormal.jl
matbesancon May 24, 2022
6529a4a
no backing
matbesancon May 24, 2022
4e4d982
conflict
matbesancon May 24, 2022
d398140
fix inference
matbesancon May 24, 2022
661de16
unthunk common computation
matbesancon May 26, 2022
3c5007c
unthunk common computation
matbesancon May 26, 2022
1f18e67
remove rules for logpdf
matbesancon May 27, 2022
b60147d
revert changes
matbesancon Jul 30, 2022
375cad8
Update src/multivariate/mvnormal.jl
matbesancon Jul 30, 2022
c34a3ed
Update src/multivariate/mvnormal.jl
matbesancon Jul 30, 2022
2d96680
revert changes
matbesancon Jul 30, 2022
f32c223
Merge branch 'cr-mvnormal' of github.com:JuliaStats/Distributions.jl …
matbesancon Jul 30, 2022
b225298
Update src/multivariate/mvnormal.jl
matbesancon Jul 30, 2022
cf1242a
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Jul 31, 2022
e2846e8
Merge branch 'cr-mvnormal' of github.com:JuliaStats/Distributions.jl …
matbesancon Jul 31, 2022
6bc85da
avoid materializing Matrix
matbesancon Jul 31, 2022
e303416
revert
matbesancon Jul 31, 2022
e21506a
fix revert
matbesancon Jul 31, 2022
fd272ae
no alloc
matbesancon Jul 31, 2022
8b7d451
revert fdiff
matbesancon Jul 31, 2022
53601fe
Update src/multivariate/mvnormal.jl
matbesancon Jul 31, 2022
2c75061
fix op
matbesancon Jul 31, 2022
6d88e4d
fix op bis, revert to invcov
matbesancon Jul 31, 2022
1f00a3b
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Aug 20, 2022
8ebf419
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Oct 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 105 additions & 13 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ Base.show(io::IO, d::MvNormal) =
length(d::MvNormal) = length(d.μ)
mean(d::MvNormal) = d.μ
params(d::MvNormal) = (d.μ, d.Σ)
@inline partype(d::MvNormal{T}) where {T<:Real} = T
@inline partype(::MvNormal{T}) where {T<:Real} = T

var(d::MvNormal) = diag(d.Σ)
cov(d::MvNormal) = Matrix(d.Σ)
Expand Down Expand Up @@ -372,7 +372,7 @@ struct MvNormalStats <: SufficientStats
tw::Float64 # total sample weight
end

function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64})
function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move unrelated changes to a separate PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there were all relatively minor things (unused variables)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert non-CR changes? It seems not only unused names were removed but also some types and dispatches changed, creating slight inconsistencies and related to the open issue about type parameters in fit. IMO it would br much cleaner to avoid these additonal changes in this PR here and instead fix the dispatches (and names) in a separate PR in a consistent way.

d = size(x, 1)
n = size(x, 2)
s = vec(sum(x, dims=2))
Expand All @@ -382,7 +382,7 @@ function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64})
MvNormalStats(s, m, s2, Float64(n))
end

function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
d = size(x, 1)
n = size(x, 2)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions."))
Expand Down Expand Up @@ -410,13 +410,13 @@ end
# each kind of covariance
#

fit_mle(D::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss)
fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x)
fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w)
fit_mle(::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss)
fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x)
fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w)

fit_mle(D::Type{FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw))
fit_mle(::Type{<:FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw))

function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64})
function fit_mle(::Type{FullNormal}, x::AbstractMatrix{Float64})
n = size(x, 2)
mu = vec(mean(x, dims=2))
z = x .- mu
Expand All @@ -425,7 +425,7 @@ function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64})
MvNormal(mu, PDMat(C))
end

function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
function fit_mle(::Type{<:FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
m = size(x, 1)
n = size(x, 2)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions"))
Expand All @@ -445,7 +445,7 @@ function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVec
MvNormal(mu, PDMat(C))
end

function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64})
function fit_mle(::Type{DiagNormal}, x::AbstractMatrix{Float64})
m = size(x, 1)
n = size(x, 2)

Expand All @@ -460,7 +460,7 @@ function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64})
MvNormal(mu, PDiagMat(va))
end

function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
function fit_mle(::Type{<:DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
m = size(x, 1)
n = size(x, 2)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions"))
Expand All @@ -479,7 +479,7 @@ function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVec
MvNormal(mu, PDiagMat(va))
end

function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64})
function fit_mle(::Type{IsoNormal}, x::AbstractMatrix{Float64})
m = size(x, 1)
n = size(x, 2)

Expand All @@ -495,7 +495,7 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64})
MvNormal(mu, ScalMat(m, va / (m * n)))
end

function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
function fit_mle(::Type{<:IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
m = size(x, 1)
n = size(x, 2)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions"))
Expand All @@ -515,3 +515,95 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVect
end
MvNormal(mu, ScalMat(m, va / (m * sw)))
end

## Differentiation

function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector)
c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d)
sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use rather frule_via_ad here for calling back into the AD system, in case it wants to define its own, possibly improved derivatives.

return c0 - sq/2, ChainRulesCore.@thunk(begin
Δc0 = ChainRulesCore.unthunk(Δc0)
Δsq = ChainRulesCore.unthunk(Δsq)
Δc0 - Δsq/2
end)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Derivatives should not be thunked if there's only one of them.

end

function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector)
c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d)
sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, this should probably be rrule_via_ad.

function logpdf_MvNormal_pullback(dy)
dy = ChainRulesCore.unthunk(dy)
(_, ∂d_c0) = c0_pullback(dy)
∂d_c0 = ChainRulesCore.unthunk(∂d_c0)
(_, ∂d_sq, ∂x_sq) = sq_pullback(dy)
∂d_sq = ChainRulesCore.unthunk(∂d_sq)
∂x_sq = ChainRulesCore.unthunk(∂x_sq)
backing = NamedTuple{(:μ, :Σ), Tuple{typeof(∂d_sq.μ), typeof(∂d_sq.Σ)}}((
(∂d_c0.μ - 0.5 * ∂d_sq.μ),
(∂d_c0.Σ - 0.5 * ∂d_sq.Σ),
))
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
return ChainRulesCore.NoTangent(), ∂d, - 0.5 * ∂x_sq
matbesancon marked this conversation as resolved.
Show resolved Hide resolved
end
return c0 - 0.5 * sq, logpdf_MvNormal_pullback
matbesancon marked this conversation as resolved.
Show resolved Hide resolved
end

function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::MvNormal)
y = mvnormal_c0(d)
Δy = ChainRulesCore.@thunk(begin
Δd = ChainRulesCore.unthunk(Δd)
-dot(Δd.Σ, invcov(d)) / 2
end)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, no thunks.

Suggested change
Δy = ChainRulesCore.@thunk(begin
Δd = ChainRulesCore.unthunk(Δd)
-dot(Δd.Σ, invcov(d)) / 2
end)
Δy = -dot(Δd.Σ, invcov(d)) / 2

return y, Δy
end

function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::MvNormal)
y = mvnormal_c0(d)
function mvnormal_c0_pullback(dy)
∂d = ChainRulesCore.@thunk(begin
dy = ChainRulesCore.unthunk(dy)
∂Σ = -dy/2 * invcov(d)
ChainRulesCore.Tangent{typeof(d)}(μ = ChainRulesCore.ZeroTangent(), Σ = ∂Σ)
end)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No thunk 🙂

return ChainRulesCore.NoTangent(), ∂d
end
return y, mvnormal_c0_pullback
end

function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::MvNormal, x::AbstractVector)
y = sqmahal(d, x)
Δy = ChainRulesCore.@thunk(begin
(_, Δd, Δx) = dargs
Δd = ChainRulesCore.unthunk(Δd)
Δx = ChainRulesCore.unthunk(Δx)
Σinv = invcov(d)
# TODO optimize
dΣ = -dot(Σinv * Δd.Σ * Σinv, x * x' - d.μ * x' - x * d.μ' + d.μ * d.μ')
dx = 2 * dot(Σinv * (x - d.μ), Δx)
dμ = 2 * dot(Σinv * (d.μ - x), Δd.μ)
dΣ + dx + dμ
end)
matbesancon marked this conversation as resolved.
Show resolved Hide resolved
return (y, Δy)
end

function ChainRulesCore.rrule(::typeof(sqmahal), d::MvNormal, x::AbstractVector)
y = sqmahal(d, x)
function sqmahal_pullback(dy)
∂x = ChainRulesCore.@thunk(begin
dy = ChainRulesCore.unthunk(dy)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
Σinv = invcov(d)
matbesancon marked this conversation as resolved.
Show resolved Hide resolved
2dy * Σinv * (x - d.μ)
end)
∂d = ChainRulesCore.@thunk(begin
dy = ChainRulesCore.unthunk(dy)
Σinv = invcov(d)
cx = x - d.μ
∂μ = -2dy * Σinv * cx
∂J = dy * cx * cx'
∂Σ = - Σinv * ∂J * Σinv
ChainRulesCore.Tangent{typeof(d)}(μ = ∂μ, Σ = ∂Σ)
end)
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
end
return y, sqmahal_pullback
end
67 changes: 67 additions & 0 deletions test/mvnormal.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Tests on Multivariate Normal distributions

import PDMats
import PDMats: ScalMat, PDiagMat, PDMat
if isdefined(PDMats, :PDSparseMat)
import PDMats: PDSparseMat
Expand All @@ -9,6 +10,8 @@ using Distributions
using LinearAlgebra, Random, Test
using SparseArrays
using FillArrays
using ChainRulesCore
using ChainRulesTestUtils

###### General Testing

Expand Down Expand Up @@ -302,3 +305,67 @@ end
x = rand(d)
@test logpdf(d, x) ≈ logpdf(Normal(), x[1]) + logpdf(Normal(), x[2])
end

@testset "MvNormal differentiation rules" begin
for n in (3, 10)
for _ in 1:10
A = Symmetric(rand(n,n)) .+ 4 * Matrix(I, n, n)
@assert isposdef(A)
d = MvNormal(randn(n), A)
# make ΔΣ symmetric, such that Σ ± ΔΣ is PSD
t = 0.001 * ChainRulesTestUtils.rand_tangent(d)
t.Σ .+= t.Σ'
if eigmin(t.Σ) < 0
while eigmin(d.Σ + t.Σ) < 0
t.Σ .*= 0.8
end
end
if eigmax(t.Σ) > 0
while eigmin(d.Σ - t.Σ) < 0
t.Σ .*= 0.8
end
end
# mvnormal_c0
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t), Distributions.mvnormal_c0, d)
y_r, c0_pullback = @inferred ChainRulesCore.rrule(Distributions.mvnormal_c0, d)
@test y_r ≈ y
y2 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ + t.Σ))
@test unthunk(Δy) ≈ y2 - y atol= n * 1e-4
y3 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ - t.Σ))
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
(_, ∇c0) = c0_pullback(1.0)
∇c0 = ChainRulesCore.unthunk(∇c0)
@test dot(∇c0.Σ, t.Σ) ≈ y2 - y atol = n * 1e-4
@test dot(∇c0.Σ, t.Σ) ≈ y - y3 atol = n * 1e-4
# sqmahal
x = randn(n)
Δx = 0.0001 * randn(n)
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), sqmahal, d, x)
(yr, sqmahal_pullback) = @inferred ChainRulesCore.rrule(sqmahal, d, x)
(_, ∇s_d, ∇s_x) = @inferred sqmahal_pullback(1.0)
∇s_d = ChainRulesCore.unthunk(∇s_d)
∇s_x = ChainRulesCore.unthunk(∇s_x)
@test yr ≈ y
y2 = Distributions.sqmahal(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx)
y3 = Distributions.sqmahal(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx)
@test unthunk(Δy) ≈ y2 - y atol = n * 1e-4
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4
# _logpdf
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), Distributions._logpdf, d, x)
(yr, logpdf_MvNormal_pullback) = @inferred ChainRulesCore.rrule(Distributions._logpdf, d, x)
@test y ≈ yr
# inference broken
# (_, ∇s_d, ∇s_x) = @inferred logpdf_MvNormal_pullback(1.0)
(_, ∇s_d, ∇s_x) = logpdf_MvNormal_pullback(1.0)

y2 = Distributions._logpdf(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx)
y3 = Distributions._logpdf(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx)
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
@test unthunk(Δy) ≈ y2 - y atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very complicated. Ideally we should just use test_rrule and test_frule.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is nontrivial in the case of the perturbation of the covariance matrix

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sometimes one has to add custom tests. But the standard should be test_frule and test_rrule since they check multiple other parts of the CR interface as well, in addition to numerical accuracy and type inference. And even for cases that are problematic for finite differencing (e.g. due to singularities and domain constraints) it is sometimes possible to use test_frule and test_rrule by specifying custom finite differencing methods, such as e.g. in #1555.

end
end
end