Skip to content

Commit

Permalink
More rotating kv-cache.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 22, 2024
1 parent 3277844 commit d6f01f6
Showing 1 changed file with 9 additions and 15 deletions.
24 changes: 9 additions & 15 deletions candle-transformers/src/models/mimi/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,6 @@ impl StreamingMultiheadAttention {
let pre_ws = match mask {
None => pre_ws,
Some(mask) => {
// This is a bit cumbersome and slightly incorrect: when providing a new slice
// the kv cache will have a slice offset rather than offset + t. In the mimi
// context of an offset of 250, this would not make much difference though.
let mask_len = mask.dim(D::Minus1)?;
let pre_ws_len = pre_ws.dim(D::Minus1)?;
let mask = if pre_ws_len < mask_len {
mask.narrow(D::Minus1, mask_len - pre_ws_len, pre_ws_len)?
} else {
mask.clone()
};
let mask = mask.broadcast_left((b, self.num_heads))?;
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
mask.where_cond(&neg_inf, &pre_ws)?
Expand Down Expand Up @@ -639,18 +629,22 @@ impl StreamingTransformer {

pub fn forward_ca(&mut self, xs: &Tensor, ca_src: Option<&Tensor>) -> Result<Tensor> {
let (_b, t, c) = xs.dims3()?;
// We will extract at most "context" from the kv_cache.
// Note that the mask will discard the values that are before context.
let pos = self.layers[0]
.self_attn
.kv_cache
.k_cache()
.current_seq_len()
.min(self.context);
.current_seq_len();
let mask = if t == 1 {
None
} else {
Some(get_mask(t, pos + t, self.context, xs.device())?)
let cache_out_len = if t < self.context {
(pos + t).min(self.context)
} else {
t
};
// TODO: this is wrong, the mask depends on the kv-cache offset because of its rotating
// nature.
Some(get_mask(t, cache_out_len, self.context, xs.device())?)
};
let mut xs = match self.positional_embedding {
PositionalEmbedding::Rope | PositionalEmbedding::None => xs.clone(),
Expand Down

0 comments on commit d6f01f6

Please sign in to comment.