From 899515a903a98c876eaf9508e3184c52a50bc63a Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 13 Aug 2024 14:04:17 +0800 Subject: [PATCH 01/15] Support in-situ quantization --- README.md | 40 +++-- src/lib.rs | 92 ++++++++++- src/openai/models/gemma.rs | 73 +++++++-- src/openai/models/linear.rs | 263 ++++++++++++++++++++++++++++++- src/openai/models/llama.rs | 44 ++++-- src/openai/models/mistral.rs | 67 ++++++-- src/openai/models/mod.rs | 2 + src/openai/models/phi2.rs | 60 +++++-- src/openai/models/phi3.rs | 46 +++++- src/openai/models/qwen2.rs | 63 ++++++-- src/openai/models/stable_lm.rs | 69 ++++++-- src/openai/models/yi.rs | 67 ++++++-- src/openai/pipelines/pipeline.rs | 49 ++---- tests/tests.rs | 1 + 14 files changed, 800 insertions(+), 136 deletions(-) diff --git a/README.md b/README.md index 251bd97..25a9ccd 100644 --- a/README.md +++ b/README.md @@ -12,25 +12,26 @@ Efficient, easy-to-use platform for inference and serving local LLMs including a - Streaming support in generation. - Efficient management of key-value cache with PagedAttention. - Continuous batching. +- In-situ quantization ## Develop Status Currently, candle-vllm supports chat serving for the following models. -| Model ID | Model Type | Supported | Speed (A100, BF16) | Throughput (bs=16) -|--|--|--|--|--| -| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** |✅|74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)| 553 tks/s (LLaMa3.1 8B) | -| #2 | **Mistral** |✅|70 tks/s (7B)| 585 tks/s (7B) | -| #3 | **Phi (v1, v1.5, v2)** |✅|97 tks/s (2.7B, F32+BF16)|TBD| -| #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)| 744 tks/s (3.8B)| -| #5 | **Yi** |✅|75 tks/s (6B)| 566 tks/s (6B) | -| #6 | **StableLM** |✅|99 tks/s (3B)|TBD| -| #7 | BigCode/StarCode |TBD|TBD|TBD | -| #8 | ChatGLM |TBD|TBD|TBD | -| #9 | **QWen2 (1.8B, 7B)** |✅|148 tks/s (1.8B)|784 tks/s (1.8B) | -| #10 | **Google Gemma** |✅|130 tks/s (2B)|TBD | -| #11 | Blip-large (Multimodal) |TBD|TBD|TBD | -| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|TBD | +| Model ID | Model Type | Supported | Speed (A100, BF16) | Throughput (bs=16) | Quantized (A100, Q8_0) | +|--|--|--|--|--|--| +| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** |✅|74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)| 553 tks/s (LLaMa3.1 8B) | 65 tks/s (LLaMa3.1 8B) | +| #2 | **Mistral** |✅|70 tks/s (7B)| 585 tks/s (7B) | 78 tks/s (7B) | +| #3 | **Phi (v1, v1.5, v2)** |✅|97 tks/s (2.7B, F32+BF16)|TBD|-| +| #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)| 744 tks/s (3.8B)|116 tks/s (3.8B)| +| #5 | **Yi** |✅|75 tks/s (6B)| 566 tks/s (6B) | 79 tks/s (6B)| +| #6 | **StableLM** |✅|99 tks/s (3B)|TBD|-| +| #7 | BigCode/StarCode |TBD|TBD|TBD |-| +| #8 | ChatGLM |TBD|TBD|TBD |-| +| #9 | **QWen2 (1.8B, 7B)** |✅|148 tks/s (1.8B)|784 tks/s (1.8B) |-| +| #10 | **Google Gemma** |✅|130 tks/s (2B)|TBD |-| +| #11 | Blip-large (Multimodal) |TBD|TBD|TBD |-| +| #12 | Moondream-2 (Multimodal LLM) |TBD|TBD|TBD |-| ## Demo Chat with candle-vllm (61-65 tokens/s, LLaMa3.1 8B, bf16, on A100) @@ -187,6 +188,17 @@ async def benchmark(): asyncio.run(benchmark()) ``` +## In-situ quantization for consumer-grade GPUs + +Candle-vllm now supports in-situ quantization, allowing the transformation of default weights (F32/F16/BF16) into any GGML format during model loading. This feature helps conserve GPU memory, making it more efficient for consumer-grade GPUs (e.g., RTX 4090). For example, 8-bit quantization can reduce memory usage to less than 20GB for 8B models, while 4-bit quantization can bring it down to under 22GB for 13B models. To use this feature, simply supply the quant parameter when running candle-vllm. + +``` +cargo run --release -- --port 2000 --weight-path /home/Meta-Llama-3.1-8B-Instruct/ llama3 --quant q8_0 +``` + +Options for `quant` parameters: ["q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2k", "q3k","q4k","q5k","q6k"] + +**Please note** that batched processing still requires optimization when operating in quantization mode. ## Usage Help For general configuration help, run `cargo run -- --help`. diff --git a/src/lib.rs b/src/lib.rs index f53d074..0ef5d21 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,10 +2,7 @@ use candle::Result; use candle_core as candle; use clap::Subcommand; -use openai::pipelines::{ - pipeline::{DefaultLoader, SpecificConfig}, - ModelLoader, -}; +use openai::pipelines::{pipeline::DefaultLoader, ModelLoader}; #[derive(Debug, Subcommand)] pub enum ModelSelected { @@ -23,6 +20,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the llama3 model (default llama3.1-8b). @@ -39,6 +39,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the phi2 model (default 2.7b). @@ -55,6 +58,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the phi3 model (default 3.8b). @@ -77,6 +83,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the qwen model (default 1.8b). @@ -99,6 +108,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the gemma model (default 2b). @@ -115,6 +127,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the mistral model (default 7b). @@ -131,6 +146,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the Yi model (default 6b). @@ -147,6 +165,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, /// Select the stable-lm model (default zephyr-3b). @@ -163,6 +184,9 @@ pub enum ModelSelected { #[arg(long)] max_gen_tokens: Option, + + #[arg(long)] + quant: Option, }, } @@ -174,18 +198,21 @@ impl ToString for ModelSelected { temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "llama".to_string(), ModelSelected::Llama3 { repeat_last_n: _, temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "llama3".to_string(), ModelSelected::Phi2 { repeat_last_n: _, temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "phi2".to_string(), ModelSelected::Phi3 { repeat_last_n: _, @@ -194,6 +221,7 @@ impl ToString for ModelSelected { top_p: _, penalty: _, max_gen_tokens: _, + quant: _, } => "phi3".to_string(), ModelSelected::Qwen2 { repeat_last_n: _, @@ -202,35 +230,73 @@ impl ToString for ModelSelected { top_p: _, penalty: _, max_gen_tokens: _, + quant: _, } => "qwen2".to_string(), ModelSelected::Gemma { repeat_last_n: _, temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "gemma".to_string(), ModelSelected::Mistral { repeat_last_n: _, temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "mistral".to_string(), ModelSelected::Yi { repeat_last_n: _, temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "yi".to_string(), ModelSelected::StableLM { repeat_last_n: _, temperature: _, penalty: _, max_gen_tokens: _, + quant: _, } => "stablelm".to_string(), } } } +#[derive(Debug, Clone)] +pub struct SpecificConfig { + repeat_last_n: Option, + temperature: Option, + top_k: Option, + top_p: Option, + penalty: Option, + max_gen_tokens: Option, + quant: Option, +} + +impl SpecificConfig { + pub fn new( + repeat_last_n: Option, + temperature: Option, + top_k: Option, + top_p: Option, + penalty: Option, + max_gen_tokens: Option, + quant: Option, + ) -> Self { + Self { + repeat_last_n, + temperature, + top_k, + top_p, + penalty, + max_gen_tokens, + quant, + } + } +} + pub fn get_model_loader( selected_model: ModelSelected, model_id: Option, @@ -241,6 +307,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -250,6 +317,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "llama".to_string(), )), @@ -264,6 +332,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -273,6 +342,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "llama3".to_string(), )), @@ -287,6 +357,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -296,6 +367,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "phi2".to_string(), )), @@ -312,6 +384,7 @@ pub fn get_model_loader( top_p, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -321,6 +394,7 @@ pub fn get_model_loader( top_p, penalty, max_gen_tokens, + quant, ), "phi3".to_string(), )), @@ -337,6 +411,7 @@ pub fn get_model_loader( top_p, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -346,6 +421,7 @@ pub fn get_model_loader( top_p, penalty, max_gen_tokens, + quant, ), "qwen2".to_string(), )), @@ -360,6 +436,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -369,6 +446,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "gemma".to_string(), )), @@ -383,6 +461,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -392,6 +471,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "mistral".to_string(), )), @@ -407,6 +487,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -416,6 +497,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "yi".to_string(), )), @@ -431,6 +513,7 @@ pub fn get_model_loader( temperature, penalty, max_gen_tokens, + quant, } => ( Box::new(DefaultLoader::new( SpecificConfig::new( @@ -440,6 +523,7 @@ pub fn get_model_loader( None, penalty, max_gen_tokens, + quant, ), "stablelm".to_string(), )), diff --git a/src/openai/models/gemma.rs b/src/openai/models/gemma.rs index af29b8a..9cb54df 100644 --- a/src/openai/models/gemma.rs +++ b/src/openai/models/gemma.rs @@ -1,12 +1,14 @@ use super::Config; -use crate::openai::models::linear::{linear_b, linear_no_bias as linear, Linear}; +use crate::openai::models::linear::{ + linear_b_x as linear_b, linear_no_bias_x as linear, LinearX as Linear, +}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_core as candle; use candle_nn::Activation; use candle_nn::{RmsNorm, VarBuilder}; - use either::Either; use std::iter::zip; use std::sync::Arc; @@ -31,7 +33,12 @@ pub struct GemmaConfig { } impl GemmaConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { let hidden_act = match (self.hidden_act, self.hidden_activation) { (None, Some(act)) | (Some(act), None) => Some(act), (Some(_), Some(_)) => panic!("both hidden_act and hidden_activation are set"), @@ -61,6 +68,7 @@ impl GemmaConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, + specifi_config: scfg.clone(), } } } @@ -135,9 +143,24 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = linear( + hidden_sz, + intermediate_sz, + vb.pp("gate_proj"), + &cfg.specifi_config.quant, + )?; + let up_proj = linear( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + &cfg.specifi_config.quant, + )?; + let down_proj = linear( + intermediate_sz, + hidden_sz, + vb.pp("down_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { gate_proj, up_proj, @@ -175,10 +198,34 @@ impl Attention { let num_kv_heads = cfg.num_key_value_heads; let head_dim = cfg.hidden_size / cfg.num_attention_heads; let bias = cfg.attention_bias; - let q_proj = linear_b(hidden_sz, num_heads * head_dim, bias, vb.pp("q_proj"))?; - let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"))?; - let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"))?; - let o_proj = linear_b(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?; + let q_proj = linear_b( + hidden_sz, + num_heads * head_dim, + bias, + vb.pp("q_proj"), + &cfg.specifi_config.quant, + )?; + let k_proj = linear_b( + hidden_sz, + num_kv_heads * head_dim, + bias, + vb.pp("k_proj"), + &cfg.specifi_config.quant, + )?; + let v_proj = linear_b( + hidden_sz, + num_kv_heads * head_dim, + bias, + vb.pp("v_proj"), + &cfg.specifi_config.quant, + )?; + let o_proj = linear_b( + num_heads * head_dim, + hidden_sz, + bias, + vb.pp("o_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { q_proj, k_proj, @@ -340,7 +387,11 @@ impl Gemma { layers.push(layer) } let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let lm_head = Linear::new(embed_tokens.embeddings().clone(), None); + let lm_head = Linear::new( + embed_tokens.embeddings().clone(), + None, + &cfg.specifi_config.quant, + ); Ok(Self { embed_tokens, layers, diff --git a/src/openai/models/linear.rs b/src/openai/models/linear.rs index ca330dc..61f3001 100644 --- a/src/openai/models/linear.rs +++ b/src/openai/models/linear.rs @@ -18,8 +18,15 @@ //! # Ok(()) } //! ``` use crate::candle::Module; -use crate::candle::{Result, Tensor}; +use crate::candle::{ + quantized::{gguf_file, QMatMul, QTensor}, + DType, Device, Result, Tensor, +}; +use candle_core::quantized; use candle_nn::init; +use either::Either; +use std::sync::Arc; + #[derive(Clone, Debug)] pub struct Linear { weight: Tensor, @@ -126,3 +133,257 @@ pub fn linear_b( linear_no_bias(in_dim, out_dim, vb) } } + +#[derive(Debug, Clone)] +pub struct QLinear { + inner: QMatMul, + bias: Option, + dtype: DType, +} + +impl QLinear { + pub fn new( + ct: &gguf_file::Content, + r: &mut R, + name: &str, + device: &Device, + ) -> Result { + let w = ct.tensor(r, &format!("{name}.weight"), device)?; + let b = ct.tensor(r, &format!("{name}.bias"), device)?; + let inner = QMatMul::from_qtensor(w)?; + let bias = b.dequantize(device)?; + Ok(Self { + inner, + bias: Some(bias), + dtype: DType::F32, + }) + } + + pub fn from_linear(linear: Linear) -> Self { + Self { + inner: QMatMul::Tensor(linear.weight().clone()), + bias: linear.bias().cloned(), + dtype: linear.weight().dtype(), + } + } + + pub fn from_parts(w: Tensor, b: Option) -> Self { + let dtype = w.dtype(); + Self { + inner: QMatMul::Tensor(w), + bias: b, + dtype, + } + } + + pub fn from_qparts(w: QTensor, b: Option) -> Self { + if let Some(ref b) = b { + assert_eq!(b.dtype(), DType::F32); + } + Self { + inner: QMatMul::QTensor(Arc::new(w)), + bias: b, + dtype: DType::F32, + } + } + + pub fn from_qparts_x(w: QTensor, b: Option, dtype: DType) -> Self { + let bx = match b { + Some(b_) => { + if b_.dtype() != DType::F32 { + Some(b_.to_dtype(DType::F32).unwrap()) + } else { + Some(b_) + } + } + _ => None, + }; + + Self { + inner: QMatMul::QTensor(Arc::new(w)), + bias: bx, + dtype: dtype, + } + } + + pub fn from_linear_x(linear: Linear, quant: String) -> Self { + let weight = linear.weight(); + let dtype = weight.dtype(); + use quantized::GgmlDType; + + let ggml_dtype = match quant.as_str() { + "q4_0" => GgmlDType::Q4_0, + "q4_1" => GgmlDType::Q4_1, + "q5_0" => GgmlDType::Q5_0, + "q5_1" => GgmlDType::Q5_1, + "q8_0" => GgmlDType::Q8_0, + "q2k" => GgmlDType::Q2K, + "q3k" => GgmlDType::Q3K, + "q4k" => GgmlDType::Q4K, + "q5k" => GgmlDType::Q5K, + "q6k" => GgmlDType::Q6K, + _ => panic!("Unsupported GGML data type!"), + }; + let qtensor = QTensor::quantize(weight, ggml_dtype).unwrap(); + let qbias = match linear.bias() { + Some(b) => Some(b.clone()), + _ => None, + }; + + QLinear::from_qparts_x(qtensor, qbias, dtype) + } + + pub fn from_old_and_qmatmul(inner: QMatMul, old: &Self) -> Self { + Self { + inner, + bias: old.bias.clone(), + dtype: old.dtype, + } + } + + pub fn inner(&mut self) -> &mut QMatMul { + &mut self.inner + } + + pub fn inner_ref(&self) -> &QMatMul { + &self.inner + } + + pub fn is_quant(&self) -> bool { + matches!(self.inner, QMatMul::QTensor(_)) + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } + + pub fn bias_mut(&mut self) -> Option<&mut Tensor> { + self.bias.as_mut() + } +} + +impl Module for QLinear { + fn forward(&self, x: &Tensor) -> Result { + let xs = if self.is_quant() { + let x1 = match *x.dims() { + [bsize, seq_len, dim1, dim2] => { + if seq_len > 1 { + x.to_dtype(DType::F32)? + } else { + x.reshape((bsize, dim1, dim2))?.to_dtype(DType::F32)? + } + } + [bsize, seq_len, dim] => { + if seq_len > 1 { + x.to_dtype(DType::F32)? + } else { + x.reshape((bsize, dim))?.to_dtype(DType::F32)? + } + } + _ => x.to_dtype(DType::F32)?, + }; + x1 + } else { + x.clone() + }; + + let xs = match *x.dims() { + [bsize, seq_len, dim1, _] => { + if seq_len > 1 { + QMatMul::forward(&self.inner, &xs)? + } else { + QMatMul::forward(&self.inner, &xs)?.reshape((bsize, seq_len, dim1, ()))? + } + } + [bsize, seq_len, _] => { + if seq_len > 1 { + QMatMul::forward(&self.inner, &xs)? + } else { + QMatMul::forward(&self.inner, &xs)?.reshape((bsize, seq_len, ()))? + } + } + _ => QMatMul::forward(&self.inner, &xs)?, + }; + + if let Some(bias) = &self.bias { + xs.broadcast_add(bias)?.to_dtype(self.dtype) + } else { + xs.to_dtype(self.dtype) + } + } +} + +#[derive(Debug, Clone)] +pub struct LinearX(Either); + +impl Module for LinearX { + fn forward(&self, x: &Tensor) -> Result { + match &self.0 { + Either::Left(ln) => ln.forward(x), + Either::Right(ln) => ln.forward(x), + } + } +} +impl LinearX { + pub fn new(weight: Tensor, bias: Option, quant: &Option) -> Self { + let ln = Linear::new(weight, bias); + if let Some(quatized_type) = quant { + LinearX(Either::Right(QLinear::from_linear_x( + ln, + quatized_type.clone(), + ))) + } else { + LinearX(Either::Left(ln)) + } + } +} + +pub fn linear_x( + in_dim: usize, + out_dim: usize, + vb: candle_nn::VarBuilder, + quant: &Option, +) -> Result { + let ln = linear(in_dim, out_dim, vb).unwrap(); + if let Some(quatized_type) = quant { + Ok(LinearX(Either::Right(QLinear::from_linear_x( + ln, + quatized_type.clone(), + )))) + } else { + Ok(LinearX(Either::Left(ln))) + } +} + +pub fn linear_no_bias_x( + in_dim: usize, + out_dim: usize, + vb: candle_nn::VarBuilder, + quant: &Option, +) -> Result { + let init_ws = init::DEFAULT_KAIMING_NORMAL; + let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?; + let ln = Linear::new(ws, None); + if let Some(quatized_type) = quant { + Ok(LinearX(Either::Right(QLinear::from_linear_x( + ln, + quatized_type.clone(), + )))) + } else { + Ok(LinearX(Either::Left(ln))) + } +} + +pub fn linear_b_x( + in_dim: usize, + out_dim: usize, + bias: bool, + vb: candle_nn::VarBuilder, + quant: &Option, +) -> Result { + if bias { + linear_x(in_dim, out_dim, vb, quant) + } else { + linear_no_bias_x(in_dim, out_dim, vb, quant) + } +} diff --git a/src/openai/models/llama.rs b/src/openai/models/llama.rs index 64f5ccb..e76f98d 100644 --- a/src/openai/models/llama.rs +++ b/src/openai/models/llama.rs @@ -1,15 +1,16 @@ use super::Config; -use crate::openai::models::linear::{linear_no_bias as linear, Linear}; +use crate::openai::models::linear::{linear_no_bias_x as linear, LinearX as Linear}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle::{DType, Device, IndexOp, Result, Tensor}; use candle_core as candle; use candle_nn::{embedding, Embedding, Module, VarBuilder}; use candle_transformers::models::with_tracing::RmsNorm; - pub const MAX_SEQ_LEN: usize = 4096; use crate::openai::models::TokenID; use std::iter::zip; + #[derive(Debug, Clone, serde::Deserialize)] pub struct LlamaConfig { pub hidden_size: usize, @@ -31,7 +32,12 @@ fn default_rope() -> f32 { } impl LlamaConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -56,6 +62,7 @@ impl LlamaConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, + specifi_config: scfg.clone(), } } } @@ -184,10 +191,10 @@ impl CausalSelfAttention { let size_in = cfg.hidden_size; let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; - let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; - let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; - let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; - let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; + let q_proj = linear(size_in, size_q, vb.pp("q_proj"), &cfg.specifi_config.quant)?; + let k_proj = linear(size_in, size_kv, vb.pp("k_proj"), &cfg.specifi_config.quant)?; + let v_proj = linear(size_in, size_kv, vb.pp("v_proj"), &cfg.specifi_config.quant)?; + let o_proj = linear(size_q, size_in, vb.pp("o_proj"), &cfg.specifi_config.quant)?; let head_dim = cfg.hidden_size / cfg.num_attention_heads; Ok(Self { q_proj, @@ -232,9 +239,19 @@ impl Mlp { let span = tracing::span!(tracing::Level::TRACE, "mlp"); let h_size = cfg.hidden_size; let i_size = cfg.intermediate_size; - let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; - let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"))?; - let c_proj = linear(i_size, h_size, vb.pp("down_proj"))?; + let c_fc1 = linear( + h_size, + i_size, + vb.pp("gate_proj"), + &cfg.specifi_config.quant, + )?; + let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"), &cfg.specifi_config.quant)?; + let c_proj = linear( + i_size, + h_size, + vb.pp("down_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { c_fc1, c_fc2, @@ -358,7 +375,12 @@ impl Llama { pub fn load(vb: VarBuilder, cfg: &Config, dtype: DType, device: &Device) -> Result { let wte = embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = linear( + cfg.hidden_size, + cfg.vocab_size, + vb.pp("lm_head"), + &cfg.specifi_config.quant, + )?; let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) .map(|i| Block::load(vb.pp(&format!("model.layers.{i}")), cfg, dtype, device).unwrap()) diff --git a/src/openai/models/mistral.rs b/src/openai/models/mistral.rs index 80cf71b..23b407a 100644 --- a/src/openai/models/mistral.rs +++ b/src/openai/models/mistral.rs @@ -1,7 +1,8 @@ use super::Config; -use crate::openai::models::linear::{linear_no_bias, Linear}; +use crate::openai::models::linear::{linear_no_bias_x as linear_no_bias, LinearX as Linear}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_nn::{Activation, VarBuilder}; use candle_transformers::models::with_tracing::RmsNorm; @@ -28,7 +29,12 @@ pub struct MistralConfig { } impl MistralConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -53,6 +59,7 @@ impl MistralConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, + specifi_config: scfg.clone(), } } } @@ -124,9 +131,24 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("gate_proj"), + &cfg.specifi_config.quant, + )?; + let up_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + &cfg.specifi_config.quant, + )?; + let down_proj = linear_no_bias( + intermediate_sz, + hidden_sz, + vb.pp("down_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { gate_proj, up_proj, @@ -163,10 +185,30 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let head_dim = hidden_sz / num_heads; - let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + let q_proj = linear_no_bias( + hidden_sz, + num_heads * head_dim, + vb.pp("q_proj"), + &cfg.specifi_config.quant, + )?; + let k_proj = linear_no_bias( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("k_proj"), + &cfg.specifi_config.quant, + )?; + let v_proj = linear_no_bias( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("v_proj"), + &cfg.specifi_config.quant, + )?; + let o_proj = linear_no_bias( + num_heads * head_dim, + hidden_sz, + vb.pp("o_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { q_proj, k_proj, @@ -322,7 +364,12 @@ impl Mistral { layers.push(layer) } let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = linear_no_bias( + cfg.hidden_size, + cfg.vocab_size, + vb.pp("lm_head"), + &cfg.specifi_config.quant, + )?; Ok(Self { embed_tokens, layers, diff --git a/src/openai/models/mod.rs b/src/openai/models/mod.rs index ac7d85d..1ae2ea4 100644 --- a/src/openai/models/mod.rs +++ b/src/openai/models/mod.rs @@ -7,6 +7,7 @@ pub mod phi3; pub mod qwen2; pub mod stable_lm; pub mod yi; +use crate::SpecificConfig; use candle_core::DType; use either::Either; use serde::Deserialize; @@ -45,6 +46,7 @@ pub struct Config { pub kv_cache_dtype: DType, pub use_qkv_bias: Option, pub custom_stop_tokens: Option>, + pub specifi_config: SpecificConfig, } impl Config { diff --git a/src/openai/models/phi2.rs b/src/openai/models/phi2.rs index fa9e2d5..48067e8 100644 --- a/src/openai/models/phi2.rs +++ b/src/openai/models/phi2.rs @@ -1,7 +1,8 @@ use super::Config; -use crate::openai::models::linear::{linear_no_bias as linear, Linear}; +use crate::openai::models::linear::{linear_no_bias_x as linear, LinearX as Linear}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use candle_transformers::models::with_tracing::{layer_norm, Embedding, LayerNorm}; @@ -32,7 +33,12 @@ pub struct Phi2Config { } impl Phi2Config { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -57,6 +63,7 @@ impl Phi2Config { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, + specifi_config: scfg.clone(), } } } @@ -115,8 +122,18 @@ struct MLP { impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { - let fc1 = linear(cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"))?; - let fc2 = linear(cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"))?; + let fc1 = linear( + cfg.hidden_size, + cfg.intermediate_size, + vb.pp("fc1"), + &cfg.specifi_config.quant, + )?; + let fc2 = linear( + cfg.intermediate_size, + cfg.hidden_size, + vb.pp("fc2"), + &cfg.specifi_config.quant, + )?; Ok(Self { fc1, fc2, @@ -153,10 +170,30 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let head_dim = cfg.hidden_size / cfg.num_attention_heads; - let q_proj = linear(cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear(cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let dense = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("dense"))?; + let q_proj = linear( + cfg.hidden_size, + num_heads * head_dim, + vb.pp("q_proj"), + &cfg.specifi_config.quant, + )?; + let k_proj = linear( + cfg.hidden_size, + num_kv_heads * head_dim, + vb.pp("k_proj"), + &cfg.specifi_config.quant, + )?; + let v_proj = linear( + cfg.hidden_size, + num_kv_heads * head_dim, + vb.pp("v_proj"), + &cfg.specifi_config.quant, + )?; + let dense = linear( + num_heads * head_dim, + cfg.hidden_size, + vb.pp("dense"), + &cfg.specifi_config.quant, + )?; // Alternative rope scalings are not supported. let rotary_emb = RotaryEmbedding::new(cfg, dtype, vb.device())?; let (q_layernorm, k_layernorm) = if cfg.qk_layer_rms_norm.unwrap() { @@ -324,7 +361,12 @@ impl Phi2 { let layer = DecoderLayer::new(cfg, dtype, vb_m.pp(layer_idx))?; layers.push(layer) } - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = linear( + cfg.hidden_size, + cfg.vocab_size, + vb.pp("lm_head"), + &cfg.specifi_config.quant, + )?; Ok(Self { embed_tokens, layers, diff --git a/src/openai/models/phi3.rs b/src/openai/models/phi3.rs index c45e891..9c10033 100644 --- a/src/openai/models/phi3.rs +++ b/src/openai/models/phi3.rs @@ -1,9 +1,10 @@ // This implementation is based on: // https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py use super::{Config, RopeScaling}; -use crate::openai::models::linear::{linear_no_bias as linear, Linear}; +use crate::openai::models::linear::{linear_no_bias_x as linear, LinearX as Linear}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_core as candle; use candle_nn::VarBuilder; @@ -33,7 +34,12 @@ pub struct PhiConfig { } impl PhiConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -58,6 +64,7 @@ impl PhiConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, + specifi_config: scfg.clone(), } } } @@ -235,8 +242,18 @@ impl Attention { let num_kv_heads = cfg.num_key_value_heads; let head_dim = cfg.hidden_size / cfg.num_attention_heads; let op_size = num_heads * head_dim + 2 * num_kv_heads * head_dim; - let qkv_proj = linear(cfg.hidden_size, op_size, vb.pp("qkv_proj"))?; - let o_proj = linear(num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"))?; + let qkv_proj = linear( + cfg.hidden_size, + op_size, + vb.pp("qkv_proj"), + &cfg.specifi_config.quant, + )?; + let o_proj = linear( + num_heads * head_dim, + cfg.hidden_size, + vb.pp("o_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { qkv_proj, o_proj, @@ -340,8 +357,18 @@ impl Mlp { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_size = cfg.hidden_size; let i_size = cfg.intermediate_size; - let gate_up_proj = linear(hidden_size, 2 * i_size, vb.pp("gate_up_proj"))?; - let down_proj = linear(i_size, hidden_size, vb.pp("down_proj"))?; + let gate_up_proj = linear( + hidden_size, + 2 * i_size, + vb.pp("gate_up_proj"), + &cfg.specifi_config.quant, + )?; + let down_proj = linear( + i_size, + hidden_size, + vb.pp("down_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { gate_up_proj, down_proj, @@ -430,7 +457,12 @@ impl Phi { layers.push(layer) } let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = linear( + cfg.hidden_size, + cfg.vocab_size, + vb.pp("lm_head"), + &cfg.specifi_config.quant, + )?; Ok(Self { embed_tokens, layers, diff --git a/src/openai/models/qwen2.rs b/src/openai/models/qwen2.rs index 40fd59d..e446617 100644 --- a/src/openai/models/qwen2.rs +++ b/src/openai/models/qwen2.rs @@ -1,7 +1,10 @@ use super::Config; -use crate::openai::models::linear::{linear, linear_no_bias, Linear}; +use crate::openai::models::linear::{ + linear_no_bias_x as linear_no_bias, linear_x as linear, LinearX as Linear, +}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_core as candle; use candle_nn::VarBuilder; @@ -31,7 +34,12 @@ pub struct QwenConfig { } impl QwenConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -56,6 +64,7 @@ impl QwenConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, + specifi_config: scfg.clone(), } } } @@ -125,9 +134,24 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("gate_proj"), + &cfg.specifi_config.quant, + )?; + let up_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + &cfg.specifi_config.quant, + )?; + let down_proj = linear_no_bias( + intermediate_sz, + hidden_sz, + vb.pp("down_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { gate_proj, up_proj, @@ -164,10 +188,30 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let head_dim = hidden_sz / num_heads; - let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + let q_proj = linear( + hidden_sz, + num_heads * head_dim, + vb.pp("q_proj"), + &cfg.specifi_config.quant, + )?; + let k_proj = linear( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("k_proj"), + &cfg.specifi_config.quant, + )?; + let v_proj = linear( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("v_proj"), + &cfg.specifi_config.quant, + )?; + let o_proj = linear_no_bias( + num_heads * head_dim, + hidden_sz, + vb.pp("o_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { q_proj, k_proj, @@ -330,6 +374,7 @@ impl Qwen2 { } else { vb.pp("lm_head") }, + &cfg.specifi_config.quant, )?; Ok(Self { embed_tokens, diff --git a/src/openai/models/stable_lm.rs b/src/openai/models/stable_lm.rs index b2acc54..1c116ef 100644 --- a/src/openai/models/stable_lm.rs +++ b/src/openai/models/stable_lm.rs @@ -1,7 +1,10 @@ use super::Config; -use crate::openai::models::linear::{linear, linear_no_bias, Linear}; +use crate::openai::models::linear::{ + linear_no_bias_x as linear_no_bias, linear_x as linear, LinearX as Linear, +}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor, D}; use candle_nn::{Activation, LayerNorm, VarBuilder}; use either::Either; @@ -31,7 +34,12 @@ pub struct StableLMConfig { } impl StableLMConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -59,6 +67,7 @@ impl StableLMConfig { kv_cache_dtype, use_qkv_bias: Some(self.use_qkv_bias.unwrap_or(false)), custom_stop_tokens: None, + specifi_config: scfg.clone(), } } } @@ -125,9 +134,24 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("gate_proj"), + &cfg.specifi_config.quant, + )?; + let up_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + &cfg.specifi_config.quant, + )?; + let down_proj = linear_no_bias( + intermediate_sz, + hidden_sz, + vb.pp("down_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { gate_proj, up_proj, @@ -173,10 +197,30 @@ impl Attention { linear_no_bias }; - let q_proj = linear_layer(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_layer(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + let q_proj = linear_layer( + hidden_sz, + num_heads * head_dim, + vb.pp("q_proj"), + &cfg.specifi_config.quant, + )?; + let k_proj = linear_layer( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("k_proj"), + &cfg.specifi_config.quant, + )?; + let v_proj = linear_layer( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("v_proj"), + &cfg.specifi_config.quant, + )?; + let o_proj = linear_no_bias( + num_heads * head_dim, + hidden_sz, + vb.pp("o_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { q_proj, k_proj, @@ -333,7 +377,12 @@ impl StableLM { layers.push(layer) } let norm = candle_nn::layer_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = linear_no_bias( + cfg.hidden_size, + cfg.vocab_size, + vb.pp("lm_head"), + &cfg.specifi_config.quant, + )?; Ok(Self { embed_tokens, layers, diff --git a/src/openai/models/yi.rs b/src/openai/models/yi.rs index 20de25a..295dd7a 100644 --- a/src/openai/models/yi.rs +++ b/src/openai/models/yi.rs @@ -1,7 +1,8 @@ use super::Config; -use crate::openai::models::linear::{linear_no_bias, Linear}; +use crate::openai::models::linear::{linear_no_bias_x as linear_no_bias, LinearX as Linear}; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; +use crate::SpecificConfig; use candle_core::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_nn::{Activation, VarBuilder}; use candle_transformers::models::with_tracing::RmsNorm; @@ -28,7 +29,12 @@ pub struct YiConfig { } impl YiConfig { - pub fn into_config(self, use_flash_attn: bool, kv_cache_dtype: DType) -> Config { + pub fn into_config( + self, + use_flash_attn: bool, + kv_cache_dtype: DType, + scfg: &SpecificConfig, + ) -> Config { Config { hidden_size: self.hidden_size, intermediate_size: self.intermediate_size, @@ -53,6 +59,7 @@ impl YiConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: Some(vec!["<|im_end|>".to_string()]), + specifi_config: scfg.clone(), } } } @@ -123,9 +130,24 @@ impl MLP { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden_sz = cfg.hidden_size; let intermediate_sz = cfg.intermediate_size; - let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; - let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; - let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + let gate_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("gate_proj"), + &cfg.specifi_config.quant, + )?; + let up_proj = linear_no_bias( + hidden_sz, + intermediate_sz, + vb.pp("up_proj"), + &cfg.specifi_config.quant, + )?; + let down_proj = linear_no_bias( + intermediate_sz, + hidden_sz, + vb.pp("down_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { gate_proj, up_proj, @@ -162,10 +184,30 @@ impl Attention { let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; let head_dim = hidden_sz / num_heads; - let q_proj = linear_no_bias(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; - let k_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; - let v_proj = linear_no_bias(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; - let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + let q_proj = linear_no_bias( + hidden_sz, + num_heads * head_dim, + vb.pp("q_proj"), + &cfg.specifi_config.quant, + )?; + let k_proj = linear_no_bias( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("k_proj"), + &cfg.specifi_config.quant, + )?; + let v_proj = linear_no_bias( + hidden_sz, + num_kv_heads * head_dim, + vb.pp("v_proj"), + &cfg.specifi_config.quant, + )?; + let o_proj = linear_no_bias( + num_heads * head_dim, + hidden_sz, + vb.pp("o_proj"), + &cfg.specifi_config.quant, + )?; Ok(Self { q_proj, k_proj, @@ -319,7 +361,12 @@ impl Yi { layers.push(layer) } let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; - let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?; + let lm_head = linear_no_bias( + cfg.hidden_size, + cfg.vocab_size, + vb.pp("lm_head"), + &cfg.specifi_config.quant, + )?; Ok(Self { embed_tokens, layers, diff --git a/src/openai/pipelines/pipeline.rs b/src/openai/pipelines/pipeline.rs index 38ab155..2bf5340 100644 --- a/src/openai/pipelines/pipeline.rs +++ b/src/openai/pipelines/pipeline.rs @@ -26,7 +26,7 @@ use crate::{ PipelineConfig, }, paged_attention::input_metadata::InputMetadata, - try_api, + try_api, SpecificConfig, }; use candle_core::{DType, Device, IndexOp, Tensor}; use candle_examples::token_output_stream::TokenOutputStream; @@ -42,37 +42,6 @@ const EOS_TOKEN: &str = ""; const SAMPLING_SEED: u64 = 299792458; const MIN_GEN_TOKENS: usize = 128; const MAX_GEN_TOKENS: usize = 4096; - -#[derive(Debug, Clone)] -pub struct SpecificConfig { - repeat_last_n: Option, - temperature: Option, - top_k: Option, - top_p: Option, - penalty: Option, - max_gen_tokens: Option, -} - -impl SpecificConfig { - pub fn new( - repeat_last_n: Option, - temperature: Option, - top_k: Option, - top_p: Option, - penalty: Option, - max_gen_tokens: Option, - ) -> Self { - Self { - repeat_last_n, - temperature, - top_k, - top_p, - penalty, - max_gen_tokens, - } - } -} - enum LLMModel { LLAMA(Llama), Phi2(Phi2), @@ -176,50 +145,50 @@ impl ModelLoader for DefaultLoader { let config: LlamaConfig = try_api!(serde_json::from_slice(&try_api!( std::fs::read(paths.get_config_filename()) ),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } "phi2" => { let config: Phi2Config = try_api!(serde_json::from_slice(&try_api!( std::fs::read(paths.get_config_filename()) ),)); //Phi2 use F32 type for kvcache - config.into_config(false, DType::F32) + config.into_config(false, DType::F32, &specific_args) } "phi3" => { let config: PhiConfig = try_api!(serde_json::from_slice(&try_api!(std::fs::read( paths.get_config_filename() )),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } "qwen2" => { let config: QwenConfig = try_api!(serde_json::from_slice(&try_api!( std::fs::read(paths.get_config_filename()) ),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } "gemma" => { let config: GemmaConfig = try_api!(serde_json::from_slice(&try_api!( std::fs::read(paths.get_config_filename()) ),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } "mistral" => { let config: MistralConfig = try_api!(serde_json::from_slice(&try_api!( std::fs::read(paths.get_config_filename()) ),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } "yi" => { let config: YiConfig = try_api!(serde_json::from_slice(&try_api!(std::fs::read( paths.get_config_filename() )),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } "stablelm" => { let config: StableLMConfig = try_api!(serde_json::from_slice(&try_api!( std::fs::read(paths.get_config_filename()) ),)); - config.into_config(false, dtype) + config.into_config(false, dtype, &specific_args) } _ => panic!("Model not supported!"), }; diff --git a/tests/tests.rs b/tests/tests.rs index 7a65c69..ccd9842 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -24,6 +24,7 @@ async fn test_llama() -> Result<(), APIError> { penalty: Some(1.1), temperature: None, max_gen_tokens: Some(512), + quant: None, }, Some("meta-llama/Llama-2-7b-chat-hf".to_string()), ); From 6e791f5a19e8c7eb574cdd85acfb37627fc0ceca Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 13 Aug 2024 14:12:14 +0800 Subject: [PATCH 02/15] Typo fix --- src/openai/models/gemma.rs | 18 +++++++++--------- src/openai/models/llama.rs | 18 +++++++++--------- src/openai/models/mistral.rs | 18 +++++++++--------- src/openai/models/mod.rs | 2 +- src/openai/models/phi2.rs | 16 ++++++++-------- src/openai/models/phi3.rs | 12 ++++++------ src/openai/models/qwen2.rs | 18 +++++++++--------- src/openai/models/stable_lm.rs | 18 +++++++++--------- src/openai/models/yi.rs | 18 +++++++++--------- 9 files changed, 69 insertions(+), 69 deletions(-) diff --git a/src/openai/models/gemma.rs b/src/openai/models/gemma.rs index 9cb54df..0a8c6e3 100644 --- a/src/openai/models/gemma.rs +++ b/src/openai/models/gemma.rs @@ -68,7 +68,7 @@ impl GemmaConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, - specifi_config: scfg.clone(), + specific_config: scfg.clone(), } } } @@ -147,19 +147,19 @@ impl MLP { hidden_sz, intermediate_sz, vb.pp("gate_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let up_proj = linear( hidden_sz, intermediate_sz, vb.pp("up_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let down_proj = linear( intermediate_sz, hidden_sz, vb.pp("down_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { gate_proj, @@ -203,28 +203,28 @@ impl Attention { num_heads * head_dim, bias, vb.pp("q_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let k_proj = linear_b( hidden_sz, num_kv_heads * head_dim, bias, vb.pp("k_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let v_proj = linear_b( hidden_sz, num_kv_heads * head_dim, bias, vb.pp("v_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let o_proj = linear_b( num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { q_proj, @@ -390,7 +390,7 @@ impl Gemma { let lm_head = Linear::new( embed_tokens.embeddings().clone(), None, - &cfg.specifi_config.quant, + &cfg.specific_config.quant, ); Ok(Self { embed_tokens, diff --git a/src/openai/models/llama.rs b/src/openai/models/llama.rs index e76f98d..7b3fd65 100644 --- a/src/openai/models/llama.rs +++ b/src/openai/models/llama.rs @@ -62,7 +62,7 @@ impl LlamaConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, - specifi_config: scfg.clone(), + specific_config: scfg.clone(), } } } @@ -191,10 +191,10 @@ impl CausalSelfAttention { let size_in = cfg.hidden_size; let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; - let q_proj = linear(size_in, size_q, vb.pp("q_proj"), &cfg.specifi_config.quant)?; - let k_proj = linear(size_in, size_kv, vb.pp("k_proj"), &cfg.specifi_config.quant)?; - let v_proj = linear(size_in, size_kv, vb.pp("v_proj"), &cfg.specifi_config.quant)?; - let o_proj = linear(size_q, size_in, vb.pp("o_proj"), &cfg.specifi_config.quant)?; + let q_proj = linear(size_in, size_q, vb.pp("q_proj"), &cfg.specific_config.quant)?; + let k_proj = linear(size_in, size_kv, vb.pp("k_proj"), &cfg.specific_config.quant)?; + let v_proj = linear(size_in, size_kv, vb.pp("v_proj"), &cfg.specific_config.quant)?; + let o_proj = linear(size_q, size_in, vb.pp("o_proj"), &cfg.specific_config.quant)?; let head_dim = cfg.hidden_size / cfg.num_attention_heads; Ok(Self { q_proj, @@ -243,14 +243,14 @@ impl Mlp { h_size, i_size, vb.pp("gate_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; - let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"), &cfg.specifi_config.quant)?; + let c_fc2 = linear(h_size, i_size, vb.pp("up_proj"), &cfg.specific_config.quant)?; let c_proj = linear( i_size, h_size, vb.pp("down_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { c_fc1, @@ -379,7 +379,7 @@ impl Llama { cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let ln_f = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?; let blocks: Vec<_> = (0..cfg.num_hidden_layers) diff --git a/src/openai/models/mistral.rs b/src/openai/models/mistral.rs index 23b407a..038e267 100644 --- a/src/openai/models/mistral.rs +++ b/src/openai/models/mistral.rs @@ -59,7 +59,7 @@ impl MistralConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, - specifi_config: scfg.clone(), + specific_config: scfg.clone(), } } } @@ -135,19 +135,19 @@ impl MLP { hidden_sz, intermediate_sz, vb.pp("gate_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let up_proj = linear_no_bias( hidden_sz, intermediate_sz, vb.pp("up_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let down_proj = linear_no_bias( intermediate_sz, hidden_sz, vb.pp("down_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { gate_proj, @@ -189,25 +189,25 @@ impl Attention { hidden_sz, num_heads * head_dim, vb.pp("q_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let k_proj = linear_no_bias( hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let v_proj = linear_no_bias( hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let o_proj = linear_no_bias( num_heads * head_dim, hidden_sz, vb.pp("o_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { q_proj, @@ -368,7 +368,7 @@ impl Mistral { cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { embed_tokens, diff --git a/src/openai/models/mod.rs b/src/openai/models/mod.rs index 1ae2ea4..1cb94e6 100644 --- a/src/openai/models/mod.rs +++ b/src/openai/models/mod.rs @@ -46,7 +46,7 @@ pub struct Config { pub kv_cache_dtype: DType, pub use_qkv_bias: Option, pub custom_stop_tokens: Option>, - pub specifi_config: SpecificConfig, + pub specific_config: SpecificConfig, } impl Config { diff --git a/src/openai/models/phi2.rs b/src/openai/models/phi2.rs index 48067e8..4ee25a6 100644 --- a/src/openai/models/phi2.rs +++ b/src/openai/models/phi2.rs @@ -63,7 +63,7 @@ impl Phi2Config { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, - specifi_config: scfg.clone(), + specific_config: scfg.clone(), } } } @@ -126,13 +126,13 @@ impl MLP { cfg.hidden_size, cfg.intermediate_size, vb.pp("fc1"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let fc2 = linear( cfg.intermediate_size, cfg.hidden_size, vb.pp("fc2"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { fc1, @@ -174,25 +174,25 @@ impl Attention { cfg.hidden_size, num_heads * head_dim, vb.pp("q_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let k_proj = linear( cfg.hidden_size, num_kv_heads * head_dim, vb.pp("k_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let v_proj = linear( cfg.hidden_size, num_kv_heads * head_dim, vb.pp("v_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let dense = linear( num_heads * head_dim, cfg.hidden_size, vb.pp("dense"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; // Alternative rope scalings are not supported. let rotary_emb = RotaryEmbedding::new(cfg, dtype, vb.device())?; @@ -365,7 +365,7 @@ impl Phi2 { cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { embed_tokens, diff --git a/src/openai/models/phi3.rs b/src/openai/models/phi3.rs index 9c10033..f750129 100644 --- a/src/openai/models/phi3.rs +++ b/src/openai/models/phi3.rs @@ -64,7 +64,7 @@ impl PhiConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, - specifi_config: scfg.clone(), + specific_config: scfg.clone(), } } } @@ -246,13 +246,13 @@ impl Attention { cfg.hidden_size, op_size, vb.pp("qkv_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let o_proj = linear( num_heads * head_dim, cfg.hidden_size, vb.pp("o_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { qkv_proj, @@ -361,13 +361,13 @@ impl Mlp { hidden_size, 2 * i_size, vb.pp("gate_up_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let down_proj = linear( i_size, hidden_size, vb.pp("down_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { gate_up_proj, @@ -461,7 +461,7 @@ impl Phi { cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { embed_tokens, diff --git a/src/openai/models/qwen2.rs b/src/openai/models/qwen2.rs index e446617..8b835b2 100644 --- a/src/openai/models/qwen2.rs +++ b/src/openai/models/qwen2.rs @@ -64,7 +64,7 @@ impl QwenConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: None, - specifi_config: scfg.clone(), + specific_config: scfg.clone(), } } } @@ -138,19 +138,19 @@ impl MLP { hidden_sz, intermediate_sz, vb.pp("gate_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let up_proj = linear_no_bias( hidden_sz, intermediate_sz, vb.pp("up_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let down_proj = linear_no_bias( intermediate_sz, hidden_sz, vb.pp("down_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { gate_proj, @@ -192,25 +192,25 @@ impl Attention { hidden_sz, num_heads * head_dim, vb.pp("q_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let k_proj = linear( hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let v_proj = linear( hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let o_proj = linear_no_bias( num_heads * head_dim, hidden_sz, vb.pp("o_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { q_proj, @@ -374,7 +374,7 @@ impl Qwen2 { } else { vb.pp("lm_head") }, - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { embed_tokens, diff --git a/src/openai/models/stable_lm.rs b/src/openai/models/stable_lm.rs index 1c116ef..9388c28 100644 --- a/src/openai/models/stable_lm.rs +++ b/src/openai/models/stable_lm.rs @@ -67,7 +67,7 @@ impl StableLMConfig { kv_cache_dtype, use_qkv_bias: Some(self.use_qkv_bias.unwrap_or(false)), custom_stop_tokens: None, - specifi_config: scfg.clone(), + specific_config: scfg.clone(), } } } @@ -138,19 +138,19 @@ impl MLP { hidden_sz, intermediate_sz, vb.pp("gate_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let up_proj = linear_no_bias( hidden_sz, intermediate_sz, vb.pp("up_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let down_proj = linear_no_bias( intermediate_sz, hidden_sz, vb.pp("down_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { gate_proj, @@ -201,25 +201,25 @@ impl Attention { hidden_sz, num_heads * head_dim, vb.pp("q_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let k_proj = linear_layer( hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let v_proj = linear_layer( hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let o_proj = linear_no_bias( num_heads * head_dim, hidden_sz, vb.pp("o_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { q_proj, @@ -381,7 +381,7 @@ impl StableLM { cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { embed_tokens, diff --git a/src/openai/models/yi.rs b/src/openai/models/yi.rs index 295dd7a..3ff4ecd 100644 --- a/src/openai/models/yi.rs +++ b/src/openai/models/yi.rs @@ -59,7 +59,7 @@ impl YiConfig { kv_cache_dtype, use_qkv_bias: None, custom_stop_tokens: Some(vec!["<|im_end|>".to_string()]), - specifi_config: scfg.clone(), + specific_config: scfg.clone(), } } } @@ -134,19 +134,19 @@ impl MLP { hidden_sz, intermediate_sz, vb.pp("gate_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let up_proj = linear_no_bias( hidden_sz, intermediate_sz, vb.pp("up_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let down_proj = linear_no_bias( intermediate_sz, hidden_sz, vb.pp("down_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { gate_proj, @@ -188,25 +188,25 @@ impl Attention { hidden_sz, num_heads * head_dim, vb.pp("q_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let k_proj = linear_no_bias( hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let v_proj = linear_no_bias( hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; let o_proj = linear_no_bias( num_heads * head_dim, hidden_sz, vb.pp("o_proj"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { q_proj, @@ -365,7 +365,7 @@ impl Yi { cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"), - &cfg.specifi_config.quant, + &cfg.specific_config.quant, )?; Ok(Self { embed_tokens, From 504398d5527ffbda52a5098378277f8d681ad3d3 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 13 Aug 2024 14:15:35 +0800 Subject: [PATCH 03/15] Cargo fmt --- src/openai/models/llama.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/openai/models/llama.rs b/src/openai/models/llama.rs index 7b3fd65..168139f 100644 --- a/src/openai/models/llama.rs +++ b/src/openai/models/llama.rs @@ -192,8 +192,18 @@ impl CausalSelfAttention { let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; let q_proj = linear(size_in, size_q, vb.pp("q_proj"), &cfg.specific_config.quant)?; - let k_proj = linear(size_in, size_kv, vb.pp("k_proj"), &cfg.specific_config.quant)?; - let v_proj = linear(size_in, size_kv, vb.pp("v_proj"), &cfg.specific_config.quant)?; + let k_proj = linear( + size_in, + size_kv, + vb.pp("k_proj"), + &cfg.specific_config.quant, + )?; + let v_proj = linear( + size_in, + size_kv, + vb.pp("v_proj"), + &cfg.specific_config.quant, + )?; let o_proj = linear(size_q, size_in, vb.pp("o_proj"), &cfg.specific_config.quant)?; let head_dim = cfg.hidden_size / cfg.num_attention_heads; Ok(Self { From a3e1fc4ecab7729a277cbdc92917cfc8d1094540 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Tue, 13 Aug 2024 16:06:36 +0800 Subject: [PATCH 04/15] Optimize quantized matmul in batch processing & update Q4K results --- README.md | 30 +++++++---- src/openai/models/linear.rs | 101 ++++++++++++++++++++++++++++-------- 2 files changed, 100 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 25a9ccd..f51b896 100644 --- a/README.md +++ b/README.md @@ -12,19 +12,19 @@ Efficient, easy-to-use platform for inference and serving local LLMs including a - Streaming support in generation. - Efficient management of key-value cache with PagedAttention. - Continuous batching. -- In-situ quantization +- `In-situ` quantization ## Develop Status Currently, candle-vllm supports chat serving for the following models. -| Model ID | Model Type | Supported | Speed (A100, BF16) | Throughput (bs=16) | Quantized (A100, Q8_0) | +| Model ID | Model Type | Supported | Speed (A100, `BF16`) | Throughput (`BF16`, `bs=16`) | Quantized (A100, `Q4K`) | |--|--|--|--|--|--| -| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** |✅|74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)| 553 tks/s (LLaMa3.1 8B) | 65 tks/s (LLaMa3.1 8B) | -| #2 | **Mistral** |✅|70 tks/s (7B)| 585 tks/s (7B) | 78 tks/s (7B) | +| #1 | **LLAMA/LLAMA2/LLaMa3/LLaMa3.1** |✅|74 tks/s (7B), 65 tks/s (LLaMa3.1 8B)| 553 tks/s (LLaMa3.1 8B) | 75 tks/s (LLaMa3.1 8B) | +| #2 | **Mistral** |✅|70 tks/s (7B)| 585 tks/s (7B) | 96 tks/s (7B) | | #3 | **Phi (v1, v1.5, v2)** |✅|97 tks/s (2.7B, F32+BF16)|TBD|-| -| #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)| 744 tks/s (3.8B)|116 tks/s (3.8B)| -| #5 | **Yi** |✅|75 tks/s (6B)| 566 tks/s (6B) | 79 tks/s (6B)| +| #4 | **Phi-3 (3.8B, 7B)** |✅|107 tks/s (3.8B)| 744 tks/s (3.8B)|135 tks/s (3.8B)| +| #5 | **Yi** |✅|75 tks/s (6B)| 566 tks/s (6B) | 105 tks/s (6B)| | #6 | **StableLM** |✅|99 tks/s (3B)|TBD|-| | #7 | BigCode/StarCode |TBD|TBD|TBD |-| | #8 | ChatGLM |TBD|TBD|TBD |-| @@ -190,15 +190,19 @@ asyncio.run(benchmark()) ## In-situ quantization for consumer-grade GPUs -Candle-vllm now supports in-situ quantization, allowing the transformation of default weights (F32/F16/BF16) into any GGML format during model loading. This feature helps conserve GPU memory, making it more efficient for consumer-grade GPUs (e.g., RTX 4090). For example, 8-bit quantization can reduce memory usage to less than 20GB for 8B models, while 4-bit quantization can bring it down to under 22GB for 13B models. To use this feature, simply supply the quant parameter when running candle-vllm. +Candle-vllm now supports in-situ quantization, allowing the transformation of default weights (F32/F16/BF16) into any GGML format during model loading. This feature helps conserve GPU memory, making it more efficient for consumer-grade GPUs (e.g., RTX 4090). For example, 4-bit quantization can reduce GPU memory usage to less than 12GB for 8B models, while bring 13B models down to 24GB. To use this feature, simply supply the quant parameter when running candle-vllm. ``` -cargo run --release -- --port 2000 --weight-path /home/Meta-Llama-3.1-8B-Instruct/ llama3 --quant q8_0 +cargo run --release -- --port 2000 --weight-path /home/Meta-Llama-3.1-8B-Instruct/ llama3 --quant q4k ``` Options for `quant` parameters: ["q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2k", "q3k","q4k","q5k","q6k"] -**Please note** that batched processing still requires optimization when operating in quantization mode. +**Please note**: + +1) It may takes few minutes to load F32/F16/BF16 models into quantized; + +2) Batched processing still requires further optimizations when operating in quantization mode. ## Usage Help For general configuration help, run `cargo run -- --help`. @@ -237,6 +241,14 @@ cargo run --release -- --port 2000 --weight-path /home/mistral_7b/ mistral --rep `--max-gen-tokens` parameter is used to control the maximum output tokens per chat response. The value will be set to 1/5 of max_sequence_len by default. +For `consumer GPUs`, it is suggested to run the models under GGML formats, e.g., + +``` +cargo run --release -- --port 2000 --weight-path /home/Meta-Llama-3.1-8B-Instruct/ llama3 --quant q4k +``` + +where `quant` is one of ["q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2k", "q3k","q4k","q5k","q6k"]. + ## Report issue Installing `candle-vllm` is as simple as the following steps. If you have any problems, please create an [issue](https://github.com/EricLBuehler/candle-lora/issues). diff --git a/src/openai/models/linear.rs b/src/openai/models/linear.rs index 61f3001..d4b2bce 100644 --- a/src/openai/models/linear.rs +++ b/src/openai/models/linear.rs @@ -262,31 +262,25 @@ impl QLinear { } } -impl Module for QLinear { - fn forward(&self, x: &Tensor) -> Result { - let xs = if self.is_quant() { - let x1 = match *x.dims() { - [bsize, seq_len, dim1, dim2] => { - if seq_len > 1 { - x.to_dtype(DType::F32)? - } else { - x.reshape((bsize, dim1, dim2))?.to_dtype(DType::F32)? - } +impl QLinear { + pub fn forward_no_dequant(&self, x: &Tensor) -> Result { + let xs = match *x.dims() { + [bsize, seq_len, dim1, dim2] => { + if seq_len > 1 { + x.to_dtype(DType::F32)? + } else { + x.reshape((bsize, dim1, dim2))?.to_dtype(DType::F32)? } - [bsize, seq_len, dim] => { - if seq_len > 1 { - x.to_dtype(DType::F32)? - } else { - x.reshape((bsize, dim))?.to_dtype(DType::F32)? - } + } + [bsize, seq_len, dim] => { + if seq_len > 1 { + x.to_dtype(DType::F32)? + } else { + x.reshape((bsize, dim))?.to_dtype(DType::F32)? } - _ => x.to_dtype(DType::F32)?, - }; - x1 - } else { - x.clone() + } + _ => x.to_dtype(DType::F32)?, }; - let xs = match *x.dims() { [bsize, seq_len, dim1, _] => { if seq_len > 1 { @@ -311,6 +305,69 @@ impl Module for QLinear { xs.to_dtype(self.dtype) } } + + pub fn forward_via_f16(&self, x: &Tensor) -> Result { + let in_dtype = x.dtype(); + let w = self.inner.dequantize_f16()?; + let w = match *x.dims() { + [b1, seq_len, _, _] => { + if seq_len > 1 { + w.broadcast_left((b1, seq_len))?.t()? + } else { + w.t()? + } + } + [bsize, seq_len, _] => { + if seq_len > 1 { + w.broadcast_left(bsize)?.t()? + } else { + w.t()? + } + } + _ => w.t()?, + }; + let x = x.to_dtype(DType::F16)?; + let x = match *x.dims() { + [bsize, seq_len, dim1, dim2] => { + if seq_len > 1 { + x.matmul(&w)? + } else { + let wdim = w.dims()[w.dims().len() - 1]; + x.reshape((bsize * seq_len, dim1, dim2))? + .matmul(&w)? + .reshape((bsize, seq_len, dim1, wdim))? + } + } + [bsize, seq_len, dim] => { + if seq_len > 1 { + x.matmul(&w)? + } else { + let wdim = w.dims()[w.dims().len() - 1]; + x.reshape((bsize * seq_len, dim))? + .matmul(&w)? + .reshape((bsize, seq_len, wdim))? + } + } + _ => x.matmul(&w)?, + }; + + if let Some(bias) = &self.bias { + x.broadcast_add(bias)?.to_dtype(in_dtype) + } else { + x.to_dtype(in_dtype) + } + } +} + +impl Module for QLinear { + fn forward(&self, x: &Tensor) -> Result { + let batch = x.dims()[0]; + if batch > 4 { + self.forward_via_f16(x) //suitable for batched + } else { + self.forward_no_dequant(x) //faster in single-query + } + } } #[derive(Debug, Clone)] From 80f56ae036d2bf37152270c524aae8c01ff63aae Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 14 Aug 2024 13:28:32 +0800 Subject: [PATCH 05/15] Fix bug for non-stream response --- src/openai/openai_server.rs | 67 ++++++++++++++---------------- src/openai/pipelines/llm_engine.rs | 6 +-- 2 files changed, 34 insertions(+), 39 deletions(-) diff --git a/src/openai/openai_server.rs b/src/openai/openai_server.rs index f4d899c..19ca9a5 100644 --- a/src/openai/openai_server.rs +++ b/src/openai/openai_server.rs @@ -178,24 +178,30 @@ pub async fn chat_completions( let (response_tx, rx) = flume::unbounded(); // println!("{:?}", sampling_params); - if request.stream.is_some_and(|x| x) { - let _ = tokio::task::spawn_blocking(move || { - tokio::runtime::Handle::current().block_on(async move { - { - //send completion request to inference engine - let mut model = data.model.lock().await; - model.add_request( - token_ids, - request_id.clone(), - SystemTime::now(), - sampling_params, - request.logprobs.unwrap_or(false), - Some(response_tx), - ); - model.notify.notify_one(); - } - }); + let finish_notify = data.finish_notify.clone(); + let data_clone = data.clone(); + let request_id_clone = request_id.clone(); + let stream_request = request.stream.is_some_and(|x| x); + let model_name = request.model.clone(); + let _ = tokio::task::spawn_blocking(move || { + tokio::runtime::Handle::current().block_on(async move { + { + //send completion request to inference engine + let mut model = data.model.lock().await; + model.add_request( + token_ids, + request_id.clone(), + SystemTime::now(), + sampling_params, + request.logprobs.unwrap_or(false), + Some(response_tx), + ); + model.notify.notify_one(); + } }); + }); + + if stream_request { ChatResponder::Streamer( Sse::new(Streamer { rx, @@ -212,35 +218,24 @@ pub async fn chat_completions( ), ) } else { - //send completion request to inference engine - let mut model = data.model.lock().await; - model.add_request( - token_ids, - request_id.clone(), - SystemTime::now(), - sampling_params, - request.logprobs.unwrap_or(false), - Some(response_tx), - ); - model.notify.notify_one(); // wait until current response finished - data.finish_notify.notified().await; - let model = data.model.lock().await; - if !model.completion_records.contains_key(&request_id) { + finish_notify.notified().await; + let model = data_clone.model.lock().await; + if !model.completion_records.contains_key(&request_id_clone) { return ChatResponder::ModelError(APIError::from(format!( "Unable to generate response for request {}", - request_id + request_id_clone ))); } - let choices = &model.completion_records[&request_id].0; - let usage = &model.completion_records[&request_id].1; + let choices = &model.completion_records[&request_id_clone].0; + let usage = &model.completion_records[&request_id_clone].1; ChatResponder::Completion(ChatCompletionResponse { - id: request_id, + id: request_id_clone, choices: choices.to_vec(), created: usage.created, - model: request.model.clone(), + model: model_name, object: "chat.completion", usage: usage.clone(), }) diff --git a/src/openai/pipelines/llm_engine.rs b/src/openai/pipelines/llm_engine.rs index 527eb41..d9bed06 100644 --- a/src/openai/pipelines/llm_engine.rs +++ b/src/openai/pipelines/llm_engine.rs @@ -93,9 +93,9 @@ impl LLMEngine { if result.len() == 0 { continue; } - - let _ = result.values() - .map(|usage| e.completion_records.insert(usage.1.request_id.clone(), usage.clone())); + for request_id in result.keys() { + e.completion_records.insert(request_id.to_string(), result[request_id].clone()); + } finish_notify.notify_one(); //chat completion statistics From bd476d316c31546a6506ff20f0b5e01893b3f0c0 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 14 Aug 2024 14:04:39 +0800 Subject: [PATCH 06/15] Ask users to provide huggingface token if no token cached and passed to the program. --- src/main.rs | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/main.rs b/src/main.rs index 1968be5..0b17e9f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -100,7 +100,30 @@ async fn main() -> Result<(), APIError> { safetensors_files }, }), - _ => loader.download_model(model_id, None, args.hf_token, args.hf_token_path)?, + _ => { + if args.hf_token.is_none() && args.hf_token_path.is_none() { + //no token provided + let token_path = format!( + "{}/.cache/huggingface/token", + dirs::home_dir() + .ok_or(APIError::new_str("No home directory"))? + .display() + ); + if !Path::new(&token_path).exists() { + //also no token cache + use std::io::Write; + let mut input_token = String::new(); + println!("Please provide your huggingface token to download model:\n"); + std::io::stdin() + .read_line(&mut input_token) + .expect("Failed to read token!"); + std::fs::create_dir_all(Path::new(&token_path).parent().unwrap()).unwrap(); + let mut output = std::fs::File::create(token_path).unwrap(); + write!(output, "{}", input_token.trim()).expect("Failed to save token!"); + } + } + loader.download_model(model_id, None, args.hf_token, args.hf_token_path)? + } }; let dtype = match args.dtype.as_deref() { From afb50f3ace6240ee632807bfb135e147be911392 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 14 Aug 2024 15:27:41 +0800 Subject: [PATCH 07/15] No crash when both hidden_act and hidden_activation are set for gemma model --- src/openai/models/gemma.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/openai/models/gemma.rs b/src/openai/models/gemma.rs index 0a8c6e3..289c50a 100644 --- a/src/openai/models/gemma.rs +++ b/src/openai/models/gemma.rs @@ -41,7 +41,10 @@ impl GemmaConfig { ) -> Config { let hidden_act = match (self.hidden_act, self.hidden_activation) { (None, Some(act)) | (Some(act), None) => Some(act), - (Some(_), Some(_)) => panic!("both hidden_act and hidden_activation are set"), + (Some(act), Some(_)) => { + println!("both hidden_act and hidden_activation are set"); + Some(act) + } (None, None) => panic!("none of hidden_act and hidden_activation are set"), }; Config { From 616ffc6488dab32353c09a81dff5fc923951298a Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 14 Aug 2024 15:28:31 +0800 Subject: [PATCH 08/15] Print the number of decoded tokens for each request --- examples/benchmark.py | 2 +- src/openai/pipelines/llm_engine.rs | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index ff51ce8..560cdf9 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -57,7 +57,7 @@ async def benchmark(): # avoid generating very short answers for i in range(len(prompts)): - prompts[i] = prompts[i] + " Describe in about {} words.".format((int(max_tokens / 1.3 / 10) + 1) * 10) + prompts[i] = prompts[i] + " Respond in more than {} words.".format((int(max_tokens / 10) + 1) * 10) # send 16 chat requests at the same time tasks: List[asyncio.Task] = [] diff --git a/src/openai/pipelines/llm_engine.rs b/src/openai/pipelines/llm_engine.rs index d9bed06..9a10e17 100644 --- a/src/openai/pipelines/llm_engine.rs +++ b/src/openai/pipelines/llm_engine.rs @@ -272,9 +272,12 @@ impl LLMEngine { .duration_since(prompt_finish_time) .unwrap() .as_millis(); + let seq = group.get_seqs().values().nth(0).unwrap(); + let decoded_tokens = seq.deref().get_len() - seq.deref().get_prompt_len(); println!( - "Request {} decoding finished in {} seconds", + "Request {} decoding {} tokens finished in {} seconds", group.request_id, + decoded_tokens, completion_time_costs / 1000 ); // Create choices from the group From 360a22768fa1ce9d249da7909558f45da9a5d834 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 21 Aug 2024 16:55:31 +0800 Subject: [PATCH 09/15] Restore previous bug fix --- kernels/src/lib.rs | 8 ++- src/openai/openai_server.rs | 28 +++++--- src/openai/pipelines/llm_engine.rs | 112 ++++++++++++++--------------- 3 files changed, 78 insertions(+), 70 deletions(-) diff --git a/kernels/src/lib.rs b/kernels/src/lib.rs index c43d980..5903b07 100644 --- a/kernels/src/lib.rs +++ b/kernels/src/lib.rs @@ -1,4 +1,6 @@ -pub const COPY_BLOCKS_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/copy_blocks_kernel.ptx")); +pub const COPY_BLOCKS_KERNEL: &str = + include_str!(concat!(env!("OUT_DIR"), "/copy_blocks_kernel.ptx")); pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedattention.ptx")); -pub const RESHAPE_AND_CACHE_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx")); -pub mod ffi; +pub const RESHAPE_AND_CACHE_KERNEL: &str = + include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx")); +pub mod ffi; \ No newline at end of file diff --git a/src/openai/openai_server.rs b/src/openai/openai_server.rs index 06e2b3d..44f05b8 100644 --- a/src/openai/openai_server.rs +++ b/src/openai/openai_server.rs @@ -184,17 +184,23 @@ pub async fn chat_completions( let stream_request = request.stream.is_some_and(|x| x); let model_name = request.model.clone(); - //send completion request to inference engine - let mut model = data.model.lock().await; - model.add_request( - token_ids, - request_id.clone(), - SystemTime::now(), - sampling_params, - request.logprobs.unwrap_or(false), - Some(response_tx), - ); - model.notify.notify_one(); + let _ = tokio::task::spawn_blocking(move || { + tokio::runtime::Handle::current().block_on(async move { + { + //send completion request to inference engine + let mut model = data.model.lock().await; + model.add_request( + token_ids, + request_id.clone(), + SystemTime::now(), + sampling_params, + request.logprobs.unwrap_or(false), + Some(response_tx), + ); + model.notify.notify_one(); + } + }); + }); if stream_request { ChatResponder::Streamer( diff --git a/src/openai/pipelines/llm_engine.rs b/src/openai/pipelines/llm_engine.rs index ed121f1..1e57f04 100644 --- a/src/openai/pipelines/llm_engine.rs +++ b/src/openai/pipelines/llm_engine.rs @@ -83,63 +83,63 @@ impl LLMEngine { })); let engine_clone = engine.clone(); - tokio::runtime::Handle::current().block_on(async move { - loop { - notify.notified().await; // Blocking call to wait for notification - let _ = tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; - let mut e = engine.lock().await; - let result = e.generate_once().unwrap(); - if result.is_empty() { - continue; - } - for request_id in result.keys() { - e.completion_records - .insert(request_id.to_string(), result[request_id].clone()); - } - finish_notify.notify_one(); - - //chat completion statistics - let overall_usage = ChatCompletionUsageResponse { - request_id: "".to_string(), - created: 0, - completion_tokens: result - .values() - .map(|(_, usage)| usage.completion_tokens) - .sum(), - prompt_tokens: result.values().map(|(_, usage)| usage.prompt_tokens).sum(), - total_tokens: result.values().map(|(_, usage)| usage.total_tokens).sum(), - prompt_time_costs: result - .values() - .map(|(_, usage)| usage.prompt_time_costs) - .max() - .unwrap_or(0), - completion_time_costs: result - .values() - .map(|(_, usage)| usage.completion_time_costs) - .max() - .unwrap_or(0), - }; + let _ = tokio::task::spawn_blocking(move || { + tokio::runtime::Handle::current().block_on(async move { + loop { + notify.notified().await; // Blocking call to wait for notification + let _ = tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + let mut e = engine.lock().await; + let result = e.generate_once().unwrap(); + if result.len() == 0 { + continue; + } + for request_id in result.keys() { + e.completion_records.insert(request_id.to_string(), result[request_id].clone()); + } + finish_notify.notify_one(); + + //chat completion statistics + let overall_usage = ChatCompletionUsageResponse { + request_id: "".to_string(), + created: 0, + completion_tokens: result.values() + .map(|(_, usage)| usage.completion_tokens) + .sum(), + prompt_tokens: result.values().map(|(_, usage)| usage.prompt_tokens).sum(), + total_tokens: result.values().map(|(_, usage)| usage.total_tokens).sum(), + prompt_time_costs: result + .values() + .map(|(_, usage)| usage.prompt_time_costs) + .max() + .unwrap_or(0), + completion_time_costs: result + .values() + .map(|(_, usage)| usage.completion_time_costs) + .max() + .unwrap_or(0), + }; - println!( - "\r\n [{} requests] Prefilling: {} prompt tokens processed in {} seconds", - result.len(), - overall_usage.prompt_tokens, - overall_usage.prompt_time_costs / 1000 - ); - - println!( - "\r\n [{} requests] Decoding: {} tokens processed in {} seconds ({} tokens/s)", - result.len(), - overall_usage.completion_tokens, - overall_usage.completion_time_costs / 1000, - overall_usage.completion_tokens * 1000 - / if overall_usage.completion_time_costs > 0 { - overall_usage.completion_time_costs - } else { - 1 - } - ); - } + println!( + "\r\n [{} requests] Prefilling: {} prompt tokens processed in {} seconds", + result.len(), + overall_usage.prompt_tokens, + overall_usage.prompt_time_costs / 1000 + ); + + println!( + "\r\n [{} requests] Decoding: {} tokens processed in {} seconds ({} tokens/s)", + result.len(), + overall_usage.completion_tokens, + overall_usage.completion_time_costs / 1000, + overall_usage.completion_tokens * 1000 + / if overall_usage.completion_time_costs > 0 { + overall_usage.completion_time_costs + } else { + 1 + } + ); + } + }); }); Ok(engine_clone) From a33884f11e97b25fd7c4124c3ea25ca30ac99e23 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 21 Aug 2024 18:18:55 +0800 Subject: [PATCH 10/15] Support softcapping (Gemma-2 models) --- examples/benchmark.py | 2 +- kernels/src/ffi.rs | 2 + kernels/src/lib.rs | 2 +- kernels/src/pagedattention.cu | 56 +++++++++++++++----- src/backend/paged_attention.rs | 6 ++- src/openai/models/gemma.rs | 96 ++++++++++++++++++++++++++++------ src/openai/models/llama.rs | 4 ++ src/openai/models/mistral.rs | 4 ++ src/openai/models/mod.rs | 6 ++- src/openai/models/phi2.rs | 4 ++ src/openai/models/phi3.rs | 4 ++ src/openai/models/qwen2.rs | 4 ++ src/openai/models/stable_lm.rs | 4 ++ src/openai/models/yi.rs | 4 ++ src/paged_attention/mod.rs | 6 +++ src/scheduler/cache_engine.rs | 4 +- 16 files changed, 172 insertions(+), 36 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 560cdf9..1595b02 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -57,7 +57,7 @@ async def benchmark(): # avoid generating very short answers for i in range(len(prompts)): - prompts[i] = prompts[i] + " Respond in more than {} words.".format((int(max_tokens / 10) + 1) * 10) + prompts[i] = prompts[i] + " Respond in more than {} words.".format(int(max_tokens / 10) * 10) # send 16 chat requests at the same time tasks: List[asyncio.Task] = [] diff --git a/kernels/src/ffi.rs b/kernels/src/ffi.rs index 423db2c..6c1078d 100644 --- a/kernels/src/ffi.rs +++ b/kernels/src/ffi.rs @@ -40,6 +40,7 @@ extern "C" { kv_head_stride: c_int, dtype: u32, + softscapping: f32, ); pub fn paged_attention_v2( @@ -66,5 +67,6 @@ extern "C" { kv_head_stride: c_int, dtype: u32, + softscapping: f32, ); } diff --git a/kernels/src/lib.rs b/kernels/src/lib.rs index 5903b07..54dce19 100644 --- a/kernels/src/lib.rs +++ b/kernels/src/lib.rs @@ -3,4 +3,4 @@ pub const COPY_BLOCKS_KERNEL: &str = pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedattention.ptx")); pub const RESHAPE_AND_CACHE_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx")); -pub mod ffi; \ No newline at end of file +pub mod ffi; diff --git a/kernels/src/pagedattention.cu b/kernels/src/pagedattention.cu index f3a5082..6127677 100644 --- a/kernels/src/pagedattention.cu +++ b/kernels/src/pagedattention.cu @@ -73,6 +73,20 @@ inline __device__ float block_sum(float* red_smem, float sum) { return VLLM_SHFL_SYNC(sum, 0); } +inline __device__ float fast_tanh(float x) { + #if defined(__CUDA_ARCH__) + #if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDA_ARCH__ >= 750) + float y; + asm volatile ( "tanh.approx.f32 %0, %1; " : "=f"(y) : "f"(x)); + return y; + #else + return ::tanhf(x); + #endif + #else + return std::tanh(x); + #endif +} + // TODO(woosuk): Merge the last two dimensions of the grid. // Grid: (num_heads, num_seqs, max_num_partitions). template< @@ -96,7 +110,8 @@ __device__ void paged_attention_kernel( const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, - const int kv_head_stride) { + const int kv_head_stride, + const float softscapping) { const int seq_idx = blockIdx.y; const int partition_idx = blockIdx.z; const int max_num_partitions = gridDim.z; @@ -212,6 +227,10 @@ __device__ void paged_attention_kernel( // Compute dot product. // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + + if (softscapping != 1.0) { + qk = fast_tanh(qk / softscapping) * softscapping; + } // Add the ALiBi bias if slopes are given. qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0; @@ -409,11 +428,12 @@ __global__ void paged_attention_v1_kernel( const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, - const int kv_head_stride) { + const int kv_head_stride, + const float softscapping) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride, softscapping); } // Grid: (num_heads, num_seqs, max_num_partitions). @@ -438,11 +458,12 @@ __global__ void paged_attention_v2_kernel( const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, - const int kv_head_stride) { + const int kv_head_stride, + const float softscapping) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride); + q_stride, kv_block_stride, kv_head_stride, softscapping); } // Grid: (num_heads, num_seqs). @@ -564,7 +585,8 @@ __global__ void paged_attention_v2_reduce_kernel( alibi_slopes_ptr, \ q_stride, \ kv_block_stride, \ - kv_head_stride); + kv_head_stride,\ + softscapping); // TODO(woosuk): Tune NUM_THREADS. template< @@ -588,7 +610,8 @@ void paged_attention_v1_launcher( int max_num_blocks_per_seq, int q_stride, int kv_block_stride, - int kv_head_stride + int kv_head_stride, + float softscapping ) { // int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); @@ -652,7 +675,8 @@ void paged_attention_v1_launcher( max_num_blocks_per_seq, \ q_stride, \ kv_block_stride, \ - kv_head_stride); + kv_head_stride, \ + softscapping); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -691,7 +715,8 @@ extern "C" void paged_attention_v1( int32_t kv_block_stride, int32_t kv_head_stride, - uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 + uint32_t dtype, // 0 => f16; 1 => bf16; 2 => f32 + float softscapping ) { if (dtype == 2) { CALL_V1_LAUNCHER_BLOCK_SIZE(float); @@ -719,7 +744,8 @@ extern "C" void paged_attention_v1( alibi_slopes, \ q_stride, \ kv_block_stride, \ - kv_head_stride); \ + kv_head_stride,\ + softscapping); \ vllm::paged_attention_v2_reduce_kernel \ <<>>( \ reinterpret_cast(out), \ @@ -754,8 +780,8 @@ void paged_attention_v2_launcher( int max_num_blocks_per_seq, int q_stride, int kv_block_stride, - int kv_head_stride - + int kv_head_stride, + float softscapping ) { // int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); @@ -825,7 +851,8 @@ void paged_attention_v2_launcher( max_num_blocks_per_seq, \ q_stride, \ kv_block_stride, \ - kv_head_stride); + kv_head_stride,\ + softscapping); // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. @@ -867,7 +894,8 @@ extern "C" void paged_attention_v2( int32_t kv_block_stride, int32_t kv_head_stride, - uint32_t dtype // 0 => f16; 1 => bf16; 2 => f32 + uint32_t dtype, // 0 => f16; 1 => bf16; 2 => f32 + float softscapping ) { if (dtype == 2) { CALL_V2_LAUNCHER_BLOCK_SIZE(float); diff --git a/src/backend/paged_attention.rs b/src/backend/paged_attention.rs index 784a1cc..412191d 100644 --- a/src/backend/paged_attention.rs +++ b/src/backend/paged_attention.rs @@ -11,7 +11,7 @@ use std::ffi::c_int; struct PagedAttention { softmax_scale: f32, - + softcapping: f32, key_cache: Tensor, value_cache: Tensor, block_tables: Tensor, @@ -187,6 +187,7 @@ impl PagedAttention { kv_block_stride as c_int, kv_head_stride as c_int, internal_type, + self.softcapping, ) } } else { @@ -223,6 +224,7 @@ impl PagedAttention { kv_block_stride as c_int, kv_head_stride as c_int, internal_type, + self.softcapping, ) } } @@ -277,6 +279,7 @@ pub fn paged_attention( context_lens: &Tensor, max_context_len: usize, softmax_scale: f32, + softcapping: f32, ) -> Result { let op = PagedAttention { softmax_scale, @@ -285,6 +288,7 @@ pub fn paged_attention( block_tables: block_tables.clone(), context_lens: context_lens.clone(), max_context_len, + softcapping, }; q.apply_op1(op) } diff --git a/src/openai/models/gemma.rs b/src/openai/models/gemma.rs index 2a444a6..c1e177f 100644 --- a/src/openai/models/gemma.rs +++ b/src/openai/models/gemma.rs @@ -15,7 +15,7 @@ use std::sync::Arc; #[derive(serde::Deserialize, Debug, Clone)] pub struct GemmaConfig { pub attention_bias: bool, - pub head_dim: usize, + pub head_dim: Option, // The code gemma configs include both hidden_act and hidden_activation. pub hidden_act: Option, pub hidden_activation: Option, @@ -30,6 +30,8 @@ pub struct GemmaConfig { pub bos_token_id: usize, pub eos_token_id: usize, pub max_position_embeddings: Option, + pub attn_logit_softcapping: Option, + pub final_logit_softcapping: Option, } impl GemmaConfig { @@ -49,6 +51,10 @@ impl GemmaConfig { }; Config { hidden_size: self.hidden_size, + head_dim: Some( + self.head_dim + .unwrap_or(self.hidden_size / self.num_attention_heads), + ), intermediate_size: self.intermediate_size, vocab_size: self.vocab_size, num_hidden_layers: self.num_hidden_layers, @@ -72,6 +78,8 @@ impl GemmaConfig { use_qkv_bias: None, custom_stop_tokens: None, specific_config: scfg.clone(), + attn_logit_softcapping: self.attn_logit_softcapping, + final_logit_softcapping: self.final_logit_softcapping, } } } @@ -89,7 +97,7 @@ struct RotaryEmbedding { impl RotaryEmbedding { fn new(_dtype: DType, cfg: &Config, dev: &Device) -> Result { - let dim = cfg.hidden_size / cfg.num_attention_heads; + let dim = cfg.get_head_size(); let max_seq_len = cfg.max_seq_len; let inv_freq: Vec<_> = (0..dim) .step_by(2) @@ -199,7 +207,7 @@ impl Attention { let hidden_sz = cfg.hidden_size; let num_heads = cfg.num_attention_heads; let num_kv_heads = cfg.num_key_value_heads; - let head_dim = cfg.hidden_size / cfg.num_attention_heads; + let head_dim = cfg.head_dim.unwrap(); let bias = cfg.attention_bias; let q_proj = linear_b( hidden_sz, @@ -258,6 +266,7 @@ impl Attention { input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, input_metadata: &mut InputMetadata, + softcapping: Option, ) -> Result { let (b_sz, seq_len, _) = xs.dims3()?; @@ -307,13 +316,13 @@ impl Attention { cache.map(|(k_, _)| k_.clone()), cache.map(|(_, v_)| v_.clone()), input_metadata, + softcapping, )?; let y = if attention_mask.is_some() { - y.transpose(1, 2)? - .reshape(&[b_sz, seq_len, self.hidden_size])? + y.transpose(1, 2)?.reshape((b_sz, seq_len, ()))? } else { - y.reshape(&[b_sz, seq_len, self.hidden_size])? + y.reshape((b_sz, seq_len, ()))? }; let y = self.o_proj.forward(&y)?; Ok(y) @@ -324,6 +333,8 @@ struct DecoderLayer { self_attn: Attention, mlp: MLP, input_layernorm: RmsNorm, + post_feedforward_layernorm: Option, + pre_feedforward_layernorm: Option, post_attention_layernorm: RmsNorm, } @@ -333,6 +344,27 @@ impl DecoderLayer { let mlp = MLP::new(cfg, vb.pp("mlp"))?; let input_layernorm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + + let pre_feedforward_layernorm = if cfg.attn_logit_softcapping.is_some() { + Some(rms_norm( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("pre_feedforward_layernorm"), + )?) + } else { + None + }; + + let post_feedforward_layernorm = if cfg.attn_logit_softcapping.is_some() { + Some(rms_norm( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_feedforward_layernorm"), + )?) + } else { + None + }; + let post_attention_layernorm = rms_norm( cfg.hidden_size, cfg.rms_norm_eps, @@ -342,6 +374,8 @@ impl DecoderLayer { self_attn, mlp, input_layernorm, + pre_feedforward_layernorm, + post_feedforward_layernorm, post_attention_layernorm, }) } @@ -353,16 +387,38 @@ impl DecoderLayer { input_positions: &[Vec], cache: Option<(&Tensor, &Tensor)>, input_metadata: &mut InputMetadata, + softcapping: Option, ) -> Result { let residual = xs; let xs = self.input_layernorm.forward(xs)?; - let xs = - self.self_attn - .forward(&xs, attention_mask, input_positions, cache, input_metadata)?; - let xs = (xs + residual)?; - let residual = &xs; - let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; - residual + xs + let xs = self.self_attn.forward( + &xs, + attention_mask, + input_positions, + cache, + input_metadata, + softcapping, + )?; + + if softcapping.is_some() { + let xs = xs.apply(&self.post_attention_layernorm)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = match &self.pre_feedforward_layernorm { + Some(l) => l.forward(&xs)?, + None => xs.clone(), + }; + let xs = xs.apply(&self.mlp)?; + let xs = match &self.post_feedforward_layernorm { + Some(l) => l.forward(&xs)?, + None => xs, + }; + residual + xs + } else { + let xs = (xs + residual)?; + let residual = &xs; + residual + xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)? + } } } @@ -440,6 +496,7 @@ impl Gemma { input_positions, Some((k_cache, v_cache)), input_metadata, + self.cfg.attn_logit_softcapping, )? } } else { @@ -450,14 +507,21 @@ impl Gemma { input_positions, None, input_metadata, + self.cfg.attn_logit_softcapping, )? } } - xs.i((.., seq_len - 1, ..))? + let logits = xs + .i((.., seq_len - 1, ..))? .apply(&self.norm)? - .apply(&self.lm_head)? - .to_dtype(DType::F32) + .apply(&self.lm_head)?; + + let logits = match self.cfg.final_logit_softcapping { + None => logits, + Some(sc) => ((logits / sc)?.tanh()? * sc)?, + }; + logits.to_dtype(DType::F32) } pub fn get_config(&self) -> &Config { diff --git a/src/openai/models/llama.rs b/src/openai/models/llama.rs index 2bcd088..616fe5b 100644 --- a/src/openai/models/llama.rs +++ b/src/openai/models/llama.rs @@ -40,6 +40,7 @@ impl LlamaConfig { ) -> Config { Config { hidden_size: self.hidden_size, + head_dim: Some(self.hidden_size / self.num_attention_heads), intermediate_size: self.intermediate_size, vocab_size: self.vocab_size, num_hidden_layers: self.num_hidden_layers, @@ -63,6 +64,8 @@ impl LlamaConfig { use_qkv_bias: None, custom_stop_tokens: None, specific_config: scfg.clone(), + attn_logit_softcapping: None, + final_logit_softcapping: None, } } } @@ -174,6 +177,7 @@ impl CausalSelfAttention { cache.map(|(k_, _)| k_.clone()), cache.map(|(_, v_)| v_.clone()), input_metadata, + None, )?; let y = if attention_mask.is_some() { diff --git a/src/openai/models/mistral.rs b/src/openai/models/mistral.rs index 8c309b6..22838e1 100644 --- a/src/openai/models/mistral.rs +++ b/src/openai/models/mistral.rs @@ -37,6 +37,7 @@ impl MistralConfig { ) -> Config { Config { hidden_size: self.hidden_size, + head_dim: Some(self.hidden_size / self.num_attention_heads), intermediate_size: self.intermediate_size, vocab_size: self.vocab_size, num_hidden_layers: self.num_hidden_layers, @@ -60,6 +61,8 @@ impl MistralConfig { use_qkv_bias: None, custom_stop_tokens: None, specific_config: scfg.clone(), + attn_logit_softcapping: None, + final_logit_softcapping: None, } } } @@ -281,6 +284,7 @@ impl Attention { cache.map(|(k_, _)| k_.clone()), cache.map(|(_, v_)| v_.clone()), input_metadata, + None, )?; let y = if attention_mask.is_some() { diff --git a/src/openai/models/mod.rs b/src/openai/models/mod.rs index 1cb94e6..b40449b 100644 --- a/src/openai/models/mod.rs +++ b/src/openai/models/mod.rs @@ -24,6 +24,7 @@ pub struct TokenID( #[derive(Debug, Clone)] pub struct Config { pub hidden_size: usize, + pub head_dim: Option, pub intermediate_size: usize, pub vocab_size: usize, pub num_hidden_layers: usize, @@ -47,10 +48,13 @@ pub struct Config { pub use_qkv_bias: Option, pub custom_stop_tokens: Option>, pub specific_config: SpecificConfig, + pub attn_logit_softcapping: Option, + pub final_logit_softcapping: Option, } impl Config { pub fn get_head_size(&self) -> usize { - self.hidden_size / self.num_attention_heads + self.head_dim + .unwrap_or(self.hidden_size / self.num_attention_heads) } } diff --git a/src/openai/models/phi2.rs b/src/openai/models/phi2.rs index 8082e52..4a31341 100644 --- a/src/openai/models/phi2.rs +++ b/src/openai/models/phi2.rs @@ -41,6 +41,7 @@ impl Phi2Config { ) -> Config { Config { hidden_size: self.hidden_size, + head_dim: Some(self.hidden_size / self.num_attention_heads), intermediate_size: self.intermediate_size, vocab_size: self.vocab_size, num_hidden_layers: self.num_hidden_layers, @@ -64,6 +65,8 @@ impl Phi2Config { use_qkv_bias: None, custom_stop_tokens: None, specific_config: scfg.clone(), + attn_logit_softcapping: None, + final_logit_softcapping: None, } } } @@ -284,6 +287,7 @@ impl Attention { cache.map(|(k_, _)| k_.clone()), cache.map(|(_, v_)| v_.clone()), input_metadata, + None, )?; let y = if attention_mask.is_some() { diff --git a/src/openai/models/phi3.rs b/src/openai/models/phi3.rs index 259527f..78bcc76 100644 --- a/src/openai/models/phi3.rs +++ b/src/openai/models/phi3.rs @@ -42,6 +42,7 @@ impl PhiConfig { ) -> Config { Config { hidden_size: self.hidden_size, + head_dim: Some(self.hidden_size / self.num_attention_heads), intermediate_size: self.intermediate_size, vocab_size: self.vocab_size, num_hidden_layers: self.num_hidden_layers, @@ -65,6 +66,8 @@ impl PhiConfig { use_qkv_bias: None, custom_stop_tokens: None, specific_config: scfg.clone(), + attn_logit_softcapping: None, + final_logit_softcapping: None, } } } @@ -332,6 +335,7 @@ impl Attention { cache.map(|(k_, _)| k_.clone()), cache.map(|(_, v_)| v_.clone()), input_metadata, + None, )?; let y = if attention_mask.is_some() { diff --git a/src/openai/models/qwen2.rs b/src/openai/models/qwen2.rs index 2d8c7ea..67788e2 100644 --- a/src/openai/models/qwen2.rs +++ b/src/openai/models/qwen2.rs @@ -42,6 +42,7 @@ impl QwenConfig { ) -> Config { Config { hidden_size: self.hidden_size, + head_dim: Some(self.hidden_size / self.num_attention_heads), intermediate_size: self.intermediate_size, vocab_size: self.vocab_size, num_hidden_layers: self.num_hidden_layers, @@ -65,6 +66,8 @@ impl QwenConfig { use_qkv_bias: None, custom_stop_tokens: None, specific_config: scfg.clone(), + attn_logit_softcapping: None, + final_logit_softcapping: None, } } } @@ -283,6 +286,7 @@ impl Attention { cache.map(|(k_, _)| k_.clone()), cache.map(|(_, v_)| v_.clone()), input_metadata, + None, )?; let y = if attention_mask.is_some() { diff --git a/src/openai/models/stable_lm.rs b/src/openai/models/stable_lm.rs index 9977c94..faaceed 100644 --- a/src/openai/models/stable_lm.rs +++ b/src/openai/models/stable_lm.rs @@ -42,6 +42,7 @@ impl StableLMConfig { ) -> Config { Config { hidden_size: self.hidden_size, + head_dim: Some(self.hidden_size / self.num_attention_heads), intermediate_size: self.intermediate_size, vocab_size: self.vocab_size, num_hidden_layers: self.num_hidden_layers, @@ -68,6 +69,8 @@ impl StableLMConfig { use_qkv_bias: Some(self.use_qkv_bias.unwrap_or(false)), custom_stop_tokens: None, specific_config: scfg.clone(), + attn_logit_softcapping: None, + final_logit_softcapping: None, } } } @@ -295,6 +298,7 @@ impl Attention { cache.map(|(k_, _)| k_.clone()), cache.map(|(_, v_)| v_.clone()), input_metadata, + None, )?; let y = if attention_mask.is_some() { diff --git a/src/openai/models/yi.rs b/src/openai/models/yi.rs index 199097e..f663c57 100644 --- a/src/openai/models/yi.rs +++ b/src/openai/models/yi.rs @@ -37,6 +37,7 @@ impl YiConfig { ) -> Config { Config { hidden_size: self.hidden_size, + head_dim: Some(self.hidden_size / self.num_attention_heads), intermediate_size: self.intermediate_size, vocab_size: self.vocab_size, num_hidden_layers: self.num_hidden_layers, @@ -60,6 +61,8 @@ impl YiConfig { use_qkv_bias: None, custom_stop_tokens: Some(vec!["<|im_end|>".to_string()]), specific_config: scfg.clone(), + attn_logit_softcapping: None, + final_logit_softcapping: None, } } } @@ -280,6 +283,7 @@ impl Attention { cache.map(|(k_, _)| k_.clone()), cache.map(|(_, v_)| v_.clone()), input_metadata, + None, )?; let y = if attention_mask.is_some() { diff --git a/src/paged_attention/mod.rs b/src/paged_attention/mod.rs index c2a020a..7ce04e4 100644 --- a/src/paged_attention/mod.rs +++ b/src/paged_attention/mod.rs @@ -67,6 +67,7 @@ impl PagedAttention { mut key_cache: Option, mut value_cache: Option, input_metadata: &mut InputMetadata, + softcapping: Option, ) -> Result { let dims = input_metadata.slot_mapping.dims(); let slot_mapping = if dims.len() > 1 { @@ -94,6 +95,10 @@ impl PagedAttention { } else { (query.matmul(&key.t()?)? * f64::from(self.scale))? }; + let att = match softcapping { + None => att, + Some(sc) => ((att / sc)?.tanh()? * sc)?, + }; let att = att.broadcast_add(mask)?; let att = candle_nn::ops::softmax_last_dim(&att.to_dtype(DType::F32)?)? @@ -173,6 +178,7 @@ impl PagedAttention { input_metadata.context_lens.as_ref().unwrap(), input_metadata.max_context_len.unwrap(), self.scale, + softcapping.unwrap_or(1.0f64) as f32, ) } } diff --git a/src/scheduler/cache_engine.rs b/src/scheduler/cache_engine.rs index 07f0517..5ef85b4 100644 --- a/src/scheduler/cache_engine.rs +++ b/src/scheduler/cache_engine.rs @@ -161,7 +161,7 @@ impl CacheEngine { let x = 16 / element_size; ( model_config.num_key_value_heads, - model_config.hidden_size / model_config.num_attention_heads / x, + model_config.get_head_size() / x, block_size, x, ) @@ -173,7 +173,7 @@ impl CacheEngine { ) -> (usize, usize, usize) { ( model_config.num_key_value_heads, - model_config.hidden_size / model_config.num_attention_heads, + model_config.get_head_size(), block_size, ) } From 761067e168cdfde30ab4359da3b16baa91347bda Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 21 Aug 2024 18:24:23 +0800 Subject: [PATCH 11/15] Update lib.rs --- kernels/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/kernels/src/lib.rs b/kernels/src/lib.rs index 2bc1b60..54dce19 100644 --- a/kernels/src/lib.rs +++ b/kernels/src/lib.rs @@ -4,4 +4,3 @@ pub const PAGEDATTENTION: &str = include_str!(concat!(env!("OUT_DIR"), "/pagedat pub const RESHAPE_AND_CACHE_KERNEL: &str = include_str!(concat!(env!("OUT_DIR"), "/reshape_and_cache_kernel.ptx")); pub mod ffi; - From ff84499ea6cc5469ce6c94afb2108fb8aee7d74d Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Thu, 22 Aug 2024 10:28:21 +0800 Subject: [PATCH 12/15] Fix Gemma-2 multiple eos/bos ids --- src/openai/models/gemma.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/openai/models/gemma.rs b/src/openai/models/gemma.rs index c1e177f..98e4e46 100644 --- a/src/openai/models/gemma.rs +++ b/src/openai/models/gemma.rs @@ -2,6 +2,7 @@ use super::Config; use crate::openai::models::linear::{ linear_b_x as linear_b, linear_no_bias_x as linear, LinearX as Linear, }; +use crate::openai::models::TokenID; use crate::paged_attention::input_metadata::InputMetadata; use crate::paged_attention::PagedAttention; use crate::SpecificConfig; @@ -9,9 +10,9 @@ use candle::{DType, Device, IndexOp, Module, Result, Tensor}; use candle_core as candle; use candle_nn::Activation; use candle_nn::{RmsNorm, VarBuilder}; -use either::Either; use std::iter::zip; use std::sync::Arc; + #[derive(serde::Deserialize, Debug, Clone)] pub struct GemmaConfig { pub attention_bias: bool, @@ -27,8 +28,8 @@ pub struct GemmaConfig { pub rms_norm_eps: f64, pub rope_theta: f64, pub vocab_size: usize, - pub bos_token_id: usize, - pub eos_token_id: usize, + pub bos_token_id: TokenID, + pub eos_token_id: TokenID, pub max_position_embeddings: Option, pub attn_logit_softcapping: Option, pub final_logit_softcapping: Option, @@ -63,8 +64,8 @@ impl GemmaConfig { rms_norm_eps: self.rms_norm_eps, rope_theta: self.rope_theta, use_flash_attn, - bos_token_id: super::TokenID(Either::Left(Some(self.bos_token_id as u32))), - eos_token_id: super::TokenID(Either::Left(Some(self.eos_token_id as u32))), + bos_token_id: self.bos_token_id, + eos_token_id: self.eos_token_id, max_seq_len: self.max_position_embeddings.unwrap_or(4096), sliding_window: None, hidden_act, @@ -198,7 +199,6 @@ struct Attention { num_kv_heads: usize, head_dim: usize, rotary_emb: Arc, - hidden_size: usize, attn: PagedAttention, } @@ -246,7 +246,6 @@ impl Attention { num_kv_heads, head_dim, rotary_emb, - hidden_size: hidden_sz, attn: PagedAttention::new( cfg.num_attention_heads, head_dim, From 2c81291dc445610f0b0ccf3c552a925a2d1b8759 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Fri, 23 Aug 2024 17:04:48 +0800 Subject: [PATCH 13/15] Custom benchmark with parameters --- examples/benchmark.py | 50 ++++++++++++++++++++++------------------- src/openai/streaming.rs | 2 -- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 1595b02..72564bb 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -3,15 +3,26 @@ from openai import Stream from openai.types.chat import ChatCompletionChunk from typing import List -# Run: cargo run --release -- --port 2000 --model-id --repeat-last-n 64 +import argparse +# Run candle-vllm service: cargo run --release -- --port 2000 --model-id --repeat-last-n 64 # MODEL_ID is the huggingface model id or local weight path # MODEL_TYPE is one of ["llama", "llama3", "mistral", "phi2", "phi3", "qwen2", "gemma", "yi", "stable-lm"] - +# Then run this file: python3 examples/benchmark.py --batch 16 openai.api_key = "EMPTY" openai.base_url = "http://localhost:2000/v1/" +# You may add your custom prompts here +PROMPT_CANDIDATES = ["Explain how to best learn Rust.", + "Please talk about deep learning.", + "Do you know the capital city of China? Talk the details of you known.", + "Who is the best female actor in the world? Explain why.", + "Let me know how to deal with depression?", + "How to make money in short time?", + "What is the future trend of large language model?", + "The famous tech companies in the world."] + async def chat_completion(model, max_tokens, prompt): completion = openai.chat.completions.create( model=model, @@ -34,26 +45,12 @@ async def stream_response(response_idx, stream: Stream[ChatCompletionChunk]): result += r return (response_idx, result) -async def benchmark(): - model = "mistral7b" - max_tokens = 1024 - # 16 requests - prompts = ["Explain how to best learn Rust.", - "Please talk about deep learning.", - "Do you know the capital city of China? Talk the details of you known.", - "Who is the best female actor in the world? Explain why.", - "Let me know how to deal with depression?", - "How to make money in short time?", - "What is the future trend of large language model?", - "The famous tech companies in the world.", - "Explain how to best learn Rust.", - "Please talk about deep learning.", - "Do you know the capital city of China? Talk the details of you known.", - "Who is the best female actor in the world? Explain why.", - "Let me know how to deal with depression?", - "How to make money in short time?", - "What is the future trend of large language model?", - "The famous tech companies in the world."] +async def benchmark(batch, max_tokens=1024): + model = "any" # model used dependent on the server side + # candidate requests + prompts = [] + for i in range(batch): + prompts.append(PROMPT_CANDIDATES[i % len(PROMPT_CANDIDATES)]) # avoid generating very short answers for i in range(len(prompts)): @@ -86,4 +83,11 @@ async def benchmark(): print("\n\n Response {}: \n\n {}".format(idx, output)) -asyncio.run(benchmark()) \ No newline at end of file +if __name__ == "__main__": + batch = 1 + parser = argparse.ArgumentParser(description="Using 'batch' and 'max_tokens' parameters for candle-vllm benchmark.") + parser.add_argument('--batch', default=16, type=int) + parser.add_argument('--max_tokens', default=1024, type=int) + + args = parser.parse_args() + asyncio.run(benchmark(args.batch, args.max_tokens)) \ No newline at end of file diff --git a/src/openai/streaming.rs b/src/openai/streaming.rs index 455dab1..a88d019 100644 --- a/src/openai/streaming.rs +++ b/src/openai/streaming.rs @@ -50,11 +50,9 @@ impl Stream for Streamer { Poll::Ready(Some(Ok(Event::default().data("[DONE]")))) } }, - Err(e) => { if self.status == StreamingStatus::Started && e == flume::TryRecvError::Disconnected { - //no TryRecvError::Disconnected returned even if the client closed the stream or disconnected self.status = StreamingStatus::Interrupted; Poll::Ready(None) } else { From 221eace6e4935d4ca56dd0bfde77b86b79cac47e Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Fri, 23 Aug 2024 17:08:02 +0800 Subject: [PATCH 14/15] Mention arguments for benchmark.py --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index f51b896..3cb96fb 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,9 @@ After the `candle-vllm` service is running, run the Python script and enjoy effi ## Batched requests +``` shell +python3 examples/benchmark.py --batch 16 --max_tokens 1024 +``` Refer to `examples/benchmark.py` ``` python From 08f9491cd37fb6dd1a337312959493b346a908ec Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Fri, 23 Aug 2024 17:09:24 +0800 Subject: [PATCH 15/15] Tweak --- examples/benchmark.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/benchmark.py b/examples/benchmark.py index 72564bb..d5df23b 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -84,10 +84,8 @@ async def benchmark(batch, max_tokens=1024): if __name__ == "__main__": - batch = 1 parser = argparse.ArgumentParser(description="Using 'batch' and 'max_tokens' parameters for candle-vllm benchmark.") parser.add_argument('--batch', default=16, type=int) parser.add_argument('--max_tokens', default=1024, type=int) - args = parser.parse_args() asyncio.run(benchmark(args.batch, args.max_tokens)) \ No newline at end of file