Skip to content

Commit

Permalink
Use orbax.checkpoint._src.futures.future instead of `orbax.checkpoi…
Browse files Browse the repository at this point in the history
…nt.future` inside _src directory

PiperOrigin-RevId: 718290785
  • Loading branch information
mridul-sahu authored and Orbax Authors committed Jan 22, 2025
1 parent f7bbe80 commit e5673dd
Show file tree
Hide file tree
Showing 22 changed files with 228 additions and 162 deletions.
5 changes: 3 additions & 2 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ properties not included in any tree mapping operations.

### Added
- The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`.
- `CommitFuture` and `HandlerAwaitableSignal` for signalling between
Checkpointing layers to enable async directory creation.
- `CommitFutureAwaitDirectorySignals`, `CommitFuture` and
`HandlerAwaitableSignal` for signalling between Checkpointing layers to enable
async directory creation.
- User-provided custom PyTree metadata.

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/_src/checkpointers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ py_library(
deps = [
":checkpointer",
"//checkpoint/orbax/checkpoint:checkpoint_args",
"//checkpoint/orbax/checkpoint:future",
"//checkpoint/orbax/checkpoint:options",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/handlers:async_checkpoint_handler",
Expand All @@ -67,5 +66,6 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/path:atomicity",
"//checkpoint/orbax/checkpoint/_src/path:atomicity_types",
"//orbax/checkpoint:utils",
"//orbax/checkpoint/_src/futures:future",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from etils import epath
import jax
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future as future_lib
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import utils
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.checkpointers import checkpointer
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.multihost import multihost
Expand Down Expand Up @@ -106,7 +106,7 @@ def __del__(self):
def _thread_func(
self,
directory: epath.Path,
commit_futures: Sequence[future_lib.Future],
commit_futures: Sequence[future.Future],
on_commit_callback: Callable[[], None],
):
"""Awaits on commit futures and finalizes the checkpoint."""
Expand All @@ -122,8 +122,8 @@ def _thread_func(
thread_start_time = time.time()

# Wait for commit operations to complete.
for future in commit_futures:
future.result()
for commit_future in commit_futures:
commit_future.result()
logging.info(
'[process=%s][thread=%s] %d Handler Commit operations completed.',
current_process,
Expand Down Expand Up @@ -194,7 +194,7 @@ def _thread_func(
def start_async_commit(
self,
directory: epath.Path,
commit_futures: Sequence[future_lib.Future],
commit_futures: Sequence[future.Future],
on_commit_callback: Callable[[], None],
):
"""Completes checkpoint save in a background thread."""
Expand Down Expand Up @@ -304,9 +304,8 @@ def __init__(
self._barrier_sync_key_prefix = barrier_sync_key_prefix
self._file_options = file_options
self._metadata_store = (
checkpoint_metadata_store or checkpoint.metadata_store(
enable_write=True
)
checkpoint_metadata_store
or checkpoint.metadata_store(enable_write=True)
)
self._temporary_path_class = temporary_path_class
timeout_secs = timeout_secs or async_options.timeout_secs
Expand Down
158 changes: 151 additions & 7 deletions checkpoint/orbax/checkpoint/_src/futures/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
"""Futures that can be used for signaling for synchronization."""

import threading
from typing import Any, Coroutine, Optional
import time
from typing import Any, Callable, Coroutine, Optional, Sequence

from absl import logging
import jax
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.futures import synchronization
from orbax.checkpoint._src.multihost import multihost
Expand All @@ -27,9 +29,51 @@
get_unique_barrier_key = (
synchronization.HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key
)
is_intialized = (
synchronization.HandlerAwaitableSignalBarrierKeyGenerator.is_intialized
)
HandlerAwaitableSignal = synchronization.HandlerAwaitableSignal
_SIGNAL_ACTION_SUCCESS = 'signal_action_success'


def get_awaitable_signals_from_contract() -> Sequence[HandlerAwaitableSignal]:
"""Gets the awaitable signals that may be sent for the current operation id."""
client = multihost.get_jax_distributed_client()
barrier_key = get_unique_barrier_key(
HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT
)
try:
values_str = str(client.key_value_try_get(barrier_key))
return [HandlerAwaitableSignal(value) for value in values_str.split(',')]
except jax.errors.JaxRuntimeError:
# If the key is not found, then there are no awaitable signals yet.
return []


def add_to_awaitable_signals_contract(
signals: Sequence[HandlerAwaitableSignal],
):
"""Adds awaitable signals to `AWAITABLE_SIGNALS_CONTRACT` for lower checkpointing layers to wait on.
These signals are added to the list of awaitable signals for the current
opertation id in `HandlerAwaitableSignalBarrierKeyGenerator`.
Args:
signals: The signals to add to the list of awaitable signals.
"""
if not signals:
return

current_signals = list(get_awaitable_signals_from_contract())
current_signals.extend(signals)
keys = ','.join([current_signal.value for current_signal in current_signals])
client = multihost.get_jax_distributed_client()
barrier_key = get_unique_barrier_key(
HandlerAwaitableSignal.AWAITABLE_SIGNALS_CONTRACT
)
client.key_value_set(barrier_key, keys, allow_overwrite=True)


class Future(Protocol):
"""Abstracted Orbax Future class.
Expand All @@ -48,6 +92,62 @@ def result(self, timeout: Optional[int] = None) -> Any:
...


class NoopFuture:

def result(self, timeout: Optional[int] = None) -> Any:
del timeout
return None


class ChainedFuture:
"""A future representing a sequence of multiple futures."""

def __init__(self, futures: Sequence[Future], cb: Callable[[], None]):
self._futures = futures
self._cb = cb

def result(self, timeout: Optional[int] = None) -> Any:
"""Waits for all futures to complete."""
n = len(self._futures)
start = time.time()
time_remaining = timeout
for k, f in enumerate(self._futures):
f.result(timeout=time_remaining)
if time_remaining is not None:
time_elapsed = time.time() - start
time_remaining -= time_elapsed
if time_remaining <= 0:
raise TimeoutError(
'ChainedFuture completed {:d}/{:d} futures but timed out after'
' {:.2f} seconds.'.format(k, n, time_elapsed)
)
time_elapsed = time.time() - start
logging.info(
'ChainedFuture completed %d/%d futures in %.2f seconds.',
n,
n,
time_elapsed,
)
self._cb()


class ThreadRaisingException(threading.Thread):
"""Thread that raises an exception if it encounters an error."""

_exception: Optional[Exception] = None

def run(self):
try:
super().run()
except Exception as e: # pylint: disable=broad-exception-caught
self._exception = e

def join(self, timeout=None):
super().join(timeout=timeout)
if self._exception is not None:
raise self._exception


class _SignalingThread(threading.Thread):
"""Thread that raises an exception if it encounters an error.
Expand All @@ -60,8 +160,8 @@ class _SignalingThread(threading.Thread):
def __init__(
self,
*,
send_signals: list[synchronization.HandlerAwaitableSignal],
receive_signals: list[synchronization.HandlerAwaitableSignal],
send_signals: Sequence[HandlerAwaitableSignal],
receive_signals: Sequence[HandlerAwaitableSignal],
timeout_secs: int = 600,
**kwargs,
):
Expand Down Expand Up @@ -137,10 +237,8 @@ def __init__(
coro: Coroutine[Any, Any, None],
*,
name: str | None = None,
send_signals: list[synchronization.HandlerAwaitableSignal] | None = None,
receive_signals: (
list[synchronization.HandlerAwaitableSignal] | None
) = None,
send_signals: Sequence[HandlerAwaitableSignal] | None = None,
receive_signals: Sequence[HandlerAwaitableSignal] | None = None,
timeout_secs: int = 600,
):
"""Constructor.
Expand All @@ -167,3 +265,49 @@ def __init__(
def result(self, timeout: Optional[float] = None) -> Any:
"""Waits for the commit to complete."""
return self._t.join(timeout=timeout)


class CommitFutureAwaitingContractedSignals(Future):
"""Represents the result of a background commit.
May send signals to indicate that the commit has completed. Waits for all
awaitable signals in the `AWAITABLE_SIGNALS_CONTRACT` to be set before
proceeding with the commit.
"""

def __init__(
self,
coro: Coroutine[Any, Any, None],
*,
name: str | None = None,
send_signals: Sequence[HandlerAwaitableSignal] | None = None,
skip_if_not_initialized: bool = True,
timeout_secs: int = 600,
):
"""Constructor.
Synchronously gets all awaitable signals in the contract and waits to
receive them in background before proceeding with the commit.
Args:
coro: The coroutine to run.
name: The name of the thread.
send_signals: Signals to send to indicate that the commit has completed.
skip_if_not_initialized: If True, skip fetching signals if the
`HandlerAwaitableSignalBarrierKeyGenerator` is not initialized.
timeout_secs: Timeout in seconds for waiting for signals.
"""
super().__init__()
receive_signals = []
if is_intialized() or not skip_if_not_initialized:
receive_signals = get_awaitable_signals_from_contract()
self._f = CommitFuture(
coro,
name=name,
send_signals=send_signals,
receive_signals=receive_signals,
timeout_secs=timeout_secs,
)

def result(self, timeout: Optional[float] = None) -> Any:
return self._f.result(timeout=timeout)
8 changes: 8 additions & 0 deletions checkpoint/orbax/checkpoint/_src/futures/synchronization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class HandlerAwaitableSignal(enum.Enum):
`CheckpointHandler or below.`
Attributes:
AWAITABLE_SIGNALS_CONTRACT: Contract that contains a list of signals that
may be sent and can be awaited by the handlers.
STEP_DIRECTORY_CREATION: When recieved, indicates that the step directory
has been created. The handler should not attempt to write files before the
directory is created.
Expand All @@ -34,6 +36,7 @@ class HandlerAwaitableSignal(enum.Enum):
directory is created.
"""

AWAITABLE_SIGNALS_CONTRACT = "awaitable_signals_contract"
STEP_DIRECTORY_CREATION = "step_directory_creation"
ITEM_DIRECTORY_CREATION = "item_directory_creation"

Expand Down Expand Up @@ -68,3 +71,8 @@ def get_unique_barrier_key(cls, signal: HandlerAwaitableSignal) -> str:
return multihost.unique_barrier_key(
signal.value, suffix=str(cls._operation_id)
)

@classmethod
def is_intialized(cls) -> bool:
"""Returns whether the operation id counter is initialized."""
return cls._operation_id is not None
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_get_unique_barrier_key_without_operation_id_raises_error(self):
HandlerAwaitableSignalBarrierKeyGenerator.get_unique_barrier_key(
step_directory_creation_signal
)
self.assertFalse(HandlerAwaitableSignalBarrierKeyGenerator.is_intialized())

def test_get_unique_barrier_key(self):
step_directory_creation_signal = (
Expand All @@ -63,6 +64,7 @@ def test_get_unique_barrier_key(self):
)
)

self.assertTrue(HandlerAwaitableSignalBarrierKeyGenerator.is_intialized())
self.assertEqual(barrier_key_0, expected_barrier_key_0)
self.assertEqual(barrier_key_1, expected_barrier_key_1)

Expand Down
8 changes: 4 additions & 4 deletions checkpoint/orbax/checkpoint/_src/handlers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ py_library(
":handler_registration",
":proto_checkpoint_handler",
"//checkpoint/orbax/checkpoint:checkpoint_args",
"//checkpoint/orbax/checkpoint:future",
"//checkpoint/orbax/checkpoint:options",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src:composite",
Expand All @@ -26,6 +25,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/path:atomicity",
"//checkpoint/orbax/checkpoint/_src/path:atomicity_defaults",
"//checkpoint/orbax/checkpoint/_src/path:atomicity_types",
"//orbax/checkpoint/_src/futures:future",
],
)

Expand Down Expand Up @@ -81,7 +81,6 @@ py_library(
deps = [
":async_checkpoint_handler",
"//checkpoint/orbax/checkpoint:checkpoint_args",
"//checkpoint/orbax/checkpoint:future",
"//checkpoint/orbax/checkpoint:options",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/metadata:empty_values",
Expand All @@ -95,6 +94,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/tree:types",
"//checkpoint/orbax/checkpoint/_src/tree:utils",
"//orbax/checkpoint:utils",
"//orbax/checkpoint/_src/futures:future",
"//orbax/checkpoint/_src/metadata:array_metadata_store",
],
)
Expand All @@ -105,10 +105,10 @@ py_library(
deps = [
":async_checkpoint_handler",
"//checkpoint/orbax/checkpoint:checkpoint_args",
"//checkpoint/orbax/checkpoint:future",
"//checkpoint/orbax/checkpoint:options",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//orbax/checkpoint:utils",
"//orbax/checkpoint/_src/futures:future",
],
)

Expand Down Expand Up @@ -152,10 +152,10 @@ py_library(
deps = [
":async_checkpoint_handler",
"//checkpoint/orbax/checkpoint:checkpoint_args",
"//checkpoint/orbax/checkpoint:future",
"//checkpoint/orbax/checkpoint:options",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//orbax/checkpoint:utils",
"//orbax/checkpoint/_src/futures:future",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
import numpy as np
from orbax.checkpoint import aggregate_handlers
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future
from orbax.checkpoint import utils
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils
from orbax.checkpoint._src.serialization import type_handlers
Expand Down
Loading

0 comments on commit e5673dd

Please sign in to comment.