From 2e0035f83b3bd56c7bfe6c497db84277ddd1dd96 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 23 Sep 2024 10:36:45 -0400 Subject: [PATCH] Update MoE examples (#192) * Update MoE examples * Add top-level link * Fix deepseek_moe_w8a8_int8.py * Add deepseek_moe_w8a8_fp8.py * Quality * Quality --- README.md | 1 + examples/quantizing_moe/README.md | 20 +++-- .../quantizing_moe/deepseek_moe_w8a8_fp8.py | 83 +++++++++++++++++++ ..._moe_w8a8.py => deepseek_moe_w8a8_int8.py} | 2 +- ...ral_moe_fp8.py => mixtral_moe_w8a8_fp8.py} | 6 +- 5 files changed, 98 insertions(+), 14 deletions(-) create mode 100644 examples/quantizing_moe/deepseek_moe_w8a8_fp8.py rename examples/quantizing_moe/{deepseek_moe_w8a8.py => deepseek_moe_w8a8_int8.py} (97%) rename examples/quantizing_moe/{mixtral_moe_fp8.py => mixtral_moe_w8a8_fp8.py} (89%) diff --git a/README.md b/README.md index 67f4d1809..207d236ac 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,7 @@ Applying quantization with `llmcompressor`: * [Activation quantization to `int8`](examples/quantization_w8a8_int8) * [Activation quantization to `fp8`](examples/quantization_w8a8_fp8) * [Weight only quantization to `int4`](examples/quantization_w4a16) +* [Quantizing MoE LLMs](examples/quantizing_moe) ### User Guides Deep dives into advanced usage of `llmcompressor`: diff --git a/examples/quantizing_moe/README.md b/examples/quantizing_moe/README.md index 2ac3db1dd..4421bdf01 100644 --- a/examples/quantizing_moe/README.md +++ b/examples/quantizing_moe/README.md @@ -1,6 +1,6 @@ -# Quantizing TinyMixtral-4x248M-MoE Model with FP8 +# Quantizing Mixtral-8x7B-Instruct-v0.1 Model with FP8 -This directory contains an example script for quantizing the `TinyMixtral-4x248M-MoE` model using the FP8 quantization scheme. +This directory contains an example script for quantizing the `Mixtral-8x7B-Instruct-v0.1` model using the static per-tensor FP8 quantization scheme. ## Installation @@ -17,7 +17,7 @@ pip install -e . The provided example script demonstrates an end-to-end process for applying the quantization algorithm: ```bash -python3 mixtral_moe_fp8.py +python3 mixtral_moe_w8a8_fp8.py ``` ## Creating a Quantized MoE Model @@ -27,7 +27,7 @@ This example leverages `llm-compressor` and `compressed-tensors` to create an FP You can follow the detailed steps below or simply run the example script with: ```bash -python examples/quantizing_moe/mixtral_moe_fp8.py +python mixtral_moe_w8a8_fp8.py ``` ### Step 1: Select a Model, Dataset, and Recipe @@ -36,12 +36,12 @@ In this step, you'll choose a baseline model for quantization, a dataset for cal - **Models**: Can be referenced from a local directory or retrieved from the Hugging Face Hub. - **Datasets**: Can also be from a local directory or the Hugging Face Hub. -- **Recipes**: These are YAML files or Python modifier objects that describe how a model should be optimized during or after training. In this example, we use a `GPTQModifier` object with the scheme set to `FP8`. +- **Recipes**: These are YAML files or Python modifier objects that describe how a model should be optimized during or after training. In this example, we use a `QuantizationModifier` object with the scheme set to `FP8`. ```python -from llmcompressor.modifiers.quantization.gptq import GPTQModifier +from llmcompressor.modifiers.quantization import QuantizationModifier -recipe = GPTQModifier(scheme="FP8", targets="Linear", ignore=["lm_head", "re:.*block_sparse_moe.gate"], sequential_update=True) +recipe = QuantizationModifier(scheme="FP8", targets="Linear", ignore=["lm_head", "re:.*block_sparse_moe.gate"]) ``` NOTE: `.*block_sparse_moe.gate` layers do not quantize well, hence they are ignored! @@ -69,9 +69,11 @@ oneshot( ### Custom Quantization +NOTE: Only per-tensor quantization is supported in vLLM as of now (`vllm==0.6.1`) + The repository supports multiple quantization techniques configured via a recipe. Supported strategies include `tensor`, `group`, and `channel` quantization. -In the above example, FP8 channel-wise quantization is used as specified by the `FP8` scheme. For other preset schemes, refer to the [quantization schemes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py) in the `Compressed-Tensors` library. +In the above example, FP8 per-tensor quantization is used as specified by the `FP8` scheme. For other preset schemes, refer to the [quantization schemes](https://github.com/neuralmagic/compressed-tensors/blob/main/src/compressed_tensors/quantization/quant_scheme.py) in the `compressed-tensors` library. A custom scheme can also be specified using `config_groups`: @@ -89,7 +91,7 @@ config_groups = { "num_bits": 8, "type": "int", "symmetric": true, - "strategy": "tensor", + "strategy": "group", "group_size": 128, } } diff --git a/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py b/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py new file mode 100644 index 000000000..32db0485f --- /dev/null +++ b/examples/quantizing_moe/deepseek_moe_w8a8_fp8.py @@ -0,0 +1,83 @@ +from datasets import load_dataset +from transformers import AutoTokenizer + +from llmcompressor.modifiers.quantization import QuantizationModifier +from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot + +# select a Mixture of Experts model for quantization +MODEL_ID = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" + +model = SparseAutoModelForCausalLM.from_pretrained( + MODEL_ID, device_map="auto", torch_dtype="auto", trust_remote_code=True +) +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + +# Select calibration dataset. +# its recommended to use more calibration samples for MoE models so each expert is hit +DATASET_ID = "HuggingFaceH4/ultrachat_200k" +DATASET_SPLIT = "train_sft" +NUM_CALIBRATION_SAMPLES = 2048 +MAX_SEQUENCE_LENGTH = 2048 + + +# Load dataset and preprocess. +ds = load_dataset(DATASET_ID, split=DATASET_SPLIT) +ds = ds.shuffle(seed=42).select(range(NUM_CALIBRATION_SAMPLES)) + + +def preprocess(example): + return { + "text": tokenizer.apply_chat_template( + example["messages"], + tokenize=False, + ) + } + + +ds = ds.map(preprocess) + + +# Tokenize inputs. +def tokenize(sample): + return tokenizer( + sample["text"], + padding=False, + max_length=MAX_SEQUENCE_LENGTH, + truncation=True, + add_special_tokens=False, + ) + + +ds = ds.map(tokenize, remove_columns=ds.column_names) + +# define a llmcompressor recipe for FP8 W8A8 quantization +# since the MoE gate layers are sensitive to quantization, we add them to the ignore +# list so they remain at full precision +recipe = [ + QuantizationModifier( + targets="Linear", + scheme="FP8", + ignore=["lm_head", "re:.*mlp.gate$"], + ), +] + +SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8" + +oneshot( + model=model, + dataset=ds, + recipe=recipe, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + save_compressed=True, + output_dir=SAVE_DIR, +) + + +print("========== SAMPLE GENERATION ==============") +SAMPLE_INPUT = ["I love quantization because"] +tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) +inputs = tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(model.device) +output = model.generate(**inputs, max_length=50) +text_output = tokenizer.batch_decode(output) +print(text_output) diff --git a/examples/quantizing_moe/deepseek_moe_w8a8.py b/examples/quantizing_moe/deepseek_moe_w8a8_int8.py similarity index 97% rename from examples/quantizing_moe/deepseek_moe_w8a8.py rename to examples/quantizing_moe/deepseek_moe_w8a8_int8.py index 11a1023e6..3c02f5d8d 100644 --- a/examples/quantizing_moe/deepseek_moe_w8a8.py +++ b/examples/quantizing_moe/deepseek_moe_w8a8_int8.py @@ -62,7 +62,7 @@ def tokenize(sample): ds = ds.map(tokenize, remove_columns=ds.column_names) -# define a llmcompressor recipe for W416 quantization +# define a llmcompressor recipe for INT8 W8A8 quantization # since the MoE gate layers are sensitive to quantization, we add them to the ignore # list so they remain at full precision recipe = [ diff --git a/examples/quantizing_moe/mixtral_moe_fp8.py b/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py similarity index 89% rename from examples/quantizing_moe/mixtral_moe_fp8.py rename to examples/quantizing_moe/mixtral_moe_w8a8_fp8.py index 4c7fd1725..ac7510b03 100644 --- a/examples/quantizing_moe/mixtral_moe_fp8.py +++ b/examples/quantizing_moe/mixtral_moe_w8a8_fp8.py @@ -2,7 +2,7 @@ from transformers import AutoTokenizer -from llmcompressor.modifiers.quantization.gptq import GPTQModifier +from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.transformers import SparseAutoModelForCausalLM, oneshot from llmcompressor.transformers.compression.helpers import calculate_offload_device_map @@ -34,9 +34,7 @@ "re:.*block_sparse_moe.gate", # does not quantize well ] -recipe = GPTQModifier( - scheme="FP8", targets="Linear", ignore=layers_to_ignore, sequential_update=True -) +recipe = QuantizationModifier(scheme="FP8", targets="Linear", ignore=layers_to_ignore) oneshot(