Skip to content

Commit

Permalink
Added Pareto RV
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck authored and brandonwillard committed Apr 2, 2021
1 parent b2229fc commit 16ee436
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
18 changes: 18 additions & 0 deletions aesara/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,24 @@ def rng_fn(cls, rng, shape, scale, size):
gamma = GammaRV()


class ParetoRV(RandomVariable):
name = "pareto"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
_print_name = ("Pareto", "\\operatorname{Pareto}")

def __call__(self, b, scale=1.0, size=None, **kwargs):
return super().__call__(b, scale, size=size, **kwargs)

@classmethod
def rng_fn(cls, rng, b, scale, size):
return stats.pareto.rvs(b, scale=scale, size=size, random_state=rng)


pareto = ParetoRV()


class ExponentialRV(RandomVariable):
name = "exponential"
ndim_supp = 0
Expand Down
8 changes: 8 additions & 0 deletions tests/tensor/random/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
multivariate_normal,
nbinom,
normal,
pareto,
permutation,
poisson,
polyagamma,
Expand Down Expand Up @@ -227,6 +228,13 @@ def test_exponential_samples():
rv_numpy_tester(exponential, test_lambda, size=[2, 3])


def test_pareto_samples():
test_alpha = np.array(0.5, dtype=config.floatX)

rv_numpy_tester(pareto, test_alpha, test_fn=stats.pareto.rvs)
rv_numpy_tester(pareto, test_alpha, size=[2, 3], test_fn=stats.pareto.rvs)


def test_mvnormal_samples():
def test_fn(mean=None, cov=None, size=None, rng=None):
if mean is None:
Expand Down

0 comments on commit 16ee436

Please sign in to comment.