diff --git a/src/multivariate.jl b/src/multivariate.jl index 225823a7..29bf2358 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -103,6 +103,12 @@ function MvNormal( ) return TuringMvNormal(m, D) end +function MvNormal( + m::TrackedVector{<:Real}, + D::Diagonal{T, <:AbstractVector{T}} where {T<:Real}, +) + return TuringMvNormal(m, D) +end # dense mean, diagonal covariance MvNormal(m::TrackedVector{<:Real}, σ::TrackedVector{<:Real}) = TuringMvNormal(m, σ) @@ -211,6 +217,18 @@ function MvLogNormal( ) return TuringMvLogNormal(TuringMvNormal(m, D)) end +function MvLogNormal( + m::TrackedVector{<:Real}, + D::Diagonal{T, <:AbstractVector{T}} where {T<:Real}, +) + return TuringMvLogNormal(TuringMvNormal(m, D)) +end +function MvLogNormal( + m::AbstractVector{<:Real}, + D::Diagonal{T, <:AbstractVector{T}} where {T<:Real}, +) + return MvLogNormal(MvNormal(m, D)) +end # dense mean, diagonal covariance MvLogNormal(m::TrackedVector{<:Real}, σ::TrackedVector{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, σ)) diff --git a/test/distributions.jl b/test/distributions.jl index 4071222d..450c689f 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -176,26 +176,34 @@ separator() # Vector case DistSpec(:MvNormal, (mean, cov_mat), norm_val_vec), DistSpec(:MvNormal, (mean, cov_vec), norm_val_vec), + DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_vec), DistSpec(:MvNormal, (mean, cov_num), norm_val_vec), DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val_vec), DistSpec(:MvNormal, (cov_mat,), norm_val_vec), DistSpec(:MvNormal, (cov_vec,), norm_val_vec), + DistSpec(:MvNormal, (Diagonal(cov_vec),), norm_val_vec), DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val_vec), DistSpec(:MvLogNormal, (mean, cov_mat), norm_val_vec), DistSpec(:MvLogNormal, (mean, cov_vec), norm_val_vec), + DistSpec(:MvLogNormal, (mean, Diagonal(cov_vec)), norm_val_vec), DistSpec(:MvLogNormal, (mean, cov_num), norm_val_vec), DistSpec(:MvLogNormal, (cov_mat,), norm_val_vec), DistSpec(:MvLogNormal, (cov_vec,), norm_val_vec), + DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_vec), DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_vec), # Matrix case DistSpec(:MvNormal, (mean, cov_vec), norm_val_mat), + DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_mat), DistSpec(:MvNormal, (mean, cov_num), norm_val_mat), DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val_mat), DistSpec(:MvNormal, (cov_vec,), norm_val_mat), + DistSpec(:MvNormal, (Diagonal(cov_vec),), norm_val_mat), DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val_mat), DistSpec(:MvLogNormal, (mean, cov_vec), norm_val_mat), + DistSpec(:MvLogNormal, (mean, Diagonal(cov_vec)), norm_val_mat), DistSpec(:MvLogNormal, (mean, cov_num), norm_val_mat), DistSpec(:MvLogNormal, (cov_vec,), norm_val_mat), + DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_mat), DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_mat), ] diff --git a/test/test_utils.jl b/test/test_utils.jl index 1f3f309b..add03f5f 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -10,6 +10,7 @@ struct DistSpec{Tθ<:Tuple, Tx} end vectorize(v::Number) = [v] +vectorize(v::Diagonal) = v.diag vectorize(v) = vec(v) pack(vals...) = reduce(vcat, vectorize.(vals)) @generated function unpack(x, vals...) @@ -22,6 +23,9 @@ pack(vals...) = reduce(vcat, vectorize.(vals)) elseif T <: Vector push!(unpacked, :(x[$ind:$ind+length(vals[$i])-1])) ind = :($ind + length(vals[$i])) + elseif T <: Diagonal + push!(unpacked, :(Diagonal(x[$ind:$ind+size(vals[$i],1)-1]))) + ind = :($ind + size(vals[$i],1)) elseif T <: Matrix push!(unpacked, :(reshape(x[$ind:($ind+length(vals[$i])-1)], size(vals[$i])))) ind = :($ind + length(vals[$i]))