From b43750b555aea13f7579b18f109056270e72371a Mon Sep 17 00:00:00 2001 From: mohamed82008 Date: Thu, 23 Jan 2020 16:38:46 +1100 Subject: [PATCH] flatten all multi of univariate by default --- src/flatten.jl | 107 +++++++++++++++++++++++++++---------------------- src/multi.jl | 17 +++++--- 2 files changed, 71 insertions(+), 53 deletions(-) diff --git a/src/flatten.jl b/src/flatten.jl index 16372a89..acd862ad 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -1,3 +1,9 @@ +macro register(dist) + return quote + DistributionsAD.eval(getexpr($(esc(dist)))) + DistributionsAD.toflatten(::$(esc(dist))) = true + end +end function getexpr(Tdist) x = gensym() fnames = fieldnames(Tdist) @@ -13,53 +19,58 @@ function getexpr(Tdist) ) return :(flatten(dist::$Tdist) = ($func, $flattened_args)) end -for T in ( Bernoulli, - BetaBinomial, - Binomial, - Geometric, - NegativeBinomial, - Poisson, - Skellam, - PoissonBinomial, - Arcsine, - Beta, - BetaPrime, - Biweight, - Cauchy, - Chernoff, - Chi, - Chisq, - Cosine, - Epanechnikov, - Erlang, - Exponential, - FDist, - Frechet, - Gamma, - GeneralizedExtremeValue, - GeneralizedPareto, - Gumbel, - InverseGamma, - InverseGaussian, - Kolmogorov, - Laplace, - Levy, - LocationScale, - Logistic, - LogitNormal, - LogNormal, - Normal, - NormalCanon, - NormalInverseGaussian, - Pareto, - PGeneralizedGaussian, - Rayleigh, - SymTriangularDist, - TDist, - TriangularDist, - Triweight, - Categorical, - Truncated, - ) +const flattened_dists = [ Bernoulli, + BetaBinomial, + Binomial, + Geometric, + NegativeBinomial, + Poisson, + Skellam, + PoissonBinomial, + Arcsine, + Beta, + BetaPrime, + Biweight, + Cauchy, + Chernoff, + Chi, + Chisq, + Cosine, + Epanechnikov, + Erlang, + Exponential, + FDist, + Frechet, + Gamma, + GeneralizedExtremeValue, + GeneralizedPareto, + Gumbel, + InverseGamma, + InverseGaussian, + Kolmogorov, + Laplace, + Levy, + LocationScale, + Logistic, + LogitNormal, + LogNormal, + Normal, + NormalCanon, + NormalInverseGaussian, + Pareto, + PGeneralizedGaussian, + Rayleigh, + SymTriangularDist, + TDist, + TriangularDist, + Triweight, + Categorical, + Truncated, + ] +for T in flattened_dists + @eval toflatten(::T) = true +end +toflatten(::Distribution) = false +for T in flattened_dists eval(getexpr(T)) end diff --git a/src/multi.jl b/src/multi.jl index 35aacbcf..03a04ee7 100644 --- a/src/multi.jl +++ b/src/multi.jl @@ -61,12 +61,19 @@ function Distributions.logpdf( dist::MultipleContinuousUnivariate, x::AbstractVector{<:Real}, ) - f, args = flatten(dist.dist) - return sum(f.(args..., x)) + return _flat_logpdf(dist.dist, x) end function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleContinuousUnivariate) return rand(rng, dist.dist, dist.N) end +function _flat_logpdf(dist, x) + if toflatten(dist) + f, args = flatten(dist) + return sum(f.(args..., x)) + else + return sum(logpdf.(dist, x)) + end +end struct MatrixContinuousUnivariate{ Tdist <: ContinuousUnivariateDistribution, @@ -83,7 +90,7 @@ function Distributions.logpdf( dist::MatrixContinuousUnivariate, x::AbstractMatrix{<:Real} ) - return sum(logpdf.(dist.dist, x)) + return _flat_logpdf(dist.dist, x) end function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixContinuousUnivariate) return rand(rng, dist.dist, dist.S) @@ -106,7 +113,7 @@ function Distributions.logpdf( dist::MultipleDiscreteUnivariate, x::AbstractVector{<:Integer} ) - return sum(logpdf.(dist.dist, x)) + return _flat_logpdf(dist.dist, x) end function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleDiscreteUnivariate) return rand(rng, dist.dist, dist.N) @@ -127,7 +134,7 @@ function Distributions.logpdf( dist::MatrixDiscreteUnivariate, x::AbstractMatrix{<:Real} ) - return sum(logpdf.(dist.dist, x)) + return _flat_logpdf(dist.dist, x) end function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixDiscreteUnivariate) return rand(rng, dist.dist, dist.S)