diff --git a/candle-examples/examples/quantized-phi/main.rs b/candle-examples/examples/quantized-phi/main.rs new file mode 100644 index 000000000..301c2e06e --- /dev/null +++ b/candle-examples/examples/quantized-phi/main.rs @@ -0,0 +1,273 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::{Parser, ValueEnum}; +use std::io::Write; +use tokenizers::Tokenizer; + +use candle::quantized::gguf_file; +use candle::Tensor; +use candle_transformers::generation::{LogitsProcessor, Sampling}; + +use candle_examples::token_output_stream::TokenOutputStream; +use candle_transformers::models::quantized_phi as model; +use model::ModelWeights; + +const DEFAULT_PROMPT: &str = "Write a function to count prime numbers up to N. "; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)] +enum Which { + #[value(name = "phi-2")] + Phi2, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp + #[arg(long)] + model: Option, + + /// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way + /// and 'chat' for an interactive model where history of previous prompts and generated tokens + /// is preserved. + #[arg(long)] + prompt: Option, + + /// The length of the sample to generate (in tokens). + #[arg(short = 'n', long, default_value_t = 1000)] + sample_len: usize, + + /// The tokenizer config in json format. + #[arg(long)] + tokenizer: Option, + + /// The temperature used to generate samples, use 0 for greedy sampling. + #[arg(long, default_value_t = 0.8)] + temperature: f64, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Process prompt elements separately. + #[arg(long)] + split_prompt: bool, + + /// Run on CPU rather than GPU even if a GPU is available. + #[arg(long)] + cpu: bool, + + /// Penalty to be applied for repeating tokens, 1. means no penalty. + #[arg(long, default_value_t = 1.1)] + repeat_penalty: f32, + + /// The context size to consider for the repeat penalty. + #[arg(long, default_value_t = 64)] + repeat_last_n: usize, + + /// The model size to use. + #[arg(long, default_value = "phi-2")] + which: Which, +} + +impl Args { + fn tokenizer(&self) -> anyhow::Result { + let tokenizer_path = match &self.tokenizer { + Some(config) => std::path::PathBuf::from(config), + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("microsoft/phi-2".to_string()); + api.get("tokenizer.json")? + } + }; + Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg) + } + + fn model(&self) -> anyhow::Result { + let model_path = match &self.model { + Some(config) => std::path::PathBuf::from(config), + None => { + let (repo, filename) = match self.which { + Which::Phi2 => ("TheBloke/phi-2-GGUF", "phi-2.Q4_K_M.gguf"), + }; + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(repo.to_string()); + api.get(filename)? + } + }; + Ok(model_path) + } +} + +fn format_size(size_in_bytes: usize) -> String { + if size_in_bytes < 1_000 { + format!("{}B", size_in_bytes) + } else if size_in_bytes < 1_000_000 { + format!("{:.2}KB", size_in_bytes as f64 / 1e3) + } else if size_in_bytes < 1_000_000_000 { + format!("{:.2}MB", size_in_bytes as f64 / 1e6) + } else { + format!("{:.2}GB", size_in_bytes as f64 / 1e9) + } +} + +fn main() -> anyhow::Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + println!( + "temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", + args.temperature, args.repeat_penalty, args.repeat_last_n + ); + + let model_path = args.model()?; + let mut file = std::fs::File::open(&model_path)?; + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + + let mut model = { + let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?; + let mut total_size_in_bytes = 0; + for (_, tensor) in model.tensor_infos.iter() { + let elem_count = tensor.shape.elem_count(); + total_size_in_bytes += + elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size(); + } + println!( + "loaded {:?} tensors ({}) in {:.2}s", + model.tensor_infos.len(), + &format_size(total_size_in_bytes), + start.elapsed().as_secs_f32(), + ); + ModelWeights::from_gguf(model, &mut file, &device)? + }; + println!("model built"); + + let tokenizer = args.tokenizer()?; + let mut tos = TokenOutputStream::new(tokenizer); + let prompt_str = args.prompt.unwrap_or_else(|| DEFAULT_PROMPT.to_string()); + print!("{}", &prompt_str); + let tokens = tos + .tokenizer() + .encode(prompt_str, true) + .map_err(anyhow::Error::msg)?; + let tokens = tokens.get_ids(); + let to_sample = args.sample_len.saturating_sub(1); + let mut all_tokens = vec![]; + let mut logits_processor = { + let temperature = args.temperature; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + let start_prompt_processing = std::time::Instant::now(); + let mut next_token = if !args.split_prompt { + let input = Tensor::new(tokens, &device)?.unsqueeze(0)?; + let logits = model.forward(&input, 0)?; + let logits = logits.squeeze(0)?; + logits_processor.sample(&logits)? + } else { + let mut next_token = 0; + for (pos, token) in tokens.iter().enumerate() { + let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, pos)?; + let logits = logits.squeeze(0)?; + next_token = logits_processor.sample(&logits)? + } + next_token + }; + let prompt_dt = start_prompt_processing.elapsed(); + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + let eos_token = *tos + .tokenizer() + .get_vocab(true) + .get("<|endoftext|>") + .unwrap(); + let start_post_prompt = std::time::Instant::now(); + let mut sampled = 0; + for index in 0..to_sample { + let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?; + let logits = model.forward(&input, tokens.len() + index)?; + let logits = logits.squeeze(0)?; + let logits = if args.repeat_penalty == 1. { + logits + } else { + let start_at = all_tokens.len().saturating_sub(args.repeat_last_n); + candle_transformers::utils::apply_repeat_penalty( + &logits, + args.repeat_penalty, + &all_tokens[start_at..], + )? + }; + next_token = logits_processor.sample(&logits)?; + all_tokens.push(next_token); + if let Some(t) = tos.next_token(next_token)? { + print!("{t}"); + std::io::stdout().flush()?; + } + sampled += 1; + if next_token == eos_token { + break; + }; + } + if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? { + print!("{rest}"); + } + std::io::stdout().flush()?; + let dt = start_post_prompt.elapsed(); + println!( + "\n\n{:4} prompt tokens processed: {:.2} token/s", + tokens.len(), + tokens.len() as f64 / prompt_dt.as_secs_f64(), + ); + println!( + "{sampled:4} tokens generated: {:.2} token/s", + sampled as f64 / dt.as_secs_f64(), + ); + Ok(()) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index f4a719319..5f1a40ad0 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -37,6 +37,7 @@ pub mod quantized_mistral; pub mod quantized_mixformer; pub mod quantized_moondream; pub mod quantized_mpt; +pub mod quantized_phi; pub mod quantized_recurrent_gemma; pub mod quantized_rwkv_v5; pub mod quantized_rwkv_v6; diff --git a/candle-transformers/src/models/quantized_phi.rs b/candle-transformers/src/models/quantized_phi.rs new file mode 100644 index 000000000..0ebf7f4d4 --- /dev/null +++ b/candle-transformers/src/models/quantized_phi.rs @@ -0,0 +1,288 @@ +use std::collections::HashMap; + +use candle::quantized::gguf_file; +use candle::quantized::QTensor; +use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; +use candle_nn::{Embedding, LayerNorm}; + +pub const MAX_SEQ_LEN: usize = 4096; + +#[derive(Debug, Clone)] +struct QLinear { + inner: candle::quantized::QMatMul, + bias: Tensor, + span: tracing::Span, +} + +impl QLinear { + fn new( + ct: &gguf_file::Content, + r: &mut R, + name: &str, + device: &Device, + ) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "qmatmul"); + let w = ct.tensor(r, &format!("{name}.weight"), device)?; + let b = ct.tensor(r, &format!("{name}.bias"), device)?; + let inner = candle::quantized::QMatMul::from_qtensor(w)?; + let bias = b.dequantize(device)?; + Ok(Self { inner, bias, span }) + } +} + +impl Module for QLinear { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + self.inner.forward(xs)?.broadcast_add(&self.bias) + } +} + +#[derive(Debug, Clone)] +struct Mlp { + ffn_up: QLinear, + ffn_down: QLinear, +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + xs.apply(&self.ffn_up)?.gelu()?.apply(&self.ffn_down) + } +} + +#[derive(Debug, Clone)] +struct LayerWeights { + attn_qkv: QLinear, + attn_output: QLinear, + attn_norm: LayerNorm, + mlp: Mlp, + n_head: usize, + n_kv_head: usize, + head_dim: usize, + cos: Tensor, + sin: Tensor, + rope_dim: usize, + neg_inf: Tensor, + kv_cache: Option<(Tensor, Tensor)>, + span_attn: tracing::Span, + span_rot: tracing::Span, +} + +fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: &Tensor) -> Result { + let shape = mask.shape(); + let m = mask.where_cond(&on_true.broadcast_as(shape.dims())?, on_false)?; + Ok(m) +} + +impl LayerWeights { + fn apply_rotary_emb(&self, xs: &Tensor, index_pos: usize) -> Result { + let _enter = self.span_rot.enter(); + let (_b_sz, _n_head, seq_len, _n_embd) = xs.dims4()?; + let xs_rot = xs.i((.., .., .., ..self.rope_dim))?; + let xs_pass = xs.i((.., .., .., self.rope_dim..))?; + let cos = self.cos.narrow(0, index_pos, seq_len)?; + let sin = self.sin.narrow(0, index_pos, seq_len)?; + let xs_rot = candle_nn::rotary_emb::rope(&xs_rot.contiguous()?, &cos, &sin)?; + Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1) + } + + fn forward_attn( + &mut self, + x: &Tensor, + mask: Option<&Tensor>, + index_pos: usize, + ) -> Result { + let _enter = self.span_attn.enter(); + let (b_sz, seq_len, n_embd) = x.dims3()?; + let qkv = + self.attn_qkv + .forward(x)? + .reshape((b_sz, seq_len, 3, self.n_head, self.head_dim))?; + + let q = qkv.i((.., .., 0))?.transpose(1, 2)?; + let k = qkv.i((.., .., 1))?.transpose(1, 2)?; + let v = qkv.i((.., .., 2))?.transpose(1, 2)?; + // This call to contiguous ensures that the fast kernel can be called below. It's + // actually a no-op except when processing the initial prompt so has no significant + // impact on performance. + let v = v.contiguous()?; + + let q = self.apply_rotary_emb(&q, index_pos)?.contiguous()?; + let k = self.apply_rotary_emb(&k, index_pos)?; + + let (k, v) = match &self.kv_cache { + None => (k.contiguous()?, v.contiguous()?), + Some((k_cache, v_cache)) => { + if index_pos == 0 { + (k.contiguous()?, v.contiguous()?) + } else { + let k = Tensor::cat(&[k_cache, &k], 2)?; + let v = Tensor::cat(&[v_cache, &v], 2)?; + (k.contiguous()?, v.contiguous()?) + } + } + }; + self.kv_cache = Some((k.clone(), v.clone())); + + let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?; + let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?; + + let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; + let att = match mask { + None => att, + Some(mask) => { + let mask = mask.broadcast_as(att.shape())?; + masked_fill(&att, &mask, &self.neg_inf)? + } + }; + let att = candle_nn::ops::softmax_last_dim(&att)?; + // Convert to contiguous as matmul doesn't support strided vs for now. + let y = att.matmul(&v.contiguous()?)?; + let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; + let y = self.attn_output.forward(&y)?; + Ok(y) + } +} + +#[derive(Debug, Clone)] +pub struct ModelWeights { + tok_embeddings: Embedding, + layers: Vec, + output_norm: LayerNorm, + output: QLinear, + masks: HashMap, + span: tracing::Span, + span_output: tracing::Span, +} + +fn precomput_freqs_cis( + head_dim: usize, + freq_base: f32, + device: &Device, +) -> Result<(Tensor, Tensor)> { + let theta: Vec<_> = (0..head_dim) + .step_by(2) + .map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32)) + .collect(); + let theta = Tensor::new(theta.as_slice(), device)?; + let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)? + .to_dtype(DType::F32)? + .reshape((MAX_SEQ_LEN, 1))? + .matmul(&theta.reshape((1, theta.elem_count()))?)?; + let cos = idx_theta.cos()?; + let sin = idx_theta.sin()?; + Ok((cos, sin)) +} + +fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result { + let w = w.dequantize(&w.device())?; + let b = b.dequantize(&b.device())?; + let ln = LayerNorm::new(w, b, eps); + Ok(ln) +} + +impl ModelWeights { + pub fn from_gguf( + ct: gguf_file::Content, + reader: &mut R, + device: &Device, + ) -> Result { + let md_get = |s: &str| match ct.metadata.get(s) { + None => candle::bail!("cannot find {s} in metadata"), + Some(v) => Ok(v), + }; + + // Parameter extraction from metadata. + let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("phi2.block_count")?.to_u32()? as usize; + let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize; + let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize; + let ln_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64; + let (cos, sin) = precomput_freqs_cis(rope_dim, 10_000., device)?; + let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?; + + let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; + let tok_embeddings = tok_embeddings.dequantize(device)?; + let output_norm = layer_norm( + ct.tensor(reader, "output_norm.weight", device)?, + ct.tensor(reader, "output_norm.bias", device)?, + ln_eps, + )?; + let output = QLinear::new(&ct, reader, "output", device)?; + let mut layers = Vec::with_capacity(block_count); + for layer_idx in 0..block_count { + let prefix = format!("blk.{layer_idx}"); + let ffn_up = QLinear::new(&ct, reader, &format!("{prefix}.ffn_up"), device)?; + let ffn_down = QLinear::new(&ct, reader, &format!("{prefix}.ffn_down"), device)?; + let mlp = Mlp { ffn_up, ffn_down }; + let attn_norm = layer_norm( + ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?, + ct.tensor(reader, &format!("{prefix}.attn_norm.bias"), device)?, + ln_eps, + )?; + let span_attn = tracing::span!(tracing::Level::TRACE, "attn"); + let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot"); + layers.push(LayerWeights { + attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?, + attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?, + attn_norm, + mlp, + n_head: head_count, + n_kv_head: head_count_kv, + head_dim: embedding_length / head_count, + cos: cos.clone(), + sin: sin.clone(), + rope_dim, + neg_inf: neg_inf.clone(), + kv_cache: None, + span_attn, + span_rot, + }) + } + let span = tracing::span!(tracing::Level::TRACE, "model"); + let span_output = tracing::span!(tracing::Level::TRACE, "output"); + Ok(Self { + tok_embeddings: Embedding::new(tok_embeddings, embedding_length), + layers, + output_norm, + output, + masks: HashMap::new(), + span, + span_output, + }) + } + + fn mask(&mut self, t: usize, device: &Device) -> Result { + if let Some(mask) = self.masks.get(&t) { + Ok(mask.clone()) + } else { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u8::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), device)?; + self.masks.insert(t, mask.clone()); + Ok(mask) + } + } + + pub fn forward(&mut self, xs: &Tensor, index_pos: usize) -> Result { + let (_b_sz, seq_len) = xs.dims2()?; + let mask = if seq_len == 1 { + None + } else { + Some(self.mask(seq_len, xs.device())?) + }; + let _enter = self.span.enter(); + let mut xs = self.tok_embeddings.forward(xs)?; + for layer in self.layers.iter_mut() { + let residual = &xs; + let xs_norm = xs.apply(&layer.attn_norm)?; + let attn_outputs = layer.forward_attn(&xs_norm, mask.as_ref(), index_pos)?; + let feed_forward_hidden_states = layer.mlp.forward(&xs_norm)?; + xs = (attn_outputs + feed_forward_hidden_states + residual)? + } + let xs = xs.apply(&self.output_norm)?.i((.., seq_len - 1, ..))?; + let _enter = self.span_output.enter(); + self.output.forward(&xs) + } +}