-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from TuringLang/mt/array_dist_and_multi
Implement FillDist and ArrayDist
- Loading branch information
Showing
11 changed files
with
440 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# Univariate | ||
|
||
const VectorOfUnivariate = Distributions.Product | ||
|
||
function arraydist(dists::AbstractVector{<:Normal{T}}) where {T} | ||
means = mean.(dists) | ||
vars = var.(dists) | ||
return MvNormal(means, vars) | ||
end | ||
function arraydist(dists::AbstractVector{<:Normal{<:TrackedReal}}) | ||
means = vcatmapreduce(mean, dists) | ||
vars = vcatmapreduce(var, dists) | ||
return MvNormal(means, vars) | ||
end | ||
function arraydist(dists::AbstractVector{<:UnivariateDistribution}) | ||
return product_distribution(dists) | ||
end | ||
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractVector{<:Real}) | ||
return sum(vcatmapreduce(logpdf, dist.v, x)) | ||
end | ||
function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real}) | ||
# eachcol breaks Zygote, so we need an adjoint | ||
return vcatmapreduce((dist, c) -> logpdf.(dist, c), dist.v, eachcol(x)) | ||
end | ||
@adjoint function Distributions.logpdf(dist::VectorOfUnivariate, x::AbstractMatrix{<:Real}) | ||
# Any other more efficient implementation breaks Zygote | ||
f(dist, x) = [sum(logpdf.(dist.v, view(x, :, i))) for i in 1:size(x, 2)] | ||
return pullback(f, dist, x) | ||
end | ||
|
||
struct MatrixOfUnivariate{ | ||
S <: ValueSupport, | ||
Tdist <: UnivariateDistribution{S}, | ||
Tdists <: AbstractMatrix{Tdist}, | ||
} <: MatrixDistribution{S} | ||
dists::Tdists | ||
end | ||
Base.size(dist::MatrixOfUnivariate) = size(dist.dists) | ||
function arraydist(dists::AbstractMatrix{<:UnivariateDistribution}) | ||
return MatrixOfUnivariate(dists) | ||
end | ||
function Distributions.logpdf(dist::MatrixOfUnivariate, x::AbstractMatrix{<:Real}) | ||
# Broadcasting here breaks Tracker for some reason | ||
# A Zygote adjoint is defined for vcatmapreduce to use broadcasting | ||
return sum(vcatmapreduce(logpdf, dist.dists, x)) | ||
end | ||
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixOfUnivariate) | ||
return rand.(Ref(rng), dist.dists) | ||
end | ||
|
||
# Multivariate | ||
|
||
struct VectorOfMultivariate{ | ||
S <: ValueSupport, | ||
Tdist <: MultivariateDistribution{S}, | ||
Tdists <: AbstractVector{Tdist}, | ||
} <: MatrixDistribution{S} | ||
dists::Tdists | ||
end | ||
Base.size(dist::VectorOfMultivariate) = (length(dist.dists[1]), length(dist)) | ||
Base.length(dist::VectorOfMultivariate) = length(dist.dists) | ||
function arraydist(dists::AbstractVector{<:MultivariateDistribution}) | ||
return VectorOfMultivariate(dists) | ||
end | ||
function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) | ||
# eachcol breaks Zygote, so we define an adjoint | ||
return sum(vcatmapreduce(logpdf, dist.dists, eachcol(x))) | ||
end | ||
@adjoint function Distributions.logpdf(dist::VectorOfMultivariate, x::AbstractMatrix{<:Real}) | ||
f(dist, x) = sum(vcatmapreduce(i -> logpdf(dist.dists[i], view(x, :, i)), 1:size(x, 2))) | ||
return pullback(f, dist, x) | ||
end | ||
function Distributions.rand(rng::Random.AbstractRNG, dist::VectorOfMultivariate) | ||
init = reshape(rand(rng, dist.dists[1]), :, 1) | ||
return mapreduce(i -> rand(rng, dist.dists[i]), hcat, 2:length(dist); init = init) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Univariate | ||
|
||
const FillVectorOfUnivariate{ | ||
S <: ValueSupport, | ||
T <: UnivariateDistribution{S}, | ||
Tdists <: Fill{T, 1}, | ||
} = VectorOfUnivariate{S, T, Tdists} | ||
|
||
function filldist(dist::UnivariateDistribution, N::Int) | ||
return product_distribution(Fill(dist, N)) | ||
end | ||
filldist(d::Normal, N::Int) = MvNormal(fill(d.μ, N), d.σ) | ||
|
||
function Distributions.logpdf( | ||
dist::FillVectorOfUnivariate, | ||
x::AbstractVector{<:Real}, | ||
) | ||
return _logpdf(dist, x) | ||
end | ||
function Distributions.logpdf( | ||
dist::FillVectorOfUnivariate, | ||
x::AbstractMatrix{<:Real}, | ||
) | ||
return _logpdf(dist, x) | ||
end | ||
@adjoint function Distributions.logpdf( | ||
dist::FillVectorOfUnivariate, | ||
x::AbstractMatrix{<:Real}, | ||
) | ||
return pullback(_logpdf, dist, x) | ||
end | ||
|
||
function _logpdf( | ||
dist::FillVectorOfUnivariate, | ||
x::AbstractVector{<:Real}, | ||
) | ||
return _flat_logpdf(dist.v.value, x) | ||
end | ||
function _logpdf( | ||
dist::FillVectorOfUnivariate, | ||
x::AbstractMatrix{<:Real}, | ||
) | ||
return _flat_logpdf_mat(dist.v.value, x) | ||
end | ||
|
||
function _flat_logpdf(dist, x) | ||
if toflatten(dist) | ||
f, args = flatten(dist) | ||
return sum(f.(args..., x)) | ||
else | ||
return sum(vcatmapreduce(x -> logpdf(dist, x), x)) | ||
end | ||
end | ||
function _flat_logpdf_mat(dist, x) | ||
if toflatten(dist) | ||
f, args = flatten(dist) | ||
return vec(sum(f.(args..., x), dims = 1)) | ||
else | ||
temp = vcatmapreduce(x -> logpdf(dist, x), x) | ||
return vec(sum(reshape(temp, size(x)), dims = 1)) | ||
end | ||
end | ||
|
||
const FillMatrixOfUnivariate{ | ||
S <: ValueSupport, | ||
T <: UnivariateDistribution{S}, | ||
Tdists <: Fill{T, 2}, | ||
} = MatrixOfUnivariate{S, T, Tdists} | ||
|
||
function filldist(dist::UnivariateDistribution, N1::Integer, N2::Integer) | ||
return MatrixOfUnivariate(Fill(dist, N1, N2)) | ||
end | ||
function Distributions.logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:Real}) | ||
return _flat_logpdf(dist.dists.value, x) | ||
end | ||
function Distributions.rand(rng::Random.AbstractRNG, dist::FillMatrixOfUnivariate) | ||
return rand(rng, dist.dists.value, length.(dist.dists.axes)) | ||
end | ||
|
||
# Multivariate | ||
|
||
const FillVectorOfMultivariate{ | ||
S <: ValueSupport, | ||
T <: MultivariateDistribution{S}, | ||
Tdists <: Fill{T, 1}, | ||
} = VectorOfMultivariate{S, T, Tdists} | ||
|
||
function filldist(dist::MultivariateDistribution, N::Int) | ||
return VectorOfMultivariate(Fill(dist, N)) | ||
end | ||
function Distributions.logpdf( | ||
dist::FillVectorOfMultivariate, | ||
x::AbstractMatrix{<:Real}, | ||
) | ||
return _logpdf(dist, x) | ||
end | ||
@adjoint function Distributions.logpdf( | ||
dist::FillVectorOfMultivariate, | ||
x::AbstractMatrix{<:Real}, | ||
) | ||
return pullback(_logpdf, dist, x) | ||
end | ||
function _logpdf( | ||
dist::FillVectorOfMultivariate, | ||
x::AbstractMatrix{<:Real}, | ||
) | ||
return sum(logpdf(dist.dists.value, x)) | ||
end | ||
function Distributions.rand(rng::Random.AbstractRNG, dist::FillVectorOfMultivariate) | ||
return rand(rng, dist.dists.value, length.(dist.dists.axes)) | ||
end |
Oops, something went wrong.