Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ trigger creation if it has not already been started.
- Add new `OrbaxV0Layout` that will handle specific v0 checkpoint format logic.
- Add sharding fallback for target tree leaves in `StandardCheckpointHandler`
restore, removing sharding/topology warnings.
- Add PyTorch DCP (Distributed Checkpoint) to the benchmark suite.

## [0.11.31] - 2025-12-11

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,27 @@ def _partition_axis_name(offset: int) -> str:



def is_safetensor_checkpoint(path: str|epath.Path) -> bool:
"""Checks if the checkpoint is a SafeTensor checkpoint."""
path = epath.Path(path)
for f in path.iterdir():
if f.is_file() and 'safetensors' in f.name:
return True
return False


def load_checkpoint(path: str) -> Any:
"""Loads a PyTree of test checkpoint from a provided path."""
logging.info('Loading checkpoint from path: %s', path)
path = epath.Path(path)


# If the checkpoint is a SafeTensor checkpoint, return the path directly.
# This is because we don't need to load the checkpoint into a PyTree, and can
# directly use the path to load the checkpoint in the benchmark test.
if is_safetensor_checkpoint(path):
return path

use_ocdbt = type_handlers.is_ocdbt_checkpoint(path)
with checkpointer.Checkpointer(
pytree_checkpoint_handler.PyTreeCheckpointHandler(use_ocdbt=use_ocdbt)
Expand Down
46 changes: 35 additions & 11 deletions checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import hashlib
import itertools
import sys
import threading
from typing import Any, Callable

from absl import logging
Expand All @@ -33,6 +34,29 @@
from orbax.checkpoint._src.testing.benchmarks.core import metric as metric_lib


def _sync_global_processes(
name: str,
):
"""Syncs global processes using torch.distributed if available, else multihost."""
try:
import torch.distributed as dist # pylint: disable=g-import-not-at-top

if dist.is_initialized():
logging.vlog(
1,
"[process=%s][thread=%s] sync_global_processes with torch"
" barrier: %s",
dist.get_rank(),
threading.current_thread().name,
name,
)
dist.barrier()
return
except ImportError:
pass
multihost.sync_global_processes(name)


@dataclasses.dataclass(frozen=True)
class BenchmarkOptions:
"""Base class for benchmark generator options."""
Expand Down Expand Up @@ -148,13 +172,13 @@ def run(self, repeat_index: int | None = None) -> TestResult:
name += f"_repeat_{repeat_index}"
logging.info(
"[process_id=%s] Setting up test: %s",
multihost.process_index(),
metric_lib.get_process_index(),
name,
)

benchmark_metrics = metric_lib.Metrics(name=f"{name} Internal")
with benchmark_metrics.measure("sync_global_processes:benchmark:run"):
multihost.sync_global_processes("benchmark:run")
_sync_global_processes("benchmark:run")

path = directory_setup.setup_test_directory(
self.name, self.output_dir, repeat_index
Expand All @@ -163,7 +187,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:
with benchmark_metrics.measure(
"sync_global_processes:benchmark:setup_test_directory"
):
multihost.sync_global_processes("benchmark:setup_test_directory")
_sync_global_processes("benchmark:setup_test_directory")

if self.checkpoint_config.path is None:
data = checkpoint_generation.generate_checkpoint(
Expand All @@ -175,7 +199,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:
with benchmark_metrics.measure(
"sync_global_processes:benchmark:setup_pytree"
):
multihost.sync_global_processes("benchmark:setup_pytree")
_sync_global_processes("benchmark:setup_pytree")

context = TestContext(
pytree=data,
Expand All @@ -191,7 +215,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:

logging.info(
"[process_id=%s] Executing test function: %s",
multihost.process_index(),
metric_lib.get_process_index(),
name,
)
try:
Expand All @@ -201,13 +225,13 @@ def run(self, repeat_index: int | None = None) -> TestResult:
# execution is recorded in the TestResult.
if sys.version_info >= (3, 11):
e.add_note(
f"[process_id={multihost.process_index()}],"
f"[process_id={metric_lib.get_process_index()}],"
f" {test_context_summary[:100]}"
)
logging.error(
"[process_id=%s] Test function '%s' context: %s, raised an"
" exception: %s",
multihost.process_index(),
metric_lib.get_process_index(),
name,
test_context_summary[:100],
e,
Expand All @@ -221,7 +245,7 @@ def run(self, repeat_index: int | None = None) -> TestResult:

logging.info(
"[process_id=%s] Test finished: %s",
multihost.process_index(),
metric_lib.get_process_index(),
name,
)

Expand Down Expand Up @@ -304,13 +328,13 @@ def _get_options_product(self) -> Sequence[BenchmarkOptions]:
option_instances.append(option_instance)
logging.info(
"[process_id=%s] Generating valid option combination: %s",
multihost.process_index(),
metric_lib.get_process_index(),
option_instance,
)
else:
logging.info(
"[process_id=%s] Skipping invalid option combination: %s",
multihost.process_index(),
metric_lib.get_process_index(),
option_instance,
)
return option_instances
Expand Down Expand Up @@ -458,5 +482,5 @@ def run(self) -> Sequence[TestResult]:
)

logging.info(self._suite_metrics.generate_report())
multihost.sync_global_processes("test_suite:run_end")
_sync_global_processes("test_suite:run_end")
return all_results
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from absl import logging
from etils import epath
import jax
from orbax.checkpoint._src.testing.benchmarks.core import metric


def setup_test_directory(
Expand All @@ -39,7 +39,7 @@ def setup_test_directory(
if repeat_index is not None:
path = path / f"repeat_{repeat_index}"
logging.info("Setting up test directory at: %s", path)
if jax.process_index() == 0:
if metric.get_process_index() == 0:
if path.exists():
logging.warning("Test directory %s already exists. Deleting it.", path)
path.rmtree()
Expand Down
26 changes: 19 additions & 7 deletions checkpoint/orbax/checkpoint/_src/testing/benchmarks/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@
import tensorstore as ts


def get_process_index():
"""Returns process index from torch.distributed if available, else from multihost."""
try:
import torch.distributed as dist # pylint: disable=g-import-not-at-top

if dist.is_initialized():
return dist.get_rank()
except ImportError:
pass
return multihost.process_index()


class BaseMetric:
"""Base class for a metric type."""

Expand All @@ -47,7 +59,7 @@ def start(self):
self._start_time = time.perf_counter()
logging.info(
"[process_id=%s] Starting metric: '%s'...",
multihost.process_index(),
get_process_index(),
self.name,
)

Expand All @@ -56,7 +68,7 @@ def stop(self) -> dict[str, tuple[Any, str]]:
duration = time.perf_counter() - self._start_time
logging.info(
"[process_id=%s] Finished metric: '%s' (took %.4fs)",
multihost.process_index(),
get_process_index(),
self.name,
duration,
)
Expand Down Expand Up @@ -168,7 +180,7 @@ def stop(self) -> dict[str, tuple[Any, str]]:

self._log_tracemalloc_snapshot_diff(
self.name,
multihost.process_index(),
get_process_index(),
self._start_snapshot,
end_snapshot,
top_n=15,
Expand Down Expand Up @@ -285,7 +297,7 @@ def stop(self) -> dict[str, tuple[Any, str]]:
diff = self._diff_metrics(self._start_metrics, end_metrics)
logging.info(
"[process_id=%s] Finished metric: %s, num_diffs=%d",
multihost.process_index(),
get_process_index(),
self.name,
len(diff),
)
Expand Down Expand Up @@ -423,12 +435,12 @@ def report(self):
"""Logs a formatted report of all collected metrics."""
report_lines = []
report_lines.append(
f"---[process_id={multihost.process_index()}] {self.name} Metrics"
f"---[process_id={get_process_index()}] {self.name} Metrics"
" Report ---"
)
if not self.results:
report_lines.append(
f"[process_id={multihost.process_index()}] No metrics recorded."
f"[process_id={get_process_index()}] No metrics recorded."
)
else:
for name, (value, unit) in sorted(self.results.items()):
Expand Down Expand Up @@ -649,7 +661,7 @@ def export_to_tensorboard(self, tensorboard_dir: epath.Path):
"""Exports metrics to TensorBoard."""
logging.info("Writing per-repetition metrics to TensorBoard...")
for benchmark_name, results in self._runs.items():
is_primary_host = multihost.process_index() == 0
is_primary_host = get_process_index() == 0
writer = metric_writers.create_default_writer(
tensorboard_dir,
just_logging=not is_primary_host,
Expand Down
Loading
Loading