Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3.2 vision does not run with distributed state dict #2277

Open
acisseJZhong opened this issue Jan 17, 2025 · 1 comment
Open

Llama3.2 vision does not run with distributed state dict #2277

acisseJZhong opened this issue Jan 17, 2025 · 1 comment

Comments

@acisseJZhong
Copy link
Contributor

Running Llama3.2 vision full finetune distributed using distributed state dict, I am running into errors that come from set_model_state_dict.
Full error log:

[rank2]: Traceback (most recent call last):
[rank2]:   File "/home/jessicazhong/torchtune/recipes/full_finetune_distributed.py", line 911, in <module>
[rank2]:     sys.exit(recipe_main())
[rank2]:              ^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank2]:     sys.exit(recipe_main(conf))
[rank2]:              ^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/torchtune/recipes/full_finetune_distributed.py", line 905, in recipe_main
[rank2]:     recipe.setup(cfg=cfg)
[rank2]:   File "/home/jessicazhong/torchtune/recipes/full_finetune_distributed.py", line 253, in setup
[rank2]:     self._model = self._setup_model(
[rank2]:                   ^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/torchtune/recipes/full_finetune_distributed.py", line 546, in _setup_model
[rank2]:     training.load_from_full_model_state_dict(
[rank2]:   File "/home/jessicazhong/torchtune/torchtune/training/_distributed.py", line 216, in load_from_full_model_state_dict
[rank2]:     return set_model_state_dict(
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict.py", line 1218, in set_model_state_dict
[rank2]:     return _load_model_state_dict(model, model_state_dict, info)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank2]:     return func(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict.py", line 591, in _load_model_state_dict
[rank2]:     _state_dict_fn(model, "load_state_dict")(
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2561, in load_state_dict
[rank2]:     load(self, state_dict)
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2549, in load
[rank2]:     load(child, child_state_dict, child_prefix)  # noqa: F821
[rank2]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2549, in load
[rank2]:     load(child, child_state_dict, child_prefix)  # noqa: F821
[rank2]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2549, in load
[rank2]:     load(child, child_state_dict, child_prefix)  # noqa: F821
[rank2]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2554, in load
[rank2]:     out = hook(module, incompatible_keys)
[rank2]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 251, in <lambda>
[rank2]:     lambda *args, **kwargs: self.reset_sharded_param()
[rank2]:                             ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune-v0.5.0/lib/python3.11/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 839, in reset_sharded_param
[rank2]:     local_tensor = new_param._local_tensor
[rank2]:                    ^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: AttributeError: 'Parameter' object has no attribute '_local_tensor'
@acisseJZhong
Copy link
Contributor Author

I believe @mori360 is currently looking into it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant