Skip to content

Commit

Permalink
Fix scalar size issue in CategoricalRV
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Mar 26, 2021
1 parent ce9c17f commit 0237886
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
10 changes: 7 additions & 3 deletions aesara/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,10 +413,14 @@ def rng_fn(cls, rng, p, size):
size = tuple(np.atleast_1d(size))
ind_shape = p.shape[:-1]

if len(size) > 0 and size[-len(ind_shape) :] != ind_shape:
raise ValueError("Parameters shape and size do not match.")
if len(ind_shape) > 0:
if len(size) > 0 and size[-len(ind_shape) :] != ind_shape:
raise ValueError("Parameters shape and size do not match.")

samples_shape = size[: -len(ind_shape)] + ind_shape
else:
samples_shape = size

samples_shape = size[: -len(ind_shape)] + ind_shape
unif_samples = rng.uniform(size=samples_shape)
samples = vsearchsorted(p.cumsum(axis=-1), unif_samples)

Expand Down
4 changes: 4 additions & 0 deletions tests/tensor/random/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,10 @@ def test_categorical_samples():

rng_state = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1234)))

assert categorical.rng_fn(rng_state, np.array([1.0 / 3.0] * 3), size=10).shape == (
10,
)

p = np.array([[100000, 1, 1], [1, 100000, 1], [1, 1, 100000]], dtype=config.floatX)
p = p / p.sum(axis=-1)

Expand Down

0 comments on commit 0237886

Please sign in to comment.