Skip to content

Commit

Permalink
Merge pull request #26 from TuringLang/mt/missing_MvNormal_constr
Browse files Browse the repository at this point in the history
Missing MvNormal and MvLogNormal constructors and Diagonal covariance tests
  • Loading branch information
mohamed82008 authored Jan 31, 2020
2 parents 4bce688 + 6711168 commit 91fb311
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 0 deletions.
18 changes: 18 additions & 0 deletions src/multivariate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,12 @@ function MvNormal(
)
return TuringMvNormal(m, D)
end
function MvNormal(
m::TrackedVector{<:Real},
D::Diagonal{T, <:AbstractVector{T}} where {T<:Real},
)
return TuringMvNormal(m, D)
end

# dense mean, diagonal covariance
MvNormal(m::TrackedVector{<:Real}, σ::TrackedVector{<:Real}) = TuringMvNormal(m, σ)
Expand Down Expand Up @@ -211,6 +217,18 @@ function MvLogNormal(
)
return TuringMvLogNormal(TuringMvNormal(m, D))
end
function MvLogNormal(
m::TrackedVector{<:Real},
D::Diagonal{T, <:AbstractVector{T}} where {T<:Real},
)
return TuringMvLogNormal(TuringMvNormal(m, D))
end
function MvLogNormal(
m::AbstractVector{<:Real},
D::Diagonal{T, <:AbstractVector{T}} where {T<:Real},
)
return MvLogNormal(MvNormal(m, D))
end

# dense mean, diagonal covariance
MvLogNormal(m::TrackedVector{<:Real}, σ::TrackedVector{<:Real}) = TuringMvLogNormal(TuringMvNormal(m, σ))
Expand Down
8 changes: 8 additions & 0 deletions test/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,26 +176,34 @@ separator()
# Vector case
DistSpec(:MvNormal, (mean, cov_mat), norm_val_vec),
DistSpec(:MvNormal, (mean, cov_vec), norm_val_vec),
DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_vec),
DistSpec(:MvNormal, (mean, cov_num), norm_val_vec),
DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val_vec),
DistSpec(:MvNormal, (cov_mat,), norm_val_vec),
DistSpec(:MvNormal, (cov_vec,), norm_val_vec),
DistSpec(:MvNormal, (Diagonal(cov_vec),), norm_val_vec),
DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val_vec),
DistSpec(:MvLogNormal, (mean, cov_mat), norm_val_vec),
DistSpec(:MvLogNormal, (mean, cov_vec), norm_val_vec),
DistSpec(:MvLogNormal, (mean, Diagonal(cov_vec)), norm_val_vec),
DistSpec(:MvLogNormal, (mean, cov_num), norm_val_vec),
DistSpec(:MvLogNormal, (cov_mat,), norm_val_vec),
DistSpec(:MvLogNormal, (cov_vec,), norm_val_vec),
DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_vec),
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_vec),
# Matrix case
DistSpec(:MvNormal, (mean, cov_vec), norm_val_mat),
DistSpec(:MvNormal, (mean, Diagonal(cov_vec)), norm_val_mat),
DistSpec(:MvNormal, (mean, cov_num), norm_val_mat),
DistSpec(:((m, v) -> MvNormal(m, v*I)), (mean, cov_num), norm_val_mat),
DistSpec(:MvNormal, (cov_vec,), norm_val_mat),
DistSpec(:MvNormal, (Diagonal(cov_vec),), norm_val_mat),
DistSpec(:(cov_num -> MvNormal(dim, cov_num)), (cov_num,), norm_val_mat),
DistSpec(:MvLogNormal, (mean, cov_vec), norm_val_mat),
DistSpec(:MvLogNormal, (mean, Diagonal(cov_vec)), norm_val_mat),
DistSpec(:MvLogNormal, (mean, cov_num), norm_val_mat),
DistSpec(:MvLogNormal, (cov_vec,), norm_val_mat),
DistSpec(:MvLogNormal, (Diagonal(cov_vec),), norm_val_mat),
DistSpec(:(cov_num -> MvLogNormal(dim, cov_num)), (cov_num,), norm_val_mat),
]

Expand Down
4 changes: 4 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ struct DistSpec{Tθ<:Tuple, Tx}
end

vectorize(v::Number) = [v]
vectorize(v::Diagonal) = v.diag
vectorize(v) = vec(v)
pack(vals...) = reduce(vcat, vectorize.(vals))
@generated function unpack(x, vals...)
Expand All @@ -22,6 +23,9 @@ pack(vals...) = reduce(vcat, vectorize.(vals))
elseif T <: Vector
push!(unpacked, :(x[$ind:$ind+length(vals[$i])-1]))
ind = :($ind + length(vals[$i]))
elseif T <: Diagonal
push!(unpacked, :(Diagonal(x[$ind:$ind+size(vals[$i],1)-1])))
ind = :($ind + size(vals[$i],1))
elseif T <: Matrix
push!(unpacked, :(reshape(x[$ind:($ind+length(vals[$i])-1)], size(vals[$i]))))
ind = :($ind + length(vals[$i]))
Expand Down

0 comments on commit 91fb311

Please sign in to comment.