Skip to content

Commit

Permalink
Add ability to create directories asynchronously via signaling in Asy…
Browse files Browse the repository at this point in the history
…nc Checkpointer.

PiperOrigin-RevId: 717447279
  • Loading branch information
mridul-sahu authored and Orbax Authors committed Jan 22, 2025
1 parent f7bbe80 commit b796446
Show file tree
Hide file tree
Showing 25 changed files with 395 additions and 184 deletions.
5 changes: 3 additions & 2 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ properties not included in any tree mapping operations.

### Added
- The ability to specify a custom `snapshot_dir` in `checkpoints_iterator`.
- `CommitFuture` and `HandlerAwaitableSignal` for signalling between
Checkpointing layers to enable async directory creation.
- `CommitFutureAwaitDirectorySignals`, `CommitFuture` and
`HandlerAwaitableSignal` for signalling between Checkpointing layers to enable
async directory creation.
- User-provided custom PyTree metadata.

### Fixed
Expand Down
3 changes: 2 additions & 1 deletion checkpoint/orbax/checkpoint/_src/checkpointers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ py_library(
deps = [
":checkpointer",
"//checkpoint/orbax/checkpoint:checkpoint_args",
"//checkpoint/orbax/checkpoint:future",
"//checkpoint/orbax/checkpoint:options",
"//checkpoint/orbax/checkpoint/_src:asyncio_utils",
"//checkpoint/orbax/checkpoint/_src/handlers:async_checkpoint_handler",
Expand All @@ -67,5 +66,7 @@ py_library(
"//checkpoint/orbax/checkpoint/_src/path:atomicity",
"//checkpoint/orbax/checkpoint/_src/path:atomicity_types",
"//orbax/checkpoint:utils",
"//orbax/checkpoint/_src/futures:future",
"//orbax/checkpoint/_src/futures:synchronization",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
from etils import epath
import jax
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future as future_lib
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import utils
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.checkpointers import checkpointer
from orbax.checkpoint._src.futures import future
from orbax.checkpoint._src.futures import synchronization
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.metadata import checkpoint
from orbax.checkpoint._src.multihost import multihost
Expand All @@ -37,6 +38,8 @@


BarrierSyncFn = multihost.BarrierSyncFn
HandlerAwaitableSignal = synchronization.HandlerAwaitableSignal
_DIRECTORY_CREATION_SIGNALS = [HandlerAwaitableSignal.STEP_DIRECTORY_CREATION]


def _on_commit_callback(
Expand Down Expand Up @@ -106,7 +109,7 @@ def __del__(self):
def _thread_func(
self,
directory: epath.Path,
commit_futures: Sequence[future_lib.Future],
commit_futures: Sequence[future.Future],
on_commit_callback: Callable[[], None],
):
"""Awaits on commit futures and finalizes the checkpoint."""
Expand All @@ -122,8 +125,8 @@ def _thread_func(
thread_start_time = time.time()

# Wait for commit operations to complete.
for future in commit_futures:
future.result()
for commit_future in commit_futures:
commit_future.result()
logging.info(
'[process=%s][thread=%s] %d Handler Commit operations completed.',
current_process,
Expand Down Expand Up @@ -194,7 +197,7 @@ def _thread_func(
def start_async_commit(
self,
directory: epath.Path,
commit_futures: Sequence[future_lib.Future],
commit_futures: Sequence[future.Future],
on_commit_callback: Callable[[], None],
):
"""Completes checkpoint save in a background thread."""
Expand Down Expand Up @@ -296,6 +299,9 @@ def __init__(
self._primary_host = multiprocessing_options.primary_host
self._active_processes = multiprocessing_options.active_processes
self._post_finalization_callback = async_options.post_finalization_callback
self._create_directories_asynchronously = (
async_options.create_directories_asynchronously
)
barrier_sync_key_prefix = (
''
if multiprocessing_options.barrier_sync_key_prefix is None
Expand All @@ -304,9 +310,8 @@ def __init__(
self._barrier_sync_key_prefix = barrier_sync_key_prefix
self._file_options = file_options
self._metadata_store = (
checkpoint_metadata_store or checkpoint.metadata_store(
enable_write=True
)
checkpoint_metadata_store
or checkpoint.metadata_store(enable_write=True)
)
self._temporary_path_class = temporary_path_class
timeout_secs = timeout_secs or async_options.timeout_secs
Expand All @@ -324,12 +329,65 @@ def __init__(
barrier_sync_key_prefix=barrier_sync_key_prefix,
)

def _create_temporary_path_asyncronously(
self, temporary_path: atomicity_types.TemporaryPath
) -> future.Future:
start = time.time()
# Sync for existence check to complete on all hosts before directory
# creation starts.
multihost.sync_global_processes(
multihost.unique_barrier_key(
'create_tmp_directory:post_existence_check',
prefix=self._barrier_sync_key_prefix,
),
timeout=multihost.DIRECTORY_CREATION_TIMEOUT,
processes=self._active_processes,
)

commit_future = future.NoopFuture()
if utils.is_primary_host(self._primary_host):
commit_future = future.CommitFutureAwaitingContractedSignals(
atomicity.create_paths(
[temporary_path], create_path_start_time=start
),
send_signals=_DIRECTORY_CREATION_SIGNALS,
timeout_secs=multihost.DIRECTORY_CREATION_TIMEOUT,
)
future.add_to_awaitable_signals_contract(_DIRECTORY_CREATION_SIGNALS)

# Sync to enusre that all hosts have the same awaitable signals contract.
multihost.sync_global_processes(
multihost.unique_barrier_key(
'add_to_awaitable_signals_contract',
prefix=self._barrier_sync_key_prefix,
),
timeout=multihost.DIRECTORY_CREATION_TIMEOUT,
processes=self._active_processes,
)
return commit_future

def _syncronize_next_awaitable_signal_operation_id(self):
# Synchronize next operation id if async directory creation is enabled
# across all hosts.
if self._create_directories_asynchronously:
synchronization.HandlerAwaitableSignalBarrierKeyGenerator.next_operation_id()

multihost.sync_global_processes(
multihost.unique_barrier_key(
'next_awaitable_signal_operation_id:sync',
prefix=self._barrier_sync_key_prefix,
),
timeout=multihost.DIRECTORY_CREATION_TIMEOUT,
processes=self._active_processes,
)

async def _save(
self, directory: epath.PathLike, *args, force: bool = False, **kwargs
):
checkpoint_start_time = time.time()
directory = epath.Path(directory)
self.wait_until_finished()
self._syncronize_next_awaitable_signal_operation_id()

jax.monitoring.record_event('/jax/orbax/write/async/start')
logging.info(
Expand All @@ -351,13 +409,20 @@ async def _save(
else:
raise ValueError(f'Destination {directory} already exists.')

tmpdir = await self.create_temporary_path(directory)
commit_ops = []
tmpdir = self.get_temporary_path(directory)
if self._create_directories_asynchronously:
commit_ops.append(self._create_temporary_path_asyncronously(tmpdir))
else:
await self.create_temporary_path(tmpdir)
# Run copy ops.
# Try to save using new CheckpointArgs API if supported by the handler.
ckpt_args = checkpointer.construct_checkpoint_args(
self._handler, True, *args, **kwargs
)
commit_ops = await self._handler.async_save(tmpdir.get(), args=ckpt_args)
commit_ops.extend(
await self._handler.async_save(tmpdir.get(), args=ckpt_args)
)
commit_ops, _ = jax.tree.flatten(commit_ops)
commit_ops = [op for op in commit_ops if op is not None]

Expand Down
18 changes: 14 additions & 4 deletions checkpoint/orbax/checkpoint/_src/checkpointers/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __init__(

jax.monitoring.record_event('/jax/orbax/checkpointer/init')

async def create_temporary_path(
def get_temporary_path(
self, directory: epath.Path
) -> atomicity_types.TemporaryPath:
temporary_path_class = (
Expand All @@ -157,10 +157,19 @@ async def create_temporary_path(
multiprocessing_options=multiprocessing_options,
file_options=self._file_options,
)
return tmpdir

async def create_temporary_path(
self, temporary_path: atomicity_types.TemporaryPath
):
multiprocessing_options = options_lib.MultiprocessingOptions(
primary_host=self._primary_host,
active_processes=self._active_processes,
barrier_sync_key_prefix=self._barrier_sync_key_prefix,
)
await atomicity.create_all(
[tmpdir], multiprocessing_options=multiprocessing_options
[temporary_path], multiprocessing_options=multiprocessing_options
)
return tmpdir

def save(
self,
Expand Down Expand Up @@ -206,8 +215,9 @@ def save(
else:
raise ValueError(f'Destination {directory} already exists.')
ckpt_args = construct_checkpoint_args(self._handler, True, *args, **kwargs)
tmpdir = self.get_temporary_path(directory)
# tmpdir creation also does an initial StepMetadata save.
tmpdir = asyncio_utils.run_sync(self.create_temporary_path(directory))
asyncio_utils.run_sync(self.create_temporary_path(tmpdir))
self._handler.save(tmpdir.get(), args=ckpt_args)
if utils.is_primary_host(self._primary_host):
# Update StepMetadata after the handler save is complete. (blocking write)
Expand Down
Loading

0 comments on commit b796446

Please sign in to comment.