diff --git a/.gitignore b/.gitignore index 99e3306ab6..b8530882a1 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,4 @@ core aesara-venv/ /notebooks/Sandbox* +.vscode/ diff --git a/aesara/tensor/random/basic.py b/aesara/tensor/random/basic.py index 23f32d32c1..503f5ff954 100644 --- a/aesara/tensor/random/basic.py +++ b/aesara/tensor/random/basic.py @@ -1,3 +1,5 @@ +from typing import List, Optional, Union + import numpy as np import scipy.stats as stats @@ -108,6 +110,36 @@ def rng_fn(cls, rng, b, scale, size): pareto = ParetoRV() +class GumbelRV(RandomVariable): + name = "gumbel" + ndim_supp = 0 + ndims_params = [0, 0] + dtype = "floatX" + _print_name = ("Gumbel", "\\operatorname{Gumbel}") + + def __call__( + self, + loc: Union[np.ndarray, float], + scale: Union[np.ndarray, float] = 1.0, + size: Optional[Union[List[int], int]] = None, + **kwargs + ) -> RandomVariable: + return super().__call__(loc, scale, size=size, **kwargs) + + @classmethod + def rng_fn( + cls, + rng: np.random.RandomState, + loc: Union[np.ndarray, float], + scale: Union[np.ndarray, float], + size: Optional[Union[List[int], int]], + ) -> np.ndarray: + return stats.gumbel_r.rvs(loc=loc, scale=scale, size=size, random_state=rng) + + +gumbel = GumbelRV() + + class ExponentialRV(RandomVariable): name = "exponential" ndim_supp = 0 diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 518f69a86f..26e141b444 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -23,6 +23,7 @@ dirichlet, exponential, gamma, + gumbel, halfcauchy, halfnormal, invgamma, @@ -218,6 +219,14 @@ def test_gamma_samples(): rv_numpy_tester(gamma, test_a, test_b, size=[2, 3], test_fn=stats.gamma.rvs) +def test_gumbel_samples(): + test_mu = np.array(0.0, dtype=config.floatX) + test_beta = np.array(1.0, dtype=config.floatX) + + rv_numpy_tester(gumbel, test_mu, test_beta, test_fn=stats.gumbel_r.rvs) + rv_numpy_tester(gumbel, test_mu, test_beta, size=[2, 3], test_fn=stats.gumbel_r.rvs) + + def test_exponential_samples(): rv_numpy_tester(exponential)