Skip to content

Commit

Permalink
Random reshaping.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jul 27, 2023
1 parent 24060ee commit d4cb6b6
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions candle-examples/examples/bigcode/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,15 @@ impl Attention {
) -> Result<Tensor> {
// TODO: check if we need scaling/upcasting.
let scale_factor = 1f64 / (self.head_dim as f64).sqrt();
let initial_query_shape = query.shape();
let key_len = key.dim(D::Minus1)?;
let (query, key, attn_shape) = if self.multi_query {
let (b_sz, query_len, _) = query.dims3()?;
let query = query.reshape((b_sz, query_len * self.num_heads, key_len))?;
let attn_shape = (b_sz, query_len, self.num_heads, key_len);
(query, key.clone(), attn_shape)
} else {
let (b_sz, _num_heads, query_len, _head_dime) = query.dims4()?;
let (b_sz, _num_heads, query_len, _head_dim) = query.dims4()?;
let query = query.reshape((b_sz, query_len * self.num_heads, key_len))?;
let key = key.reshape((b_sz * self.num_heads, self.head_dim, key_len))?;
let attn_shape = (b_sz, self.num_heads, query_len, key_len);
Expand All @@ -92,7 +93,10 @@ impl Attention {
let attn_weights = (query.matmul(&key)? * scale_factor)?.reshape(attn_shape)?;
let attn_weights = attn_weights.softmax(D::Minus1)?;
let attn_output = if self.multi_query {
todo!()
attn_weights
.reshape(query.shape())?
.matmul(value)?
.reshape(initial_query_shape)?
} else {
attn_weights.matmul(value)?
};
Expand Down

0 comments on commit d4cb6b6

Please sign in to comment.