Skip to content

Commit

Permalink
Support GPTQ/Marlin format quantization (4bit weight, f16 input) (#89)
Browse files Browse the repository at this point in the history
Support GPTQ/Marlin format quantization (4bit weight, f16 input)
  • Loading branch information
guoqingbao authored Oct 14, 2024
1 parent c7ff257 commit 1b4b0d4
Show file tree
Hide file tree
Showing 17 changed files with 2,018 additions and 68 deletions.
39 changes: 25 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,26 @@ Efficient, easy-to-use platform for inference and serving local LLMs including a
- Efficient management of key-value cache with PagedAttention.
- Continuous batching.
- `In-situ` quantization
- `GPTQ/Marlin` format quantization

## Develop Status

Currently, candle-vllm supports chat serving for the following models.

| Model ID | Model Type | Supported | Speed (A100, `BF16`) | Throughput (`BF16`, `bs=16`) | Quantized (A100, `Q4K`) |
|--|--|--|--|--|--|
| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** ||74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)| 553 tks/s (LLaMa3.1 8B) | 75 tks/s (LLaMa3.1 8B) |
| #2 | **Mistral** ||70 tks/s (7B)| 585 tks/s (7B) | 96 tks/s (7B) |
| #3 | **Phi (v1, v1.5, v2)** ||97 tks/s (2.7B, F32+BF16)|TBD|-|
| #4 | **Phi-3 (3.8B, 7B)** ||107 tks/s (3.8B)| 744 tks/s (3.8B)|135 tks/s (3.8B)|
| #5 | **Yi** ||75 tks/s (6B)| 566 tks/s (6B) | 105 tks/s (6B)|
| #6 | **StableLM** ||99 tks/s (3B)|TBD|-|
| #7 | BigCode/StarCode |TBD|TBD|TBD |-|
| #8 | ChatGLM |TBD|TBD|TBD |-|
| #9 | **QWen2 (1.8B, 7B)** ||148 tks/s (1.8B)|784 tks/s (1.8B) |-|
| #10 | **Google Gemma** ||130 tks/s (2B)|TBD |-|
| #11 | Blip-large (Multimodal) |TBD|TBD|TBD |-|
| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|TBD |-|
| Model ID | Model Type | Supported | Speed (A100, `BF16`) | Throughput (`BF16`, `bs=16`) | Quantized (A100, `Q4K`) | Throughput (`GTPQ/Marlin`, `bs=16`) |
|--|--|--|--|--|--|--|
| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** ||65 tks/s (LLaMa3.1 8B), **115 tks/s (LLaMa3.1 8B, Marlin)** | 553 tks/s (LLaMa3.1 8B) | 75 tks/s (LLaMa3.1 8B) |**755 tks/s (LLaMa3.1 8B)**|
| #2 | **Mistral** ||70 tks/s (7B)| 585 tks/s (7B) | 96 tks/s (7B) |TBD|
| #3 | **Phi (v1, v1.5, v2)** ||97 tks/s (2.7B, F32+BF16)|TBD|-|TBD|
| #4 | **Phi-3 (3.8B, 7B)** ||107 tks/s (3.8B)| 744 tks/s (3.8B)|135 tks/s (3.8B)|TBD|
| #5 | **Yi** ||75 tks/s (6B)| 566 tks/s (6B) | 105 tks/s (6B)|TBD|
| #6 | **StableLM** ||99 tks/s (3B)|TBD|-|TBD|
| #7 | BigCode/StarCode |TBD|TBD|TBD |-|TBD|
| #8 | ChatGLM |TBD|TBD|TBD |-|TBD|
| #9 | **QWen2 (1.8B, 7B)** ||148 tks/s (1.8B)|784 tks/s (1.8B) |-|TBD|
| #10 | **Google Gemma** ||130 tks/s (2B)|TBD |-|TBD|
| #11 | Blip-large (Multimodal) |TBD|TBD|TBD |-|TBD|
| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|TBD |-|TBD|


## Demo Chat with candle-vllm (61-65 tokens/s, LLaMa3.1 8B, bf16, on A100)
Expand Down Expand Up @@ -191,6 +192,16 @@ async def benchmark():
asyncio.run(benchmark())
```

## GPTQ/Marlin quantization
Candle-vllm now supports GPTQ (Marlin kernel), you may supply the `quant` (marlin) and `dtype` (f16) parameters if you have `Marlin` format quantized weights, such as:

```
cargo run --release -- --port 2000 --dtype f16 --weight-path /home/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4-Marlin/ llama3 --quant marlin
```
You may also use `AutoGPTQ` to transform a model to marlin format by loading the model and supply the `use_marlin=True` in `AutoGPTQ`.

**Note:** only 4bit GPTQ quantization supported for marlin format at the moment, and the input data type should be `f16` (--dtype f16). You need also renamed the transformed marlin format weight to "model.safetensors" and copy the "tokenizer.json" from the source model folder.

## In-situ quantization for consumer-grade GPUs

Candle-vllm now supports in-situ quantization, allowing the transformation of default weights (F32/F16/BF16) into any GGML format during model loading. This feature helps conserve GPU memory, making it more efficient for consumer-grade GPUs (e.g., RTX 4090). For example, 4-bit quantization can reduce GPU memory usage to less than 12GB for 8B models, while bring 13B models down to 24GB. To use this feature, simply supply the quant parameter when running candle-vllm.
Expand Down
4 changes: 3 additions & 1 deletion kernels/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ fn main() -> Result<()> {
println!("cargo:rerun-if-changed=src/pagedattention.cu");
println!("cargo:rerun-if-changed=src/copy_blocks_kernel.cu");
println!("cargo:rerun-if-changed=src/reshape_and_cache_kernel.cu");
let builder = bindgen_cuda::Builder::default();
println!("cargo:rerun-if-changed=src/marlin_cuda_kernel.cu");

let builder = bindgen_cuda::Builder::default().arg("--expt-relaxed-constexpr");
println!("cargo:info={builder:?}");
builder.build_lib("libpagedattention.a");

Expand Down
22 changes: 22 additions & 0 deletions kernels/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,26 @@ extern "C" {
dtype: u32,
softscapping: f32,
);

pub fn marlin_4bit_f16(
inputs: *const c_void,
weight: *const c_int,
scales: *const c_void,
out: *const c_void,
m: c_int,
k: c_int,
n: c_int,
workspace: *const c_void, //tensor with at least `n / 128 * max_par` entries that are all zero
groupsize: c_int,
) -> i32;

pub fn gptq_marlin_repack(
weight: *const c_void,
perm: *const c_void,
result: *const c_void,
k: c_int,
n: c_int,
bits: c_int,
);

}
1 change: 1 addition & 0 deletions kernels/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub const COPY_BLOCKS_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/copy_blocks_kernel.ptx"));
pub const MARLIN_CUDA_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/marlin_cuda_kernel.ptx"));
pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedattention.ptx"));
pub const RESHAPE_AND_CACHE_KERNEL: &str =
include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx"));
Expand Down
Loading

0 comments on commit 1b4b0d4

Please sign in to comment.