Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiating mvnormal #1554

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open

Differentiating mvnormal #1554

wants to merge 31 commits into from

Conversation

matbesancon
Copy link
Member

No description provided.

@codecov-commenter
Copy link

codecov-commenter commented May 23, 2022

Codecov Report

Base: 85.95% // Head: 86.02% // Increases project coverage by +0.06% 🎉

Coverage data is based on head (8ebf419) compared to base (a31ebc4).
Patch coverage: 100.00% of modified lines in pull request are covered.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1554      +/-   ##
==========================================
+ Coverage   85.95%   86.02%   +0.06%     
==========================================
  Files         129      129              
  Lines        8105     8144      +39     
==========================================
+ Hits         6967     7006      +39     
  Misses       1138     1138              
Impacted Files Coverage Δ
src/multivariate/mvnormal.jl 80.93% <100.00%> (+3.41%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the rules for sqmahal etc. needed? Would it maybe be sufficient to just make sure that PDMats works and is optimized for ChainRules?

Comment on lines 522 to 523
c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d)
sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should use rather frule_via_ad here for calling back into the AD system, in case it wants to define its own, possibly improved derivatives.

Comment on lines 524 to 528
return c0 - sq/2, ChainRulesCore.@thunk(begin
Δc0 = ChainRulesCore.unthunk(Δc0)
Δsq = ChainRulesCore.unthunk(Δsq)
Δc0 - Δsq/2
end)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Derivatives should not be thunked if there's only one of them.

Comment on lines 532 to 533
c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d)
sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, this should probably be rrule_via_ad.

src/multivariate/mvnormal.jl Outdated Show resolved Hide resolved
src/multivariate/mvnormal.jl Outdated Show resolved Hide resolved
src/multivariate/mvnormal.jl Outdated Show resolved Hide resolved
Comment on lines 553 to 556
Δy = ChainRulesCore.@thunk(begin
Δd = ChainRulesCore.unthunk(Δd)
-dot(Δd.Σ, invcov(d)) / 2
end)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, no thunks.

Suggested change
Δy = ChainRulesCore.@thunk(begin
Δd = ChainRulesCore.unthunk(Δd)
-dot(Δd.Σ, invcov(d)) / 2
end)
Δy = -dot(Δd.Σ, invcov(d)) / 2

Comment on lines 563 to 567
∂d = ChainRulesCore.@thunk(begin
dy = ChainRulesCore.unthunk(dy)
∂Σ = -dy/2 * invcov(d)
ChainRulesCore.Tangent{typeof(d)}(μ = ChainRulesCore.ZeroTangent(), Σ = ∂Σ)
end)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No thunk 🙂

src/multivariate/mvnormal.jl Outdated Show resolved Hide resolved
src/multivariate/mvnormal.jl Outdated Show resolved Hide resolved
test/mvnormal.jl Outdated
Comment on lines 329 to 368
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t), Distributions.mvnormal_c0, d)
y_r, c0_pullback = @inferred ChainRulesCore.rrule(Distributions.mvnormal_c0, d)
@test y_r ≈ y
y2 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ + t.Σ))
@test unthunk(Δy) ≈ y2 - y atol= n * 1e-4
y3 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ - t.Σ))
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
(_, ∇c0) = c0_pullback(1.0)
∇c0 = ChainRulesCore.unthunk(∇c0)
@test dot(∇c0.Σ, t.Σ) ≈ y2 - y atol = n * 1e-4
@test dot(∇c0.Σ, t.Σ) ≈ y - y3 atol = n * 1e-4
# sqmahal
x = randn(n)
Δx = 0.0001 * randn(n)
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), sqmahal, d, x)
(yr, sqmahal_pullback) = @inferred ChainRulesCore.rrule(sqmahal, d, x)
(_, ∇s_d, ∇s_x) = @inferred sqmahal_pullback(1.0)
∇s_d = ChainRulesCore.unthunk(∇s_d)
∇s_x = ChainRulesCore.unthunk(∇s_x)
@test yr ≈ y
y2 = Distributions.sqmahal(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx)
y3 = Distributions.sqmahal(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx)
@test unthunk(Δy) ≈ y2 - y atol = n * 1e-4
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4
# _logpdf
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), Distributions._logpdf, d, x)
(yr, logpdf_MvNormal_pullback) = @inferred ChainRulesCore.rrule(Distributions._logpdf, d, x)
@test y ≈ yr
# inference broken
# (_, ∇s_d, ∇s_x) = @inferred logpdf_MvNormal_pullback(1.0)
(_, ∇s_d, ∇s_x) = logpdf_MvNormal_pullback(1.0)

y2 = Distributions._logpdf(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx)
y3 = Distributions._logpdf(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx)
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4
@test unthunk(Δy) ≈ y2 - y atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very complicated. Ideally we should just use test_rrule and test_frule.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is nontrivial in the case of the perturbation of the covariance matrix

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sometimes one has to add custom tests. But the standard should be test_frule and test_rrule since they check multiple other parts of the CR interface as well, in addition to numerical accuracy and type inference. And even for cases that are problematic for finite differencing (e.g. due to singularities and domain constraints) it is sometimes possible to use test_frule and test_rrule by specifying custom finite differencing methods, such as e.g. in #1555.

Comment on lines 521 to 527
function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector)
c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d)
sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x)
Δc0 = ChainRulesCore.unthunk(Δc0)
Δsq = ChainRulesCore.unthunk(Δsq)
return c0 - sq/2, Δc0 - Δsq/2
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's useful to add this definition. This is exactly what AD systems do anyway.

Suggested change
function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector)
c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d)
sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x)
Δc0 = ChainRulesCore.unthunk(Δc0)
Δsq = ChainRulesCore.unthunk(Δsq)
return c0 - sq/2, Δc0 - Δsq/2
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't cost us much to add this definition and lets us have derivatives built-in, we can also re-add specialized methods for some MvNormal if necessary

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it can be quite problematic to define derivatives that overrule the AD system if they are not needed and e.g. to generic (as possibly the case here). I've ran into multiple issues of this kind with ChainRules, which then requires e.g. packages that otherwise would just work (without even knowing about ChainRules) to use ChainRules.@opt_out or define their own rules. The rule here will catch every AbstractMvNormal and AbstractVector which can be problematic e.g. similar to TuringLang/DistributionsAD.jl#180.

So I strongly recommend not adding rules that are not needed.

Comment on lines 529 to 547
function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector)
c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d)
sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x)
function logpdf_MvNormal_pullback(dy)
dy = ChainRulesCore.unthunk(dy)
(_, ∂d_c0) = c0_pullback(dy)
∂d_c0 = ChainRulesCore.unthunk(∂d_c0)
(_, ∂d_sq, ∂x_sq) = sq_pullback(dy)
∂d_sq = ChainRulesCore.unthunk(∂d_sq)
∂x_sq = ChainRulesCore.unthunk(∂x_sq)
backing = NamedTuple{(:μ, :Σ), Tuple{typeof(∂d_sq.μ), typeof(∂d_sq.Σ)}}((
(∂d_c0.μ - 0.5 * ∂d_sq.μ),
(∂d_c0.Σ - 0.5 * ∂d_sq.Σ),
))
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing)
return ChainRulesCore.NoTangent(), ∂d, ∂x_sq / (-2)
end
return c0 - sq / 2, logpdf_MvNormal_pullback
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, it seems this is exactly what AD does if no rule is defined.

Suggested change
function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector)
c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d)
sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x)
function logpdf_MvNormal_pullback(dy)
dy = ChainRulesCore.unthunk(dy)
(_, ∂d_c0) = c0_pullback(dy)
∂d_c0 = ChainRulesCore.unthunk(∂d_c0)
(_, ∂d_sq, ∂x_sq) = sq_pullback(dy)
∂d_sq = ChainRulesCore.unthunk(∂d_sq)
∂x_sq = ChainRulesCore.unthunk(∂x_sq)
backing = NamedTuple{(:μ, :Σ), Tuple{typeof(∂d_sq.μ), typeof(∂d_sq.Σ)}}((
(∂d_c0.μ - 0.5 * ∂d_sq.μ),
(∂d_c0.Σ - 0.5 * ∂d_sq.Σ),
))
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing)
return ChainRulesCore.NoTangent(), ∂d, ∂x_sq / (-2)
end
return c0 - sq / 2, logpdf_MvNormal_pullback
end

@@ -372,7 +372,7 @@ struct MvNormalStats <: SufficientStats
tw::Float64 # total sample weight
end

function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64})
function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move unrelated changes to a separate PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there were all relatively minor things (unused variables)

@matbesancon
Copy link
Member Author

│   %34 = Base.getproperty(∂d_sq::Tangent{FullNormal}, :μ)::Any
│   %35 = (0.5 * %34)::Any
│   %36 = (%33 - %35)::Any
│   %37 = Base.getproperty(∂d_c0, :Σ)::Matrix{Float64}
│   %38 = Base.getproperty(∂d_sq::Tangent{FullNormal}, :Σ)::Any

One of the instability issues seems to be that the type of the members of Tangent{FullNormal} don't seem to be inferrable

@matbesancon
Copy link
Member Author

Alright I could fix inference but it took some assumptions which might be a bit restrictive

@matbesancon matbesancon marked this pull request as draft May 28, 2022 14:53
@matbesancon matbesancon requested a review from devmotion July 28, 2022 20:18
@matbesancon matbesancon marked this pull request as ready for review July 28, 2022 20:18
@matbesancon
Copy link
Member Author

@devmotion I think everything discussed has been adapted/validated

@@ -372,7 +372,7 @@ struct MvNormalStats <: SufficientStats
tw::Float64 # total sample weight
end

function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64})
function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert non-CR changes? It seems not only unused names were removed but also some types and dispatches changed, creating slight inconsistencies and related to the open issue about type parameters in fit. IMO it would br much cleaner to avoid these additonal changes in this PR here and instead fix the dispatches (and names) in a separate PR in a consistent way.

src/multivariate/mvnormal.jl Outdated Show resolved Hide resolved
src/multivariate/mvnormal.jl Outdated Show resolved Hide resolved
src/multivariate/mvnormal.jl Show resolved Hide resolved
(_, Δd, Δx) = dargs
Δd = ChainRulesCore.unthunk(Δd)
Δx = ChainRulesCore.unthunk(Δx)
Σinv = invcov(d)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we avoid computing the inverse?

src/multivariate/mvnormal.jl Outdated Show resolved Hide resolved
@matbesancon
Copy link
Member Author

not computing the inverse at all would be a bit of a hassle here since it's reused several times in the whole function. I removed the materialization of the inverse as a matrix here to lighten the computations, we are using the inverse of the Cholesky decomposition directly

@matbesancon
Copy link
Member Author

That change here: b225298
was somehow incorrect

@matbesancon
Copy link
Member Author

Reverted the changes, the inverse covariance matrix is needed here, there is no way around it since the derivative is just the invcov scaled

@matbesancon
Copy link
Member Author

Are the rules for sqmahal etc. needed? Would it maybe be sufficient to just make sure that PDMats works and is optimized for ChainRules?

I would say not since the operations done on top are non-trivial and could be costly

@matbesancon
Copy link
Member Author

ping @devmotion for another round of review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants