From 6a54ca115eebfd367fbf09d3ea582f6b2e810e0d Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 28 Jul 2023 09:57:32 +0100 Subject: [PATCH] Add some Bigcode model (#260) * Start sketching the bigcode gpt model. * Sketch the bigcode model. * Implement the attention mechanism. * Random reshaping. * Sketch more of the example. * Add some kv cache. * Properly generate the position ids. * Proper attention mask. * Bail on upcasting. * Properly apply the attention mask. * Add the smaller starcoder variants. * Update for the new hub api. * Fix a shape issue. * Fix another shape issue. * Get some logits out. * Adjust the weigth names. --- candle-examples/examples/bigcode/main.rs | 161 ++++++++++ candle-examples/examples/bigcode/model.rs | 357 ++++++++++++++++++++++ 2 files changed, 518 insertions(+) create mode 100644 candle-examples/examples/bigcode/main.rs create mode 100644 candle-examples/examples/bigcode/model.rs diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs new file mode 100644 index 000000000..165d1c8c3 --- /dev/null +++ b/candle-examples/examples/bigcode/main.rs @@ -0,0 +1,161 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::{Error as E, Result}; +use clap::Parser; + +mod model; +use model::{Config, GPTBigCode}; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::generation::LogitsProcessor; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::Tokenizer; + +struct TextGeneration { + model: GPTBigCode, + device: Device, + tokenizer: Tokenizer, + logits_processor: LogitsProcessor, +} + +impl TextGeneration { + fn new( + model: GPTBigCode, + tokenizer: Tokenizer, + seed: u64, + temp: Option, + device: &Device, + ) -> Self { + let logits_processor = LogitsProcessor::new(seed, temp); + Self { + model, + tokenizer, + logits_processor, + device: device.clone(), + } + } + + fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { + println!("starting the inference loop"); + let mut tokens = self + .tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let mut new_tokens = vec![]; + let start_gen = std::time::Instant::now(); + for index in 0..sample_len { + let start_gen = std::time::Instant::now(); + let (context_size, past_len) = if self.model.config().use_cache && index > 0 { + (1, tokens.len().saturating_sub(1)) + } else { + (tokens.len(), 0) + }; + let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; + let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; + let logits = self.model.forward(&input, past_len)?; + let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; + + let next_token = self.logits_processor.sample(&logits)?; + tokens.push(next_token); + new_tokens.push(next_token); + println!("> {:?}", start_gen.elapsed()); + println!( + "{} token: {} '{}'", + index + 1, + next_token, + self.tokenizer + .decode(vec![next_token], true) + .map_err(E::msg)? + ); + } + let dt = start_gen.elapsed(); + println!( + "{sample_len} tokens generated ({} token/s)\n----\n{}\n----", + sample_len as f64 / dt.as_secs_f64(), + self.tokenizer.decode(new_tokens, true).map_err(E::msg)? + ); + Ok(()) + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(long)] + prompt: String, + + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The length of the sample to generate (in tokens). + #[arg(long, default_value_t = 100)] + sample_len: usize, + + #[arg(long, default_value = "bigcode/starcoderbase-1b")] + model_id: String, + + #[arg(long, default_value = "main")] + revision: String, + + #[arg(long)] + weight_file: Option, +} + +fn main() -> Result<()> { + let args = Args::parse(); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let repo = api.repo(Repo::with_revision( + args.model_id, + RepoType::Model, + args.revision, + )); + let tokenizer_filename = repo.get("tokenizer.json")?; + let filenames = match args.weight_file { + Some(weight_file) => vec![std::path::PathBuf::from(weight_file.clone())], + None => { + let repo_filenames: Vec = vec![]; + repo_filenames + .iter() + .map(|f| repo.get(f)) + .collect::, _>>()? + } + }; + println!("retrieved the files in {:?}", start.elapsed()); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let weights = filenames + .iter() + .map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? })) + .collect::>>()?; + let weights = weights + .iter() + .map(|f| Ok(f.deserialize()?)) + .collect::>>()?; + + let start = std::time::Instant::now(); + let device = candle_examples::device(args.cpu)?; + let vb = VarBuilder::from_safetensors(weights, DType::F32, &device); + let config = Config::starcoder_1b(); + let model = GPTBigCode::load(vb, config)?; + println!("loaded the model in {:?}", start.elapsed()); + + let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device); + pipeline.run(&args.prompt, args.sample_len)?; + Ok(()) +} diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs new file mode 100644 index 000000000..e9172adf7 --- /dev/null +++ b/candle-examples/examples/bigcode/model.rs @@ -0,0 +1,357 @@ +use candle::{DType, Device, IndexOp, Result, Tensor, D}; +use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder}; + +fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; + let bias = if bias { + Some(vb.get(size2, "bias")?) + } else { + None + }; + Ok(Linear::new(weight, bias)) +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} + +fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { + let weight = vb.get(size, "weight")?; + let bias = vb.get(size, "bias")?; + Ok(LayerNorm::new(weight, bias, eps)) +} + +fn make_causal_mask(t: usize) -> Result { + let mask: Vec<_> = (0..t) + .flat_map(|i| (0..t).map(move |j| u32::from(j > i))) + .collect(); + let mask = Tensor::from_slice(&mask, (t, t), &Device::Cpu)?; + Ok(mask) +} + +#[derive(Debug)] +pub struct Config { + pub vocab_size: usize, + // max_position_embeddings aka n_positions + pub max_position_embeddings: usize, + // num_hidden_layers aka n_layer + pub num_hidden_layers: usize, + // hidden_size aka n_embd + pub hidden_size: usize, + pub layer_norm_epsilon: f64, + pub n_inner: Option, + // num_attention_heads aka n_head + pub num_attention_heads: usize, + pub multi_query: bool, + pub use_cache: bool, +} + +impl Config { + #[allow(dead_code)] + pub fn starcoder_1b() -> Self { + Self { + vocab_size: 49152, + max_position_embeddings: 8192, + num_hidden_layers: 24, + hidden_size: 2048, + layer_norm_epsilon: 1e-5, + n_inner: Some(8192), + num_attention_heads: 16, + multi_query: true, + use_cache: true, + } + } + + #[allow(dead_code)] + pub fn starcoder_3b() -> Self { + Self { + vocab_size: 49152, + max_position_embeddings: 8192, + num_hidden_layers: 36, + hidden_size: 2816, + layer_norm_epsilon: 1e-5, + n_inner: Some(11264), + num_attention_heads: 22, + multi_query: true, + use_cache: true, + } + } + + #[allow(dead_code)] + pub fn starcoder_7b() -> Self { + Self { + vocab_size: 49152, + max_position_embeddings: 8192, + num_hidden_layers: 42, + hidden_size: 4096, + layer_norm_epsilon: 1e-5, + n_inner: Some(16384), + num_attention_heads: 32, + multi_query: true, + use_cache: true, + } + } + + #[allow(dead_code)] + pub fn starcoder() -> Self { + Self { + vocab_size: 49152, + max_position_embeddings: 8192, + num_hidden_layers: 40, + hidden_size: 6144, + layer_norm_epsilon: 1e-5, + n_inner: Some(24576), + num_attention_heads: 48, + multi_query: true, + use_cache: true, + } + } +} + +struct Attention { + c_attn: Linear, + c_proj: Linear, + kv_cache: Option, + use_cache: bool, + embed_dim: usize, + kv_dim: usize, + num_heads: usize, + head_dim: usize, + multi_query: bool, +} + +impl Attention { + pub fn load(vb: VarBuilder, cfg: &Config) -> Result { + let hidden_size = cfg.hidden_size; + let head_dim = hidden_size / cfg.num_attention_heads; + let kv_heads = if cfg.multi_query { + 1 + } else { + cfg.num_attention_heads + }; + let kv_dim = kv_heads * head_dim; + let c_attn = linear(hidden_size, hidden_size + 2 * kv_dim, true, vb.pp("c_attn"))?; + let c_proj = linear(hidden_size, hidden_size, true, vb.pp("c_proj"))?; + Ok(Self { + c_proj, + c_attn, + embed_dim: hidden_size, + kv_cache: None, + use_cache: cfg.use_cache, + kv_dim, + head_dim, + num_heads: cfg.num_attention_heads, + multi_query: cfg.multi_query, + }) + } + + fn attn( + &self, + query: &Tensor, + key: &Tensor, + value: &Tensor, + attention_mask: &Tensor, + ) -> Result { + if query.dtype() != DType::F32 { + // If we start supporting f16 models, we may need the upcasting scaling bits. + // https://github.com/huggingface/transformers/blob/a0042379269bea9182c1f87e6b2eee4ba4c8cce8/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py#L133 + candle::bail!("upcasting is not supported {:?}", query.dtype()) + } + let scale_factor = 1f64 / (self.head_dim as f64).sqrt(); + let initial_query_shape = query.shape(); + let key_len = key.dim(D::Minus1)?; + let (query, key, attn_shape, attn_view) = if self.multi_query { + let (b_sz, query_len, _) = query.dims3()?; + let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?; + let attn_shape = (b_sz, query_len, self.num_heads, key_len); + let attn_view = (b_sz, query_len * self.num_heads, key_len); + (query, key.clone(), attn_shape, attn_view) + } else { + let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?; + let query = query.reshape((b_sz, query_len * self.num_heads, self.head_dim))?; + let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?; + let attn_shape = (b_sz, self.num_heads, query_len, key_len); + let attn_view = (b_sz * self.num_heads, query_len, key_len); + (query, key, attn_shape, attn_view) + }; + + let attn_weights = (query.matmul(&key)? * scale_factor)?.reshape(attn_shape)?; + let attention_mask = attention_mask.broadcast_as(attn_shape)?; + let mask_value = + Tensor::new(f32::NEG_INFINITY, query.device())?.broadcast_as(attn_shape)?; + let attn_weights = attention_mask.where_cond(&attn_weights, &mask_value)?; + let attn_weights = attn_weights.softmax(D::Minus1)?; + let attn_output = if self.multi_query { + attn_weights + .reshape(attn_view)? + .matmul(value)? + .reshape(initial_query_shape)? + } else { + attn_weights.matmul(value)? + }; + Ok(attn_output) + } + + fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let qkv = self.c_attn.forward(hidden_states)?; + let (query, key_value) = if self.multi_query { + let query = qkv.i((.., .., ..self.embed_dim))?; + let key_value = qkv.i((.., .., self.embed_dim..self.embed_dim + 2 * self.kv_dim))?; + (query, key_value) + } else { + let mut dims = qkv.dims().to_vec(); + dims.pop(); + dims.push(self.embed_dim); + dims.push(self.head_dim * 3); + let qkv = qkv.reshape(dims)?.transpose(1, 2)?; + let query = qkv.i((.., .., .., ..self.head_dim))?; + let key_value = qkv.i((.., .., .., self.head_dim..3 * self.head_dim))?; + (query, key_value) + }; + let mut key_value = key_value; + if self.use_cache { + if let Some(kv_cache) = &self.kv_cache { + // TODO: we could trim the tensors to MAX_SEQ_LEN so that this would work for + // arbitrarily large sizes. + key_value = Tensor::cat(&[kv_cache, &key_value], D::Minus2)?.contiguous()?; + } + self.kv_cache = Some(key_value.clone()) + } + + let key = key_value.narrow(D::Minus1, 0, self.head_dim)?; + let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?; + let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?; + let attn_output = if self.multi_query { + attn_output + } else { + attn_output + .transpose(1, 2)? + .reshape(hidden_states.shape())? + }; + let attn_output = self.c_proj.forward(&attn_output)?; + Ok(attn_output) + } +} + +struct Mlp { + c_fc: Linear, + c_proj: Linear, +} + +impl Mlp { + fn load(inner_dim: usize, vb: VarBuilder, cfg: &Config) -> Result { + let c_fc = linear(cfg.hidden_size, inner_dim, true, vb.pp("c_fc"))?; + let c_proj = linear(inner_dim, cfg.hidden_size, true, vb.pp("c_proj"))?; + Ok(Self { c_fc, c_proj }) + } + + fn forward(&mut self, hidden_states: &Tensor) -> Result { + let hidden_states = self.c_fc.forward(hidden_states)?.gelu()?; + let hidden_states = self.c_proj.forward(&hidden_states)?; + Ok(hidden_states) + } +} + +// TODO: Add cross-attention? +struct Block { + ln_1: LayerNorm, + attn: Attention, + ln_2: LayerNorm, + mlp: Mlp, +} + +impl Block { + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let hidden_size = cfg.hidden_size; + let inner_dim = cfg.n_inner.unwrap_or(4 * hidden_size); + let ln_1 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_1"))?; + let attn = Attention::load(vb.pp("attn"), cfg)?; + let ln_2 = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb.pp("ln_2"))?; + let mlp = Mlp::load(inner_dim, vb.pp("mlp"), cfg)?; + Ok(Self { + ln_1, + attn, + ln_2, + mlp, + }) + } + + fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let residual = hidden_states; + let hidden_states = self.ln_1.forward(hidden_states)?; + let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?; + let hidden_states = (&attn_outputs + residual)?; + let residual = &hidden_states; + let hidden_states = self.ln_2.forward(&hidden_states)?; + let hidden_states = self.mlp.forward(&hidden_states)?; + let hidden_states = (&hidden_states + residual)?; + Ok(hidden_states) + } +} + +pub struct GPTBigCode { + wte: Embedding, + wpe: Embedding, + blocks: Vec, + ln_f: LayerNorm, + lm_head: Linear, + bias: Tensor, + config: Config, +} + +impl GPTBigCode { + pub fn config(&self) -> &Config { + &self.config + } + + pub fn load(vb: VarBuilder, cfg: Config) -> Result { + let hidden_size = cfg.hidden_size; + let vb_t = vb.pp("transformer"); + let wte = embedding(cfg.vocab_size, hidden_size, vb_t.pp("wte"))?; + let wpe = embedding(cfg.max_position_embeddings, hidden_size, vb_t.pp("wpe"))?; + let blocks = (0..cfg.num_hidden_layers) + .map(|i| Block::load(vb_t.pp(&format!("h.{i}")), &cfg)) + .collect::>>()?; + let ln_f = layer_norm(hidden_size, cfg.layer_norm_epsilon, vb_t.pp("ln_f"))?; + let lm_head = linear(hidden_size, cfg.vocab_size, false, vb.pp("lm_head"))?; + let bias = make_causal_mask(cfg.max_position_embeddings)?; + Ok(Self { + wte, + wpe, + blocks, + lm_head, + ln_f, + bias, + config: cfg, + }) + } + + pub fn forward(&mut self, input_ids: &Tensor, past_len: usize) -> Result { + let dev = input_ids.device(); + let (b_sz, seq_len) = input_ids.dims2()?; + + let key_len = past_len + seq_len; + let attention_mask = self.bias.i((past_len..key_len, ..key_len))?.unsqueeze(0)?; + // MQA models: (batch_size, query_length, n_heads, key_length) + // MHA models: (batch_size, n_heads, query_length, key_length) + let seq_len_dim = if self.config.multi_query { 2 } else { 1 }; + let attention_mask = attention_mask.unsqueeze(seq_len_dim)?; + + let position_ids = Tensor::arange(past_len as u32, (past_len + seq_len) as u32, dev)?; + let position_ids = position_ids.unsqueeze(0)?.broadcast_as((b_sz, seq_len))?; + let input_embeds = self.wte.forward(input_ids)?; + let position_embeds = self.wpe.forward(&position_ids)?; + + let mut hidden_states = (&input_embeds + &position_embeds)?; + for block in self.blocks.iter_mut() { + hidden_states = block.forward(&hidden_states, &attention_mask)?; + } + let hidden_states = self.ln_f.forward(&hidden_states)?; + let hidden_states = hidden_states + .reshape((b_sz, seq_len, self.config.hidden_size))? + .narrow(1, seq_len - 1, 1)?; + let logits = self.lm_head.forward(&hidden_states)?.squeeze(1)?; + Ok(logits) + } +}