3
3
"""
4
4
5
5
import math
6
+ from threading import Lock
6
7
from typing import Literal , Self , TypeAlias
7
8
8
9
from jax import Array
@@ -52,8 +53,9 @@ class Generator:
52
53
Wrapper class for JAX random number generation.
53
54
"""
54
55
55
- __slots__ = ("_key" ,)
56
- _key : Array
56
+ __slots__ = ("key" , "lock" )
57
+ key : Array
58
+ lock : Lock
57
59
58
60
@classmethod
59
61
def from_key (cls , key : Array ) -> Self :
@@ -63,38 +65,43 @@ def from_key(cls, key: Array) -> Self:
63
65
if not isinstance (key , Array ) or not issubdtype (key .dtype , prng_key ):
64
66
raise ValueError ("not a random key" )
65
67
rng = object .__new__ (cls )
66
- rng ._key = key
68
+ rng .key = key
69
+ rng .lock = Lock ()
67
70
return rng
68
71
69
72
def __init__ (self , seed : int | ArrayLike , * , impl : str | None = None ) -> None :
70
73
"""
71
74
Create a wrapper instance with a new key.
72
75
"""
73
- self ._key = key (seed , impl = impl )
76
+ self .key = key (seed , impl = impl )
77
+ self .lock = Lock ()
74
78
75
79
@property
76
80
def __key (self ) -> Array :
77
81
"""
78
82
Return next key for sampling while updating internal state.
79
83
"""
80
- self ._key , key = split (self ._key )
84
+ with self .lock :
85
+ self .key , key = split (self .key )
81
86
return key
82
87
83
- def key (self , size : Size = None ) -> Array :
88
+ def split (self , size : Size = None ) -> Array :
84
89
"""
85
- Return random key, advancing internal state .
90
+ Split random key.
86
91
"""
87
92
shape = _s (size )
88
- keys = split (self ._key , 1 + math .prod (shape ))
89
- self ._key = keys [0 ]
93
+ with self .lock :
94
+ keys = split (self .key , 1 + math .prod (shape ))
95
+ self .key = keys [0 ]
90
96
return keys [1 :].reshape (shape )
91
97
92
98
def spawn (self , n_children : int ) -> list [Self ]:
93
99
"""
94
100
Create new independent child generators.
95
101
"""
96
- self ._key , * subkeys = split (self ._key , num = n_children + 1 )
97
- return list (map (self .from_key , subkeys ))
102
+ with self .lock :
103
+ self .key , * keys = split (self .key , num = n_children + 1 )
104
+ return list (map (self .from_key , keys ))
98
105
99
106
def integers (
100
107
self ,
@@ -119,7 +126,6 @@ def random(self, size: Size = None, dtype: DTypeLike = float) -> Array:
119
126
"""
120
127
Return random floats in the half-open interval [0.0, 1.0).
121
128
"""
122
- self ._key , key = split (self ._key )
123
129
return uniform (self .__key , _s (size ), dtype )
124
130
125
131
def choice (
0 commit comments