Skip to content

Commit 2223745

Browse files
committed
locking and split() method
1 parent 68e968d commit 2223745

File tree

2 files changed

+25
-20
lines changed

2 files changed

+25
-20
lines changed

README.md

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# `rng-jax`JAX random number generation as a NumPy generator
1+
# `rng-jax`NumPy random number generator API for JAX
22

33
**This is a proof of concept only.**
44

@@ -9,7 +9,6 @@ Wraps JAX's stateless random number generation in a class implementing the
99

1010
```py
1111
>>> import rng_jax
12-
>>>
1312
>>> rng = rng_jax.Generator(42) # same arguments as jax.random.key()
1413
>>> rng.standard_normal(3)
1514
Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32)
@@ -38,23 +37,23 @@ package is to work in tandem with the Array API: array-agnostic code is not
3837
usually compiled at low level. Conversely, native JAX code usually expects a
3938
`key`, anyway, not a `rng_jax.Generator` instance.
4039

41-
To interface with a native JAX function expecting a `key`, use the `.key()`
40+
To interface with a native JAX function expecting a `key`, use the `.split()`
4241
method to obtain a new random key and advance the internal state of the
4342
generator:
4443

4544
```py
45+
>>> import jax
4646
>>> rng = rng_jax.Generator(42)
47-
>>> key = rng.key()
47+
>>> key = rng.split()
4848
>>> jax.random.normal(key, 3)
4949
Array([-0.5675502 , 0.28439185, -0.9320608 ], dtype=float32)
50-
>>> key = rng.key()
50+
>>> key = rng.split()
5151
>>> jax.random.normal(key, 3)
5252
Array([ 0.67903334, -1.220606 , 0.94670606], dtype=float32)
5353
```
5454

55-
The right way to compile array-agnostic code is usually to compile the "main"
56-
function at the highest level of the code. Using the `rng_jax.Generator` class
57-
fully _within_ a compiled function works without issue.
55+
Using the `rng_jax.Generator` class fully _within_ a compiled JAX function
56+
works without issue.
5857

5958
[array-api]: https://data-apis.org/array-api/latest/
6059
[generator]: https://numpy.org/doc/stable/reference/random/generator.html

rng_jax.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import math
6+
from threading import Lock
67
from typing import Literal, Self, TypeAlias
78

89
from jax import Array
@@ -52,8 +53,9 @@ class Generator:
5253
Wrapper class for JAX random number generation.
5354
"""
5455

55-
__slots__ = ("_key",)
56-
_key: Array
56+
__slots__ = ("key", "lock")
57+
key: Array
58+
lock: Lock
5759

5860
@classmethod
5961
def from_key(cls, key: Array) -> Self:
@@ -63,38 +65,43 @@ def from_key(cls, key: Array) -> Self:
6365
if not isinstance(key, Array) or not issubdtype(key.dtype, prng_key):
6466
raise ValueError("not a random key")
6567
rng = object.__new__(cls)
66-
rng._key = key
68+
rng.key = key
69+
rng.lock = Lock()
6770
return rng
6871

6972
def __init__(self, seed: int | ArrayLike, *, impl: str | None = None) -> None:
7073
"""
7174
Create a wrapper instance with a new key.
7275
"""
73-
self._key = key(seed, impl=impl)
76+
self.key = key(seed, impl=impl)
77+
self.lock = Lock()
7478

7579
@property
7680
def __key(self) -> Array:
7781
"""
7882
Return next key for sampling while updating internal state.
7983
"""
80-
self._key, key = split(self._key)
84+
with self.lock:
85+
self.key, key = split(self.key)
8186
return key
8287

83-
def key(self, size: Size = None) -> Array:
88+
def split(self, size: Size = None) -> Array:
8489
"""
85-
Return random key, advancing internal state.
90+
Split random key.
8691
"""
8792
shape = _s(size)
88-
keys = split(self._key, 1 + math.prod(shape))
89-
self._key = keys[0]
93+
with self.lock:
94+
keys = split(self.key, 1 + math.prod(shape))
95+
self.key = keys[0]
9096
return keys[1:].reshape(shape)
9197

9298
def spawn(self, n_children: int) -> list[Self]:
9399
"""
94100
Create new independent child generators.
95101
"""
96-
self._key, *subkeys = split(self._key, num=n_children + 1)
97-
return list(map(self.from_key, subkeys))
102+
with self.lock:
103+
self.key, *keys = split(self.key, num=n_children + 1)
104+
return list(map(self.from_key, keys))
98105

99106
def integers(
100107
self,
@@ -119,7 +126,6 @@ def random(self, size: Size = None, dtype: DTypeLike = float) -> Array:
119126
"""
120127
Return random floats in the half-open interval [0.0, 1.0).
121128
"""
122-
self._key, key = split(self._key)
123129
return uniform(self.__key, _s(size), dtype)
124130

125131
def choice(

0 commit comments

Comments
 (0)