-
Notifications
You must be signed in to change notification settings - Fork 421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiating mvnormal #1554
base: master
Are you sure you want to change the base?
Differentiating mvnormal #1554
Changes from 4 commits
789ad0a
24861ca
6e6c029
2bcc217
1e9571b
ac0995e
522c13e
6529a4a
4e4d982
d398140
661de16
3c5007c
1f18e67
b60147d
375cad8
c34a3ed
2d96680
f32c223
b225298
cf1242a
e2846e8
6bc85da
e303416
e21506a
fd272ae
8b7d451
53601fe
2c75061
6d88e4d
1f00a3b
8ebf419
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -253,7 +253,7 @@ Base.show(io::IO, d::MvNormal) = | |||||||||||
length(d::MvNormal) = length(d.μ) | ||||||||||||
mean(d::MvNormal) = d.μ | ||||||||||||
params(d::MvNormal) = (d.μ, d.Σ) | ||||||||||||
@inline partype(d::MvNormal{T}) where {T<:Real} = T | ||||||||||||
@inline partype(::MvNormal{T}) where {T<:Real} = T | ||||||||||||
|
||||||||||||
var(d::MvNormal) = diag(d.Σ) | ||||||||||||
cov(d::MvNormal) = Matrix(d.Σ) | ||||||||||||
|
@@ -372,7 +372,7 @@ struct MvNormalStats <: SufficientStats | |||||||||||
tw::Float64 # total sample weight | ||||||||||||
end | ||||||||||||
|
||||||||||||
function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you revert non-CR changes? It seems not only unused names were removed but also some types and dispatches changed, creating slight inconsistencies and related to the open issue about type parameters in |
||||||||||||
d = size(x, 1) | ||||||||||||
n = size(x, 2) | ||||||||||||
s = vec(sum(x, dims=2)) | ||||||||||||
|
@@ -382,7 +382,7 @@ function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}) | |||||||||||
MvNormalStats(s, m, s2, Float64(n)) | ||||||||||||
end | ||||||||||||
|
||||||||||||
function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
d = size(x, 1) | ||||||||||||
n = size(x, 2) | ||||||||||||
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions.")) | ||||||||||||
|
@@ -410,13 +410,13 @@ end | |||||||||||
# each kind of covariance | ||||||||||||
# | ||||||||||||
|
||||||||||||
fit_mle(D::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss) | ||||||||||||
fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x) | ||||||||||||
fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w) | ||||||||||||
fit_mle(::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss) | ||||||||||||
fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x) | ||||||||||||
fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w) | ||||||||||||
|
||||||||||||
fit_mle(D::Type{FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw)) | ||||||||||||
fit_mle(::Type{<:FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw)) | ||||||||||||
|
||||||||||||
function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
function fit_mle(::Type{FullNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
n = size(x, 2) | ||||||||||||
mu = vec(mean(x, dims=2)) | ||||||||||||
z = x .- mu | ||||||||||||
|
@@ -425,7 +425,7 @@ function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}) | |||||||||||
MvNormal(mu, PDMat(C)) | ||||||||||||
end | ||||||||||||
|
||||||||||||
function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
function fit_mle(::Type{<:FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
m = size(x, 1) | ||||||||||||
n = size(x, 2) | ||||||||||||
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions")) | ||||||||||||
|
@@ -445,7 +445,7 @@ function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVec | |||||||||||
MvNormal(mu, PDMat(C)) | ||||||||||||
end | ||||||||||||
|
||||||||||||
function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
function fit_mle(::Type{DiagNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
m = size(x, 1) | ||||||||||||
n = size(x, 2) | ||||||||||||
|
||||||||||||
|
@@ -460,7 +460,7 @@ function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}) | |||||||||||
MvNormal(mu, PDiagMat(va)) | ||||||||||||
end | ||||||||||||
|
||||||||||||
function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
function fit_mle(::Type{<:DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
m = size(x, 1) | ||||||||||||
n = size(x, 2) | ||||||||||||
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions")) | ||||||||||||
|
@@ -479,7 +479,7 @@ function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVec | |||||||||||
MvNormal(mu, PDiagMat(va)) | ||||||||||||
end | ||||||||||||
|
||||||||||||
function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
function fit_mle(::Type{IsoNormal}, x::AbstractMatrix{Float64}) | ||||||||||||
m = size(x, 1) | ||||||||||||
n = size(x, 2) | ||||||||||||
|
||||||||||||
|
@@ -495,7 +495,7 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}) | |||||||||||
MvNormal(mu, ScalMat(m, va / (m * n))) | ||||||||||||
end | ||||||||||||
|
||||||||||||
function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
function fit_mle(::Type{<:IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) | ||||||||||||
m = size(x, 1) | ||||||||||||
n = size(x, 2) | ||||||||||||
length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions")) | ||||||||||||
|
@@ -515,3 +515,95 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVect | |||||||||||
end | ||||||||||||
MvNormal(mu, ScalMat(m, va / (m * sw))) | ||||||||||||
end | ||||||||||||
|
||||||||||||
## Differentiation | ||||||||||||
|
||||||||||||
function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector) | ||||||||||||
c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d) | ||||||||||||
sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should use rather |
||||||||||||
return c0 - sq/2, ChainRulesCore.@thunk(begin | ||||||||||||
Δc0 = ChainRulesCore.unthunk(Δc0) | ||||||||||||
Δsq = ChainRulesCore.unthunk(Δsq) | ||||||||||||
Δc0 - Δsq/2 | ||||||||||||
end) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Derivatives should not be thunked if there's only one of them. |
||||||||||||
end | ||||||||||||
|
||||||||||||
function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector) | ||||||||||||
c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d) | ||||||||||||
sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, this should probably be |
||||||||||||
function logpdf_MvNormal_pullback(dy) | ||||||||||||
dy = ChainRulesCore.unthunk(dy) | ||||||||||||
(_, ∂d_c0) = c0_pullback(dy) | ||||||||||||
∂d_c0 = ChainRulesCore.unthunk(∂d_c0) | ||||||||||||
(_, ∂d_sq, ∂x_sq) = sq_pullback(dy) | ||||||||||||
∂d_sq = ChainRulesCore.unthunk(∂d_sq) | ||||||||||||
∂x_sq = ChainRulesCore.unthunk(∂x_sq) | ||||||||||||
backing = NamedTuple{(:μ, :Σ), Tuple{typeof(∂d_sq.μ), typeof(∂d_sq.Σ)}}(( | ||||||||||||
(∂d_c0.μ - 0.5 * ∂d_sq.μ), | ||||||||||||
(∂d_c0.Σ - 0.5 * ∂d_sq.Σ), | ||||||||||||
)) | ||||||||||||
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) | ||||||||||||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
return ChainRulesCore.NoTangent(), ∂d, - 0.5 * ∂x_sq | ||||||||||||
matbesancon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
end | ||||||||||||
return c0 - 0.5 * sq, logpdf_MvNormal_pullback | ||||||||||||
matbesancon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
end | ||||||||||||
|
||||||||||||
function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::MvNormal) | ||||||||||||
y = mvnormal_c0(d) | ||||||||||||
Δy = ChainRulesCore.@thunk(begin | ||||||||||||
Δd = ChainRulesCore.unthunk(Δd) | ||||||||||||
-dot(Δd.Σ, invcov(d)) / 2 | ||||||||||||
end) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, no thunks.
Suggested change
|
||||||||||||
return y, Δy | ||||||||||||
end | ||||||||||||
|
||||||||||||
function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::MvNormal) | ||||||||||||
y = mvnormal_c0(d) | ||||||||||||
function mvnormal_c0_pullback(dy) | ||||||||||||
∂d = ChainRulesCore.@thunk(begin | ||||||||||||
dy = ChainRulesCore.unthunk(dy) | ||||||||||||
∂Σ = -dy/2 * invcov(d) | ||||||||||||
ChainRulesCore.Tangent{typeof(d)}(μ = ChainRulesCore.ZeroTangent(), Σ = ∂Σ) | ||||||||||||
end) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No thunk 🙂 |
||||||||||||
return ChainRulesCore.NoTangent(), ∂d | ||||||||||||
end | ||||||||||||
return y, mvnormal_c0_pullback | ||||||||||||
end | ||||||||||||
|
||||||||||||
function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::MvNormal, x::AbstractVector) | ||||||||||||
y = sqmahal(d, x) | ||||||||||||
Δy = ChainRulesCore.@thunk(begin | ||||||||||||
(_, Δd, Δx) = dargs | ||||||||||||
Δd = ChainRulesCore.unthunk(Δd) | ||||||||||||
Δx = ChainRulesCore.unthunk(Δx) | ||||||||||||
Σinv = invcov(d) | ||||||||||||
# TODO optimize | ||||||||||||
dΣ = -dot(Σinv * Δd.Σ * Σinv, x * x' - d.μ * x' - x * d.μ' + d.μ * d.μ') | ||||||||||||
dx = 2 * dot(Σinv * (x - d.μ), Δx) | ||||||||||||
dμ = 2 * dot(Σinv * (d.μ - x), Δd.μ) | ||||||||||||
dΣ + dx + dμ | ||||||||||||
end) | ||||||||||||
matbesancon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
return (y, Δy) | ||||||||||||
end | ||||||||||||
|
||||||||||||
function ChainRulesCore.rrule(::typeof(sqmahal), d::MvNormal, x::AbstractVector) | ||||||||||||
y = sqmahal(d, x) | ||||||||||||
function sqmahal_pullback(dy) | ||||||||||||
∂x = ChainRulesCore.@thunk(begin | ||||||||||||
dy = ChainRulesCore.unthunk(dy) | ||||||||||||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
Σinv = invcov(d) | ||||||||||||
matbesancon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
2dy * Σinv * (x - d.μ) | ||||||||||||
end) | ||||||||||||
∂d = ChainRulesCore.@thunk(begin | ||||||||||||
dy = ChainRulesCore.unthunk(dy) | ||||||||||||
Σinv = invcov(d) | ||||||||||||
cx = x - d.μ | ||||||||||||
∂μ = -2dy * Σinv * cx | ||||||||||||
∂J = dy * cx * cx' | ||||||||||||
∂Σ = - Σinv * ∂J * Σinv | ||||||||||||
ChainRulesCore.Tangent{typeof(d)}(μ = ∂μ, Σ = ∂Σ) | ||||||||||||
end) | ||||||||||||
return (ChainRulesCore.NoTangent(), ∂d, ∂x) | ||||||||||||
end | ||||||||||||
return y, sqmahal_pullback | ||||||||||||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# Tests on Multivariate Normal distributions | ||
|
||
import PDMats | ||
import PDMats: ScalMat, PDiagMat, PDMat | ||
if isdefined(PDMats, :PDSparseMat) | ||
import PDMats: PDSparseMat | ||
|
@@ -9,6 +10,8 @@ using Distributions | |
using LinearAlgebra, Random, Test | ||
using SparseArrays | ||
using FillArrays | ||
using ChainRulesCore | ||
using ChainRulesTestUtils | ||
|
||
###### General Testing | ||
|
||
|
@@ -302,3 +305,67 @@ end | |
x = rand(d) | ||
@test logpdf(d, x) ≈ logpdf(Normal(), x[1]) + logpdf(Normal(), x[2]) | ||
end | ||
|
||
@testset "MvNormal differentiation rules" begin | ||
for n in (3, 10) | ||
for _ in 1:10 | ||
A = Symmetric(rand(n,n)) .+ 4 * Matrix(I, n, n) | ||
@assert isposdef(A) | ||
d = MvNormal(randn(n), A) | ||
# make ΔΣ symmetric, such that Σ ± ΔΣ is PSD | ||
t = 0.001 * ChainRulesTestUtils.rand_tangent(d) | ||
t.Σ .+= t.Σ' | ||
if eigmin(t.Σ) < 0 | ||
while eigmin(d.Σ + t.Σ) < 0 | ||
t.Σ .*= 0.8 | ||
end | ||
end | ||
if eigmax(t.Σ) > 0 | ||
while eigmin(d.Σ - t.Σ) < 0 | ||
t.Σ .*= 0.8 | ||
end | ||
end | ||
# mvnormal_c0 | ||
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t), Distributions.mvnormal_c0, d) | ||
y_r, c0_pullback = @inferred ChainRulesCore.rrule(Distributions.mvnormal_c0, d) | ||
@test y_r ≈ y | ||
y2 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ + t.Σ)) | ||
@test unthunk(Δy) ≈ y2 - y atol= n * 1e-4 | ||
y3 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ - t.Σ)) | ||
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 | ||
(_, ∇c0) = c0_pullback(1.0) | ||
∇c0 = ChainRulesCore.unthunk(∇c0) | ||
@test dot(∇c0.Σ, t.Σ) ≈ y2 - y atol = n * 1e-4 | ||
@test dot(∇c0.Σ, t.Σ) ≈ y - y3 atol = n * 1e-4 | ||
# sqmahal | ||
x = randn(n) | ||
Δx = 0.0001 * randn(n) | ||
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), sqmahal, d, x) | ||
(yr, sqmahal_pullback) = @inferred ChainRulesCore.rrule(sqmahal, d, x) | ||
(_, ∇s_d, ∇s_x) = @inferred sqmahal_pullback(1.0) | ||
∇s_d = ChainRulesCore.unthunk(∇s_d) | ||
∇s_x = ChainRulesCore.unthunk(∇s_x) | ||
@test yr ≈ y | ||
y2 = Distributions.sqmahal(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) | ||
y3 = Distributions.sqmahal(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) | ||
@test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 | ||
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 | ||
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4 | ||
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4 | ||
# _logpdf | ||
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), Distributions._logpdf, d, x) | ||
(yr, logpdf_MvNormal_pullback) = @inferred ChainRulesCore.rrule(Distributions._logpdf, d, x) | ||
@test y ≈ yr | ||
# inference broken | ||
# (_, ∇s_d, ∇s_x) = @inferred logpdf_MvNormal_pullback(1.0) | ||
(_, ∇s_d, ∇s_x) = logpdf_MvNormal_pullback(1.0) | ||
|
||
y2 = Distributions._logpdf(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) | ||
y3 = Distributions._logpdf(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) | ||
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 | ||
@test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 | ||
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4 | ||
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems very complicated. Ideally we should just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is nontrivial in the case of the perturbation of the covariance matrix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, sometimes one has to add custom tests. But the standard should be |
||
end | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move unrelated changes to a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there were all relatively minor things (unused variables)