Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 715503067
  • Loading branch information
niketkumar authored and Orbax Authors committed Jan 15, 2025
1 parent c829e8c commit 913f061
Show file tree
Hide file tree
Showing 7 changed files with 543 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from orbax.checkpoint import utils
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.multihost import multihost
Expand Down Expand Up @@ -282,6 +283,9 @@ def __init__(
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
tree_metadata.PYTREE_METADATA_OPTIONS
),
array_metadata_validator: array_metadata_store_lib.Validator = (
array_metadata_store_lib.Validator()
),
):
"""Creates BasePyTreeCheckpointHandler.
Expand All @@ -301,6 +305,7 @@ def __init__(
enable_post_merge_validation: If True, enables validation of the
parameters after the finalize step.
pytree_metadata_options: `PyTreeMetadataOptions` to manage metadata.
array_metadata_validator: Validator for ArrayMetadata.
"""
self._save_concurrent_bytes = save_concurrent_bytes
self._restore_concurrent_bytes = restore_concurrent_bytes
Expand All @@ -310,14 +315,22 @@ def __init__(
self._type_handler_registry = type_handler_registry
self._enable_post_merge_validation = enable_post_merge_validation
self._pytree_metadata_options = pytree_metadata_options
# Get ArrayMetadata Store from TypeHandler for jax.Array.
# ArrayMetadata persistence is only supported for jax.Array.
self._array_metadata_store = (
array_metadata_store_lib.resolve_array_metadata_store(
type_handler_registry
)
)
self._array_metadata_validator = array_metadata_validator


jax.monitoring.record_event(
'/jax/orbax/pytree_checkpoint_handler/init/ocdbt'
)

self._thread_pool = futures.ThreadPoolExecutor(
max_workers=2, thread_name_prefix='base_pytree_ch'
max_workers=3, thread_name_prefix='base_pytree_ch'
)
logging.info(
'Created BasePyTreeCheckpointHandler: pytree_metadata_options=%s',
Expand Down Expand Up @@ -451,7 +464,7 @@ async def async_save(
leaf.parent_dir == directory for leaf in jax.tree.leaves(param_infos)
)

serialize_ops = []
serialize_ops = [] # List of (coros -> List of futures)
batch_requests = batched_serialization_requests(
item,
param_infos,
Expand All @@ -465,20 +478,29 @@ async def async_save(
]
write_size, _ = _get_batch_memory_size(request.handler, request.values)
tree_memory_size += write_size
# Await copy futures. Returns list of lists.
# Await copy futures. Returns List[List[future.Future]].
commit_futures = await asyncio.gather(*serialize_ops)
# Flatten to List[future.Future].
commit_futures, _ = jax.tree.flatten(commit_futures)

if logging.vlog_is_on(1):
logging.vlog(1, 'param_info: %s', param_infos)
logging.vlog(1, 'save_args: %s', save_args)

save_futures = []
if multihost.is_primary_host(self._primary_host):
commit_futures.append(
self._write_metadata_file(
directory, param_infos, save_args, self._use_zarr3
save_futures.append(
self._thread_pool.submit(
self._write_metadata_after_commits,
commit_futures=commit_futures,
checkpoint_dir=directory,
param_infos=param_infos,
save_args=save_args,
use_zarr3=self._use_zarr3,
)
)
else:
save_futures += commit_futures

_log_io_metrics(
tree_memory_size,
Expand All @@ -487,7 +509,7 @@ async def async_save(
)
return [
future.ChainedFuture(
commit_futures,
save_futures,
functools.partial(
_log_io_metrics,
tree_memory_size,
Expand Down Expand Up @@ -725,14 +747,68 @@ class TrainState:
)
return restored_item

def _get_param_infos_with_write_shape(
self,
param_infos: PyTree,
checkpoint_dir: epath.Path,
array_metadata_store: array_metadata_store_lib.Store,
) -> PyTree:
"""Returns `param_infos` updated with `write_shape`.
Args:
param_infos: A PyTree of ParamInfo to be updated.
checkpoint_dir: The checkpoint directory where write_shape metadata is
saved in ArrayMetadata store.
array_metadata_store: The ArrayMetadata store to read write_shape metadata
from.
"""
if not utils.is_primary_host(self._primary_host):
return param_infos
# Extract write_shape from ArrayMetadata for current process_index.
process_index = multihost.process_index()
array_metadatas = array_metadata_store.read(
checkpoint_dir, process_index=process_index
)
if array_metadatas is None:
jax_array_param_info = type_handlers.any_jax_array_param_info(param_infos)
if jax_array_param_info is not None:
raise ValueError(
f'No ArrayMetadata found for process_index={process_index} in the'
f' checkpoint directory: {checkpoint_dir}. But input PyTree'
' contains at least one jax.Array param_info:'
f' {jax_array_param_info}.'
)
return param_infos

assert isinstance(array_metadatas, list)
array_metadatas_cache = {
array_metadata.param_name: array_metadata
for array_metadata in array_metadatas
}

def update_param_info(param_info: types.ParamInfo) -> types.ParamInfo:
if not type_handlers.represents_jax_array(param_info):
return param_info
if param_info.name not in array_metadatas_cache:
raise ValueError(
f'No ArrayMetadata found for param_info: {param_info}, checkpoint'
f' directory: {checkpoint_dir}, process_index={process_index}.'
)
return dataclasses.replace(
param_info,
write_shape=array_metadatas_cache[param_info.name].write_shape,
)

return jax.tree.map(update_param_info, param_infos)

def _write_metadata_file(
self,
directory: epath.Path,
param_infos: PyTree,
save_args: PyTree,
use_zarr3: bool = False,
) -> future.Future:
def _save_fn():
def _save_fn(param_infos):
if utils.is_primary_host(self._primary_host):
metadata_write_start_time = time.time()
path = directory / PYTREE_METADATA_FILE
Expand All @@ -755,7 +831,35 @@ def _save_fn():
)
return 0

return self._thread_pool.submit(_save_fn)
return self._thread_pool.submit(_save_fn, param_infos)

def _write_metadata_after_commits(
self,
commit_futures: List[future.Future],
checkpoint_dir: epath.Path,
param_infos: PyTree,
save_args: PyTree,
use_zarr3: bool,
) -> None:
if not utils.is_primary_host(self._primary_host):
return
for commit_future in commit_futures:
commit_future.result()
# `write_shape` is extracted from ArrayMetadata store saved during
# materialization of commit_futures. Then it is written to the pytree
# metadata.
# TODO(niket): Simplify all metadata related code in this module after
# removing overriding of self._write_metadata_file() in subclasses. All
# metadata related code can be moved to a separate class and
# BasePyTreeCheckpointHandler should delegate all metadata related code to
# that class.
if self._array_metadata_store is not None:
param_infos = self._get_param_infos_with_write_shape(
param_infos, checkpoint_dir, self._array_metadata_store
)
self._write_metadata_file(
checkpoint_dir, param_infos, save_args, use_zarr3
).result()

def _read_metadata_file(
self, directory: epath.Path
Expand Down Expand Up @@ -834,6 +938,17 @@ def finalize(self, directory: epath.Path) -> None:
Args:
directory: Path where the checkpoint is located.
"""
if (
utils.is_primary_host(self._primary_host)
and self._array_metadata_store is not None
):
array_metadatas = self._array_metadata_store.read(directory)
if array_metadatas is not None:
assert isinstance(array_metadatas, dict) # read all processes.
self._array_metadata_validator.validate_all_array_metadatas(
array_metadatas
)

merge_start_time = time.time()
ts_context = ts_utils.get_ts_context(use_ocdbt=True)
asyncio_utils.run_sync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from orbax.checkpoint._src import asyncio_utils
from orbax.checkpoint._src.handlers import async_checkpoint_handler
from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
from orbax.checkpoint._src.metadata import empty_values
from orbax.checkpoint._src.metadata import tree as tree_metadata
from orbax.checkpoint._src.serialization import serialization
Expand Down Expand Up @@ -475,6 +476,9 @@ def __init__(
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
tree_metadata.PYTREE_METADATA_OPTIONS
),
array_metadata_validator: array_metadata_store_lib.Validator = (
array_metadata_store_lib.Validator()
),
):
"""Creates PyTreeCheckpointHandler.
Expand All @@ -496,6 +500,7 @@ def __init__(
specified, the global type handler registry will be used.
handler_impl: Allows overriding the internal implementation.
pytree_metadata_options: `PyTreeMetadataOptions` to manage metadata.
array_metadata_validator: Validator for ArrayMetadata.
"""
self._aggregate_handler = MsgpackHandler(
primary_host=multiprocessing_options.primary_host,
Expand All @@ -518,6 +523,7 @@ def __init__(
multiprocessing_options=multiprocessing_options,
type_handler_registry=type_handler_registry,
pytree_metadata_options=pytree_metadata_options,
array_metadata_validator=array_metadata_validator,
)
self._pytree_metadata_options = pytree_metadata_options

Expand Down
2 changes: 2 additions & 0 deletions checkpoint/orbax/checkpoint/_src/metadata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ py_library(
deps = [
":array_metadata",
"//checkpoint/orbax/checkpoint/_src/multihost",
"//checkpoint/orbax/checkpoint/_src/serialization:types",
],
)

Expand All @@ -152,5 +153,6 @@ py_test(
deps = [
":array_metadata",
":array_metadata_store",
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
],
)
Loading

0 comments on commit 913f061

Please sign in to comment.