Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiating mvnormal #1554

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
789ad0a
test and code for frule
matbesancon Apr 18, 2022
24861ca
added tests
matbesancon Apr 24, 2022
6e6c029
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Apr 25, 2022
2bcc217
diff on MvNormal only for now
matbesancon Apr 25, 2022
1e9571b
unthunk when only one element
matbesancon May 23, 2022
ac0995e
Update src/multivariate/mvnormal.jl
matbesancon May 24, 2022
522c13e
Update src/multivariate/mvnormal.jl
matbesancon May 24, 2022
6529a4a
no backing
matbesancon May 24, 2022
4e4d982
conflict
matbesancon May 24, 2022
d398140
fix inference
matbesancon May 24, 2022
661de16
unthunk common computation
matbesancon May 26, 2022
3c5007c
unthunk common computation
matbesancon May 26, 2022
1f18e67
remove rules for logpdf
matbesancon May 27, 2022
b60147d
revert changes
matbesancon Jul 30, 2022
375cad8
Update src/multivariate/mvnormal.jl
matbesancon Jul 30, 2022
c34a3ed
Update src/multivariate/mvnormal.jl
matbesancon Jul 30, 2022
2d96680
revert changes
matbesancon Jul 30, 2022
f32c223
Merge branch 'cr-mvnormal' of github.com:JuliaStats/Distributions.jl …
matbesancon Jul 30, 2022
b225298
Update src/multivariate/mvnormal.jl
matbesancon Jul 30, 2022
cf1242a
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Jul 31, 2022
e2846e8
Merge branch 'cr-mvnormal' of github.com:JuliaStats/Distributions.jl …
matbesancon Jul 31, 2022
6bc85da
avoid materializing Matrix
matbesancon Jul 31, 2022
e303416
revert
matbesancon Jul 31, 2022
e21506a
fix revert
matbesancon Jul 31, 2022
fd272ae
no alloc
matbesancon Jul 31, 2022
8b7d451
revert fdiff
matbesancon Jul 31, 2022
53601fe
Update src/multivariate/mvnormal.jl
matbesancon Jul 31, 2022
2c75061
fix op
matbesancon Jul 31, 2022
6d88e4d
fix op bis, revert to invcov
matbesancon Jul 31, 2022
1f00a3b
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Aug 20, 2022
8ebf419
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Oct 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 97 additions & 13 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ Base.show(io::IO, d::MvNormal) =
length(d::MvNormal) = length(d.μ)
mean(d::MvNormal) = d.μ
params(d::MvNormal) = (d.μ, d.Σ)
@inline partype(d::MvNormal{T}) where {T<:Real} = T
@inline partype(::MvNormal{T}) where {T<:Real} = T

var(d::MvNormal) = diag(d.Σ)
cov(d::MvNormal) = Matrix(d.Σ)
Expand Down Expand Up @@ -372,7 +372,7 @@ struct MvNormalStats <: SufficientStats
tw::Float64 # total sample weight
end

function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64})
function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move unrelated changes to a separate PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there were all relatively minor things (unused variables)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert non-CR changes? It seems not only unused names were removed but also some types and dispatches changed, creating slight inconsistencies and related to the open issue about type parameters in fit. IMO it would br much cleaner to avoid these additonal changes in this PR here and instead fix the dispatches (and names) in a separate PR in a consistent way.

d = size(x, 1)
n = size(x, 2)
s = vec(sum(x, dims=2))
Expand All @@ -382,7 +382,7 @@ function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64})
MvNormalStats(s, m, s2, Float64(n))
end

function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
d = size(x, 1)
n = size(x, 2)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions."))
Expand Down Expand Up @@ -410,13 +410,13 @@ end
# each kind of covariance
#

fit_mle(D::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss)
fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x)
fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w)
fit_mle(::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss)
fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x)
fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w)

fit_mle(D::Type{FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw))
fit_mle(::Type{<:FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw))

function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64})
function fit_mle(::Type{FullNormal}, x::AbstractMatrix{Float64})
n = size(x, 2)
mu = vec(mean(x, dims=2))
z = x .- mu
Expand All @@ -425,7 +425,7 @@ function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64})
MvNormal(mu, PDMat(C))
end

function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
function fit_mle(::Type{<:FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
m = size(x, 1)
n = size(x, 2)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions"))
Expand All @@ -445,7 +445,7 @@ function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVec
MvNormal(mu, PDMat(C))
end

function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64})
function fit_mle(::Type{DiagNormal}, x::AbstractMatrix{Float64})
m = size(x, 1)
n = size(x, 2)

Expand All @@ -460,7 +460,7 @@ function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64})
MvNormal(mu, PDiagMat(va))
end

function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
function fit_mle(::Type{<:DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
m = size(x, 1)
n = size(x, 2)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions"))
Expand All @@ -479,7 +479,7 @@ function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVec
MvNormal(mu, PDiagMat(va))
end

function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64})
function fit_mle(::Type{IsoNormal}, x::AbstractMatrix{Float64})
m = size(x, 1)
n = size(x, 2)

Expand All @@ -495,7 +495,7 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64})
MvNormal(mu, ScalMat(m, va / (m * n)))
end

function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
function fit_mle(::Type{<:IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector)
m = size(x, 1)
n = size(x, 2)
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions"))
Expand All @@ -515,3 +515,87 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVect
end
MvNormal(mu, ScalMat(m, va / (m * sw)))
end

## Differentiation

function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector)
c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d)
sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use rather frule_via_ad here for calling back into the AD system, in case it wants to define its own, possibly improved derivatives.

Δc0 = ChainRulesCore.unthunk(Δc0)
Δsq = ChainRulesCore.unthunk(Δsq)
return c0 - sq/2, Δc0 - Δsq/2
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's useful to add this definition. This is exactly what AD systems do anyway.

Suggested change
function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector)
c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d)
sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x)
Δc0 = ChainRulesCore.unthunk(Δc0)
Δsq = ChainRulesCore.unthunk(Δsq)
return c0 - sq/2, Δc0 - Δsq/2
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't cost us much to add this definition and lets us have derivatives built-in, we can also re-add specialized methods for some MvNormal if necessary

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it can be quite problematic to define derivatives that overrule the AD system if they are not needed and e.g. to generic (as possibly the case here). I've ran into multiple issues of this kind with ChainRules, which then requires e.g. packages that otherwise would just work (without even knowing about ChainRules) to use ChainRules.@opt_out or define their own rules. The rule here will catch every AbstractMvNormal and AbstractVector which can be problematic e.g. similar to TuringLang/DistributionsAD.jl#180.

So I strongly recommend not adding rules that are not needed.


function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector)
c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d)
sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, this should probably be rrule_via_ad.

function logpdf_MvNormal_pullback(dy)
dy = ChainRulesCore.unthunk(dy)
(_, ∂d_c0) = c0_pullback(dy)
∂d_c0 = ChainRulesCore.unthunk(∂d_c0)
(_, ∂d_sq, ∂x_sq) = sq_pullback(dy)
∂d_sq_v = ChainRulesCore.unthunk(∂d_sq)
∂x_sq_v::typeof(x) = ChainRulesCore.unthunk(∂x_sq)
μs::typeof(d.μ) = ∂d_sq_v.μ
Σs::Matrix{partype(d)} = ∂d_sq_v.Σ
∂d = ChainRulesCore.Tangent{typeof(d)}(;
μ = ∂d_c0.μ - 0.5 * μs,
Σ = ∂d_c0.Σ - 0.5 * Σs,
)
return ChainRulesCore.NoTangent(), ∂d, -∂x_sq_v / 2
end
return c0 - sq / 2, logpdf_MvNormal_pullback
end

function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::MvNormal)
y = mvnormal_c0(d)
Δd = ChainRulesCore.unthunk(Δd)
Δy = -dot(Δd.Σ, invcov(d)) / 2
return y, Δy
end

function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::MvNormal)
y = mvnormal_c0(d)
function mvnormal_c0_pullback(dy)
dy = ChainRulesCore.unthunk(dy)
∂Σ = -dy/2 * invcov(d)
matbesancon marked this conversation as resolved.
Show resolved Hide resolved
∂d = ChainRulesCore.Tangent{typeof(d)}(μ = ChainRulesCore.ZeroTangent(), Σ = ∂Σ)
matbesancon marked this conversation as resolved.
Show resolved Hide resolved
return ChainRulesCore.NoTangent(), ∂d
end
return y, mvnormal_c0_pullback
end

function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::MvNormal, x::AbstractVector)
y = sqmahal(d, x)
(_, Δd, Δx) = dargs
Δd = ChainRulesCore.unthunk(Δd)
Δx = ChainRulesCore.unthunk(Δx)
Σinv = invcov(d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we avoid computing the inverse?

# TODO optimize
dΣ = -dot(Σinv * Δd.Σ * Σinv, x * x' - d.μ * x' - x * d.μ' + d.μ * d.μ')
dx = 2 * dot(Σinv * (x - d.μ), Δx)
dμ = 2 * dot(Σinv * (d.μ - x), Δd.μ)
Δy = dΣ + dx + dμ
matbesancon marked this conversation as resolved.
Show resolved Hide resolved
return (y, Δy)
end

function ChainRulesCore.rrule(::typeof(sqmahal), d::MvNormal, x::AbstractVector)
y = sqmahal(d, x)
function sqmahal_pullback(dy)
Σinv = invcov(d)
∂x = ChainRulesCore.@thunk(begin
dy = ChainRulesCore.unthunk(dy)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
2dy * Σinv * (x - d.μ)
end)
∂d = ChainRulesCore.@thunk(begin
dy = ChainRulesCore.unthunk(dy)
cx = x - d.μ
∂μ = -2dy * Σinv * cx
∂J = dy * cx * cx'
∂Σ = - Σinv * ∂J * Σinv
ChainRulesCore.Tangent{typeof(d)}(μ = ∂μ, Σ = ∂Σ)
end)
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
end
return y, sqmahal_pullback
end
65 changes: 65 additions & 0 deletions test/mvnormal.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Tests on Multivariate Normal distributions

import PDMats
import PDMats: ScalMat, PDiagMat, PDMat
if isdefined(PDMats, :PDSparseMat)
import PDMats: PDSparseMat
Expand All @@ -9,6 +10,8 @@ using Distributions
using LinearAlgebra, Random, Test
using SparseArrays
using FillArrays
using ChainRulesCore
using ChainRulesTestUtils

###### General Testing

Expand Down Expand Up @@ -302,3 +305,65 @@ end
x = rand(d)
@test logpdf(d, x) ≈ logpdf(Normal(), x[1]) + logpdf(Normal(), x[2])
end

@testset "MvNormal differentiation rules" begin
for n in (3, 10)
for _ in 1:10
A = Symmetric(rand(n,n)) .+ 4 * Matrix(I, n, n)
@assert isposdef(A)
d = MvNormal(randn(n), A)
# make ΔΣ symmetric, such that Σ ± ΔΣ is PSD
t = 0.001 * ChainRulesTestUtils.rand_tangent(d)
t.Σ .+= t.Σ'
if eigmin(t.Σ) < 0
while eigmin(d.Σ + t.Σ) < 0
t.Σ .*= 0.8
end
end
if eigmax(t.Σ) > 0
while eigmin(d.Σ - t.Σ) < 0
t.Σ .*= 0.8
end
end
# mvnormal_c0
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t), Distributions.mvnormal_c0, d)
y_r, c0_pullback = @inferred ChainRulesCore.rrule(Distributions.mvnormal_c0, d)
@test y_r ≈ y
y2 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ + t.Σ))
@test unthunk(Δy) ≈ y2 - y atol= n * 1e-4
y3 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ - t.Σ))
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
(_, ∇c0) = @inferred c0_pullback(1.0)
∇c0 = ChainRulesCore.unthunk(∇c0)
@test dot(∇c0.Σ, t.Σ) ≈ y2 - y atol = n * 1e-4
@test dot(∇c0.Σ, t.Σ) ≈ y - y3 atol = n * 1e-4
# sqmahal
x = randn(n)
Δx = 0.0001 * randn(n)
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), sqmahal, d, x)
(yr, sqmahal_pullback) = @inferred ChainRulesCore.rrule(sqmahal, d, x)
(_, ∇s_d, ∇s_x) = @inferred sqmahal_pullback(1.0)
∇s_d = ChainRulesCore.unthunk(∇s_d)
∇s_x = ChainRulesCore.unthunk(∇s_x)
@test yr ≈ y
y2 = Distributions.sqmahal(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx)
y3 = Distributions.sqmahal(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx)
@test unthunk(Δy) ≈ y2 - y atol = n * 1e-4
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4
# _logpdf
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), Distributions._logpdf, d, x)
(yr, logpdf_MvNormal_pullback) = @inferred ChainRulesCore.rrule(Distributions._logpdf, d, x)
@test y ≈ yr
(_, ∇s_d, ∇s_x) = @inferred logpdf_MvNormal_pullback(1.0)

y2 = Distributions._logpdf(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx)
y3 = Distributions._logpdf(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx)
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
@test unthunk(Δy) ≈ y2 - y atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4
end
end
end