-
Notifications
You must be signed in to change notification settings - Fork 32.2k
(Part 2) feat: allow for tp_size attr for tplizing the model #37054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
1059fff
feat: custom tp_size, new transformers tp interface
kmehant bb2950d
fix: review cmt - error when tp_plan not set for tp_size
kmehant a65130c
fix: nit in docs
kmehant d77505c
Merge branch 'main' into tp-size
SunMarc c073736
Merge branch 'main' into tp-size
S1ro1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1784,6 +1784,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi | |
| # for example. | ||
| _tp_plan = None | ||
|
|
||
| # tensor parallel degree to which model is sharded to. | ||
| _tp_size = None | ||
|
|
||
| # A pipeline parallel plan specifying the layers which may not be present | ||
| # on all ranks when PP is enabled. For top-level models, this attribute is | ||
| # currently defined in respective model code. For base models, this | ||
|
|
@@ -3845,6 +3848,8 @@ def from_pretrained( | |
| A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts | ||
| `tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with | ||
| `torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations. | ||
| tp_size (`str`, *optional*): | ||
| A torch tensor parallel degree. If not provided would default to world size. | ||
| offload_folder (`str` or `os.PathLike`, *optional*): | ||
| If the `device_map` contains any value `"disk"`, the folder where we will offload weights. | ||
| offload_state_dict (`bool`, *optional*): | ||
|
|
@@ -3941,6 +3946,7 @@ def from_pretrained( | |
| generation_config = kwargs.pop("generation_config", None) | ||
| gguf_file = kwargs.pop("gguf_file", None) | ||
| tp_plan = kwargs.pop("tp_plan", None) | ||
| tp_size = kwargs.pop("tp_size", None) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's raise an error if
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @SunMarc Addressed this comment, thank you. |
||
| key_mapping = kwargs.pop("key_mapping", None) | ||
| # Not used anymore -- remove them from the kwargs | ||
| _ = kwargs.pop("resume_download", None) | ||
|
|
@@ -3953,7 +3959,8 @@ def from_pretrained( | |
| raise ValueError( | ||
| "`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies." | ||
| ) | ||
|
|
||
| if tp_size is not None and tp_plan is None: | ||
| raise ValueError("tp_plan has to be set when tp_size is passed.") | ||
| if tp_plan is not None and tp_plan != "auto": | ||
| # TODO: we can relax this check when we support taking tp_plan from a json file, for example. | ||
| raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.") | ||
|
|
@@ -4007,9 +4014,10 @@ def from_pretrained( | |
| sys.stderr = open(os.devnull, "w") | ||
| # This is the easiest way to dispatch to the current process device | ||
| device_map = tp_device | ||
| # Assuming sharding the model onto the world | ||
| world_size = torch.distributed.get_world_size() | ||
| device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,)) | ||
|
|
||
| # Assuming sharding the model onto the world when tp_size not provided | ||
| tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size() | ||
| device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,)) | ||
|
|
||
| if use_auth_token is not None: | ||
| warnings.warn( | ||
|
|
@@ -4373,6 +4381,9 @@ def from_pretrained( | |
| weights_only=weights_only, | ||
| ) | ||
|
|
||
| # record tp degree the model sharded to | ||
| model._tp_size = tp_size | ||
|
|
||
| # make sure token embedding weights are still tied if needed | ||
| model.tie_weights() | ||
|
|
||
|
|
@@ -4456,7 +4467,6 @@ def from_pretrained( | |
| elif from_flax: | ||
| loading_info = None | ||
| return model, loading_info | ||
|
|
||
| return model | ||
|
|
||
| @staticmethod | ||
|
|
@@ -5100,6 +5110,14 @@ def supports_tp_plan(self): | |
| return True | ||
| return False | ||
|
|
||
| @property | ||
| def tp_size(self): | ||
| """ | ||
| Returns the model's tensor parallelism degree. | ||
| """ | ||
| # if None, the model didn't undergo tensor parallel sharding | ||
| return self._tp_size | ||
|
|
||
| @property | ||
| def supports_pp_plan(self): | ||
| if self._pp_plan is not None: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed for this specific PR. I don't know if we want to add this option yet cc @ArthurZucker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can have it in a separate PR as well, however, its needed to support TP + FSDP/DDP.
Sure, @ArthurZucker Let me know your thoughts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SunMarc Would appreciate it here, been looking at enabling TP + FSDP and this is exactly what I used myself.
cc @ArthurZucker