Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up -jit suffix in feature flags and modules #2705

Merged
merged 6 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ candle-accelerate = ["burn/candle", "burn/accelerate"]
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
cuda-jit = ["burn/cuda-jit"]
cuda-jit-fusion = ["cuda-jit", "burn/fusion"]
cuda = ["burn/cuda"]
cuda-fusion = ["cuda", "burn/fusion"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
hip-jit = ["burn/hip-jit"]
hip = ["burn/hip"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
Expand All @@ -27,7 +27,7 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
wgpu-spirv = ["burn/wgpu-spirv", "burn/autotune"]
wgpu-spirv = ["burn/vulkan", "burn/autotune"]
wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"]

[dependencies]
Expand Down
20 changes: 10 additions & 10 deletions backend-comparison/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ macro_rules! bench_on_backend {
let feature_name = "wgpu-spirv";
#[cfg(feature = "wgpu-spirv-fusion")]
let feature_name = "wgpu-spirv-fusion";
#[cfg(feature = "cuda-jit")]
let feature_name = "cuda-jit";
#[cfg(feature = "cuda-jit-fusion")]
let feature_name = "cuda-jit-fusion";
#[cfg(feature = "hip-jit")]
let feature_name = "hip-jit";
#[cfg(feature = "cuda")]
let feature_name = "cuda";
#[cfg(feature = "cuda-fusion")]
let feature_name = "cuda-fusion";
#[cfg(feature = "hip")]
let feature_name = "hip";

#[cfg(any(feature = "wgpu"))]
{
Expand Down Expand Up @@ -172,16 +172,16 @@ macro_rules! bench_on_backend {
$fn_name::<Candle>(&device, feature_name, url, token);
}

#[cfg(feature = "cuda-jit")]
#[cfg(feature = "cuda")]
{
use burn::backend::cuda_jit::{Cuda, CudaDevice};
use burn::backend::cuda::{Cuda, CudaDevice};

$fn_name::<Cuda<half::f16>>(&CudaDevice::default(), feature_name, url, token);
}

#[cfg(feature = "hip-jit")]
#[cfg(feature = "hip")]
{
use burn::backend::hip_jit::{Hip, HipDevice};
use burn::backend::hip::{Hip, HipDevice};

$fn_name::<Hip<half::f16>>(&HipDevice::default(), feature_name, url, token);
}
Expand Down
16 changes: 8 additions & 8 deletions crates/burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ doc = [
"ndarray",
"tch",
"wgpu",
"cuda-jit",
"hip-jit",
"cuda",
"hip",
"audio",
"vision",
"autodiff",
Expand Down Expand Up @@ -100,12 +100,12 @@ template = ["burn-wgpu?/template"]

candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
cuda-jit = ["burn-cuda"]
hip-jit = ["burn-hip"]
cuda = ["burn-cuda"]
hip = ["burn-hip"]
ndarray = ["burn-ndarray"]
tch = ["burn-tch"]
wgpu = ["burn-wgpu"]
wgpu-spirv = ["wgpu", "burn-wgpu/spirv"]
vulkan = ["wgpu", "burn-wgpu/spirv"]

# Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files.
record-item-custom-serde = ["thiserror", "regex"]
Expand All @@ -116,13 +116,13 @@ experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]
# Backwards compatibility with previous serialized data format.
record-backward-compat = []

test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray.
test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray.
test-cuda = ["cuda"] # To use cuda during testing, default uses ndarray.
test-hip = ["hip"] # To use hip during testing, default uses ndarray.
test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.
test-wgpu-spirv = [
"test-wgpu",
"wgpu-spirv",
"vulkan",
] # To use wgpu-spirv during testing, default uses ndarray.

[dependencies]
Expand Down
16 changes: 8 additions & 8 deletions crates/burn-core/src/backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,23 @@ pub use burn_wgpu as wgpu;
#[cfg(feature = "wgpu")]
pub use burn_wgpu::Wgpu;

#[cfg(feature = "cuda-jit")]
pub use burn_cuda as cuda_jit;
#[cfg(feature = "cuda")]
pub use burn_cuda as cuda;

#[cfg(feature = "cuda-jit")]
pub use burn_cuda::Cuda as CudaJit;
#[cfg(feature = "cuda")]
pub use burn_cuda::Cuda;

#[cfg(feature = "candle")]
pub use burn_candle as candle;

#[cfg(feature = "candle")]
pub use burn_candle::Candle;

#[cfg(feature = "hip-jit")]
pub use burn_hip as hip_jit;
#[cfg(feature = "hip")]
pub use burn_hip as hip;

#[cfg(feature = "hip-jit")]
pub use burn_hip::Hip as HipJit;
#[cfg(feature = "hip")]
pub use burn_hip::Hip;

#[cfg(feature = "tch")]
pub use burn_tch as libtorch;
Expand Down
6 changes: 3 additions & 3 deletions crates/burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ openblas-system = ["burn-core/openblas-system"]
template = ["burn-core/template"]

candle = ["burn-core/candle"]
cuda-jit = ["burn-core/cuda-jit"]
hip-jit = ["burn-core/hip-jit"]
cuda = ["burn-core/cuda"]
hip = ["burn-core/hip"]
ndarray = ["burn-core/ndarray"]
remote = ["burn-core/remote"]
router = ["burn-core/router"]
server = ["burn-core/server"]
tch = ["burn-core/tch"]
wgpu = ["burn-core/wgpu"]
wgpu-spirv = ["burn-core/wgpu-spirv"]
vulkan = ["burn-core/vulkan"]

# Network utils
network = ["burn-core/network"]
Expand Down
5 changes: 3 additions & 2 deletions crates/burn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,13 @@
//! - `vision`: Enables vision datasets (MnistDataset)
//! - Backends
//! - `wgpu`: Makes available the WGPU backend
//! - `wgpu-spirv`: Makes available the `wgpu` backend with the alternative SPIR-V compiler
//! - `vulkan`: Makes available the `wgpu` backend with the alternative SPIR-V compiler
//! - `cuda`: Makes available the CUDA backend
//! - `hip`: Makes available the HIP backend
//! - `candle`: Makes available the Candle backend
//! - `tch`: Makes available the LibTorch backend
//! - `ndarray`: Makes available the NdArray backend
//! - Backend specifications
//! - `cuda`: If supported, CUDA will be used
//! - `accelerate`: If supported, Accelerate will be used
//! - `blas-netlib`: If supported, Blas Netlib will be use
//! - `openblas`: If supported, Openblas will be use
Expand Down
4 changes: 2 additions & 2 deletions examples/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ version.workspace = true

[features]
default = ["wgpu"]
cuda-jit = ["burn/cuda-jit"]
cuda = ["burn/cuda"]
wgpu = ["burn/wgpu"]
wgpu-spirv = ["wgpu", "burn/wgpu-spirv"]
vulkan = ["wgpu", "burn/vulkan"]
ndarray = ["burn/ndarray"]

[dependencies]
Expand Down
4 changes: 2 additions & 2 deletions examples/server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ pub fn start() {
cfg_if::cfg_if! {
if #[cfg(feature = "ndarray")]{
burn::server::start::<burn::backend::NdArray>(Default::default(), port);
} else if #[cfg(feature = "cuda-jit")]{
burn::server::start::<burn::backend::CudaJit>(Default::default(), port);
} else if #[cfg(feature = "cuda")]{
burn::server::start::<burn::backend::Cuda>(Default::default(), port);
} else if #[cfg(feature = "wgpu")] {
burn::server::start::<burn::backend::Wgpu>(Default::default(), port);
} else {
Expand Down
6 changes: 3 additions & 3 deletions examples/text-classification/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
wgpu-spirv = ["wgpu", "burn/wgpu-spirv"]
vulkan = ["wgpu", "burn/vulkan"]
remote = ["burn/remote"]
cuda-jit = ["burn/cuda-jit"]
hip-jit = ["burn/hip-jit"]
cuda = ["burn/cuda"]
hip = ["burn/hip"]

[dependencies]
# Burn
Expand Down
4 changes: 2 additions & 2 deletions examples/text-classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,6 @@ cd burn
# Use the --release flag to really speed up training.

# AG News
cargo run --example ag-news-train --release --features cuda-jit # Train on the ag news dataset
cargo run --example ag-news-infer --release --features cuda-jit # Run inference on the ag news dataset
cargo run --example ag-news-train --release --features cuda # Train on the ag news dataset
cargo run --example ag-news-infer --release --features cuda # Run inference on the ag news dataset
```
12 changes: 6 additions & 6 deletions examples/text-classification/examples/ag-news-infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ mod wgpu {
}
}

#[cfg(feature = "cuda-jit")]
mod cuda_jit {
#[cfg(feature = "cuda")]
mod cuda {
use crate::{launch, ElemType};
use burn::backend::{cuda_jit::CudaDevice, CudaJit};
use burn::backend::{cuda::CudaDevice, Cuda};

pub fn run() {
launch::<CudaJit<ElemType, i32>>(CudaDevice::default());
launch::<Cuda<ElemType, i32>>(CudaDevice::default());
}
}

Expand All @@ -105,6 +105,6 @@ fn main() {
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda-jit")]
cuda_jit::run();
#[cfg(feature = "cuda")]
cuda::run();
}
20 changes: 10 additions & 10 deletions examples/text-classification/examples/ag-news-train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,18 +101,18 @@ mod remote {
}
}

#[cfg(feature = "cuda-jit")]
mod cuda_jit {
#[cfg(feature = "cuda")]
mod cuda {
use crate::{launch, ElemType};
use burn::backend::{Autodiff, CudaJit};
use burn::backend::{Autodiff, Cuda};

pub fn run() {
launch::<Autodiff<CudaJit<ElemType, i32>>>(vec![Default::default()]);
launch::<Autodiff<Cuda<ElemType, i32>>>(vec![Default::default()]);
}
}

#[cfg(feature = "hip-jit")]
mod hip_jit {
#[cfg(feature = "hip")]
mod hip {
use crate::{launch, ElemType};
use burn::backend::{Autodiff, HipJit};

Expand All @@ -135,10 +135,10 @@ fn main() {
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda-jit")]
cuda_jit::run();
#[cfg(feature = "hip-jit")]
hip_jit::run();
#[cfg(feature = "cuda")]
cuda::run();
#[cfg(feature = "hip")]
hip::run();
#[cfg(feature = "remote")]
remote::run();
}
2 changes: 1 addition & 1 deletion examples/wgan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
cuda-jit = ["burn/cuda-jit"]
cuda = ["burn/cuda"]

[dependencies]
burn = { path = "../../crates/burn", features=["train", "vision"] }
Expand Down
4 changes: 2 additions & 2 deletions examples/wgan/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Please note that better performance maybe gained by adopting a convolution layer

```sh
# Cuda backend
cargo run --example wgan-mnist --release --features cuda-jit
cargo run --example wgan-mnist --release --features cuda

# Wgpu backend
cargo run --example wgan-mnist --release --features wgpu
Expand All @@ -36,5 +36,5 @@ cargo run --example wgan-mnist --release --features ndarray-blas-netlib # f32
To generate a sample of images, you can use `wgan-generate`. The same feature flags are used to select a backend.

```sh
cargo run --example wgan-generate --release --features cuda-jit
cargo run --example wgan-generate --release --features cuda
```
12 changes: 6 additions & 6 deletions examples/wgan/examples/wgan-generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,13 @@ mod wgpu {
}
}

#[cfg(feature = "cuda-jit")]
mod cuda_jit {
#[cfg(feature = "cuda")]
mod cuda {
use crate::launch;
use burn::backend::{Autodiff, CudaJit};
use burn::backend::{Autodiff, Cuda};

pub fn run() {
launch::<Autodiff<CudaJit>>(Default::default());
launch::<Autodiff<Cuda>>(Default::default());
}
}

Expand All @@ -90,6 +90,6 @@ fn main() {
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda-jit")]
cuda_jit::run();
#[cfg(feature = "cuda")]
cuda::run();
}
12 changes: 6 additions & 6 deletions examples/wgan/examples/wgan-mnist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ mod wgpu {
}
}

#[cfg(feature = "cuda-jit")]
mod cuda_jit {
#[cfg(feature = "cuda")]
mod cuda {
use crate::launch;
use burn::backend::{cuda_jit::CudaDevice, Autodiff, CudaJit};
use burn::backend::{cuda::CudaDevice, Autodiff, Cuda};

pub fn run() {
launch::<Autodiff<CudaJit>>(CudaDevice::default());
launch::<Autodiff<Cuda>>(CudaDevice::default());
}
}

Expand All @@ -102,6 +102,6 @@ fn main() {
tch_cpu::run();
#[cfg(feature = "wgpu")]
wgpu::run();
#[cfg(feature = "cuda-jit")]
cuda_jit::run();
#[cfg(feature = "cuda")]
cuda::run();
}
2 changes: 1 addition & 1 deletion xtask/src/commands/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ pub(crate) fn handle_command(
vec!["--features", "test-wgpu-spirv"],
None,
None,
"std wgpu-spirv",
"std vulkan",
)?;
}
}
Expand Down
Loading