Skip to content

Commit

Permalink
Clean up -jit suffix in feature flags and modules (#2705)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Jan 28, 2025
1 parent 29c383b commit 2d9e9b9
Show file tree
Hide file tree
Showing 28 changed files with 153 additions and 120 deletions.
8 changes: 4 additions & 4 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ candle-accelerate = ["burn/candle", "burn/accelerate"]
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
cuda-jit = ["burn/cuda-jit"]
cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
cuda = ["burn/cuda"]
cuda-fusion = ["cuda", "burn/fusion"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
hip-jit = ["burn/hip-jit"]
hip = ["burn/hip"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
Expand All @@ -27,7 +27,7 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
wgpu-spirv = ["burn/wgpu-spirv", "burn/autotune"]
wgpu-spirv = ["burn/vulkan", "burn/autotune"]
wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"]

[dependencies]
Expand Down
20 changes: 10 additions & 10 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ macro_rules! bench_on_backend {
let feature_name = "wgpu-spirv";
#[cfg(feature = "wgpu-spirv-fusion")]
let feature_name = "wgpu-spirv-fusion";
#[cfg(feature = "cuda-jit")]
let feature_name = "cuda-jit";
#[cfg(feature = "cuda-jit-fusion")]
let feature_name = "cuda-jit-fusion";
#[cfg(feature = "hip-jit")]
let feature_name = "hip-jit";
#[cfg(feature = "cuda")]
let feature_name = "cuda";
#[cfg(feature = "cuda-fusion")]
let feature_name = "cuda-fusion";
#[cfg(feature = "hip")]
let feature_name = "hip";

#[cfg(any(feature = "wgpu"))]
{
Expand Down Expand Up @@ -172,16 +172,16 @@ macro_rules! bench_on_backend {
$fn_name::<Candle>(&device, feature_name, url, token);
}

#[cfg(feature = "cuda-jit")]
#[cfg(feature = "cuda")]
{
use burn::backend::cuda_jit::{Cuda, CudaDevice};
use burn::backend::cuda::{Cuda, CudaDevice};

$fn_name::<Cuda<half::f16>>(&CudaDevice::default(), feature_name, url, token);
}

#[cfg(feature = "hip-jit")]
#[cfg(feature = "hip")]
{
use burn::backend::hip_jit::{Hip, HipDevice};
use burn::backend::hip::{Hip, HipDevice};

$fn_name::<Hip<half::f16>>(&HipDevice::default(), feature_name, url, token);
}
Expand Down
17 changes: 9 additions & 8 deletions crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ doc = [
"ndarray",
"tch",
"wgpu",
"cuda-jit",
"hip-jit",
"cuda",
"hip",
"audio",
"vision",
"autodiff",
Expand Down Expand Up @@ -100,26 +100,27 @@ template = ["burn-wgpu?/template"]

candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
cuda-jit = ["burn-cuda"]
hip-jit = ["burn-hip"]
cuda = ["burn-cuda"]
hip = ["burn-hip"]
ndarray = ["burn-ndarray"]
tch = ["burn-tch"]
wgpu = ["burn-wgpu"]
wgpu-spirv = ["wgpu", "burn-wgpu/spirv"]
vulkan = ["wgpu", "burn-wgpu/vulkan"]
webgpu = ["wgpu", "burn-wgpu/webgpu"]

# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
record-item-custom-serde = ["thiserror", "regex"]

# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]

test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray.
test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray.
test-cuda = ["cuda"] # To use cuda during testing, default uses ndarray.
test-hip = ["hip"] # To use hip during testing, default uses ndarray.
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
test-wgpu-spirv = [
"test-wgpu",
"wgpu-spirv",
"vulkan",
] # To use wgpu-spirv during testing, default uses ndarray.

[dependencies]
Expand Down
22 changes: 14 additions & 8 deletions crates/burn-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,29 @@ pub use burn_wgpu as wgpu;
#[cfg(feature = "wgpu")]
pub use burn_wgpu::Wgpu;

#[cfg(feature = "cuda-jit")]
pub use burn_cuda as cuda_jit;
#[cfg(feature = "webgpu")]
pub use burn_wgpu::WebGpu;

#[cfg(feature = "cuda-jit")]
pub use burn_cuda::Cuda as CudaJit;
#[cfg(feature = "vulkan")]
pub use burn_wgpu::Vulkan;

#[cfg(feature = "cuda")]
pub use burn_cuda as cuda;

#[cfg(feature = "cuda")]
pub use burn_cuda::Cuda;

#[cfg(feature = "candle")]
pub use burn_candle as candle;

#[cfg(feature = "candle")]
pub use burn_candle::Candle;

#[cfg(feature = "hip-jit")]
pub use burn_hip as hip_jit;
#[cfg(feature = "hip")]
pub use burn_hip as hip;

#[cfg(feature = "hip-jit")]
pub use burn_hip::Hip as HipJit;
#[cfg(feature = "hip")]
pub use burn_hip::Hip;

#[cfg(feature = "tch")]
pub use burn_tch as libtorch;
Expand Down
5 changes: 3 additions & 2 deletions crates/burn-hip/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ pub type Hip<F = f32, I = i32, B = u8> = burn_fusion::Fusion<JitBackend<HipRunti
// use burn_jit::JitBackend;
//
// pub type TestRuntime = cubecl::hip::HipRuntime;
// pub use half::{bf16, f16};
// pub use half::f16;
//
// burn_jit::testgen_all!();
// // TODO: Add tests for bf16
// burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]);
// }
9 changes: 8 additions & 1 deletion crates/burn-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@ default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"]
doc = ["burn-jit/doc"]
exclusive-memory-only = ["cubecl/exclusive-memory-only"]
fusion = ["burn-fusion", "burn-jit/fusion"]
spirv = ["cubecl/wgpu-spirv"]
std = ["burn-jit/std", "cubecl/std"]
template = ["burn-jit/template", "cubecl/template"]

# Backends
webgpu = ["cubecl-wgsl"]
vulkan = ["cubecl-spirv"]

# Compilers
cubecl-wgsl = []
cubecl-spirv = ["cubecl/wgpu-spirv"]

[dependencies]
cubecl = { workspace = true, features = ["wgpu"] }

Expand Down
57 changes: 35 additions & 22 deletions crates/burn-wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,21 @@ pub use burn_jit::{
pub use burn_jit::{tensor::JitTensor, JitBackend};
pub use burn_jit::{BoolElement, FloatElement, IntElement};
pub use cubecl::flex32;
pub use cubecl::wgpu::*;
pub use cubecl::CubeDim;

pub type Wgsl = cubecl::wgpu::WgslCompiler;
#[cfg(feature = "spirv")]
pub type SpirV = cubecl::wgpu::spirv::VkSpirvCompiler;
pub use cubecl::wgpu::{
init_device, init_setup, init_setup_async, MemoryConfiguration, RuntimeOptions, WgpuDevice,
WgpuResource, WgpuRuntime, WgpuSetup, WgpuStorage,
};
// Vulkan and WebGpu would have conflicting type names
pub mod graphics {
pub use cubecl::wgpu::{AutoGraphicsApi, Dx12, GraphicsApi, Metal, OpenGl, Vulkan, WebGpu};
}

#[cfg(feature = "spirv")]
type Compiler = SpirV;
#[cfg(feature = "spirv")]
type Bool = u8;
#[cfg(not(feature = "spirv"))]
type Compiler = Wgsl;
#[cfg(not(feature = "spirv"))]
type Bool = u32;
#[cfg(feature = "cubecl-spirv")]
pub use cubecl::wgpu::spirv::SpirvCompiler;
#[cfg(feature = "cubecl-wgsl")]
pub use cubecl::wgpu::WgslCompiler;

#[cfg(feature = "fusion")]
/// Tensor backend that uses the wgpu crate for executing GPU compute shaders.
Expand All @@ -44,14 +44,14 @@ type Bool = u32;
/// ```rust, ignore
/// fn custom_init() {
/// let device = Default::default();
/// burn::backend::wgpu::init_sync::<burn::backend::wgpu::Vulkan>(
/// burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(
/// &device,
/// Default::default(),
/// );
/// }
/// ```
/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
/// It's also possible to use an existing wgpu device, by using `init_existing_device`.
/// It's also possible to use an existing wgpu device, by using `init_device`.
///
/// # Notes
///
Expand All @@ -60,7 +60,7 @@ type Bool = u32;
///
/// You can disable the `fusion` feature flag to remove that functionality, which might be
/// necessary on `wasm` for now.
pub type Wgpu<F = f32, I = i32, B = Bool, C = Compiler> =
pub type Wgpu<F = f32, I = i32, B = u32, C = cubecl::wgpu::WgslCompiler> =
burn_fusion::Fusion<JitBackend<cubecl::wgpu::WgpuRuntime<C>, F, I, B>>;

#[cfg(not(feature = "fusion"))]
Expand All @@ -79,14 +79,14 @@ pub type Wgpu<F = f32, I = i32, B = Bool, C = Compiler> =
/// ```rust, ignore
/// fn custom_init() {
/// let device = Default::default();
/// burn::backend::wgpu::init_sync::<burn::backend::wgpu::Vulkan>(
/// burn::backend::wgpu::init_setup::<burn::backend::wgpu::graphics::Vulkan>(
/// &device,
/// Default::default(),
/// );
/// }
/// ```
/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
/// It's also possible to use an existing wgpu device, by using `init_existing_device`.
/// It's also possible to use an existing wgpu device, by using `init_device`.
///
/// # Notes
///
Expand All @@ -95,20 +95,33 @@ pub type Wgpu<F = f32, I = i32, B = Bool, C = Compiler> =
///
/// You can enable the `fusion` feature flag to add that functionality, which might improve
/// performance.
pub type Wgpu<F = f32, I = i32, B = Bool, C = Compiler> =
pub type Wgpu<F = f32, I = i32, B = u32, C = cubecl::wgpu::WgslCompiler> =
JitBackend<cubecl::wgpu::WgpuRuntime<C>, F, I, B>;

#[cfg(feature = "vulkan")]
/// Tensor backend that leverages the Vulkan graphics API to execute GPU compute shaders compiled to SPIR-V.
pub type Vulkan<F = f32, I = i32, B = u8> = Wgpu<F, I, B, cubecl::wgpu::spirv::VkSpirvCompiler>;

#[cfg(feature = "webgpu")]
/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL.
pub type WebGpu<F = f32, I = i32, B = u32> = Wgpu<F, I, B, WgslCompiler>;

#[cfg(test)]
mod tests {
use burn_jit::JitBackend;
#[cfg(feature = "spirv")]
#[cfg(feature = "vulkan")]
pub use half::f16;
pub type TestRuntime = cubecl::wgpu::WgpuRuntime<super::Compiler>;

#[cfg(feature = "cubecl-spirv")]
type Compiler = cubecl::wgpu::spirv::VkSpirvCompiler;
#[cfg(not(feature = "cubecl-spirv"))]
type Compiler = cubecl::wgpu::WgslCompiler;
pub type TestRuntime = cubecl::wgpu::WgpuRuntime<Compiler>;

// Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it
// breaks a lot of tests from precision issues
#[cfg(feature = "spirv")]
#[cfg(feature = "vulkan")]
burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]);
#[cfg(not(feature = "spirv"))]
#[cfg(not(feature = "vulkan"))]
burn_jit::testgen_all!([f32], [i32], [u32]);
}
7 changes: 4 additions & 3 deletions crates/burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@ openblas-system = ["burn-core/openblas-system"]
template = ["burn-core/template"]

candle = ["burn-core/candle"]
cuda-jit = ["burn-core/cuda-jit"]
hip-jit = ["burn-core/hip-jit"]
cuda = ["burn-core/cuda"]
hip = ["burn-core/hip"]
ndarray = ["burn-core/ndarray"]
remote = ["burn-core/remote"]
router = ["burn-core/router"]
server = ["burn-core/server"]
tch = ["burn-core/tch"]
wgpu = ["burn-core/wgpu"]
wgpu-spirv = ["burn-core/wgpu-spirv"]
vulkan = ["burn-core/vulkan"]
webgpu = ["burn-core/webgpu"]

# Network utils
network = ["burn-core/network"]
Expand Down
6 changes: 4 additions & 2 deletions crates/burn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,14 @@
//! - `vision`: Enables vision datasets (MnistDataset)
//! - Backends
//! - `wgpu`: Makes available the WGPU backend
//! - `wgpu-spirv`: Makes available the `wgpu` backend with the alternative SPIR-V compiler
//! - `webgpu`: Makes available the `wgpu` backend with the WebGPU Shading Language (WGSL) compiler
//! - `vulkan`: Makes available the `wgpu` backend with the alternative SPIR-V compiler
//! - `cuda`: Makes available the CUDA backend
//! - `hip`: Makes available the HIP backend
//! - `candle`: Makes available the Candle backend
//! - `tch`: Makes available the LibTorch backend
//! - `ndarray`: Makes available the NdArray backend
//! - Backend specifications
//! - `cuda`: If supported, CUDA will be used
//! - `accelerate`: If supported, Accelerate will be used
//! - `blas-netlib`: If supported, Blas Netlib will be use
//! - `openblas`: If supported, Openblas will be use
Expand Down
4 changes: 2 additions & 2 deletions examples/custom-renderer/examples/custom-renderer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu};
use burn::backend::{wgpu::WgpuDevice, Autodiff, WebGpu};

fn main() {
custom_renderer::run::<Autodiff<Wgpu>>(WgpuDevice::default());
custom_renderer::run::<Autodiff<WebGpu>>(WgpuDevice::default());
}
2 changes: 1 addition & 1 deletion examples/custom-training-loop/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ publish = false
version.workspace = true

[dependencies]
burn = {path = "../../crates/burn", features=["autodiff", "wgpu", "vision"]}
burn = {path = "../../crates/burn", features=["autodiff", "webgpu", "vision"]}
guide = {path = "../guide"}

# Serialization
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use burn::backend::{Autodiff, Wgpu};
use burn::backend::{Autodiff, WebGpu};

fn main() {
custom_training_loop::run::<Autodiff<Wgpu>>(Default::default());
custom_training_loop::run::<Autodiff<WebGpu>>(Default::default());
}
2 changes: 1 addition & 1 deletion examples/guide/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ version.workspace = true
default = ["burn/default"]

[dependencies]
burn = {path = "../../crates/burn", features = ["wgpu", "train", "vision"]}
burn = {path = "../../crates/burn", features = ["webgpu", "train", "vision"]}

# Serialization
log = {workspace = true}
Expand Down
4 changes: 2 additions & 2 deletions examples/guide/src/bin/infer.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#![recursion_limit = "131"]
use burn::{backend::Wgpu, data::dataset::Dataset};
use burn::{backend::WebGpu, data::dataset::Dataset};
use guide::inference;

fn main() {
type MyBackend = Wgpu<f32, i32>;
type MyBackend = WebGpu<f32, i32>;

let device = burn::backend::wgpu::WgpuDevice::default();

Expand Down
4 changes: 2 additions & 2 deletions examples/guide/src/bin/print.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use burn::backend::Wgpu;
use burn::backend::WebGpu;
use guide::model::ModelConfig;

fn main() {
type MyBackend = Wgpu<f32, i32>;
type MyBackend = WebGpu<f32, i32>;

let device = Default::default();
let model = ModelConfig::new(10, 512).init::<MyBackend>(&device);
Expand Down
Loading

0 comments on commit 2d9e9b9

Please sign in to comment.