Skip to content

Commit

Permalink
Merge branch 'main' into chore/update-cubecl
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Feb 2, 2025
2 parents 71584e0 + cb0854c commit b83bab9
Show file tree
Hide file tree
Showing 40 changed files with 469 additions and 958 deletions.
1,051 changes: 267 additions & 784 deletions Cargo.lock

Large diffs are not rendered by default.

16 changes: 8 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ version = "0.17.0"
atomic_float = "1"
bytemuck = "1.21.0"
candle-core = { version = "0.8" }
clap = { version = "4.5.26", features = ["derive"] }
clap = { version = "4.5.27", features = ["derive"] }
colored = "2.1.0"
console_error_panic_hook = "0.1.7"
csv = "1.3.1"
Expand All @@ -54,7 +54,7 @@ log = { default-features = false, version = "0.4.25" }
md5 = "0.7.0"
paste = "1"
percent-encoding = "2.3.1"
polars = { version = "0.44.2", features = ["lazy"] }
polars = { version = "0.46.0", features = ["lazy"] }
pretty_assertions = "1.4.1"
proc-macro2 = "1.0.93"
protobuf = "3.7.1"
Expand Down Expand Up @@ -101,7 +101,7 @@ ratatui = "0.29.0"

# WGPU stuff
text_placeholder = "0.5.1"
wgpu = "24.0.0"
wgpu = "24.0.1"

# Benchmarks and Burnbench
arboard = "3.4.1"
Expand Down Expand Up @@ -141,7 +141,7 @@ serde = { version = "1.0.217", default-features = false, features = [
"alloc",
] } # alloc is for no_std, derive is needed
serde_json = { version = "1.0.137", default-features = false }
uuid = { version = "1.12.0", default-features = false }
uuid = { version = "1.12.1", default-features = false }

libc = "0.2.169"
nvml-wrapper = "0.10.0"
Expand All @@ -153,11 +153,11 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e0734dadca994b02b7dce3b77a575edb1fb2232e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "e0734dadca994b02b7dce3b77a575edb1fb2232e" }
### For local development. ###
cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
### For the release. ###
# cubecl = { version = "0.4.0", default-features = false }
# cubecl-common = { version = "0.4.0", default-features = false }
Expand Down
8 changes: 4 additions & 4 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ candle-accelerate = ["burn/candle", "burn/accelerate"]
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
cuda-jit = ["burn/cuda-jit"]
cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
cuda = ["burn/cuda"]
cuda-fusion = ["cuda", "burn/fusion"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
hip-jit = ["burn/hip-jit"]
hip = ["burn/hip"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
Expand All @@ -27,7 +27,7 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
wgpu-spirv = ["burn/wgpu-spirv", "burn/autotune"]
wgpu-spirv = ["burn/vulkan", "burn/autotune"]
wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"]

[dependencies]
Expand Down
14 changes: 7 additions & 7 deletions backend-comparison/src/burnbenchapp/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ enum BackendValues {
CandleCuda,
#[strum(to_string = "candle-metal")]
CandleMetal,
#[strum(to_string = "cuda")]
Cuda,
#[strum(to_string = "cuda-fusion")]
CudaFusion,
#[cfg(target_os = "linux")]
#[strum(to_string = "hip")]
Hip,
#[strum(to_string = "ndarray")]
Ndarray,
#[strum(to_string = "ndarray-blas-accelerate")]
Expand All @@ -82,13 +89,6 @@ enum BackendValues {
WgpuSpirv,
#[strum(to_string = "wgpu-spirv-fusion")]
WgpuSpirvFusion,
#[strum(to_string = "cuda-jit")]
CudaJit,
#[strum(to_string = "cuda-jit-fusion")]
CudaJitFusion,
#[cfg(target_os = "linux")]
#[strum(to_string = "hip-jit")]
HipJit,
}

#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
Expand Down
20 changes: 10 additions & 10 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ macro_rules! bench_on_backend {
let feature_name = "wgpu-spirv";
#[cfg(feature = "wgpu-spirv-fusion")]
let feature_name = "wgpu-spirv-fusion";
#[cfg(feature = "cuda-jit")]
let feature_name = "cuda-jit";
#[cfg(feature = "cuda-jit-fusion")]
let feature_name = "cuda-jit-fusion";
#[cfg(feature = "hip-jit")]
let feature_name = "hip-jit";
#[cfg(feature = "cuda")]
let feature_name = "cuda";
#[cfg(feature = "cuda-fusion")]
let feature_name = "cuda-fusion";
#[cfg(feature = "hip")]
let feature_name = "hip";

#[cfg(any(feature = "wgpu"))]
{
Expand Down Expand Up @@ -172,16 +172,16 @@ macro_rules! bench_on_backend {
$fn_name::<Candle>(&device, feature_name, url, token);
}

#[cfg(feature = "cuda-jit")]
#[cfg(feature = "cuda")]
{
use burn::backend::cuda_jit::{Cuda, CudaDevice};
use burn::backend::cuda::{Cuda, CudaDevice};

$fn_name::<Cuda<half::f16>>(&CudaDevice::default(), feature_name, url, token);
}

#[cfg(feature = "hip-jit")]
#[cfg(feature = "hip")]
{
use burn::backend::hip_jit::{Hip, HipDevice};
use burn::backend::hip::{Hip, HipDevice};

$fn_name::<Hip<half::f16>>(&HipDevice::default(), feature_name, url, token);
}
Expand Down
2 changes: 1 addition & 1 deletion burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ strategies.
| Burn API | PyTorch Equivalent |
| ------------------------------------------------ | -------------------------------------------------- |
| `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` |
| `activation::hard_sigmoid(tensor, alpha, beta) | `nn.functional.hardsigmoid(tensor)` |
| `activation::hard_sigmoid(tensor, alpha, beta)` | `nn.functional.hardsigmoid(tensor)` |
| `activation::leaky_relu(tensor, negative_slope)` | `nn.functional.leaky_relu(tensor, negative_slope)` |
| `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` |
| `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` |
Expand Down
2 changes: 1 addition & 1 deletion burn-book/src/saving-and-loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Saving your trained machine learning model is quite easy, no matter the output f
mentioned in the [Record](./building-blocks/record.md) section, different formats are supported to
serialize/deserialize models. By default, we use the `NamedMpkFileRecorder` which uses the
[MessagePack](https://msgpack.org/) binary serialization format with the help of
[smp_serde](https://docs.rs/rmp-serde/).
[rmp_serde](https://docs.rs/rmp-serde/).

```rust, ignore
// Save model in MessagePack format with full precision
Expand Down
19 changes: 10 additions & 9 deletions crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ doc = [
"ndarray",
"tch",
"wgpu",
"cuda-jit",
"hip-jit",
"cuda",
"hip",
"audio",
"vision",
"autodiff",
Expand Down Expand Up @@ -88,7 +88,7 @@ fusion = ["burn-wgpu?/fusion", "burn-cuda?/fusion"]

## Backend features
accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"]
autotune = ["burn-wgpu?/autotune"]
autotune = ["burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-hip?/autotune"]
blas-netlib = ["burn-ndarray?/blas-netlib"]
metal = ["burn-candle?/metal"]
openblas = ["burn-ndarray?/blas-openblas"]
Expand All @@ -100,26 +100,27 @@ template = ["burn-wgpu?/template"]

candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
cuda-jit = ["burn-cuda"]
hip-jit = ["burn-hip"]
cuda = ["burn-cuda"]
hip = ["burn-hip"]
ndarray = ["burn-ndarray"]
tch = ["burn-tch"]
wgpu = ["burn-wgpu"]
wgpu-spirv = ["wgpu", "burn-wgpu/spirv"]
vulkan = ["wgpu", "burn-wgpu/vulkan"]
webgpu = ["wgpu", "burn-wgpu/webgpu"]

# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
record-item-custom-serde = ["thiserror", "regex"]

# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]

test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray.
test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray.
test-cuda = ["cuda"] # To use cuda during testing, default uses ndarray.
test-hip = ["hip"] # To use hip during testing, default uses ndarray.
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
test-wgpu-spirv = [
"test-wgpu",
"wgpu-spirv",
"vulkan",
] # To use wgpu-spirv during testing, default uses ndarray.

[dependencies]
Expand Down
22 changes: 14 additions & 8 deletions crates/burn-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,29 @@ pub use burn_wgpu as wgpu;
#[cfg(feature = "wgpu")]
pub use burn_wgpu::Wgpu;

#[cfg(feature = "cuda-jit")]
pub use burn_cuda as cuda_jit;
#[cfg(feature = "webgpu")]
pub use burn_wgpu::WebGpu;

#[cfg(feature = "cuda-jit")]
pub use burn_cuda::Cuda as CudaJit;
#[cfg(feature = "vulkan")]
pub use burn_wgpu::Vulkan;

#[cfg(feature = "cuda")]
pub use burn_cuda as cuda;

#[cfg(feature = "cuda")]
pub use burn_cuda::Cuda;

#[cfg(feature = "candle")]
pub use burn_candle as candle;

#[cfg(feature = "candle")]
pub use burn_candle::Candle;

#[cfg(feature = "hip-jit")]
pub use burn_hip as hip_jit;
#[cfg(feature = "hip")]
pub use burn_hip as hip;

#[cfg(feature = "hip-jit")]
pub use burn_hip::Hip as HipJit;
#[cfg(feature = "hip")]
pub use burn_hip::Hip;

#[cfg(feature = "tch")]
pub use burn_tch as libtorch;
Expand Down
22 changes: 11 additions & 11 deletions crates/burn-dataset/src/dataset/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,20 +269,20 @@ mod tests {
}

fn create_test_dataframe() -> DataFrame {
let s0 = Column::Series(Series::new("int32".into(), &[1i32, 2i32, 3i32]));
let s1 = Column::Series(Series::new("bool".into(), &[true, false, true]));
let s2 = Column::Series(Series::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64]));
let s3 = Column::Series(Series::new("string".into(), &["Boo", "Boo2", "Boo3"]));
let s6 = Column::Series(Series::new("int16".into(), &[1i16, 2i16, 3i16]));
let s8 = Column::Series(Series::new("uint32".into(), &[1u32, 2u32, 3u32]));
let s9 = Column::Series(Series::new("uint64".into(), &[1u64, 2u64, 3u64]));
let s10 = Column::Series(Series::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32]));
let s11 = Column::Series(Series::new("int64".into(), &[1i64, 2i64, 3i64]));
let s12 = Column::Series(Series::new("int8".into(), &[1i8, 2i8, 3i8]));
let s0 = Column::new("int32".into(), &[1i32, 2i32, 3i32]);
let s1 = Column::new("bool".into(), &[true, false, true]);
let s2 = Column::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64]);
let s3 = Column::new("string".into(), &["Boo", "Boo2", "Boo3"]);
let s6 = Column::new("int16".into(), &[1i16, 2i16, 3i16]);
let s8 = Column::new("uint32".into(), &[1u32, 2u32, 3u32]);
let s9 = Column::new("uint64".into(), &[1u64, 2u64, 3u64]);
let s10 = Column::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32]);
let s11 = Column::new("int64".into(), &[1i64, 2i64, 3i64]);
let s12 = Column::new("int8".into(), &[1i8, 2i8, 3i8]);

let binary_data: Vec<&[u8]> = vec![&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]];

let s13 = Column::Series(Series::new("binary".into(), binary_data));
let s13 = Column::new("binary".into(), binary_data);
DataFrame::new(vec![s0, s1, s2, s3, s6, s8, s9, s10, s11, s12, s13]).unwrap()
}

Expand Down
5 changes: 3 additions & 2 deletions crates/burn-hip/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ pub type Hip<F = f32, I = i32, B = u8> = burn_fusion::Fusion<JitBackend<HipRunti
// use burn_jit::JitBackend;
//
// pub type TestRuntime = cubecl::hip::HipRuntime;
// pub use half::{bf16, f16};
// pub use half::f16;
//
// burn_jit::testgen_all!();
// // TODO: Add tests for bf16
// burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]);
// }
2 changes: 1 addition & 1 deletion crates/burn-jit/src/fusion/matmul/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl<R: JitRuntime> MatmulOptimization<R> {
fused_matmul_autotune::<R, BT>(self, context);

#[cfg(not(feature = "autotune"))]
if self.execute_fused::<BT>(context).is_err() {
if self.execute_standard_fused::<BT>(context).is_err() {
self.execute_fallback::<BT>(context);
}
}
Expand Down
4 changes: 3 additions & 1 deletion crates/burn-jit/src/kernel/reduce/base.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[cfg(feature = "autotune")]
use super::{autotune_reduce, autotune_sum};
use crate::{
element::JitElement,
Expand Down Expand Up @@ -31,6 +32,7 @@ pub fn sum<Run: JitRuntime, E: JitElement>(
))
}
SumStrategy::Chained(strategy) => reduce::<Run, E, E, Sum>(tensor, strategy),
#[cfg(feature = "autotune")]
SumStrategy::Autotune => Ok(autotune_sum::<Run, E>(&client, tensor)),
}
}
Expand All @@ -53,7 +55,7 @@ impl Default for SumStrategy {
return Self::Autotune;

#[cfg(not(feature = "autotune"))]
return Self::Static(4);
return Self::OneShot(4);
}
}

Expand Down
2 changes: 2 additions & 0 deletions crates/burn-jit/src/kernel/reduce/tune.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ mod reduce_ops {
}

/// Executes autotune on reduce operations.
#[cfg(feature = "autotune")]
pub fn autotune_sum<Run: JitRuntime, E: JitElement>(
client: &ComputeClient<Run::Server, Run::Channel>,
input: JitTensor<Run>,
Expand Down Expand Up @@ -280,6 +281,7 @@ mod sum_ops {
.map_err(|e| e.to_string())
}

#[cfg(feature = "autotune")]
pub(crate) fn sum_chained<Run: JitRuntime, E: JitElement>(
input: JitTensor<Run>,
) -> Result<JitTensor<Run>, String> {
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-remote/src/server/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,12 @@ impl<B: ReprBackend> SessionManager<B> {

impl<B: ReprBackend> Session<B> {
fn new(runner: Runner<B>) -> Self {
let (sender, reveiver) = std::sync::mpsc::sync_channel(1);
let (sender, receiver) = std::sync::mpsc::sync_channel(1);
Self {
runner,
streams: Default::default(),
sender,
receiver: Some(reveiver),
receiver: Some(receiver),
}
}

Expand Down
4 changes: 2 additions & 2 deletions crates/burn-tensor/src/tensor/backend/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ mod tests {
}

#[test]
fn should_build_indices_2d_complexe() {
fn should_build_indices_2d_complex() {
let shape = Shape::new([2, 3]);

let indices = build_indices(&shape, Order::Left);
Expand All @@ -206,7 +206,7 @@ mod tests {
}

#[test]
fn should_build_indices_3d_complexe() {
fn should_build_indices_3d_complex() {
let shape = Shape::new([2, 5, 3]);

let indices = build_indices(&shape, Order::Left);
Expand Down
9 changes: 8 additions & 1 deletion crates/burn-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@ default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"]
doc = ["burn-jit/doc"]
exclusive-memory-only = ["cubecl/exclusive-memory-only"]
fusion = ["burn-fusion", "burn-jit/fusion"]
spirv = ["cubecl/wgpu-spirv"]
std = ["burn-jit/std", "cubecl/std"]
template = ["burn-jit/template", "cubecl/template"]

# Backends
webgpu = ["cubecl-wgsl"]
vulkan = ["cubecl-spirv"]

# Compilers
cubecl-wgsl = []
cubecl-spirv = ["cubecl/wgpu-spirv"]

[dependencies]
cubecl = { workspace = true, features = ["wgpu"] }

Expand Down
Loading

0 comments on commit b83bab9

Please sign in to comment.