-
Notifications
You must be signed in to change notification settings - Fork 39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-host Checkpointing Error #999
Comments
I just encountered the same issue in a different project. Even the toy example from the official tutorial fails. I'm running the following code on a TPUv4-16.
The error says:
It feels like something is wrong with Orbax multi-host saving? The same code works perfectly on a single host with TPUv4-8. |
I think the issue is that Orbax is assuming the root directory is a global storage. If you specify Also, @YUE-FAN why are you using |
Hi @cpgaffney1 thanks for the response! I changed the directory to a GCS bucket and currently see a different error, but still related to a directory not being found. I believe I have set all the correct permissions as I am able to write other files to the storage bucket with no issue. I also did this test directly installing Orbax-checkpoint from the GitHub rather than the pypi release and got the same error. Traceback (most recent call last):
File "/home/simonsenan/dnadiffusion-jax/checkpoint_test.py", line 47, in <module>
test_checkpointing()
File "/home/simonsenan/dnadiffusion-jax/checkpoint_test.py", line 34, in test_checkpointing
with ocp.CheckpointManager(path, item_names=("state", "custom_data"), metadata=global_metadata) as mngr:
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 634, in __init__
self._save_metadata(metadata)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1331, in _save_metadata
self._metadata_checkpointer.save(path, metadata)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 165, in save
tmpdir = utils.create_tmp_directory(
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/step.py", line 626, in create_tmp_directory
checkpoint_metadata_store.write(
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/metadata/checkpoint.py", line 238, in write
self._store_impl.write(checkpoint_path, checkpoint_metadata)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/metadata/checkpoint.py", line 126, in write
raise ValueError(f'Checkpoint path does not exist: {checkpoint_path}')
ValueError: Checkpoint path does not exist: gs://dnadiffusion-bucket/checkpoints/metadata |
I think you don't actually have the latest version of the code. See checkpointer.py. At head there is no reference to Also checkpoint |
@cpgaffney1 i am using 0.5.20 version of orbax.checkpoint what all we need to update ? |
@cpgaffney1 You're right I was actually accidentally reinstalling the pypi version when loading my package, but installing directly from the GitHub does resolve the issue (along with syncing all hosts before saving). Is this fix included the the 0.5.21 release on pypi from yesterday or will it be included in the next release? Otherwise, thanks for all your assistance and feel free to close this issue! |
Thanks! Using GCS solves the problem perfectly :D I was using MaxText, |
|
Hi @cpgaffney1, oddly I managed to get it working for my toy example, but migrating back to main library I'm seeing an error that's similar again. Traceback (most recent call last):
File "/home/simonsenan/dnadiffusion-jax/train.py", line 224, in main
train(
File "/home/simonsenan/dnadiffusion-jax/train.py", line 195, in train
checkpoint_manager.save(
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py", line 1078, in save
self._checkpointer.save(save_directory, args=args)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 193, in save
tmpdir = self.create_temporary_path(directory)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py", line 157, in create_temporary_path
tmpdir.create()
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/atomicity.py", line 441, in create
return _create_tmp_directory(
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/path/atomicity.py", line 192, in _create_tmp_directory
checkpoint_metadata_store.write(
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/metadata/checkpoint.py", line 238, in write
self._store_impl.write(checkpoint_path, checkpoint_metadata)
File "/home/simonsenan/.local/lib/python3.10/site-packages/orbax/checkpoint/metadata/checkpoint.py", line 126, in write
raise ValueError(f'Checkpoint path does not exist: {checkpoint_path}')
ValueError: Checkpoint path does not exist: gs://dnadiffusion-bucket/checkpoints/105
I'm still seeing it successfully write a checkpoint to my google storage, so I wonder if this is coming from one of the non primary hosts? Let me know if I'm still overlooking something, but I did confirm as per your last suggestion that I am running the latest version of Orbax (0.5.22 installing directly from the GitHub) |
@niketkumar could you take a look at this? It's a bit weird because |
Based on the above error stack, it is not likely that The checkpoint_metadata_store write is called right after the (Looking at the stack, it seems that CheckpointManager initializes a Checkpointer, not AsyncCheckpointer. For a Checkpointer, we only allow synchronous @ssenan You can check if the current host is a primary or not with @ssenan I didn't get what you meant by
I am not sure how the run managed to write a checkpoint in spite of the above error. Can you please explain your scenario and observations a bit? |
@niketkumar Sorry for the delay, I will check on this / elaborate further in a couple days. |
Hi Everyone,
I've been trying to checkpoint training using Orbax in a project linked here project. When I test the code locally I'm able checkpoint successfully, but when training in a TPU v4-32 VM I encounter an issue related to directories not being found.
I've put together a simpler example using code from the Orbax docs, which outputs a similar error.
which appears to succeed on the process index, but fail on the rest of the hosts.
Here is the error I see:
Finally, here's the command I normally use when installing all my dependencies on the vm
Is this issue related to my sharding and the directory not being created on all of the hosts or something on the Orbax end? Any assistance is greatly appreciated!
The text was updated successfully, but these errors were encountered: