From 3349cdb42abc77bd3f4584ba0b740915a32cda14 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Thu, 9 Jan 2020 14:34:18 +1100 Subject: [PATCH 1/2] fix MvNormal and LogMvNormal perf --- src/multivariate.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/multivariate.jl b/src/multivariate.jl index eb112748..522eb048 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -51,9 +51,6 @@ end function _logpdf(d::TuringMvNormal, x::AbstractMatrix) return -(dim(d) * log(2π) .+ logdet(d.C) .+ sum(abs2, zygote_ldiv(d.C.U', x .- d.m), dims=1)') ./ 2 end -function _logpdf(d::MvNormal, x::Union{Tracker.TrackedVector, Tracker.TrackedMatrix}) - _logpdf(TuringMvNormal(d.μ, getchol(d.Σ)), x) -end # zero mean, dense covariance MvNormal(A::TrackedMatrix) = MvNormal(zeros(size(A, 1)), A) @@ -120,9 +117,6 @@ end function _logpdf(d::TuringMvLogNormal, x::AbstractVecOrMat{T}) where {T<:Real} return insupport(d, x) ? (_logpdf(d.normal, log.(x)) - sum(log.(x))) : -Inf end -function _logpdf(d::MvLogNormal, x::Union{Tracker.TrackedVector, Tracker.TrackedMatrix}) - _logpdf(TuringMvLogNormal(TuringMvNormal(d.normal.μ, getchol(d.normal.Σ))), x) -end # zero mean, dense covariance MvLogNormal(A::TrackedMatrix) = MvLogNormal(zeros(size(A, 1)), A) From 522c7d3e12f949b059cfb9ca32da86ade3358e91 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Thu, 9 Jan 2020 14:39:56 +1100 Subject: [PATCH 2/2] fix it for real --- src/multivariate.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/multivariate.jl b/src/multivariate.jl index 522eb048..4c96a902 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -35,9 +35,6 @@ for T in (:AbstractVector, :AbstractMatrix) @eval Distributions.logpdf(d::TuringDiagNormal, x::$T) = _logpdf(d, x) @eval Distributions.logpdf(d::TuringMvNormal, x::$T) = _logpdf(d, x) end -for T in (:(Tracker.TrackedVector), :(Tracker.TrackedMatrix)) - @eval Distributions.logpdf(d::MvNormal, x::$T) = _logpdf(d, x) -end function _logpdf(d::TuringDiagNormal, x::AbstractVector) return -(dim(d) * log(2π) + 2 * sum(log.(d.σ)) + sum(abs2, (x .- d.m) ./ d.σ)) / 2