Skip to content

Commit

Permalink
Mimi streaming fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 22, 2024
1 parent c79bf42 commit 3277844
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
5 changes: 4 additions & 1 deletion candle-examples/examples/mimi/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ fn main() -> Result<()> {
for chunk_start in (0..seq_len).step_by(chunk_size) {
let chunk_len = usize::min(chunk_size, seq_len - chunk_start);
let codes = codes.narrow(candle::D::Minus1, chunk_start, chunk_len)?;
pcm_chunks.push(model.decode(&codes)?)
let pcm = model.decode_step(&codes.into())?;
if let Some(pcm) = pcm.as_option() {
pcm_chunks.push(pcm.clone())
}
}
Tensor::cat(&pcm_chunks, candle::D::Minus1)?
}
Expand Down
10 changes: 10 additions & 0 deletions candle-transformers/src/models/mimi/transformer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,16 @@ impl StreamingMultiheadAttention {
let pre_ws = match mask {
None => pre_ws,
Some(mask) => {
// This is a bit cumbersome and slightly incorrect: when providing a new slice
// the kv cache will have a slice offset rather than offset + t. In the mimi
// context of an offset of 250, this would not make much difference though.
let mask_len = mask.dim(D::Minus1)?;
let pre_ws_len = pre_ws.dim(D::Minus1)?;
let mask = if pre_ws_len < mask_len {
mask.narrow(D::Minus1, mask_len - pre_ws_len, pre_ws_len)?
} else {
mask.clone()
};
let mask = mask.broadcast_left((b, self.num_heads))?;
let neg_inf = self.neg_inf.broadcast_as(pre_ws.shape())?;
mask.where_cond(&neg_inf, &pre_ws)?
Expand Down

0 comments on commit 3277844

Please sign in to comment.