diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index 6126c1d8f..630838024 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -515,3 +515,56 @@ function fit_mle(D::Type{IsoNormal}, x::AbstractMatrix{Float64}, w::AbstractVect end MvNormal(mu, ScalMat(m, va / (m * sw))) end + +## Differentiation + +function ChainRulesCore.frule((_, Δd)::Tuple{Any,Any}, ::typeof(mvnormal_c0), d::MvNormal) + y = mvnormal_c0(d) + Δd = ChainRulesCore.unthunk(Δd) + Δy = -dot(Δd.Σ, invcov(d)) / 2 + return y, Δy +end + +function ChainRulesCore.rrule(::typeof(mvnormal_c0), d::MvNormal) + y = mvnormal_c0(d) + function mvnormal_c0_pullback(dy) + dy = ChainRulesCore.unthunk(dy) + ∂Σ = (dy / (-2)) * invcov(d) + ∂d = ChainRulesCore.Tangent{typeof(d)}(Σ = ∂Σ) + return ChainRulesCore.NoTangent(), ∂d + end + return y, mvnormal_c0_pullback +end + +function ChainRulesCore.frule(dargs::Tuple{Any,Any,Any}, ::typeof(sqmahal), d::MvNormal, x::AbstractVector) + y = sqmahal(d, x) + (_, Δd, Δx) = dargs + Δd = ChainRulesCore.unthunk(Δd) + Δx = ChainRulesCore.unthunk(Δx) + Σinv = inv(_cov(d)) + # TODO optimize + dΣ = -dot(Σinv * Δd.Σ * Σinv, x * x' - d.μ * x' - x * d.μ' + d.μ * d.μ') + dx = 2 * dot(Σinv * (x - d.μ), Δx) + dμ = 2 * dot(Σinv * (d.μ - x), Δd.μ) + Δy = dΣ + dx + dμ + return (y, Δy) +end + +function ChainRulesCore.rrule(::typeof(sqmahal), d::MvNormal, x::AbstractVector) + y = sqmahal(d, x) + Σ = _cov(d) + cx = x - d.μ + z = Σ \ cx + function sqmahal_pullback(_dy) + dy = ChainRulesCore.unthunk(_dy) + ∂x = 2 * dy * z + ∂d = ChainRulesCore.@thunk(begin + ∂μ = -∂x + ∂J = dy * cx * cx' + ∂Σ = - (Σ \ ∂J) / Σ + ChainRulesCore.Tangent{typeof(d)}(μ = ∂μ, Σ = ∂Σ) + end) + return (ChainRulesCore.NoTangent(), ∂d, ∂x) + end + return y, sqmahal_pullback +end diff --git a/test/multivariate/mvnormal.jl b/test/multivariate/mvnormal.jl index 1386764a6..76a9262d6 100644 --- a/test/multivariate/mvnormal.jl +++ b/test/multivariate/mvnormal.jl @@ -1,5 +1,6 @@ # Tests on Multivariate Normal distributions +import PDMats import PDMats: ScalMat, PDiagMat, PDMat if isdefined(PDMats, :PDSparseMat) import PDMats: PDSparseMat @@ -9,6 +10,8 @@ using Distributions using LinearAlgebra, Random, Test using SparseArrays using FillArrays +using ChainRulesCore +using ChainRulesTestUtils ###### General Testing @@ -302,3 +305,53 @@ end x = rand(d) @test logpdf(d, x) ≈ logpdf(Normal(), x[1]) + logpdf(Normal(), x[2]) end + +@testset "MvNormal differentiation rules" begin + for n in (3, 10) + for _ in 1:10 + A = Symmetric(rand(n,n)) .+ 4 * Matrix(I, n, n) + @assert isposdef(A) + d = MvNormal(randn(n), A) + # make ΔΣ symmetric, such that Σ ± ΔΣ is PSD + t = 0.001 * ChainRulesTestUtils.rand_tangent(d) + t.Σ .+= t.Σ' + if eigmin(t.Σ) < 0 + while eigmin(d.Σ + t.Σ) < 0 + t.Σ .*= 0.8 + end + end + if eigmax(t.Σ) > 0 + while eigmin(d.Σ - t.Σ) < 0 + t.Σ .*= 0.8 + end + end + # mvnormal_c0 + (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t), Distributions.mvnormal_c0, d) + y_r, c0_pullback = @inferred ChainRulesCore.rrule(Distributions.mvnormal_c0, d) + @test y_r ≈ y + y2 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ + t.Σ)) + @test unthunk(Δy) ≈ y2 - y atol= n * 1e-4 + y3 = Distributions.mvnormal_c0(MvNormal(d.μ, d.Σ - t.Σ)) + @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 + (_, ∇c0) = @inferred c0_pullback(1.0) + ∇c0 = ChainRulesCore.unthunk(∇c0) + @test dot(∇c0.Σ, t.Σ) ≈ y2 - y atol = n * 1e-4 + @test dot(∇c0.Σ, t.Σ) ≈ y - y3 atol = n * 1e-4 + # sqmahal + x = randn(n) + Δx = 0.0001 * randn(n) + (y, Δy) = @inferred ChainRulesCore.frule((ChainRulesCore.NoTangent(), t, Δx), sqmahal, d, x) + (yr, sqmahal_pullback) = @inferred ChainRulesCore.rrule(sqmahal, d, x) + (_, ∇s_d, ∇s_x) = @inferred sqmahal_pullback(1.0) + ∇s_d = ChainRulesCore.unthunk(∇s_d) + ∇s_x = ChainRulesCore.unthunk(∇s_x) + @test yr ≈ y + y2 = Distributions.sqmahal(MvNormal(d.μ + t.μ, d.Σ + t.Σ), x + Δx) + y3 = Distributions.sqmahal(MvNormal(d.μ - t.μ, d.Σ - t.Σ), x - Δx) + @test unthunk(Δy) ≈ y2 - y atol = n * 1e-4 + @test unthunk(Δy) ≈ y - y3 atol = n * 1e-4 + @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y2 - y atol = n * 1e-4 + @test dot(∇s_d.Σ, t.Σ) + dot(∇s_d.μ, t.μ) + dot(∇s_x, Δx) ≈ y - y3 atol = n * 1e-4 + end + end +end