Skip to content

Commit

Permalink
Move testgen to burn-jit
Browse files Browse the repository at this point in the history
  • Loading branch information
wingertge committed Jan 28, 2025
1 parent d292f85 commit 9e65150
Show file tree
Hide file tree
Showing 8 changed files with 8 additions and 19 deletions.
2 changes: 0 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions crates/burn-cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@ log = { workspace = true }
burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [
"export_tests",
] }
burn-vision = { path = "../burn-vision", version = "0.17.0", default-features = false, features = [
"export_tests",
] }
paste = { workspace = true }


Expand Down
1 change: 0 additions & 1 deletion crates/burn-cuda/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,4 @@ mod tests {

// TODO: Add tests for bf16
burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]);
burn_vision::testgen_all!();
}
3 changes: 2 additions & 1 deletion crates/burn-jit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ export_tests = [
"serial_test",
"burn-autodiff/export_tests",
"burn-tensor/export_tests",
"burn-vision?/export_tests",
"burn-vision/export_tests",
"burn-ndarray",
"fusion",
"vision",
"paste",
]
fusion = ["burn-fusion"]
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-jit/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub use burn_autodiff;
pub use burn_fusion;
pub use burn_ndarray;
pub use burn_tensor;
pub use burn_vision;
pub use serial_test;

#[macro_export]
Expand All @@ -43,7 +44,10 @@ macro_rules! testgen_all {
};
([$($float:ident),*], [$($int:ident),*], [$($bool:ident),*]) => {
mod jit {
pub use $crate::tests::burn_vision;

burn_jit::testgen_jit!([$($float),*], [$($int),*], [$($bool),*]);
burn_vision::testgen_all!();

mod kernel {
use super::*;
Expand Down
6 changes: 0 additions & 6 deletions crates/burn-vision/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@ macro_rules! testgen_all {
() => {
use burn_tensor::{Bool, Float, Int};

pub type TestBackend = JitBackend<TestRuntime, f32, i32, u32>;

type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, Int>;
type TestTensorBool<const D: usize> = burn_tensor::Tensor<TestBackend, D, Bool>;

pub mod vision {
pub use super::*;

Expand Down
7 changes: 2 additions & 5 deletions crates/burn-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ std = ["burn-jit/std", "cubecl/std"]
template = ["burn-jit/template", "cubecl/template"]

# Backends
webgpu = ["cubecl-wgsl"]
vulkan = ["cubecl-spirv"]
webgpu = ["cubecl-wgsl"]

# Compilers
cubecl-wgsl = []
cubecl-spirv = ["cubecl/wgpu-spirv"]
cubecl-wgsl = []

[dependencies]
cubecl = { workspace = true, features = ["wgpu"] }
Expand All @@ -42,9 +42,6 @@ burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features =
burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [
"export_tests",
] }
burn-vision = { path = "../burn-vision", version = "0.17.0", features = [
"export_tests",
] }
half = { workspace = true }
paste = { workspace = true }

Expand Down
1 change: 0 additions & 1 deletion crates/burn-wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,5 +124,4 @@ mod tests {
burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]);
#[cfg(not(feature = "vulkan"))]
burn_jit::testgen_all!([f32], [i32], [u32]);
burn_vision::testgen_all!();
}

0 comments on commit 9e65150

Please sign in to comment.