Skip to content

Commit

Permalink
Use cat for faster MQA computation.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Apr 11, 2024
1 parent a0460cd commit c40098d
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions candle-transformers/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ impl LayerWeights {
};
self.kv_cache = Some((k.clone(), v.clone()));

// Support for MQA, useful for 70B models.
// Support for MQA, useful for 70B models and mistral.
let k = self.repeat_kv(k)?;
let v = self.repeat_kv(v)?;

Expand All @@ -231,11 +231,7 @@ impl LayerWeights {
Ok(x)
} else {
let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
let x = x
.unsqueeze(2)?
.expand((b_sz, n_kv_head, n_rep, seq_len, head_dim))?
.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))?;
Ok(x)
Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
}
}
}
Expand Down

0 comments on commit c40098d

Please sign in to comment.