From 142380b09d40565089852203ae6d350e103ad34f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 09:24:32 +0100 Subject: [PATCH 01/38] Add NamedTupleVariate --- src/Distributions.jl | 1 + src/common.jl | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/src/Distributions.jl b/src/Distributions.jl index ceb6063c7..b43416d99 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -39,6 +39,7 @@ export Multivariate, Matrixvariate, CholeskyVariate, + NamedTupleVariate, Discrete, Continuous, Sampleable, diff --git a/src/common.jl b/src/common.jl index 703e12693..8b0672545 100644 --- a/src/common.jl +++ b/src/common.jl @@ -16,6 +16,12 @@ const Univariate = ArrayLikeVariate{0} const Multivariate = ArrayLikeVariate{1} const Matrixvariate = ArrayLikeVariate{2} +""" +`F <: NamedTupleVariate{K}` specifies that the variate or a sample is of type +`NamedTuple{K}`. +""" +abstract type NamedTupleVariate{K} <: VariateForm end + """ `F <: CholeskyVariate` specifies that the variate or a sample is of type `LinearAlgebra.Cholesky`. From 191ca1ab9d2698989712f61c0345b6e393f5193b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 09:24:53 +0100 Subject: [PATCH 02/38] Add ProductNamedTupleDistribution --- src/Distributions.jl | 1 + src/namedtuple/productnamedtuple.jl | 105 ++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 src/namedtuple/productnamedtuple.jl diff --git a/src/Distributions.jl b/src/Distributions.jl index b43416d99..7b8a2ef42 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -296,6 +296,7 @@ include("univariates.jl") include("edgeworth.jl") include("multivariates.jl") include("matrixvariates.jl") +include("namedtuple/productnamedtuple.jl") include("cholesky/lkjcholesky.jl") include("samplers.jl") diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl new file mode 100644 index 000000000..546a46688 --- /dev/null +++ b/src/namedtuple/productnamedtuple.jl @@ -0,0 +1,105 @@ +struct ProductNamedTupleDistribution{Tnames,Tdists,eltypes,S<:ValueSupport} <: + Distribution{NamedTupleVariate{Tnames},S} + dists::NamedTuple{Tnames,Tdists} +end +function ProductNamedTupleDistribution( + dists::NamedTuple{K,V} +) where {K,V<:Tuple{Vararg{Distribution}}} + eltypes = Tuple{map(eltype, values(dists))...} + # TODO: allow mixed ValueSupports here + vs = _product_valuesupport(dists) + return ProductNamedTupleDistribution{K,V,eltypes,vs}(dists) +end + +function Base.show(io::IO, d::ProductNamedTupleDistribution) + show_multline(io, d, collect(pairs(d.dists))) +end + +function distrname(::ProductNamedTupleDistribution{K}) where {K} + return "ProductNamedTupleDistribution{$K}" +end + +""" + product_distribution(dists::Namedtuple{K,Tuple{Vararg{Distribution}}}) where {K} + +Create a distribution of `NamedTuple`s as a product distribution of independent named +distributions. + +The function falls back to constructing a [`ProductNamedTupleDistribution`](@ref) +distribution but specialized methods can be defined. +""" +function product_distribution(dists::NamedTuple{<:Any,<:Tuple{Vararg{Distribution}}}) + return ProductNamedTupleDistribution(dists) +end + +# Properties + +function Base.eltype(::Type{<:ProductNamedTupleDistribution{K,<:Any,V}}) where {K,V} + return NamedTuple{K,V} +end + +function minimum( + d::ProductNamedTupleDistribution{<:Any,<:Tuple{Vararg{UnivariateDistribution}}} +) + return map(minimum, d.dists) +end +function maximum( + d::ProductNamedTupleDistribution{<:Any,<:Tuple{Vararg{UnivariateDistribution}}} +) + return map(maximum, d.dists) +end + +function insupport(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K} + return all(Base.splat(insupport), zip(dist.dists, x)) +end + +# Evaluation + +function pdf(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K} + return exp(logpdf(dist, x)) +end +function logpdf(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K} + return mapreduce(logpdf, +, dist.dists, x) +end + +# Statistics + +mode(d::ProductNamedTupleDistribution) = map(mode, d.dists) + +mean(d::ProductNamedTupleDistribution) = map(mean, d.dists) + +var(d::ProductNamedTupleDistribution) = map(var, d.dists) + +entropy(d::ProductNamedTupleDistribution) = sum(entropy, d.dists) + +function kldivergence( + d1::ProductNamedTupleDistribution{K}, d2::ProductNamedTupleDistribution{K} +) where {K} + return mapreduce(kldivergence, +, d1.dists, d2.dists) +end + +# Sampling + +function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution{K}) where {K} + return NamedTuple{K}(map(Base.Fix1(rand, rng), d.dists)) +end +function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution, dims::Dims) + x = rand(rng, d) + xs = Array{typeof(x)}(undef, dims) + xs[1] = x + for i in Iterators.drop(eachindex(xs), 1) + xs[i] = rand(rng, d) + end + return xs +end + +function _rand!( + rng::AbstractRNG, + d::ProductNamedTupleDistribution, + xs::AbstractArray, +) + for i in eachindex(xs) + xs[i] = Random.rand(rng, d) + end + return xs +end From 399b03b1ed075306027133b740e4f6c4c92aef22 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 17:14:28 +0100 Subject: [PATCH 03/38] Correctly implement eltype --- src/namedtuple/productnamedtuple.jl | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 546a46688..434bfbef8 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -1,18 +1,27 @@ -struct ProductNamedTupleDistribution{Tnames,Tdists,eltypes,S<:ValueSupport} <: +struct ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <: Distribution{NamedTupleVariate{Tnames},S} dists::NamedTuple{Tnames,Tdists} end function ProductNamedTupleDistribution( dists::NamedTuple{K,V} ) where {K,V<:Tuple{Vararg{Distribution}}} - eltypes = Tuple{map(eltype, values(dists))...} - # TODO: allow mixed ValueSupports here vs = _product_valuesupport(dists) - return ProductNamedTupleDistribution{K,V,eltypes,vs}(dists) + eltypes = _product_namedtuple_eltype(dists) + return ProductNamedTupleDistribution{K,V,vs,eltypes}(dists) end +_gentype(d::UnivariateDistribution) = eltype(d) +_gentype(d::Distribution{<:ArrayLikeVariate{S}}) where {S} = Array{eltype(d),S} +function _gentype(d::Distribution{CholeskyVariate}) + T = eltype(d) + return LinearAlgebra.Cholesky{T,Matrix{T}} +end +_gentype(::Distribution) = Any + +_product_namedtuple_eltype(dists) = typejoin(map(_gentype, dists)...) + function Base.show(io::IO, d::ProductNamedTupleDistribution) - show_multline(io, d, collect(pairs(d.dists))) + return show_multline(io, d, collect(pairs(d.dists))) end function distrname(::ProductNamedTupleDistribution{K}) where {K} @@ -34,9 +43,7 @@ end # Properties -function Base.eltype(::Type{<:ProductNamedTupleDistribution{K,<:Any,V}}) where {K,V} - return NamedTuple{K,V} -end +Base.eltype(::Type{<:ProductNamedTupleDistribution{<:Any,<:Any,<:Any,T}}) where {T} = T function minimum( d::ProductNamedTupleDistribution{<:Any,<:Tuple{Vararg{UnivariateDistribution}}} From eb946a8d17be160558b52d84193e12f003ce4d2e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 17:14:57 +0100 Subject: [PATCH 04/38] Simplify insupport implementation --- src/namedtuple/productnamedtuple.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 434bfbef8..a30d606ee 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -57,7 +57,7 @@ function maximum( end function insupport(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K} - return all(Base.splat(insupport), zip(dist.dists, x)) + return all(map(insupport, dist.dists, x)) end # Evaluation From 32ca2f03c0d22699d594280323c8dcfc4b0c5003 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 17:15:13 +0100 Subject: [PATCH 05/38] Overload std for ProductNamedTupleDistribution --- src/namedtuple/productnamedtuple.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index a30d606ee..336c975f3 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -77,6 +77,8 @@ mean(d::ProductNamedTupleDistribution) = map(mean, d.dists) var(d::ProductNamedTupleDistribution) = map(var, d.dists) +std(d::ProductNamedTupleDistribution) = map(std, d.dists) + entropy(d::ProductNamedTupleDistribution) = sum(entropy, d.dists) function kldivergence( From a416f02a17e98997b4bf749e79ed46655db543b5 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 17:16:57 +0100 Subject: [PATCH 06/38] Simplify rand for ProductNamedTupleDistribution --- src/namedtuple/productnamedtuple.jl | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 336c975f3..3d60c0b4a 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -93,11 +93,8 @@ function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution{K}) where return NamedTuple{K}(map(Base.Fix1(rand, rng), d.dists)) end function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution, dims::Dims) - x = rand(rng, d) - xs = Array{typeof(x)}(undef, dims) - xs[1] = x - for i in Iterators.drop(eachindex(xs), 1) - xs[i] = rand(rng, d) + xs = return map(CartesianIndices(dims)) do _ + return rand(rng, d) end return xs end From 7deff941fbc66955c603cf853688a9dd388e50e6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 17:17:06 +0100 Subject: [PATCH 07/38] Reformat line --- src/namedtuple/productnamedtuple.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 3d60c0b4a..576269c08 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -99,11 +99,7 @@ function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution, dims::Dim return xs end -function _rand!( - rng::AbstractRNG, - d::ProductNamedTupleDistribution, - xs::AbstractArray, -) +function _rand!(rng::AbstractRNG, d::ProductNamedTupleDistribution, xs::AbstractArray) for i in eachindex(xs) xs[i] = Random.rand(rng, d) end From 978b2de3b36ab74df2327c76a39b14d8305fd3da Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 17:20:05 +0100 Subject: [PATCH 08/38] Add docstring to ProductNamedTupleDistribution --- src/namedtuple/productnamedtuple.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 576269c08..ccefcb34e 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -1,3 +1,14 @@ +""" + ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <: + Distribution{NamedTupleVariate{Tnames},S} + + A distribution of `NamedTuple`s, constructed from a `NamedTuple` of independent named + distributions. + + Users should use [`product_distribution`](@ref) to construct a product distribution of + independent distributions instead of constructing a `ProductNamedTupleDistribution` + directly. +""" struct ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <: Distribution{NamedTupleVariate{Tnames},S} dists::NamedTuple{Tnames,Tdists} From b718e59f91fbbbef4070ecab0b3df38bd2ee32fd Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 19:49:58 +0100 Subject: [PATCH 09/38] Add marginal API function --- src/Distributions.jl | 1 + src/common.jl | 10 ++++++++++ src/namedtuple/productnamedtuple.jl | 2 ++ 3 files changed, 13 insertions(+) diff --git a/src/Distributions.jl b/src/Distributions.jl index 7b8a2ef42..d4484e3bf 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -229,6 +229,7 @@ export sqmahal!, # in-place evaluation of sqmahal location, # get the location parameter location!, # provide storage for the location parameter (used in multivariate distribution mvlognormal) + marginal, # marginal distributions mean, # mean of distribution meandir, # mean direction (of a spherical distribution) meanform, # convert a normal distribution from canonical form to mean form diff --git a/src/common.jl b/src/common.jl index 8b0672545..32a8d4961 100644 --- a/src/common.jl +++ b/src/common.jl @@ -475,6 +475,16 @@ Base.@propagate_inbounds function loglikelihood( return sum(Base.Fix1(logpdf, d), x) end +""" + marginal(d::Distribution, k...) -> Distribution + +Return the marginal distribution of `d` at the indices `k...`. + +The result is the distribution of the variate `rand(d)[k...]` that one would obtain by +integrating over all other indices. +""" +marginal(d::Distribution, k...) + ## TODO: the following types need to be improved abstract type SufficientStats end abstract type IncompleteDistribution end diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index ccefcb34e..7d58e5be7 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -39,6 +39,8 @@ function distrname(::ProductNamedTupleDistribution{K}) where {K} return "ProductNamedTupleDistribution{$K}" end +marginal(d::ProductNamedTupleDistribution, k::Union{Int,Symbol}) = d.dists[k] + """ product_distribution(dists::Namedtuple{K,Tuple{Vararg{Distribution}}}) where {K} From d08431d7e1c9c061b22960b2f19a43b11104f5bc Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 19:50:08 +0100 Subject: [PATCH 10/38] Add marginal for ProductDistribution --- src/product.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/product.jl b/src/product.jl index 7a4904ae7..05e6927df 100644 --- a/src/product.jl +++ b/src/product.jl @@ -67,6 +67,8 @@ const ArrayOfUnivariateDistribution{N,D,S<:ValueSupport,T} = ProductDistribution const FillArrayOfUnivariateDistribution{N,D<:Fill{<:Any,N},S<:ValueSupport,T} = ProductDistribution{N,0,D,S,T} +marginal(d::ProductDistribution, i...) = d.dists[i...] + ## General definitions function Base.eltype(::Type{<:ProductDistribution{<:Any,<:Any,<:Any,<:ValueSupport,T}}) where {T} return T From 79e5d594d96cc73d0f7521a9c5a7d31d3b20c4cf Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 20:06:27 +0100 Subject: [PATCH 11/38] Rearrange marginal --- src/namedtuple/productnamedtuple.jl | 4 ++-- src/product.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 7d58e5be7..2558d637f 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -39,8 +39,6 @@ function distrname(::ProductNamedTupleDistribution{K}) where {K} return "ProductNamedTupleDistribution{$K}" end -marginal(d::ProductNamedTupleDistribution, k::Union{Int,Symbol}) = d.dists[k] - """ product_distribution(dists::Namedtuple{K,Tuple{Vararg{Distribution}}}) where {K} @@ -69,6 +67,8 @@ function maximum( return map(maximum, d.dists) end +marginal(d::ProductNamedTupleDistribution, k::Union{Int,Symbol}) = d.dists[k] + function insupport(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K} return all(map(insupport, dist.dists, x)) end diff --git a/src/product.jl b/src/product.jl index 05e6927df..8424cffb8 100644 --- a/src/product.jl +++ b/src/product.jl @@ -67,8 +67,6 @@ const ArrayOfUnivariateDistribution{N,D,S<:ValueSupport,T} = ProductDistribution const FillArrayOfUnivariateDistribution{N,D<:Fill{<:Any,N},S<:ValueSupport,T} = ProductDistribution{N,0,D,S,T} -marginal(d::ProductDistribution, i...) = d.dists[i...] - ## General definitions function Base.eltype(::Type{<:ProductDistribution{<:Any,<:Any,<:Any,<:ValueSupport,T}}) where {T} return T @@ -95,6 +93,8 @@ minimum(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(minimum, d.dis maximum(d::ArrayOfUnivariateDistribution) = map(maximum, d.dists) maximum(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(maximum, d.dists)) +marginal(d::ProductDistribution, i...) = d.dists[i...] + function entropy(d::ArrayOfUnivariateDistribution) # we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020) return sum(Broadcast.instantiate(Broadcast.broadcasted(entropy, d.dists))) From 52fb9a07eee6f2c521257a4947b57a7e709652df Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 20:45:06 +0100 Subject: [PATCH 12/38] Allow tuple indexing via marginal --- src/namedtuple/productnamedtuple.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 2558d637f..0c216d823 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -68,6 +68,11 @@ function maximum( end marginal(d::ProductNamedTupleDistribution, k::Union{Int,Symbol}) = d.dists[k] +if VERSION ≥ v"1.7.0-DEV.294" + function marginal(d::ProductNamedTupleDistribution, ks::Tuple{Symbol,Vararg{Symbol}}) + return ProductNamedTupleDistribution(d.dists[ks]) + end +end function insupport(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K} return all(map(insupport, dist.dists, x)) From 1509abd06b9ae0b66aac67328613456716ddddec Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 21:33:47 +0100 Subject: [PATCH 13/38] Make logpdf type-stable --- src/namedtuple/productnamedtuple.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 0c216d823..bafb0b223 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -83,8 +83,9 @@ end function pdf(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K} return exp(logpdf(dist, x)) end + function logpdf(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K} - return mapreduce(logpdf, +, dist.dists, x) + return sum(map(logpdf, dist.dists, x)) end # Statistics From 450fb7d52dc4e7540ca5ef775593a83d80bdd7b6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 21:33:53 +0100 Subject: [PATCH 14/38] Add loglikelihood --- src/namedtuple/productnamedtuple.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index bafb0b223..12e0b95e0 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -88,6 +88,16 @@ function logpdf(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where return sum(map(logpdf, dist.dists, x)) end +function loglikelihood(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K} + return logpdf(dist, x) +end + +function loglikelihood( + dist::ProductNamedTupleDistribution{K}, xs::AbstractArray{<:NamedTuple{K}} +) where {K} + return sum(Base.Fix1(loglikelihood, dist), xs) +end + # Statistics mode(d::ProductNamedTupleDistribution) = map(mode, d.dists) From eb2ed6c5a45e6231179c10d3dcb99cedaf76b673 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 21:34:09 +0100 Subject: [PATCH 15/38] Support extrema for multivariate distributions --- src/namedtuple/productnamedtuple.jl | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 12e0b95e0..983b01e67 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -56,16 +56,9 @@ end Base.eltype(::Type{<:ProductNamedTupleDistribution{<:Any,<:Any,<:Any,T}}) where {T} = T -function minimum( - d::ProductNamedTupleDistribution{<:Any,<:Tuple{Vararg{UnivariateDistribution}}} -) - return map(minimum, d.dists) -end -function maximum( - d::ProductNamedTupleDistribution{<:Any,<:Tuple{Vararg{UnivariateDistribution}}} -) - return map(maximum, d.dists) -end +Base.minimum(d::ProductNamedTupleDistribution) = map(minimum, d.dists) + +Base.maximum(d::ProductNamedTupleDistribution) = map(maximum, d.dists) marginal(d::ProductNamedTupleDistribution, k::Union{Int,Symbol}) = d.dists[k] if VERSION ≥ v"1.7.0-DEV.294" From e3a0814812efc45f9935422e621cff12672d6d78 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 21:39:49 +0100 Subject: [PATCH 16/38] Add tests --- test/namedtuple/productnamedtuple.jl | 198 +++++++++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 199 insertions(+) create mode 100644 test/namedtuple/productnamedtuple.jl diff --git a/test/namedtuple/productnamedtuple.jl b/test/namedtuple/productnamedtuple.jl new file mode 100644 index 000000000..51d4ab8d0 --- /dev/null +++ b/test/namedtuple/productnamedtuple.jl @@ -0,0 +1,198 @@ +using Distributions +using Distributions: ProductNamedTupleDistribution +using LinearAlgebra +using Random +using Test + +@testset "ProductNamedTupleDistribution" begin + @testset "Constructor" begin + nt = (x=Normal(1.0, 2.0), y=Normal(3.0, 4.0)) + d = @inferred ProductNamedTupleDistribution(nt) + @test d isa ProductNamedTupleDistribution + @test d.dists === nt + @test Distributions.variate_form(typeof(d)) === NamedTupleVariate{(:x, :y)} + @test Distributions.value_support(typeof(d)) === Continuous + + nt = ( + x=Normal(), + y=Dirichlet(10, 1.0), + z=DiscreteUniform(1, 10), + w=LKJCholesky(3, 2.0), + ) + d = @inferred ProductNamedTupleDistribution(nt) + @test d isa ProductNamedTupleDistribution + @test d.dists === nt + @test Distributions.variate_form(typeof(d)) === NamedTupleVariate{(:x, :y, :z, :w)} + @test Distributions.value_support(typeof(d)) === Continuous + end + + @testset "product_distribution" begin + nt = (x=Normal(1.0, 2.0), y=Normal(3.0, 4.0)) + d = @inferred product_distribution(nt) + @test d === ProductNamedTupleDistribution(nt) + + nt = ( + x=Normal(), + y=Dirichlet(10, 1.0), + z=DiscreteUniform(1, 10), + w=LKJCholesky(3, 2.0), + ) + d = @inferred product_distribution(nt) + @test d === ProductNamedTupleDistribution(nt) + end + + @testset "show" begin + d = ProductNamedTupleDistribution((x=Gamma(1.0, 2.0), y=Normal())) + @test sprint(show, d) == """ + ProductNamedTupleDistribution{(:x, :y)}( + x: Gamma{Float64}(α=1.0, θ=2.0) + y: Normal{Float64}(μ=0.0, σ=1.0) + ) + """ + end + + @testset "Properties" begin + @testset "eltype" begin + nt = (x=Normal(1.0, 2.0), y=Normal(3.0, 4.0)) + d = ProductNamedTupleDistribution(nt) + @test eltype(d) === Float64 + + nt = (x=Normal(), y=Gamma()) + d = ProductNamedTupleDistribution(nt) + @test eltype(d) === Float64 + + nt = (x=Bernoulli(),) + d = ProductNamedTupleDistribution(nt) + @test eltype(d) === Bool + + nt = (x=Normal(), y=Bernoulli()) + d = ProductNamedTupleDistribution(nt) + @test eltype(d) === Real + + nt = (w=LKJCholesky(3, 2.0),) + d = ProductNamedTupleDistribution(nt) + @test eltype(d) === LinearAlgebra.Cholesky{Float64,Array{Float64,2}} + + nt = ( + x=Normal(), + y=Dirichlet(10, 1.0), + z=DiscreteUniform(1, 10), + w=LKJCholesky(3, 2.0), + ) + d = ProductNamedTupleDistribution(nt) + @test eltype(d) === Any + end + + @testset "minimum" begin + nt = (x=Normal(1.0, 2.0), y=Gamma(), z=MvNormal(ones(5))) + d = ProductNamedTupleDistribution(nt) + @test @inferred(minimum(d)) == + (x=minimum(nt.x), y=minimum(nt.y), z=minimum(nt.z)) + end + + @testset "maximum" begin + nt = (x=Normal(1.0, 2.0), y=Gamma(), z=MvNormal(ones(5))) + d = ProductNamedTupleDistribution(nt) + @test @inferred(maximum(d)) == + (x=maximum(nt.x), y=maximum(nt.y), z=maximum(nt.z)) + end + + @testset "marginal" begin + nt = (x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0)) + d = ProductNamedTupleDistribution(nt) + @test marginal(d, :x) === nt[:x] + @test marginal(d, :y) === nt[:y] + @test marginal(d, :z) === nt[:z] + @test marginal(d, 1) === nt[1] + @test marginal(d, 2) === nt[2] + @test marginal(d, 3) === nt[3] + if VERSION ≥ v"1.7.0-DEV.294" + @test marginal(d, (:x, :y)) === + ProductNamedTupleDistribution((x=nt[:x], y=nt[:y])) + @test marginal(d, (:z, :x)) === + ProductNamedTupleDistribution((z=nt[:z], x=nt[:x])) + @test_throws ErrorException marginal(d, (:x, :w)) + end + @test_throws MethodError marginal(d, ()) + end + + @testset "insupport" begin + nt = (x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0)) + d = ProductNamedTupleDistribution(nt) + x = (x=rand(nt.x), y=rand(nt.y), z=rand(nt.z)) + @test @inferred(insupport(d, x)) + @test_throws MethodError insupport(d, NamedTuple{(:y, :z, :x)}(x)) + @test_throws MethodError insupport(d, NamedTuple{(:x, :y)}(x)) + @test !insupport(d, merge(x, (x=NaN,))) + @test !insupport(d, merge(x, (y=-1,))) + @test !insupport(d, merge(x, (z=fill(0.25, 4),))) + end + end + + @testset "Evaluation" begin + nt = (x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0), w=Bernoulli()) + d = ProductNamedTupleDistribution(nt) + x = (x=rand(nt.x), y=rand(nt.y), z=rand(nt.z), w=rand(nt.w)) + @test @inferred(logpdf(d, x)) == + logpdf(nt.x, x.x) + logpdf(nt.y, x.y) + logpdf(nt.z, x.z) + logpdf(nt.w, x.w) + @test @inferred(pdf(d, x)) == exp(logpdf(d, x)) + @test @inferred(loglikelihood(d, x)) == logpdf(d, x) + xs = [(x=rand(nt.x), y=rand(nt.y), z=rand(nt.z), w=rand(nt.w)) for _ in 1:10] + @test @inferred(loglikelihood(d, xs)) == sum(logpdf.(Ref(d), xs)) + end + + @testset "Statistics" begin + nt = (x=Normal(1.0, 2.0), y=Gamma(), z=MvNormal(1.0:5.0), w=Poisson(100)) + d = ProductNamedTupleDistribution(nt) + @test @inferred(mode(d)) == (x=mode(nt.x), y=mode(nt.y), z=mode(nt.z), w=mode(nt.w)) + @test @inferred(mean(d)) == (x=mean(nt.x), y=mean(nt.y), z=mean(nt.z), w=mean(nt.w)) + @test @inferred(var(d)) == (x=var(nt.x), y=var(nt.y), z=var(nt.z), w=var(nt.w)) + @test @inferred(entropy(d)) == + entropy(nt.x) + entropy(nt.y) + entropy(nt.z) + entropy(nt.w) + + d1 = ProductNamedTupleDistribution((x=Normal(1.0, 2.0), y=Gamma())) + d2 = ProductNamedTupleDistribution((x=Normal(), y=Gamma(2.0, 3.0))) + @test kldivergence(d1, d2) == + kldivergence(d1.dists.x, d2.dists.x) + kldivergence(d1.dists.y, d2.dists.y) + + d3 = ProductNamedTupleDistribution((x=Normal(1.0, 2.0), y=Gamma(6.0, 7.0))) + @test std(d3) == (x=std(d3.dists.x), y=std(d3.dists.y)) + end + + @testset "Sampling" begin + rng = MersenneTwister(973) + + @testset "rand" begin + nt = (x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0), w=Bernoulli()) + d = ProductNamedTupleDistribution(nt) + rng = MersenneTwister(973) + x1 = @inferred rand(rng, d) + @test eltype(x1) === eltype(d) + rng = MersenneTwister(973) + x2 = ( + x=rand(rng, nt.x), y=rand(rng, nt.y), z=rand(rng, nt.z), w=rand(rng, nt.w) + ) + @test x1 == x2 + x3 = rand(rng, d) + @test x3 != x1 + + xs1 = @inferred rand(rng, d, 10) + @test length(xs1) == 10 + @test all(insupport.(Ref(d), xs1)) + + xs2 = @inferred rand(rng, d, (2, 3, 4)) + @test size(xs2) == (2, 3, 4) + @test all(insupport.(Ref(d), xs2)) + end + + @testset "rand!" begin + d = ProductNamedTupleDistribution(( + x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0), w=Bernoulli() + )) + x = rand(d) + xs = Array{typeof(x)}(undef, (2, 3, 4)) + rand!(d, xs) + @test all(insupport.(Ref(d), xs)) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ce3f16b79..00a0fae15 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,7 @@ const tests = [ "qq", "univariate/continuous/pgeneralizedgaussian", "product", + "namedtuple/productnamedtuple.jl", "univariate/discrete/discretenonparametric", "univariate/continuous/chernoff", "univariate_bounds", # extra file compared to /src From 9acc869d99e384b3114879c594846ef17a8996f4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 22:39:27 +0100 Subject: [PATCH 17/38] Improve type-inferrability --- src/namedtuple/productnamedtuple.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 983b01e67..2654b1268 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -16,8 +16,8 @@ end function ProductNamedTupleDistribution( dists::NamedTuple{K,V} ) where {K,V<:Tuple{Vararg{Distribution}}} - vs = _product_valuesupport(dists) - eltypes = _product_namedtuple_eltype(dists) + vs = _product_valuesupport(values(dists)) + eltypes = _product_namedtuple_eltype(values(dists)) return ProductNamedTupleDistribution{K,V,vs,eltypes}(dists) end @@ -48,7 +48,9 @@ distributions. The function falls back to constructing a [`ProductNamedTupleDistribution`](@ref) distribution but specialized methods can be defined. """ -function product_distribution(dists::NamedTuple{<:Any,<:Tuple{Vararg{Distribution}}}) +function product_distribution( + dists::NamedTuple{<:Any,<:Tuple{Distribution,Vararg{Distribution}}} +) return ProductNamedTupleDistribution(dists) end @@ -101,7 +103,7 @@ var(d::ProductNamedTupleDistribution) = map(var, d.dists) std(d::ProductNamedTupleDistribution) = map(std, d.dists) -entropy(d::ProductNamedTupleDistribution) = sum(entropy, d.dists) +entropy(d::ProductNamedTupleDistribution) = sum(entropy, values(d.dists)) function kldivergence( d1::ProductNamedTupleDistribution{K}, d2::ProductNamedTupleDistribution{K} From 6d8df2ae3f33e97d590cddfcd8ca4847d0d2b3c6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 23 Nov 2023 23:03:09 +0100 Subject: [PATCH 18/38] Remove extension --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 00a0fae15..ad5043b18 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,7 +61,7 @@ const tests = [ "qq", "univariate/continuous/pgeneralizedgaussian", "product", - "namedtuple/productnamedtuple.jl", + "namedtuple/productnamedtuple", "univariate/discrete/discretenonparametric", "univariate/continuous/chernoff", "univariate_bounds", # extra file compared to /src From 800de5b7a2ad5e6b9d23f57f1d91bb2cc9cad6a1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 15 Jul 2024 14:05:11 +0200 Subject: [PATCH 19/38] Apply suggestions from code review Co-authored-by: David Widmann --- src/namedtuple/productnamedtuple.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 2654b1268..6697a8f9f 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -15,7 +15,7 @@ struct ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <: end function ProductNamedTupleDistribution( dists::NamedTuple{K,V} -) where {K,V<:Tuple{Vararg{Distribution}}} +) where {K,V<:Tuple{Distribution,Vararg{Distribution}}} vs = _product_valuesupport(values(dists)) eltypes = _product_namedtuple_eltype(values(dists)) return ProductNamedTupleDistribution{K,V,vs,eltypes}(dists) @@ -125,7 +125,7 @@ end function _rand!(rng::AbstractRNG, d::ProductNamedTupleDistribution, xs::AbstractArray) for i in eachindex(xs) - xs[i] = Random.rand(rng, d) + xs[i] = rand(rng, d) end return xs end From 0b835875dff062a2d837aeacc03ff86f5bdce298 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 15 Jul 2024 14:45:00 +0200 Subject: [PATCH 20/38] Remove marginal --- src/Distributions.jl | 1 - src/common.jl | 10 ---------- src/namedtuple/productnamedtuple.jl | 7 ------- src/product.jl | 2 -- test/namedtuple/productnamedtuple.jl | 19 ------------------- 5 files changed, 39 deletions(-) diff --git a/src/Distributions.jl b/src/Distributions.jl index 05c8156b4..3e72a8f35 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -231,7 +231,6 @@ export sqmahal!, # in-place evaluation of sqmahal location, # get the location parameter location!, # provide storage for the location parameter (used in multivariate distribution mvlognormal) - marginal, # marginal distributions mean, # mean of distribution meandir, # mean direction (of a spherical distribution) meanform, # convert a normal distribution from canonical form to mean form diff --git a/src/common.jl b/src/common.jl index c9e259343..8effc7f27 100644 --- a/src/common.jl +++ b/src/common.jl @@ -470,16 +470,6 @@ Base.@propagate_inbounds function loglikelihood( return sum(Base.Fix1(logpdf, d), x) end -""" - marginal(d::Distribution, k...) -> Distribution - -Return the marginal distribution of `d` at the indices `k...`. - -The result is the distribution of the variate `rand(d)[k...]` that one would obtain by -integrating over all other indices. -""" -marginal(d::Distribution, k...) - ## TODO: the following types need to be improved abstract type SufficientStats end abstract type IncompleteDistribution end diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 6697a8f9f..fa46df16a 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -62,13 +62,6 @@ Base.minimum(d::ProductNamedTupleDistribution) = map(minimum, d.dists) Base.maximum(d::ProductNamedTupleDistribution) = map(maximum, d.dists) -marginal(d::ProductNamedTupleDistribution, k::Union{Int,Symbol}) = d.dists[k] -if VERSION ≥ v"1.7.0-DEV.294" - function marginal(d::ProductNamedTupleDistribution, ks::Tuple{Symbol,Vararg{Symbol}}) - return ProductNamedTupleDistribution(d.dists[ks]) - end -end - function insupport(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K} return all(map(insupport, dist.dists, x)) end diff --git a/src/product.jl b/src/product.jl index 2f52a149b..d6f7aaa6b 100644 --- a/src/product.jl +++ b/src/product.jl @@ -95,8 +95,6 @@ minimum(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(minimum, d.dis maximum(d::ArrayOfUnivariateDistribution) = map(maximum, d.dists) maximum(d::VectorOfUnivariateDistribution{<:Tuple}) = collect(map(maximum, d.dists)) -marginal(d::ProductDistribution, i...) = d.dists[i...] - function entropy(d::ArrayOfUnivariateDistribution) # we use pairwise summation (https://github.com/JuliaLang/julia/pull/31020) return sum(Broadcast.instantiate(Broadcast.broadcasted(entropy, d.dists))) diff --git a/test/namedtuple/productnamedtuple.jl b/test/namedtuple/productnamedtuple.jl index 51d4ab8d0..c082eeda1 100644 --- a/test/namedtuple/productnamedtuple.jl +++ b/test/namedtuple/productnamedtuple.jl @@ -97,25 +97,6 @@ using Test (x=maximum(nt.x), y=maximum(nt.y), z=maximum(nt.z)) end - @testset "marginal" begin - nt = (x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0)) - d = ProductNamedTupleDistribution(nt) - @test marginal(d, :x) === nt[:x] - @test marginal(d, :y) === nt[:y] - @test marginal(d, :z) === nt[:z] - @test marginal(d, 1) === nt[1] - @test marginal(d, 2) === nt[2] - @test marginal(d, 3) === nt[3] - if VERSION ≥ v"1.7.0-DEV.294" - @test marginal(d, (:x, :y)) === - ProductNamedTupleDistribution((x=nt[:x], y=nt[:y])) - @test marginal(d, (:z, :x)) === - ProductNamedTupleDistribution((z=nt[:z], x=nt[:x])) - @test_throws ErrorException marginal(d, (:x, :w)) - end - @test_throws MethodError marginal(d, ()) - end - @testset "insupport" begin nt = (x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0)) d = ProductNamedTupleDistribution(nt) From ba03eea8d35471e9daaf23512026f50a4ece9f00 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 15 Jul 2024 15:38:28 +0200 Subject: [PATCH 21/38] Add sampler for product namedtuple --- src/namedtuple/productnamedtuple.jl | 6 ++++++ src/samplers.jl | 4 +++- src/samplers/productnamedtuple.jl | 21 +++++++++++++++++++++ test/namedtuple/productnamedtuple.jl | 24 ++++++++++++++++++++++-- 4 files changed, 52 insertions(+), 3 deletions(-) create mode 100644 src/samplers/productnamedtuple.jl diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index fa46df16a..04269c9fd 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -106,6 +106,12 @@ end # Sampling +function sampler(d::ProductNamedTupleDistribution{K,<:Any,S}) where {K,S} + samplers = map(sampler, d.dists) + Tsamplers = typeof(values(samplers)) + return ProductNamedTupleSampler{K,Tsamplers,S}(samplers) +end + function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution{K}) where {K} return NamedTuple{K}(map(Base.Fix1(rand, rng), d.dists)) end diff --git a/src/samplers.jl b/src/samplers.jl index 794f2bff4..fa667fff2 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -24,7 +24,9 @@ for fname in ["aliastable.jl", "vonmises.jl", "vonmisesfisher.jl", "discretenonparametric.jl", - "categorical.jl"] + "categorical.jl", + "productnamedtuple.jl", + ] include(joinpath("samplers", fname)) end diff --git a/src/samplers/productnamedtuple.jl b/src/samplers/productnamedtuple.jl new file mode 100644 index 000000000..792ca8569 --- /dev/null +++ b/src/samplers/productnamedtuple.jl @@ -0,0 +1,21 @@ +struct ProductNamedTupleSampler{Tnames,Tsamplers,S<:ValueSupport} <: + Sampleable{NamedTupleVariate{Tnames},S} + samplers::NamedTuple{Tnames,Tsamplers} +end + +function Base.rand(rng::AbstractRNG, spl::ProductNamedTupleSampler{K}) where {K} + return NamedTuple{K}(map(Base.Fix1(rand, rng), spl.samplers)) +end + +function _rand(rng::AbstractRNG, spl::ProductNamedTupleSampler, dims::Dims) + return map(CartesianIndices(dims)) do _ + return rand(rng, spl) + end +end + +function _rand!(rng::AbstractRNG, spl::ProductNamedTupleSampler, xs::AbstractArray) + for i in eachindex(xs) + xs[i] = rand(rng, spl) + end + return xs +end diff --git a/test/namedtuple/productnamedtuple.jl b/test/namedtuple/productnamedtuple.jl index c082eeda1..9ed459bf3 100644 --- a/test/namedtuple/productnamedtuple.jl +++ b/test/namedtuple/productnamedtuple.jl @@ -1,5 +1,5 @@ using Distributions -using Distributions: ProductNamedTupleDistribution +using Distributions: ProductNamedTupleDistribution, ProductNamedTupleSampler using LinearAlgebra using Random using Test @@ -140,9 +140,29 @@ using Test @test std(d3) == (x=std(d3.dists.x), y=std(d3.dists.y)) end - @testset "Sampling" begin + @testset "Sampler" begin + nt1 = (x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0), w=Bernoulli()) + d1 = ProductNamedTupleDistribution(nt1) + # sampler(::Gamma) is type-unstable + spl = @inferred ProductNamedTupleSampler{(:x, :y, :z, :w)} sampler(d1) + @test spl.samplers == (; (k => sampler(v) for (k, v) in pairs(nt1))...) + rng = MersenneTwister(973) + x1 = @inferred rand(rng, d1) rng = MersenneTwister(973) + x2 = ( + x=rand(rng, nt1.x), y=rand(rng, nt1.y), z=rand(rng, nt1.z), w=rand(rng, nt1.w) + ) + @test x1 == x2 + x3 = rand(rng, d1) + @test x3 != x1 + + # sampler should now be type-stable + nt2 = (x=Normal(1.0, 2.0), z=Dirichlet(5, 1.0), w=Bernoulli()) + d2 = ProductNamedTupleDistribution(nt2) + @inferred sampler(d2) + end + @testset "Sampling" begin @testset "rand" begin nt = (x=Normal(1.0, 2.0), y=Gamma(), z=Dirichlet(5, 1.0), w=Bernoulli()) d = ProductNamedTupleDistribution(nt) From 1712be648c5dd85479a337c558514a2fd90d3e63 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 15 Jul 2024 15:38:58 +0200 Subject: [PATCH 22/38] Use ProductNamedTupleSampler for array rand calls --- src/namedtuple/productnamedtuple.jl | 14 +++++--------- test/namedtuple/productnamedtuple.jl | 11 +++++++++-- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 04269c9fd..dda48cfdd 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -115,16 +115,12 @@ end function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution{K}) where {K} return NamedTuple{K}(map(Base.Fix1(rand, rng), d.dists)) end -function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution, dims::Dims) - xs = return map(CartesianIndices(dims)) do _ - return rand(rng, d) - end - return xs +function Base.rand( + rng::AbstractRNG, d::ProductNamedTupleDistribution{K}, dims::Dims +) where {K} + return convert(AbstractArray{<:NamedTuple{K}}, _rand(rng, sampler(d), dims)) end function _rand!(rng::AbstractRNG, d::ProductNamedTupleDistribution, xs::AbstractArray) - for i in eachindex(xs) - xs[i] = rand(rng, d) - end - return xs + return _rand!(rng, sampler(d), xs) end diff --git a/test/namedtuple/productnamedtuple.jl b/test/namedtuple/productnamedtuple.jl index 9ed459bf3..e7f1d7ed6 100644 --- a/test/namedtuple/productnamedtuple.jl +++ b/test/namedtuple/productnamedtuple.jl @@ -177,13 +177,20 @@ using Test x3 = rand(rng, d) @test x3 != x1 - xs1 = @inferred rand(rng, d, 10) + # not completely type-inferrable due to sampler(::Gamma) being type-unstable + xs1 = @inferred Vector{<:NamedTuple{(:x, :y, :z, :w)}} rand(rng, d, 10) @test length(xs1) == 10 @test all(insupport.(Ref(d), xs1)) - xs2 = @inferred rand(rng, d, (2, 3, 4)) + xs2 = @inferred Array{<:NamedTuple{(:x, :y, :z, :w)},3} rand(rng, d, (2, 3, 4)) @test size(xs2) == (2, 3, 4) @test all(insupport.(Ref(d), xs2)) + + nt2 = (x=Normal(1.0, 2.0), z=Dirichlet(5, 1.0), w=Bernoulli()) + d2 = ProductNamedTupleDistribution(nt2) + # now type-inferrable + @inferred rand(rng, d2, 10) + @inferred rand(rng, d2, 2, 3, 4) end @testset "rand!" begin From 1056d0d7e7b25ffdd57fedd9faa29d701c38009e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 19 Aug 2024 10:36:56 +0200 Subject: [PATCH 23/38] Add docs page for product distributions --- docs/make.jl | 1 + docs/src/product.md | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 docs/src/product.md diff --git a/docs/make.jl b/docs/make.jl index f95b3f60b..c5ab4409e 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -17,6 +17,7 @@ makedocs( "reshape.md", "cholesky.md", "mixture.md", + "product.md", "order_statistics.md", "convolution.md", "fit.md", diff --git a/docs/src/product.md b/docs/src/product.md new file mode 100644 index 000000000..63225f8df --- /dev/null +++ b/docs/src/product.md @@ -0,0 +1,2 @@ +# Product Distributions + From 58937fd741f05ae0ba07398a1bcd4fd5eb401fad Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 19 Aug 2024 10:39:41 +0200 Subject: [PATCH 24/38] Fix typo --- src/namedtuple/productnamedtuple.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index dda48cfdd..778bfc01a 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -40,7 +40,7 @@ function distrname(::ProductNamedTupleDistribution{K}) where {K} end """ - product_distribution(dists::Namedtuple{K,Tuple{Vararg{Distribution}}}) where {K} + product_distribution(dists::NamedTuple{K,Tuple{Vararg{Distribution}}}) where {K} Create a distribution of `NamedTuple`s as a product distribution of independent named distributions. From d7fd8421310726b41883cfbe9d1fe48f2c098230 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 19 Aug 2024 11:24:26 +0200 Subject: [PATCH 25/38] Fix ProductNamedTuple docstring --- src/namedtuple/productnamedtuple.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 778bfc01a..6310777bd 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -2,12 +2,12 @@ ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <: Distribution{NamedTupleVariate{Tnames},S} - A distribution of `NamedTuple`s, constructed from a `NamedTuple` of independent named - distributions. +A distribution of `NamedTuple`s, constructed from a `NamedTuple` of independent named +distributions. - Users should use [`product_distribution`](@ref) to construct a product distribution of - independent distributions instead of constructing a `ProductNamedTupleDistribution` - directly. +Users should use [`product_distribution`](@ref) to construct a product distribution of +independent distributions instead of constructing a `ProductNamedTupleDistribution` +directly. """ struct ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <: Distribution{NamedTupleVariate{Tnames},S} From c8b1602235fc80a79e734a449ac54c828dcce19f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 19 Aug 2024 11:24:52 +0200 Subject: [PATCH 26/38] Add deprecation warning to Product docstring --- src/multivariate/product.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/multivariate/product.jl b/src/multivariate/product.jl index 3bb0f0d9e..1227e5ba1 100644 --- a/src/multivariate/product.jl +++ b/src/multivariate/product.jl @@ -10,6 +10,10 @@ An N dimensional `MultivariateDistribution` constructed from a vector of N indep ```julia Product(Uniform.(rand(10), 1)) # A 10-dimensional Product from 10 independent `Uniform` distributions. ``` + +!!! note + `Product` is deprecated and will be removed in the next breaking release. + Use [`product_distribution`](@ref) instead. """ struct Product{ S<:ValueSupport, From db029c57a09189dfc736f25e8787489b15eedc9b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 19 Aug 2024 11:25:58 +0200 Subject: [PATCH 27/38] Move multivariate product distributions to own page --- docs/src/multivariate.md | 10 ---------- docs/src/product.md | 12 ++++++++++++ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/docs/src/multivariate.md b/docs/src/multivariate.md index c4e7c1764..35eff32ca 100644 --- a/docs/src/multivariate.md +++ b/docs/src/multivariate.md @@ -58,7 +58,6 @@ MvNormalCanon MvLogitNormal MvLogNormal Dirichlet -Product ``` ## Addition Methods @@ -105,15 +104,6 @@ params{D<:Distributions.AbstractMvLogNormal}(::Type{D},m::AbstractVector,S::Abst Distributions._logpdf(d::MultivariateDistribution, x::AbstractArray) ``` -## Product distributions - -```@docs -Distributions.product_distribution -``` - -Using `product_distribution` is advised to construct product distributions. -For some distributions, it constructs a special multivariate type. - ## Index ```@index diff --git a/docs/src/product.md b/docs/src/product.md index 63225f8df..e62439b09 100644 --- a/docs/src/product.md +++ b/docs/src/product.md @@ -1,2 +1,14 @@ # Product Distributions +Product distributions are joint distributions of multiple independent distributions. +It is recommended to use `product_distribution` to construct product distributions. +Depending on the type of the argument, it may construct a different distribution type. + +## Multivariate products + +```@docs +Distributions.product_distribution(::AbstractArray{<:Distribution{<:ArrayLikeVariate}}) +Distributions.product_distribution(::AbstractVector{<:Normal}) +Distributions.ProductDistribution +Distributions.Product +``` From 2634adbc28e5c10284b98429b1a0991e127f10d0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 19 Aug 2024 11:26:24 +0200 Subject: [PATCH 28/38] Document NamedTuple products --- docs/src/product.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/src/product.md b/docs/src/product.md index e62439b09..4226ec799 100644 --- a/docs/src/product.md +++ b/docs/src/product.md @@ -12,3 +12,10 @@ Distributions.product_distribution(::AbstractVector{<:Normal}) Distributions.ProductDistribution Distributions.Product ``` + +## NamedTuple-variate products + +```@docs +Distributions.product_distribution(::NamedTuple{<:Any,<:Tuple{Distribution,Vararg{Distribution}}}) +Distributions.ProductNamedTupleDistribution +``` From eb5b1761881d15f2eb27c4f3f76915ae741ee7e9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 19 Aug 2024 11:44:49 +0200 Subject: [PATCH 29/38] Add docs index --- docs/src/product.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/src/product.md b/docs/src/product.md index 4226ec799..9a01b6cd2 100644 --- a/docs/src/product.md +++ b/docs/src/product.md @@ -19,3 +19,9 @@ Distributions.Product Distributions.product_distribution(::NamedTuple{<:Any,<:Tuple{Distribution,Vararg{Distribution}}}) Distributions.ProductNamedTupleDistribution ``` + +## Index + +```@index +Pages = ["product.md"] +``` From 3ebc3ba912f97b2ffaac2c0351a45ebaca6a4d17 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 19 Aug 2024 11:45:16 +0200 Subject: [PATCH 30/38] Document usage of ProductNamedTuple --- src/namedtuple/productnamedtuple.jl | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 6310777bd..d7ff7d708 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -8,6 +8,32 @@ distributions. Users should use [`product_distribution`](@ref) to construct a product distribution of independent distributions instead of constructing a `ProductNamedTupleDistribution` directly. + +# Examples + +```jldoctest ProductNamedTuple; setup = :(using Random; Random.seed!(832)) +julia> d = product_distribution((x=Normal(), y=Dirichlet([2, 4]))) +ProductNamedTupleDistribution{(:x, :y)}( +x: Normal{Float64}(μ=0.0, σ=1.0) +y: Dirichlet{Int64, Vector{Int64}, Float64}(alpha=[2, 4]) +) + + +julia> nt = rand(d) +(x = 1.5155385995160346, y = [0.533531876438439, 0.466468123561561]) + +julia> pdf(d, nt) +0.13702825691074877 + +julia> mode(d) # mode of marginals +(x = 0.0, y = [0.25, 0.75]) + +julia> mean(d) # mean of marginals +(x = 0.0, y = [0.3333333333333333, 0.6666666666666666]) + +julia> var(d) # var of marginals +(x = 1.0, y = [0.031746031746031744, 0.031746031746031744]) +``` """ struct ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <: Distribution{NamedTupleVariate{Tnames},S} From f0dd8c4808f1ba71d42a7f362bb56fdda0b8766e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 19 Aug 2024 13:03:03 +0200 Subject: [PATCH 31/38] Load Distributions for jldoctest --- src/namedtuple/productnamedtuple.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index d7ff7d708..3e77d0a41 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -11,7 +11,7 @@ directly. # Examples -```jldoctest ProductNamedTuple; setup = :(using Random; Random.seed!(832)) +```jldoctest ProductNamedTuple; setup = :(using Distributions, Random; Random.seed!(832)) julia> d = product_distribution((x=Normal(), y=Dirichlet([2, 4]))) ProductNamedTupleDistribution{(:x, :y)}( x: Normal{Float64}(μ=0.0, σ=1.0) From 121dd2bb25fbbb053ff40e754c3d78b78422d5d1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 4 Sep 2024 11:07:40 +0200 Subject: [PATCH 32/38] Apply suggestions from code review Co-authored-by: David Widmann --- src/common.jl | 2 +- src/namedtuple/productnamedtuple.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/common.jl b/src/common.jl index 8effc7f27..4fdb36ecb 100644 --- a/src/common.jl +++ b/src/common.jl @@ -20,7 +20,7 @@ const Matrixvariate = ArrayLikeVariate{2} `F <: NamedTupleVariate{K}` specifies that the variate or a sample is of type `NamedTuple{K}`. """ -abstract type NamedTupleVariate{K} <: VariateForm end +struct NamedTupleVariate{K} <: VariateForm end """ `F <: CholeskyVariate` specifies that the variate or a sample is of type diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 3e77d0a41..4d6b4f03e 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -55,7 +55,7 @@ function _gentype(d::Distribution{CholeskyVariate}) end _gentype(::Distribution) = Any -_product_namedtuple_eltype(dists) = typejoin(map(_gentype, dists)...) +_product_namedtuple_eltype(dists::NamedTuple{K,V}) where {K,V} = __product_promote_type(eltype, V) function Base.show(io::IO, d::ProductNamedTupleDistribution) return show_multline(io, d, collect(pairs(d.dists))) @@ -127,7 +127,7 @@ entropy(d::ProductNamedTupleDistribution) = sum(entropy, values(d.dists)) function kldivergence( d1::ProductNamedTupleDistribution{K}, d2::ProductNamedTupleDistribution{K} ) where {K} - return mapreduce(kldivergence, +, d1.dists, d2.dists) + return sum(map(kldivergence, d1.dists, d2.dists)) end # Sampling From a86cac4585adbf130076a48af981dc842049c66c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 4 Sep 2024 17:32:06 +0200 Subject: [PATCH 33/38] Call method on NamedTuple --- src/namedtuple/productnamedtuple.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index 4d6b4f03e..df01162e1 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -43,7 +43,7 @@ function ProductNamedTupleDistribution( dists::NamedTuple{K,V} ) where {K,V<:Tuple{Distribution,Vararg{Distribution}}} vs = _product_valuesupport(values(dists)) - eltypes = _product_namedtuple_eltype(values(dists)) + eltypes = _product_namedtuple_eltype(dists) return ProductNamedTupleDistribution{K,V,vs,eltypes}(dists) end From 46fdcfc8419b4e25c497e99f6b191d6074f740a8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 5 Sep 2024 14:51:25 +0200 Subject: [PATCH 34/38] Revert to typejoin based eltype --- src/namedtuple/productnamedtuple.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index df01162e1..cb047d204 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -55,7 +55,7 @@ function _gentype(d::Distribution{CholeskyVariate}) end _gentype(::Distribution) = Any -_product_namedtuple_eltype(dists::NamedTuple{K,V}) where {K,V} = __product_promote_type(eltype, V) +_product_namedtuple_eltype(dists) = typejoin(map(_gentype, dists)...) function Base.show(io::IO, d::ProductNamedTupleDistribution) return show_multline(io, d, collect(pairs(d.dists))) From 54b0d03985b10b17bf2e94e3404d7b34a7dd18a7 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 5 Sep 2024 14:54:43 +0200 Subject: [PATCH 35/38] Explicitly check eltype of dist matches that of draw --- test/namedtuple/productnamedtuple.jl | 44 ++++++++++------------------ 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/test/namedtuple/productnamedtuple.jl b/test/namedtuple/productnamedtuple.jl index e7f1d7ed6..ae28c41ad 100644 --- a/test/namedtuple/productnamedtuple.jl +++ b/test/namedtuple/productnamedtuple.jl @@ -53,34 +53,22 @@ using Test @testset "Properties" begin @testset "eltype" begin - nt = (x=Normal(1.0, 2.0), y=Normal(3.0, 4.0)) - d = ProductNamedTupleDistribution(nt) - @test eltype(d) === Float64 - - nt = (x=Normal(), y=Gamma()) - d = ProductNamedTupleDistribution(nt) - @test eltype(d) === Float64 - - nt = (x=Bernoulli(),) - d = ProductNamedTupleDistribution(nt) - @test eltype(d) === Bool - - nt = (x=Normal(), y=Bernoulli()) - d = ProductNamedTupleDistribution(nt) - @test eltype(d) === Real - - nt = (w=LKJCholesky(3, 2.0),) - d = ProductNamedTupleDistribution(nt) - @test eltype(d) === LinearAlgebra.Cholesky{Float64,Array{Float64,2}} - - nt = ( - x=Normal(), - y=Dirichlet(10, 1.0), - z=DiscreteUniform(1, 10), - w=LKJCholesky(3, 2.0), - ) - d = ProductNamedTupleDistribution(nt) - @test eltype(d) === Any + @testset for nt in [ + (x=Normal(1.0, 2.0), y=Normal(3.0, 4.0)), + (x=Normal(), y=Gamma()), + (x=Bernoulli(),), + (x=Normal(), y=Bernoulli()), + (w=LKJCholesky(3, 2.0),), + ( + x=Normal(), + y=Dirichlet(10, 1.0), + z=DiscreteUniform(1, 10), + w=LKJCholesky(3, 2.0), + ), + ] + d = ProductNamedTupleDistribution(nt) + @test eltype(d) === eltype(rand(d)) + end end @testset "minimum" begin From fe284b12f59cb427c088ec797e280951227ffa7d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 5 Sep 2024 14:57:29 +0200 Subject: [PATCH 36/38] Correctly compute eltype for nested prod namedtuple distributions --- src/namedtuple/productnamedtuple.jl | 3 +++ test/namedtuple/productnamedtuple.jl | 1 + 2 files changed, 4 insertions(+) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index cb047d204..e3f3c6880 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -53,6 +53,9 @@ function _gentype(d::Distribution{CholeskyVariate}) T = eltype(d) return LinearAlgebra.Cholesky{T,Matrix{T}} end +function _gentype(d::ProductNamedTupleDistribution{K}) where {K} + return NamedTuple{K,Tuple{map(_gentype, values(d.dists))...}} +end _gentype(::Distribution) = Any _product_namedtuple_eltype(dists) = typejoin(map(_gentype, dists)...) diff --git a/test/namedtuple/productnamedtuple.jl b/test/namedtuple/productnamedtuple.jl index ae28c41ad..294279ec0 100644 --- a/test/namedtuple/productnamedtuple.jl +++ b/test/namedtuple/productnamedtuple.jl @@ -65,6 +65,7 @@ using Test z=DiscreteUniform(1, 10), w=LKJCholesky(3, 2.0), ), + (x = product_distribution((x=Normal(), y=Gamma())),), ] d = ProductNamedTupleDistribution(nt) @test eltype(d) === eltype(rand(d)) From 1eabd23a02257b5dc61c52af7104be53998fbb60 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 5 Sep 2024 19:11:36 +0200 Subject: [PATCH 37/38] Revert "Call method on NamedTuple" This reverts commit a86cac4585adbf130076a48af981dc842049c66c. --- src/namedtuple/productnamedtuple.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/namedtuple/productnamedtuple.jl b/src/namedtuple/productnamedtuple.jl index e3f3c6880..853159ee3 100644 --- a/src/namedtuple/productnamedtuple.jl +++ b/src/namedtuple/productnamedtuple.jl @@ -43,7 +43,7 @@ function ProductNamedTupleDistribution( dists::NamedTuple{K,V} ) where {K,V<:Tuple{Distribution,Vararg{Distribution}}} vs = _product_valuesupport(values(dists)) - eltypes = _product_namedtuple_eltype(dists) + eltypes = _product_namedtuple_eltype(values(dists)) return ProductNamedTupleDistribution{K,V,vs,eltypes}(dists) end From 28a7c0054aa6bb1727cee6415bd3808ff216233f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 5 Sep 2024 23:39:50 +0200 Subject: [PATCH 38/38] Update test/namedtuple/productnamedtuple.jl Co-authored-by: David Widmann --- test/namedtuple/productnamedtuple.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/namedtuple/productnamedtuple.jl b/test/namedtuple/productnamedtuple.jl index 294279ec0..2f121fd33 100644 --- a/test/namedtuple/productnamedtuple.jl +++ b/test/namedtuple/productnamedtuple.jl @@ -43,7 +43,7 @@ using Test @testset "show" begin d = ProductNamedTupleDistribution((x=Gamma(1.0, 2.0), y=Normal())) - @test sprint(show, d) == """ + @test repr(d) == """ ProductNamedTupleDistribution{(:x, :y)}( x: Gamma{Float64}(α=1.0, θ=2.0) y: Normal{Float64}(μ=0.0, σ=1.0)