15
15
"""Futures that can be used for signaling for synchronization."""
16
16
17
17
import threading
18
- from typing import Any , Coroutine , Optional
18
+ import time
19
+ from typing import Any , Callable , Coroutine , Optional , Sequence
19
20
20
21
from absl import logging
22
+ import jax
21
23
from orbax .checkpoint ._src import asyncio_utils
22
24
from orbax .checkpoint ._src .futures import synchronization
23
25
from orbax .checkpoint ._src .multihost import multihost
27
29
get_unique_barrier_key = (
28
30
synchronization .HandlerAwaitableSignalBarrierKeyGenerator .get_unique_barrier_key
29
31
)
32
+ is_intialized = (
33
+ synchronization .HandlerAwaitableSignalBarrierKeyGenerator .is_intialized
34
+ )
35
+ HandlerAwaitableSignal = synchronization .HandlerAwaitableSignal
30
36
_SIGNAL_ACTION_SUCCESS = 'signal_action_success'
31
37
32
38
39
+ def get_awaitable_signals_from_contract () -> Sequence [HandlerAwaitableSignal ]:
40
+ """Gets the awaitable signals that may be sent for the current operation id."""
41
+ client = multihost .get_jax_distributed_client ()
42
+ barrier_key = get_unique_barrier_key (
43
+ HandlerAwaitableSignal .AWAITABLE_SIGNALS_CONTRACT
44
+ )
45
+ try :
46
+ values_str = str (client .key_value_try_get (barrier_key ))
47
+ return [HandlerAwaitableSignal (value ) for value in values_str .split (',' )]
48
+ except jax .errors .JaxRuntimeError :
49
+ # If the key is not found, then there are no awaitable signals yet.
50
+ return []
51
+
52
+
53
+ def add_to_awaitable_signals_contract (
54
+ signals : Sequence [HandlerAwaitableSignal ],
55
+ ):
56
+ """Adds awaitable signals to `AWAITABLE_SIGNALS_CONTRACT` for lower checkpointing layers to wait on.
57
+
58
+ These signals are added to the list of awaitable signals for the current
59
+ opertation id in `HandlerAwaitableSignalBarrierKeyGenerator`.
60
+
61
+ Args:
62
+ signals: The signals to add to the list of awaitable signals.
63
+ """
64
+ if not signals :
65
+ return
66
+
67
+ current_signals = list (get_awaitable_signals_from_contract ())
68
+ current_signals .extend (signals )
69
+ keys = ',' .join ([current_signal .value for current_signal in current_signals ])
70
+ client = multihost .get_jax_distributed_client ()
71
+ barrier_key = get_unique_barrier_key (
72
+ HandlerAwaitableSignal .AWAITABLE_SIGNALS_CONTRACT
73
+ )
74
+ client .key_value_set (barrier_key , keys , allow_overwrite = True )
75
+
76
+
33
77
class Future (Protocol ):
34
78
"""Abstracted Orbax Future class.
35
79
@@ -48,6 +92,62 @@ def result(self, timeout: Optional[int] = None) -> Any:
48
92
...
49
93
50
94
95
+ class NoopFuture :
96
+
97
+ def result (self , timeout : Optional [int ] = None ) -> Any :
98
+ del timeout
99
+ return None
100
+
101
+
102
+ class ChainedFuture :
103
+ """A future representing a sequence of multiple futures."""
104
+
105
+ def __init__ (self , futures : Sequence [Future ], cb : Callable [[], None ]):
106
+ self ._futures = futures
107
+ self ._cb = cb
108
+
109
+ def result (self , timeout : Optional [int ] = None ) -> Any :
110
+ """Waits for all futures to complete."""
111
+ n = len (self ._futures )
112
+ start = time .time ()
113
+ time_remaining = timeout
114
+ for k , f in enumerate (self ._futures ):
115
+ f .result (timeout = time_remaining )
116
+ if time_remaining is not None :
117
+ time_elapsed = time .time () - start
118
+ time_remaining -= time_elapsed
119
+ if time_remaining <= 0 :
120
+ raise TimeoutError (
121
+ 'ChainedFuture completed {:d}/{:d} futures but timed out after'
122
+ ' {:.2f} seconds.' .format (k , n , time_elapsed )
123
+ )
124
+ time_elapsed = time .time () - start
125
+ logging .info (
126
+ 'ChainedFuture completed %d/%d futures in %.2f seconds.' ,
127
+ n ,
128
+ n ,
129
+ time_elapsed ,
130
+ )
131
+ self ._cb ()
132
+
133
+
134
+ class ThreadRaisingException (threading .Thread ):
135
+ """Thread that raises an exception if it encounters an error."""
136
+
137
+ _exception : Optional [Exception ] = None
138
+
139
+ def run (self ):
140
+ try :
141
+ super ().run ()
142
+ except Exception as e : # pylint: disable=broad-exception-caught
143
+ self ._exception = e
144
+
145
+ def join (self , timeout = None ):
146
+ super ().join (timeout = timeout )
147
+ if self ._exception is not None :
148
+ raise self ._exception
149
+
150
+
51
151
class _SignalingThread (threading .Thread ):
52
152
"""Thread that raises an exception if it encounters an error.
53
153
@@ -60,8 +160,8 @@ class _SignalingThread(threading.Thread):
60
160
def __init__ (
61
161
self ,
62
162
* ,
63
- send_signals : list [ synchronization . HandlerAwaitableSignal ],
64
- receive_signals : list [ synchronization . HandlerAwaitableSignal ],
163
+ send_signals : Sequence [ HandlerAwaitableSignal ],
164
+ receive_signals : Sequence [ HandlerAwaitableSignal ],
65
165
timeout_secs : int = 600 ,
66
166
** kwargs ,
67
167
):
@@ -137,10 +237,8 @@ def __init__(
137
237
coro : Coroutine [Any , Any , None ],
138
238
* ,
139
239
name : str | None = None ,
140
- send_signals : list [synchronization .HandlerAwaitableSignal ] | None = None ,
141
- receive_signals : (
142
- list [synchronization .HandlerAwaitableSignal ] | None
143
- ) = None ,
240
+ send_signals : Sequence [HandlerAwaitableSignal ] | None = None ,
241
+ receive_signals : Sequence [HandlerAwaitableSignal ] | None = None ,
144
242
timeout_secs : int = 600 ,
145
243
):
146
244
"""Constructor.
@@ -167,3 +265,49 @@ def __init__(
167
265
def result (self , timeout : Optional [float ] = None ) -> Any :
168
266
"""Waits for the commit to complete."""
169
267
return self ._t .join (timeout = timeout )
268
+
269
+
270
+ class CommitFutureAwaitingContractedSignals (Future ):
271
+ """Represents the result of a background commit.
272
+
273
+ May send signals to indicate that the commit has completed. Waits for all
274
+ awaitable signals in the `AWAITABLE_SIGNALS_CONTRACT` to be set before
275
+ proceeding with the commit.
276
+ """
277
+
278
+ def __init__ (
279
+ self ,
280
+ coro : Coroutine [Any , Any , None ],
281
+ * ,
282
+ name : str | None = None ,
283
+ send_signals : Sequence [HandlerAwaitableSignal ] | None = None ,
284
+ skip_if_not_initialized : bool = True ,
285
+ timeout_secs : int = 600 ,
286
+ ):
287
+ """Constructor.
288
+
289
+ Synchronously gets all awaitable signals in the contract and waits to
290
+ receive them in background before proceeding with the commit.
291
+
292
+ Args:
293
+ coro: The coroutine to run.
294
+ name: The name of the thread.
295
+ send_signals: Signals to send to indicate that the commit has completed.
296
+ skip_if_not_initialized: If True, skip fetching signals if the
297
+ `HandlerAwaitableSignalBarrierKeyGenerator` is not initialized.
298
+ timeout_secs: Timeout in seconds for waiting for signals.
299
+ """
300
+ super ().__init__ ()
301
+ receive_signals = []
302
+ if is_intialized () or not skip_if_not_initialized :
303
+ receive_signals = get_awaitable_signals_from_contract ()
304
+ self ._f = CommitFuture (
305
+ coro ,
306
+ name = name ,
307
+ send_signals = send_signals ,
308
+ receive_signals = receive_signals ,
309
+ timeout_secs = timeout_secs ,
310
+ )
311
+
312
+ def result (self , timeout : Optional [float ] = None ) -> Any :
313
+ return self ._f .result (timeout = timeout )
0 commit comments