diff --git a/candle-examples/examples/bigcode/model.rs b/candle-examples/examples/bigcode/model.rs index 4190f6667..f23d6ca8f 100644 --- a/candle-examples/examples/bigcode/model.rs +++ b/candle-examples/examples/bigcode/model.rs @@ -1,5 +1,4 @@ -use anyhow::Result; -use candle::{DType, Device, IndexOp, Tensor, D}; +use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder}; fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result { @@ -18,16 +17,8 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result Result { - let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) { - (Ok(weight), Ok(bias)) => (weight, bias), - (Err(err), _) | (_, Err(err)) => { - if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { - (weight, bias) - } else { - return Err(err.into()); - } - } - }; + let weight = vb.get(size, "weight")?; + let bias = vb.get(size, "bias")?; Ok(LayerNorm::new(weight, bias, eps)) } @@ -48,6 +39,7 @@ struct Attention { c_proj: Linear, embed_dim: usize, kv_dim: usize, + num_heads: usize, head_dim: usize, multi_query: bool, } @@ -70,15 +62,44 @@ impl Attention { embed_dim: hidden_size, kv_dim, head_dim, + num_heads: cfg.num_attention_heads, multi_query: cfg.multi_query, }) } - fn attn(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<(Tensor, Tensor)> { - todo!() + fn attn( + &self, + query: &Tensor, + key: &Tensor, + value: &Tensor, + attention_mask: &Tensor, + ) -> Result { + // TODO: check if we need scaling/upcasting. + let scale_factor = 1f64 / (self.head_dim as f64).sqrt(); + let key_len = key.dim(D::Minus1)?; + let (query, key, attn_shape) = if self.multi_query { + let (b_sz, query_len, _) = query.dims3()?; + let query = query.reshape((b_sz, query_len * self.num_heads, key_len))?; + let attn_shape = (b_sz, query_len, self.num_heads, key_len); + (query, key.clone(), attn_shape) + } else { + let (b_sz, _num_heads, query_len, _head_dime) = query.dims4()?; + let query = query.reshape((b_sz, query_len * self.num_heads, key_len))?; + let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?; + let attn_shape = (b_sz, self.num_heads, query_len, key_len); + (query, key, attn_shape) + }; + let attn_weights = (query.matmul(&key)? * scale_factor)?.reshape(attn_shape)?; + let attn_weights = attn_weights.softmax(D::Minus1)?; + let attn_output = if self.multi_query { + todo!() + } else { + attn_weights.matmul(value)? + }; + Ok(attn_output) } - fn forward(&mut self, hidden_states: &Tensor) -> Result { + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { let qkv = self.c_attn.forward(hidden_states)?; let (query, key_value) = if self.multi_query { let query = qkv.i((.., .., ..self.embed_dim))?; @@ -97,8 +118,16 @@ impl Attention { // TODO: layer past let key = key_value.narrow(D::Minus1, 0, self.head_dim)?; let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?; - let (attn_output, attn_weights) = self.attn(&query, &key.t()?, &value)?; // TODO: masks - todo!() + let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?; + let attn_output = if self.multi_query { + attn_output + } else { + attn_output + .transpose(1, 2)? + .reshape(hidden_states.shape())? + }; + let attn_output = self.c_proj.forward(&attn_output)?; + Ok(attn_output) } } @@ -145,10 +174,10 @@ impl Block { }) } - fn forward(&mut self, hidden_states: &Tensor) -> Result { + fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { let residual = hidden_states; let hidden_states = self.ln_1.forward(hidden_states)?; - let attn_outputs = self.attn.forward(&hidden_states)?; + let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?; let hidden_states = (&attn_outputs + residual)?; let residual = &hidden_states; let hidden_states = self.ln_2.forward(&hidden_states)?; @@ -191,13 +220,15 @@ impl GPTBigCode { }) } - pub fn forward(&mut self, input_ids: &Tensor, position_ids: &Tensor) -> Result { + pub fn forward(&mut self, input_ids: &Tensor) -> Result { + let attention_mask = Tensor::zeros(1, DType::F32, input_ids.device())?; // TODO + let position_ids = Tensor::zeros(1, DType::F32, input_ids.device())?; // TODO let (_b_sz, seq_len) = input_ids.dims2()?; let input_embeds = self.wte.forward(input_ids)?; - let position_embeds = self.wpe.forward(position_ids)?; + let position_embeds = self.wpe.forward(&position_ids)?; let mut hidden_states = (&input_embeds + &position_embeds)?; for block in self.blocks.iter_mut() { - hidden_states = block.forward(&hidden_states)?; + hidden_states = block.forward(&hidden_states, &attention_mask)?; } let hidden_states = self.ln_f.forward(&hidden_states)?; let hidden_states = hidden_states.i((.., seq_len - 1, seq_len))?;