Skip to content

Commit

Permalink
Merge pull request #18 from TuringLang/mt/mvnormal
Browse files Browse the repository at this point in the history
Fix MvNormal and LogMvNormal perf
  • Loading branch information
mohamed82008 authored Jan 9, 2020
2 parents d0a1063 + 522c7d3 commit 97936bf
Showing 1 changed file with 0 additions and 9 deletions.
9 changes: 0 additions & 9 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ for T in (:AbstractVector, :AbstractMatrix)
@eval Distributions.logpdf(d::TuringDiagNormal, x::$T) = _logpdf(d, x)
@eval Distributions.logpdf(d::TuringMvNormal, x::$T) = _logpdf(d, x)
end
for T in (:(Tracker.TrackedVector), :(Tracker.TrackedMatrix))
@eval Distributions.logpdf(d::MvNormal, x::$T) = _logpdf(d, x)
end

function _logpdf(d::TuringDiagNormal, x::AbstractVector)
return -(dim(d) * log(2π) + 2 * sum(log.(d.σ)) + sum(abs2, (x .- d.m) ./ d.σ)) / 2
Expand All @@ -51,9 +48,6 @@ end
function _logpdf(d::TuringMvNormal, x::AbstractMatrix)
return -(dim(d) * log(2π) .+ logdet(d.C) .+ sum(abs2, zygote_ldiv(d.C.U', x .- d.m), dims=1)') ./ 2
end
function _logpdf(d::MvNormal, x::Union{Tracker.TrackedVector, Tracker.TrackedMatrix})
_logpdf(TuringMvNormal(d.μ, getchol(d.Σ)), x)
end

# zero mean, dense covariance
MvNormal(A::TrackedMatrix) = MvNormal(zeros(size(A, 1)), A)
Expand Down Expand Up @@ -120,9 +114,6 @@ end
function _logpdf(d::TuringMvLogNormal, x::AbstractVecOrMat{T}) where {T<:Real}
return insupport(d, x) ? (_logpdf(d.normal, log.(x)) - sum(log.(x))) : -Inf
end
function _logpdf(d::MvLogNormal, x::Union{Tracker.TrackedVector, Tracker.TrackedMatrix})
_logpdf(TuringMvLogNormal(TuringMvNormal(d.normal.μ, getchol(d.normal.Σ))), x)
end

# zero mean, dense covariance
MvLogNormal(A::TrackedMatrix) = MvLogNormal(zeros(size(A, 1)), A)
Expand Down

0 comments on commit 97936bf

Please sign in to comment.