Skip to content

Commit

Permalink
ADLR/megatron-lm!1754 - MoE documentation refinement.
Browse files Browse the repository at this point in the history
  • Loading branch information
Victarry authored and ko3n1g committed Aug 2, 2024
1 parent 6dbe4cf commit 0b981f9
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 50 deletions.
Binary file added docs/source/images/moe/token_drop.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 10 additions & 3 deletions examples/mixtral/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@ snapshot_download(repo_id="mistralai/Mixtral-8x7B-v0.1", ignore_patterns=["*.pt"
The HF checkpoints can be converted to Megatron format by using the provided checkpoint converter for HF format.
The target model parallel size(e.g. TP,PP,EP) should be specified.

Currently the converter doesn't support distributed checkpointing yet, so each different parallel config requires a specific checkpoint.
- For training, the recommended model parallel config is TP1EP8PP4
- For inference, the recommended model parallel config is TP1EP1PP2

```
TOKENIZER_MODEL=/workspace/checkpoints/mixtral-hf/tokenizer.model
MEGATRON_PATH="/workspace/megatron-lm"
export PYTHONPATH=$MEGATRON_PATH:$PYTHONPATH
export CUDA_DEVICE_MAX_CONNECTIONS=1
TARGET_TP_SIZE=1
TARGET_PP_SIZE=4
TARGET_EP_SIZE=8
TARGET_TP_SIZE=""
TARGET_EP_SIZE=""
TARGET_PP_SIZE=""
HF_FORMAT_DIR=/workspace/checkpoints/mixtral-hf
MEGATRON_FORMAT_DIR=/workspace/checkpoints/mixtral-mcore-TP${TARGET_TP_SIZE}PP${TARGET_PP_SIZE}EP${TARGET_EP_SIZE}
Expand Down Expand Up @@ -88,6 +92,7 @@ torchrun $DISTRIBUTED_ARGS tools/run_text_generation_server.py \
--num-experts 8 \
--moe-router-topk 2 \
--moe-token-dispatcher-type alltoall \
--moe-grouped-gemm \
--mock-data \
--rotary-base 1000000
```
Expand Down Expand Up @@ -119,6 +124,8 @@ docker run \
bash examples/mixtral/train_mixtral_8x7b_distributed.sh $CHECKPOINT_PATH $TOKENIZER_MODEL $DATA_PATH
```

The above functionality also applys to Mixtral 8x22B actually, you should set the model config (including hidden_size/head_num/num_layers/ffn_hidden_size) properly according to the original [config](https://huggingface.co/mistralai/Mixtral-8x22B-v0.1/blob/main/config.json).

## Acknowledgements
Contributors outside NVIDIA for the huggingface converter and example of Mixtral models in Megatron-Core:
- Peng Li <[email protected]>
Expand Down
185 changes: 141 additions & 44 deletions megatron/core/transformer/moe/README.md
Original file line number Diff line number Diff line change
@@ -1,48 +1,43 @@
# Megatron Core MoE Key Features

### Parallelism
Megatron-Core offers rich parallelism mappings, combining Expert Parallelism with tensor, data, sequence, and pipeline parallelism. This boosts Mixtral 8X7B bf16 training to achieve **438 TFLOPS** as of MCore v0.8.


- **Expert Parallel**
### Parallelism
- **Expert Parallelism**
- A specific method of parallelism for MoE models, where experts are partitioned onto different workers and each worker processes a different batch of training samples, each worker process one or more experts for each MoE layer.
- **3D Parallel**: Data Parallel , Tensor Parallel, Pipeline Parallel, Sequence Parallel
- Note: When using MoE with expert parallelism and tensor parallelism, sequence parallelism must be used.
- **Richer parallel mappings**: EP can be combined with DP/TP/PP/SP for handling larger MoE variants.
- **3D Parallelism**: Data Parallelism, Tensor Parallelism, Pipeline Parallelism
- Note: When using MoE with expert parallelism and tensor parallelism, sequence parallelism must be enabled.
- **Context Parallelism**:
- Split the sequence dimension to support long context training.
- **Richer parallel mappings**: EP can be combined with DP/TP/PP/CP for handling larger MoE variants.
- **Full distributed optimizer support.**

### Router and Load Balancing

- Router type:
- Top-K MLP router
- Load Balancing algorithms:
- Sinkhorn (S-BASE)
- Aux loss / Load balancing loss

### Performance Optimizations

- GroupedGEMM when num local experts > 1
- Supported dtype: bf16
- Performance improvements for larger MoE models
- Enable `--tp-comm-overlap` for MoE

### Token Dispatch Mechanism

- Dropless / No token drop.
- Token drop and padding.
- Dropless / No token drop
- Token drop, with or without padding to capacity

### Ease of use
- Checkpoint converter (coming soon)
- Checkpoint converter for Mixtral models, see the [example](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/mixtral) for details.
- Distributed checkpoining
- Per-layer logging

## Upcoming features

- Enhanced cutlass GroupedGEMM kernels
- Reduced host-device syncs.
- More supported dtype: fp32/bf16/fp16
- Kernel heuristics tuned for H100/A100/A10/L40S
- BWD cutlass GroupedGEMM kernels supported
- Token permutation / unpermutation fusion
- Fused Sinkhorn Kernel
- Context Parallel with MoE
- FP8 training support

# User Guide
Expand All @@ -51,51 +46,80 @@

| Item | Description |
| --- | --- |
| num-experts | Number of Experts in MoE (None means no MoE) |
| expert-model-parallel-size | Degree of expert model parallelism. Default is 1. |
| moe-grouped-gemm | When there are multiple experts per rank, compress multiple local (potentially small) gemms in a single kernel launch to improve the utilization and performance by leveraging the Grouped GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm). |
| moe-router-load-balancing-type | Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss". |
| moe-router-topk | Number of experts to route to for each token. The default is 2. |
| moe-aux-loss-coeff | Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended. Default is 0.0. |
| moe-z-loss-coeff | Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended. Default is None. |
| moe-input-jitter-eps | Add noise to the input tensor by applying jitter with a specified epsilon value. Default is None. |
| moe-token-dispatcher-type | Determines the token dispatcher type. Choices are "allgather" and "alltoall". Default is "allgather". |
| moe-per-layer-logging | Enable per-layer logging for MoE, currently supports auxiliary loss and z loss. |
| moe-expert-capacity-factor | The capacity factor for each expert, None means no token will be dropped. Default is None. |
| moe-pad-expert-input-to-capacity | Pads the input for each expert to match the expert capacity length, effective only after the --moe-expert-capacity-factor is set. |

### Usage

To train a top-2 MoE model with an auxiliary loss, include the following arguments:

```python
| --num-experts | Number of Experts in MoE (None means no MoE) |
| --expert-model-parallel-size | Degree of expert model parallelism. Default is 1. |
| --moe-grouped-gemm | When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine. |
| --moe-router-load-balancing-type | Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss". |
| --moe-router-topk | Number of experts to route to for each token. The default is 2. |
| --moe-aux-loss-coeff | Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended. Default is 0.0. |
| --moe-z-loss-coeff | Scaling coefficient for the z-loss: a starting value of 1e-3 is recommended. Default is None. |
| --moe-input-jitter-eps | Add noise to the input tensor by applying jitter with a specified epsilon value. Default is None. |
| --moe-token-dispatcher-type | Determines the token dispatcher type. Choices are "allgather" and "alltoall". Default is "allgather". |
| --moe-per-layer-logging | Enable per-layer logging for MoE, currently supports auxiliary loss and z loss. |
| --moe-expert-capacity-factor | The capacity factor for each expert, None means no token will be dropped. Default is None. |
| --moe-pad-expert-input-to-capacity | Pads the input for each expert to match the expert capacity length, effective only after the --moe-expert-capacity-factor is set. |
| --moe-token-drop-policy | The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped. |
| --moe-layer-recompute | Enable activation checkpointing for moe_layer, should be used when memory is not sufficient. |
| --moe-extended-tp | (Experimental) Alternative parallelization strategy for expert parallelism. Instead of distributing experts across *expert_model_parallel_size*, each expert is sharded along extendended tensor parallel domain (tensor_model_paralle_size * expert_model_parallel_size). It avoids the load balancing problem with MOE training. Only avaiable with `--moe-token-dispatcher-type allgather`. |


## Usage

### Quick Start
To train a top-2 MoE model with 8 experts and auxiliary loss, include the following arguments:

```bash
--num-experts 8
--expert-model-parallel-size 8
--moe-grouped-gemm
--moe-router-load-balancing-type aux_loss # options: aux_loss, sinkhorn, none. Default is aux_loss.
--moe-router-topk 2
--moe-aux-loss-coeff 1e-2
--use-distributed-optimizer
```

To avoid out-of-memory in dropless MoE training, we can set a large capacity factor, add:

```python
--moe-expert-capacity-factor 4.0
--moe-token-dispatcher-type alltoall
```

To enable the token drop mechanism, such as GShard and SwitchTransformer, include the following arguments:

```python
```bash
--moe-expert-capacity-factor 1.0
--moe-pad-expert-input-to-capacity # Optional
```

The following figure illustrates differenting dropping strategies in MCore:
![Token Droppling Strategies](../../../../docs/source/images/moe/token_drop.png)

1. The default dropless strategy will not drop or pad any token.
2. By setting `--moe-expert-capacity-factor`, the tokens exceed the capcacity of expert will be dropped based on their selected probabilities.
The dropping is performed before the token exchange operation between EP ranks when EP > 1.
The formula of capacity is `capacity = num_tokens_per_rank * topk * capacity_factor / num_experts`.
3. By setting `--moe-pad-expert-input-to-capacity`, the experts with tokens less than capacity will be padded to the capacity.

### Fine-tuning Mixtral Models
Megatron-Core has full support for Mixtral MoE models, and we provide the checkpoint converter for Mixtral models from huggingface format to MCore format.
See more details in the [mixtral example](../../../../examples/mixtral/README.md).

### Distributed Checkpointing
MCore v0.7 introduced fully parallel and asynchronous saving capabilities to distributed checkpointing,
which addresses the issues of low efficiency in the traditional checkpoint saving methods.
It also solved the problem of incompatibility between checkpoints of differnt parallel mappings in the traditional format.
With the new distributed checkpointing solution, MCore can achieve flexible parallelism configurations by saving and loading the unified format checkpoints.
Compared to native PyTorch solution, MCore achieves up to 50x reduction in checkpointing overhead.

With MCore v0.8, MoE supports Distributed Checkpointing, which means users can save and load with any combination of parallelism and it is currently available, including expert parallel.
1. Loading weight and distributed optimizer states with TPxPPxEP resharding is supported in version 0.8.
2. GroupedMLP is also supported, including the ability to switch between GroupedMLP/SequentialMLP when loading and saving.
- When switching between GroupedMLP and SequentialMLP, loading distributed optimizer states is currently unsupported; this feature will be added in version 0.9.
Besides these limitations, Distributed Checkpointing is fully functional.

Usage
- `--use-dist-ckpt` The main argument, it will attempt to save and load using distributed checkpointing.
- `--auto-detect-ckpt-format` With this, it can load both distributed checkpointing and legacy checkpointing.

## Dropless MoE training script example:
<details>
<summary>Click here. </summary>

```bash
#!/bin/bash

Expand Down Expand Up @@ -213,3 +237,76 @@ torchrun ${DISTRIBUTED_ARGS[@]} pretrain_gpt.py \
${LOGGING_ARGS[@]}
```
</details>

# Performance Best Practice

### Tuning Guide of Paralell Mappings

To find a good parallel mapping that help you achieve a high throughput of a new model, there are some general rule that could help. Here is an overview of properties in different aspects for each parallel strategy.

| Parallel Strategy | Peak Activation Memory | Weight Memory | Optimizer states | Communication (Per-Layer) |
|:-----------------:|:-------------------------------:|:--------------:|:---------------------------------:|:-------------------------:|
| TP | 1/N (with SP on) | 1/N | 1/N | High |
| EP | 1 | 1/N in MoELayer| 1/N | Medium |
| PP | 1 (>1 with virtual pipeline) | 1/N | 1/N | Medium |
| CP | 1/N | 1 | 1/N (with distributed optimizer) | Medium |
| DP | 1 | 1 | 1/N (with distributed optimizer) | Low |

For a specific model, the best parallel mapping varies based on the model architecture, trained sequence length and the hardware platform.
Here we provide some general rules to get better performance:
1. Keep the model parallism size as small as possible.
- For the large language models, model parallism is often required to prevent OOM, but it will bring communication overhead and hurt performance.
- With distributed optimizer, master weights and optimizer states will be sharded across all DP ranks with slight communication overhead.
So try to reduce the model parallism size and increase data parallism size when there are lots of free GPU memory during training.
2. Ensure the EPxTP communication winthin the NVLink domain.
- Communications of EP and TP should remain within the NVLink domain as much as possible, as both are communication-intensive.
- If the model is too large and requires scaling across multiple nodes, consider PP before TP and EP. See item 3 for details.
3. Use Pipeline Parallelism to scale the model further.
- Enable Virtual Pipeline Parallelism(VPP) to reduce pp bubbles when PP_size >= 2 by setting `num_layers_per_virtual_pipeline_stage`.
- VPP_size tuning: the legal values of vpp_size are all common divisors of num_layers/pp_size, E.g., num_layers=24, pp_size=4, then we can pick vpp_size from {1, 2, 3, 6}. The larger the vpp_size, the lower the pipeline bubbles, while the larger number of P2P communications between each PP stages. Empirically a value in the middle often gives the best trade-off. `VPP_size=num_layers / PP_size / num_layers_per_virtual_pipeline_stage`
4. Prefer EP over TP for the expert layer when possible:
- TP saves more memory than EP, but EP can achieve better GEMM efficiency and less communication overhead than TP.
- If EP size increased to the number of expert, the local token permutation/un-permutation for experts computation are omitted.
- Simplify the computation graph of moe layers, more convenient for performing potential comm-computation overlapping.
- In practice, EP8TP1 is better than EP4TP2 for 8x7B.
5. Enable Context Parallelism for long context training.
- The efficiency of CP largely depends on whether its communication can be overlapped with computation.
- Emperically, use CP when sequence length >= 8K.


### End-to-End Training Practice
**Use the latest NVIDIA PyTorch or NeMo Docker Image**
- [NGC PyTorch Image](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
- [NGC NeMo Image](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo)

**OOM Caused by Token Distribution Imbalance when Training From Scratch**
MoE suffers from a severe load imbalance issue when the router is under-trained, leading to the model easily running out of memory (OOM), which typically occurs in the first 100~300 steps when training from scratch.
Therefore, there are two recommended ways during the first 200 steps to avoid the OOM problem, which can be removed after the token distribution is more stable:
1. Use Extended-TP(`-moe-extended-tp`) to replace EP with TP in MoELayer, this can prevent the load imbalancing between EP ranks. Since current ETP implementation has some memeory overhead, you can further enable activation recomputation only for MoE Layer by adding `--moe-layer-recompute`.
2. Setting capacity factor to a relatively small number like 1.0 by adding `--moe-token-capacity-factor 1.0`.

**Enable Communication Overlap**
- Enable `--overlap-param-gather` and `--overlap-grad-reduce` with distributed optimizer.
- Enable `--tp-comm-overlap` when TP>1.
- Enable p2p comm overlap when PP > 1 by setting `num_layers_per_virtual_pipeline_stage`.

**Enable GroupedGEMM when num_local_experts>1 with `--moe-grouped-gemm`**
- GroupedGEMM has higher efficiency than vanilla sequential GEMMs for each expert.
- Recommend to use the TE version of Grouped GEMM (by upgrading to MCore v0.8 and TE v1.9), which support Gradient Accumulation Fusion and FP8 Training.

### Reference Best Parallel Mapping

Here are the reference parallel mappings of MCore v0.8 for Mixtral 8x7B and 8x22B models:
| Model | Vocab Size| Dispatcher | Precision | #GPUs | SEQ LEN | TP | EP | PP | VP | MBS | GBS |
|:-----------------------:|:---------:|:----------:|:---------:|:-----:|:-------:|:--:|:--:|:--:|:--:|:---:|:---:|
| Mixtral 8x7B(Dropless) | 32K | All-to-All | BF16 | 64 | 4096 | 1 | 8 | 4 | 8 | 1 | 256 |
| Mixtral 8x22B(Dropless) | 32K | All-to-All | BF16 | 128 | 4096 | 4 | 2 | 8 | 7 | 1 | 256 |

Detailed Benchmark Information:
Server:
- 8xH100 80GB HBM3
- NVLink 4th Generation
- InfiniBand 8x400 Gbit/s

Docker Image:
- PyTorch 24.04 with TransformerEngine v1.9
2 changes: 1 addition & 1 deletion megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class TransformerConfig(ModelParallelConfig):
"""Number of experts to route to for each token."""

moe_router_pre_softmax: bool = False
"""Enable pre-softmax routing for MoE, which means the top-k selection is before the softmax. By default, top-k is done after the softmax."""
"""Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k."""

moe_grouped_gemm: bool = False
"""When there are multiple experts per rank, compress multiple local (potentially small) gemms
Expand Down
4 changes: 2 additions & 2 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1764,9 +1764,9 @@ def _add_moe_args(parser):
group.add_argument('--moe-router-topk', type=int, default=2,
help='Number of experts to route to for each token. The default is 2.')
group.add_argument('--moe-router-pre-softmax', action='store_true',
help='Enable pre-softmax routing for MoE, which means the top-k selection is before the softmax. By default, top-k is done after the softmax.')
help='Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.')
group.add_argument('--moe-grouped-gemm', action='store_true',
help='When there are multiple experts per rank, compress multiple local (potentially small) gemms in a single kernel launch to improve the utilization and performance by leveraging the Grouped GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm).')
help='When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine.')
group.add_argument('--moe-aux-loss-coeff', type=float, default=0.0,
help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.')
group.add_argument('--moe-z-loss-coeff', type=float, default=None,
Expand Down

0 comments on commit 0b981f9

Please sign in to comment.