diff --git a/README.md b/README.md index 77cae19..504db62 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ Currently, candle-vllm supports chat serving for the following models. | Model ID | Model Type | Supported | Speed (A100, `BF16`) | Throughput (`BF16`, `bs=16`) | Quantized (A100, `Q4K` or `Marlin`) | Throughput (`GTPQ/Marlin`, `bs=16`) | |--|--|--|--|--|--|--| | #1 | **LLAMA** |✅|65 tks/s (LLaMa3.1 8B) | 553 tks/s (LLaMa3.1 8B) | 75 tks/s (LLaMa3.1 8B), **115 tks/s (LLaMa3.1 8B, Marlin)** |**755 tks/s (LLaMa3.1 8B)**| -| #2 | **Mistral** |✅|70 tks/s (7B)| 585 tks/s (7B) | 96 tks/s (7B) |TBD| +| #2 | **Mistral** |✅|70 tks/s (7B)| 585 tks/s (7B) | 96 tks/s (7B), **113 tks/s (7B, Marlin)** |**764 tks/s (7B)**| | #3 | **Phi (v1, v1.5, v2)** |✅|97 tks/s (2.7B, F32+BF16)|TBD|-|TBD| | #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)| 744 tks/s (3.8B)|135 tks/s (3.8B)|TBD| | #5 | **Yi** |✅|75 tks/s (6B)| 566 tks/s (6B) | 105 tks/s (6B)|TBD| @@ -30,7 +30,7 @@ Currently, candle-vllm supports chat serving for the following models. | #7 | BigCode/StarCode |TBD|TBD|TBD |-|TBD| | #8 | ChatGLM |TBD|TBD|TBD |-|TBD| | #9 | **QWen2 (1.8B, 7B)** |✅|148 tks/s (1.8B)|784 tks/s (1.8B) |-|TBD| -| #10 | **Google Gemma** |✅|130 tks/s (2B)|TBD |-|TBD| +| #10 | **Google Gemma** |✅|130 tks/s (2B)|TBD |**73 tks/s (Gemma2-9B, Marlin)** |**512 tks/s (Gemma2-9B)**| | #11 | Blip-large (Multimodal) |TBD|TBD|TBD |-|TBD| | #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|TBD |-|TBD| diff --git a/examples/convert_marlin.py b/examples/convert_marlin.py new file mode 100644 index 0000000..712d586 --- /dev/null +++ b/examples/convert_marlin.py @@ -0,0 +1,58 @@ +from transformers import AutoTokenizer +from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #install latest autogptq +import shutil + +#pipeline to marlin format: pretrained model (f16/bf16/f32 format) -> gptq (4-bit quantization) -> gptq marlin + +#change the following paths +pretrained_model_dir = "/home/mistral_7b/" #path to original model (un-quantized model) +# saving path, save as gptq (4-bit quantization) model if needed +#(you may skip the quantization step if you have GPTQ model) +quantized_model_dir = "/home/mistral_7b-int4/" +save_path = "/home/mistral_7b-int4-Marlin/" # final saving path, save as gptq marlin model + +def main(): + quantize_config = BaseQuantizeConfig( + bits=4, # quantize model to 4-bit (candle-vllm now only support 4-bit quantization for marlin) + group_size=128, # it is recommended to set the value to 128 + desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad + ) + + # # load un-quantized model, by default, the model will always be loaded into CPU memory + model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) + + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) + examples = [ + tokenizer( + "auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm." + ) + ] + + # # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask" + model.quantize(examples) + + # save quantized model + model.save_quantized(quantized_model_dir) + + #must specify "use_marlin=True" to save marlin format model + model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_marlin=True, use_safetensors=True) + print(model.state_dict().keys()) + + model.save_pretrained(save_path) + + #if everything works fine, the target folder should contain the quantized marlin model called "gptq_model-4bit-128g.safetensors" + #candle-vllm only support "model.safeternsors" for single-file model or "model.safetensors.index.json" for chunked model + shutil.move(save_path + "gptq_model-4bit-128g.safetensors", save_path + "model.safetensors") + #we also need tokenizer.json + shutil.copy2(pretrained_model_dir + "tokenizer.json", save_path) + +if __name__ == "__main__": + import logging + + logging.basicConfig( + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + + main()