From e3f8a67ada079bea6748db34e41febab53051e37 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 15:26:59 +0100 Subject: [PATCH] Added default impls for `filldist` and `arraydist` (#264) * added default impls for `filldist` and `arraydist` * Apply suggestions from code review Co-authored-by: David Widmann * Update filldist.jl (#265) --------- Co-authored-by: David Widmann --- Project.toml | 2 +- src/arraydist.jl | 7 +++++++ src/filldist.jl | 11 ++++++++++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 8d61a5a..e510591 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DistributionsAD" uuid = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -version = "0.6.54" +version = "0.6.55" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/arraydist.jl b/src/arraydist.jl index 28e9e2b..867fdc0 100644 --- a/src/arraydist.jl +++ b/src/arraydist.jl @@ -1,3 +1,10 @@ +""" + arraydist(dists) + +Create a distribution from an array of distributions. +""" +arraydist(dists::AbstractArray{<:Distribution}) = product_distribution(dists) + # Univariate const VectorOfUnivariate = Distributions.Product diff --git a/src/filldist.jl b/src/filldist.jl index 67b758a..d67b0b2 100644 --- a/src/filldist.jl +++ b/src/filldist.jl @@ -1,5 +1,14 @@ +# Default implementation just defers to Distributions.jl. +""" + filldist(d::Distribution, ns...) + +Create a product distribution using `FillArrays.Fill` as the array type. +""" +filldist(d::Distribution, n1::Int, ns::Int...) = product_distribution(Fill(d, n1, ns...)) + # Univariate +# TODO: Do we even need these? Probably should benchmark to be sure. const FillVectorOfUnivariate{ S <: ValueSupport, T <: UnivariateDistribution{S}, @@ -59,7 +68,7 @@ const FillMatrixOfUnivariate{ Tdists <: Fill{T, 2}, } = MatrixOfUnivariate{S, T, Tdists} -function filldist(dist::UnivariateDistribution, N1::Integer, N2::Integer) +function filldist(dist::UnivariateDistribution, N1::Int, N2::Int) return MatrixOfUnivariate(Fill(dist, N1, N2)) end function Distributions._logpdf(dist::FillMatrixOfUnivariate, x::AbstractMatrix{<:Real})