From 789ad0ad514dd64c6ce5b4421bfe2c3a77ec7e07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 18 Apr 2022 16:51:46 +0200 Subject: [PATCH 01/24] test and code for frule --- src/multivariate/mvnormal.jl | 55 ++++++++++++++++++++++++++++++------ test/mvnormal.jl | 45 +++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 9 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 36992d3bd..73ed6b482 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -253,7 +253,7 @@ Base.show(io::IO, d::MvNormal) = length(d::MvNormal) = length(d.μ) mean(d::MvNormal) = d.μ params(d::MvNormal) = (d.μ, d.Σ) -@inline partype(d::MvNormal{T}) where {T<:Real} = T +@inline partype(::MvNormal{T}) where {T<:Real} = T var(d::MvNormal) = diag(d.Σ) cov(d::MvNormal) = Matrix(d.Σ) @@ -382,7 +382,7 @@ function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}) MvNormalStats(s, m, s2, Float64(n)) end -function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) +function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) d = size(x, 1) n = size(x, 2) length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions.")) @@ -410,11 +410,11 @@ end # each kind of covariance # -fit_mle(D::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss) -fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x) -fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w) +fit_mle(::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss) +fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x) +fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w) -fit_mle(D::Type{FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw)) +fit_mle(::Type{<:FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw)) function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}) n = size(x, 2) @@ -425,7 +425,7 @@ function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}) MvNormal(mu, PDMat(C)) end -function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) +function fit_mle(::Type{<:FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) m = size(x, 1) n = size(x, 2) length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions")) @@ -460,7 +460,7 @@ function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}) MvNormal(mu, PDiagMat(va)) end -function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) +function fit_mle(::Type{<:DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) m = size(x, 1) n = size(x, 2) length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions")) @@ -495,7 +495,7 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}) MvNormal(mu, ScalMat(m, va / (m * n))) end -function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) +function fit_mle(::Type{<:IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) m = size(x, 1) n = size(x, 2) length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions")) @@ -515,3 +515,40 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVect end MvNormal(mu, ScalMat(m, va / (m * sw))) end + +## Differentiation + +function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector) + c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d) + sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x) + return c0 - sq/2, ChainRulesCore.@thunk(begin + Δc0 = ChainRulesCore.unthunk(Δc0) + Δsq = ChainRulesCore.unthunk(Δsq) + Δc0 - Δsq/2 + end) +end + +function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::AbstractMvNormal) + y = mvnormal_c0(d) + Δy = ChainRulesCore.@thunk(begin + Δd = ChainRulesCore.unthunk(Δd) + -dot(Δd.Σ, invcov(d)) / 2 + end) + return y, Δy +end + +function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::AbstractMvNormal, x::AbstractVector) + y = sqmahal(d, x) + Δy = ChainRulesCore.@thunk(begin + (_, Δd, Δx) = dargs + Δd = ChainRulesCore.unthunk(Δd) + Δx = ChainRulesCore.unthunk(Δx) + Σinv = inv(d.Σ) + # TODO optimize + dΣ = -dot(Σinv * Δd.Σ * Σinv, x * x' - d.μ * x' - x * d.μ' + d.μ * d.μ') + dx = 2 * dot(Σinv * (x - d.μ), Δx) + dμ = 2 * dot(Σinv * (d.μ - x), Δd.μ) + dΣ + dx + dμ + end) + return (y, Δy) +end diff --git a/test/mvnormal.jl b/test/mvnormal.jl index 1386764a6..4916af67d 100644 --- a/test/mvnormal.jl +++ b/test/mvnormal.jl @@ -9,6 +9,8 @@ using Distributions using LinearAlgebra, Random, Test using SparseArrays using FillArrays +using ChainRulesCore +using ChainRulesTestUtils ###### General Testing @@ -302,3 +304,46 @@ end x = rand(d) @test logpdf(d, x) ≈ logpdf(Normal(), x[1]) + logpdf(Normal(), x[2]) end + +@testset "MvNormal differentiation rules" begin + for n in (3, 10) + for _ in 1:10 + A = Symmetric(rand(n,n)) .+ 4 * Matrix(I, n, n) + @assert isposdef(A) + d = MvNormal(randn(n), A) + # make ΔΣ symmetric, such that Σ ± ΔΣ is PSD + t = 0.001 * ChainRulesTestUtils.rand_tangent(d) + t.Σ .+= t.Σ' + if eigmin(t.Σ) < 0 + while eigmin(d.Σ + t.Σ) < 0 + t.Σ .*= 0.8 + end + end + if eigmax(t.Σ) > 0 + while eigmin(d.Σ - t.Σ) < 0 + t.Σ .*= 0.8 + end + end + # mvnormal_c0 + (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t), Distributions.mvnormal_c0, d) + y2 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ + t.Σ)) + @test unthunk(Δy) ≈ y2 - y atol= n * 1e-4 + y3 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ - t.Σ)) + @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 + # sqmahal + x = randn(n) + Δx = 0.001 * randn(n) + (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), sqmahal, d, x) + y2 = Distributions.sqmahal(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) + @test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 + y3 = Distributions.sqmahal(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) + @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 + # _logpdf + (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), Distributions._logpdf, d, x) + y2 = Distributions._logpdf(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) + y3 = Distributions._logpdf(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) + @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 + @test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 + end + end +end From 24861ca765decf944a35ea443c779fc84468c38d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 24 Apr 2022 22:49:52 +0200 Subject: [PATCH 02/24] added tests --- src/multivariate/mvnormal.jl | 57 +++++++++++++++++++++++++++++++++++- test/mvnormal.jl | 26 ++++++++++++++-- 2 files changed, 80 insertions(+), 3 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 73ed6b482..56e709d57 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -528,6 +528,26 @@ function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpd end) end +function ChainRulesCore.rrule(::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector) + c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d) + sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x) + function logpdf_MvNormal_pullback(dy) + dy = ChainRulesCore.unthunk(dy) + (_, ∂d_c0) = c0_pullback(dy) + ∂d_c0 = ChainRulesCore.unthunk(∂d_c0) + (_, ∂d_sq, ∂x_sq) = sq_pullback(dy) + ∂d_sq = ChainRulesCore.unthunk(∂d_sq) + ∂x_sq = ChainRulesCore.unthunk(∂x_sq) + backing = NamedTuple{(:μ, :Σ), Tuple{typeof(∂d_sq.μ), typeof(∂d_sq.Σ)}}(( + (∂d_c0.μ - 0.5 * ∂d_sq.μ), + (∂d_c0.Σ - 0.5 * ∂d_sq.Σ), + )) + ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) + return ChainRulesCore.NoTangent(), ∂d, - 0.5 * ∂x_sq + end + return c0 - 0.5 * sq, logpdf_MvNormal_pullback +end + function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::AbstractMvNormal) y = mvnormal_c0(d) Δy = ChainRulesCore.@thunk(begin @@ -537,13 +557,26 @@ function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d return y, Δy end +function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::AbstractMvNormal) + y = mvnormal_c0(d) + function mvnormal_c0_pullback(dy) + ∂d = ChainRulesCore.@thunk(begin + dy = ChainRulesCore.unthunk(dy) + ∂Σ = -dy/2 * invcov(d) + ChainRulesCore.Tangent{typeof(d)}(μ = ChainRulesCore.ZeroTangent(), Σ = ∂Σ) + end) + return ChainRulesCore.NoTangent(), ∂d + end + return y, mvnormal_c0_pullback +end + function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::AbstractMvNormal, x::AbstractVector) y = sqmahal(d, x) Δy = ChainRulesCore.@thunk(begin (_, Δd, Δx) = dargs Δd = ChainRulesCore.unthunk(Δd) Δx = ChainRulesCore.unthunk(Δx) - Σinv = inv(d.Σ) + Σinv = invcov(d) # TODO optimize dΣ = -dot(Σinv * Δd.Σ * Σinv, x * x' - d.μ * x' - x * d.μ' + d.μ * d.μ') dx = 2 * dot(Σinv * (x - d.μ), Δx) @@ -552,3 +585,25 @@ function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::A end) return (y, Δy) end + +function ChainRulesCore.rrule(::typeof(sqmahal), d::AbstractMvNormal, x::AbstractVector) + y = sqmahal(d, x) + function sqmahal_pullback(dy) + ∂x = ChainRulesCore.@thunk(begin + dy = ChainRulesCore.unthunk(dy) + Σinv = invcov(d) + 2dy * Σinv * (x - d.μ) + end) + ∂d = ChainRulesCore.@thunk(begin + dy = ChainRulesCore.unthunk(dy) + Σinv = invcov(d) + cx = x - d.μ + ∂μ = -2dy * Σinv * cx + ∂J = dy * cx * cx' + ∂Σ = - Σinv * ∂J * Σinv + ChainRulesCore.Tangent{typeof(d)}(μ = ∂μ, Σ = ∂Σ) + end) + return (ChainRulesCore.NoTangent(), ∂d, ∂x) + end + return y, sqmahal_pullback +end diff --git a/test/mvnormal.jl b/test/mvnormal.jl index 4916af67d..d11198be9 100644 --- a/test/mvnormal.jl +++ b/test/mvnormal.jl @@ -1,5 +1,6 @@ # Tests on Multivariate Normal distributions +import PDMats import PDMats: ScalMat, PDiagMat, PDMat if isdefined(PDMats, :PDSparseMat) import PDMats: PDSparseMat @@ -326,24 +327,45 @@ end end # mvnormal_c0 (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t), Distributions.mvnormal_c0, d) + y_r, c0_pullback = @inferred ChainRulesCore.rrule(Distributions.mvnormal_c0, d) + @test y_r ≈ y y2 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ + t.Σ)) @test unthunk(Δy) ≈ y2 - y atol= n * 1e-4 y3 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ - t.Σ)) @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 + (_, ∇c0) = c0_pullback(1.0) + ∇c0 = ChainRulesCore.unthunk(∇c0) + @test dot(∇c0.Σ, t.Σ) ≈ y2 - y atol = n * 1e-4 + @test dot(∇c0.Σ, t.Σ) ≈ y - y3 atol = n * 1e-4 # sqmahal x = randn(n) - Δx = 0.001 * randn(n) + Δx = 0.0001 * randn(n) (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), sqmahal, d, x) + (yr, sqmahal_pullback) = @inferred ChainRulesCore.rrule(sqmahal, d, x) + (_, ∇s_d, ∇s_x) = @inferred sqmahal_pullback(1.0) + ∇s_d = ChainRulesCore.unthunk(∇s_d) + ∇s_x = ChainRulesCore.unthunk(∇s_x) + @test yr ≈ y y2 = Distributions.sqmahal(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) - @test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 y3 = Distributions.sqmahal(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) + @test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 + @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4 + @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4 # _logpdf (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), Distributions._logpdf, d, x) + (yr, logpdf_MvNormal_pullback) = @inferred ChainRulesCore.rrule(Distributions._logpdf, d, x) + @test y ≈ yr + # inference broken + # (_, ∇s_d, ∇s_x) = @inferred logpdf_MvNormal_pullback(1.0) + (_, ∇s_d, ∇s_x) = logpdf_MvNormal_pullback(1.0) + y2 = Distributions._logpdf(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) y3 = Distributions._logpdf(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 @test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 + @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4 + @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4 end end end From 2bcc217ab34fdfe0a2fb40dc1af134054d5c4f2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 25 Apr 2022 22:23:24 +0200 Subject: [PATCH 03/24] diff on MvNormal only for now --- src/multivariate/mvnormal.jl | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 56e709d57..d97eddc68 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -372,7 +372,7 @@ struct MvNormalStats <: SufficientStats tw::Float64 # total sample weight end -function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}) +function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}) d = size(x, 1) n = size(x, 2) s = vec(sum(x, dims=2)) @@ -416,7 +416,7 @@ fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) fit_mle(::Type{<:FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw)) -function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}) +function fit_mle(::Type{FullNormal}, x::AbstractMatrix{Float64}) n = size(x, 2) mu = vec(mean(x, dims=2)) z = x .- mu @@ -445,7 +445,7 @@ function fit_mle(::Type{<:FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVe MvNormal(mu, PDMat(C)) end -function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}) +function fit_mle(::Type{DiagNormal}, x::AbstractMatrix{Float64}) m = size(x, 1) n = size(x, 2) @@ -479,7 +479,7 @@ function fit_mle(::Type{<:DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVe MvNormal(mu, PDiagMat(va)) end -function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}) +function fit_mle(::Type{IsoNormal}, x::AbstractMatrix{Float64}) m = size(x, 1) n = size(x, 2) @@ -528,7 +528,7 @@ function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpd end) end -function ChainRulesCore.rrule(::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector) +function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector) c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d) sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x) function logpdf_MvNormal_pullback(dy) @@ -548,7 +548,7 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::AbstractMvNormal, x::Abstrac return c0 - 0.5 * sq, logpdf_MvNormal_pullback end -function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::AbstractMvNormal) +function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::MvNormal) y = mvnormal_c0(d) Δy = ChainRulesCore.@thunk(begin Δd = ChainRulesCore.unthunk(Δd) @@ -557,7 +557,7 @@ function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d return y, Δy end -function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::AbstractMvNormal) +function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::MvNormal) y = mvnormal_c0(d) function mvnormal_c0_pullback(dy) ∂d = ChainRulesCore.@thunk(begin @@ -570,7 +570,7 @@ function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::AbstractMvNormal) return y, mvnormal_c0_pullback end -function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::AbstractMvNormal, x::AbstractVector) +function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::MvNormal, x::AbstractVector) y = sqmahal(d, x) Δy = ChainRulesCore.@thunk(begin (_, Δd, Δx) = dargs @@ -586,7 +586,7 @@ function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::A return (y, Δy) end -function ChainRulesCore.rrule(::typeof(sqmahal), d::AbstractMvNormal, x::AbstractVector) +function ChainRulesCore.rrule(::typeof(sqmahal), d::MvNormal, x::AbstractVector) y = sqmahal(d, x) function sqmahal_pullback(dy) ∂x = ChainRulesCore.@thunk(begin From 1e9571b0fcd9587311b344b6b26602aeb4580d06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Mon, 23 May 2022 09:47:17 -0400 Subject: [PATCH 04/24] unthunk when only one element --- src/multivariate/mvnormal.jl | 44 +++++++++++++++--------------------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index d97eddc68..4184a5148 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -521,11 +521,9 @@ end function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector) c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d) sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x) - return c0 - sq/2, ChainRulesCore.@thunk(begin - Δc0 = ChainRulesCore.unthunk(Δc0) - Δsq = ChainRulesCore.unthunk(Δsq) - Δc0 - Δsq/2 - end) + Δc0 = ChainRulesCore.unthunk(Δc0) + Δsq = ChainRulesCore.unthunk(Δsq) + return c0 - sq/2, Δc0 - Δsq/2 end function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector) @@ -550,39 +548,33 @@ end function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::MvNormal) y = mvnormal_c0(d) - Δy = ChainRulesCore.@thunk(begin - Δd = ChainRulesCore.unthunk(Δd) - -dot(Δd.Σ, invcov(d)) / 2 - end) + Δd = ChainRulesCore.unthunk(Δd) + Δy = -dot(Δd.Σ, invcov(d)) / 2 return y, Δy end function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::MvNormal) y = mvnormal_c0(d) function mvnormal_c0_pullback(dy) - ∂d = ChainRulesCore.@thunk(begin - dy = ChainRulesCore.unthunk(dy) - ∂Σ = -dy/2 * invcov(d) - ChainRulesCore.Tangent{typeof(d)}(μ = ChainRulesCore.ZeroTangent(), Σ = ∂Σ) - end) + dy = ChainRulesCore.unthunk(dy) + ∂Σ = -dy/2 * invcov(d) + ∂d = ChainRulesCore.Tangent{typeof(d)}(μ = ChainRulesCore.ZeroTangent(), Σ = ∂Σ) return ChainRulesCore.NoTangent(), ∂d end return y, mvnormal_c0_pullback end function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::MvNormal, x::AbstractVector) - y = sqmahal(d, x) - Δy = ChainRulesCore.@thunk(begin - (_, Δd, Δx) = dargs - Δd = ChainRulesCore.unthunk(Δd) - Δx = ChainRulesCore.unthunk(Δx) - Σinv = invcov(d) - # TODO optimize - dΣ = -dot(Σinv * Δd.Σ * Σinv, x * x' - d.μ * x' - x * d.μ' + d.μ * d.μ') - dx = 2 * dot(Σinv * (x - d.μ), Δx) - dμ = 2 * dot(Σinv * (d.μ - x), Δd.μ) - dΣ + dx + dμ - end) + y = sqmahal(d, x) + (_, Δd, Δx) = dargs + Δd = ChainRulesCore.unthunk(Δd) + Δx = ChainRulesCore.unthunk(Δx) + Σinv = invcov(d) + # TODO optimize + dΣ = -dot(Σinv * Δd.Σ * Σinv, x * x' - d.μ * x' - x * d.μ' + d.μ * d.μ') + dx = 2 * dot(Σinv * (x - d.μ), Δx) + dμ = 2 * dot(Σinv * (d.μ - x), Δd.μ) + Δy = dΣ + dx + dμ return (y, Δy) end From ac0995eb5ac432acdaa3cbdc6afdb5ac9f5cea0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Tue, 24 May 2022 03:54:57 +0200 Subject: [PATCH 05/24] Update src/multivariate/mvnormal.jl Co-authored-by: David Widmann --- src/multivariate/mvnormal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 4184a5148..a2b0f1347 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -543,7 +543,7 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector) ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) return ChainRulesCore.NoTangent(), ∂d, - 0.5 * ∂x_sq end - return c0 - 0.5 * sq, logpdf_MvNormal_pullback + return c0 - sq / 2, logpdf_MvNormal_pullback end function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::MvNormal) From 522c13e0c49bec8f4be4a62ecf502771f113c53f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Tue, 24 May 2022 03:55:08 +0200 Subject: [PATCH 06/24] Update src/multivariate/mvnormal.jl Co-authored-by: David Widmann --- src/multivariate/mvnormal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index a2b0f1347..a9dd2b4bd 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -541,7 +541,7 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector) (∂d_c0.Σ - 0.5 * ∂d_sq.Σ), )) ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) - return ChainRulesCore.NoTangent(), ∂d, - 0.5 * ∂x_sq + return ChainRulesCore.NoTangent(), ∂d, ∂x_sq / (-2) end return c0 - sq / 2, logpdf_MvNormal_pullback end From 6529a4a28af747066a88da2da2ee0015354794c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Tue, 24 May 2022 10:35:17 -0400 Subject: [PATCH 07/24] no backing --- src/multivariate/mvnormal.jl | 9 ++++----- test/mvnormal.jl | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 4184a5148..3194554dd 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -536,11 +536,10 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector) (_, ∂d_sq, ∂x_sq) = sq_pullback(dy) ∂d_sq = ChainRulesCore.unthunk(∂d_sq) ∂x_sq = ChainRulesCore.unthunk(∂x_sq) - backing = NamedTuple{(:μ, :Σ), Tuple{typeof(∂d_sq.μ), typeof(∂d_sq.Σ)}}(( - (∂d_c0.μ - 0.5 * ∂d_sq.μ), - (∂d_c0.Σ - 0.5 * ∂d_sq.Σ), - )) - ∂d = ChainRulesCore.Tangent{typeof(d), typeof(backing)}(backing) + ∂d = ChainRulesCore.Tangent{typeof(d)}(; + μ = ∂d_c0.μ - 0.5 * ∂d_sq.μ, + Σ = ∂d_c0.Σ - 0.5 * ∂d_sq.Σ, + ) return ChainRulesCore.NoTangent(), ∂d, - 0.5 * ∂x_sq end return c0 - 0.5 * sq, logpdf_MvNormal_pullback diff --git a/test/mvnormal.jl b/test/mvnormal.jl index d11198be9..baac33f15 100644 --- a/test/mvnormal.jl +++ b/test/mvnormal.jl @@ -333,7 +333,7 @@ end @test unthunk(Δy) ≈ y2 - y atol= n * 1e-4 y3 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ - t.Σ)) @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 - (_, ∇c0) = c0_pullback(1.0) + (_, ∇c0) = @inferred c0_pullback(1.0) ∇c0 = ChainRulesCore.unthunk(∇c0) @test dot(∇c0.Σ, t.Σ) ≈ y2 - y atol = n * 1e-4 @test dot(∇c0.Σ, t.Σ) ≈ y - y3 atol = n * 1e-4 From d3981404b048fff1916df1b32e2289f773bf7ab2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Tue, 24 May 2022 13:17:30 -0400 Subject: [PATCH 08/24] fix inference --- src/multivariate/mvnormal.jl | 12 +++++++----- test/mvnormal.jl | 4 +--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 9c3d12805..0651af54e 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -534,13 +534,15 @@ function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector) (_, ∂d_c0) = c0_pullback(dy) ∂d_c0 = ChainRulesCore.unthunk(∂d_c0) (_, ∂d_sq, ∂x_sq) = sq_pullback(dy) - ∂d_sq = ChainRulesCore.unthunk(∂d_sq) - ∂x_sq = ChainRulesCore.unthunk(∂x_sq) + ∂d_sq_v = ChainRulesCore.unthunk(∂d_sq) + ∂x_sq_v::typeof(x) = ChainRulesCore.unthunk(∂x_sq) + μs::typeof(d.μ) = ∂d_sq_v.μ + Σs::Matrix{partype(d)} = ∂d_sq_v.Σ ∂d = ChainRulesCore.Tangent{typeof(d)}(; - μ = ∂d_c0.μ - 0.5 * ∂d_sq.μ, - Σ = ∂d_c0.Σ - 0.5 * ∂d_sq.Σ, + μ = ∂d_c0.μ - 0.5 * μs, + Σ = ∂d_c0.Σ - 0.5 * Σs, ) - return ChainRulesCore.NoTangent(), ∂d, -∂x_sq / 2 + return ChainRulesCore.NoTangent(), ∂d, -∂x_sq_v / 2 end return c0 - sq / 2, logpdf_MvNormal_pullback end diff --git a/test/mvnormal.jl b/test/mvnormal.jl index baac33f15..9cf713e11 100644 --- a/test/mvnormal.jl +++ b/test/mvnormal.jl @@ -356,9 +356,7 @@ end (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), Distributions._logpdf, d, x) (yr, logpdf_MvNormal_pullback) = @inferred ChainRulesCore.rrule(Distributions._logpdf, d, x) @test y ≈ yr - # inference broken - # (_, ∇s_d, ∇s_x) = @inferred logpdf_MvNormal_pullback(1.0) - (_, ∇s_d, ∇s_x) = logpdf_MvNormal_pullback(1.0) + (_, ∇s_d, ∇s_x) = @inferred logpdf_MvNormal_pullback(1.0) y2 = Distributions._logpdf(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) y3 = Distributions._logpdf(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) From 661de16361d2c65dfe22e43bfe263d355df4d015 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Thu, 26 May 2022 09:22:19 -0400 Subject: [PATCH 09/24] unthunk common computation --- src/multivariate/mvnormal.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 0651af54e..f19535016 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -582,14 +582,13 @@ end function ChainRulesCore.rrule(::typeof(sqmahal), d::MvNormal, x::AbstractVector) y = sqmahal(d, x) function sqmahal_pullback(dy) + Σinv = invcov(d) ∂x = ChainRulesCore.@thunk(begin dy = ChainRulesCore.unthunk(dy) - Σinv = invcov(d) 2dy * Σinv * (x - d.μ) end) ∂d = ChainRulesCore.@thunk(begin dy = ChainRulesCore.unthunk(dy) - Σinv = invcov(d) cx = x - d.μ ∂μ = -2dy * Σinv * cx ∂J = dy * cx * cx' From 3c5007c339f2f94b582ce9bf914794faca79a360 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Thu, 26 May 2022 18:37:11 -0400 Subject: [PATCH 10/24] unthunk common computation --- src/multivariate/mvnormal.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index f19535016..8440ac553 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -583,12 +583,11 @@ function ChainRulesCore.rrule(::typeof(sqmahal), d::MvNormal, x::AbstractVector) y = sqmahal(d, x) function sqmahal_pullback(dy) Σinv = invcov(d) + dy = ChainRulesCore.unthunk(dy) ∂x = ChainRulesCore.@thunk(begin - dy = ChainRulesCore.unthunk(dy) 2dy * Σinv * (x - d.μ) end) ∂d = ChainRulesCore.@thunk(begin - dy = ChainRulesCore.unthunk(dy) cx = x - d.μ ∂μ = -2dy * Σinv * cx ∂J = dy * cx * cx' From 1f18e67182f6ee0ce98ae4c6fd2ca24ef3b3e686 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Fri, 27 May 2022 11:11:05 -0400 Subject: [PATCH 11/24] remove rules for logpdf --- src/multivariate/mvnormal.jl | 29 ----------------------------- test/mvnormal.jl | 12 ------------ 2 files changed, 41 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 8440ac553..f2bcdac60 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -518,35 +518,6 @@ end ## Differentiation -function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(_logpdf), d::AbstractMvNormal, x::AbstractVector) - c0, Δc0 = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd), mvnormal_c0, d) - sq, Δsq = ChainRulesCore.frule((ChainRulesCore.NoTangent(), Δd, Δx), sqmahal, d, x) - Δc0 = ChainRulesCore.unthunk(Δc0) - Δsq = ChainRulesCore.unthunk(Δsq) - return c0 - sq/2, Δc0 - Δsq/2 -end - -function ChainRulesCore.rrule(::typeof(_logpdf), d::MvNormal, x::AbstractVector) - c0, c0_pullback = ChainRulesCore.rrule(mvnormal_c0, d) - sq, sq_pullback = ChainRulesCore.rrule(sqmahal, d, x) - function logpdf_MvNormal_pullback(dy) - dy = ChainRulesCore.unthunk(dy) - (_, ∂d_c0) = c0_pullback(dy) - ∂d_c0 = ChainRulesCore.unthunk(∂d_c0) - (_, ∂d_sq, ∂x_sq) = sq_pullback(dy) - ∂d_sq_v = ChainRulesCore.unthunk(∂d_sq) - ∂x_sq_v::typeof(x) = ChainRulesCore.unthunk(∂x_sq) - μs::typeof(d.μ) = ∂d_sq_v.μ - Σs::Matrix{partype(d)} = ∂d_sq_v.Σ - ∂d = ChainRulesCore.Tangent{typeof(d)}(; - μ = ∂d_c0.μ - 0.5 * μs, - Σ = ∂d_c0.Σ - 0.5 * Σs, - ) - return ChainRulesCore.NoTangent(), ∂d, -∂x_sq_v / 2 - end - return c0 - sq / 2, logpdf_MvNormal_pullback -end - function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::MvNormal) y = mvnormal_c0(d) Δd = ChainRulesCore.unthunk(Δd) diff --git a/test/mvnormal.jl b/test/mvnormal.jl index 9cf713e11..76a9262d6 100644 --- a/test/mvnormal.jl +++ b/test/mvnormal.jl @@ -352,18 +352,6 @@ end @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4 @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4 - # _logpdf - (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), Distributions._logpdf, d, x) - (yr, logpdf_MvNormal_pullback) = @inferred ChainRulesCore.rrule(Distributions._logpdf, d, x) - @test y ≈ yr - (_, ∇s_d, ∇s_x) = @inferred logpdf_MvNormal_pullback(1.0) - - y2 = Distributions._logpdf(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) - y3 = Distributions._logpdf(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) - @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 - @test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 - @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4 - @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4 end end end From b60147de53d76a52e2fbe831d534d5b976c9e28f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sat, 30 Jul 2022 17:14:15 +0200 Subject: [PATCH 12/24] revert changes --- src/multivariate/mvnormal.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index f2bcdac60..df6196e46 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -253,7 +253,7 @@ Base.show(io::IO, d::MvNormal) = length(d::MvNormal) = length(d.μ) mean(d::MvNormal) = d.μ params(d::MvNormal) = (d.μ, d.Σ) -@inline partype(::MvNormal{T}) where {T<:Real} = T +@inline partype(d::MvNormal{T}) where {T<:Real} = T var(d::MvNormal) = diag(d.Σ) cov(d::MvNormal) = Matrix(d.Σ) @@ -372,7 +372,7 @@ struct MvNormalStats <: SufficientStats tw::Float64 # total sample weight end -function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}) +function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}) d = size(x, 1) n = size(x, 2) s = vec(sum(x, dims=2)) @@ -382,7 +382,7 @@ function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}) MvNormalStats(s, m, s2, Float64(n)) end -function suffstats(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) +function suffstats(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) d = size(x, 1) n = size(x, 2) length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions.")) @@ -410,13 +410,13 @@ end # each kind of covariance # -fit_mle(::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss) -fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x) -fit_mle(::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w) +fit_mle(D::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss) +fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x) +fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w) -fit_mle(::Type{<:FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw)) +fit_mle(F::Type{FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw)) -function fit_mle(::Type{FullNormal}, x::AbstractMatrix{Float64}) +function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}) n = size(x, 2) mu = vec(mean(x, dims=2)) z = x .- mu @@ -425,7 +425,7 @@ function fit_mle(::Type{FullNormal}, x::AbstractMatrix{Float64}) MvNormal(mu, PDMat(C)) end -function fit_mle(::Type{<:FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) +function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) m = size(x, 1) n = size(x, 2) length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions")) @@ -445,7 +445,7 @@ function fit_mle(::Type{<:FullNormal}, x::AbstractMatrix{Float64}, w::AbstractVe MvNormal(mu, PDMat(C)) end -function fit_mle(::Type{DiagNormal}, x::AbstractMatrix{Float64}) +function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}) m = size(x, 1) n = size(x, 2) @@ -460,7 +460,7 @@ function fit_mle(::Type{DiagNormal}, x::AbstractMatrix{Float64}) MvNormal(mu, PDiagMat(va)) end -function fit_mle(::Type{<:DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) +function fit_mle(D::Type{DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) m = size(x, 1) n = size(x, 2) length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions")) @@ -479,7 +479,7 @@ function fit_mle(::Type{<:DiagNormal}, x::AbstractMatrix{Float64}, w::AbstractVe MvNormal(mu, PDiagMat(va)) end -function fit_mle(::Type{IsoNormal}, x::AbstractMatrix{Float64}) +function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}) m = size(x, 1) n = size(x, 2) @@ -495,7 +495,7 @@ function fit_mle(::Type{IsoNormal}, x::AbstractMatrix{Float64}) MvNormal(mu, ScalMat(m, va / (m * n))) end -function fit_mle(::Type{<:IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) +function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVector) m = size(x, 1) n = size(x, 2) length(w) == n || throw(DimensionMismatch("Inconsistent argument dimensions")) From 375cad85a90abfbb36563c614f01e2e752ad10b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sat, 30 Jul 2022 17:15:27 +0200 Subject: [PATCH 13/24] Update src/multivariate/mvnormal.jl Co-authored-by: David Widmann --- src/multivariate/mvnormal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index df6196e46..075082921 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -530,7 +530,7 @@ function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::MvNormal) function mvnormal_c0_pullback(dy) dy = ChainRulesCore.unthunk(dy) ∂Σ = -dy/2 * invcov(d) - ∂d = ChainRulesCore.Tangent{typeof(d)}(μ = ChainRulesCore.ZeroTangent(), Σ = ∂Σ) + ∂d = ChainRulesCore.Tangent{typeof(d)}(Σ = ∂Σ) return ChainRulesCore.NoTangent(), ∂d end return y, mvnormal_c0_pullback From c34a3ed92a9ad8db5a332a8f51d22358121346e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sat, 30 Jul 2022 17:19:25 +0200 Subject: [PATCH 14/24] Update src/multivariate/mvnormal.jl Co-authored-by: David Widmann --- src/multivariate/mvnormal.jl | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 075082921..9444472f6 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -552,17 +552,16 @@ end function ChainRulesCore.rrule(::typeof(sqmahal), d::MvNormal, x::AbstractVector) y = sqmahal(d, x) - function sqmahal_pullback(dy) - Σinv = invcov(d) - dy = ChainRulesCore.unthunk(dy) - ∂x = ChainRulesCore.@thunk(begin - 2dy * Σinv * (x - d.μ) - end) + Σ = _cov(d) + cx = x - d.μ + z = Σ \ cx + function sqmahal_pullback(_dy) + dy = ChainRulesCore.unthunk(_dy) + ∂x = 2 * dy * z ∂d = ChainRulesCore.@thunk(begin - cx = x - d.μ - ∂μ = -2dy * Σinv * cx + ∂μ = -∂x ∂J = dy * cx * cx' - ∂Σ = - Σinv * ∂J * Σinv + ∂Σ = - (Σ \ ∂J) / Σ ChainRulesCore.Tangent{typeof(d)}(μ = ∂μ, Σ = ∂Σ) end) return (ChainRulesCore.NoTangent(), ∂d, ∂x) From 2d966809ffda41df0906bcfdfba85de7647836c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sat, 30 Jul 2022 17:20:06 +0200 Subject: [PATCH 15/24] revert changes --- src/multivariate/mvnormal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index df6196e46..4cd91ce06 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -414,7 +414,7 @@ fit_mle(D::Type{MvNormal}, ss::MvNormalStats) = fit_mle(FullNormal, ss) fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}) = fit_mle(FullNormal, x) fit_mle(D::Type{MvNormal}, x::AbstractMatrix{Float64}, w::AbstractArray{Float64}) = fit_mle(FullNormal, x, w) -fit_mle(F::Type{FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw)) +fit_mle(D::Type{FullNormal}, ss::MvNormalStats) = MvNormal(ss.m, ss.s2 * inv(ss.tw)) function fit_mle(D::Type{FullNormal}, x::AbstractMatrix{Float64}) n = size(x, 2) From b225298aad7d57f96d560da7606987b3cd41e434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sat, 30 Jul 2022 17:24:52 +0200 Subject: [PATCH 16/24] Update src/multivariate/mvnormal.jl Co-authored-by: David Widmann --- src/multivariate/mvnormal.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 35e7acf2a..79ff877b9 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -543,10 +543,10 @@ function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::M Δx = ChainRulesCore.unthunk(Δx) Σinv = invcov(d) # TODO optimize - dΣ = -dot(Σinv * Δd.Σ * Σinv, x * x' - d.μ * x' - x * d.μ' + d.μ * d.μ') - dx = 2 * dot(Σinv * (x - d.μ), Δx) - dμ = 2 * dot(Σinv * (d.μ - x), Δd.μ) - Δy = dΣ + dx + dμ + z = x - d.μ + dΣ = -dot(Xt_A_X(Σinv, Δd.Σ), z * z') + dx_dμ = 2 * dot(Σinv * z, Δx - Δd.μ) + Δy = dΣ + dx_dμ return (y, Δy) end From 6bc85da60ee4cde78ae4b4f0a04b6b9a9d5ce061 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 31 Jul 2022 09:04:40 +0200 Subject: [PATCH 17/24] avoid materializing Matrix --- src/multivariate/mvnormal.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 3f759cf32..3d39a48aa 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -541,10 +541,9 @@ function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::M (_, Δd, Δx) = dargs Δd = ChainRulesCore.unthunk(Δd) Δx = ChainRulesCore.unthunk(Δx) - Σinv = invcov(d) - # TODO optimize + Σinv = inv(_cov(d)) z = x - d.μ - dΣ = -dot(Xt_A_X(Σinv, Δd.Σ), z * z') + dΣ = -dot(PDMats.Xt_A_X(Σinv, Δd.Σ), z * z') dx_dμ = 2 * dot(Σinv * z, Δx - Δd.μ) Δy = dΣ + dx_dμ return (y, Δy) From e303416ba2091a1c522a671da0a850d597a626b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 31 Jul 2022 09:09:48 +0200 Subject: [PATCH 18/24] revert --- src/multivariate/mvnormal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 3d39a48aa..2272bd1e4 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -541,7 +541,7 @@ function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::M (_, Δd, Δx) = dargs Δd = ChainRulesCore.unthunk(Δd) Δx = ChainRulesCore.unthunk(Δx) - Σinv = inv(_cov(d)) + Σinv = invcov(d) z = x - d.μ dΣ = -dot(PDMats.Xt_A_X(Σinv, Δd.Σ), z * z') dx_dμ = 2 * dot(Σinv * z, Δx - Δd.μ) From e21506aa4a5718a94664f35750f5f14f046a8dc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 31 Jul 2022 09:22:30 +0200 Subject: [PATCH 19/24] fix revert --- src/multivariate/mvnormal.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 2272bd1e4..3f759cf32 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -542,8 +542,9 @@ function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::M Δd = ChainRulesCore.unthunk(Δd) Δx = ChainRulesCore.unthunk(Δx) Σinv = invcov(d) + # TODO optimize z = x - d.μ - dΣ = -dot(PDMats.Xt_A_X(Σinv, Δd.Σ), z * z') + dΣ = -dot(Xt_A_X(Σinv, Δd.Σ), z * z') dx_dμ = 2 * dot(Σinv * z, Δx - Δd.μ) Δy = dΣ + dx_dμ return (y, Δy) From fd272ae5bc428657e9acf8c20d3907ccc3d50dc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 31 Jul 2022 09:28:51 +0200 Subject: [PATCH 20/24] no alloc --- src/multivariate/mvnormal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 3f759cf32..d3d333cef 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -541,7 +541,7 @@ function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::M (_, Δd, Δx) = dargs Δd = ChainRulesCore.unthunk(Δd) Δx = ChainRulesCore.unthunk(Δx) - Σinv = invcov(d) + Σinv = inv(_cov(d)) # TODO optimize z = x - d.μ dΣ = -dot(Xt_A_X(Σinv, Δd.Σ), z * z') From 8b7d45174259e58827a3cb4d6dcf48f00df4dcc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 31 Jul 2022 09:43:08 +0200 Subject: [PATCH 21/24] revert fdiff --- src/multivariate/mvnormal.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index d3d333cef..885540a81 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -543,10 +543,10 @@ function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::M Δx = ChainRulesCore.unthunk(Δx) Σinv = inv(_cov(d)) # TODO optimize - z = x - d.μ - dΣ = -dot(Xt_A_X(Σinv, Δd.Σ), z * z') - dx_dμ = 2 * dot(Σinv * z, Δx - Δd.μ) - Δy = dΣ + dx_dμ + dΣ = -dot(Σinv * Δd.Σ * Σinv, x * x' - d.μ * x' - x * d.μ' + d.μ * d.μ') + dx = 2 * dot(Σinv * (x - d.μ), Δx) + dμ = 2 * dot(Σinv * (d.μ - x), Δd.μ) + Δy = dΣ + dx + dμ return (y, Δy) end From 53601fef8a615f5f0375147e2931195e91e1d2f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 31 Jul 2022 20:50:04 +0200 Subject: [PATCH 22/24] Update src/multivariate/mvnormal.jl Co-authored-by: David Widmann --- src/multivariate/mvnormal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 885540a81..ae5dd66b6 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -529,7 +529,7 @@ function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::MvNormal) y = mvnormal_c0(d) function mvnormal_c0_pullback(dy) dy = ChainRulesCore.unthunk(dy) - ∂Σ = -dy/2 * invcov(d) + ∂Σ = (dy / (-2)) / _cov(d) ∂d = ChainRulesCore.Tangent{typeof(d)}(Σ = ∂Σ) return ChainRulesCore.NoTangent(), ∂d end From 2c75061a1e8f9d8f60e6393801b4f72e32ab0d74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 31 Jul 2022 21:01:05 +0200 Subject: [PATCH 23/24] fix op --- src/multivariate/mvnormal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index ae5dd66b6..bb9922126 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -529,7 +529,7 @@ function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::MvNormal) y = mvnormal_c0(d) function mvnormal_c0_pullback(dy) dy = ChainRulesCore.unthunk(dy) - ∂Σ = (dy / (-2)) / _cov(d) + ∂Σ = _cov(d) \ (dy / (-2)) ∂d = ChainRulesCore.Tangent{typeof(d)}(Σ = ∂Σ) return ChainRulesCore.NoTangent(), ∂d end From 6d88e4db6ee69728f72dde007eb97401b31d0ace Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mathieu=20Besan=C3=A7on?= Date: Sun, 31 Jul 2022 21:09:03 +0200 Subject: [PATCH 24/24] fix op bis, revert to invcov --- src/multivariate/mvnormal.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index bb9922126..630838024 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -529,7 +529,7 @@ function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::MvNormal) y = mvnormal_c0(d) function mvnormal_c0_pullback(dy) dy = ChainRulesCore.unthunk(dy) - ∂Σ = _cov(d) \ (dy / (-2)) + ∂Σ = (dy / (-2)) * invcov(d) ∂d = ChainRulesCore.Tangent{typeof(d)}(Σ = ∂Σ) return ChainRulesCore.NoTangent(), ∂d end