Skip to content

Commit 096ec13

Browse files
Chore/update/cubecl (#2067)
1 parent 2046831 commit 096ec13

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+238
-351
lines changed

Cargo.lock

+54-43
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ members = [
1616

1717
exclude = [
1818
"examples/notebook",
19-
"crates/burn-cuda", # comment this line to work on burn-cuda
19+
# "crates/burn-cuda", # comment this line to work on burn-cuda
2020
]
2121

2222
[workspace.package]
@@ -93,7 +93,7 @@ crossterm = "0.27.0"
9393

9494
# WGPU stuff
9595
text_placeholder = "0.5.1"
96-
wgpu = "0.20.1"
96+
wgpu = "22.0.0"
9797

9898
# Benchmarks and Burnbench
9999
arboard = "3.4.0"
@@ -140,8 +140,12 @@ nvml-wrapper = "0.10.0"
140140
sysinfo = "0.30.13"
141141
systemstat = "0.2.3"
142142

143+
### For the main burn branch. ###
143144
cubecl = { version = "0.1.1", git = "https://github.com/tracel-ai/cubecl", default-features = false }
144145
cubecl-common = { version = "0.1.1", git = "https://github.com/tracel-ai/cubecl", default-features = false }
146+
### For local development. ###
147+
# cubecl = { path = "../cubecl/crates/cubecl" }
148+
# cubecl-common = { path = "../cubecl/crates/cubecl-common" }
145149

146150
[profile.dev]
147151
debug = 0 # Speed up compilation time and not necessary.

crates/burn-common/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ tokio = { workspace = true, optional = true }
3232

3333
# Parallel
3434
rayon = { workspace = true, optional = true }
35-
cubecl-common = { workspace = true, default-features = false }
35+
cubecl-common = { workspace = true }
3636

3737
[dev-dependencies]
3838
dashmap = { workspace = true }

crates/burn-core/src/nn/rope_encoding.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ impl<B: Backend> RotaryEncoding<B> {
124124
///
125125
/// Arguments:
126126
/// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
127-
/// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
128-
/// respectively.
127+
/// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
128+
/// respectively.
129129
///
130130
/// Returns:
131131
/// * Output tensor with the same shape as input tensor after applying rotary encoding.
@@ -139,8 +139,8 @@ impl<B: Backend> RotaryEncoding<B> {
139139
///
140140
/// Arguments:
141141
/// * `x` - Input tensor of shape (..., seq_len, d_model). Accommodate both 3D and 4D tensors
142-
/// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
143-
/// respectively.
142+
/// for (batch size, seq_len, hidden_dim) or (batch size, num_heads, seq_len, hidden_dim)
143+
/// respectively.
144144
/// * `start` - Sequence start position index.
145145
///
146146
/// Returns:

crates/burn-cuda/Cargo.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-cuda"
1111
version.workspace = true
1212

1313
[features]
14-
default = ["fusion", "burn-jit/default"]
14+
default = ["fusion", "burn-jit/default", "cubecl/default"]
1515
fusion = ["burn-fusion", "burn-jit/fusion"]
1616
autotune = ["burn-jit/autotune"]
1717
doc = ["burn-jit/doc"]
18-
std = ["burn-jit/std"]
18+
std = ["burn-jit/std", "cubecl/std"]
1919

2020
[dependencies]
2121
cubecl = { workspace = true, features = ["cuda"] }

crates/burn-dataset/src/dataset/sqlite.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ impl From<&'static str> for SqliteDatasetError {
7373
/// Table columns can be represented in two ways:
7474
///
7575
/// 1. The table can have a column for each field in the `I` struct. In this case, the column names in the table
76-
/// should match the field names of the `I` struct. The field names can be a subset of column names and
77-
/// can be in any order.
76+
/// should match the field names of the `I` struct. The field names can be a subset of column names and
77+
/// can be in any order.
7878
///
7979
/// For the supported field types, refer to:
8080
/// - [Serialization field types](https://docs.rs/serde_rusqlite/latest/serde_rusqlite)
8181
/// - [SQLite data types](https://www.sqlite.org/datatype3.html)
8282
///
8383
/// 2. The fields in the `I` struct can be serialized into a single column `item` in the table. In this case, the table
84-
/// should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields
85-
/// that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using
86-
/// [MessagePack](https://msgpack.org/).
84+
/// should have a single column named `item` of type `BLOB`. This is useful when the `I` struct contains complex fields
85+
/// that cannot be mapped to a SQLite type, such as nested structs, vectors, etc. The serialization is done using
86+
/// [MessagePack](https://msgpack.org/).
8787
///
8888
/// Note: The code automatically figures out which of the above two cases is applicable, and uses the appropriate
8989
/// method to read the data from the table.

0 commit comments

Comments
 (0)