Skip to content

Commit

Permalink
Update MoE examples (#192)
Browse files Browse the repository at this point in the history
* Update MoE examples

* Add top-level link

* Fix deepseek_moe_w8a8_int8.py

* Add deepseek_moe_w8a8_fp8.py

* Quality

* Quality
  • Loading branch information
mgoin authored Sep 23, 2024
1 parent 23c499a commit 2e0035f
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 14 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Applying quantization with `llmcompressor`:
* [Activation quantization to `int8`](examples/quantization_w8a8_int8)
* [Activation quantization to `fp8`](examples/quantization_w8a8_fp8)
* [Weight only quantization to `int4`](examples/quantization_w4a16)
* [Quantizing MoE LLMs](examples/quantizing_moe)

### User Guides
Deep dives into advanced usage of `llmcompressor`:
Expand Down
20 changes: 11 additions & 9 deletions examples/quantizing_moe/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Quantizing TinyMixtral-4x248M-MoE Model with FP8
# Quantizing Mixtral-8x7B-Instruct-v0.1 Model with FP8

This directory contains an example script for quantizing the `TinyMixtral-4x248M-MoE` model using the FP8 quantization scheme.
This directory contains an example script for quantizing the `Mixtral-8x7B-Instruct-v0.1` model using the static per-tensor FP8 quantization scheme.

## Installation

Expand All @@ -17,7 +17,7 @@ pip install -e .
The provided example script demonstrates an end-to-end process for applying the quantization algorithm:

```bash
python3 mixtral_moe_fp8.py
python3 mixtral_moe_w8a8_fp8.py
```

## Creating a Quantized MoE Model
Expand All @@ -27,7 +27,7 @@ This example leverages `llm-compressor` and `compressed-tensors` to create an FP
You can follow the detailed steps below or simply run the example script with:

```bash
python examples/quantizing_moe/mixtral_moe_fp8.py
python mixtral_moe_w8a8_fp8.py
```

### Step 1: Select a Model, Dataset, and Recipe
Expand All @@ -36,12 +36,12 @@ In this step, you'll choose a baseline model for quantization, a dataset for cal

- **Models**: Can be referenced from a local directory or retrieved from the Hugging Face Hub.
- **Datasets**: Can also be from a local directory or the Hugging Face Hub.
- **Recipes**: These are YAML files or Python modifier objects that describe how a model should be optimized during or after training. In this example, we use a `GPTQModifier` object with the scheme set to `FP8`.
- **Recipes**: These are YAML files or Python modifier objects that describe how a model should be optimized during or after training. In this example, we use a `QuantizationModifier` object with the scheme set to `FP8`.

```python
from llmcompressor.modifiers.quantization.gptq import GPTQModifier
from llmcompressor.modifiers.quantization import QuantizationModifier

recipe = GPTQModifier(scheme="FP8", targets="Linear", ignore=["lm_head", "re:.*block_sparse_moe.gate"], sequential_update=True)
recipe = QuantizationModifier(scheme="FP8", targets="Linear", ignore=["lm_head", "re:.*block_sparse_moe.gate"])
```

NOTE: `.*block_sparse_moe.gate` layers do not quantize well, hence they are ignored!
Expand Down Expand Up @@ -69,9 +69,11 @@ oneshot(

### Custom Quantization

NOTE: Only per-tensor quantization is supported in vLLM as of now (`vllm==0.6.1`)

The repository supports multiple quantization techniques configured via a recipe. Supported strategies include `tensor`, `group`, and `channel` quantization.

In the above example, FP8 channel-wise quantization is used as specified by the `FP8` scheme. For other preset schemes, refer to the [quantization schemes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py) in the `Compressed-Tensors` library.
In the above example, FP8 per-tensor quantization is used as specified by the `FP8` scheme. For other preset schemes, refer to the [quantization schemes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py) in the `compressed-tensors` library.

A custom scheme can also be specified using `config_groups`:

Expand All @@ -89,7 +91,7 @@ config_groups = {
"num_bits": 8,
"type": "int",
"symmetric": true,
"strategy": "tensor",
"strategy": "group",
"group_size": 128,
}
}
Expand Down
83 changes: 83 additions & 0 deletions examples/quantizing_moe/deepseek_moe_w8a8_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from datasets import load_dataset
from transformers import AutoTokenizer

from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot

# select a Mixture of Experts model for quantization
MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"

model = SparseAutoModelForCausalLM.from_pretrained(
MODEL_ID, device_map="auto", torch_dtype="auto", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Select calibration dataset.
# its recommended to use more calibration samples for MoE models so each expert is hit
DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 2048
MAX_SEQUENCE_LENGTH = 2048


# Load dataset and preprocess.
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES))


def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}


ds = ds.map(preprocess)


# Tokenize inputs.
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)


ds = ds.map(tokenize, remove_columns=ds.column_names)

# define a llmcompressor recipe for FP8 W8A8 quantization
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision
recipe = [
QuantizationModifier(
targets="Linear",
scheme="FP8",
ignore=["lm_head", "re:.*mlp.gate$"],
),
]

SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8"

oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
save_compressed=True,
output_dir=SAVE_DIR,
)


print("========== SAMPLE GENERATION ==============")
SAMPLE_INPUT = ["I love quantization because"]
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
inputs = tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(model.device)
output = model.generate(**inputs, max_length=50)
text_output = tokenizer.batch_decode(output)
print(text_output)
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def tokenize(sample):

ds = ds.map(tokenize, remove_columns=ds.column_names)

# define a llmcompressor recipe for W416 quantization
# define a llmcompressor recipe for INT8 W8A8 quantization
# since the MoE gate layers are sensitive to quantization, we add them to the ignore
# list so they remain at full precision
recipe = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from transformers import AutoTokenizer

from llmcompressor.modifiers.quantization.gptq import GPTQModifier
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot
from llmcompressor.transformers.compression.helpers import calculate_offload_device_map

Expand Down Expand Up @@ -34,9 +34,7 @@
"re:.*block_sparse_moe.gate", # does not quantize well
]

recipe = GPTQModifier(
scheme="FP8", targets="Linear", ignore=layers_to_ignore, sequential_update=True
)
recipe = QuantizationModifier(scheme="FP8", targets="Linear", ignore=layers_to_ignore)


oneshot(
Expand Down

0 comments on commit 2e0035f

Please sign in to comment.