Skip to content

Commit 57d4387

Browse files
committed
flatten all multi of univariate by default
1 parent 9417ca3 commit 57d4387

File tree

2 files changed

+71
-53
lines changed

2 files changed

+71
-53
lines changed

src/flatten.jl

+59-48
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
macro register(dist)
2+
return quote
3+
DistributionsAD.eval(getexpr($(esc(dist))))
4+
DistributionsAD.toflatten(::$(esc(dist))) = true
5+
end
6+
end
17
function getexpr(Tdist)
28
x = gensym()
39
fnames = fieldnames(Tdist)
@@ -13,53 +19,58 @@ function getexpr(Tdist)
1319
)
1420
return :(flatten(dist::$Tdist) = ($func, $flattened_args))
1521
end
16-
for T in ( Bernoulli,
17-
BetaBinomial,
18-
Binomial,
19-
Geometric,
20-
NegativeBinomial,
21-
Poisson,
22-
Skellam,
23-
PoissonBinomial,
24-
Arcsine,
25-
Beta,
26-
BetaPrime,
27-
Biweight,
28-
Cauchy,
29-
Chernoff,
30-
Chi,
31-
Chisq,
32-
Cosine,
33-
Epanechnikov,
34-
Erlang,
35-
Exponential,
36-
FDist,
37-
Frechet,
38-
Gamma,
39-
GeneralizedExtremeValue,
40-
GeneralizedPareto,
41-
Gumbel,
42-
InverseGamma,
43-
InverseGaussian,
44-
Kolmogorov,
45-
Laplace,
46-
Levy,
47-
LocationScale,
48-
Logistic,
49-
LogitNormal,
50-
LogNormal,
51-
Normal,
52-
NormalCanon,
53-
NormalInverseGaussian,
54-
Pareto,
55-
PGeneralizedGaussian,
56-
Rayleigh,
57-
SymTriangularDist,
58-
TDist,
59-
TriangularDist,
60-
Triweight,
61-
Categorical,
62-
Truncated,
63-
)
22+
const flattened_dists = [ Bernoulli,
23+
BetaBinomial,
24+
Binomial,
25+
Geometric,
26+
NegativeBinomial,
27+
Poisson,
28+
Skellam,
29+
PoissonBinomial,
30+
Arcsine,
31+
Beta,
32+
BetaPrime,
33+
Biweight,
34+
Cauchy,
35+
Chernoff,
36+
Chi,
37+
Chisq,
38+
Cosine,
39+
Epanechnikov,
40+
Erlang,
41+
Exponential,
42+
FDist,
43+
Frechet,
44+
Gamma,
45+
GeneralizedExtremeValue,
46+
GeneralizedPareto,
47+
Gumbel,
48+
InverseGamma,
49+
InverseGaussian,
50+
Kolmogorov,
51+
Laplace,
52+
Levy,
53+
LocationScale,
54+
Logistic,
55+
LogitNormal,
56+
LogNormal,
57+
Normal,
58+
NormalCanon,
59+
NormalInverseGaussian,
60+
Pareto,
61+
PGeneralizedGaussian,
62+
Rayleigh,
63+
SymTriangularDist,
64+
TDist,
65+
TriangularDist,
66+
Triweight,
67+
Categorical,
68+
Truncated,
69+
]
70+
for T in flattened_dists
71+
@eval toflatten(::T) = true
72+
end
73+
toflatten(::Distribution) = false
74+
for T in flattened_dists
6475
eval(getexpr(T))
6576
end

src/multi.jl

+12-5
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,19 @@ function Distributions.logpdf(
6161
dist::MultipleContinuousUnivariate,
6262
x::AbstractVector{<:Real},
6363
)
64-
f, args = flatten(dist.dist)
65-
return sum(f.(args..., x))
64+
return _flat_logpdf(dist.dist, x)
6665
end
6766
function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleContinuousUnivariate)
6867
return rand(rng, dist.dist, dist.N)
6968
end
69+
function _flat_logpdf(dist, x)
70+
if toflatten(dist)
71+
f, args = flatten(dist)
72+
return sum(f.(args..., x))
73+
else
74+
return sum(logpdf.(dist, x))
75+
end
76+
end
7077

7178
struct MatrixContinuousUnivariate{
7279
Tdist <: ContinuousUnivariateDistribution,
@@ -83,7 +90,7 @@ function Distributions.logpdf(
8390
dist::MatrixContinuousUnivariate,
8491
x::AbstractMatrix{<:Real}
8592
)
86-
return sum(logpdf.(dist.dist, x))
93+
return _flat_logpdf(dist.dist, x)
8794
end
8895
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixContinuousUnivariate)
8996
return rand(rng, dist.dist, dist.S)
@@ -106,7 +113,7 @@ function Distributions.logpdf(
106113
dist::MultipleDiscreteUnivariate,
107114
x::AbstractVector{<:Integer}
108115
)
109-
return sum(logpdf.(dist.dist, x))
116+
return _flat_logpdf(dist.dist, x)
110117
end
111118
function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleDiscreteUnivariate)
112119
return rand(rng, dist.dist, dist.N)
@@ -127,7 +134,7 @@ function Distributions.logpdf(
127134
dist::MatrixDiscreteUnivariate,
128135
x::AbstractMatrix{<:Real}
129136
)
130-
return sum(logpdf.(dist.dist, x))
137+
return _flat_logpdf(dist.dist, x)
131138
end
132139
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixDiscreteUnivariate)
133140
return rand(rng, dist.dist, dist.S)

0 commit comments

Comments
 (0)