Skip to content

Commit

Permalink
Handle contiguity + bugfix + use in mimi.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 22, 2024
1 parent 9964c6d commit 58c1e90
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
16 changes: 10 additions & 6 deletions candle-nn/src/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,23 +224,27 @@ impl RotatingCache {

self.current_seq_len += seq_len;
if seq_len >= self.max_seq_len {
let src = src.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?;
ad.slice_set(&src, self.dim, 0)?;
let to_copy = src
.narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
.contiguous()?;
ad.slice_set(&to_copy, self.dim, 0)?;
self.offset = 0;
// Here we return `src` rather than `ad` so that all the past can be used.
Ok(src)
Ok(src.clone())
} else {
let rem_len = self.max_seq_len - self.offset;
if seq_len <= rem_len {
ad.slice_set(src, self.dim, self.offset)?;
ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
self.offset = (self.offset + seq_len) % self.max_seq_len;
} else {
// We have to make two copies here as we go over the boundary of the cache.
if rem_len > 0 {
let src1 = src.narrow(self.dim, 0, rem_len)?;
let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
ad.slice_set(&src1, self.dim, self.offset)?;
}
let src2 = src.narrow(self.dim, rem_len, seq_len - rem_len)?;
let src2 = src
.narrow(self.dim, rem_len, seq_len - rem_len)?
.contiguous()?;
ad.slice_set(&src2, self.dim, 0)?;
self.offset = seq_len - rem_len;
}
Expand Down
2 changes: 1 addition & 1 deletion candle-nn/tests/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ fn rotating_kv_cache() -> Result<()> {

let t = Tensor::new(&[0., 1., 2., 3., 4., 5., 6., 7., 8.], &Device::Cpu)?;
let data = cache.append(&t)?;
assert_eq!(data.to_vec1::<f64>()?, [3., 4., 5., 6., 7., 8.]);
assert_eq!(data.to_vec1::<f64>()?, [0., 1., 2., 3., 4., 5., 6., 7., 8.]);
assert_eq!(cache.current_seq_len(), 22);
assert_eq!(cache.offset(), 0);

Expand Down
8 changes: 4 additions & 4 deletions candle-transformers/src/models/mimi/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ pub struct StreamingMultiheadAttention {
context: usize,
neg_inf: Tensor,
rope: Option<Arc<RotaryEmbedding>>,
kv_cache: candle_nn::kv_cache::KvCache,
kv_cache: candle_nn::kv_cache::RotatingKvCache,
pos: usize,
use_flash_attn: bool,
span: tracing::Span,
Expand All @@ -153,7 +153,7 @@ impl StreamingMultiheadAttention {
num_heads: cfg.num_heads,
context: cfg.context,
neg_inf,
kv_cache: candle_nn::kv_cache::KvCache::new(2, cfg.max_seq_len),
kv_cache: candle_nn::kv_cache::RotatingKvCache::new(2, cfg.context),
pos: 0,
use_flash_attn: false,
span: tracing::span!(tracing::Level::TRACE, "mha"),
Expand Down Expand Up @@ -236,7 +236,7 @@ impl StreamingMultiheadAttention {
self.kv_cache.reset()
}

pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
self.kv_cache = kv_cache
}
}
Expand Down Expand Up @@ -582,7 +582,7 @@ impl StreamingTransformerLayer {
self.self_attn.reset_kv_cache()
}

pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::KvCache) {
pub fn set_kv_cache(&mut self, kv_cache: candle_nn::kv_cache::RotatingKvCache) {
self.self_attn.set_kv_cache(kv_cache)
}
}
Expand Down

0 comments on commit 58c1e90

Please sign in to comment.