Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

step_prefix cannot contain _ -- Checkpoint manager does not recognized multiple _. #1499

Open
scott-yj-yang opened this issue Jan 14, 2025 · 0 comments

Comments

@scott-yj-yang
Copy link

scott-yj-yang commented Jan 14, 2025

Bug Description:

When I created a checkpoint manager option like the following,

options = ocp.CheckpointManagerOptions(step_prefix="ppo_networks")
with ocp.CheckpointManager(
    ".../model_checkpoints/7358284e-a603-453f-9024-f69a27a293c4",
    options=options,
) as mngr:
    mngr.restore(0)

with my directory looks like this

Image

it gives me an value error of the following when instantiating the manager object.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[13], line 5
      1 import orbax.checkpoint as ocp
      4 options = ocp.CheckpointManagerOptions(step_prefix="ppo_networks")
----> 5 with ocp.CheckpointManager(
      6     "/root/vast/scott-yang/track-mjx/model_checkpoints/7358284e-a603-453f-9024-f69a27a293c4",
      7     options=options,
      8 ) as mngr:
      9     mngr.restore(0)

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py:685, in CheckpointManager.__init__(self, directory, checkpointers, options, metadata, item_names, item_handlers, logger, handler_registry)
    675   self._cleanup_tmp_directories()
    677 self._step_name_format = (
    678     self._options.step_name_format
    679     or step_lib.standard_name_format(
   (...)
    682     )
    683 )
--> 685 self._checkpoints = self._load_checkpoint_infos()
    687 self._metadata_checkpointer = Checkpointer(
    688     JsonCheckpointHandler(
    689         multiprocessing_options=self._multiprocessing_options
   (...)
    694     temporary_path_class=self._options.temporary_path_class,
    695 )
    696 if self._options.read_only and not self._metadata_path().exists():

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/checkpoint_manager.py:1431, in CheckpointManager._load_checkpoint_infos(self)
   1423 """Loads a list of CheckpointInfo for existing checkpoints.
   1424 
   1425 If none are present, returns empty list.
   (...)
   1428   a list of CheckpointInfo, sorted by increasing step.
   1429 """
   1430 start = time.time()
-> 1431 steps = utils.checkpoint_steps(
   1432     self.directory, self._options.single_host_load_and_broadcast
   1433 )
   1434 steps.sort()  # Prefer in-place sort.
   1436 if not steps:

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:698, in checkpoint_steps(checkpoint_dir, single_host_load_and_broadcast)
    696   padded_step_list = multihost.broadcast_one_to_all(padded_step_list)
    697   return [step for step in padded_step_list if step >= 0]
--> 698 return _checkpoint_steps(checkpoint_dir)

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:682, in checkpoint_steps.<locals>._checkpoint_steps(path)
    681 def _checkpoint_steps(path: epath.Path) -> List[int]:
--> 682   return [
    683       step_from_checkpoint_name(s.name) for s in checkpoint_steps_paths(path)
    684   ]

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:683, in <listcomp>(.0)
    681 def _checkpoint_steps(path: epath.Path) -> List[int]:
    682   return [
--> 683       step_from_checkpoint_name(s.name) for s in checkpoint_steps_paths(path)
    684   ]

File ~/miniforge3/envs/track_mjx/lib/python3.11/site-packages/orbax/checkpoint/path/step.py:645, in step_from_checkpoint_name(name)
    643 elif tmp_match := re.match(TMP_DIR_STEP_PATTERN, name):
    644   return int(tmp_match.group(1))
--> 645 raise ValueError(f'Unrecognized name format: {name}.')

ValueError: Unrecognized name format: ppo_networks_1024000.

Specifically, when I check the step.py

def step_from_checkpoint_name(name: str) -> int:
"""Returns the step from a checkpoint name. Also works for tmp checkpoints."""
if name.isdigit():
return int(name)
elif name.split('_')[-1].isdigit():
split = name.split('_')
if len(split) == 2 and split[0]:
return int(split[-1])
elif tmp_match := re.match(TMP_DIR_STEP_PATTERN, name):
return int(tmp_match.group(1))
raise ValueError(f'Unrecognized name format: {name}.')
it assumes that after the split by _, there are only two members. An input validation of the prefix is needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant