Skip to content

Commit

Permalink
flatten all multi of univariate by default
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed Jan 23, 2020
1 parent 4ee4798 commit b43750b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 53 deletions.
107 changes: 59 additions & 48 deletions src/flatten.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
macro register(dist)
return quote
DistributionsAD.eval(getexpr($(esc(dist))))
DistributionsAD.toflatten(::$(esc(dist))) = true
end
end
function getexpr(Tdist)
x = gensym()
fnames = fieldnames(Tdist)
Expand All @@ -13,53 +19,58 @@ function getexpr(Tdist)
)
return :(flatten(dist::$Tdist) = ($func, $flattened_args))
end
for T in ( Bernoulli,
BetaBinomial,
Binomial,
Geometric,
NegativeBinomial,
Poisson,
Skellam,
PoissonBinomial,
Arcsine,
Beta,
BetaPrime,
Biweight,
Cauchy,
Chernoff,
Chi,
Chisq,
Cosine,
Epanechnikov,
Erlang,
Exponential,
FDist,
Frechet,
Gamma,
GeneralizedExtremeValue,
GeneralizedPareto,
Gumbel,
InverseGamma,
InverseGaussian,
Kolmogorov,
Laplace,
Levy,
LocationScale,
Logistic,
LogitNormal,
LogNormal,
Normal,
NormalCanon,
NormalInverseGaussian,
Pareto,
PGeneralizedGaussian,
Rayleigh,
SymTriangularDist,
TDist,
TriangularDist,
Triweight,
Categorical,
Truncated,
)
const flattened_dists = [ Bernoulli,
BetaBinomial,
Binomial,
Geometric,
NegativeBinomial,
Poisson,
Skellam,
PoissonBinomial,
Arcsine,
Beta,
BetaPrime,
Biweight,
Cauchy,
Chernoff,
Chi,
Chisq,
Cosine,
Epanechnikov,
Erlang,
Exponential,
FDist,
Frechet,
Gamma,
GeneralizedExtremeValue,
GeneralizedPareto,
Gumbel,
InverseGamma,
InverseGaussian,
Kolmogorov,
Laplace,
Levy,
LocationScale,
Logistic,
LogitNormal,
LogNormal,
Normal,
NormalCanon,
NormalInverseGaussian,
Pareto,
PGeneralizedGaussian,
Rayleigh,
SymTriangularDist,
TDist,
TriangularDist,
Triweight,
Categorical,
Truncated,
]
for T in flattened_dists
@eval toflatten(::T) = true
end
toflatten(::Distribution) = false
for T in flattened_dists
eval(getexpr(T))
end
17 changes: 12 additions & 5 deletions src/multi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,19 @@ function Distributions.logpdf(
dist::MultipleContinuousUnivariate,
x::AbstractVector{<:Real},
)
f, args = flatten(dist.dist)
return sum(f.(args..., x))
return _flat_logpdf(dist.dist, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleContinuousUnivariate)
return rand(rng, dist.dist, dist.N)
end
function _flat_logpdf(dist, x)
if toflatten(dist)
f, args = flatten(dist)
return sum(f.(args..., x))
else
return sum(logpdf.(dist, x))
end
end

struct MatrixContinuousUnivariate{
Tdist <: ContinuousUnivariateDistribution,
Expand All @@ -83,7 +90,7 @@ function Distributions.logpdf(
dist::MatrixContinuousUnivariate,
x::AbstractMatrix{<:Real}
)
return sum(logpdf.(dist.dist, x))
return _flat_logpdf(dist.dist, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixContinuousUnivariate)
return rand(rng, dist.dist, dist.S)
Expand All @@ -106,7 +113,7 @@ function Distributions.logpdf(
dist::MultipleDiscreteUnivariate,
x::AbstractVector{<:Integer}
)
return sum(logpdf.(dist.dist, x))
return _flat_logpdf(dist.dist, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::MultipleDiscreteUnivariate)
return rand(rng, dist.dist, dist.N)
Expand All @@ -127,7 +134,7 @@ function Distributions.logpdf(
dist::MatrixDiscreteUnivariate,
x::AbstractMatrix{<:Real}
)
return sum(logpdf.(dist.dist, x))
return _flat_logpdf(dist.dist, x)
end
function Distributions.rand(rng::Random.AbstractRNG, dist::MatrixDiscreteUnivariate)
return rand(rng, dist.dist, dist.S)
Expand Down

0 comments on commit b43750b

Please sign in to comment.