Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions docs/source/ar/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -674,29 +674,7 @@ use_cpu: false
```

</hfoption>
<hfoption id="Tensor Parallelism with PyTorch 2">

```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

```

</hfoption>
</hfoptions>
يُعد أمر [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) هو الطريقة المُوصى بها لتشغيل نص البرمجى للتدريب على نظام موزع باستخدام Accelerate و [`Trainer`] مع المعلمات المحددة في `config_file.yaml`. يتم حفظ هذا الملف في مجلد ذاكرة التخزين المؤقت لـ Accelerate ويتم تحميله تلقائيًا عند تشغيل `accelerate_launch`.

Expand Down
22 changes: 1 addition & 21 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -341,29 +341,9 @@ use_cpu: false
```

</hfoption>
<hfoption id="Tensor parallelism with PyTorch 2">

```yaml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```

</hfoptions>


Run [accelerate_launch](https://hf.co/docs/accelerate/package_reference/cli#accelerate-launch) to start training with the configurations set in `config_file.yaml`. This file is saved to the Accelerate cache folder and automatically loaded when you run `accelerate_launch`.

The example below launches the [run_glue.py](../../../examples/pytorch/text-classification/run_glue) script with the FSDP configuration shown earlier. Parameters from the `config_file.yaml` file can also be directly set in the command line.
Expand Down
23 changes: 0 additions & 23 deletions docs/source/es/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,29 +363,6 @@ use_cpu: false

</hfoption>

<hfoption id="Tensor Parallelism with PyTorch 2">

```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

```

</hfoption>
</hfoptions>

El comando [`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) es la forma recomendada de lanzar tu script de entrenamiento en un sistema distribuido con Accelerate y [`Trainer`] con los parámetros especificados en `config_file.yaml`. Este archivo se guarda en la carpeta de caché de Accelerate y se carga automáticamente cuando ejecutas `accelerate_launch`.
Expand Down
22 changes: 0 additions & 22 deletions docs/source/ko/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -549,29 +549,7 @@ use_cpu: false
```

</hfoption>
<hfoption id="Tensor Parallelism with PyTorch 2">

```yml
compute_environment: LOCAL_MACHINE
tp_config:
tp_size: 4
distributed_type: TP
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

```

</hfoption>
</hfoptions>

[`accelerate_launch`](https://huggingface.co/docs/accelerate/package_reference/cli#accelerate-launch) 명령은 Accelerate와 [`Trainer`]를 사용하여 분산 시스템에서 훈련 스크립트를 실행하는 권장 방법이며, `config_file.yaml`에 지정된 매개변수를 사용합니다. 이 파일은 Accelerate 캐시 폴더에 저장되며 `accelerate_launch`를 실행할 때 자동으로 로드됩니다.
Expand Down
28 changes: 23 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1784,6 +1784,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMi
# for example.
_tp_plan = None

# tensor parallel degree to which model is sharded to.
_tp_size = None

# A pipeline parallel plan specifying the layers which may not be present
# on all ranks when PP is enabled. For top-level models, this attribute is
# currently defined in respective model code. For base models, this
Expand Down Expand Up @@ -3845,6 +3848,8 @@ def from_pretrained(
A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts
`tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with
`torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
tp_size (`str`, *optional*):
A torch tensor parallel degree. If not provided would default to world size.
Comment on lines +3851 to +3852
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed for this specific PR. I don't know if we want to add this option yet cc @ArthurZucker

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have it in a separate PR as well, however, its needed to support TP + FSDP/DDP.

I don't know if we want to add this option yet

Sure, @ArthurZucker Let me know your thoughts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc Would appreciate it here, been looking at enabling TP + FSDP and this is exactly what I used myself.
cc @ArthurZucker

offload_folder (`str` or `os.PathLike`, *optional*):
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
offload_state_dict (`bool`, *optional*):
Expand Down Expand Up @@ -3941,6 +3946,7 @@ def from_pretrained(
generation_config = kwargs.pop("generation_config", None)
gguf_file = kwargs.pop("gguf_file", None)
tp_plan = kwargs.pop("tp_plan", None)
tp_size = kwargs.pop("tp_size", None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's raise an error if tp_size was set but not tp_plan

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SunMarc Addressed this comment, thank you.

key_mapping = kwargs.pop("key_mapping", None)
# Not used anymore -- remove them from the kwargs
_ = kwargs.pop("resume_download", None)
Expand All @@ -3953,7 +3959,8 @@ def from_pretrained(
raise ValueError(
"`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies."
)

if tp_size is not None and tp_plan is None:
raise ValueError("tp_plan has to be set when tp_size is passed.")
if tp_plan is not None and tp_plan != "auto":
# TODO: we can relax this check when we support taking tp_plan from a json file, for example.
raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
Expand Down Expand Up @@ -4007,9 +4014,10 @@ def from_pretrained(
sys.stderr = open(os.devnull, "w")
# This is the easiest way to dispatch to the current process device
device_map = tp_device
# Assuming sharding the model onto the world
world_size = torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (world_size,))

# Assuming sharding the model onto the world when tp_size not provided
tp_size = tp_size if tp_size is not None else torch.distributed.get_world_size()
device_mesh = torch.distributed.init_device_mesh(tp_device.type, (tp_size,))

if use_auth_token is not None:
warnings.warn(
Expand Down Expand Up @@ -4373,6 +4381,9 @@ def from_pretrained(
weights_only=weights_only,
)

# record tp degree the model sharded to
model._tp_size = tp_size

# make sure token embedding weights are still tied if needed
model.tie_weights()

Expand Down Expand Up @@ -4456,7 +4467,6 @@ def from_pretrained(
elif from_flax:
loading_info = None
return model, loading_info

return model

@staticmethod
Expand Down Expand Up @@ -5100,6 +5110,14 @@ def supports_tp_plan(self):
return True
return False

@property
def tp_size(self):
"""
Returns the model's tensor parallelism degree.
"""
# if None, the model didn't undergo tensor parallel sharding
return self._tp_size

@property
def supports_pp_plan(self):
if self._pp_plan is not None:
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def __init__(
self.hp_name = None
self.deepspeed = None
self.is_in_train = False

self.model = model
self.create_accelerator_and_postprocess()

# memory metrics - must set up as early as possible
Expand Down Expand Up @@ -5137,10 +5137,10 @@ def create_accelerator_and_postprocess(self):
args.update(accelerator_config)
# tp is initialized at Accelerator init phase so
# args should be prepared here
if self.args.tp_size > 1:
if hasattr(self.model, "tp_size") and self.model.tp_size is not None and self.model.tp_size > 1:
self.is_tp_enabled = True
if version.parse(accelerate_version) > version.parse("1.3.0"):
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.args.tp_size)
args["torch_tp_plugin"] = TorchTensorParallelPlugin(tp_size=self.model.tp_size)
else:
raise ValueError("Requires accelerate>1.3.0 to use Tensor Parallelism.")

Expand Down
24 changes: 0 additions & 24 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,6 @@ class TrainingArguments:
Will use gradient checkpointing over each nested XLA FSDP wrapped layer. This setting can only be
used when the xla flag is set to true, and an auto wrapping policy is specified through
fsdp_min_num_params or fsdp_transformer_layer_cls_to_wrap.
tp_size (`int`, *optional*):
Use tp_size to enable PyTorch tensor parallelism. Tensor parallelism support is only available to models having `base_tp_plan`
in their respective config classes.
Set a value greater than 1 to activate TP. The same is used to prepare device mesh internally. Requires accelerate>1.3.0.
deepspeed (`str` or `dict`, *optional*):
Use [Deepspeed](https://github.com/deepspeedai/DeepSpeed). This is an experimental feature and its API may
evolve in the future. The value is either the location of DeepSpeed json config file (e.g.,
Expand Down Expand Up @@ -1244,18 +1240,6 @@ class TrainingArguments:
)
},
)
tp_size: Optional[int] = field(
default=0,
metadata={
"help": (
"Use tp_size to enable pytorch tensor parallelism."
"Tensor parallelism support is only available to models having `base_tp_plan` in their respective config classes."
"Set a value greater than 1 to activate TP."
"The same is used to prepare device mesh internally."
"Requires accelerate>1.3.0."
)
},
)
fsdp_transformer_layer_cls_to_wrap: Optional[str] = field(
default=None,
metadata={
Expand Down Expand Up @@ -1941,14 +1925,6 @@ def __post_init__(self):
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")

if self.tp_size > 1:
if not is_accelerate_available("1.3.1"):
raise NotImplementedError(
"TP using PyTorch requires Accelerate version `accelerate` >= 1.3.1. "
"This is not supported and we recommend you to update your version."
)
os.environ["ACCELERATE_USE_TP"] = "true"
os.environ["TP_SIZE"] = str(self.tp_size)
# accelerate integration for FSDP
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
os.environ["ACCELERATE_USE_FSDP"] = "true"
Expand Down