41
41
from orbax .checkpoint import utils
42
42
from orbax .checkpoint ._src import asyncio_utils
43
43
from orbax .checkpoint ._src .handlers import async_checkpoint_handler
44
+ from orbax .checkpoint ._src .metadata import array_metadata_store as array_metadata_store_lib
44
45
from orbax .checkpoint ._src .metadata import empty_values
45
46
from orbax .checkpoint ._src .metadata import tree as tree_metadata
46
47
from orbax .checkpoint ._src .multihost import multihost
@@ -282,6 +283,9 @@ def __init__(
282
283
pytree_metadata_options : tree_metadata .PyTreeMetadataOptions = (
283
284
tree_metadata .PYTREE_METADATA_OPTIONS
284
285
),
286
+ array_metadata_validator : array_metadata_store_lib .Validator = (
287
+ array_metadata_store_lib .Validator ()
288
+ ),
285
289
):
286
290
"""Creates BasePyTreeCheckpointHandler.
287
291
@@ -301,6 +305,7 @@ def __init__(
301
305
enable_post_merge_validation: If True, enables validation of the
302
306
parameters after the finalize step.
303
307
pytree_metadata_options: `PyTreeMetadataOptions` to manage metadata.
308
+ array_metadata_validator: Validator for ArrayMetadata.
304
309
"""
305
310
self ._save_concurrent_bytes = save_concurrent_bytes
306
311
self ._restore_concurrent_bytes = restore_concurrent_bytes
@@ -310,18 +315,28 @@ def __init__(
310
315
self ._type_handler_registry = type_handler_registry
311
316
self ._enable_post_merge_validation = enable_post_merge_validation
312
317
self ._pytree_metadata_options = pytree_metadata_options
318
+ # Get ArrayMetadata Store from TypeHandler for jax.Array.
319
+ # ArrayMetadata persistence is only supported for jax.Array.
320
+ self ._array_metadata_store = (
321
+ array_metadata_store_lib .resolve_array_metadata_store (
322
+ type_handler_registry
323
+ )
324
+ )
325
+ self ._array_metadata_validator = array_metadata_validator
313
326
314
327
315
328
jax .monitoring .record_event (
316
329
'/jax/orbax/pytree_checkpoint_handler/init/ocdbt'
317
330
)
318
331
319
332
self ._thread_pool = futures .ThreadPoolExecutor (
320
- max_workers = 2 , thread_name_prefix = 'base_pytree_ch'
333
+ max_workers = 3 , thread_name_prefix = 'base_pytree_ch'
321
334
)
322
335
logging .info (
323
- 'Created BasePyTreeCheckpointHandler: pytree_metadata_options=%s' ,
336
+ 'Created BasePyTreeCheckpointHandler: pytree_metadata_options=%s,'
337
+ ' array_metadata_store=%s' ,
324
338
self ._pytree_metadata_options ,
339
+ self ._array_metadata_store ,
325
340
)
326
341
327
342
def get_param_names (self , item : PyTree ) -> PyTree :
@@ -451,7 +466,7 @@ async def async_save(
451
466
leaf .parent_dir == directory for leaf in jax .tree .leaves (param_infos )
452
467
)
453
468
454
- serialize_ops = []
469
+ serialize_ops = [] # List of (coros -> List of futures)
455
470
batch_requests = batched_serialization_requests (
456
471
item ,
457
472
param_infos ,
@@ -465,20 +480,29 @@ async def async_save(
465
480
]
466
481
write_size , _ = _get_batch_memory_size (request .handler , request .values )
467
482
tree_memory_size += write_size
468
- # Await copy futures. Returns list of lists .
483
+ # Await copy futures. Returns List[List[future.Future]] .
469
484
commit_futures = await asyncio .gather (* serialize_ops )
485
+ # Flatten to List[future.Future].
470
486
commit_futures , _ = jax .tree .flatten (commit_futures )
471
487
472
488
if logging .vlog_is_on (1 ):
473
489
logging .vlog (1 , 'param_info: %s' , param_infos )
474
490
logging .vlog (1 , 'save_args: %s' , save_args )
475
491
492
+ save_futures = []
476
493
if multihost .is_primary_host (self ._primary_host ):
477
- commit_futures .append (
478
- self ._write_metadata_file (
479
- directory , param_infos , save_args , self ._use_zarr3
494
+ save_futures .append (
495
+ self ._thread_pool .submit (
496
+ self ._write_metadata_after_commits ,
497
+ commit_futures = commit_futures ,
498
+ checkpoint_dir = directory ,
499
+ param_infos = param_infos ,
500
+ save_args = save_args ,
501
+ use_zarr3 = self ._use_zarr3 ,
480
502
)
481
503
)
504
+ else :
505
+ save_futures += commit_futures
482
506
483
507
_log_io_metrics (
484
508
tree_memory_size ,
@@ -487,7 +511,7 @@ async def async_save(
487
511
)
488
512
return [
489
513
future .ChainedFuture (
490
- commit_futures ,
514
+ save_futures ,
491
515
functools .partial (
492
516
_log_io_metrics ,
493
517
tree_memory_size ,
@@ -725,14 +749,68 @@ class TrainState:
725
749
)
726
750
return restored_item
727
751
752
+ def _get_param_infos_with_write_shape (
753
+ self ,
754
+ param_infos : PyTree ,
755
+ checkpoint_dir : epath .Path ,
756
+ array_metadata_store : array_metadata_store_lib .Store ,
757
+ ) -> PyTree :
758
+ """Returns `param_infos` updated with `write_shape`.
759
+
760
+ Args:
761
+ param_infos: A PyTree of ParamInfo to be updated.
762
+ checkpoint_dir: The checkpoint directory where write_shape metadata is
763
+ saved in ArrayMetadata store.
764
+ array_metadata_store: The ArrayMetadata store to read write_shape metadata
765
+ from.
766
+ """
767
+ if not utils .is_primary_host (self ._primary_host ):
768
+ return param_infos
769
+ # Extract write_shape from ArrayMetadata for current process_index.
770
+ process_index = multihost .process_index ()
771
+ array_metadatas = array_metadata_store .read (
772
+ checkpoint_dir , process_index = process_index
773
+ )
774
+ if array_metadatas is None :
775
+ jax_array_param_info = type_handlers .any_jax_array_param_info (param_infos )
776
+ if jax_array_param_info is not None :
777
+ raise ValueError (
778
+ f'No ArrayMetadata found for process_index={ process_index } in the'
779
+ f' checkpoint directory: { checkpoint_dir } . But input PyTree'
780
+ ' contains at least one jax.Array param_info:'
781
+ f' { jax_array_param_info } .'
782
+ )
783
+ return param_infos
784
+
785
+ assert isinstance (array_metadatas , list )
786
+ array_metadatas_cache = {
787
+ array_metadata .param_name : array_metadata
788
+ for array_metadata in array_metadatas
789
+ }
790
+
791
+ def update_param_info (param_info : types .ParamInfo ) -> types .ParamInfo :
792
+ if not type_handlers .represents_jax_array (param_info ):
793
+ return param_info
794
+ if param_info .name not in array_metadatas_cache :
795
+ raise ValueError (
796
+ f'No ArrayMetadata found for param_info: { param_info } , checkpoint'
797
+ f' directory: { checkpoint_dir } , process_index={ process_index } .'
798
+ )
799
+ return dataclasses .replace (
800
+ param_info ,
801
+ write_shape = array_metadatas_cache [param_info .name ].write_shape ,
802
+ )
803
+
804
+ return jax .tree .map (update_param_info , param_infos )
805
+
728
806
def _write_metadata_file (
729
807
self ,
730
808
directory : epath .Path ,
731
809
param_infos : PyTree ,
732
810
save_args : PyTree ,
733
811
use_zarr3 : bool = False ,
734
812
) -> future .Future :
735
- def _save_fn ():
813
+ def _save_fn (param_infos ):
736
814
if utils .is_primary_host (self ._primary_host ):
737
815
metadata_write_start_time = time .time ()
738
816
path = directory / PYTREE_METADATA_FILE
@@ -755,7 +833,35 @@ def _save_fn():
755
833
)
756
834
return 0
757
835
758
- return self ._thread_pool .submit (_save_fn )
836
+ return self ._thread_pool .submit (_save_fn , param_infos )
837
+
838
+ def _write_metadata_after_commits (
839
+ self ,
840
+ commit_futures : List [future .Future ],
841
+ checkpoint_dir : epath .Path ,
842
+ param_infos : PyTree ,
843
+ save_args : PyTree ,
844
+ use_zarr3 : bool ,
845
+ ) -> None :
846
+ if not utils .is_primary_host (self ._primary_host ):
847
+ return
848
+ for commit_future in commit_futures :
849
+ commit_future .result ()
850
+ # `write_shape` is extracted from ArrayMetadata store saved during
851
+ # materialization of commit_futures. Then it is written to the pytree
852
+ # metadata.
853
+ # TODO(b/390465017): Simplify all metadata related code in this module after
854
+ # removing overriding of self._write_metadata_file() in subclasses. All
855
+ # metadata related code can be moved to a separate class and
856
+ # BasePyTreeCheckpointHandler should delegate all metadata related code to
857
+ # that class.
858
+ if self ._array_metadata_store is not None :
859
+ param_infos = self ._get_param_infos_with_write_shape (
860
+ param_infos , checkpoint_dir , self ._array_metadata_store
861
+ )
862
+ self ._write_metadata_file (
863
+ checkpoint_dir , param_infos , save_args , use_zarr3
864
+ ).result ()
759
865
760
866
def _read_metadata_file (
761
867
self , directory : epath .Path
@@ -834,6 +940,21 @@ def finalize(self, directory: epath.Path) -> None:
834
940
Args:
835
941
directory: Path where the checkpoint is located.
836
942
"""
943
+ if self ._array_metadata_store is not None :
944
+ if self ._primary_host is None :
945
+ logging .warning (
946
+ '[process=%s] Skipped cross-host ArrayMetadata validation'
947
+ ' because all hosts are primary (e.g. local storage).' ,
948
+ multihost .process_index (),
949
+ )
950
+ elif utils .is_primary_host (self ._primary_host ):
951
+ array_metadatas = self ._array_metadata_store .read (directory )
952
+ if array_metadatas is not None :
953
+ assert isinstance (array_metadatas , dict ) # read all processes.
954
+ self ._array_metadata_validator .validate_all_array_metadatas (
955
+ array_metadatas
956
+ )
957
+
837
958
merge_start_time = time .time ()
838
959
ts_context = ts_utils .get_ts_context (use_ocdbt = True )
839
960
asyncio_utils .run_sync (
0 commit comments