diff --git a/src/Bijectors.jl b/src/Bijectors.jl index d1d47e43..55e4758b 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -7,7 +7,6 @@ using LinearAlgebra using MappedArrays export TransformDistribution, - RealDistribution, PositiveDistribution, UnitDistribution, SimplexDistribution, @@ -335,16 +334,33 @@ function invlink(d::PDMatDistribution, Y::AbstractMatrix{<:Real}) return LowerTriangular(X) * LowerTriangular(X)' end -function logpdf_with_trans(d::PDMatDistribution, X::AbstractMatrix{<:Real}, transform::Bool) - lp = logpdf(d, X) +function logpdf_with_trans( + d::PDMatDistribution, + X::AbstractMatrix{<:Real}, + transform::Bool +) + T = eltype(X) + Xcf = cholesky(X, check=false) + if !issuccess(Xcf) + Xcf = cholesky(X + (eps(T) * norm(X)) * I) + end + lp = getlogp(d, Xcf, X) if transform && isfinite(lp) - U = cholesky(X).U - lp += sum((dim(d) .- (1:dim(d)) .+ 2) .* log.(view(U, diagind(U)))) - lp += dim(d) * log(2) + U = Xcf.U + @inbounds @simd for i in 1:dim(d) + lp += (dim(d) - i + 2) * log(U[i, i]) + end + lp += dim(d) * log(T(2)) end return lp end - +function getlogp(d::Wishart, Xcf, X) + return 0.5 * ((d.df - (dim(d) + 1)) * logdet(Xcf) - tr(d.S \ X)) - d.c0 +end +function getlogp(d::InverseWishart, Xcf, X) + Ψ = Matrix(d.Ψ) + return -0.5 * ((d.df + dim(d) + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) - d.c0 +end ############################################ # Defaults (assume identity link function) #