Skip to content

Commit

Permalink
Add an example for marlin format conversion & update results (#91)
Browse files Browse the repository at this point in the history
Add an example for marlin format conversion & update results
  • Loading branch information
guoqingbao authored Oct 16, 2024
1 parent 885ee3e commit 46b10ad
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 2 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ Currently, candle-vllm supports chat serving for the following models.
| Model ID | Model Type | Supported | Speed (A100, `BF16`) | Throughput (`BF16`, `bs=16`) | Quantized (A100, `Q4K` or `Marlin`) | Throughput (`GTPQ/Marlin`, `bs=16`) |
|--|--|--|--|--|--|--|
| #1 | **LLAMA** ||65 tks/s (LLaMa3.1 8B) | 553 tks/s (LLaMa3.1 8B) | 75 tks/s (LLaMa3.1 8B), **115 tks/s (LLaMa3.1 8B, Marlin)** |**755 tks/s (LLaMa3.1 8B)**|
| #2 | **Mistral** ||70 tks/s (7B)| 585 tks/s (7B) | 96 tks/s (7B) |TBD|
| #2 | **Mistral** ||70 tks/s (7B)| 585 tks/s (7B) | 96 tks/s (7B), **113 tks/s (7B, Marlin)** |**764 tks/s (7B)**|
| #3 | **Phi (v1, v1.5, v2)** ||97 tks/s (2.7B, F32+BF16)|TBD|-|TBD|
| #4 | **Phi-3 (3.8B, 7B)** ||107 tks/s (3.8B)| 744 tks/s (3.8B)|135 tks/s (3.8B)|TBD|
| #5 | **Yi** ||75 tks/s (6B)| 566 tks/s (6B) | 105 tks/s (6B)|TBD|
| #6 | **StableLM** ||99 tks/s (3B)|TBD|-|TBD|
| #7 | BigCode/StarCode |TBD|TBD|TBD |-|TBD|
| #8 | ChatGLM |TBD|TBD|TBD |-|TBD|
| #9 | **QWen2 (1.8B, 7B)** ||148 tks/s (1.8B)|784 tks/s (1.8B) |-|TBD|
| #10 | **Google Gemma** ||130 tks/s (2B)|TBD |-|TBD|
| #10 | **Google Gemma** ||130 tks/s (2B)|TBD |**73 tks/s (Gemma2-9B, Marlin)** |**512 tks/s (Gemma2-9B)**|
| #11 | Blip-large (Multimodal) |TBD|TBD|TBD |-|TBD|
| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|TBD |-|TBD|

Expand Down
58 changes: 58 additions & 0 deletions examples/convert_marlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig #install latest autogptq
import shutil

#pipeline to marlin format: pretrained model (f16/bf16/f32 format) -> gptq (4-bit quantization) -> gptq marlin

#change the following paths
pretrained_model_dir = "/home/mistral_7b/" #path to original model (un-quantized model)
# saving path, save as gptq (4-bit quantization) model if needed
#(you may skip the quantization step if you have GPTQ model)
quantized_model_dir = "/home/mistral_7b-int4/"
save_path = "/home/mistral_7b-int4-Marlin/" # final saving path, save as gptq marlin model

def main():
quantize_config = BaseQuantizeConfig(
bits=4, # quantize model to 4-bit (candle-vllm now only support 4-bit quantization for marlin)
group_size=128, # it is recommended to set the value to 128
desc_act=False, # set to False can significantly speed up inference but the perplexity may slightly bad
)

# # load un-quantized model, by default, the model will always be loaded into CPU memory
model = AutoGPTQForCausalLM.from_pretrained(pretrained_model_dir, quantize_config)

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
examples = [
tokenizer(
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
)
]

# # quantize model, the examples should be list of dict whose keys can only be "input_ids" and "attention_mask"
model.quantize(examples)

# save quantized model
model.save_quantized(quantized_model_dir)

#must specify "use_marlin=True" to save marlin format model
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir, device="cuda:0", use_marlin=True, use_safetensors=True)
print(model.state_dict().keys())

model.save_pretrained(save_path)

#if everything works fine, the target folder should contain the quantized marlin model called "gptq_model-4bit-128g.safetensors"
#candle-vllm only support "model.safeternsors" for single-file model or "model.safetensors.index.json" for chunked model
shutil.move(save_path + "gptq_model-4bit-128g.safetensors", save_path + "model.safetensors")
#we also need tokenizer.json
shutil.copy2(pretrained_model_dir + "tokenizer.json", save_path)

if __name__ == "__main__":
import logging

logging.basicConfig(
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
level=logging.INFO,
datefmt="%Y-%m-%d %H:%M:%S",
)

main()

0 comments on commit 46b10ad

Please sign in to comment.