Skip to content

Commit

Permalink
Implement the attention mechanism.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jul 27, 2023
1 parent db20b45 commit 24060ee
Showing 1 changed file with 53 additions and 22 deletions.
75 changes: 53 additions & 22 deletions candle-examples/examples/bigcode/model.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use anyhow::Result;
use candle::{DType, Device, IndexOp, Tensor, D};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, LayerNorm, Linear, VarBuilder};

fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
Expand All @@ -18,16 +17,8 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
}

fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let (weight, bias) = match (vb.get(size, "weight"), vb.get(size, "bias")) {
(Ok(weight), Ok(bias)) => (weight, bias),
(Err(err), _) | (_, Err(err)) => {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias)
} else {
return Err(err.into());
}
}
};
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;
Ok(LayerNorm::new(weight, bias, eps))
}

Expand All @@ -48,6 +39,7 @@ struct Attention {
c_proj: Linear,
embed_dim: usize,
kv_dim: usize,
num_heads: usize,
head_dim: usize,
multi_query: bool,
}
Expand All @@ -70,15 +62,44 @@ impl Attention {
embed_dim: hidden_size,
kv_dim,
head_dim,
num_heads: cfg.num_attention_heads,
multi_query: cfg.multi_query,
})
}

fn attn(&self, query: &Tensor, key: &Tensor, value: &Tensor) -> Result<(Tensor, Tensor)> {
todo!()
fn attn(
&self,
query: &Tensor,
key: &Tensor,
value: &Tensor,
attention_mask: &Tensor,
) -> Result<Tensor> {
// TODO: check if we need scaling/upcasting.
let scale_factor = 1f64 / (self.head_dim as f64).sqrt();
let key_len = key.dim(D::Minus1)?;
let (query, key, attn_shape) = if self.multi_query {
let (b_sz, query_len, _) = query.dims3()?;
let query = query.reshape((b_sz, query_len * self.num_heads, key_len))?;
let attn_shape = (b_sz, query_len, self.num_heads, key_len);
(query, key.clone(), attn_shape)
} else {
let (b_sz, _num_heads, query_len, _head_dime) = query.dims4()?;
let query = query.reshape((b_sz, query_len * self.num_heads, key_len))?;
let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?;
let attn_shape = (b_sz, self.num_heads, query_len, key_len);
(query, key, attn_shape)
};
let attn_weights = (query.matmul(&key)? * scale_factor)?.reshape(attn_shape)?;
let attn_weights = attn_weights.softmax(D::Minus1)?;
let attn_output = if self.multi_query {
todo!()
} else {
attn_weights.matmul(value)?
};
Ok(attn_output)
}

fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> {
fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let qkv = self.c_attn.forward(hidden_states)?;
let (query, key_value) = if self.multi_query {
let query = qkv.i((.., .., ..self.embed_dim))?;
Expand All @@ -97,8 +118,16 @@ impl Attention {
// TODO: layer past
let key = key_value.narrow(D::Minus1, 0, self.head_dim)?;
let value = key_value.narrow(D::Minus1, self.head_dim, self.head_dim)?;
let (attn_output, attn_weights) = self.attn(&query, &key.t()?, &value)?; // TODO: masks
todo!()
let attn_output = self.attn(&query, &key.t()?, &value, attention_mask)?;
let attn_output = if self.multi_query {
attn_output
} else {
attn_output
.transpose(1, 2)?
.reshape(hidden_states.shape())?
};
let attn_output = self.c_proj.forward(&attn_output)?;
Ok(attn_output)
}
}

Expand Down Expand Up @@ -145,10 +174,10 @@ impl Block {
})
}

fn forward(&mut self, hidden_states: &Tensor) -> Result<Tensor> {
fn forward(&mut self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
let residual = hidden_states;
let hidden_states = self.ln_1.forward(hidden_states)?;
let attn_outputs = self.attn.forward(&hidden_states)?;
let attn_outputs = self.attn.forward(&hidden_states, attention_mask)?;
let hidden_states = (&attn_outputs + residual)?;
let residual = &hidden_states;
let hidden_states = self.ln_2.forward(&hidden_states)?;
Expand Down Expand Up @@ -191,13 +220,15 @@ impl GPTBigCode {
})
}

pub fn forward(&mut self, input_ids: &Tensor, position_ids: &Tensor) -> Result<Tensor> {
pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let attention_mask = Tensor::zeros(1, DType::F32, input_ids.device())?; // TODO
let position_ids = Tensor::zeros(1, DType::F32, input_ids.device())?; // TODO
let (_b_sz, seq_len) = input_ids.dims2()?;
let input_embeds = self.wte.forward(input_ids)?;
let position_embeds = self.wpe.forward(position_ids)?;
let position_embeds = self.wpe.forward(&position_ids)?;
let mut hidden_states = (&input_embeds + &position_embeds)?;
for block in self.blocks.iter_mut() {
hidden_states = block.forward(&hidden_states)?;
hidden_states = block.forward(&hidden_states, &attention_mask)?;
}
let hidden_states = self.ln_f.forward(&hidden_states)?;
let hidden_states = hidden_states.i((.., seq_len - 1, seq_len))?;
Expand Down

0 comments on commit 24060ee

Please sign in to comment.