From f8a72e6e43498c774fe42cc779c8ba3951feda95 Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 1 Feb 2020 08:34:59 +1100 Subject: [PATCH 1/3] add missing MvNormal Tracker constructor --- src/multivariate.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/multivariate.jl b/src/multivariate.jl index 225823a7..8361d3c4 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -103,6 +103,12 @@ function MvNormal( ) return TuringMvNormal(m, D) end +function MvNormal( + m::TrackedVector{<:Real}, + D::Diagonal{T, <:AbstractVector{T}} where {T<:Real}, +) + return TuringMvNormal(m, D) +end # dense mean, diagonal covariance MvNormal(m::TrackedVector{<:Real}, σ::TrackedVector{<:Real}) = TuringMvNormal(m, σ) From e958f7214443ea46b9b4ed050cdf2951a702fb2e Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 1 Feb 2020 08:58:12 +1100 Subject: [PATCH 2/3] add missing MvLogNormal constructors --- src/multivariate.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/multivariate.jl b/src/multivariate.jl index 8361d3c4..29bf2358 100644 --- a/src/multivariate.jl +++ b/src/multivariate.jl @@ -217,6 +217,18 @@ function MvLogNormal( ) return TuringMvLogNormal(TuringMvNormal(m, D)) end +function MvLogNormal( + m::TrackedVector{<:Real}, + D::Diagonal{T, <:AbstractVector{T}} where {T<:Real}, +) + return TuringMvLogNormal(TuringMvNormal(m, D)) +end +function MvLogNormal( + m::AbstractVector{<:Real}, + D::Diagonal{T, <:AbstractVector{T}} where {T<:Real}, +) + return MvLogNormal(MvNormal(m, D)) +end # dense mean, diagonal covariance MvLogNormal(m::TrackedVector{<:Real}, σ::TrackedVector{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, σ)) From 6711168013f7868c2c02ecbfb62b7b7fb1f6ca3f Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Sat, 1 Feb 2020 08:58:22 +1100 Subject: [PATCH 3/3] test Diagonal covariance --- test/distributions.jl | 8 ++++++++ test/test_utils.jl | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/test/distributions.jl b/test/distributions.jl index 4071222d..450c689f 100644 --- a/test/distributions.jl +++ b/test/distributions.jl @@ -176,26 +176,34 @@ separator() # Vector case DistSpec(:MvNormal, (mean, cov_mat), norm_val_vec), DistSpec(:MvNormal, (mean, cov_vec), norm_val_vec), + DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_vec), DistSpec(:MvNormal, (mean, cov_num), norm_val_vec), DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val_vec), DistSpec(:MvNormal, (cov_mat,), norm_val_vec), DistSpec(:MvNormal, (cov_vec,), norm_val_vec), + DistSpec(:MvNormal, (Diagonal(cov_vec),), norm_val_vec), DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val_vec), DistSpec(:MvLogNormal, (mean, cov_mat), norm_val_vec), DistSpec(:MvLogNormal, (mean, cov_vec), norm_val_vec), + DistSpec(:MvLogNormal, (mean, Diagonal(cov_vec)), norm_val_vec), DistSpec(:MvLogNormal, (mean, cov_num), norm_val_vec), DistSpec(:MvLogNormal, (cov_mat,), norm_val_vec), DistSpec(:MvLogNormal, (cov_vec,), norm_val_vec), + DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_vec), DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_vec), # Matrix case DistSpec(:MvNormal, (mean, cov_vec), norm_val_mat), + DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_mat), DistSpec(:MvNormal, (mean, cov_num), norm_val_mat), DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val_mat), DistSpec(:MvNormal, (cov_vec,), norm_val_mat), + DistSpec(:MvNormal, (Diagonal(cov_vec),), norm_val_mat), DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val_mat), DistSpec(:MvLogNormal, (mean, cov_vec), norm_val_mat), + DistSpec(:MvLogNormal, (mean, Diagonal(cov_vec)), norm_val_mat), DistSpec(:MvLogNormal, (mean, cov_num), norm_val_mat), DistSpec(:MvLogNormal, (cov_vec,), norm_val_mat), + DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_mat), DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_mat), ] diff --git a/test/test_utils.jl b/test/test_utils.jl index 1f3f309b..add03f5f 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -10,6 +10,7 @@ struct DistSpec{Tθ<:Tuple, Tx} end vectorize(v::Number) = [v] +vectorize(v::Diagonal) = v.diag vectorize(v) = vec(v) pack(vals...) = reduce(vcat, vectorize.(vals)) @generated function unpack(x, vals...) @@ -22,6 +23,9 @@ pack(vals...) = reduce(vcat, vectorize.(vals)) elseif T <: Vector push!(unpacked, :(x[$ind:$ind+length(vals[$i])-1])) ind = :($ind + length(vals[$i])) + elseif T <: Diagonal + push!(unpacked, :(Diagonal(x[$ind:$ind+size(vals[$i],1)-1]))) + ind = :($ind + size(vals[$i],1)) elseif T <: Matrix push!(unpacked, :(reshape(x[$ind:($ind+length(vals[$i])-1)], size(vals[$i])))) ind = :($ind + length(vals[$i]))