Skip to content

Commit

Permalink
Remove burn-import dep and upgrade to burn 0.15 (#52)
Browse files Browse the repository at this point in the history
* Remove burn-import dep and upgrade to burn 0.15

* Add missing cubecl runtime feature flag
  • Loading branch information
laggui authored Dec 12, 2024
1 parent 473929c commit 3654542
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
6 changes: 3 additions & 3 deletions bert-burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ version = "0.2.0"
edition = "2021"

[features]
default = ["wgpu", "fusion"]
default = []
f16 = []
ndarray = ["burn/ndarray"]
tch-cpu = ["burn/tch"]
Expand All @@ -20,14 +20,14 @@ safetensors = ["candle-core/default"]

[dependencies]
# Burn
burn = { version = "0.14", default-features = false, features = ["dataset", "std"] }
burn = { version = "0.15", default-features = false, features = ["dataset", "std"] }
cubecl-runtime = { version = "0.3.0", features = ["channel-mpsc"] } # missing feature flag when burn default-features are off
candle-core = { version = "0.3" }
# Tokenizer
tokenizers = { version = "0.15.0", default-features = false, features = [
"onig",
"http",
] }
burn-import = "0.13"
derive-new = "0.6.0"
hf-hub = { version = "0.3.2", features = ["tokio"] }

Expand Down
9 changes: 6 additions & 3 deletions bert-burn/src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl BertEmbeddingsConfig {
impl<B: Backend> BertEmbeddings<B> {
pub fn forward(&self, item: BertInferenceBatch<B>) -> Tensor<B, 3, Float> {
// Items batch contains the tokenized input and padding mask, each of dim: [batch_size, max_seq_length]
let input_shape = &item.tokens.shape();
let input_shape = item.tokens.shape();
let input_ids = item.tokens;

// Embed tokens
Expand All @@ -76,7 +76,9 @@ impl<B: Backend> BertEmbeddings<B> {

let seq_length = input_shape.dims[1];
let mut position_ids_tensor: Tensor<B, 2, Int> =
Tensor::arange(0..seq_length as i64, device).reshape([1, seq_length]);
Tensor::arange(0..seq_length as i64, device)
.reshape([1, seq_length])
.expand(input_shape.clone());

if self.max_position_embeddings != 512 {
// RoBERTa use a different scheme than BERT to create position indexes where padding tokens are given
Expand All @@ -87,7 +89,8 @@ impl<B: Backend> BertEmbeddings<B> {
..(seq_length as i64) + (self.pad_token_idx as i64) + 1,
device,
)
.reshape([1, seq_length]);
.reshape([1, seq_length])
.expand(input_shape);
position_ids_tensor =
position_ids.mask_fill(item.mask_pad.clone(), self.pad_token_idx as i32);
}
Expand Down

0 comments on commit 3654542

Please sign in to comment.