Skip to content

Commit e0ed32e

Browse files
mridul-sahuOrbax Authors
authored and
Orbax Authors
committed
Use orbax.checkpoint._src.futures.future instead of orbax.checkpoint.future inside _src directory
PiperOrigin-RevId: 718290785
1 parent f7bbe80 commit e0ed32e

22 files changed

+233
-167
lines changed

checkpoint/CHANGELOG.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ properties not included in any tree mapping operations.
1919

2020
### Added
2121
- The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`.
22-
- `CommitFuture` and `HandlerAwaitableSignal` for signalling between
23-
Checkpointing layers to enable async directory creation.
22+
- `CommitFutureAwaitDirectorySignals`, `CommitFuture` and
23+
`HandlerAwaitableSignal` for signalling between Checkpointing layers to enable
24+
async directory creation.
2425
- User-provided custom PyTree metadata.
2526

2627
### Fixed

checkpoint/orbax/checkpoint/_src/checkpointers/BUILD

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ py_library(
5757
deps = [
5858
":checkpointer",
5959
"//checkpoint/orbax/checkpoint:checkpoint_args",
60-
"//checkpoint/orbax/checkpoint:future",
6160
"//checkpoint/orbax/checkpoint:options",
6261
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
6362
"//checkpoint/orbax/checkpoint/_src/handlers:async_checkpoint_handler",
@@ -67,5 +66,6 @@ py_library(
6766
"//checkpoint/orbax/checkpoint/_src/path:atomicity",
6867
"//checkpoint/orbax/checkpoint/_src/path:atomicity_types",
6968
"//orbax/checkpoint:utils",
69+
"//orbax/checkpoint/_src/futures:future",
7070
],
7171
)

checkpoint/orbax/checkpoint/_src/checkpointers/async_checkpointer.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@
2323
from etils import epath
2424
import jax
2525
from orbax.checkpoint import checkpoint_args
26-
from orbax.checkpoint import future as future_lib
2726
from orbax.checkpoint import options as options_lib
2827
from orbax.checkpoint import utils
2928
from orbax.checkpoint._src import asyncio_utils
3029
from orbax.checkpoint._src.checkpointers import checkpointer
30+
from orbax.checkpoint._src.futures import future
3131
from orbax.checkpoint._src.handlers import async_checkpoint_handler
3232
from orbax.checkpoint._src.metadata import checkpoint
3333
from orbax.checkpoint._src.multihost import multihost
@@ -106,7 +106,7 @@ def __del__(self):
106106
def _thread_func(
107107
self,
108108
directory: epath.Path,
109-
commit_futures: Sequence[future_lib.Future],
109+
commit_futures: Sequence[future.Future],
110110
on_commit_callback: Callable[[], None],
111111
):
112112
"""Awaits on commit futures and finalizes the checkpoint."""
@@ -122,8 +122,8 @@ def _thread_func(
122122
thread_start_time = time.time()
123123

124124
# Wait for commit operations to complete.
125-
for future in commit_futures:
126-
future.result()
125+
for commit_future in commit_futures:
126+
commit_future.result()
127127
logging.info(
128128
'[process=%s][thread=%s] %d Handler Commit operations completed.',
129129
current_process,
@@ -194,7 +194,7 @@ def _thread_func(
194194
def start_async_commit(
195195
self,
196196
directory: epath.Path,
197-
commit_futures: Sequence[future_lib.Future],
197+
commit_futures: Sequence[future.Future],
198198
on_commit_callback: Callable[[], None],
199199
):
200200
"""Completes checkpoint save in a background thread."""
@@ -304,9 +304,8 @@ def __init__(
304304
self._barrier_sync_key_prefix = barrier_sync_key_prefix
305305
self._file_options = file_options
306306
self._metadata_store = (
307-
checkpoint_metadata_store or checkpoint.metadata_store(
308-
enable_write=True
309-
)
307+
checkpoint_metadata_store
308+
or checkpoint.metadata_store(enable_write=True)
310309
)
311310
self._temporary_path_class = temporary_path_class
312311
timeout_secs = timeout_secs or async_options.timeout_secs

checkpoint/orbax/checkpoint/_src/futures/future.py

+151-7
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
"""Futures that can be used for signaling for synchronization."""
1616

1717
import threading
18-
from typing import Any, Coroutine, Optional
18+
import time
19+
from typing import Any, Callable, Coroutine, Optional, Sequence
1920

2021
from absl import logging
22+
import jax
2123
from orbax.checkpoint._src import asyncio_utils
2224
from orbax.checkpoint._src.futures import synchronization
2325
from orbax.checkpoint._src.multihost import multihost
@@ -27,9 +29,51 @@
2729
get_unique_barrier_key = (
2830
synchronization.HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key
2931
)
32+
is_intialized = (
33+
synchronization.HandlerAwaitableSignalBarrierKeyGenerator.is_intialized
34+
)
35+
HandlerAwaitableSignal = synchronization.HandlerAwaitableSignal
3036
_SIGNAL_ACTION_SUCCESS = 'signal_action_success'
3137

3238

39+
def get_awaitable_signals_from_contract() -> Sequence[HandlerAwaitableSignal]:
40+
"""Gets the awaitable signals that may be sent for the current operation id."""
41+
client = multihost.get_jax_distributed_client()
42+
barrier_key = get_unique_barrier_key(
43+
HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT
44+
)
45+
try:
46+
values_str = str(client.key_value_try_get(barrier_key))
47+
return [HandlerAwaitableSignal(value) for value in values_str.split(',')]
48+
except jax.errors.JaxRuntimeError:
49+
# If the key is not found, then there are no awaitable signals yet.
50+
return []
51+
52+
53+
def add_to_awaitable_signals_contract(
54+
signals: Sequence[HandlerAwaitableSignal],
55+
):
56+
"""Adds awaitable signals to `AWAITABLE_SIGNALS_CONTRACT` for lower checkpointing layers to wait on.
57+
58+
These signals are added to the list of awaitable signals for the current
59+
opertation id in `HandlerAwaitableSignalBarrierKeyGenerator`.
60+
61+
Args:
62+
signals: The signals to add to the list of awaitable signals.
63+
"""
64+
if not signals:
65+
return
66+
67+
current_signals = list(get_awaitable_signals_from_contract())
68+
current_signals.extend(signals)
69+
keys = ','.join([current_signal.value for current_signal in current_signals])
70+
client = multihost.get_jax_distributed_client()
71+
barrier_key = get_unique_barrier_key(
72+
HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT
73+
)
74+
client.key_value_set(barrier_key, keys, allow_overwrite=True)
75+
76+
3377
class Future(Protocol):
3478
"""Abstracted Orbax Future class.
3579
@@ -48,6 +92,62 @@ def result(self, timeout: Optional[int] = None) -> Any:
4892
...
4993

5094

95+
class NoopFuture:
96+
97+
def result(self, timeout: Optional[int] = None) -> Any:
98+
del timeout
99+
return None
100+
101+
102+
class ChainedFuture:
103+
"""A future representing a sequence of multiple futures."""
104+
105+
def __init__(self, futures: Sequence[Future], cb: Callable[[], None]):
106+
self._futures = futures
107+
self._cb = cb
108+
109+
def result(self, timeout: Optional[int] = None) -> Any:
110+
"""Waits for all futures to complete."""
111+
n = len(self._futures)
112+
start = time.time()
113+
time_remaining = timeout
114+
for k, f in enumerate(self._futures):
115+
f.result(timeout=time_remaining)
116+
if time_remaining is not None:
117+
time_elapsed = time.time() - start
118+
time_remaining -= time_elapsed
119+
if time_remaining <= 0:
120+
raise TimeoutError(
121+
'ChainedFuture completed {:d}/{:d} futures but timed out after'
122+
' {:.2f} seconds.'.format(k, n, time_elapsed)
123+
)
124+
time_elapsed = time.time() - start
125+
logging.info(
126+
'ChainedFuture completed %d/%d futures in %.2f seconds.',
127+
n,
128+
n,
129+
time_elapsed,
130+
)
131+
self._cb()
132+
133+
134+
class ThreadRaisingException(threading.Thread):
135+
"""Thread that raises an exception if it encounters an error."""
136+
137+
_exception: Optional[Exception] = None
138+
139+
def run(self):
140+
try:
141+
super().run()
142+
except Exception as e: # pylint: disable=broad-exception-caught
143+
self._exception = e
144+
145+
def join(self, timeout=None):
146+
super().join(timeout=timeout)
147+
if self._exception is not None:
148+
raise self._exception
149+
150+
51151
class _SignalingThread(threading.Thread):
52152
"""Thread that raises an exception if it encounters an error.
53153
@@ -60,8 +160,8 @@ class _SignalingThread(threading.Thread):
60160
def __init__(
61161
self,
62162
*,
63-
send_signals: list[synchronization.HandlerAwaitableSignal],
64-
receive_signals: list[synchronization.HandlerAwaitableSignal],
163+
send_signals: Sequence[HandlerAwaitableSignal],
164+
receive_signals: Sequence[HandlerAwaitableSignal],
65165
timeout_secs: int = 600,
66166
**kwargs,
67167
):
@@ -137,10 +237,8 @@ def __init__(
137237
coro: Coroutine[Any, Any, None],
138238
*,
139239
name: str | None = None,
140-
send_signals: list[synchronization.HandlerAwaitableSignal] | None = None,
141-
receive_signals: (
142-
list[synchronization.HandlerAwaitableSignal] | None
143-
) = None,
240+
send_signals: Sequence[HandlerAwaitableSignal] | None = None,
241+
receive_signals: Sequence[HandlerAwaitableSignal] | None = None,
144242
timeout_secs: int = 600,
145243
):
146244
"""Constructor.
@@ -167,3 +265,49 @@ def __init__(
167265
def result(self, timeout: Optional[float] = None) -> Any:
168266
"""Waits for the commit to complete."""
169267
return self._t.join(timeout=timeout)
268+
269+
270+
class CommitFutureAwaitingContractedSignals(Future):
271+
"""Represents the result of a background commit.
272+
273+
May send signals to indicate that the commit has completed. Waits for all
274+
awaitable signals in the `AWAITABLE_SIGNALS_CONTRACT` to be set before
275+
proceeding with the commit.
276+
"""
277+
278+
def __init__(
279+
self,
280+
coro: Coroutine[Any, Any, None],
281+
*,
282+
name: str | None = None,
283+
send_signals: Sequence[HandlerAwaitableSignal] | None = None,
284+
skip_if_not_initialized: bool = True,
285+
timeout_secs: int = 600,
286+
):
287+
"""Constructor.
288+
289+
Synchronously gets all awaitable signals in the contract and waits to
290+
receive them in background before proceeding with the commit.
291+
292+
Args:
293+
coro: The coroutine to run.
294+
name: The name of the thread.
295+
send_signals: Signals to send to indicate that the commit has completed.
296+
skip_if_not_initialized: If True, skip fetching signals if the
297+
`HandlerAwaitableSignalBarrierKeyGenerator` is not initialized.
298+
timeout_secs: Timeout in seconds for waiting for signals.
299+
"""
300+
super().__init__()
301+
receive_signals = []
302+
if is_intialized() or not skip_if_not_initialized:
303+
receive_signals = get_awaitable_signals_from_contract()
304+
self._f = CommitFuture(
305+
coro,
306+
name=name,
307+
send_signals=send_signals,
308+
receive_signals=receive_signals,
309+
timeout_secs=timeout_secs,
310+
)
311+
312+
def result(self, timeout: Optional[float] = None) -> Any:
313+
return self._f.result(timeout=timeout)

checkpoint/orbax/checkpoint/_src/futures/synchronization.py

+8
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class HandlerAwaitableSignal(enum.Enum):
2626
`CheckpointHandler or below.`
2727
2828
Attributes:
29+
AWAITABLE_SIGNALS_CONTRACT: Contract that contains a list of signals that
30+
may be sent and can be awaited by the handlers.
2931
STEP_DIRECTORY_CREATION: When recieved, indicates that the step directory
3032
has been created. The handler should not attempt to write files before the
3133
directory is created.
@@ -34,6 +36,7 @@ class HandlerAwaitableSignal(enum.Enum):
3436
directory is created.
3537
"""
3638

39+
AWAITABLE_SIGNALS_CONTRACT = "awaitable_signals_contract"
3740
STEP_DIRECTORY_CREATION = "step_directory_creation"
3841
ITEM_DIRECTORY_CREATION = "item_directory_creation"
3942

@@ -68,3 +71,8 @@ def get_unique_barrier_key(cls, signal: HandlerAwaitableSignal) -> str:
6871
return multihost.unique_barrier_key(
6972
signal.value, suffix=str(cls._operation_id)
7073
)
74+
75+
@classmethod
76+
def is_intialized(cls) -> bool:
77+
"""Returns whether the operation id counter is initialized."""
78+
return cls._operation_id is not None

checkpoint/orbax/checkpoint/_src/futures/synchronization_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def test_get_unique_barrier_key_without_operation_id_raises_error(self):
3838
HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key(
3939
step_directory_creation_signal
4040
)
41+
self.assertFalse(HandlerAwaitableSignalBarrierKeyGenerator.is_intialized())
4142

4243
def test_get_unique_barrier_key(self):
4344
step_directory_creation_signal = (
@@ -63,6 +64,7 @@ def test_get_unique_barrier_key(self):
6364
)
6465
)
6566

67+
self.assertTrue(HandlerAwaitableSignalBarrierKeyGenerator.is_intialized())
6668
self.assertEqual(barrier_key_0, expected_barrier_key_0)
6769
self.assertEqual(barrier_key_1, expected_barrier_key_1)
6870

0 commit comments

Comments
 (0)