Skip to content

Commit

Permalink
Move the function to utils + use it in mistral.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Apr 11, 2024
1 parent c40098d commit 677f22d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 26 deletions.
16 changes: 2 additions & 14 deletions candle-transformers/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,6 @@ impl Attention {
})
}

fn repeat_kv(&self, xs: Tensor) -> Result<Tensor> {
let n_rep = self.num_kv_groups;
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, num_kv_heads, seq_len, head_dim) = xs.dims4()?;
xs.unsqueeze(2)?
.expand((b_sz, num_kv_heads, n_rep, seq_len, head_dim))?
.reshape((b_sz, num_kv_heads * n_rep, seq_len, head_dim))
}
}

fn forward(
&mut self,
xs: &Tensor,
Expand Down Expand Up @@ -266,8 +254,8 @@ impl Attention {
};
self.kv_cache = Some((key_states.clone(), value_states.clone()));

let key_states = self.repeat_kv(key_states)?;
let value_states = self.repeat_kv(value_states)?;
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?;
let value_states = crate::utils::repeat_kv(value_states, self.num_kv_groups)?;

let attn_output = if self.use_flash_attn {
// flash-attn expects (b_sz, seq_len, nheads, head_dim)
Expand Down
14 changes: 2 additions & 12 deletions candle-transformers/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ impl LayerWeights {
self.kv_cache = Some((k.clone(), v.clone()));

// Support for MQA, useful for 70B models and mistral.
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;

let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
let att = match mask {
Expand All @@ -224,16 +224,6 @@ impl LayerWeights {
let y = self.attention_wo.forward(&y)?;
Ok(y)
}

fn repeat_kv(&self, x: Tensor) -> Result<Tensor> {
let n_rep = self.n_head / self.n_kv_head;
if n_rep == 1 {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
}
}
}

#[derive(Debug, Clone)]
Expand Down
14 changes: 14 additions & 0 deletions candle-transformers/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,17 @@ pub fn apply_repeat_penalty(logits: &Tensor, penalty: f32, context: &[u32]) -> R
let logits_len = logits.len();
Tensor::from_vec(logits, logits_len, device)
}

/// Repeats a key or value tensor for grouped query attention
/// The input tensor should have a shape `(batch, num_kv_heads, seq_len, head_dim)`,
pub fn repeat_kv(xs: Tensor, n_rep: usize) -> Result<Tensor> {
if n_rep == 1 {
Ok(xs)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = xs.dims4()?;
// Using cat is faster than a broadcast as it avoids going through a potentially
// strided copy.
// https://github.com/huggingface/candle/pull/2043
Tensor::cat(&vec![&xs; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
}
}

0 comments on commit 677f22d

Please sign in to comment.