Skip to content

Commit 24eaa0f

Browse files
niketkumarOrbax Authors
authored and
Orbax Authors
committed
Internal change.
PiperOrigin-RevId: 715616303
1 parent 6ed02f3 commit 24eaa0f

10 files changed

+720
-38
lines changed

checkpoint/orbax/checkpoint/_src/handlers/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ py_library(
5858
"//checkpoint/orbax/checkpoint/_src/serialization:tensorstore_utils",
5959
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
6060
"//checkpoint/orbax/checkpoint/_src/tree:utils",
61+
"//orbax/checkpoint/_src/metadata:array_metadata_store",
6162
],
6263
)
6364

@@ -77,6 +78,7 @@ py_library(
7778
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
7879
"//checkpoint/orbax/checkpoint/_src/serialization:types",
7980
"//checkpoint/orbax/checkpoint/_src/tree:utils",
81+
"//orbax/checkpoint/_src/metadata:array_metadata_store",
8082
],
8183
)
8284

checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py

+131-10
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from orbax.checkpoint import utils
4242
from orbax.checkpoint._src import asyncio_utils
4343
from orbax.checkpoint._src.handlers import async_checkpoint_handler
44+
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
4445
from orbax.checkpoint._src.metadata import empty_values
4546
from orbax.checkpoint._src.metadata import tree as tree_metadata
4647
from orbax.checkpoint._src.multihost import multihost
@@ -282,6 +283,9 @@ def __init__(
282283
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
283284
tree_metadata.PYTREE_METADATA_OPTIONS
284285
),
286+
array_metadata_validator: array_metadata_store_lib.Validator = (
287+
array_metadata_store_lib.Validator()
288+
),
285289
):
286290
"""Creates BasePyTreeCheckpointHandler.
287291
@@ -301,6 +305,7 @@ def __init__(
301305
enable_post_merge_validation: If True, enables validation of the
302306
parameters after the finalize step.
303307
pytree_metadata_options: `PyTreeMetadataOptions` to manage metadata.
308+
array_metadata_validator: Validator for ArrayMetadata.
304309
"""
305310
self._save_concurrent_bytes = save_concurrent_bytes
306311
self._restore_concurrent_bytes = restore_concurrent_bytes
@@ -310,18 +315,28 @@ def __init__(
310315
self._type_handler_registry = type_handler_registry
311316
self._enable_post_merge_validation = enable_post_merge_validation
312317
self._pytree_metadata_options = pytree_metadata_options
318+
# Get ArrayMetadata Store from TypeHandler for jax.Array.
319+
# ArrayMetadata persistence is only supported for jax.Array.
320+
self._array_metadata_store = (
321+
array_metadata_store_lib.resolve_array_metadata_store(
322+
type_handler_registry
323+
)
324+
)
325+
self._array_metadata_validator = array_metadata_validator
313326

314327

315328
jax.monitoring.record_event(
316329
'/jax/orbax/pytree_checkpoint_handler/init/ocdbt'
317330
)
318331

319332
self._thread_pool = futures.ThreadPoolExecutor(
320-
max_workers=2, thread_name_prefix='base_pytree_ch'
333+
max_workers=3, thread_name_prefix='base_pytree_ch'
321334
)
322335
logging.info(
323-
'Created BasePyTreeCheckpointHandler: pytree_metadata_options=%s',
336+
'Created BasePyTreeCheckpointHandler: pytree_metadata_options=%s,'
337+
' array_metadata_store=%s',
324338
self._pytree_metadata_options,
339+
self._array_metadata_store,
325340
)
326341

327342
def get_param_names(self, item: PyTree) -> PyTree:
@@ -451,7 +466,7 @@ async def async_save(
451466
leaf.parent_dir == directory for leaf in jax.tree.leaves(param_infos)
452467
)
453468

454-
serialize_ops = []
469+
serialize_ops = [] # List of (coros -> List of futures)
455470
batch_requests = batched_serialization_requests(
456471
item,
457472
param_infos,
@@ -465,20 +480,29 @@ async def async_save(
465480
]
466481
write_size, _ = _get_batch_memory_size(request.handler, request.values)
467482
tree_memory_size += write_size
468-
# Await copy futures. Returns list of lists.
483+
# Await copy futures. Returns List[List[future.Future]].
469484
commit_futures = await asyncio.gather(*serialize_ops)
485+
# Flatten to List[future.Future].
470486
commit_futures, _ = jax.tree.flatten(commit_futures)
471487

472488
if logging.vlog_is_on(1):
473489
logging.vlog(1, 'param_info: %s', param_infos)
474490
logging.vlog(1, 'save_args: %s', save_args)
475491

492+
save_futures = []
476493
if multihost.is_primary_host(self._primary_host):
477-
commit_futures.append(
478-
self._write_metadata_file(
479-
directory, param_infos, save_args, self._use_zarr3
494+
save_futures.append(
495+
self._thread_pool.submit(
496+
self._write_metadata_after_commits,
497+
commit_futures=commit_futures,
498+
checkpoint_dir=directory,
499+
param_infos=param_infos,
500+
save_args=save_args,
501+
use_zarr3=self._use_zarr3,
480502
)
481503
)
504+
else:
505+
save_futures += commit_futures
482506

483507
_log_io_metrics(
484508
tree_memory_size,
@@ -487,7 +511,7 @@ async def async_save(
487511
)
488512
return [
489513
future.ChainedFuture(
490-
commit_futures,
514+
save_futures,
491515
functools.partial(
492516
_log_io_metrics,
493517
tree_memory_size,
@@ -725,14 +749,68 @@ class TrainState:
725749
)
726750
return restored_item
727751

752+
def _get_param_infos_with_write_shape(
753+
self,
754+
param_infos: PyTree,
755+
checkpoint_dir: epath.Path,
756+
array_metadata_store: array_metadata_store_lib.Store,
757+
) -> PyTree:
758+
"""Returns `param_infos` updated with `write_shape`.
759+
760+
Args:
761+
param_infos: A PyTree of ParamInfo to be updated.
762+
checkpoint_dir: The checkpoint directory where write_shape metadata is
763+
saved in ArrayMetadata store.
764+
array_metadata_store: The ArrayMetadata store to read write_shape metadata
765+
from.
766+
"""
767+
if not utils.is_primary_host(self._primary_host):
768+
return param_infos
769+
# Extract write_shape from ArrayMetadata for current process_index.
770+
process_index = multihost.process_index()
771+
array_metadatas = array_metadata_store.read(
772+
checkpoint_dir, process_index=process_index
773+
)
774+
if array_metadatas is None:
775+
jax_array_param_info = type_handlers.any_jax_array_param_info(param_infos)
776+
if jax_array_param_info is not None:
777+
raise ValueError(
778+
f'No ArrayMetadata found for process_index={process_index} in the'
779+
f' checkpoint directory: {checkpoint_dir}. But input PyTree'
780+
' contains at least one jax.Array param_info:'
781+
f' {jax_array_param_info}.'
782+
)
783+
return param_infos
784+
785+
assert isinstance(array_metadatas, list)
786+
array_metadatas_cache = {
787+
array_metadata.param_name: array_metadata
788+
for array_metadata in array_metadatas
789+
}
790+
791+
def update_param_info(param_info: types.ParamInfo) -> types.ParamInfo:
792+
if not type_handlers.represents_jax_array(param_info):
793+
return param_info
794+
if param_info.name not in array_metadatas_cache:
795+
raise ValueError(
796+
f'No ArrayMetadata found for param_info: {param_info}, checkpoint'
797+
f' directory: {checkpoint_dir}, process_index={process_index}.'
798+
)
799+
return dataclasses.replace(
800+
param_info,
801+
write_shape=array_metadatas_cache[param_info.name].write_shape,
802+
)
803+
804+
return jax.tree.map(update_param_info, param_infos)
805+
728806
def _write_metadata_file(
729807
self,
730808
directory: epath.Path,
731809
param_infos: PyTree,
732810
save_args: PyTree,
733811
use_zarr3: bool = False,
734812
) -> future.Future:
735-
def _save_fn():
813+
def _save_fn(param_infos):
736814
if utils.is_primary_host(self._primary_host):
737815
metadata_write_start_time = time.time()
738816
path = directory / PYTREE_METADATA_FILE
@@ -755,7 +833,35 @@ def _save_fn():
755833
)
756834
return 0
757835

758-
return self._thread_pool.submit(_save_fn)
836+
return self._thread_pool.submit(_save_fn, param_infos)
837+
838+
def _write_metadata_after_commits(
839+
self,
840+
commit_futures: List[future.Future],
841+
checkpoint_dir: epath.Path,
842+
param_infos: PyTree,
843+
save_args: PyTree,
844+
use_zarr3: bool,
845+
) -> None:
846+
if not utils.is_primary_host(self._primary_host):
847+
return
848+
for commit_future in commit_futures:
849+
commit_future.result()
850+
# `write_shape` is extracted from ArrayMetadata store saved during
851+
# materialization of commit_futures. Then it is written to the pytree
852+
# metadata.
853+
# TODO(b/390465017): Simplify all metadata related code in this module after
854+
# removing overriding of self._write_metadata_file() in subclasses. All
855+
# metadata related code can be moved to a separate class and
856+
# BasePyTreeCheckpointHandler should delegate all metadata related code to
857+
# that class.
858+
if self._array_metadata_store is not None:
859+
param_infos = self._get_param_infos_with_write_shape(
860+
param_infos, checkpoint_dir, self._array_metadata_store
861+
)
862+
self._write_metadata_file(
863+
checkpoint_dir, param_infos, save_args, use_zarr3
864+
).result()
759865

760866
def _read_metadata_file(
761867
self, directory: epath.Path
@@ -834,6 +940,21 @@ def finalize(self, directory: epath.Path) -> None:
834940
Args:
835941
directory: Path where the checkpoint is located.
836942
"""
943+
if self._array_metadata_store is not None:
944+
if self._primary_host is None:
945+
logging.warning(
946+
'[process=%s] Skipped cross-host ArrayMetadata validation'
947+
' because all hosts are primary (e.g. local storage).',
948+
multihost.process_index(),
949+
)
950+
elif utils.is_primary_host(self._primary_host):
951+
array_metadatas = self._array_metadata_store.read(directory)
952+
if array_metadatas is not None:
953+
assert isinstance(array_metadatas, dict) # read all processes.
954+
self._array_metadata_validator.validate_all_array_metadatas(
955+
array_metadatas
956+
)
957+
837958
merge_start_time = time.time()
838959
ts_context = ts_utils.get_ts_context(use_ocdbt=True)
839960
asyncio_utils.run_sync(

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py

+6
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from orbax.checkpoint._src import asyncio_utils
4242
from orbax.checkpoint._src.handlers import async_checkpoint_handler
4343
from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler
44+
from orbax.checkpoint._src.metadata import array_metadata_store as array_metadata_store_lib
4445
from orbax.checkpoint._src.metadata import empty_values
4546
from orbax.checkpoint._src.metadata import tree as tree_metadata
4647
from orbax.checkpoint._src.serialization import serialization
@@ -475,6 +476,9 @@ def __init__(
475476
pytree_metadata_options: tree_metadata.PyTreeMetadataOptions = (
476477
tree_metadata.PYTREE_METADATA_OPTIONS
477478
),
479+
array_metadata_validator: array_metadata_store_lib.Validator = (
480+
array_metadata_store_lib.Validator()
481+
),
478482
):
479483
"""Creates PyTreeCheckpointHandler.
480484
@@ -496,6 +500,7 @@ def __init__(
496500
specified, the global type handler registry will be used.
497501
handler_impl: Allows overriding the internal implementation.
498502
pytree_metadata_options: `PyTreeMetadataOptions` to manage metadata.
503+
array_metadata_validator: Validator for ArrayMetadata.
499504
"""
500505
self._aggregate_handler = MsgpackHandler(
501506
primary_host=multiprocessing_options.primary_host,
@@ -518,6 +523,7 @@ def __init__(
518523
multiprocessing_options=multiprocessing_options,
519524
type_handler_registry=type_handler_registry,
520525
pytree_metadata_options=pytree_metadata_options,
526+
array_metadata_validator=array_metadata_validator,
521527
)
522528
self._pytree_metadata_options = pytree_metadata_options
523529

checkpoint/orbax/checkpoint/_src/metadata/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,13 @@ py_library(
161161
py_library(
162162
name = "array_metadata_store",
163163
srcs = ["array_metadata_store.py"],
164+
visibility = default_visibility + [
165+
"//learning/deepmind/jax/roc/formats/roc_orbax:__subpackages__",
166+
],
164167
deps = [
165168
":array_metadata",
166169
"//checkpoint/orbax/checkpoint/_src/multihost",
170+
"//checkpoint/orbax/checkpoint/_src/serialization:types",
167171
],
168172
)
169173

@@ -173,5 +177,6 @@ py_test(
173177
deps = [
174178
":array_metadata",
175179
":array_metadata_store",
180+
"//checkpoint/orbax/checkpoint/_src/serialization:type_handlers",
176181
],
177182
)

0 commit comments

Comments
 (0)