diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 265dbeaaf0..821d189fe0 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -15,10 +15,10 @@ candle-accelerate = ["burn/candle", "burn/accelerate"] candle-cpu = ["burn/candle"] candle-cuda = ["burn/candle-cuda"] candle-metal = ["burn/candle", "burn/metal"] -cuda-jit = ["burn/cuda-jit"] -cuda-jit-fusion = ["cuda-jit", "burn/fusion"] +cuda = ["burn/cuda"] +cuda-fusion = ["cuda", "burn/fusion"] default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"] -hip-jit = ["burn/hip-jit"] +hip = ["burn/hip"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] @@ -27,7 +27,7 @@ tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu", "burn/autotune"] wgpu-fusion = ["wgpu", "burn/fusion"] -wgpu-spirv = ["burn/wgpu-spirv", "burn/autotune"] +wgpu-spirv = ["burn/vulkan", "burn/autotune"] wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"] [dependencies] diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index 26b08bc3b8..b3351e9dd5 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -91,12 +91,12 @@ macro_rules! bench_on_backend { let feature_name = "wgpu-spirv"; #[cfg(feature = "wgpu-spirv-fusion")] let feature_name = "wgpu-spirv-fusion"; - #[cfg(feature = "cuda-jit")] - let feature_name = "cuda-jit"; - #[cfg(feature = "cuda-jit-fusion")] - let feature_name = "cuda-jit-fusion"; - #[cfg(feature = "hip-jit")] - let feature_name = "hip-jit"; + #[cfg(feature = "cuda")] + let feature_name = "cuda"; + #[cfg(feature = "cuda-fusion")] + let feature_name = "cuda-fusion"; + #[cfg(feature = "hip")] + let feature_name = "hip"; #[cfg(any(feature = "wgpu"))] { @@ -172,16 +172,16 @@ macro_rules! bench_on_backend { $fn_name::(&device, feature_name, url, token); } - #[cfg(feature = "cuda-jit")] + #[cfg(feature = "cuda")] { - use burn::backend::cuda_jit::{Cuda, CudaDevice}; + use burn::backend::cuda::{Cuda, CudaDevice}; $fn_name::>(&CudaDevice::default(), feature_name, url, token); } - #[cfg(feature = "hip-jit")] + #[cfg(feature = "hip")] { - use burn::backend::hip_jit::{Hip, HipDevice}; + use burn::backend::hip::{Hip, HipDevice}; $fn_name::>(&HipDevice::default(), feature_name, url, token); } diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index e895cc4572..5d926cd0b3 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -36,8 +36,8 @@ doc = [ "ndarray", "tch", "wgpu", - "cuda-jit", - "hip-jit", + "cuda", + "hip", "audio", "vision", "autodiff", @@ -100,12 +100,13 @@ template = ["burn-wgpu?/template"] candle = ["burn-candle"] candle-cuda = ["candle", "burn-candle/cuda"] -cuda-jit = ["burn-cuda"] -hip-jit = ["burn-hip"] +cuda = ["burn-cuda"] +hip = ["burn-hip"] ndarray = ["burn-ndarray"] tch = ["burn-tch"] wgpu = ["burn-wgpu"] -wgpu-spirv = ["wgpu", "burn-wgpu/spirv"] +vulkan = ["wgpu", "burn-wgpu/vulkan"] +webgpu = ["wgpu", "burn-wgpu/webgpu"] # Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files. record-item-custom-serde = ["thiserror", "regex"] @@ -113,13 +114,13 @@ record-item-custom-serde = ["thiserror", "regex"] # Serialization formats experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] -test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray. -test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray. +test-cuda = ["cuda"] # To use cuda during testing, default uses ndarray. +test-hip = ["hip"] # To use hip during testing, default uses ndarray. test-tch = ["tch"] # To use tch during testing, default uses ndarray. test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray. test-wgpu-spirv = [ "test-wgpu", - "wgpu-spirv", + "vulkan", ] # To use wgpu-spirv during testing, default uses ndarray. [dependencies] diff --git a/crates/burn-core/src/backend.rs b/crates/burn-core/src/backend.rs index bd4c959302..31ac3a8c41 100644 --- a/crates/burn-core/src/backend.rs +++ b/crates/burn-core/src/backend.rs @@ -21,11 +21,17 @@ pub use burn_wgpu as wgpu; #[cfg(feature = "wgpu")] pub use burn_wgpu::Wgpu; -#[cfg(feature = "cuda-jit")] -pub use burn_cuda as cuda_jit; +#[cfg(feature = "webgpu")] +pub use burn_wgpu::WebGpu; -#[cfg(feature = "cuda-jit")] -pub use burn_cuda::Cuda as CudaJit; +#[cfg(feature = "vulkan")] +pub use burn_wgpu::Vulkan; + +#[cfg(feature = "cuda")] +pub use burn_cuda as cuda; + +#[cfg(feature = "cuda")] +pub use burn_cuda::Cuda; #[cfg(feature = "candle")] pub use burn_candle as candle; @@ -33,11 +39,11 @@ pub use burn_candle as candle; #[cfg(feature = "candle")] pub use burn_candle::Candle; -#[cfg(feature = "hip-jit")] -pub use burn_hip as hip_jit; +#[cfg(feature = "hip")] +pub use burn_hip as hip; -#[cfg(feature = "hip-jit")] -pub use burn_hip::Hip as HipJit; +#[cfg(feature = "hip")] +pub use burn_hip::Hip; #[cfg(feature = "tch")] pub use burn_tch as libtorch; diff --git a/crates/burn-hip/src/lib.rs b/crates/burn-hip/src/lib.rs index fc8f704e74..13f5239637 100644 --- a/crates/burn-hip/src/lib.rs +++ b/crates/burn-hip/src/lib.rs @@ -26,7 +26,8 @@ pub type Hip = burn_fusion::Fusion( +/// burn::backend::wgpu::init_setup::( /// &device, /// Default::default(), /// ); /// } /// ``` /// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API. -/// It's also possible to use an existing wgpu device, by using `init_existing_device`. +/// It's also possible to use an existing wgpu device, by using `init_device`. /// /// # Notes /// @@ -60,7 +60,7 @@ type Bool = u32; /// /// You can disable the `fusion` feature flag to remove that functionality, which might be /// necessary on `wasm` for now. -pub type Wgpu = +pub type Wgpu = burn_fusion::Fusion, F, I, B>>; #[cfg(not(feature = "fusion"))] @@ -79,14 +79,14 @@ pub type Wgpu = /// ```rust, ignore /// fn custom_init() { /// let device = Default::default(); -/// burn::backend::wgpu::init_sync::( +/// burn::backend::wgpu::init_setup::( /// &device, /// Default::default(), /// ); /// } /// ``` /// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API. -/// It's also possible to use an existing wgpu device, by using `init_existing_device`. +/// It's also possible to use an existing wgpu device, by using `init_device`. /// /// # Notes /// @@ -95,20 +95,33 @@ pub type Wgpu = /// /// You can enable the `fusion` feature flag to add that functionality, which might improve /// performance. -pub type Wgpu = +pub type Wgpu = JitBackend, F, I, B>; +#[cfg(feature = "vulkan")] +/// Tensor backend that leverages the Vulkan graphics API to execute GPU compute shaders compiled to SPIR-V. +pub type Vulkan = Wgpu; + +#[cfg(feature = "webgpu")] +/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL. +pub type WebGpu = Wgpu; + #[cfg(test)] mod tests { use burn_jit::JitBackend; - #[cfg(feature = "spirv")] + #[cfg(feature = "vulkan")] pub use half::f16; - pub type TestRuntime = cubecl::wgpu::WgpuRuntime; + + #[cfg(feature = "cubecl-spirv")] + type Compiler = cubecl::wgpu::spirv::VkSpirvCompiler; + #[cfg(not(feature = "cubecl-spirv"))] + type Compiler = cubecl::wgpu::WgslCompiler; + pub type TestRuntime = cubecl::wgpu::WgpuRuntime; // Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it // breaks a lot of tests from precision issues - #[cfg(feature = "spirv")] + #[cfg(feature = "vulkan")] burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]); - #[cfg(not(feature = "spirv"))] + #[cfg(not(feature = "vulkan"))] burn_jit::testgen_all!([f32], [i32], [u32]); } diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index cd13682a4b..b0abf7d178 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -50,15 +50,16 @@ openblas-system = ["burn-core/openblas-system"] template = ["burn-core/template"] candle = ["burn-core/candle"] -cuda-jit = ["burn-core/cuda-jit"] -hip-jit = ["burn-core/hip-jit"] +cuda = ["burn-core/cuda"] +hip = ["burn-core/hip"] ndarray = ["burn-core/ndarray"] remote = ["burn-core/remote"] router = ["burn-core/router"] server = ["burn-core/server"] tch = ["burn-core/tch"] wgpu = ["burn-core/wgpu"] -wgpu-spirv = ["burn-core/wgpu-spirv"] +vulkan = ["burn-core/vulkan"] +webgpu = ["burn-core/webgpu"] # Network utils network = ["burn-core/network"] diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs index b0ecf06a71..203d1a802d 100644 --- a/crates/burn/src/lib.rs +++ b/crates/burn/src/lib.rs @@ -76,12 +76,14 @@ //! - `vision`: Enables vision datasets (MnistDataset) //! - Backends //! - `wgpu`: Makes available the WGPU backend -//! - `wgpu-spirv`: Makes available the `wgpu` backend with the alternative SPIR-V compiler +//! - `webgpu`: Makes available the `wgpu` backend with the WebGPU Shading Language (WGSL) compiler +//! - `vulkan`: Makes available the `wgpu` backend with the alternative SPIR-V compiler +//! - `cuda`: Makes available the CUDA backend +//! - `hip`: Makes available the HIP backend //! - `candle`: Makes available the Candle backend //! - `tch`: Makes available the LibTorch backend //! - `ndarray`: Makes available the NdArray backend //! - Backend specifications -//! - `cuda`: If supported, CUDA will be used //! - `accelerate`: If supported, Accelerate will be used //! - `blas-netlib`: If supported, Blas Netlib will be use //! - `openblas`: If supported, Openblas will be use diff --git a/examples/custom-renderer/examples/custom-renderer.rs b/examples/custom-renderer/examples/custom-renderer.rs index ea580833df..aa344b1d2b 100644 --- a/examples/custom-renderer/examples/custom-renderer.rs +++ b/examples/custom-renderer/examples/custom-renderer.rs @@ -1,5 +1,5 @@ -use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu}; +use burn::backend::{wgpu::WgpuDevice, Autodiff, WebGpu}; fn main() { - custom_renderer::run::>(WgpuDevice::default()); + custom_renderer::run::>(WgpuDevice::default()); } diff --git a/examples/custom-training-loop/Cargo.toml b/examples/custom-training-loop/Cargo.toml index 536307fdba..6e1fca1e92 100644 --- a/examples/custom-training-loop/Cargo.toml +++ b/examples/custom-training-loop/Cargo.toml @@ -7,7 +7,7 @@ publish = false version.workspace = true [dependencies] -burn = {path = "../../crates/burn", features=["autodiff", "wgpu", "vision"]} +burn = {path = "../../crates/burn", features=["autodiff", "webgpu", "vision"]} guide = {path = "../guide"} # Serialization diff --git a/examples/custom-training-loop/examples/custom-training-loop.rs b/examples/custom-training-loop/examples/custom-training-loop.rs index a418ede196..ec9d55f42a 100644 --- a/examples/custom-training-loop/examples/custom-training-loop.rs +++ b/examples/custom-training-loop/examples/custom-training-loop.rs @@ -1,5 +1,5 @@ -use burn::backend::{Autodiff, Wgpu}; +use burn::backend::{Autodiff, WebGpu}; fn main() { - custom_training_loop::run::>(Default::default()); + custom_training_loop::run::>(Default::default()); } diff --git a/examples/guide/Cargo.toml b/examples/guide/Cargo.toml index e60b8d45e5..aea61f5e25 100644 --- a/examples/guide/Cargo.toml +++ b/examples/guide/Cargo.toml @@ -10,7 +10,7 @@ version.workspace = true default = ["burn/default"] [dependencies] -burn = {path = "../../crates/burn", features = ["wgpu", "train", "vision"]} +burn = {path = "../../crates/burn", features = ["webgpu", "train", "vision"]} # Serialization log = {workspace = true} diff --git a/examples/guide/src/bin/infer.rs b/examples/guide/src/bin/infer.rs index 6a246d85f0..44c5b1dabc 100644 --- a/examples/guide/src/bin/infer.rs +++ b/examples/guide/src/bin/infer.rs @@ -1,9 +1,9 @@ #![recursion_limit = "131"] -use burn::{backend::Wgpu, data::dataset::Dataset}; +use burn::{backend::WebGpu, data::dataset::Dataset}; use guide::inference; fn main() { - type MyBackend = Wgpu; + type MyBackend = WebGpu; let device = burn::backend::wgpu::WgpuDevice::default(); diff --git a/examples/guide/src/bin/print.rs b/examples/guide/src/bin/print.rs index 9432aa93a4..6f3b710c25 100644 --- a/examples/guide/src/bin/print.rs +++ b/examples/guide/src/bin/print.rs @@ -1,8 +1,8 @@ -use burn::backend::Wgpu; +use burn::backend::WebGpu; use guide::model::ModelConfig; fn main() { - type MyBackend = Wgpu; + type MyBackend = WebGpu; let device = Default::default(); let model = ModelConfig::new(10, 512).init::(&device); diff --git a/examples/guide/src/bin/train.rs b/examples/guide/src/bin/train.rs index 04f1f44146..a4acf02b69 100644 --- a/examples/guide/src/bin/train.rs +++ b/examples/guide/src/bin/train.rs @@ -1,5 +1,5 @@ use burn::{ - backend::{Autodiff, Wgpu}, + backend::{Autodiff, WebGpu}, data::dataset::Dataset, optim::AdamConfig, }; @@ -10,7 +10,7 @@ use guide::{ }; fn main() { - type MyBackend = Wgpu; + type MyBackend = WebGpu; type MyAutodiffBackend = Autodiff; // Create a default Wgpu device diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index 4b20507abc..a9868099f6 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -14,7 +14,7 @@ use burn::{ tensor::activation::softmax, }; -use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; +use burn::backend::wgpu::{graphics::AutoGraphicsApi, WebGpu, WgpuDevice}; use burn_candle::Candle; use serde::Serialize; @@ -37,8 +37,8 @@ pub enum ModelType { /// The model is loaded to the NdArray backend WithNdArrayBackend(Model>), - /// The model is loaded to the Wgpu backend - WithWgpuBackend(Model>), + /// The model is loaded to the WebGpu backend + WithWgpuBackend(Model>), } /// The image is 224x224 pixels with 3 channels (RGB) diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index bb4824fba9..f9f80bdb8d 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -7,10 +7,10 @@ publish = false version.workspace = true [features] -default = ["wgpu"] -cuda-jit = ["burn/cuda-jit"] -wgpu = ["burn/wgpu"] -wgpu-spirv = ["wgpu", "burn/wgpu-spirv"] +default = ["webgpu"] +cuda = ["burn/cuda"] +webgpu = ["burn/webgpu"] +vulkan = ["burn/vulkan"] ndarray = ["burn/ndarray"] [dependencies] diff --git a/examples/server/src/lib.rs b/examples/server/src/lib.rs index 70705a0876..014a5e2cf5 100644 --- a/examples/server/src/lib.rs +++ b/examples/server/src/lib.rs @@ -11,10 +11,12 @@ pub fn start() { cfg_if::cfg_if! { if #[cfg(feature = "ndarray")]{ burn::server::start::(Default::default(), port); - } else if #[cfg(feature = "cuda-jit")]{ - burn::server::start::(Default::default(), port); - } else if #[cfg(feature = "wgpu")] { - burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "cuda")]{ + burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "webgpu")] { + burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "vulkan")] { + burn::server::start::(Default::default(), port); } else { panic!("No backend selected, can't start server on port {port}"); } diff --git a/examples/text-classification/Cargo.toml b/examples/text-classification/Cargo.toml index 4ec5d7c89a..043c61672d 100644 --- a/examples/text-classification/Cargo.toml +++ b/examples/text-classification/Cargo.toml @@ -16,10 +16,10 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] -wgpu-spirv = ["wgpu", "burn/wgpu-spirv"] +vulkan = ["wgpu", "burn/vulkan"] remote = ["burn/remote"] -cuda-jit = ["burn/cuda-jit"] -hip-jit = ["burn/hip-jit"] +cuda = ["burn/cuda"] +hip = ["burn/hip"] [dependencies] # Burn diff --git a/examples/text-classification/README.md b/examples/text-classification/README.md index 8bc611361f..9d62606706 100644 --- a/examples/text-classification/README.md +++ b/examples/text-classification/README.md @@ -102,6 +102,6 @@ cd burn # Use the --release flag to really speed up training. # AG News -cargo run --example ag-news-train --release --features cuda-jit # Train on the ag news dataset -cargo run --example ag-news-infer --release --features cuda-jit # Run inference on the ag news dataset +cargo run --example ag-news-train --release --features cuda # Train on the ag news dataset +cargo run --example ag-news-infer --release --features cuda # Run inference on the ag news dataset ``` diff --git a/examples/text-classification/examples/ag-news-infer.rs b/examples/text-classification/examples/ag-news-infer.rs index 9af5c6c6eb..77626e0b60 100644 --- a/examples/text-classification/examples/ag-news-infer.rs +++ b/examples/text-classification/examples/ag-news-infer.rs @@ -81,13 +81,13 @@ mod wgpu { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::{launch, ElemType}; - use burn::backend::{cuda_jit::CudaDevice, CudaJit}; + use burn::backend::{cuda::CudaDevice, Cuda}; pub fn run() { - launch::>(CudaDevice::default()); + launch::>(CudaDevice::default()); } } @@ -105,6 +105,6 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); } diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 1be9803a15..9a9cab44bd 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -103,18 +103,18 @@ mod remote { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::{launch, ElemType}; - use burn::backend::{Autodiff, CudaJit}; + use burn::backend::{Autodiff, Cuda}; pub fn run() { - launch::>>(vec![Default::default()]); + launch::>>(vec![Default::default()]); } } -#[cfg(feature = "hip-jit")] -mod hip_jit { +#[cfg(feature = "hip")] +mod hip { use crate::{launch, ElemType}; use burn::backend::{Autodiff, HipJit}; @@ -137,10 +137,10 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); - #[cfg(feature = "hip-jit")] - hip_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); + #[cfg(feature = "hip")] + hip::run(); #[cfg(feature = "remote")] remote::run(); } diff --git a/examples/wgan/Cargo.toml b/examples/wgan/Cargo.toml index 48d5680f51..d6ee6345b1 100644 --- a/examples/wgan/Cargo.toml +++ b/examples/wgan/Cargo.toml @@ -11,7 +11,7 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] -cuda-jit = ["burn/cuda-jit"] +cuda = ["burn/cuda"] [dependencies] burn = { path = "../../crates/burn", features=["train", "vision"] } diff --git a/examples/wgan/README.md b/examples/wgan/README.md index d7252ba520..0828145f61 100644 --- a/examples/wgan/README.md +++ b/examples/wgan/README.md @@ -12,7 +12,7 @@ Please note that better performance maybe gained by adopting a convolution layer ```sh # Cuda backend -cargo run --example wgan-mnist --release --features cuda-jit +cargo run --example wgan-mnist --release --features cuda # Wgpu backend cargo run --example wgan-mnist --release --features wgpu @@ -36,5 +36,5 @@ cargo run --example wgan-mnist --release --features ndarray-blas-netlib # f32 To generate a sample of images, you can use `wgan-generate`. The same feature flags are used to select a backend. ```sh -cargo run --example wgan-generate --release --features cuda-jit +cargo run --example wgan-generate --release --features cuda ``` diff --git a/examples/wgan/examples/wgan-generate.rs b/examples/wgan/examples/wgan-generate.rs index fa66623ca3..1b4a51a535 100644 --- a/examples/wgan/examples/wgan-generate.rs +++ b/examples/wgan/examples/wgan-generate.rs @@ -66,13 +66,13 @@ mod wgpu { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::launch; - use burn::backend::{Autodiff, CudaJit}; + use burn::backend::{Autodiff, Cuda}; pub fn run() { - launch::>(Default::default()); + launch::>(Default::default()); } } @@ -90,6 +90,6 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); } diff --git a/examples/wgan/examples/wgan-mnist.rs b/examples/wgan/examples/wgan-mnist.rs index d964b07844..787acfec94 100644 --- a/examples/wgan/examples/wgan-mnist.rs +++ b/examples/wgan/examples/wgan-mnist.rs @@ -78,13 +78,13 @@ mod wgpu { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::launch; - use burn::backend::{cuda_jit::CudaDevice, Autodiff, CudaJit}; + use burn::backend::{cuda::CudaDevice, Autodiff, Cuda}; pub fn run() { - launch::>(CudaDevice::default()); + launch::>(CudaDevice::default()); } } @@ -102,6 +102,6 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); } diff --git a/xtask/src/commands/test.rs b/xtask/src/commands/test.rs index 47e50f80ed..5b94b2909e 100644 --- a/xtask/src/commands/test.rs +++ b/xtask/src/commands/test.rs @@ -83,7 +83,7 @@ pub(crate) fn handle_command( vec!["--features", "test-wgpu-spirv"], None, None, - "std wgpu-spirv", + "std vulkan", )?; } }