Skip to content

Commit

Permalink
Added Gumbel RV
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck authored and brandonwillard committed Apr 2, 2021
1 parent 16ee436 commit 06c1792
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,4 @@ core

aesara-venv/
/notebooks/Sandbox*
.vscode/
32 changes: 32 additions & 0 deletions aesara/tensor/random/basic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Optional, Union

import numpy as np
import scipy.stats as stats

Expand Down Expand Up @@ -108,6 +110,36 @@ def rng_fn(cls, rng, b, scale, size):
pareto = ParetoRV()


class GumbelRV(RandomVariable):
name = "gumbel"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
_print_name = ("Gumbel", "\\operatorname{Gumbel}")

def __call__(
self,
loc: Union[np.ndarray, float],
scale: Union[np.ndarray, float] = 1.0,
size: Optional[Union[List[int], int]] = None,
**kwargs
) -> RandomVariable:
return super().__call__(loc, scale, size=size, **kwargs)

@classmethod
def rng_fn(
cls,
rng: np.random.RandomState,
loc: Union[np.ndarray, float],
scale: Union[np.ndarray, float],
size: Optional[Union[List[int], int]],
) -> np.ndarray:
return stats.gumbel_r.rvs(loc=loc, scale=scale, size=size, random_state=rng)


gumbel = GumbelRV()


class ExponentialRV(RandomVariable):
name = "exponential"
ndim_supp = 0
Expand Down
9 changes: 9 additions & 0 deletions tests/tensor/random/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
dirichlet,
exponential,
gamma,
gumbel,
halfcauchy,
halfnormal,
invgamma,
Expand Down Expand Up @@ -218,6 +219,14 @@ def test_gamma_samples():
rv_numpy_tester(gamma, test_a, test_b, size=[2, 3], test_fn=stats.gamma.rvs)


def test_gumbel_samples():
test_mu = np.array(0.0, dtype=config.floatX)
test_beta = np.array(1.0, dtype=config.floatX)

rv_numpy_tester(gumbel, test_mu, test_beta, test_fn=stats.gumbel_r.rvs)
rv_numpy_tester(gumbel, test_mu, test_beta, size=[2, 3], test_fn=stats.gumbel_r.rvs)


def test_exponential_samples():

rv_numpy_tester(exponential)
Expand Down

0 comments on commit 06c1792

Please sign in to comment.