-
Notifications
You must be signed in to change notification settings - Fork 421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Differentiating mvnormal #1554
base: master
Are you sure you want to change the base?
Differentiating mvnormal #1554
Conversation
Codecov ReportBase: 85.95% // Head: 86.02% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## master #1554 +/- ##
==========================================
+ Coverage 85.95% 86.02% +0.06%
==========================================
Files 129 129
Lines 8105 8144 +39
==========================================
+ Hits 6967 7006 +39
Misses 1138 1138
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the rules for sqmahal
etc. needed? Would it maybe be sufficient to just make sure that PDMats works and is optimized for ChainRules?
src/multivariate/mvnormal.jl
Outdated
c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d) | ||
sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should use rather frule_via_ad
here for calling back into the AD system, in case it wants to define its own, possibly improved derivatives.
src/multivariate/mvnormal.jl
Outdated
return c0 - sq/2, ChainRulesCore.@thunk(begin | ||
Δc0 = ChainRulesCore.unthunk(Δc0) | ||
Δsq = ChainRulesCore.unthunk(Δsq) | ||
Δc0 - Δsq/2 | ||
end) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Derivatives should not be thunked if there's only one of them.
src/multivariate/mvnormal.jl
Outdated
c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d) | ||
sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, this should probably be rrule_via_ad
.
src/multivariate/mvnormal.jl
Outdated
Δy = ChainRulesCore.@thunk(begin | ||
Δd = ChainRulesCore.unthunk(Δd) | ||
-dot(Δd.Σ, invcov(d)) / 2 | ||
end) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, no thunks.
Δy = ChainRulesCore.@thunk(begin | |
Δd = ChainRulesCore.unthunk(Δd) | |
-dot(Δd.Σ, invcov(d)) / 2 | |
end) | |
Δy = -dot(Δd.Σ, invcov(d)) / 2 |
src/multivariate/mvnormal.jl
Outdated
∂d = ChainRulesCore.@thunk(begin | ||
dy = ChainRulesCore.unthunk(dy) | ||
∂Σ = -dy/2 * invcov(d) | ||
ChainRulesCore.Tangent{typeof(d)}(μ = ChainRulesCore.ZeroTangent(), Σ = ∂Σ) | ||
end) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No thunk 🙂
test/mvnormal.jl
Outdated
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t), Distributions.mvnormal_c0, d) | ||
y_r, c0_pullback = @inferred ChainRulesCore.rrule(Distributions.mvnormal_c0, d) | ||
@test y_r ≈ y | ||
y2 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ + t.Σ)) | ||
@test unthunk(Δy) ≈ y2 - y atol= n * 1e-4 | ||
y3 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ - t.Σ)) | ||
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 | ||
(_, ∇c0) = c0_pullback(1.0) | ||
∇c0 = ChainRulesCore.unthunk(∇c0) | ||
@test dot(∇c0.Σ, t.Σ) ≈ y2 - y atol = n * 1e-4 | ||
@test dot(∇c0.Σ, t.Σ) ≈ y - y3 atol = n * 1e-4 | ||
# sqmahal | ||
x = randn(n) | ||
Δx = 0.0001 * randn(n) | ||
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), sqmahal, d, x) | ||
(yr, sqmahal_pullback) = @inferred ChainRulesCore.rrule(sqmahal, d, x) | ||
(_, ∇s_d, ∇s_x) = @inferred sqmahal_pullback(1.0) | ||
∇s_d = ChainRulesCore.unthunk(∇s_d) | ||
∇s_x = ChainRulesCore.unthunk(∇s_x) | ||
@test yr ≈ y | ||
y2 = Distributions.sqmahal(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) | ||
y3 = Distributions.sqmahal(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) | ||
@test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 | ||
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 | ||
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4 | ||
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4 | ||
# _logpdf | ||
(y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), Distributions._logpdf, d, x) | ||
(yr, logpdf_MvNormal_pullback) = @inferred ChainRulesCore.rrule(Distributions._logpdf, d, x) | ||
@test y ≈ yr | ||
# inference broken | ||
# (_, ∇s_d, ∇s_x) = @inferred logpdf_MvNormal_pullback(1.0) | ||
(_, ∇s_d, ∇s_x) = logpdf_MvNormal_pullback(1.0) | ||
|
||
y2 = Distributions._logpdf(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) | ||
y3 = Distributions._logpdf(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) | ||
@test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 | ||
@test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 | ||
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4 | ||
@test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems very complicated. Ideally we should just use test_rrule
and test_frule
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is nontrivial in the case of the perturbation of the covariance matrix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, sometimes one has to add custom tests. But the standard should be test_frule
and test_rrule
since they check multiple other parts of the CR interface as well, in addition to numerical accuracy and type inference. And even for cases that are problematic for finite differencing (e.g. due to singularities and domain constraints) it is sometimes possible to use test_frule
and test_rrule
by specifying custom finite differencing methods, such as e.g. in #1555.
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
src/multivariate/mvnormal.jl
Outdated
function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector) | ||
c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d) | ||
sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x) | ||
Δc0 = ChainRulesCore.unthunk(Δc0) | ||
Δsq = ChainRulesCore.unthunk(Δsq) | ||
return c0 - sq/2, Δc0 - Δsq/2 | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's useful to add this definition. This is exactly what AD systems do anyway.
function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector) | |
c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d) | |
sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x) | |
Δc0 = ChainRulesCore.unthunk(Δc0) | |
Δsq = ChainRulesCore.unthunk(Δsq) | |
return c0 - sq/2, Δc0 - Δsq/2 | |
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't cost us much to add this definition and lets us have derivatives built-in, we can also re-add specialized methods for some MvNormal if necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it can be quite problematic to define derivatives that overrule the AD system if they are not needed and e.g. to generic (as possibly the case here). I've ran into multiple issues of this kind with ChainRules, which then requires e.g. packages that otherwise would just work (without even knowing about ChainRules) to use ChainRules.@opt_out
or define their own rules. The rule here will catch every AbstractMvNormal
and AbstractVector
which can be problematic e.g. similar to TuringLang/DistributionsAD.jl#180.
So I strongly recommend not adding rules that are not needed.
src/multivariate/mvnormal.jl
Outdated
function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector) | ||
c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d) | ||
sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x) | ||
function logpdf_MvNormal_pullback(dy) | ||
dy = ChainRulesCore.unthunk(dy) | ||
(_, ∂d_c0) = c0_pullback(dy) | ||
∂d_c0 = ChainRulesCore.unthunk(∂d_c0) | ||
(_, ∂d_sq, ∂x_sq) = sq_pullback(dy) | ||
∂d_sq = ChainRulesCore.unthunk(∂d_sq) | ||
∂x_sq = ChainRulesCore.unthunk(∂x_sq) | ||
backing = NamedTuple{(:μ, :Σ), Tuple{typeof(∂d_sq.μ), typeof(∂d_sq.Σ)}}(( | ||
(∂d_c0.μ - 0.5 * ∂d_sq.μ), | ||
(∂d_c0.Σ - 0.5 * ∂d_sq.Σ), | ||
)) | ||
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) | ||
return ChainRulesCore.NoTangent(), ∂d, ∂x_sq / (-2) | ||
end | ||
return c0 - sq / 2, logpdf_MvNormal_pullback | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, it seems this is exactly what AD does if no rule is defined.
function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector) | |
c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d) | |
sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x) | |
function logpdf_MvNormal_pullback(dy) | |
dy = ChainRulesCore.unthunk(dy) | |
(_, ∂d_c0) = c0_pullback(dy) | |
∂d_c0 = ChainRulesCore.unthunk(∂d_c0) | |
(_, ∂d_sq, ∂x_sq) = sq_pullback(dy) | |
∂d_sq = ChainRulesCore.unthunk(∂d_sq) | |
∂x_sq = ChainRulesCore.unthunk(∂x_sq) | |
backing = NamedTuple{(:μ, :Σ), Tuple{typeof(∂d_sq.μ), typeof(∂d_sq.Σ)}}(( | |
(∂d_c0.μ - 0.5 * ∂d_sq.μ), | |
(∂d_c0.Σ - 0.5 * ∂d_sq.Σ), | |
)) | |
∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) | |
return ChainRulesCore.NoTangent(), ∂d, ∂x_sq / (-2) | |
end | |
return c0 - sq / 2, logpdf_MvNormal_pullback | |
end |
src/multivariate/mvnormal.jl
Outdated
@@ -372,7 +372,7 @@ struct MvNormalStats <: SufficientStats | |||
tw::Float64 # total sample weight | |||
end | |||
|
|||
function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}) | |||
function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move unrelated changes to a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there were all relatively minor things (unused variables)
One of the instability issues seems to be that the type of the members of Tangent{FullNormal} don't seem to be inferrable |
Alright I could fix inference but it took some assumptions which might be a bit restrictive |
@devmotion I think everything discussed has been adapted/validated |
src/multivariate/mvnormal.jl
Outdated
@@ -372,7 +372,7 @@ struct MvNormalStats <: SufficientStats | |||
tw::Float64 # total sample weight | |||
end | |||
|
|||
function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}) | |||
function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you revert non-CR changes? It seems not only unused names were removed but also some types and dispatches changed, creating slight inconsistencies and related to the open issue about type parameters in fit
. IMO it would br much cleaner to avoid these additonal changes in this PR here and instead fix the dispatches (and names) in a separate PR in a consistent way.
src/multivariate/mvnormal.jl
Outdated
(_, Δd, Δx) = dargs | ||
Δd = ChainRulesCore.unthunk(Δd) | ||
Δx = ChainRulesCore.unthunk(Δx) | ||
Σinv = invcov(d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we avoid computing the inverse?
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
not computing the inverse at all would be a bit of a hassle here since it's reused several times in the whole function. I removed the materialization of the inverse as a matrix here to lighten the computations, we are using the inverse of the Cholesky decomposition directly |
That change here: b225298 |
Co-authored-by: David Widmann <[email protected]>
Reverted the changes, the inverse covariance matrix is needed here, there is no way around it since the derivative is just the invcov scaled |
I would say not since the operations done on top are non-trivial and could be costly |
ping @devmotion for another round of review |
No description provided.