Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiating mvnormal #1554

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
789ad0a
test and code for frule
matbesancon Apr 18, 2022
24861ca
added tests
matbesancon Apr 24, 2022
6e6c029
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Apr 25, 2022
2bcc217
diff on MvNormal only for now
matbesancon Apr 25, 2022
1e9571b
unthunk when only one element
matbesancon May 23, 2022
ac0995e
Update src/multivariate/mvnormal.jl
matbesancon May 24, 2022
522c13e
Update src/multivariate/mvnormal.jl
matbesancon May 24, 2022
6529a4a
no backing
matbesancon May 24, 2022
4e4d982
conflict
matbesancon May 24, 2022
d398140
fix inference
matbesancon May 24, 2022
661de16
unthunk common computation
matbesancon May 26, 2022
3c5007c
unthunk common computation
matbesancon May 26, 2022
1f18e67
remove rules for logpdf
matbesancon May 27, 2022
b60147d
revert changes
matbesancon Jul 30, 2022
375cad8
Update src/multivariate/mvnormal.jl
matbesancon Jul 30, 2022
c34a3ed
Update src/multivariate/mvnormal.jl
matbesancon Jul 30, 2022
2d96680
revert changes
matbesancon Jul 30, 2022
f32c223
Merge branch 'cr-mvnormal' of github.com:JuliaStats/Distributions.jl …
matbesancon Jul 30, 2022
b225298
Update src/multivariate/mvnormal.jl
matbesancon Jul 30, 2022
cf1242a
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Jul 31, 2022
e2846e8
Merge branch 'cr-mvnormal' of github.com:JuliaStats/Distributions.jl …
matbesancon Jul 31, 2022
6bc85da
avoid materializing Matrix
matbesancon Jul 31, 2022
e303416
revert
matbesancon Jul 31, 2022
e21506a
fix revert
matbesancon Jul 31, 2022
fd272ae
no alloc
matbesancon Jul 31, 2022
8b7d451
revert fdiff
matbesancon Jul 31, 2022
53601fe
Update src/multivariate/mvnormal.jl
matbesancon Jul 31, 2022
2c75061
fix op
matbesancon Jul 31, 2022
6d88e4d
fix op bis, revert to invcov
matbesancon Jul 31, 2022
1f00a3b
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Aug 20, 2022
8ebf419
Merge branch 'master' of github.com:JuliaStats/Distributions.jl into …
matbesancon Oct 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -515,3 +515,56 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVect
end
MvNormal(mu, ScalMat(m, va / (m * sw)))
end

## Differentiation

function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::MvNormal)
y = mvnormal_c0(d)
Δd = ChainRulesCore.unthunk(Δd)
Δy = -dot(Δd.Σ, invcov(d)) / 2
return y, Δy
end

function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::MvNormal)
y = mvnormal_c0(d)
function mvnormal_c0_pullback(dy)
dy = ChainRulesCore.unthunk(dy)
∂Σ = (dy / (-2)) * invcov(d)
∂d = ChainRulesCore.Tangent{typeof(d)}(Σ = ∂Σ)
return ChainRulesCore.NoTangent(), ∂d
end
return y, mvnormal_c0_pullback
end

function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::MvNormal, x::AbstractVector)
y = sqmahal(d, x)
(_, Δd, Δx) = dargs
Δd = ChainRulesCore.unthunk(Δd)
Δx = ChainRulesCore.unthunk(Δx)
Σinv = inv(_cov(d))
# TODO optimize
dΣ = -dot(Σinv * Δd.Σ * Σinv, x * x' - d.μ * x' - x * d.μ' + d.μ * d.μ')
dx = 2 * dot(Σinv * (x - d.μ), Δx)
dμ = 2 * dot(Σinv * (d.μ - x), Δd.μ)
Δy = dΣ + dx + dμ
matbesancon marked this conversation as resolved.
Show resolved Hide resolved
return (y, Δy)
end

function ChainRulesCore.rrule(::typeof(sqmahal), d::MvNormal, x::AbstractVector)
y = sqmahal(d, x)
Σ = _cov(d)
cx = x - d.μ
z = Σ \ cx
function sqmahal_pullback(_dy)
dy = ChainRulesCore.unthunk(_dy)
∂x = 2 * dy * z
∂d = ChainRulesCore.@thunk(begin
∂μ = -∂x
∂J = dy * cx * cx'
∂Σ = - (Σ \ ∂J) / Σ
ChainRulesCore.Tangent{typeof(d)}(μ = ∂μ, Σ = ∂Σ)
end)
return (ChainRulesCore.NoTangent(), ∂d, ∂x)
end
return y, sqmahal_pullback
end
53 changes: 53 additions & 0 deletions test/multivariate/mvnormal.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Tests on Multivariate Normal distributions

import PDMats
import PDMats: ScalMat, PDiagMat, PDMat
if isdefined(PDMats, :PDSparseMat)
import PDMats: PDSparseMat
Expand All @@ -9,6 +10,8 @@ using Distributions
using LinearAlgebra, Random, Test
using SparseArrays
using FillArrays
using ChainRulesCore
using ChainRulesTestUtils

###### General Testing

Expand Down Expand Up @@ -302,3 +305,53 @@ end
x = rand(d)
@test logpdf(d, x) ≈ logpdf(Normal(), x[1]) + logpdf(Normal(), x[2])
end

@testset "MvNormal differentiation rules" begin
for n in (3, 10)
for _ in 1:10
A = Symmetric(rand(n,n)) .+ 4 * Matrix(I, n, n)
@assert isposdef(A)
d = MvNormal(randn(n), A)
# make ΔΣ symmetric, such that Σ ± ΔΣ is PSD
t = 0.001 * ChainRulesTestUtils.rand_tangent(d)
t.Σ .+= t.Σ'
if eigmin(t.Σ) < 0
while eigmin(d.Σ + t.Σ) < 0
t.Σ .*= 0.8
end
end
if eigmax(t.Σ) > 0
while eigmin(d.Σ - t.Σ) < 0
t.Σ .*= 0.8
end
end
# mvnormal_c0
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t), Distributions.mvnormal_c0, d)
y_r, c0_pullback = @inferred ChainRulesCore.rrule(Distributions.mvnormal_c0, d)
@test y_r ≈ y
y2 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ + t.Σ))
@test unthunk(Δy) ≈ y2 - y atol= n * 1e-4
y3 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ - t.Σ))
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
(_, ∇c0) = @inferred c0_pullback(1.0)
∇c0 = ChainRulesCore.unthunk(∇c0)
@test dot(∇c0.Σ, t.Σ) ≈ y2 - y atol = n * 1e-4
@test dot(∇c0.Σ, t.Σ) ≈ y - y3 atol = n * 1e-4
# sqmahal
x = randn(n)
Δx = 0.0001 * randn(n)
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), sqmahal, d, x)
(yr, sqmahal_pullback) = @inferred ChainRulesCore.rrule(sqmahal, d, x)
(_, ∇s_d, ∇s_x) = @inferred sqmahal_pullback(1.0)
∇s_d = ChainRulesCore.unthunk(∇s_d)
∇s_x = ChainRulesCore.unthunk(∇s_x)
@test yr ≈ y
y2 = Distributions.sqmahal(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx)
y3 = Distributions.sqmahal(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx)
@test unthunk(Δy) ≈ y2 - y atol = n * 1e-4
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4
end
end
end