diff --git a/Project.toml b/Project.toml index 66b06685..162f651e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.10.2" +version = "0.10.3" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index c6fa9a76..193327fa 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -230,24 +230,23 @@ function pd_logpdf_with_trans(d, X::AbstractMatrix{<:Real}, transform::Bool) end lp = getlogp(d, Xcf, X) if transform && isfinite(lp) - U = Xcf.U - d = dim(d) - lp += sum((d .- (1:d) .+ 2) .* log.(diag(U))) - lp += d * log(T(2)) + n = size(d, 1) + lp += sum(((n + 2) .- (1:n)) .* log.(diag(Xcf.factors))) + lp += n * oftype(lp, IrrationalConstants.logtwo) end return lp end function getlogp(d::MatrixBeta, Xcf, X) n1, n2 = params(d) - p = dim(d) + p = size(d, 1) return ((n1 - p - 1) / 2) * logdet(Xcf) + ((n2 - p - 1) / 2) * logdet(I - X) + d.logc0 end function getlogp(d::Wishart, Xcf, X) - return 0.5 * ((d.df - (dim(d) + 1)) * logdet(Xcf) - tr(d.S \ X)) + d.logc0 + return ((d.df - (size(d, 1) + 1)) * logdet(Xcf) - tr(d.S \ X)) / 2 + d.logc0 end function getlogp(d::InverseWishart, Xcf, X) Ψ = Matrix(d.Ψ) - return -0.5 * ((d.df + dim(d) + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) + d.logc0 + return -((d.df + size(d, 1) + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) / 2 + d.logc0 end include("utils.jl") diff --git a/src/compat/distributionsad.jl b/src/compat/distributionsad.jl index 85ef72ad..5402ce00 100644 --- a/src/compat/distributionsad.jl +++ b/src/compat/distributionsad.jl @@ -70,9 +70,9 @@ end ispd(::TuringWishart) = true ispd(::TuringInverseWishart) = true function getlogp(d::TuringWishart, Xcf, X) - return 0.5 * ((d.df - (dim(d) + 1)) * logdet(Xcf) - tr(d.chol \ X)) + d.logc0 + return ((d.df - (size(d, 1) + 1)) * logdet(Xcf) - tr(d.chol \ X)) / 2 + d.logc0 end function getlogp(d::TuringInverseWishart, Xcf, X) Ψ = d.S - return -0.5 * ((d.df + dim(d) + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) + d.logc0 + return -((d.df + size(d, 1) + 1) * logdet(Xcf) + tr(Xcf \ Ψ)) / 2 + d.logc0 end diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index b16ef58b..e2adf050 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -108,11 +108,14 @@ function pd_logpdf_with_trans_zygote( end lp = getlogp(d, Xcf, X) if transform && isfinite(lp) - U = Xcf.U - @inbounds for i in 1:dim(d) - lp += (dim(d) - i + 2) * log(U[i, i]) + factors = Xcf.factors + n = size(d, 1) + k = n + 2 + @inbounds for i in diagind(factors) + k -= 1 + lp += k * log(factors[i]) end - lp += dim(d) * log(T(2)) + lp += n * oftype(lp, IrrationalConstants.logtwo) end return lp end