Fix model saving bug post training with tensor parallel in Accelerate#36434
Fix model saving bug post training with tensor parallel in Accelerate#36434bursteratom wants to merge 4 commits intohuggingface:mainfrom
Conversation
2b9d0b4 to
df531d2
Compare
|
@kmehant Wondering what your thoughts are? |
|
cc @ArthurZucker who's also doing a big TP refactor right now! |
d4e4907 to
4460137
Compare
|
@ArthurZucker @kmehant Seems like I'm failing a couple of tests, but I'm struggling to find the root cause. Wondering if you two can kindly take a look? |
|
Same problem with me. T_T #36433 |
| gathered_state_dict = {} | ||
| for key, value in state_dict.items(): | ||
| if hasattr(value, "_local_tensor"): | ||
| gathered_state_dict[key] = value.to_local().cpu() |
There was a problem hiding this comment.
Note: we might want to do something closer to https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/distributed/tensor/_api.py#L572
There was a problem hiding this comment.
yeah using full_tensors will be better I think.
There was a problem hiding this comment.
@bursteratom and I found that full_tensor would hang here, not 100% sure why, but we could investigate more if manually redistributing doesn't work.
There was a problem hiding this comment.
@salmanmohammadi I wonder if it's related: pytorch/pytorch#115310
There was a problem hiding this comment.
@muellerzr Should this be in transformers or is the preference that this sort of unsharding is in accelerate?
There was a problem hiding this comment.
@winglian We have (will have) similar stuff in Accelerate for FSDP2, so possibly if we want to support both TP + FSDP2 on Accelerate side it'd need to be on both places. Though I remember full_tensor() working for me there, I might take a look at this too.
There was a problem hiding this comment.
value.to_local().cpu()
This would only return local to the rank shard of the tensor if the DTensor has a Shard placement which is highly likely for TP. Would not that mean the state dicts would be now different on each rank, isn't that a problem?
There was a problem hiding this comment.
Yes, this is correct. .to_local() only returns the local part of the tensor if it was sharded (most likely was as we're talking about TP), therefore this results for each process to have its own part. Possibility for why this hangs is because iirc full_tensor() requires communication and here only main process is running iirc.
|
cc @muellerzr @SunMarc for accelerate as well |
45866d4 to
809275b
Compare
SunMarc
left a comment
There was a problem hiding this comment.
Thanks ! Please add a test also
3b345fa to
24a6c33
Compare
|
Would using full_tensors be a better approach? |
|
@machinelearningprodigy I initially used |
dedaa12 to
9708c36
Compare
|
cc @kwen2501 if you have any idea |
|
Thank you so much for taking a look at this @SunMarc !!! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
e569f9a to
9c31402
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
IMO we should be careful and can save without exploding memory
| gathered_state_dict = {} | ||
| for key, value in state_dict.items(): | ||
| if hasattr(value, "_local_tensor"): | ||
| gathered_state_dict[key] = value.to_local().cpu() |
There was a problem hiding this comment.
memory will explode no? this should happen in the function that write the files to make sure you save bits by bits
9c31402 to
2217e31
Compare
2217e31 to
ee271a0
Compare
|
re @S1ro1 might be good to fix this properly somehow |
|
Oh, this should actually be fixed by #37919 already. Should probably close then. |
|
SG ! |
What does this PR do?
Currently, attempting to save model after training with tensor parallel in Accelerate gives the
RuntimeError: Attempted to access the data pointer on an invalid python storage, this is due to the state dict not properly gathered from the sharded tensors beforehand. This PR fixes the error, allowing for successful model saving.Big thank you to @salmanmohammadi for the discussion!
Fixes # (issue)
#34194 (comment)
#36436
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.