diff --git a/Cargo.lock b/Cargo.lock index 61741c4552..008d53f5f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -255,10 +255,10 @@ dependencies = [ "arboard", "burn", "burn-common", - "burn-cuda", "burn-wgpu", "clap 4.5.9", "colored", + "cubecl", "derive-new", "dirs 5.0.1", "github-device-flow", @@ -469,41 +469,17 @@ dependencies = [ name = "burn-common" version = "0.14.0" dependencies = [ + "cubecl-common", "dashmap", "data-encoding", - "derive-new", "getrandom", "indicatif", - "pollster", - "rand", "rayon", "reqwest 0.12.5", - "serde", - "spin", "tokio", "web-time", ] -[[package]] -name = "burn-compute" -version = "0.14.0" -dependencies = [ - "async-channel", - "burn-common", - "derive-new", - "dirs 5.0.1", - "hashbrown 0.14.5", - "log", - "md5", - "pollster", - "rand", - "serde", - "serde_json", - "serial_test", - "spin", - "web-time", -] - [[package]] name = "burn-core" version = "0.14.0" @@ -534,49 +510,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "burn-cube" -version = "0.14.0" -dependencies = [ - "burn-compute", - "burn-cube-macros", - "burn-tensor", - "bytemuck", - "derive-new", - "half", - "log", - "num-traits", - "serde", - "trybuild", -] - -[[package]] -name = "burn-cube-macros" -version = "0.14.0" -dependencies = [ - "derive-new", - "proc-macro2", - "quote", - "syn 2.0.71", -] - -[[package]] -name = "burn-cuda" -version = "0.14.0" -dependencies = [ - "burn-common", - "burn-compute", - "burn-cube", - "burn-fusion", - "burn-jit", - "burn-tensor", - "bytemuck", - "cudarc", - "derive-new", - "half", - "log", -] - [[package]] name = "burn-dataset" version = "0.14.0" @@ -653,7 +586,7 @@ dependencies = [ "thiserror", "tracing-core", "tracing-subscriber", - "zip 2.1.3", + "zip 2.1.4", ] [[package]] @@ -662,13 +595,12 @@ version = "0.14.0" dependencies = [ "burn-autodiff", "burn-common", - "burn-compute", - "burn-cube", "burn-fusion", "burn-ndarray", "burn-tensor", "burn-tensor-testgen", "bytemuck", + "cubecl", "derive-new", "half", "hashbrown 0.14.5", @@ -727,6 +659,7 @@ dependencies = [ "burn-common", "burn-tensor-testgen", "bytemuck", + "cubecl", "derive-new", "half", "hashbrown 0.14.5", @@ -768,19 +701,10 @@ dependencies = [ name = "burn-wgpu" version = "0.14.0" dependencies = [ - "async-channel", - "burn-common", - "burn-compute", - "burn-cube", "burn-fusion", "burn-jit", "burn-tensor", - "bytemuck", - "derive-new", - "hashbrown 0.14.5", - "log", - "pollster", - "wgpu", + "cubecl", ] [[package]] @@ -850,9 +774,9 @@ dependencies = [ [[package]] name = "candle-core" -version = "0.5.1" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "311d8dbe293aa3b5c34f6a57727fafd67d17a74fa8b65276501237c233b34ffd" +checksum = "d5b18de020c2729dbf7ac390325312644808b6ba9b7962f1f724e9185b1d53c7" dependencies = [ "accelerate-src", "byteorder", @@ -877,18 +801,18 @@ dependencies = [ [[package]] name = "candle-kernels" -version = "0.5.1" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3b4b048ca298fb8be90b0f4d0fe68bdca9de956ab52bb6e381463d955f2b661" +checksum = "8bc0a71be8b2f0950b63fd602a5e10a74a4f94a5fd63059ae455e96163389488" dependencies = [ "bindgen_cuda", ] [[package]] name = "candle-metal-kernels" -version = "0.5.1" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d31136c9541c82b7de0937c9a58210ada38e17d70810e0eacc0a99d849d848d" +checksum = "f889aacd02fd999620a0435133d7cf3b58c81ef9dd5e47c38939b7a72345ea86" dependencies = [ "metal 0.27.0", "once_cell", @@ -922,9 +846,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.5" +version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "324c74f2155653c90b04f25b2a47a8a631360cb908f92a772695f430c7e31052" +checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" dependencies = [ "jobserver", "libc", @@ -1362,11 +1286,124 @@ dependencies = [ "memchr", ] +[[package]] +name = "cubecl" +version = "0.1.1" +source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c" +dependencies = [ + "cubecl-core", + "cubecl-cuda", + "cubecl-linalg", + "cubecl-wgpu", +] + +[[package]] +name = "cubecl-common" +version = "0.1.1" +source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c" +dependencies = [ + "derive-new", + "getrandom", + "pollster", + "rand", + "serde", + "spin", + "web-time", +] + +[[package]] +name = "cubecl-core" +version = "0.1.1" +source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c" +dependencies = [ + "bytemuck", + "cubecl-macros", + "cubecl-runtime", + "derive-new", + "half", + "log", + "num-traits", + "serde", +] + +[[package]] +name = "cubecl-cuda" +version = "0.1.1" +source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c" +dependencies = [ + "bytemuck", + "cubecl-common", + "cubecl-core", + "cubecl-runtime", + "cudarc", + "derive-new", + "half", + "log", +] + +[[package]] +name = "cubecl-linalg" +version = "0.1.1" +source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c" +dependencies = [ + "bytemuck", + "cubecl-core", + "cubecl-runtime", + "half", +] + +[[package]] +name = "cubecl-macros" +version = "0.1.1" +source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c" +dependencies = [ + "derive-new", + "proc-macro2", + "quote", + "syn 2.0.71", +] + +[[package]] +name = "cubecl-runtime" +version = "0.1.1" +source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c" +dependencies = [ + "async-channel", + "cubecl-common", + "derive-new", + "dirs 5.0.1", + "hashbrown 0.14.5", + "log", + "md5", + "pollster", + "serde", + "serde_json", + "spin", + "web-time", +] + +[[package]] +name = "cubecl-wgpu" +version = "0.1.1" +source = "git+https://github.com/tracel-ai/cubecl#49d844b3d3281100a61a33a4d7865046fcd44b2c" +dependencies = [ + "async-channel", + "bytemuck", + "cubecl-common", + "cubecl-core", + "cubecl-runtime", + "derive-new", + "hashbrown 0.14.5", + "log", + "pollster", + "wgpu", +] + [[package]] name = "cudarc" -version = "0.11.8" +version = "0.11.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56028291ec3b0f6711e2e1b2d597484d359833dcb68331ce89e538012f835c4" +checksum = "e395cd01168d63af826749573071f3c5069b338ae473cab355d22db0b2bb5a0d" dependencies = [ "half", "libloading 0.8.4", @@ -1421,6 +1458,7 @@ version = "0.14.0" dependencies = [ "burn", "bytemuck", + "cubecl", "derive-new", "log", "serde", @@ -2020,15 +2058,6 @@ dependencies = [ "slab", ] -[[package]] -name = "gelu" -version = "0.14.0" -dependencies = [ - "burn-cube", - "burn-cuda", - "burn-wgpu", -] - [[package]] name = "gemm" version = "0.17.1" @@ -2846,6 +2875,7 @@ dependencies = [ "burn-import", "burn-wgpu", "console_error_panic_hook", + "cubecl", "js-sys", "log", "serde", @@ -4909,9 +4939,9 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "sdd" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eb0dde0ccd15e337a3cf738a9a38115c6d8e74795d074e73973dad3d229a897" +checksum = "85f05a494052771fc5bd0619742363b5e24e5ad72ab3111ec2e27925b8edc5f3" [[package]] name = "security-framework" @@ -5706,7 +5736,7 @@ dependencies = [ "serde", "serde_spanned", "toml_datetime", - "winnow 0.6.13", + "winnow 0.6.14", ] [[package]] @@ -5826,20 +5856,6 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" -[[package]] -name = "trybuild" -version = "1.0.97" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b1e5645f2ee8025c2f1d75e1138f2dd034d74e6ba54620f3c569ba2a2a1ea06" -dependencies = [ - "glob", - "serde", - "serde_derive", - "serde_json", - "termcolor", - "toml", -] - [[package]] name = "typenum" version = "1.17.0" @@ -6478,9 +6494,9 @@ dependencies = [ [[package]] name = "winnow" -version = "0.6.13" +version = "0.6.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59b5e5f6c299a3c7890b876a2a587f3115162487e704907d9b6cd29473052ba1" +checksum = "374ec40a2d767a3c1b4972d9475ecd557356637be906f2cb3f7fe17a6eb5e22f" dependencies = [ "memchr", ] @@ -6699,9 +6715,9 @@ dependencies = [ [[package]] name = "zip" -version = "2.1.3" +version = "2.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "775a2b471036342aa69bc5a602bc889cb0a06cda00477d0c69566757d5553d39" +checksum = "e29ab4097989787b2029a5981c41b7bfb427b5a601e23f455daacb4d0360a9e9" dependencies = [ "aes", "arbitrary", diff --git a/Cargo.toml b/Cargo.toml index f3e887620e..9e2d7e3cd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ members = [ exclude = [ "examples/notebook", - # "crates/burn-cuda" # comment this line to work on burn-cuda + "crates/burn-cuda", # comment this line to work on burn-cuda ] [workspace.package] @@ -27,7 +27,7 @@ license = "MIT OR Apache-2.0" [workspace.dependencies] bytemuck = "1.16.1" -candle-core = { version = "0.5.1" } +candle-core = { version = "0.6.0" } clap = { version = "4.5.9", features = ["derive"] } colored = "2.1.0" console_error_panic_hook = "0.1.7" @@ -140,6 +140,9 @@ nvml-wrapper = "0.10.0" sysinfo = "0.30.13" systemstat = "0.2.3" +cubecl = { version = "0.1.1", git = "https://github.com/tracel-ai/cubecl", default-features = false } +cubecl-common = { version = "0.1.1", git = "https://github.com/tracel-ai/cubecl", default-features = false } [profile.dev] debug = 0 # Speed up compilation time and not necessary. +opt-level = 2 diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 82ce12766d..c450724cda 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -24,14 +24,14 @@ tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu", "burn/autotune"] wgpu-fusion = ["wgpu", "burn/fusion"] -cuda-jit = ["burn-cuda"] +# cuda-jit = ["burn-cuda"] [dependencies] arboard = { workspace = true } burn = { path = "../crates/burn", default-features = false } burn-common = { path = "../crates/burn-common", version = "0.14.0" } -burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.14.0" } -burn-cuda = { path = "../crates/burn-cuda", version = "0.14.0", optional = true } +burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.14.0", optional = true } +# burn-cuda = { path = "../crates/burn-cuda", version = "0.14.0", optional = true } clap = { workspace = true } colored = { workspace = true } derive-new = { workspace = true } @@ -49,6 +49,7 @@ strum_macros = { workspace = true } sysinfo = { workspace = true, features = ["serde"] } wgpu = { workspace = true } wsl = { workspace = true } +cubecl = { workspace = true, features = ["wgpu"] } [dev-dependencies] rstest = { workspace = true } diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index 78feeab2ff..efbdf88d16 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -29,7 +29,7 @@ impl Benchmark for MatmulBenchmark { } fn execute(&self, (lhs, rhs): Self::Args) { - lhs.clone().transpose().matmul(rhs.clone()); + lhs.clone().matmul(rhs.clone()); } fn prepare(&self) -> Self::Args { @@ -52,11 +52,11 @@ fn bench( token: Option<&str>, ) { const D: usize = 3; - let batch_size = 32; - let m = 256; - let k = 1024; - let n = 256; - let shape_lhs = [batch_size, k, m].into(); + let batch_size = 8; + let m = 2048; + let k = 2048; + let n = 2048; + let shape_lhs = [batch_size, m, k].into(); let shape_rhs = [batch_size, k, n].into(); let benchmark = MatmulBenchmark::::new(shape_lhs, shape_rhs, device.clone()); diff --git a/backend-comparison/src/persistence/system_info.rs b/backend-comparison/src/persistence/system_info.rs index 1faebedc4c..171877f69f 100644 --- a/backend-comparison/src/persistence/system_info.rs +++ b/backend-comparison/src/persistence/system_info.rs @@ -1,5 +1,5 @@ use burn::serde::{Deserialize, Serialize}; -use burn_wgpu::GraphicsApi; +use cubecl::wgpu::GraphicsApi; use std::collections::HashSet; use sysinfo; use wgpu; @@ -51,7 +51,7 @@ impl BenchmarkSystemInfo { fn enumerate_gpus() -> Vec { let instance = wgpu::Instance::default(); let adapters: Vec = instance - .enumerate_adapters(burn_wgpu::AutoGraphicsApi::backend().into()) + .enumerate_adapters(cubecl::wgpu::AutoGraphicsApi::backend().into()) .into_iter() .filter(|adapter| { let info = adapter.get_info(); diff --git a/crates/burn-common/Cargo.toml b/crates/burn-common/Cargo.toml index fc7d231dde..aa3f44b010 100644 --- a/crates/burn-common/Cargo.toml +++ b/crates/burn-common/Cargo.toml @@ -11,8 +11,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-common" version.workspace = true [features] -default = ["std"] -std = ["rand/std", "data-encoding/std", "dep:pollster"] +default = ["std", "cubecl-common/default"] +std = ["cubecl-common/std"] doc = ["default"] network = ["dep:indicatif", "dep:reqwest", "dep:tokio"] rayon = ["dep:rayon"] @@ -23,13 +23,7 @@ web-time = { version = "1.1.0" } [dependencies] -# ** Please make sure all dependencies support no_std when std is disabled ** -rand = { workspace = true } -spin = { workspace = true } # using in place of use std::sync::Mutex; -derive-new = { workspace = true } -serde = { workspace = true } data-encoding = { workspace = true } -pollster = { workspace = true, optional = true } # Network downloader indicatif = { workspace = true, optional = true } @@ -38,6 +32,7 @@ tokio = { workspace = true, optional = true } # Parallel rayon = { workspace = true, optional = true } +cubecl-common = { workspace = true, default-features = false } [dev-dependencies] dashmap = { workspace = true } diff --git a/crates/burn-common/src/benchmark.rs b/crates/burn-common/src/benchmark.rs deleted file mode 100644 index f446d9b63e..0000000000 --- a/crates/burn-common/src/benchmark.rs +++ /dev/null @@ -1,307 +0,0 @@ -use alloc::format; -use alloc::string::String; -use alloc::vec; -use alloc::vec::Vec; -use core::fmt::Display; -use core::time::Duration; - -use serde::{Deserialize, Serialize}; - -#[cfg(all(not(target_family = "wasm"), feature = "std"))] -use std::time::Instant; -#[cfg(all(target_family = "wasm", feature = "std"))] -use web_time::Instant; - -/// Results of a benchmark run. -#[derive(new, Debug, Default, Clone, Serialize, Deserialize)] -pub struct BenchmarkDurations { - /// All durations of the run, in the order they were benchmarked - pub durations: Vec, -} - -impl BenchmarkDurations { - /// Returns a tuple of durations: (min, max, median) - fn min_max_median_durations(&self) -> (Duration, Duration, Duration) { - let mut sorted = self.durations.clone(); - sorted.sort(); - let min = *sorted.first().unwrap(); - let max = *sorted.last().unwrap(); - let median = *sorted.get(sorted.len() / 2).unwrap(); - (min, max, median) - } - - /// Returns the median duration among all durations - pub(crate) fn mean_duration(&self) -> Duration { - self.durations.iter().sum::() / self.durations.len() as u32 - } - - /// Returns the variance durations for the durations - pub(crate) fn variance_duration(&self, mean: Duration) -> Duration { - let var = self - .durations - .iter() - .map(|duration| { - let tmp = duration.as_secs_f64() - mean.as_secs_f64(); - Duration::from_secs_f64(tmp * tmp) - }) - .sum::() - / self.durations.len() as u32; - var - } -} - -impl Display for BenchmarkDurations { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let computed = BenchmarkComputations::new(self); - let BenchmarkComputations { - mean, - median, - variance, - min, - max, - } = computed; - let num_sample = self.durations.len(); - - f.write_str( - format!( - " -―――――――― Result ――――――――― - Samples {num_sample} - Mean {mean:.3?} - Variance {variance:.3?} - Median {median:.3?} - Min {min:.3?} - Max {max:.3?} -―――――――――――――――――――――――――" - ) - .as_str(), - ) - } -} - -/// Computed values from benchmark durations. -#[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub struct BenchmarkComputations { - /// Mean of all the durations. - pub mean: Duration, - /// Median of all the durations. - pub median: Duration, - /// Variance of all the durations. - pub variance: Duration, - /// Minimum duration amongst all durations. - pub min: Duration, - /// Maximum duration amongst all durations. - pub max: Duration, -} - -impl BenchmarkComputations { - /// Compute duration values and return a BenchmarkComputations struct - pub fn new(durations: &BenchmarkDurations) -> Self { - let mean = durations.mean_duration(); - let (min, max, median) = durations.min_max_median_durations(); - Self { - mean, - median, - min, - max, - variance: durations.variance_duration(mean), - } - } -} - -/// Benchmark trait. -pub trait Benchmark { - /// Benchmark arguments. - type Args: Clone; - - /// Prepare the benchmark, run anything that is essential for the benchmark, but shouldn't - /// count as included in the duration. - /// - /// # Notes - /// - /// This should not include warmup, the benchmark will be run at least one time without - /// measuring the execution time. - fn prepare(&self) -> Self::Args; - /// Execute the benchmark and returns the time it took to complete. - fn execute(&self, args: Self::Args); - /// Number of samples per run required to have a statistical significance. - fn num_samples(&self) -> usize { - 10 - } - /// Name of the benchmark, should be short and it should match the name - /// defined in the crate Cargo.toml - fn name(&self) -> String; - /// The options passed to the benchmark. - fn options(&self) -> Option { - None - } - /// Shapes dimensions - fn shapes(&self) -> Vec> { - vec![] - } - /// Wait for computed to be over - fn sync(&self); - /// Run the benchmark a number of times. - fn run(&self) -> BenchmarkDurations { - #[cfg(not(feature = "std"))] - panic!("Attempting to run benchmark in a no-std environment"); - - #[cfg(feature = "std")] - { - // Warmup - let args = self.prepare(); - - self.execute(args.clone()); - self.sync(); - - let mut durations = Vec::with_capacity(self.num_samples()); - - for _ in 0..self.num_samples() { - // Prepare - self.sync(); - - // Execute the benchmark - let start = Instant::now(); - self.execute(args.clone()); - self.sync(); - let end = Instant::now(); - - // Register the duration - durations.push(end - start); - } - - BenchmarkDurations { durations } - } - } -} - -/// Result of a benchmark run, with metadata -#[derive(Default, Clone)] -pub struct BenchmarkResult { - /// Individual raw results of the run - pub raw: BenchmarkDurations, - /// Computed values for the run - pub computed: BenchmarkComputations, - /// Git commit hash of the commit in which the run occurred - pub git_hash: String, - /// Name of the benchmark - pub name: String, - /// Options passed to the benchmark - pub options: Option, - /// Shape dimensions - pub shapes: Vec>, - /// Time just before the run - pub timestamp: u128, -} - -impl Display for BenchmarkResult { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str( - format!( - " - Timestamp: {} - Git Hash: {} - Benchmarking - {}{} - ", - self.timestamp, self.git_hash, self.name, self.raw - ) - .as_str(), - ) - } -} - -#[cfg(feature = "std")] -/// Runs the given benchmark on the device and prints result and information. -pub fn run_benchmark(benchmark: BM) -> BenchmarkResult -where - BM: Benchmark, -{ - let timestamp = std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap() - .as_millis(); - let output = std::process::Command::new("git") - .args(["rev-parse", "HEAD"]) - .output() - .unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap().trim().to_string(); - let durations = benchmark.run(); - BenchmarkResult { - raw: durations.clone(), - computed: BenchmarkComputations::new(&durations), - git_hash, - name: benchmark.name(), - options: benchmark.options(), - shapes: benchmark.shapes(), - timestamp, - } -} - -#[cfg(test)] -mod tests { - use super::*; - use alloc::vec; - - #[test] - fn test_min_max_median_durations_even_number_of_samples() { - let durations = BenchmarkDurations { - durations: vec![ - Duration::new(10, 0), - Duration::new(20, 0), - Duration::new(30, 0), - Duration::new(40, 0), - Duration::new(50, 0), - ], - }; - let (min, max, median) = durations.min_max_median_durations(); - assert_eq!(min, Duration::from_secs(10)); - assert_eq!(max, Duration::from_secs(50)); - assert_eq!(median, Duration::from_secs(30)); - } - - #[test] - fn test_min_max_median_durations_odd_number_of_samples() { - let durations = BenchmarkDurations { - durations: vec![ - Duration::new(18, 5), - Duration::new(20, 0), - Duration::new(30, 0), - Duration::new(40, 0), - ], - }; - let (min, max, median) = durations.min_max_median_durations(); - assert_eq!(min, Duration::from_nanos(18000000005_u64)); - assert_eq!(max, Duration::from_secs(40)); - assert_eq!(median, Duration::from_secs(30)); - } - - #[test] - fn test_mean_duration() { - let durations = BenchmarkDurations { - durations: vec![ - Duration::new(10, 0), - Duration::new(20, 0), - Duration::new(30, 0), - Duration::new(40, 0), - ], - }; - let mean = durations.mean_duration(); - assert_eq!(mean, Duration::from_secs(25)); - } - - #[test] - fn test_variance_duration() { - let durations = BenchmarkDurations { - durations: vec![ - Duration::new(10, 0), - Duration::new(20, 0), - Duration::new(30, 0), - Duration::new(40, 0), - Duration::new(50, 0), - ], - }; - let mean = durations.mean_duration(); - let variance = durations.variance_duration(mean); - assert_eq!(variance, Duration::from_secs(200)); - } -} diff --git a/crates/burn-common/src/lib.rs b/crates/burn-common/src/lib.rs index c02f98192c..9d7fc998b2 100644 --- a/crates/burn-common/src/lib.rs +++ b/crates/burn-common/src/lib.rs @@ -5,28 +5,10 @@ //! //! This library contains common types used by other Burn crates that must be shared. -#[macro_use] -extern crate derive_new; - /// Id module contains types for unique identifiers. pub mod id; -/// Rand module contains types for random number generation for non-std environments and for -/// std environments. -pub mod rand; - -/// Stub module contains types for stubs for non-std environments and for std environments. -pub mod stub; - -/// Module for benchmarking any executable part -pub mod benchmark; - -/// Useful when you need to read async data without having to decorate each function with async -/// notation. -pub mod reader; - -/// Synchronization type module, used both by ComputeServer and Backends. -pub mod sync_type; +pub use cubecl_common::*; extern crate alloc; diff --git a/crates/burn-common/src/rand.rs b/crates/burn-common/src/rand.rs deleted file mode 100644 index c9198930c3..0000000000 --- a/crates/burn-common/src/rand.rs +++ /dev/null @@ -1,45 +0,0 @@ -pub use rand::{rngs::StdRng, Rng, SeedableRng}; - -use rand::distributions::Standard; -use rand::prelude::Distribution; - -/// Returns a seeded random number generator using entropy. -#[cfg(feature = "std")] -#[inline(always)] -pub fn get_seeded_rng() -> StdRng { - StdRng::from_entropy() -} - -/// Returns a seeded random number generator using a pre-generated seed. -#[cfg(not(feature = "std"))] -#[inline(always)] -pub fn get_seeded_rng() -> StdRng { - const CONST_SEED: u64 = 42; - StdRng::seed_from_u64(CONST_SEED) -} - -/// Generates random data from a thread-local RNG. -#[cfg(feature = "std")] -#[inline] -pub fn gen_random() -> T -where - Standard: Distribution, -{ - rand::thread_rng().gen() -} - -/// Generates random data from a mutex-protected RNG. -#[cfg(not(feature = "std"))] -#[inline] -pub fn gen_random() -> T -where - Standard: Distribution, -{ - use crate::stub::Mutex; - static RNG: Mutex> = Mutex::new(None); - let mut rng = RNG.lock().unwrap(); - if rng.is_none() { - *rng = Some(get_seeded_rng()); - } - rng.as_mut().unwrap().gen() -} diff --git a/crates/burn-common/src/reader.rs b/crates/burn-common/src/reader.rs deleted file mode 100644 index d459a6e9c5..0000000000 --- a/crates/burn-common/src/reader.rs +++ /dev/null @@ -1,54 +0,0 @@ -use alloc::{boxed::Box, sync::Arc, task::Wake, vec::Vec}; -use core::{ - future::Future, - pin::Pin, - task::{Context, Poll, Waker}, -}; - -/// A future that is used to read resources from a compute server. -pub type Reader = Pin> + Send>>; - -/// Create a reader from a concrete value. -pub fn reader_from_concrete(val: Vec) -> Reader { - Box::pin(async move { val }) -} - -struct DummyWaker; - -impl Wake for DummyWaker { - fn wake(self: Arc) {} - fn wake_by_ref(self: &Arc) {} -} - -/// Read a future synchronously. -/// -/// On WASM futures cannot block, so this only succeeds if the future returns immediately. -/// If you want to handle this error, please use -/// try_read_sync instead. -pub fn read_sync, T>(f: F) -> T { - try_read_sync(f).expect("Failed to read tensor data synchronously. This can happen on platforms that don't support blocking futures like WASM. If possible, try using an async variant of this function instead.") -} - -/// Read a future synchronously. -/// -/// On WASM futures cannot block, so this only succeeds if the future returns immediately. -/// otherwise this returns None. -pub fn try_read_sync, T>(f: F) -> Option { - // Create a dummy context. - let waker = Waker::from(Arc::new(DummyWaker)); - let mut context = Context::from_waker(&waker); - - // Pin & poll the future. Some backends don't do async readbacks, and instead immediately get - // the data. This let's us detect when a future is synchronous and doesn't require any waiting. - let mut pinned = core::pin::pin!(f); - - match pinned.as_mut().poll(&mut context) { - Poll::Ready(output) => Some(output), - // On platforms that support it, now just block on the future and drive it to completion. - #[cfg(all(not(target_family = "wasm"), feature = "std"))] - Poll::Pending => Some(pollster::block_on(pinned)), - // Otherwise, just bail and return None - this futures will have to be read back asynchronously. - #[cfg(any(target_family = "wasm", not(feature = "std")))] - Poll::Pending => None, - } -} diff --git a/crates/burn-common/src/stub.rs b/crates/burn-common/src/stub.rs deleted file mode 100644 index 1e81b58c90..0000000000 --- a/crates/burn-common/src/stub.rs +++ /dev/null @@ -1,154 +0,0 @@ -#[cfg(not(feature = "std"))] -use spin::{ - Mutex as MutexImported, MutexGuard, Once as OnceImported, RwLock as RwLockImported, - RwLockReadGuard, RwLockWriteGuard, -}; -#[cfg(feature = "std")] -use std::sync::{ - Mutex as MutexImported, MutexGuard, OnceLock as OnceImported, RwLock as RwLockImported, - RwLockReadGuard, RwLockWriteGuard, -}; - -/// A mutual exclusion primitive useful for protecting shared data -/// -/// This mutex will block threads waiting for the lock to become available. The -/// mutex can also be statically initialized or created via a [Mutex::new] -/// -/// [Mutex] wrapper to make `spin::Mutex` API compatible with `std::sync::Mutex` to swap -#[derive(Debug)] -pub struct Mutex { - inner: MutexImported, -} - -impl Mutex { - /// Creates a new mutex in an unlocked state ready for use. - #[inline(always)] - pub const fn new(value: T) -> Self { - Self { - inner: MutexImported::new(value), - } - } - - /// Locks the mutex blocking the current thread until it is able to do so. - #[inline(always)] - pub fn lock(&self) -> Result, alloc::string::String> { - #[cfg(not(feature = "std"))] - { - Ok(self.inner.lock()) - } - - #[cfg(feature = "std")] - { - self.inner.lock().map_err(|err| err.to_string()) - } - } -} - -/// A reader-writer lock which is exclusively locked for writing or shared for reading. -/// This reader-writer lock will block threads waiting for the lock to become available. -/// The lock can also be statically initialized or created via a [RwLock::new] -/// [RwLock] wrapper to make `spin::RwLock` API compatible with `std::sync::RwLock` to swap -#[derive(Debug)] -pub struct RwLock { - inner: RwLockImported, -} - -impl RwLock { - /// Creates a new reader-writer lock in an unlocked state ready for use. - #[inline(always)] - pub const fn new(value: T) -> Self { - Self { - inner: RwLockImported::new(value), - } - } - - /// Locks this rwlock with shared read access, blocking the current thread - /// until it can be acquired. - #[inline(always)] - pub fn read(&self) -> Result, alloc::string::String> { - #[cfg(not(feature = "std"))] - { - Ok(self.inner.read()) - } - #[cfg(feature = "std")] - { - self.inner.read().map_err(|err| err.to_string()) - } - } - - /// Locks this rwlock with exclusive write access, blocking the current thread - /// until it can be acquired. - #[inline(always)] - pub fn write(&self) -> Result, alloc::string::String> { - #[cfg(not(feature = "std"))] - { - Ok(self.inner.write()) - } - - #[cfg(feature = "std")] - { - self.inner.write().map_err(|err| err.to_string()) - } - } -} - -/// A unique identifier for a running thread. -/// -/// This module is a stub when no std is available to swap with std::thread::ThreadId. -#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)] -pub struct ThreadId(core::num::NonZeroU64); - -/// A cell that provides lazy one-time initialization that implements [Sync] and [Send]. -/// -/// This module is a stub when no std is available to swap with [std::sync::OnceLock]. -pub struct SyncOnceCell(OnceImported); - -impl Default for SyncOnceCell { - fn default() -> Self { - Self::new() - } -} - -impl SyncOnceCell { - /// Create a new once. - #[inline(always)] - pub fn new() -> Self { - Self(OnceImported::new()) - } - - /// Initialize the cell with a value. - #[inline(always)] - pub fn initialized(value: T) -> Self { - #[cfg(not(feature = "std"))] - { - let cell = OnceImported::initialized(value); - Self(cell) - } - - #[cfg(feature = "std")] - { - let cell = OnceImported::new(); - cell.set(value).unwrap(); - - Self(cell) - } - } - - /// Gets the contents of the cell, initializing it with `f` if the cell - /// was empty. - #[inline(always)] - pub fn get_or_init(&self, f: F) -> &T - where - F: FnOnce() -> T, - { - #[cfg(not(feature = "std"))] - { - self.0.call_once(f) - } - - #[cfg(feature = "std")] - { - self.0.get_or_init(f) - } - } -} diff --git a/crates/burn-common/src/sync_type.rs b/crates/burn-common/src/sync_type.rs deleted file mode 100644 index afe484819d..0000000000 --- a/crates/burn-common/src/sync_type.rs +++ /dev/null @@ -1,8 +0,0 @@ -/// What kind of synchronization to use. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SyncType { - /// Submit all outstanding tasks to the task queue if any. - Flush, - /// Submit all tasks to the task queue and wait for all of them to complete. - Wait, -} diff --git a/crates/burn-compute/Cargo.toml b/crates/burn-compute/Cargo.toml deleted file mode 100644 index 30a0a602b8..0000000000 --- a/crates/burn-compute/Cargo.toml +++ /dev/null @@ -1,51 +0,0 @@ -[package] -authors = ["louisfd ", "Nathaniel Simard"] -categories = ["science"] -description = "Compute crate that helps creating high performance async backends." -edition.workspace = true -keywords = ["deep-learning", "machine-learning", "data"] -license.workspace = true -name = "burn-compute" -readme.workspace = true -repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-compute" -version.workspace = true - -[features] -default = [ - "std", - "channel-mutex", - "channel-mpsc", - "channel-cell", - "storage-bytes", - "autotune-persistent-cache", -] -std = ["burn-common/std"] -channel-mutex = [] -channel-cell = [] -channel-mpsc = ["dep:async-channel", "dep:pollster"] # Assume std -storage-bytes = [] -autotune-persistent-cache = ["dirs", "md5", "serde", "serde_json"] # Assume std - -[dependencies] -burn-common = { path = "../burn-common", version = "0.14.0", default-features = false } -derive-new = { workspace = true } -spin = { workspace = true } -log = { workspace = true } -hashbrown = { workspace = true } -dirs = { workspace = true, optional = true } -serde = { workspace = true, optional = true } -serde_json = { workspace = true, features = ["std"], optional = true } -md5 = { workspace = true, optional = true } -pollster = { workspace = true, optional = true } -async-channel = { workspace = true, optional = true } - -[target.'cfg(target_family = "wasm")'.dependencies] -web-time = { workspace = true } - -[dev-dependencies] -serial_test = { workspace = true } -rand = { workspace = true } - -[[bench]] -name = "dynamic" -harness = false diff --git a/crates/burn-compute/LICENSE-APACHE b/crates/burn-compute/LICENSE-APACHE deleted file mode 120000 index 1cd601d0a3..0000000000 --- a/crates/burn-compute/LICENSE-APACHE +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-APACHE \ No newline at end of file diff --git a/crates/burn-compute/LICENSE-MIT b/crates/burn-compute/LICENSE-MIT deleted file mode 120000 index b2cfbdc7b0..0000000000 --- a/crates/burn-compute/LICENSE-MIT +++ /dev/null @@ -1 +0,0 @@ -../../LICENSE-MIT \ No newline at end of file diff --git a/crates/burn-compute/README.md b/crates/burn-compute/README.md deleted file mode 100644 index d184ff0ce4..0000000000 --- a/crates/burn-compute/README.md +++ /dev/null @@ -1,7 +0,0 @@ -# Burn Compute - -This crate helps creating high performance async backends. - -- [x] Asynchronous kernel executions -- [x] Memory allocation management -- [x] Autotuning diff --git a/crates/burn-compute/benches/dynamic.rs b/crates/burn-compute/benches/dynamic.rs deleted file mode 100644 index da0f37b677..0000000000 --- a/crates/burn-compute/benches/dynamic.rs +++ /dev/null @@ -1,29 +0,0 @@ -use std::collections::LinkedList; - -use burn_compute::{ - memory_management::{ - dynamic::{DynamicMemoryManagement, DynamicMemoryManagementOptions}, - MemoryManagement, - }, - storage::BytesStorage, -}; - -const MB: usize = 1024 * 1024; - -fn main() { - let start = std::time::Instant::now(); - let storage = BytesStorage::default(); - let mut mm = DynamicMemoryManagement::new( - storage, - DynamicMemoryManagementOptions::preset(2048 * MB, 32), - ); - let mut handles = LinkedList::new(); - for _ in 0..100 * 2048 { - if handles.len() >= 4000 { - handles.pop_front(); - } - let handle = mm.reserve(MB, || {}); - handles.push_back(handle); - } - println!("{:?}", start.elapsed()); -} diff --git a/crates/burn-compute/src/channel/base.rs b/crates/burn-compute/src/channel/base.rs deleted file mode 100644 index 0fa266159c..0000000000 --- a/crates/burn-compute/src/channel/base.rs +++ /dev/null @@ -1,36 +0,0 @@ -use crate::{ - server::{Binding, ComputeServer, Handle}, - storage::ComputeStorage, -}; -use alloc::vec::Vec; -use burn_common::{reader::Reader, sync_type::SyncType}; - -/// The ComputeChannel trait links the ComputeClient to the ComputeServer -/// while ensuring thread-safety -pub trait ComputeChannel: Clone + core::fmt::Debug + Send + Sync { - /// Given a binding, returns owned resource as bytes - fn read(&self, binding: Binding) -> Reader; - - /// Given a resource handle, return the storage resource. - fn get_resource( - &self, - binding: Binding, - ) -> ::Resource; - - /// Given a resource as bytes, stores it and returns the resource handle - fn create(&self, data: &[u8]) -> Handle; - - /// Reserves `size` bytes in the storage, and returns a handle over them - fn empty(&self, size: usize) -> Handle; - - /// Executes the `kernel` over the given `bindings`. - fn execute( - &self, - kernel: Server::Kernel, - count: Server::DispatchOptions, - bindings: Vec>, - ); - - /// Perform some synchronization of commands on the server. - fn sync(&self, sync_type: SyncType); -} diff --git a/crates/burn-compute/src/channel/cell.rs b/crates/burn-compute/src/channel/cell.rs deleted file mode 100644 index d80a44b9fb..0000000000 --- a/crates/burn-compute/src/channel/cell.rs +++ /dev/null @@ -1,85 +0,0 @@ -use super::ComputeChannel; -use crate::server::{Binding, ComputeServer, Handle}; -use crate::storage::ComputeStorage; -use alloc::sync::Arc; -use alloc::vec::Vec; -use burn_common::reader::Reader; -use burn_common::sync_type::SyncType; - -/// A channel using a [ref cell](core::cell::RefCell) to access the server with mutability. -/// -/// # Important -/// -/// Only use this channel if you don't use any threading in your application, otherwise it will -/// panic or cause undefined behaviors. -/// -/// This is mosly useful for `no-std` environments where threads aren't supported, otherwise prefer -/// the [mutex](super::MutexComputeChannel) or the [mpsc](super::MpscComputeChannel) channels. -#[derive(Debug)] -pub struct RefCellComputeChannel { - server: Arc>, -} - -impl Clone for RefCellComputeChannel { - fn clone(&self) -> Self { - Self { - server: self.server.clone(), - } - } -} - -impl RefCellComputeChannel -where - Server: ComputeServer, -{ - /// Create a new cell compute channel. - pub fn new(server: Server) -> Self { - Self { - server: Arc::new(core::cell::RefCell::new(server)), - } - } -} - -impl ComputeChannel for RefCellComputeChannel -where - Server: ComputeServer + Send, -{ - fn read(&self, binding: Binding) -> Reader { - self.server.borrow_mut().read(binding) - } - - fn get_resource( - &self, - binding: Binding, - ) -> ::Resource { - self.server.borrow_mut().get_resource(binding) - } - - fn create(&self, resource: &[u8]) -> Handle { - self.server.borrow_mut().create(resource) - } - - fn empty(&self, size: usize) -> Handle { - self.server.borrow_mut().empty(size) - } - - fn execute( - &self, - kernel_description: Server::Kernel, - count: Server::DispatchOptions, - bindings: Vec>, - ) { - self.server - .borrow_mut() - .execute(kernel_description, count, bindings) - } - - fn sync(&self, sync_type: SyncType) { - self.server.borrow_mut().sync(sync_type) - } -} - -/// This is unsafe, since no concurrency is supported by the `RefCell` channel. -/// However using this channel should only be done in single threaded environments such as `no-std`. -unsafe impl Send for RefCellComputeChannel {} -unsafe impl Sync for RefCellComputeChannel {} diff --git a/crates/burn-compute/src/channel/mod.rs b/crates/burn-compute/src/channel/mod.rs deleted file mode 100644 index 68a1c372a0..0000000000 --- a/crates/burn-compute/src/channel/mod.rs +++ /dev/null @@ -1,17 +0,0 @@ -mod base; -pub use base::*; - -#[cfg(feature = "channel-mutex")] -mod mutex; -#[cfg(feature = "channel-mutex")] -pub use mutex::*; - -#[cfg(all(feature = "channel-mpsc", not(target_family = "wasm")))] -mod mpsc; -#[cfg(all(feature = "channel-mpsc", not(target_family = "wasm")))] -pub use mpsc::*; - -#[cfg(feature = "channel-cell")] -mod cell; -#[cfg(feature = "channel-cell")] -pub use cell::*; diff --git a/crates/burn-compute/src/channel/mpsc.rs b/crates/burn-compute/src/channel/mpsc.rs deleted file mode 100644 index c7c22659fc..0000000000 --- a/crates/burn-compute/src/channel/mpsc.rs +++ /dev/null @@ -1,181 +0,0 @@ -use burn_common::{reader::Reader, sync_type::SyncType}; -use std::{sync::Arc, thread}; - -use super::ComputeChannel; -use crate::{ - server::{Binding, ComputeServer, Handle}, - storage::ComputeStorage, -}; - -/// Create a channel using a [multi-producer, single-consumer channel to communicate with -/// the compute server spawn on its own thread. -#[derive(Debug)] -pub struct MpscComputeChannel -where - Server: ComputeServer, -{ - state: Arc>, -} - -#[derive(Debug)] -struct MpscComputeChannelState -where - Server: ComputeServer, -{ - _handle: thread::JoinHandle<()>, - sender: async_channel::Sender>, -} - -type Callback = async_channel::Sender; - -enum Message -where - Server: ComputeServer, -{ - Read(Binding, Callback>), - GetResource( - Binding, - Callback<::Resource>, - ), - Create(Vec, Callback>), - Empty(usize, Callback>), - ExecuteKernel( - (Server::Kernel, Server::DispatchOptions), - Vec>, - ), - Sync(SyncType, Callback<()>), -} - -impl MpscComputeChannel -where - Server: ComputeServer + 'static, -{ - /// Create a new mpsc compute channel. - pub fn new(mut server: Server) -> Self { - let (sender, receiver) = async_channel::unbounded(); - - let _handle = thread::spawn(move || { - // Run the whole procedure as one blocking future. This is much simpler than trying - // to use some multithreaded executor. - pollster::block_on(async { - while let Ok(message) = receiver.recv().await { - match message { - Message::Read(binding, callback) => { - let data = server.read(binding).await; - callback.send(data).await.unwrap(); - } - Message::GetResource(binding, callback) => { - let data = server.get_resource(binding); - callback.send(data).await.unwrap(); - } - Message::Create(data, callback) => { - let handle = server.create(&data); - callback.send(handle).await.unwrap(); - } - Message::Empty(size, callback) => { - let handle = server.empty(size); - callback.send(handle).await.unwrap(); - } - Message::ExecuteKernel(kernel, bindings) => { - server.execute(kernel.0, kernel.1, bindings); - } - Message::Sync(sync_type, callback) => { - server.sync(sync_type); - callback.send(()).await.unwrap(); - } - }; - } - }); - }); - - let state = Arc::new(MpscComputeChannelState { sender, _handle }); - - Self { state } - } -} - -impl Clone for MpscComputeChannel { - fn clone(&self) -> Self { - Self { - state: self.state.clone(), - } - } -} - -impl ComputeChannel for MpscComputeChannel -where - Server: ComputeServer + 'static, -{ - fn read(&self, binding: Binding) -> Reader { - let sender = self.state.sender.clone(); - - Box::pin(async move { - let (callback, response) = async_channel::unbounded(); - sender.send(Message::Read(binding, callback)).await.unwrap(); - handle_response(response.recv().await) - }) - } - - fn get_resource( - &self, - binding: Binding, - ) -> ::Resource { - let (callback, response) = async_channel::unbounded(); - - self.state - .sender - .send_blocking(Message::GetResource(binding, callback)) - .unwrap(); - - handle_response(response.recv_blocking()) - } - - fn create(&self, data: &[u8]) -> Handle { - let (callback, response) = async_channel::unbounded(); - - self.state - .sender - .send_blocking(Message::Create(data.to_vec(), callback)) - .unwrap(); - - handle_response(response.recv_blocking()) - } - - fn empty(&self, size: usize) -> Handle { - let (callback, response) = async_channel::unbounded(); - self.state - .sender - .send_blocking(Message::Empty(size, callback)) - .unwrap(); - - handle_response(response.recv_blocking()) - } - - fn execute( - &self, - kernel: Server::Kernel, - count: Server::DispatchOptions, - bindings: Vec>, - ) { - self.state - .sender - .send_blocking(Message::ExecuteKernel((kernel, count), bindings)) - .unwrap() - } - - fn sync(&self, sync_type: SyncType) { - let (callback, response) = async_channel::unbounded(); - self.state - .sender - .send_blocking(Message::Sync(sync_type, callback)) - .unwrap(); - handle_response(response.recv_blocking()) - } -} - -fn handle_response(response: Result) -> Response { - match response { - Ok(val) => val, - Err(err) => panic!("Can't connect to the server correctly {err:?}"), - } -} diff --git a/crates/burn-compute/src/channel/mutex.rs b/crates/burn-compute/src/channel/mutex.rs deleted file mode 100644 index 1eeb1bf371..0000000000 --- a/crates/burn-compute/src/channel/mutex.rs +++ /dev/null @@ -1,71 +0,0 @@ -use super::ComputeChannel; -use crate::server::{Binding, ComputeServer, Handle}; -use crate::storage::ComputeStorage; -use alloc::sync::Arc; -use alloc::vec::Vec; -use burn_common::reader::Reader; -use burn_common::sync_type::SyncType; -use spin::Mutex; - -/// The MutexComputeChannel ensures thread-safety by locking the server -/// on every operation -#[derive(Debug)] -pub struct MutexComputeChannel { - server: Arc>, -} - -impl Clone for MutexComputeChannel { - fn clone(&self) -> Self { - Self { - server: self.server.clone(), - } - } -} -impl MutexComputeChannel -where - Server: ComputeServer, -{ - /// Create a new mutex compute channel. - pub fn new(server: Server) -> Self { - Self { - server: Arc::new(Mutex::new(server)), - } - } -} - -impl ComputeChannel for MutexComputeChannel -where - Server: ComputeServer, -{ - fn read(&self, handle: Binding) -> Reader { - self.server.lock().read(handle) - } - - fn get_resource( - &self, - binding: Binding, - ) -> ::Resource { - self.server.lock().get_resource(binding) - } - - fn create(&self, data: &[u8]) -> Handle { - self.server.lock().create(data) - } - - fn empty(&self, size: usize) -> Handle { - self.server.lock().empty(size) - } - - fn execute( - &self, - kernel: Server::Kernel, - count: Server::DispatchOptions, - handles: Vec>, - ) { - self.server.lock().execute(kernel, count, handles) - } - - fn sync(&self, sync_type: SyncType) { - self.server.lock().sync(sync_type) - } -} diff --git a/crates/burn-compute/src/client.rs b/crates/burn-compute/src/client.rs deleted file mode 100644 index 2c0bb14ad5..0000000000 --- a/crates/burn-compute/src/client.rs +++ /dev/null @@ -1,119 +0,0 @@ -use crate::{ - channel::ComputeChannel, - server::{Binding, ComputeServer, Handle}, - storage::ComputeStorage, - tune::{AutotuneOperationSet, Tuner}, -}; -use alloc::vec::Vec; -use alloc::{boxed::Box, sync::Arc}; -use burn_common::stub::RwLock; -use burn_common::sync_type::SyncType; - -/// The ComputeClient is the entry point to require tasks from the ComputeServer. -/// It should be obtained for a specific device via the Compute struct. -#[derive(Debug)] -pub struct ComputeClient { - channel: Channel, - tuner: Arc>>, - features: Arc, -} - -impl Clone for ComputeClient -where - S: ComputeServer, - C: ComputeChannel, -{ - fn clone(&self) -> Self { - Self { - channel: self.channel.clone(), - tuner: self.tuner.clone(), - features: self.features.clone(), - } - } -} - -impl ComputeClient -where - Server: ComputeServer, - Channel: ComputeChannel, -{ - /// Create a new client. - pub fn new( - channel: Channel, - tuner: Arc>>, - features: Arc, - ) -> Self { - Self { - channel, - tuner, - features, - } - } - - /// Given a binding, returns owned resource as bytes. - pub async fn read_async(&self, binding: Binding) -> Vec { - self.channel.read(binding).await - } - - /// Given a binding, returns owned resource as bytes. - /// - /// # Remarks - /// Panics if the read operation fails. - pub fn read(&self, binding: Binding) -> Vec { - burn_common::reader::read_sync(self.channel.read(binding)) - } - - /// Given a resource handle, returns the storage resource. - pub fn get_resource( - &self, - binding: Binding, - ) -> ::Resource { - self.channel.get_resource(binding) - } - - /// Given a resource, stores it and returns the resource handle. - pub fn create(&self, data: &[u8]) -> Handle { - self.channel.create(data) - } - - /// Reserves `size` bytes in the storage, and returns a handle over them. - pub fn empty(&self, size: usize) -> Handle { - self.channel.empty(size) - } - - /// Executes the `kernel` over the given `bindings`. - pub fn execute( - &self, - kernel: Server::Kernel, - count: Server::DispatchOptions, - bindings: Vec>, - ) { - self.channel.execute(kernel, count, bindings) - } - - /// Wait for the completion of every task in the server. - pub fn sync(&self, sync_type: SyncType) { - self.channel.sync(sync_type) - } - - /// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks - pub fn autotune_execute( - &self, - autotune_operation_set: Box>, - ) { - self.tuner - .write() - .unwrap() - .execute_autotune(autotune_operation_set, self); - } - - /// Get the fastest kernel for the given autotune key if it exists. - pub fn autotune_result(&self, key: &Server::AutotuneKey) -> Option { - self.tuner.read().unwrap().autotune_fastest(key) - } - - /// Get the features supported by the compute server. - pub fn features(&self) -> &Server::FeatureSet { - self.features.as_ref() - } -} diff --git a/crates/burn-compute/src/compute.rs b/crates/burn-compute/src/compute.rs deleted file mode 100644 index 9a35f53841..0000000000 --- a/crates/burn-compute/src/compute.rs +++ /dev/null @@ -1,94 +0,0 @@ -use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer}; -use core::ops::DerefMut; -use hashbrown::HashMap; - -/// The compute type has the responsibility to retrieve the correct compute client based on the -/// given device. -pub struct ComputeRuntime { - clients: spin::Mutex>>>, -} - -impl Default for ComputeRuntime -where - Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, - Server: ComputeServer, - Channel: ComputeChannel, -{ - fn default() -> Self { - Self::new() - } -} - -impl ComputeRuntime -where - Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug, - Server: ComputeServer, - Channel: ComputeChannel, -{ - /// Create a new compute. - pub const fn new() -> Self { - Self { - clients: spin::Mutex::new(None), - } - } - - /// Get the compute client for the given device. - /// - /// Provide the init function to create a new client if it isn't already initialized. - pub fn client(&self, device: &Device, init: Init) -> ComputeClient - where - Init: Fn() -> ComputeClient, - { - let mut clients = self.clients.lock(); - - if clients.is_none() { - Self::register_inner(device, init(), &mut clients); - } - - match clients.deref_mut() { - Some(clients) => match clients.get(device) { - Some(client) => client.clone(), - None => { - let client = init(); - clients.insert(device.clone(), client.clone()); - client - } - }, - _ => unreachable!(), - } - } - - /// Register the compute client for the given device. - /// - /// # Note - /// - /// This function is mostly useful when the creation of the compute client can't be done - /// synchronously and require special context. - /// - /// # Panics - /// - /// If a client is already registered for the given device. - pub fn register(&self, device: &Device, client: ComputeClient) { - let mut clients = self.clients.lock(); - - Self::register_inner(device, client, &mut clients); - } - - fn register_inner( - device: &Device, - client: ComputeClient, - clients: &mut Option>>, - ) { - if clients.is_none() { - *clients = Some(HashMap::new()); - } - - if let Some(clients) = clients { - if clients.contains_key(device) { - panic!("Client already created for device {:?}", device); - } - - clients.insert(device.clone(), client); - } - } -} diff --git a/crates/burn-compute/src/id.rs b/crates/burn-compute/src/id.rs deleted file mode 100644 index dacc3cd68a..0000000000 --- a/crates/burn-compute/src/id.rs +++ /dev/null @@ -1,175 +0,0 @@ -use alloc::sync::Arc; - -#[macro_export(local_inner_macros)] -/// Create a new storage ID type. -macro_rules! storage_id_type { - ($name:ident) => { - /// Storage ID. - #[derive(Clone, Hash, PartialEq, Eq, Debug)] - pub struct $name { - value: usize, - } - - impl $name { - /// Create a new ID. - pub fn new() -> Self { - use core::sync::atomic::{AtomicUsize, Ordering}; - - static COUNTER: AtomicUsize = AtomicUsize::new(0); - - let value = COUNTER.fetch_add(1, Ordering::Relaxed); - if value == usize::MAX { - core::panic!("Memory ID overflowed"); - } - Self { value } - } - } - - impl Default for $name { - fn default() -> Self { - Self::new() - } - } - }; -} - -/// Reference to a buffer handle. -#[derive(Clone, Debug)] -pub struct HandleRef { - id: Arc, - all: Arc<()>, -} - -/// Reference to buffer binding. -#[derive(Clone, Debug)] -pub struct BindingRef { - id: Id, - _all: Arc<()>, -} - -impl BindingRef -where - Id: Clone + core::fmt::Debug, -{ - /// The id associated to the buffer. - pub(crate) fn id(&self) -> &Id { - &self.id - } -} - -impl HandleRef -where - Id: Clone + core::fmt::Debug, -{ - /// Create a new handle. - pub(crate) fn new(id: Id) -> Self { - Self { - id: Arc::new(id), - all: Arc::new(()), - } - } - - /// The id associated to the handle. - pub(crate) fn id(&self) -> &Id { - &self.id - } - - /// Get the binding. - pub(crate) fn binding(self) -> BindingRef { - BindingRef { - id: self.id.as_ref().clone(), - _all: self.all, - } - } - - /// If the handle can be mut. - pub(crate) fn can_mut(&self) -> bool { - // 1 memory management reference with 1 tensor reference. - Arc::strong_count(&self.id) <= 2 - } - - /// If the resource is free. - pub(crate) fn is_free(&self) -> bool { - Arc::strong_count(&self.all) <= 1 - } -} - -#[macro_export(local_inner_macros)] -/// Create new memory ID types. -macro_rules! memory_id_type { - ($id:ident, $handle:ident) => { - /// Memory Handle. - #[derive(Clone, Debug)] - pub struct $handle { - value: $crate::id::HandleRef<$id>, - } - - /// Memory ID. - #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)] - pub struct $id { - pub(crate) value: usize, - } - - impl $handle { - /// Create a new ID. - pub(crate) fn new() -> Self { - let value = Self::gen_id(); - Self { - value: $crate::id::HandleRef::new($id { value }), - } - } - - fn gen_id() -> usize { - static COUNTER: core::sync::atomic::AtomicUsize = - core::sync::atomic::AtomicUsize::new(0); - - let value = COUNTER.fetch_add(1, core::sync::atomic::Ordering::Relaxed); - if value == usize::MAX { - core::panic!("Memory ID overflowed"); - } - - value - } - } - - impl core::ops::Deref for $handle { - type Target = $crate::id::HandleRef<$id>; - - fn deref(&self) -> &Self::Target { - &self.value - } - } - - impl Default for $handle { - fn default() -> Self { - Self::new() - } - } - }; - - ($id:ident, $handle:ident, $binding:ident) => { - memory_id_type!($id, $handle); - - /// Binding of a memory handle. - #[derive(Clone, Debug)] - pub struct $binding { - value: $crate::id::BindingRef<$id>, - } - - impl $handle { - pub(crate) fn binding(self) -> $binding { - $binding { - value: self.value.binding(), - } - } - } - - impl core::ops::Deref for $binding { - type Target = $crate::id::BindingRef<$id>; - - fn deref(&self) -> &Self::Target { - &self.value - } - } - }; -} diff --git a/crates/burn-compute/src/lib.rs b/crates/burn-compute/src/lib.rs deleted file mode 100644 index 255328b1ee..0000000000 --- a/crates/burn-compute/src/lib.rs +++ /dev/null @@ -1,29 +0,0 @@ -#![cfg_attr(not(feature = "std"), no_std)] -#![warn(missing_docs)] - -//! Burn compute crate that helps creating high performance async backends. - -extern crate alloc; - -#[macro_use] -extern crate derive_new; - -mod id; - -/// Compute channel module. -pub mod channel; -/// Compute client module. -pub mod client; - -/// Autotune module -pub mod tune; - -/// Memory management module. -pub mod memory_management; -/// Compute server module. -pub mod server; -/// Compute Storage module. -pub mod storage; - -mod compute; -pub use compute::*; diff --git a/crates/burn-compute/src/memory_management/base.rs b/crates/burn-compute/src/memory_management/base.rs deleted file mode 100644 index 76858dda3a..0000000000 --- a/crates/burn-compute/src/memory_management/base.rs +++ /dev/null @@ -1,57 +0,0 @@ -use crate::storage::ComputeStorage; - -/// The managed tensor buffer handle that points to some memory segment. -/// It should not contain actual data. -pub trait MemoryHandle: Clone + Send + Sync + core::fmt::Debug { - /// Checks if the underlying memory can be safely mutated. - fn can_mut(&self) -> bool; - /// Get the binding associated to the current handle. - fn binding(self) -> Binding; -} - -/// Binding to a [memory handle](MemoryHandle). -pub trait MemoryBinding: Clone + Send + Sync + core::fmt::Debug {} - -/// The MemoryManagement trait encapsulates strategies for (de)allocating memory. -/// It is bound to the ComputeStorage trait, which does the actual (de)allocations. -/// -/// The MemoryManagement can only reserve memory space or get the resource located at a space. -/// Modification of the resource data should be done directly on the resource. -pub trait MemoryManagement: Send + core::fmt::Debug { - /// The associated type that must implement [MemoryHandle]. - type Handle: MemoryHandle; - /// The associated type that must implement [MemoryBinding] - type Binding: MemoryBinding; - - /// Returns the resource from the storage at the specified handle - fn get(&mut self, binding: Self::Binding) -> Storage::Resource; - - /// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it - fn reserve(&mut self, size: usize, sync: Sync) -> Self::Handle; - - /// Bypass the memory allocation algorithm to allocate data directly. - /// - /// # Notes - /// - /// Can be useful for servers that want specific control over memory. - fn alloc(&mut self, size: usize, sync: Sync) -> Self::Handle; - - /// Bypass the memory allocation algorithm to deallocate data directly. - /// - /// # Notes - /// - /// Can be useful for servers that want specific control over memory. - fn dealloc(&mut self, binding: Self::Binding); - - /// Fetch the storage used by the memory manager. - /// - /// # Notes - /// - /// The storage should probably not be used for allocations since the handles won't be - /// compatible with the ones provided by the current trait. Prefer using the - /// [alloc](MemoryManagement::alloc) and [dealloc](MemoryManagement::dealloc) functions. - /// - /// This is useful if you need to time the deallocations based on async computation, or to - /// change the mode of storage for different reasons. - fn storage(&mut self) -> &mut Storage; -} diff --git a/crates/burn-compute/src/memory_management/dynamic.rs b/crates/burn-compute/src/memory_management/dynamic.rs deleted file mode 100644 index 4df0aa2e96..0000000000 --- a/crates/burn-compute/src/memory_management/dynamic.rs +++ /dev/null @@ -1,181 +0,0 @@ -use super::memory_pool::{ - MemoryExtensionStrategy, MemoryPool, MemoryPoolBinding, MemoryPoolHandle, RoundingStrategy, - SmallMemoryPool, -}; -use crate::storage::ComputeStorage; -use alloc::vec::Vec; - -use super::MemoryManagement; - -/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks. -pub struct DynamicMemoryManagement { - min_chunk_alignment_offset: usize, - small_memory_pool: SmallMemoryPool, - pools: Vec, - options: Vec, - storage: Storage, -} - -/// Options to initialize a [dynamic memory management](DynamicMemoryManagement). -#[derive(new, Debug)] -pub struct DynamicMemoryManagementOptions { - pools: Vec, - min_chunk_alignment_offset: usize, -} - -/// Options to create a memory pool. -#[derive(Debug)] -pub struct MemoryPoolOptions { - /// The amount of bytes used for each chunk in the memory pool. - pub chunk_size: usize, - /// The number of chunks allocated directly at creation. - /// - /// Useful when you know in advance how much memory you'll need. - pub chunk_num_prealloc: usize, - /// The max size in bytes a slice can take in the pool. - pub slice_max_size: usize, -} - -impl DynamicMemoryManagementOptions { - /// Creates the options from device limits. - pub fn preset(max_chunk_size: usize, min_chunk_alignment_offset: usize) -> Self { - // Rounding down to a factor of 8. - let max_chunk_size = (max_chunk_size / 8) * 8; - - const MB: usize = 1024 * 1024; - - let mut pools = Vec::new(); - - pools.push(MemoryPoolOptions { - chunk_size: max_chunk_size, - chunk_num_prealloc: 0, - slice_max_size: max_chunk_size, - }); - - let mut current = max_chunk_size; - - while current >= 32 * MB { - current /= 4; - - pools.push(MemoryPoolOptions { - chunk_size: current, - chunk_num_prealloc: 0, - // Creating max slices lower than the chunk size reduces fragmentation. - slice_max_size: current / 2usize.pow(pools.len() as u32), - }); - } - - Self { - pools, - min_chunk_alignment_offset, - } - } -} - -impl DynamicMemoryManagement { - /// Creates a new instance using the given storage, merging_strategy strategy and slice strategy. - pub fn new(mut storage: Storage, mut options: DynamicMemoryManagementOptions) -> Self { - options - .pools - .sort_by(|pool1, pool2| usize::cmp(&pool1.slice_max_size, &pool2.slice_max_size)); - - let min_chunk_alignment_offset = options.min_chunk_alignment_offset; - - let pools = options - .pools - .iter() - .map(|option| { - let mut pool = MemoryPool::new( - MemoryExtensionStrategy::Never, - RoundingStrategy::FixedAmount(option.chunk_size), - min_chunk_alignment_offset, - ); - - for _ in 0..option.chunk_num_prealloc { - pool.alloc(&mut storage, option.chunk_size, || {}); - } - - pool - }) - .collect(); - - Self { - min_chunk_alignment_offset, - small_memory_pool: SmallMemoryPool::new(min_chunk_alignment_offset), - pools, - options: options.pools, - storage, - } - } -} - -impl core::fmt::Debug for DynamicMemoryManagement { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str( - alloc::format!( - "DynamicMemoryManagement {:?}", - core::any::type_name::(), - ) - .as_str(), - ) - } -} - -impl MemoryManagement for DynamicMemoryManagement { - type Handle = MemoryPoolHandle; - type Binding = MemoryPoolBinding; - - fn get(&mut self, binding: Self::Binding) -> Storage::Resource { - if let Some(handle) = self.small_memory_pool.get(&mut self.storage, &binding) { - return handle; - } - - for pool in &mut self.pools { - if let Some(handle) = pool.get(&mut self.storage, &binding) { - return handle; - } - } - - panic!("No handle found in memory pools"); - } - - fn reserve(&mut self, size: usize, sync: Sync) -> Self::Handle { - if size <= self.min_chunk_alignment_offset { - return self - .small_memory_pool - .reserve(&mut self.storage, size, sync); - } - - for (index, option) in self.options.iter().enumerate() { - if size <= option.slice_max_size { - let pool = &mut self.pools[index]; - return pool.reserve(&mut self.storage, size, sync); - } - } - - panic!("No memory pool big enough to reserve {size} bytes."); - } - - fn alloc(&mut self, size: usize, sync: Sync) -> Self::Handle { - if size <= self.min_chunk_alignment_offset { - return self.small_memory_pool.alloc(&mut self.storage, size, sync); - } - - for (index, option) in self.options.iter().enumerate() { - if size <= option.slice_max_size { - let pool = &mut self.pools[index]; - return pool.alloc(&mut self.storage, size, sync); - } - } - - panic!("No memory pool big enough to alloc {size} bytes."); - } - - fn dealloc(&mut self, _binding: Self::Binding) { - // Can't dealloc slices. - } - - fn storage(&mut self) -> &mut Storage { - &mut self.storage - } -} diff --git a/crates/burn-compute/src/memory_management/memory_pool/base.rs b/crates/burn-compute/src/memory_management/memory_pool/base.rs deleted file mode 100644 index 5c1f784890..0000000000 --- a/crates/burn-compute/src/memory_management/memory_pool/base.rs +++ /dev/null @@ -1,586 +0,0 @@ -use super::index::SearchIndex; -use super::{ - ChunkHandle, ChunkId, MemoryChunk, MemoryPoolBinding, MemoryPoolHandle, MemorySlice, - RingBuffer, SliceHandle, SliceId, -}; -use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization}; -use alloc::vec::Vec; -use hashbrown::{HashMap, HashSet}; - -pub struct MemoryPool { - chunks: HashMap, - slices: HashMap, - #[allow(unused)] // will be used when we rewrite memory extension - memory_extension_strategy: MemoryExtensionStrategy, - rounding: RoundingStrategy, - chunk_index: SearchIndex, - ring: RingBuffer, - recently_added_chunks: Vec, - recently_allocated_size: usize, - buffer_alignment: usize, -} - -struct SliceUpdate { - slice_id: SliceId, - size: usize, -} - -#[derive(new, Debug)] -pub struct Chunk { - pub storage: StorageHandle, - pub handle: ChunkHandle, - pub slices: MemoryPage, -} - -// TODO: consider using generic trait and decouple from Slice -#[derive(new, Debug)] -pub struct MemoryPage { - pub slices: HashMap, -} - -impl MemoryPage { - /// merge slice at first_slice_address with the next slice (if there is one and if it's free) - /// return a boolean representing if a merge happened - fn merge_with_next_slice( - &mut self, - first_slice_address: usize, - slices: &mut HashMap, - ) -> bool { - let first_slice_id = self.find_slice(first_slice_address).expect( - "merge_with_next_slice shouldn't be called with a nonexistent first_slice address", - ); - - let next_slice_address = - first_slice_address + slices.get(&first_slice_id).unwrap().effective_size(); - - if let Some(next_slice_id) = self.find_slice(next_slice_address) { - let (next_slice_eff_size, next_slice_is_free) = { - let next_slice = slices.get(&next_slice_id).unwrap(); - (next_slice.effective_size(), next_slice.is_free()) - }; - if next_slice_is_free { - let first_slice = slices.get_mut(&first_slice_id).unwrap(); - let first_slice_eff_size = first_slice.effective_size(); - let first_slice_offset = first_slice.storage.offset(); - - let merged_size = first_slice_eff_size + next_slice_eff_size; - first_slice.storage.utilization = StorageUtilization::Slice { - size: merged_size, - offset: first_slice_offset, - }; - first_slice.padding = 0; - - // Cleanup of the extra slice - self.slices.remove(&next_slice_address); - slices.remove(&next_slice_id); - return true; - } - return false; - } - false - } - - fn find_slice(&self, address: usize) -> Option { - let slice_id = self.slices.get(&address); - slice_id.copied() - } - - fn insert_slice(&mut self, address: usize, slice: SliceId) { - self.slices.insert(address, slice); - } - - fn slices_sorted_by_address(&self) -> Vec { - let mut entries: Vec<(usize, SliceId)> = self.slices.clone().into_iter().collect(); - entries.sort_by_key(|&(key, _)| key); - let sorted_slices: Vec = entries.into_iter().map(|(_, values)| values).collect(); - sorted_slices - } -} - -#[derive(new, Debug)] -pub struct Slice { - pub storage: StorageHandle, - pub handle: SliceHandle, - pub chunk: ChunkHandle, - pub padding: usize, -} - -impl Slice { - pub fn effective_size(&self) -> usize { - self.storage.size() + self.padding - } -} - -const MIN_SIZE_NEEDED_TO_OFFSET: usize = 16; - -pub enum RoundingStrategy { - FixedAmount(usize), - #[allow(unused)] - None, -} - -impl RoundingStrategy { - fn alloc_size(&self, size: usize) -> usize { - match self { - RoundingStrategy::FixedAmount(chunk_size) => { - assert!(*chunk_size >= size); - *chunk_size - } - RoundingStrategy::None => size, - } - } -} - -/// The strategy defines the frequency at which merging of free slices (defragmentation) occurs -#[allow(unused)] -#[derive(Debug)] -pub enum MemoryExtensionStrategy { - /// Once every n calls to reserve. - PeriodTick { - /// Number of calls to be executed before triggering the defragmentation. - period: usize, - /// Current state. Should start at zero. - state: usize, - }, - /// Never defragment. - Never, -} - -#[allow(unused)] -impl MemoryExtensionStrategy { - /// Create a new strategy with the given period. - pub fn new_period_tick(period: usize) -> Self { - MemoryExtensionStrategy::PeriodTick { period, state: 0 } - } - - #[allow(unused)] - fn should_extend_max_memory(&mut self) -> bool { - match self { - MemoryExtensionStrategy::PeriodTick { period, state } => { - *state = (*state + 1) % *period; - *state == 0 - } - MemoryExtensionStrategy::Never => false, - } - } -} - -impl MemoryPool { - pub fn new( - merging_strategy: MemoryExtensionStrategy, - alloc_strategy: RoundingStrategy, - buffer_alignment: usize, - ) -> Self { - Self { - chunks: HashMap::new(), - slices: HashMap::new(), - memory_extension_strategy: merging_strategy, - rounding: alloc_strategy, - chunk_index: SearchIndex::new(), - ring: RingBuffer::new(buffer_alignment), - recently_added_chunks: Vec::new(), - recently_allocated_size: 0, - buffer_alignment, - } - } - - /// Returns the resource from the storage, for the specified handle. - pub fn get( - &mut self, - storage: &mut Storage, - binding: &MemoryPoolBinding, - ) -> Option { - self.slices - .get(binding.slice.id()) - .map(|s| &s.storage) - .map(|h| storage.get(h)) - } - - /// Reserves memory of specified size using the reserve algorithm, and return - /// a handle to the reserved memory. - /// - /// Also clean ups, merging free slices together if permitted by the merging strategy - pub fn reserve( - &mut self, - storage: &mut Storage, - size: usize, - sync: Sync, - ) -> MemoryPoolHandle { - let slice = self.get_free_slice(size); - - match slice { - Some(slice) => MemoryPoolHandle { - slice: slice.clone(), - }, - None => self.alloc(storage, size, sync), - } - } - - pub fn alloc( - &mut self, - storage: &mut Storage, - size: usize, - #[allow(unused)] sync: Sync, - ) -> MemoryPoolHandle { - let alloc_size = self.rounding.alloc_size(size); - self.alloc_slice(storage, alloc_size, size) - } - - fn alloc_slice( - &mut self, - storage: &mut Storage, - alloc_size: usize, - slice_size: usize, - ) -> MemoryPoolHandle { - let chunk_size = self.rounding.alloc_size(alloc_size); - let handle_chunk = self.create_chunk(storage, chunk_size); - let chunk_size = self.chunks.get(handle_chunk.id()).unwrap().storage.size(); - self.recently_added_chunks.push(*handle_chunk.id()); - self.recently_allocated_size += chunk_size; - - let chunk_id = *handle_chunk.id(); - let (slice, extra_slice) = - self.allocate_slices(handle_chunk.clone(), chunk_size, slice_size); - - let handle_slice = slice.handle.clone(); - self.update_chunk_metadata(chunk_id, slice, extra_slice); - - MemoryPoolHandle { - slice: handle_slice, - } - } - - fn allocate_slices( - &self, - handle_chunk: ChunkHandle, - alloc_size: usize, - slice_size: usize, - ) -> (Slice, Option) { - let slice = self.create_slice(0, slice_size, handle_chunk.clone()); - - let effective_size = slice.effective_size(); - - let extra_slice = if effective_size < alloc_size { - Some(self.create_slice(effective_size, alloc_size - effective_size, handle_chunk)) - } else { - None - }; - - (slice, extra_slice) - } - - fn update_chunk_metadata( - &mut self, - chunk_id: ChunkId, - slice: Slice, - extra_slice: Option, - ) { - let slice_id = *slice.handle.id(); - let slice_offset = slice.storage.offset(); - - self.slices.insert(slice_id, slice); - self.chunks - .get_mut(&chunk_id) - .unwrap() - .slices - .slices - .insert(slice_offset, slice_id); - - if let Some(extra_slice) = extra_slice { - let extra_slice_id = *extra_slice.handle.id(); - let extra_slice_offset = extra_slice.storage.offset(); - self.slices.insert(extra_slice_id, extra_slice); - self.chunks - .get_mut(&chunk_id) - .unwrap() - .slices - .slices - .insert(extra_slice_offset, extra_slice_id); - } - } - - #[allow(unused)] - fn display_memory_usage(&self) { - let total_memory_usage: f64 = self - .chunks - .values() - .map(|chunk| chunk.storage.size() as f64) - .sum(); - let effective_memory_usage: f64 = self - .slices - .values() - .filter(|slice| slice.handle.is_free()) - .map(|slice| slice.storage.size() as f64) - .sum(); - let ratio = 100.0 * effective_memory_usage / total_memory_usage; - log::info!("the memory usage is {ratio}"); - } - - /// Finds a free slice that can contain the given size - /// Returns the chunk's id and size. - fn get_free_slice(&mut self, size: usize) -> Option { - if size < MIN_SIZE_NEEDED_TO_OFFSET { - return None; - } - - let padding = calculate_padding(size, self.buffer_alignment); - let effective_size = size + padding; - - let slice_id = - self.ring - .find_free_slice(effective_size, &mut self.chunks, &mut self.slices)?; - - let slice = self.slices.get_mut(&slice_id).unwrap(); - let old_slice_size = slice.effective_size(); - - let offset = match slice.storage.utilization { - StorageUtilization::Full(_) => 0, - StorageUtilization::Slice { offset, size: _ } => offset, - }; - slice.storage.utilization = StorageUtilization::Slice { offset, size }; - let new_padding = old_slice_size - size; - slice.padding = new_padding; - assert_eq!( - slice.effective_size(), - old_slice_size, - "new and old slice should have the same size" - ); - - Some(slice.handle.clone()) - } - - /// Creates a slice of size `size` upon the given chunk with the given offset. - fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> Slice { - assert_eq!( - offset % self.buffer_alignment, - 0, - "slice with offset {offset} needs to be a multiple of {}", - self.buffer_alignment - ); - if offset > 0 && size < MIN_SIZE_NEEDED_TO_OFFSET { - panic!("tried to create slice of size {size} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support"); - } - let chunk = self.chunks.get(handle_chunk.id()).unwrap(); - let handle = SliceHandle::new(); - - let storage = StorageHandle { - id: chunk.storage.id.clone(), - utilization: StorageUtilization::Slice { offset, size }, - }; - - let padding = calculate_padding(size, self.buffer_alignment); - - Slice::new(storage, handle, chunk.handle.clone(), padding) - } - - /// Creates a chunk of given size by allocating on the storage. - fn create_chunk( - &mut self, - storage: &mut Storage, - size: usize, - ) -> ChunkHandle { - let padding = calculate_padding(size, self.buffer_alignment); - let effective_size = size + padding; - - let storage = storage.alloc(effective_size); - let handle = ChunkHandle::new(); - let id = *handle.id(); - - self.ring.push_chunk(id); - - self.chunks.insert( - id, - Chunk::new(storage, handle.clone(), MemoryPage::new(HashMap::new())), - ); - self.chunk_index.insert(id, size); - - handle - } - - #[allow(unused)] - fn extend_max_memory(&mut self, storage: &mut Storage) { - let mut slices = Vec::::new(); - - let mut deallocations = HashSet::::new(); - - let mut chunks_total_size: usize = 0; - - for chunk_id in &self.recently_added_chunks { - let chunk = self.chunks.get(chunk_id).unwrap(); - let chunk_id = *chunk.handle.id(); - let sorted_slice = chunk.slices.slices_sorted_by_address(); - for slice_id in sorted_slice { - let slice = self.slices.get(&slice_id).unwrap(); - let size = slice.storage.size(); - - slices.push(SliceUpdate { slice_id, size }); - } - chunks_total_size += chunk.storage.size(); - deallocations.insert(chunk_id); - } - - if !slices.is_empty() { - self.move_to_new_chunk(chunks_total_size, storage, &mut slices, &mut deallocations); - } else { - self.deallocate(storage, &mut deallocations); - } - } - - fn deallocate( - &mut self, - storage: &mut Storage, - deallocations: &mut HashSet, - ) { - for id in deallocations.drain() { - let mut chunk = self.chunks.remove(&id).unwrap(); - self.ring.remove_chunk(id); - - for (_address, slice_id) in chunk.slices.slices.drain() { - let slice = self.slices.get(&slice_id).unwrap(); - let chunk_id = *slice.chunk.id(); - - assert_ne!(chunk_id, id, "Chunk id should be updated"); - } - - self.chunk_index.remove(&id); - storage.dealloc(chunk.storage.id); - } - } - - fn move_to_new_chunk( - &mut self, - alloc_size: usize, - storage: &mut Storage, - slices: &mut Vec, - deallocations: &mut HashSet, - ) { - let chunk = self.create_chunk(storage, alloc_size); - let storage_id = self.chunks.get(chunk.id()).unwrap().storage.id.clone(); - let mut offset = 0; - let mut slices_ids: Vec<(usize, SliceId)> = Vec::new(); - - for update in slices.drain(..) { - let slice_id = update.slice_id; - - let slice = self.slices.get_mut(&slice_id).unwrap(); - let old_storage = slice.storage.clone(); - - slice.chunk = chunk.clone(); - slice.storage = StorageHandle { - id: storage_id.clone(), - utilization: StorageUtilization::Slice { - offset, - size: update.size, - }, - }; - storage.copy(&old_storage, &slice.storage); - slices_ids.push((offset, slice_id)); - offset += slice.effective_size(); - } - - let chunk = self.chunks.get_mut(chunk.id()).unwrap(); - let chunk_handle = chunk.handle.clone(); - for (address, slice_id) in slices_ids.drain(..) { - chunk.slices.insert_slice(address, slice_id); - } - let chunk_size = chunk.storage.size(); - let last_slice_size = chunk_size - offset; - assert_eq!(last_slice_size % self.buffer_alignment, 0); - if last_slice_size != 0 { - self.create_slice(offset, last_slice_size, chunk_handle); - } - - self.deallocate(storage, deallocations); - } -} - -fn calculate_padding(size: usize, buffer_alignment: usize) -> usize { - let remainder = size % buffer_alignment; - if remainder != 0 { - buffer_alignment - remainder - } else { - 0 - } -} - -impl MemorySlice for Slice { - fn is_free(&self) -> bool { - self.handle.is_free() - } - - fn size(&self) -> usize { - self.effective_size() - } - - fn split(&mut self, offset_slice: usize, buffer_alignment: usize) -> Option { - let size_new = self.effective_size() - offset_slice; - let offset_new = self.storage.offset() + offset_slice; - let old_size = self.effective_size(); - - let storage_new = StorageHandle { - id: self.storage.id.clone(), - utilization: StorageUtilization::Slice { - offset: offset_new, - size: size_new, - }, - }; - - self.storage.utilization = StorageUtilization::Slice { - offset: self.storage.offset(), - size: offset_slice, - }; - - if offset_new > 0 && size_new < MIN_SIZE_NEEDED_TO_OFFSET { - panic!("tried to create slice of size {size_new} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support"); - } - if offset_new % buffer_alignment != 0 { - panic!("slice with offset {offset_new} needs to be a multiple of {buffer_alignment}"); - } - let handle = SliceHandle::new(); - if size_new < buffer_alignment { - self.padding = old_size - offset_slice; - assert_eq!(self.effective_size(), old_size); - return None; - } - - assert!( - size_new >= buffer_alignment, - "Size new > {buffer_alignment}" - ); - self.padding = 0; - let padding = calculate_padding(size_new - buffer_alignment, buffer_alignment); - Some(Slice::new(storage_new, handle, self.chunk.clone(), padding)) - } - - fn id(&self) -> SliceId { - *self.handle.id() - } - - fn next_slice_position(&self) -> usize { - self.storage.offset() + self.effective_size() - } -} - -impl MemoryChunk for Chunk { - fn merge_next_slice( - &mut self, - from_slice_index: usize, - slices: &mut HashMap, - ) -> bool { - self.slices.merge_with_next_slice(from_slice_index, slices) - } - - fn slice(&self, index: usize) -> Option { - self.slices.find_slice(index) - } - - fn insert_slice( - &mut self, - position: usize, - slice: Slice, - slices: &mut HashMap, - ) { - self.slices.insert_slice(position, slice.id()); - slices.insert(slice.id(), slice); - } -} diff --git a/crates/burn-compute/src/memory_management/memory_pool/handle.rs b/crates/burn-compute/src/memory_management/memory_pool/handle.rs deleted file mode 100644 index 3bb04c2aae..0000000000 --- a/crates/burn-compute/src/memory_management/memory_pool/handle.rs +++ /dev/null @@ -1,33 +0,0 @@ -use crate::memory_id_type; -use crate::memory_management::{MemoryBinding, MemoryHandle}; - -// The ChunkId allows to keep track of how many references there are to a specific chunk. -memory_id_type!(ChunkId, ChunkHandle); -// The SliceId allows to keep track of how many references there are to a specific slice. -memory_id_type!(SliceId, SliceHandle, SliceBinding); - -/// A tensor memory handle, referring to either a chunk or a slice. -#[derive(Debug, Clone)] -pub struct MemoryPoolHandle { - pub slice: SliceHandle, -} - -/// Binding of the [dynamic handle](DynamicHandle). -#[derive(Debug, Clone)] -pub struct MemoryPoolBinding { - pub slice: SliceBinding, -} - -impl MemoryBinding for MemoryPoolBinding {} - -impl MemoryHandle for MemoryPoolHandle { - fn can_mut(&self) -> bool { - self.slice.can_mut() - } - - fn binding(self) -> MemoryPoolBinding { - MemoryPoolBinding { - slice: self.slice.binding(), - } - } -} diff --git a/crates/burn-compute/src/memory_management/memory_pool/index.rs b/crates/burn-compute/src/memory_management/memory_pool/index.rs deleted file mode 100644 index f49875b622..0000000000 --- a/crates/burn-compute/src/memory_management/memory_pool/index.rs +++ /dev/null @@ -1,65 +0,0 @@ -use alloc::collections::BTreeMap; -use alloc::vec; -use alloc::vec::Vec; -use core::hash::Hash; -use hashbrown::HashMap; - -/// Data Structure that helps to search items by size efficiently. -pub struct SearchIndex { - items_per_size: BTreeMap>, - sizes_per_item: HashMap, -} - -impl SearchIndex { - /// Create a new item search index. - pub fn new() -> Self { - Self { - items_per_size: BTreeMap::new(), - sizes_per_item: HashMap::new(), - } - } - - /// Insert a new sized item into the search index. - pub fn insert(&mut self, item: T, size: usize) { - self.remove(&item); - - if let Some(values) = self.items_per_size.get_mut(&size) { - values.push(item.clone()) - } else { - self.items_per_size.insert(size, vec![item.clone()]); - } - self.sizes_per_item.insert(item, size); - } - - /// Find the item by size range. - #[allow(unused)] - pub fn find_by_size( - &self, - range: core::ops::Range, - ) -> impl DoubleEndedIterator { - self.items_per_size.range(range).flat_map(|a| a.1) - } - - /// Remove an item from the index. - pub fn remove(&mut self, item: &T) { - let size = match self.sizes_per_item.remove(item) { - Some(size) => size, - None => return, - }; - - if let Some(values) = self.items_per_size.get_mut(&size) { - let mut removed_index = None; - - for (i, v) in values.iter().enumerate() { - if v == item { - removed_index = Some(i); - break; - } - } - - if let Some(index) = removed_index { - values.remove(index); - } - } - } -} diff --git a/crates/burn-compute/src/memory_management/memory_pool/mod.rs b/crates/burn-compute/src/memory_management/memory_pool/mod.rs deleted file mode 100644 index 9353f1959f..0000000000 --- a/crates/burn-compute/src/memory_management/memory_pool/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -pub(crate) mod index; -mod ring; - -mod base; -mod handle; -mod small; - -pub use base::*; -pub use handle::*; -pub use ring::*; -pub use small::*; diff --git a/crates/burn-compute/src/memory_management/memory_pool/ring.rs b/crates/burn-compute/src/memory_management/memory_pool/ring.rs deleted file mode 100644 index f46ef0921c..0000000000 --- a/crates/burn-compute/src/memory_management/memory_pool/ring.rs +++ /dev/null @@ -1,454 +0,0 @@ -use alloc::vec::Vec; -use core::marker::PhantomData; -use hashbrown::HashMap; - -use super::{ChunkId, SliceId}; - -#[derive(Debug)] -pub struct RingBuffer, S: MemorySlice> { - queue: Vec, - chunk_positions: HashMap, - cursor_slice: usize, - cursor_chunk: usize, - _s: PhantomData, - _c: PhantomData, - buffer_alignment: usize, -} - -pub trait MemoryChunk { - fn merge_next_slice(&mut self, slice_position: usize, slices: &mut HashMap) - -> bool; - fn slice(&self, index: usize) -> Option; - fn insert_slice(&mut self, position: usize, slice: S, slices: &mut HashMap); -} - -pub trait MemorySlice: Sized { - fn is_free(&self) -> bool; - fn size(&self) -> usize; - fn split(&mut self, offset: usize, buffer_alignment: usize) -> Option; - fn id(&self) -> SliceId; - fn next_slice_position(&self) -> usize; -} - -impl, S: MemorySlice> RingBuffer { - pub fn new(buffer_alignment: usize) -> Self { - Self { - queue: Vec::new(), - chunk_positions: HashMap::new(), - cursor_slice: 0, - cursor_chunk: 0, - _s: PhantomData, - _c: PhantomData, - buffer_alignment, - } - } - - pub fn push_chunk(&mut self, chunk_id: ChunkId) { - self.queue.push(chunk_id); - self.chunk_positions.insert(chunk_id, self.queue.len() - 1); - } - - pub fn remove_chunk(&mut self, chunk_id: ChunkId) { - if let Some(position) = self.chunk_positions.remove(&chunk_id) { - self.queue.remove(position); - } - - self.chunk_positions.clear(); - - for (pos, id) in self.queue.iter().enumerate() { - self.chunk_positions.insert(*id, pos); - } - self.cursor_chunk = 0; - self.cursor_slice = 0; - } - - pub fn find_free_slice( - &mut self, - size: usize, - chunks: &mut HashMap, - slices: &mut HashMap, - ) -> Option { - let max_second = self.cursor_chunk; - let result = self.find_free_slice_in_all_chunks(size, chunks, slices, self.queue.len()); - - if result.is_some() { - return result; - } - - self.cursor_chunk = 0; - self.cursor_slice = 0; - self.find_free_slice_in_all_chunks(size, chunks, slices, max_second) - } - - fn find_free_slice_in_chunk( - &mut self, - size: usize, - chunk: &mut C, - slices: &mut HashMap, - mut slice_index: usize, - ) -> Option<(usize, SliceId)> { - while let Some(slice_id) = chunk.slice(slice_index) { - //mutable borrow scope - { - let slice = slices.get_mut(&slice_id).unwrap(); - - let is_big_enough = slice.size() >= size; - let is_free = slice.is_free(); - - if is_big_enough && is_free { - if slice.size() > size { - if let Some(new_slice) = slice.split(size, self.buffer_alignment) { - let new_slice_id = new_slice.id(); - chunk.insert_slice(slice.next_slice_position(), new_slice, slices); - slices.get(&new_slice_id).unwrap(); - } - } - return Some((slice_index, slice_id)); - } - } - { - let slice = slices.get_mut(&slice_id).unwrap(); - let is_free = slice.is_free(); - if is_free && chunk.merge_next_slice(slice_index, slices) { - continue; - } - } - - if let Some(slice) = slices.get(&slice_id) { - slice_index = slice.next_slice_position(); - } else { - panic!("current slice_id should still be valid after potential merge"); - } - } - - None - } - - fn find_free_slice_in_all_chunks( - &mut self, - size: usize, - chunks: &mut HashMap, - slices: &mut HashMap, - max_cursor_position: usize, - ) -> Option { - let start = self.cursor_chunk; - let end = usize::min(self.queue.len(), max_cursor_position); - let mut slice_index = self.cursor_slice; - - for chunk_index in start..end { - if chunk_index > start { - slice_index = 0; - } - - if let Some(id) = self.queue.get(chunk_index) { - let chunk = chunks.get_mut(id).unwrap(); - let result = self.find_free_slice_in_chunk(size, chunk, slices, slice_index); - - if let Some((_cursor_slice, slice)) = result { - let slice = slices.get(&slice).unwrap(); - self.cursor_slice = slice.next_slice_position(); - self.cursor_chunk = chunk_index; - return Some(slice.id()); - } - } - self.cursor_chunk = chunk_index; - self.cursor_slice = 0; - } - - None - } -} - -#[cfg(test)] -mod tests { - use super::stub::*; - use super::*; - use alloc::vec; - - #[test] - fn simple_1() { - let mut ring = RingBuffer::::new(0); - - let slice_1 = new_slice(0, 100, 0); - let slice_2 = new_slice(1, 200, 1); - let chunk_1 = new_chunk(0, vec![0, 1]); - - let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]); - let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]); - - ring.push_chunk(ChunkId { value: 0 }); - - let slice = ring.find_free_slice(50, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, SliceId { value: 0 }); - assert_eq!(slices.get(&slice).unwrap().size, 50); - assert_eq!(slices.len(), 3); - assert_eq!(chunks.values().last().unwrap().slices.len(), 3); - } - - #[test] - fn simple_2() { - let mut ring = RingBuffer::::new(0); - - let slice_1 = new_slice(0, 100, 0); - let slice_2 = new_slice(1, 200, 1); - let chunk_1 = new_chunk(0, vec![0, 1]); - - let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]); - let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]); - - ring.push_chunk(ChunkId { value: 0 }); - - let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, SliceId { value: 0 }); - assert_eq!(slices.get(&slice).unwrap().size, 150); - assert_eq!(slices.len(), 2); - assert_eq!(chunks.values().last().unwrap().slices.len(), 2); - } - - #[test] - fn multiple_chunks() { - let mut ring = RingBuffer::::new(0); - - let slice_1 = new_slice(0, 100, 0); - let slice_2 = new_slice(1, 200, 1); - let slice_3 = new_slice(2, 200, 0); - let slice_4 = new_slice(3, 200, 1); - let chunk_1 = new_chunk(0, vec![0, 1]); - let chunk_2 = new_chunk(1, vec![2, 3]); - - let mut slices = HashMap::from([ - (slice_1.id, slice_1), - (slice_2.id, slice_2), - (slice_3.id, slice_3), - (slice_4.id, slice_4), - ]); - let mut chunks = HashMap::from([(chunk_1.id, chunk_1), (chunk_2.id, chunk_2)]); - - ring.push_chunk(ChunkId { value: 0 }); - ring.push_chunk(ChunkId { value: 1 }); - - slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true; - slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = false; - slices.get_mut(&SliceId { value: 3 }).unwrap().is_free = false; - - let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, SliceId { value: 2 }); - - let slice = ring.find_free_slice(100, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, SliceId { value: 0 }); - } - - #[test] - fn find_free_slice_with_exact_fit() { - let mut ring = RingBuffer::::new(0); - - let slice_1 = new_slice(0, 100, 0); - let slice_2 = new_slice(1, 200, 1); - let chunk_1 = new_chunk(0, vec![0, 1]); - - let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]); - let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]); - - ring.push_chunk(ChunkId { value: 0 }); - - slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = false; - slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true; - - let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, SliceId { value: 1 }); - assert_eq!(slices.get(&slice).unwrap().size, 200); - assert_eq!(slices.len(), 2); - assert_eq!(chunks.values().last().unwrap().slices.len(), 2); - } - - #[test] - fn find_free_slice_with_merging() { - let mut ring = RingBuffer::::new(0); - - let slice_1 = new_slice(0, 100, 0); - let slice_2 = new_slice(1, 50, 1); - let slice_3 = new_slice(2, 100, 2); - let chunk_1 = new_chunk(0, vec![0, 1, 2]); - - let mut slices = HashMap::from([ - (slice_1.id, slice_1), - (slice_2.id, slice_2), - (slice_3.id, slice_3), - ]); - let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]); - - ring.push_chunk(ChunkId { value: 0 }); - - slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true; - slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true; - slices.get_mut(&SliceId { value: 2 }).unwrap().is_free = true; - - let slice = ring.find_free_slice(250, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slice, SliceId { value: 0 }); - assert_eq!(slices.get(&slice).unwrap().size, 250); - assert_eq!(slices.len(), 1); - assert_eq!(chunks.values().last().unwrap().slices.len(), 1); - } - - #[test] - fn find_free_slice_with_multiple_chunks_and_merging() { - let mut ring = RingBuffer::::new(0); - - let slice_1 = new_slice(0, 50, 0); - let slice_2 = new_slice(1, 50, 1); - let chunk_1 = new_chunk(0, vec![0, 1]); - - let slice_3 = new_slice(2, 100, 0); - let slice_4 = new_slice(3, 50, 1); - let chunk_2 = new_chunk(1, vec![2, 3]); - - let mut slices = HashMap::from([ - (slice_1.id, slice_1), - (slice_2.id, slice_2), - (slice_3.id, slice_3), - (slice_4.id, slice_4), - ]); - let mut chunks = HashMap::from([(chunk_1.id, chunk_1), (chunk_2.id, chunk_2)]); - - ring.push_chunk(ChunkId { value: 0 }); - ring.push_chunk(ChunkId { value: 1 }); - - slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true; - slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true; - slices.get_mut(&SliceId { value: 2 }).unwrap().is_free = true; - slices.get_mut(&SliceId { value: 3 }).unwrap().is_free = true; - - let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap(); - - assert_eq!(slices.get(&slice).unwrap().size, 150); - assert_eq!(slices.len(), 2); - assert_eq!(chunks.values().last().unwrap().slices.len(), 1); - } - - fn new_slice(id: usize, size: usize, position: usize) -> TestSlice { - TestSlice { - id: SliceId { value: id }, - is_free: true, - size, - position, - } - } - - fn new_chunk(id: usize, slices: Vec) -> TestChunk { - TestChunk { - id: ChunkId { value: id }, - slices: slices.into_iter().map(|i| SliceId { value: i }).collect(), - } - } -} - -#[cfg(test)] -mod stub { - use super::*; - use burn_common::*; - - #[derive(Debug)] - pub struct TestChunk { - pub id: ChunkId, - pub slices: Vec, - } - - #[derive(Debug)] - pub struct TestSlice { - pub id: SliceId, - pub is_free: bool, - pub size: usize, - pub position: usize, - } - - impl MemorySlice for TestSlice { - fn is_free(&self) -> bool { - self.is_free - } - - fn size(&self) -> usize { - self.size - } - - fn split(&mut self, offset: usize, _buffer_alignment: usize) -> Option { - let size_remained = self.size - offset; - self.size = offset; - - Some(Self { - id: SliceId { - value: rand::gen_random(), - }, - is_free: true, - size: size_remained, - position: self.position + 1, - }) - } - - fn id(&self) -> SliceId { - self.id - } - - fn next_slice_position(&self) -> usize { - self.position + 1 - } - } - - impl MemoryChunk for TestChunk { - fn merge_next_slice( - &mut self, - from_slice_index: usize, - slices: &mut HashMap, - ) -> bool { - let slice_id_current = self.slices.get(from_slice_index).unwrap(); - let slice_id_next = self.slices.get(from_slice_index + 1); - let slice_id_next = match slice_id_next { - Some(val) => val, - None => return false, - }; - - let slice_next = slices.get(slice_id_next).unwrap(); - let is_free = slice_next.is_free; - let size = slice_next.size; - - let slice_current = slices.get_mut(slice_id_current).unwrap(); - - if is_free { - slice_current.size += size; - slices.remove(slice_id_next); - self.slices.remove(from_slice_index + 1); - - for (index, temp_slice_id) in self.slices.iter_mut().enumerate() { - let slice = slices.get_mut(temp_slice_id).unwrap(); - slice.position = index; - } - return true; - } - - false - } - - fn slice(&self, index: usize) -> Option { - self.slices.get(index).copied() - } - - fn insert_slice( - &mut self, - position: usize, - slice: TestSlice, - slices: &mut HashMap, - ) { - self.slices.insert(position, slice.id()); - slices.insert(slice.id(), slice); - for (index, temp_slice_id) in self.slices.iter_mut().enumerate() { - let temp_slice = slices.get_mut(temp_slice_id).unwrap(); - temp_slice.position = index; - } - } - } -} diff --git a/crates/burn-compute/src/memory_management/memory_pool/small.rs b/crates/burn-compute/src/memory_management/memory_pool/small.rs deleted file mode 100644 index 7eb9cf5cf4..0000000000 --- a/crates/burn-compute/src/memory_management/memory_pool/small.rs +++ /dev/null @@ -1,226 +0,0 @@ -use super::{ChunkHandle, ChunkId, MemoryPoolBinding, MemoryPoolHandle, SliceHandle, SliceId}; -use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization}; -use alloc::vec::Vec; -use hashbrown::HashMap; - -/// A memory pool that allocates fixed-size chunks (32 bytes each) and reuses them to minimize allocations. -/// -/// - Only one slice is supported per chunk due to the limitations in WGPU where small allocations cannot be offset. -/// - The pool uses a ring buffer to efficiently manage and reuse chunks. -/// -/// Fields: -/// - `chunks`: A hashmap storing the allocated chunks by their IDs. -/// - `slices`: A hashmap storing the slices by their IDs. -/// - `ring_buffer`: A vector used as a ring buffer to manage chunk reuse. -/// - `index`: The current position in the ring buffer. -pub struct SmallMemoryPool { - chunks: HashMap, - slices: HashMap, - ring_buffer: Vec, - index: usize, - buffer_storage_alignment_offset: usize, -} - -#[derive(new, Debug)] -pub struct SmallChunk { - pub storage: StorageHandle, - #[allow(dead_code)] - pub handle: ChunkHandle, - pub slice: Option, -} - -#[derive(new, Debug)] -pub struct SmallSlice { - pub storage: StorageHandle, - pub handle: SliceHandle, - #[allow(dead_code)] - pub chunk: ChunkHandle, - pub padding: usize, -} - -impl SmallSlice { - pub fn effective_size(&self) -> usize { - self.storage.size() + self.padding - } -} - -impl SmallMemoryPool { - pub fn new(buffer_storage_alignment_offset: usize) -> Self { - Self { - chunks: HashMap::new(), - slices: HashMap::new(), - ring_buffer: Vec::new(), - index: 0, - buffer_storage_alignment_offset, - } - } - - /// Returns the resource from the storage, for the specified handle. - pub fn get( - &mut self, - storage: &mut Storage, - binding: &MemoryPoolBinding, - ) -> Option { - self.slices - .get(binding.slice.id()) - .map(|s| &s.storage) - .map(|h| storage.get(h)) - } - - /// Reserves memory of specified size using the reserve algorithm, and return - /// a handle to the reserved memory. - /// - /// Also clean ups, merging free slices together if permitted by the merging strategy - pub fn reserve( - &mut self, - storage: &mut Storage, - size: usize, - sync: Sync, - ) -> MemoryPoolHandle { - assert!(size <= self.buffer_storage_alignment_offset); - let slice = self.get_free_slice(size); - - match slice { - Some(slice) => MemoryPoolHandle { - slice: slice.clone(), - }, - None => self.alloc(storage, size, sync), - } - } - - pub fn alloc( - &mut self, - storage: &mut Storage, - size: usize, - _sync: Sync, - ) -> MemoryPoolHandle { - assert!(size <= self.buffer_storage_alignment_offset); - - self.alloc_slice(storage, size) - } - - fn alloc_slice( - &mut self, - storage: &mut Storage, - slice_size: usize, - ) -> MemoryPoolHandle { - let handle_chunk = self.create_chunk(storage, self.buffer_storage_alignment_offset); - let chunk_id = *handle_chunk.id(); - let slice = self.allocate_slice(handle_chunk.clone(), slice_size); - - let handle_slice = slice.handle.clone(); - self.update_chunk_metadata(chunk_id, slice); - - MemoryPoolHandle { - slice: handle_slice, - } - } - - fn allocate_slice(&self, handle_chunk: ChunkHandle, slice_size: usize) -> SmallSlice { - let slice = self.create_slice(0, slice_size, handle_chunk.clone()); - - let effective_size = slice.effective_size(); - assert_eq!(effective_size, self.buffer_storage_alignment_offset); - - slice - } - - fn update_chunk_metadata(&mut self, chunk_id: ChunkId, slice: SmallSlice) { - let slice_id = *slice.handle.id(); - - self.slices.insert(slice_id, slice); - self.chunks.get_mut(&chunk_id).unwrap().slice = Some(slice_id); - } - - fn find_free_slice(&mut self) -> Option { - if self.ring_buffer.is_empty() { - return None; - } - for _ in 0..self.ring_buffer.len() { - let chunk_id = self.ring_buffer.get(self.index).unwrap(); - let chunk = self.chunks.get(chunk_id).unwrap(); - let slice = self.slices.get(&chunk.slice.unwrap()).unwrap(); - self.index = (self.index + 1) % self.ring_buffer.len(); - if slice.handle.is_free() { - return Some(*slice.handle.id()); - } - } - None - } - - /// Finds a free slice that can contain the given size - /// Returns the chunk's id and size. - fn get_free_slice(&mut self, size: usize) -> Option { - let slice_id = self.find_free_slice()?; - - let slice = self.slices.get_mut(&slice_id).unwrap(); - let old_slice_size = slice.effective_size(); - - let offset = match slice.storage.utilization { - StorageUtilization::Full(_) => 0, - StorageUtilization::Slice { offset, size: _ } => offset, - }; - assert_eq!(offset, 0); - slice.storage.utilization = StorageUtilization::Slice { offset, size }; - let new_padding = old_slice_size - size; - slice.padding = new_padding; - assert_eq!( - slice.effective_size(), - old_slice_size, - "new and old slice should have the same size" - ); - - Some(slice.handle.clone()) - } - - /// Creates a slice of size `size` upon the given chunk with the given offset. - fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> SmallSlice { - assert_eq!(offset, 0); - let chunk = self.chunks.get(handle_chunk.id()).unwrap(); - let handle = SliceHandle::new(); - - let storage = StorageHandle { - id: chunk.storage.id.clone(), - utilization: StorageUtilization::Slice { offset, size }, - }; - - let padding = calculate_padding(size, self.buffer_storage_alignment_offset); - - SmallSlice::new(storage, handle, chunk.handle.clone(), padding) - } - - /// Creates a chunk of given size by allocating on the storage. - fn create_chunk( - &mut self, - storage: &mut Storage, - size: usize, - ) -> ChunkHandle { - let padding = calculate_padding(size, self.buffer_storage_alignment_offset); - let effective_size = size + padding; - - let storage = storage.alloc(effective_size); - let handle = ChunkHandle::new(); - let id = *handle.id(); - - self.ring_buffer.push(id); - - self.chunks - .insert(id, SmallChunk::new(storage, handle.clone(), None)); - - handle - } - - #[allow(unused)] - fn deallocate(&mut self, _storage: &mut Storage) { - todo!() - } -} - -fn calculate_padding(size: usize, buffer_storage_alignment_offset: usize) -> usize { - let remainder = size % buffer_storage_alignment_offset; - if remainder != 0 { - buffer_storage_alignment_offset - remainder - } else { - 0 - } -} diff --git a/crates/burn-compute/src/memory_management/mod.rs b/crates/burn-compute/src/memory_management/mod.rs deleted file mode 100644 index 1ed29851cd..0000000000 --- a/crates/burn-compute/src/memory_management/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -pub(crate) mod memory_pool; - -mod base; - -pub use base::*; - -/// Dynamic memory management strategy. -pub mod dynamic; -/// Simple memory management strategy. -pub mod simple; diff --git a/crates/burn-compute/src/memory_management/simple.rs b/crates/burn-compute/src/memory_management/simple.rs deleted file mode 100644 index ad243860fe..0000000000 --- a/crates/burn-compute/src/memory_management/simple.rs +++ /dev/null @@ -1,559 +0,0 @@ -use crate::{ - memory_id_type, - storage::{ComputeStorage, StorageHandle, StorageUtilization}, -}; -use alloc::vec::Vec; -use hashbrown::HashMap; - -#[cfg(all(not(target_family = "wasm"), feature = "std"))] -use std::time; -#[cfg(all(target_family = "wasm", feature = "std"))] -use web_time as time; - -use super::{MemoryBinding, MemoryHandle, MemoryManagement}; - -// The ChunkId allows to keep track of how many references there are to a specific chunk. -memory_id_type!(ChunkId, ChunkHandle, ChunkBinding); -// The SliceId allows to keep track of how many references there are to a specific slice. -memory_id_type!(SliceId, SliceHandle, SliceBinding); - -/// A tensor memory handle, referring to either a chunk or a slice. -#[derive(Debug, Clone)] -pub enum SimpleHandle { - /// A whole chunk of memory. - Chunk(ChunkHandle), - /// A slice of a chunk of memory. - Slice(SliceHandle), -} - -/// Binding of the [simple handle](SimpleHandle). -#[derive(Debug, Clone)] -pub enum SimpleBinding { - /// Binding of the [chunk handle](ChunkHandle). - Chunk(ChunkBinding), - /// Binding of the [slice handle](SliceHandle) - Slice(SliceBinding), -} - -/// The strategy defines the frequency at which deallocation of unused memory chunks should occur. -#[derive(Debug)] -pub enum DeallocStrategy { - /// Once every n calls to reserve. - PeriodTick { - /// Number of calls to be executed before triggering the deallocation. - period: usize, - /// Current state. Should start at zero. - state: usize, - }, - #[cfg(feature = "std")] - /// Once every period of time - PeriodTime { - /// Number of time before triggering the deallocation. - period: time::Duration, - /// Current state. Should start at now. - state: time::Instant, - }, - /// Never deallocate. - Never, -} - -/// The strategy defines when to reuse chunk with slices. -#[derive(Debug)] -pub enum SliceStrategy { - /// Never use slices. - Never, - /// Ratio needed before the chunk can be used as a slice. Between 0 and 1. - Ratio(f32), - /// When the reserved memory is at least {} bytes. - MinimumSize(usize), - /// When the reserved memory less than {} bytes. - MaximumSize(usize), -} - -impl SliceStrategy { - /// If the chunk can be used with a slice. - pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool { - if chunk_size < reserved_size { - return false; - } - - match self { - SliceStrategy::Never => false, - SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio, - SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes, - SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes, - } - } -} - -impl DeallocStrategy { - /// Create a new strategy with the given period. - pub fn new_period_tick(period: usize) -> Self { - DeallocStrategy::PeriodTick { period, state: 0 } - } - - fn should_dealloc(&mut self) -> bool { - match self { - DeallocStrategy::PeriodTick { period, state } => { - *state = (*state + 1) % *period; - *state == 0 - } - #[cfg(feature = "std")] - DeallocStrategy::PeriodTime { period, state } => { - if &state.elapsed() > period { - *state = time::Instant::now(); - true - } else { - false - } - } - DeallocStrategy::Never => false, - } - } -} - -#[derive(new)] -struct Chunk { - storage: StorageHandle, - handle: ChunkHandle, - slices: Vec, -} - -#[derive(new)] -struct Slice { - storage: StorageHandle, - handle: SliceHandle, - // It is important to keep the chunk handle inside the slice, since it increases the ref count - // on the chunk id and make the `is_free` method return false until the slice is freed. - // - // TL;DR we can't only store the chunk id. - chunk: ChunkHandle, -} - -/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks. -pub struct SimpleMemoryManagement { - chunks: HashMap, - slices: HashMap, - dealloc_strategy: DeallocStrategy, - slice_strategy: SliceStrategy, - storage: Storage, -} - -impl core::fmt::Debug for SimpleMemoryManagement { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str( - alloc::format!( - "SimpleMemoryManagement {:?} - {:?}", - self.dealloc_strategy, - core::any::type_name::(), - ) - .as_str(), - ) - } -} - -impl MemoryBinding for SimpleBinding {} - -impl MemoryHandle for SimpleHandle { - fn can_mut(&self) -> bool { - match &self { - SimpleHandle::Chunk(id) => id.can_mut(), - SimpleHandle::Slice(id) => id.can_mut(), - } - } - - fn binding(self) -> SimpleBinding { - match self { - Self::Chunk(handle) => SimpleBinding::Chunk(handle.binding()), - Self::Slice(handle) => SimpleBinding::Slice(handle.binding()), - } - } -} - -impl MemoryManagement for SimpleMemoryManagement { - type Handle = SimpleHandle; - type Binding = SimpleBinding; - - /// Returns the resource from the storage, for the specified handle. - fn get(&mut self, binding: Self::Binding) -> Storage::Resource { - let storage = match binding { - SimpleBinding::Chunk(chunk) => { - &self - .chunks - .get(chunk.id()) - .expect("Storage found for the given execution buffer handle") - .storage - } - SimpleBinding::Slice(slice) => { - &self - .slices - .get(slice.id()) - .expect("Storage found for the given execution buffer handle") - .storage - } - }; - - self.storage.get(storage) - } - - /// Reserves memory of specified size using the reserve algorithm, and return - /// a handle to the reserved memory. - /// - /// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy. - fn reserve(&mut self, size: usize, _sync: Sync) -> Self::Handle { - self.cleanup_slices(); - - let handle = self.reserve_algorithm(size); - - if self.dealloc_strategy.should_dealloc() { - self.cleanup_chunks(); - } - - handle - } - - fn alloc(&mut self, size: usize, _sync: Sync) -> Self::Handle { - self.create_chunk(size) - } - - fn dealloc(&mut self, binding: Self::Binding) { - match binding { - SimpleBinding::Chunk(chunk) => { - if let Some(chunk) = self.chunks.remove(chunk.id()) { - self.storage.dealloc(chunk.storage.id); - } - } - SimpleBinding::Slice(_) => panic!("Can't dealloc slice manually"), - } - } - - fn storage(&mut self) -> &mut Storage { - &mut self.storage - } -} - -impl SimpleMemoryManagement { - /// Creates a new instance using the given storage, deallocation strategy and slice strategy. - pub fn new( - storage: Storage, - dealloc_strategy: DeallocStrategy, - slice_strategy: SliceStrategy, - ) -> Self { - Self { - chunks: HashMap::new(), - slices: HashMap::new(), - dealloc_strategy, - slice_strategy, - storage, - } - } - - fn reserve_algorithm(&mut self, size: usize) -> SimpleHandle { - // Looks for a large enough, existing but unused chunk of memory. - let chunk = self.find_free_chunk(size); - - match chunk { - Some(chunk) => { - if size == chunk.storage.size() { - // If there is one of exactly the same size, it reuses it. - SimpleHandle::Chunk(chunk.handle.clone()) - } else { - // Otherwise creates a slice of the right size upon it, always starting at zero. - self.create_slice(size, chunk.handle.clone()) - } - } - // If no chunk available, creates one of exactly the right size. - None => self.create_chunk(size), - } - } - - /// Finds the smallest of the free and large enough chunks to fit `size` - /// Returns the chunk's id and size. - fn find_free_chunk(&self, size: usize) -> Option<&Chunk> { - let mut size_diff_current = usize::MAX; - let mut current = None; - - for chunk in self.chunks.values() { - // If chunk is already used, we do not choose it - if !chunk.handle.is_free() { - continue; - } - - let storage_size = chunk.storage.size(); - - // If we find a chunk of exactly the right size, we stop searching altogether - if size == storage_size { - current = Some(chunk); - break; - } - - // Finds the smallest of the large enough chunks that can accept a slice - // of the given size - if self.slice_strategy.can_use_chunk(storage_size, size) { - let size_diff = storage_size - size; - - if size_diff < size_diff_current { - current = Some(chunk); - size_diff_current = size_diff; - } - } - } - - current - } - - /// Creates a slice of size `size` upon the given chunk. - /// - /// For now slices must start at zero, therefore there can be only one per chunk - fn create_slice(&mut self, size: usize, handle_chunk: ChunkHandle) -> SimpleHandle { - let chunk = self.chunks.get_mut(handle_chunk.id()).unwrap(); - let handle_slice = SliceHandle::new(); - - let storage = StorageHandle { - id: chunk.storage.id.clone(), - utilization: StorageUtilization::Slice { offset: 0, size }, - }; - - if chunk.slices.is_empty() { - self.slices.insert( - *handle_slice.id(), - Slice::new(storage, handle_slice.clone(), handle_chunk.clone()), - ); - } else { - panic!("Can't have more than 1 slice yet."); - } - - chunk.slices.push(*handle_slice.id()); - - SimpleHandle::Slice(handle_slice) - } - - /// Creates a chunk of given size by allocating on the storage. - fn create_chunk(&mut self, size: usize) -> SimpleHandle { - let storage = self.storage.alloc(size); - let handle = ChunkHandle::new(); - - self.chunks.insert( - *handle.id(), - Chunk::new(storage, handle.clone(), Vec::new()), - ); - - SimpleHandle::Chunk(handle) - } - - /// Deallocates free chunks and remove them from chunks map. - fn cleanup_chunks(&mut self) { - let mut ids_to_remove = Vec::new(); - - self.chunks.iter().for_each(|(chunk_id, chunk)| { - if chunk.handle.is_free() { - ids_to_remove.push(*chunk_id); - } - }); - - ids_to_remove - .iter() - .map(|chunk_id| self.chunks.remove(chunk_id).unwrap()) - .for_each(|chunk| { - self.storage.dealloc(chunk.storage.id); - }); - } - - /// Removes free slices from slice map and corresponding chunks. - fn cleanup_slices(&mut self) { - let mut ids_to_remove = Vec::new(); - - self.slices.iter().for_each(|(slice_id, slice)| { - if slice.handle.is_free() { - ids_to_remove.push(*slice_id); - } - }); - - ids_to_remove - .iter() - .map(|slice_id| self.slices.remove(slice_id).unwrap()) - .for_each(|slice| { - let chunk = self.chunks.get_mut(slice.chunk.id()).unwrap(); - chunk.slices.retain(|id| id != slice.handle.id()); - }); - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - memory_management::{MemoryHandle, MemoryManagement}, - storage::BytesStorage, - }; - - impl SimpleMemoryManagement { - fn reserve_no_sync(&mut self, size: usize) -> SimpleHandle { - self.reserve(size, || {}) - } - } - - #[test] - fn can_mut_with_single_tensor_reference() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Never, - ); - - let chunk_size = 4; - let simple_handle = memory_management.create_chunk(chunk_size); - - let x = simple_handle.clone(); - core::mem::drop(simple_handle); - - assert!(x.can_mut()); - } - - #[test] - fn two_tensor_references_remove_mutability() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Never, - ); - - let chunk_size = 4; - let simple_handle = memory_management.create_chunk(chunk_size); - - let x = simple_handle.clone(); - - assert!(!simple_handle.can_mut()); - assert!(!x.can_mut()) - } - - #[test] - fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Never, - ); - let chunk_size = 4; - let _chunk_handle = memory_management.reserve_no_sync(chunk_size); - let _new_handle = memory_management.reserve_no_sync(chunk_size); - - assert_eq!(memory_management.chunks.len(), 2); - } - - #[test] - fn when_empty_chunk_is_cleaned_upexists_it_disappears() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Never, - ); - let chunk_size = 4; - let chunk_handle = memory_management.reserve_no_sync(chunk_size); - drop(chunk_handle); - memory_management.cleanup_chunks(); - - assert_eq!(memory_management.chunks.len(), 0); - } - - #[test] - fn never_dealloc_strategy_never_deallocs() { - let mut never_dealloc = DeallocStrategy::Never; - for _ in 0..20 { - assert!(!never_dealloc.should_dealloc()) - } - } - - #[test] - fn period_tick_dealloc_strategy_should_dealloc_after_period() { - let period = 3; - let mut period_tick_dealloc = DeallocStrategy::new_period_tick(period); - - for _ in 0..3 { - for _ in 0..period - 1 { - assert!(!period_tick_dealloc.should_dealloc()); - } - assert!(period_tick_dealloc.should_dealloc()); - } - } - - #[test] - fn slice_strategy_minimum_bytes() { - let strategy = SliceStrategy::MinimumSize(100); - - assert!(strategy.can_use_chunk(200, 101)); - assert!(!strategy.can_use_chunk(200, 99)); - } - - #[test] - fn slice_strategy_maximum_bytes() { - let strategy = SliceStrategy::MaximumSize(100); - - assert!(strategy.can_use_chunk(200, 99)); - assert!(!strategy.can_use_chunk(200, 101)); - } - - #[test] - fn slice_strategy_ratio() { - let strategy = SliceStrategy::Ratio(0.9); - - assert!(strategy.can_use_chunk(200, 180)); - assert!(!strategy.can_use_chunk(200, 179)); - } - - #[test] - fn test_handle_mutability() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Ratio(0.5), - ); - let handle = memory_management.reserve_no_sync(10); - - let other_ref = handle.clone(); - - assert!(!handle.can_mut(), "Handle can't be mut when multiple ref."); - drop(other_ref); - assert!(handle.can_mut(), "Handle should be mut when only one ref."); - } - - #[test] - fn test_slice_mutability() { - let mut memory_management = SimpleMemoryManagement::new( - BytesStorage::default(), - DeallocStrategy::Never, - SliceStrategy::Ratio(0.5), - ); - let chunk = memory_management.reserve_no_sync(10); - - if let super::SimpleHandle::Slice(_) = chunk { - panic!("Should be a chunk.") - } - - drop(chunk); - - let slice = memory_management.reserve_no_sync(8); - - if let super::SimpleHandle::Chunk(_) = &slice { - panic!("Should be a slice.") - } - - if let super::SimpleHandle::Slice(slice) = slice { - let other_ref = slice.clone(); - - assert!( - !slice.can_mut(), - "Slice can't be mut when multiple ref to the same handle." - ); - drop(other_ref); - assert!( - slice.can_mut(), - "Slice should be mut when only one ref to the same handle." - ); - assert!( - !slice.is_free(), - "Slice can't be reallocated when one ref still exist." - ); - } - } -} diff --git a/crates/burn-compute/src/server.rs b/crates/burn-compute/src/server.rs deleted file mode 100644 index 948fa24e9a..0000000000 --- a/crates/burn-compute/src/server.rs +++ /dev/null @@ -1,105 +0,0 @@ -use crate::{ - memory_management::{MemoryHandle, MemoryManagement}, - storage::ComputeStorage, - tune::AutotuneKey, -}; -use alloc::vec::Vec; -use burn_common::{reader::Reader, sync_type::SyncType}; -use core::fmt::Debug; - -/// The compute server is responsible for handling resources and computations over resources. -/// -/// Everything in the server is mutable, therefore it should be solely accessed through the -/// [compute channel](crate::channel::ComputeChannel) for thread safety. -pub trait ComputeServer: Send + core::fmt::Debug -where - Self: Sized, -{ - /// The kernel type defines the computation algorithms. - type Kernel: Send; - /// Options when dispatching the kernel, eg. the number of executions. - type DispatchOptions: Send; - /// The [storage](ComputeStorage) type defines how data is stored and accessed. - type Storage: ComputeStorage; - /// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type. - type MemoryManagement: MemoryManagement; - /// The key used to cache operations used on specific inputs in autotune - type AutotuneKey: AutotuneKey; - /// Features supported by the compute server. - type FeatureSet: Send + Sync; - - /// Given a handle, returns the owned resource as bytes. - fn read(&mut self, binding: Binding) -> Reader; - - /// Given a resource handle, returns the storage resource. - fn get_resource( - &mut self, - binding: Binding, - ) -> ::Resource; - - /// Given a resource as bytes, stores it and returns the memory handle. - fn create(&mut self, data: &[u8]) -> Handle; - - /// Reserves `size` bytes in the storage, and returns a handle over them. - fn empty(&mut self, size: usize) -> Handle; - - /// Executes the `kernel` over the given memory `handles`. - /// - /// Kernels have mutable access to every resource they are given - /// and are responsible of determining which should be read or written. - fn execute( - &mut self, - kernel: Self::Kernel, - count: Self::DispatchOptions, - bindings: Vec>, - ); - - /// Wait for the completion of every task in the server. - fn sync(&mut self, command: SyncType); -} - -/// Server handle containing the [memory handle](MemoryManagement::Handle). -#[derive(new, Debug)] -pub struct Handle { - /// Memory handle. - pub memory: >::Handle, -} - -/// Binding of a [tensor handle](Handle) to execute a kernel. -#[derive(new, Debug)] -pub struct Binding { - /// Memory binding. - pub memory: >::Binding, -} - -impl Handle { - /// If the tensor handle can be reused inplace. - pub fn can_mut(&self) -> bool { - MemoryHandle::can_mut(&self.memory) - } -} - -impl Handle { - /// Convert the [handle](Handle) into a [binding](Binding). - pub fn binding(self) -> Binding { - Binding { - memory: MemoryHandle::binding(self.memory), - } - } -} - -impl Clone for Handle { - fn clone(&self) -> Self { - Self { - memory: self.memory.clone(), - } - } -} - -impl Clone for Binding { - fn clone(&self) -> Self { - Self { - memory: self.memory.clone(), - } - } -} diff --git a/crates/burn-compute/src/storage/base.rs b/crates/burn-compute/src/storage/base.rs deleted file mode 100644 index db7e95be5c..0000000000 --- a/crates/burn-compute/src/storage/base.rs +++ /dev/null @@ -1,64 +0,0 @@ -use crate::storage_id_type; - -// This ID is used to map a handle to its actual data. -storage_id_type!(StorageId); - -/// Defines if data uses a full memory chunk or a slice of it. -#[derive(Clone, Debug)] -pub enum StorageUtilization { - /// Full memory chunk of specified size - Full(usize), - /// Slice of memory chunk with start index and size. - Slice { - /// The offset in bytes from the chunk start. - offset: usize, - /// The size of the slice in bytes. - size: usize, - }, -} - -/// Contains the [storage id](StorageId) of a resource and the way it is used. -#[derive(new, Clone, Debug)] -pub struct StorageHandle { - /// Storage id. - pub id: StorageId, - /// How the storage is used. - pub utilization: StorageUtilization, -} - -impl StorageHandle { - /// Returns the size the handle is pointing to in memory. - pub fn size(&self) -> usize { - match self.utilization { - StorageUtilization::Full(size) => size, - StorageUtilization::Slice { offset: _, size } => size, - } - } - - /// Returns the size the handle is pointing to in memory. - pub fn offset(&self) -> usize { - match self.utilization { - StorageUtilization::Full(..) => panic!("full size slice not supported anymore"), - StorageUtilization::Slice { offset, .. } => offset, - } - } -} - -/// Storage types are responsible for allocating and deallocating memory. -pub trait ComputeStorage: Send { - /// The resource associated type determines the way data is implemented and how - /// it can be accessed by kernels. - type Resource: Send; - - /// Returns the underlying resource for a specified storage handle - fn get(&mut self, handle: &StorageHandle) -> Self::Resource; - - /// Allocates `size` units of memory and returns a handle to it - fn alloc(&mut self, size: usize) -> StorageHandle; - - /// Deallocates the memory pointed by the given storage id. - fn dealloc(&mut self, id: StorageId); - - /// Copy - fn copy(&mut self, from: &StorageHandle, to: &StorageHandle); -} diff --git a/crates/burn-compute/src/storage/bytes_cpu.rs b/crates/burn-compute/src/storage/bytes_cpu.rs deleted file mode 100644 index 3b6c493c71..0000000000 --- a/crates/burn-compute/src/storage/bytes_cpu.rs +++ /dev/null @@ -1,150 +0,0 @@ -use super::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; -use alloc::alloc::{alloc, dealloc, Layout}; -use hashbrown::HashMap; - -/// The bytes storage maps ids to pointers of bytes in a contiguous layout. -#[derive(Default)] -pub struct BytesStorage { - memory: HashMap, -} - -impl core::fmt::Debug for BytesStorage { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str("BytesStorage") - } -} - -/// Can send to other threads. -unsafe impl Send for BytesStorage {} -unsafe impl Send for BytesResource {} - -/// This struct is a pointer to a memory chunk or slice. -pub struct BytesResource { - ptr: *mut u8, - utilization: StorageUtilization, -} - -/// This struct refers to a specific (contiguous) layout of bytes. -struct AllocatedBytes { - ptr: *mut u8, - layout: Layout, -} - -impl BytesResource { - fn get_exact_location_and_length(&self) -> (*mut u8, usize) { - match self.utilization { - StorageUtilization::Full(len) => (self.ptr, len), - StorageUtilization::Slice { offset, size } => unsafe { (self.ptr.add(offset), size) }, - } - } - - /// Returns the resource as a mutable slice of bytes. - pub fn write<'a>(&self) -> &'a mut [u8] { - let (ptr, len) = self.get_exact_location_and_length(); - - unsafe { core::slice::from_raw_parts_mut(ptr, len) } - } - - /// Returns the resource as an immutable slice of bytes. - pub fn read<'a>(&self) -> &'a [u8] { - let (ptr, len) = self.get_exact_location_and_length(); - - unsafe { core::slice::from_raw_parts(ptr, len) } - } -} - -impl ComputeStorage for BytesStorage { - type Resource = BytesResource; - - fn get(&mut self, handle: &StorageHandle) -> Self::Resource { - let allocated_bytes = self.memory.get_mut(&handle.id).unwrap(); - - BytesResource { - ptr: allocated_bytes.ptr, - utilization: handle.utilization.clone(), - } - } - - fn alloc(&mut self, size: usize) -> StorageHandle { - let id = StorageId::new(); - let handle = StorageHandle { - id: id.clone(), - utilization: StorageUtilization::Full(size), - }; - - unsafe { - let layout = Layout::array::(size).unwrap(); - let ptr = alloc(layout); - let memory = AllocatedBytes { ptr, layout }; - - self.memory.insert(id, memory); - } - - handle - } - - fn dealloc(&mut self, id: StorageId) { - if let Some(memory) = self.memory.remove(&id) { - unsafe { - dealloc(memory.ptr, memory.layout); - } - } - } - - fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) { - assert_eq!(from.size(), to.size()); - - let input = self.get(from); - let output = self.get(to); - - for i in 0..from.size() { - let offset = i + from.offset(); - let ptr_out = output.ptr.wrapping_add(offset); - - let offset = i + to.offset(); - let ptr_in = input.ptr.wrapping_add(offset); - - unsafe { *ptr_in = *ptr_out } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_can_alloc_and_dealloc() { - let mut storage = BytesStorage::default(); - let handle_1 = storage.alloc(64); - - assert_eq!(handle_1.size(), 64); - storage.dealloc(handle_1.id); - } - - #[test] - fn test_slices() { - let mut storage = BytesStorage::default(); - let handle_1 = storage.alloc(64); - let handle_2 = StorageHandle::new( - handle_1.id.clone(), - StorageUtilization::Slice { - offset: 24, - size: 8, - }, - ); - - storage - .get(&handle_1) - .write() - .iter_mut() - .enumerate() - .for_each(|(i, b)| { - *b = i as u8; - }); - - let bytes = storage.get(&handle_2).read().to_vec(); - storage.dealloc(handle_1.id); - assert_eq!(bytes, &[24, 25, 26, 27, 28, 29, 30, 31]); - } -} diff --git a/crates/burn-compute/src/storage/mod.rs b/crates/burn-compute/src/storage/mod.rs deleted file mode 100644 index 0bf21bd31c..0000000000 --- a/crates/burn-compute/src/storage/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod base; - -pub use base::*; - -#[cfg(feature = "storage-bytes")] -mod bytes_cpu; -#[cfg(feature = "storage-bytes")] -pub use bytes_cpu::*; diff --git a/crates/burn-compute/src/tune/mod.rs b/crates/burn-compute/src/tune/mod.rs deleted file mode 100644 index c501437498..0000000000 --- a/crates/burn-compute/src/tune/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod operation; -mod tune_benchmark; -mod tune_cache; -mod tuner; - -pub use operation::*; -pub use tune_benchmark::*; -pub use tune_cache::*; -pub use tuner::*; diff --git a/crates/burn-compute/src/tune/operation.rs b/crates/burn-compute/src/tune/operation.rs deleted file mode 100644 index ffdcba8bbd..0000000000 --- a/crates/burn-compute/src/tune/operation.rs +++ /dev/null @@ -1,69 +0,0 @@ -use alloc::boxed::Box; -use alloc::string::String; -use alloc::vec::Vec; -use core::fmt::{Debug, Display}; -use core::hash::Hash; - -/// Default checksum for an operation set -#[cfg(feature = "autotune-persistent-cache")] -pub fn compute_checksum(autotunables: &[Box]) -> String { - let mut checksum = String::new(); - autotunables.iter().for_each(|op| { - checksum += op.name(); - }); - format!("{:x}", md5::compute(checksum)) -} - -/// Groups operations of the same type for autotune -pub trait AutotuneOperationSet: Send { - /// The key used in the tune cache - fn key(&self) -> K; - - /// All candidate operations for autotuning this operation type - /// Operations can run on toy tensors of relevant size - fn autotunables(&self) -> Vec>; - - /// Returns the operation for the given index, matching the order - /// returned by autotunables. Operation obtained here runs on original tensors - fn fastest(self: Box, fastest_index: usize) -> Box; - - /// Compute a checksum that can invalidate outdated cached auto-tune results. - #[cfg(feature = "autotune-persistent-cache")] - fn compute_checksum(&self) -> String { - compute_checksum(&self.autotunables()) - } -} - -/// Contains operation to run and inputs on which to run it -pub trait AutotuneOperation { - /// Runs the operation - fn execute(self: Box); - - /// The name of the operation. - fn name(&self) -> &str { - core::any::type_name::() - } - - /// Clones the operation and inputs - fn clone(&self) -> Box; -} - -#[cfg(feature = "autotune-persistent-cache")] -/// Trait alias with support for persistent caching -pub trait AutotuneKey: - Clone - + Debug - + PartialEq - + Eq - + Hash - + Display - + serde::Serialize - + serde::de::DeserializeOwned - + Send - + Sync -{ -} -#[cfg(not(feature = "autotune-persistent-cache"))] -/// Trait alias -pub trait AutotuneKey: Clone + Debug + PartialEq + Eq + Hash + Display {} -impl AutotuneKey for String {} diff --git a/crates/burn-compute/src/tune/tune_benchmark.rs b/crates/burn-compute/src/tune/tune_benchmark.rs deleted file mode 100644 index e0fd6a2493..0000000000 --- a/crates/burn-compute/src/tune/tune_benchmark.rs +++ /dev/null @@ -1,48 +0,0 @@ -use burn_common::benchmark::Benchmark; -use burn_common::sync_type::SyncType; - -use crate::channel::ComputeChannel; -use crate::client::ComputeClient; -use crate::server::ComputeServer; - -use super::AutotuneOperation; -use alloc::boxed::Box; -use alloc::string::{String, ToString}; - -/// A benchmark that runs on server handles -#[derive(new)] -pub struct TuneBenchmark { - operation: Box, - client: ComputeClient, -} - -impl Clone for Box { - fn clone(&self) -> Self { - self.as_ref().clone() - } -} - -impl> Benchmark for TuneBenchmark { - type Args = Box; - - fn prepare(&self) -> Self::Args { - self.operation.clone() - } - - fn num_samples(&self) -> usize { - 10 - } - - fn execute(&self, operation: Self::Args) { - AutotuneOperation::execute(operation); - } - - fn name(&self) -> String { - "autotune".to_string() - } - - fn sync(&self) { - // For benchmarks - we need to wait for all tasks to complete before returning. - self.client.sync(SyncType::Wait); - } -} diff --git a/crates/burn-compute/src/tune/tune_cache.rs b/crates/burn-compute/src/tune/tune_cache.rs deleted file mode 100644 index 7208ef9cec..0000000000 --- a/crates/burn-compute/src/tune/tune_cache.rs +++ /dev/null @@ -1,243 +0,0 @@ -#[cfg(feature = "autotune-persistent-cache")] -mod std_imports { - pub use std::fs; - pub use std::fs::File; - pub use std::io; - pub use std::path::Path; - pub use std::path::PathBuf; -} -#[cfg(feature = "autotune-persistent-cache")] -use std_imports::*; - -#[cfg(feature = "autotune-persistent-cache")] -use serde::{Deserialize, Serialize}; - -use super::AutotuneKey; -use super::AutotuneOperation; -use super::AutotuneOperationSet; -use alloc::boxed::Box; -use hashbrown::HashMap; - -#[cfg(feature = "autotune-persistent-cache")] -/// Return the file path for the persistent cache on disk -/// prefix should be the device id computed at the backend level -pub fn get_persistent_cache_file_path(prefix: &str) -> PathBuf { - let home_dir = dirs::home_dir().expect("An home directory should exist"); - let path_dir = home_dir.join(".cache").join("burn").join("autotune"); - let path = Path::new(&path_dir); - path.join(format!("{}-autotune-cache.json", prefix)) -} - -/// In-memory cache entry -#[derive(Debug)] -pub(crate) struct InMemoryCacheEntry { - #[cfg(feature = "autotune-persistent-cache")] - checksum_checked: bool, - fastest_index: usize, -} - -/// Persistent cache entry -#[cfg(feature = "autotune-persistent-cache")] -#[derive(Debug, Serialize, Deserialize)] -pub(crate) struct PersistentCacheEntry { - checksum: String, - fastest_index: usize, -} - -/// Use to find and reuse the best kernel for some input -#[derive(Debug)] -pub(crate) struct TuneCache { - in_memory_cache: HashMap, - #[cfg(feature = "autotune-persistent-cache")] - persistent_cache: HashMap, - #[cfg(feature = "autotune-persistent-cache")] - device_id: String, - #[cfg(feature = "autotune-persistent-cache")] - name: String, -} - -/// Result of the cache try -pub enum TuneCacheResult { - /// An operation is found and given - Hit(Box), - /// No operation is found and the set is given back for ownership - Miss(Box>), -} - -impl TuneCache { - pub(crate) fn new( - #[cfg_attr(not(feature = "autotune-persistent-cache"), allow(unused_variables))] name: &str, - #[cfg_attr(not(feature = "autotune-persistent-cache"), allow(unused_variables))] - device_id: &str, - ) -> Self { - #[cfg(feature = "autotune-persistent-cache")] - { - let mut cache = TuneCache { - in_memory_cache: HashMap::new(), - persistent_cache: HashMap::new(), - device_id: device_id.to_string(), - name: name.to_string(), - }; - if let Err(e) = cache.load() { - log::warn!( - "Unable to load autotune cache. Cache will be ignored ({}).", - e - ); - } - cache - } - - #[cfg(not(feature = "autotune-persistent-cache"))] - { - TuneCache { - in_memory_cache: HashMap::new(), - } - } - } - - pub(crate) fn find_fastest(&self, key: &K) -> Option { - let val = self.in_memory_cache.get(key)?; - - #[cfg(feature = "autotune-persistent-cache")] - if val.checksum_checked { - Some(val.fastest_index) - } else { - None - } - - #[cfg(not(feature = "autotune-persistent-cache"))] - Some(val.fastest_index) - } - - pub(crate) fn try_cache( - &mut self, - autotune_operation_set: Box>, - ) -> TuneCacheResult { - let key = autotune_operation_set.key(); - let result = self.in_memory_cache.get_mut(&key); - - #[cfg(feature = "autotune-persistent-cache")] - { - if let Some(InMemoryCacheEntry { - checksum_checked, - fastest_index, - }) = result - { - if !*checksum_checked { - let checksum = autotune_operation_set.compute_checksum(); - let persistent_entry = self - .persistent_cache - .get(&key) - .expect("Both caches should be in sync"); - if checksum != persistent_entry.checksum { - return TuneCacheResult::Miss(autotune_operation_set); - } - *checksum_checked = true; - } - return TuneCacheResult::Hit(autotune_operation_set.fastest(*fastest_index)); - } - } - - #[cfg(not(feature = "autotune-persistent-cache"))] - { - if let Some(InMemoryCacheEntry { fastest_index, .. }) = result { - return TuneCacheResult::Hit(autotune_operation_set.fastest(*fastest_index)); - } - } - - TuneCacheResult::Miss(autotune_operation_set) - } - - pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) { - self.in_memory_cache.insert( - key, - InMemoryCacheEntry { - #[cfg(feature = "autotune-persistent-cache")] - checksum_checked: true, - fastest_index, - }, - ); - } - - #[cfg(feature = "autotune-persistent-cache")] - pub(crate) fn persistent_cache_insert( - &mut self, - key: K, - checksum: String, - fastest_index: usize, - ) { - self.persistent_cache.insert( - key, - PersistentCacheEntry { - checksum, - fastest_index, - }, - ); - } - - /// Load the persistent cache data from disk - #[cfg(feature = "autotune-persistent-cache")] - pub(crate) fn load(&mut self) -> Result<(), io::Error> { - let file_path = self.get_persistent_cache_file_path(); - // note: reading file from memory is faster than using - // serde from_reader with a buffered reader - // see issue: - // https://github.com/serde-rs/json/issues/160 - match fs::read_to_string(file_path) { - Ok(data) => { - let data: Vec<(K, PersistentCacheEntry)> = serde_json::from_str(&data)?; - for (key, value) in data.into_iter() { - self.persistent_cache.insert(key, value); - } - Ok(()) - } - Err(e) => { - if e.kind() == std::io::ErrorKind::NotFound { - Ok(()) - } else { - Err(e) - } - } - }?; - for (key, entry) in self.persistent_cache.iter() { - self.in_memory_cache.insert( - key.clone(), - InMemoryCacheEntry { - checksum_checked: false, - fastest_index: entry.fastest_index, - }, - ); - } - Ok(()) - } - - /// Save the persistent cache on disk - #[cfg(feature = "autotune-persistent-cache")] - pub(crate) fn save(&self) { - let file_path = self.get_persistent_cache_file_path(); - if let Some(parent_dir) = file_path.parent() { - if !parent_dir.exists() { - fs::create_dir_all(parent_dir).unwrap_or_else(|_| { - panic!( - "Should be able to create directory '{}' for autotune persistent cache file", - parent_dir.to_str().unwrap()) - }); - } - } - let file = File::create(file_path.clone()).unwrap_or_else(|_| { - panic!( - "Should be able to open autotune persistent cache file '{}'", - file_path.to_str().unwrap() - ) - }); - let data = self.persistent_cache.iter().collect::>(); - serde_json::to_writer_pretty(file, &data) - .expect("Should be able to write to autotune persistent cache"); - } - - /// Return the file path for the persistent cache on disk - #[cfg(feature = "autotune-persistent-cache")] - pub fn get_persistent_cache_file_path(&self) -> PathBuf { - get_persistent_cache_file_path(&format!("{}-{}", self.name, self.device_id)) - } -} diff --git a/crates/burn-compute/src/tune/tuner.rs b/crates/burn-compute/src/tune/tuner.rs deleted file mode 100644 index bac874c4e1..0000000000 --- a/crates/burn-compute/src/tune/tuner.rs +++ /dev/null @@ -1,124 +0,0 @@ -#[cfg(target_family = "wasm")] -use web_time::Duration; - -#[cfg(not(target_family = "wasm"))] -use core::time::Duration; - -use alloc::boxed::Box; -use alloc::string::ToString; -use alloc::vec::Vec; -use burn_common::benchmark::{Benchmark, BenchmarkComputations, BenchmarkDurations}; - -use crate::channel::ComputeChannel; -use crate::client::ComputeClient; -use crate::server::ComputeServer; -use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCache}; - -use super::AutotuneKey; - -#[derive(Debug)] -/// Executes autotune benchmarking and caching -pub struct Tuner { - tune_cache: TuneCache, -} - -#[allow(clippy::new_without_default)] -impl Tuner { - /// Returns a tuner with cache initialized from persistent cache - pub fn new(name: &str, device_id: &str) -> Self { - Self { - tune_cache: TuneCache::new(name, device_id), - } - } - - /// Fetch the fastest autotune operation index for an autotune key. - pub fn autotune_fastest(&self, key: &K) -> Option { - self.tune_cache.find_fastest(key) - } - - /// Execute the fastest autotune operation if known, otherwise perform some benchmarks before. - pub fn execute_autotune( - &mut self, - autotune_operation_set: Box>, - client: &ComputeClient, - ) where - S: ComputeServer, - C: ComputeChannel, - { - let operation = match self.tune_cache.try_cache(autotune_operation_set) { - super::TuneCacheResult::Hit(ops) => ops, - super::TuneCacheResult::Miss(set) => self.autotuning(set, client), - }; - - AutotuneOperation::execute(operation); - } - - fn autotuning( - &mut self, - autotune_operation_set: Box>, - client: &ComputeClient, - ) -> Box - where - S: ComputeServer, - C: ComputeChannel, - { - let key = autotune_operation_set.key(); - let autotunables = autotune_operation_set.autotunables(); - let mut names = Vec::with_capacity(autotunables.len()); - - let results: Vec = autotunables - .into_iter() - .map(|op| { - names.push(op.name().to_string()); - self.run_benchmark(op, client) - }) - .collect(); - - // Finds the fastest operation, stores it and returns it - let fastest_index = self.find_fastest(results); - let fastest_name = names.get(fastest_index).unwrap(); - log::info!("Fastest result {fastest_name}-{key}"); - - self.tune_cache.cache_insert(key.clone(), fastest_index); - #[cfg(feature = "autotune-persistent-cache")] - { - let checksum = autotune_operation_set.compute_checksum(); - self.tune_cache - .persistent_cache_insert(key, checksum, fastest_index); - self.tune_cache.save(); - } - - match self.tune_cache.try_cache(autotune_operation_set) { - super::TuneCacheResult::Hit(ops) => ops, - super::TuneCacheResult::Miss(_) => panic!("We just inserted, should not miss"), - } - } - - fn run_benchmark( - &mut self, - operation: Box, - client: &ComputeClient, - ) -> BenchmarkDurations - where - S: ComputeServer, - C: ComputeChannel, - { - TuneBenchmark::new(operation, client.clone()).run() - } - - fn find_fastest(&self, results: Vec) -> usize { - let mut smallest_duration = Duration::MAX; - let mut fastest_tunable = None; - - for (i, result) in results.into_iter().enumerate() { - let computed = BenchmarkComputations::new(&result); - - if computed.median < smallest_duration { - smallest_duration = computed.median; - fastest_tunable = Some(i); - } - } - - fastest_tunable.expect("At least one kernel needed. ") - } -} diff --git a/crates/burn-compute/tests/dummy/compute.rs b/crates/burn-compute/tests/dummy/compute.rs deleted file mode 100644 index 493504d3d9..0000000000 --- a/crates/burn-compute/tests/dummy/compute.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::sync::Arc; - -use super::DummyServer; -use burn_common::stub::RwLock; -use burn_compute::channel::MutexComputeChannel; -use burn_compute::client::ComputeClient; -use burn_compute::memory_management::simple::{ - DeallocStrategy, SimpleMemoryManagement, SliceStrategy, -}; -use burn_compute::storage::BytesStorage; -use burn_compute::tune::Tuner; -use burn_compute::ComputeRuntime; - -/// The dummy device. -#[derive(Clone, Debug, Hash, PartialEq, Eq)] -pub struct DummyDevice; - -pub type DummyChannel = MutexComputeChannel; -pub type DummyClient = ComputeClient; - -static RUNTIME: ComputeRuntime = ComputeRuntime::new(); -pub static TUNER_DEVICE_ID: &str = "tests/dummy-device"; -pub static TUNER_PREFIX: &str = "dummy-tests/dummy-device"; - -pub fn init_client() -> ComputeClient> { - let storage = BytesStorage::default(); - let memory_management = - SimpleMemoryManagement::new(storage, DeallocStrategy::Never, SliceStrategy::Never); - let server = DummyServer::new(memory_management); - let channel = MutexComputeChannel::new(server); - let tuner = Arc::new(RwLock::new(Tuner::new("dummy", TUNER_DEVICE_ID))); - ComputeClient::new(channel, tuner, Arc::new(())) -} - -pub fn client(device: &DummyDevice) -> DummyClient { - RUNTIME.client(device, init_client) -} diff --git a/crates/burn-compute/tests/dummy/kernel.rs b/crates/burn-compute/tests/dummy/kernel.rs deleted file mode 100644 index 30a67d5538..0000000000 --- a/crates/burn-compute/tests/dummy/kernel.rs +++ /dev/null @@ -1,25 +0,0 @@ -use burn_compute::storage::BytesResource; - -/// The DummyKernel trait should be implemented for every supported operation -pub trait DummyKernel: Sync + Send { - fn compute(&self, resources: &mut [BytesResource]); -} - -/// Contains the algorithm for element-wise addition -pub struct DummyElementwiseAddition; - -impl DummyKernel for DummyElementwiseAddition { - fn compute(&self, inputs: &mut [BytesResource]) { - // Notice how the kernel is responsible for determining which inputs - // are read-only and which are writable. - let lhs = &inputs[0].read(); - let rhs = &inputs[1].read(); - let out = &mut inputs[2].write(); - - let size = lhs.len(); - - for i in 0..size { - out[i] = lhs[i] + rhs[i]; - } - } -} diff --git a/crates/burn-compute/tests/dummy/mod.rs b/crates/burn-compute/tests/dummy/mod.rs deleted file mode 100644 index a560368197..0000000000 --- a/crates/burn-compute/tests/dummy/mod.rs +++ /dev/null @@ -1,9 +0,0 @@ -mod compute; -mod kernel; -mod server; -mod tune; - -pub use compute::*; -pub use kernel::*; -pub use server::*; -pub use tune::*; diff --git a/crates/burn-compute/tests/dummy/server.rs b/crates/burn-compute/tests/dummy/server.rs deleted file mode 100644 index 5e786fe163..0000000000 --- a/crates/burn-compute/tests/dummy/server.rs +++ /dev/null @@ -1,74 +0,0 @@ -use std::sync::Arc; - -use burn_common::{reader::reader_from_concrete, sync_type::SyncType}; -use burn_compute::{ - memory_management::{simple::SimpleMemoryManagement, MemoryHandle, MemoryManagement}, - server::{Binding, ComputeServer, Handle}, - storage::{BytesResource, BytesStorage}, -}; -use derive_new::new; - -use super::DummyKernel; - -/// The dummy server is used to test the burn-compute infrastructure. -/// It uses simple memory management with a bytes storage on CPU, without asynchronous tasks. -#[derive(new, Debug)] -pub struct DummyServer> { - memory_management: MM, -} - -impl ComputeServer for DummyServer -where - MM: MemoryManagement, -{ - type DispatchOptions = (); - type Kernel = Arc; - type Storage = BytesStorage; - type MemoryManagement = MM; - type AutotuneKey = String; - type FeatureSet = (); - - fn read(&mut self, binding: Binding) -> burn_common::reader::Reader { - let bytes = self.memory_management.get(binding.memory); - reader_from_concrete(bytes.read().to_vec()) - } - - fn get_resource(&mut self, binding: Binding) -> BytesResource { - self.memory_management.get(binding.memory) - } - - fn create(&mut self, data: &[u8]) -> Handle { - let handle = self.memory_management.reserve(data.len(), || {}); - let resource = self.memory_management.get(handle.clone().binding()); - - let bytes = resource.write(); - - for (i, val) in data.iter().enumerate() { - bytes[i] = *val; - } - - Handle::new(handle) - } - - fn empty(&mut self, size: usize) -> Handle { - Handle::new(self.memory_management.reserve(size, || {})) - } - - fn execute( - &mut self, - kernel: Self::Kernel, - _count: Self::DispatchOptions, - bindings: Vec>, - ) { - let mut resources = bindings - .into_iter() - .map(|binding| self.memory_management.get(binding.memory)) - .collect::>(); - - kernel.compute(&mut resources); - } - - fn sync(&mut self, _: SyncType) { - // Nothing to do with dummy backend. - } -} diff --git a/crates/burn-compute/tests/dummy/tune/autotune_operations.rs b/crates/burn-compute/tests/dummy/tune/autotune_operations.rs deleted file mode 100644 index 97e1a132f4..0000000000 --- a/crates/burn-compute/tests/dummy/tune/autotune_operations.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::sync::Arc; - -use burn_compute::{client::ComputeClient, server::Binding, tune::AutotuneOperation}; -use derive_new::new; - -use crate::dummy::{DummyChannel, DummyKernel, DummyServer}; - -#[derive(new)] -/// Extended kernel that accounts for additional parameters, i.e. needed -/// information that does not count as an input/output. -pub struct OneKernelAutotuneOperation { - kernel: Arc, - client: ComputeClient, - shapes: Vec>, - bindings: Vec>, -} - -impl AutotuneOperation for OneKernelAutotuneOperation { - /// Executes the operation on given bindings and server, with the additional parameters - fn execute(self: Box) { - self.client.execute(self.kernel.clone(), (), self.bindings); - } - - fn clone(&self) -> Box { - Box::new(Self { - kernel: self.kernel.clone(), - client: self.client.clone(), - shapes: self.shapes.clone(), - bindings: self.bindings.clone(), - }) - } -} diff --git a/crates/burn-compute/tests/dummy/tune/kernels.rs b/crates/burn-compute/tests/dummy/tune/kernels.rs deleted file mode 100644 index bd0058310b..0000000000 --- a/crates/burn-compute/tests/dummy/tune/kernels.rs +++ /dev/null @@ -1,106 +0,0 @@ -use std::{thread::sleep, time::Duration}; - -use burn_compute::storage::BytesResource; - -use crate::dummy::DummyKernel; - -const SLEEP_MS: u64 = 1; - -pub struct DummyElementwiseAdditionSlowWrong; -pub struct DummyElementwiseMultiplication; -pub struct DummyElementwiseMultiplicationSlowWrong; -pub struct CacheTestFastOn3; -pub struct CacheTestSlowOn3; -pub struct ParameteredKernel; - -impl DummyKernel for DummyElementwiseAdditionSlowWrong { - fn compute(&self, inputs: &mut [BytesResource]) { - // Slow and wrong on purpose, for tests - let lhs = &inputs[0].read(); - let out = &mut inputs[2].write(); - - let size = lhs.len(); - - for i in 0..size { - sleep(Duration::from_millis(SLEEP_MS)); - out[i] = lhs[i] - } - } -} -impl DummyKernel for DummyElementwiseMultiplication { - fn compute(&self, inputs: &mut [BytesResource]) { - let lhs = &inputs[0].read(); - let rhs = &inputs[1].read(); - let out = &mut inputs[2].write(); - - let size = lhs.len(); - - for i in 0..size { - out[i] = lhs[i] * rhs[i]; - } - } -} -impl DummyKernel for DummyElementwiseMultiplicationSlowWrong { - fn compute(&self, inputs: &mut [BytesResource]) { - // Slow and wrong on purpose, for tests - let lhs = &inputs[0].read(); - let out = &mut inputs[2].write(); - - let size = lhs.len(); - - for i in 0..size { - sleep(Duration::from_millis(SLEEP_MS)); - out[i] = lhs[i]; - } - } -} -impl DummyKernel for CacheTestFastOn3 { - fn compute(&self, inputs: &mut [BytesResource]) { - // This is an artificial kernel designed for testing cache only - let lhs = &inputs[0].read(); - let out = &mut inputs[2].write(); - - let size = lhs.len(); - if size == 3 { - out[..size].copy_from_slice(&lhs[..size]); - } else { - for i in 0..size { - sleep(Duration::from_millis(SLEEP_MS)); - out[i] = lhs[i]; - } - } - } -} - -impl DummyKernel for CacheTestSlowOn3 { - fn compute(&self, inputs: &mut [BytesResource]) { - // This is an artificial kernel designed for testing cache only - let lhs = &inputs[0].read(); - let rhs = &inputs[1].read(); - let out = &mut inputs[2].write(); - - let size = lhs.len(); - if size == 3 { - for i in 0..size { - sleep(Duration::from_millis(SLEEP_MS)); - out[i] = rhs[i]; - } - } else { - out[..size].copy_from_slice(&rhs[..size]); - } - } -} - -impl DummyKernel for ParameteredKernel { - fn compute(&self, inputs: &mut [BytesResource]) { - // This is an artificial kernel designed for info buffer - let lhs = &inputs[0].read(); - let rhs = &inputs[1].read(); - let out = &mut inputs[2].write(); - let info = &inputs[3].read(); - - for i in 0..lhs.len() { - out[i] = lhs[i] + rhs[i] + info[0]; - } - } -} diff --git a/crates/burn-compute/tests/dummy/tune/mod.rs b/crates/burn-compute/tests/dummy/tune/mod.rs deleted file mode 100644 index c72d787f5f..0000000000 --- a/crates/burn-compute/tests/dummy/tune/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod autotune_operations; -mod kernels; -mod operation_sets; - -pub use autotune_operations::*; -pub use kernels::*; -#[allow(unused)] -pub use operation_sets::*; diff --git a/crates/burn-compute/tests/dummy/tune/operation_sets.rs b/crates/burn-compute/tests/dummy/tune/operation_sets.rs deleted file mode 100644 index dc707ec310..0000000000 --- a/crates/burn-compute/tests/dummy/tune/operation_sets.rs +++ /dev/null @@ -1,194 +0,0 @@ -#[cfg(feature = "autotune-persistent-cache")] -use rand::{distributions::Alphanumeric, Rng}; -use std::sync::Arc; - -#[cfg(feature = "autotune-persistent-cache")] -use burn_compute::tune::compute_checksum; -use burn_compute::{ - server::Binding, - tune::{AutotuneOperation, AutotuneOperationSet}, -}; - -use crate::dummy::{ - CacheTestFastOn3, CacheTestSlowOn3, DummyClient, DummyElementwiseAddition, - DummyElementwiseMultiplication, DummyElementwiseMultiplicationSlowWrong, DummyServer, - OneKernelAutotuneOperation, -}; - -use super::DummyElementwiseAdditionSlowWrong; - -pub struct AdditionAutotuneOperationSet { - client: DummyClient, - key: String, - shapes: Vec>, - bindings: Vec>, -} - -impl AdditionAutotuneOperationSet { - #[allow(dead_code)] - pub fn new( - client: DummyClient, - shapes: Vec>, - bindings: Vec>, - ) -> Self { - Self { - client, - key: format!("{}-{}", "add", log_shape_input_key(&shapes)), - shapes, - bindings, - } - } -} - -impl AutotuneOperationSet for AdditionAutotuneOperationSet { - fn key(&self) -> String { - self.key.clone() - } - - fn autotunables(&self) -> Vec> { - vec![ - Box::new(OneKernelAutotuneOperation::new( - Arc::new(DummyElementwiseAddition), - self.client.clone(), - self.shapes.clone(), - self.bindings.clone(), - )), - Box::new(OneKernelAutotuneOperation::new( - Arc::new(DummyElementwiseAdditionSlowWrong), - self.client.clone(), - self.shapes.clone(), - self.bindings.clone(), - )), - ] - } - - fn fastest(self: Box, fastest_index: usize) -> Box { - self.autotunables()[fastest_index].clone() - } -} - -pub struct MultiplicationAutotuneOperationSet { - client: DummyClient, - key: String, - shapes: Vec>, - bindings: Vec>, -} - -impl MultiplicationAutotuneOperationSet { - #[allow(dead_code)] - pub fn new( - client: DummyClient, - shapes: Vec>, - bindings: Vec>, - ) -> Self { - Self { - client, - key: format!("{}-{}", "mul", log_shape_input_key(&shapes)), - shapes, - bindings, - } - } -} -impl AutotuneOperationSet for MultiplicationAutotuneOperationSet { - fn key(&self) -> String { - self.key.clone() - } - - fn autotunables(&self) -> Vec> { - vec![ - Box::new(OneKernelAutotuneOperation::new( - Arc::new(DummyElementwiseMultiplicationSlowWrong), - self.client.clone(), - self.shapes.clone(), - self.bindings.clone(), - )), - Box::new(OneKernelAutotuneOperation::new( - Arc::new(DummyElementwiseMultiplication), - self.client.clone(), - self.shapes.clone(), - self.bindings.clone(), - )), - ] - } - - fn fastest(self: Box, fastest_index: usize) -> Box { - self.autotunables()[fastest_index].clone() - } -} - -pub struct CacheTestAutotuneOperationSet { - client: DummyClient, - key: String, - shapes: Vec>, - bindings: Vec>, - pub generate_random_checksum: bool, -} - -impl CacheTestAutotuneOperationSet { - #[allow(dead_code)] - pub fn new( - client: DummyClient, - shapes: Vec>, - bindings: Vec>, - ) -> Self { - Self { - client, - key: format!("{}-{}", "cache_test", log_shape_input_key(&shapes)), - shapes, - bindings, - generate_random_checksum: false, - } - } -} - -impl AutotuneOperationSet for CacheTestAutotuneOperationSet { - fn key(&self) -> String { - self.key.clone() - } - - fn autotunables(&self) -> Vec> { - vec![ - Box::new(OneKernelAutotuneOperation::new( - Arc::new(CacheTestFastOn3), - self.client.clone(), - self.shapes.clone(), - self.bindings.clone(), - )), - Box::new(OneKernelAutotuneOperation::new( - Arc::new(CacheTestSlowOn3), - self.client.clone(), - self.shapes.clone(), - self.bindings.clone(), - )), - ] - } - - fn fastest(self: Box, fastest_index: usize) -> Box { - self.autotunables()[fastest_index].clone() - } - - #[cfg(feature = "std")] - fn compute_checksum(&self) -> String { - if self.generate_random_checksum { - let rand_string: String = rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(16) - .map(char::from) - .collect(); - rand_string - } else { - compute_checksum(&self.autotunables()) - } - } -} - -pub fn log_shape_input_key(shapes: &[Vec]) -> String { - let mut hash = String::new(); - let lhs = &shapes[0]; - for size in lhs { - let exp = f32::ceil(f32::log2(*size as f32)) as u32; - hash.push_str(2_u32.pow(exp).to_string().as_str()); - hash.push(','); - } - hash -} diff --git a/crates/burn-compute/tests/integration_test.rs b/crates/burn-compute/tests/integration_test.rs deleted file mode 100644 index db6927b457..0000000000 --- a/crates/burn-compute/tests/integration_test.rs +++ /dev/null @@ -1,292 +0,0 @@ -mod dummy; - -use std::sync::Arc; - -use crate::dummy::{client, DummyDevice, DummyElementwiseAddition}; -use burn_compute::ComputeRuntime; - -#[allow(unused)] -use serial_test::serial; - -#[test] -fn created_resource_is_the_same_when_read() { - let client = client(&DummyDevice); - let resource = Vec::from([0, 1, 2]); - let resource_description = client.create(&resource); - - let obtained_resource = client.read(resource_description.binding()); - - assert_eq!(resource, obtained_resource) -} - -#[test] -fn empty_allocates_memory() { - let client = client(&DummyDevice); - let size = 4; - let resource_description = client.empty(size); - let empty_resource = client.read(resource_description.binding()); - - assert_eq!(empty_resource.len(), 4); -} - -#[test] -fn execute_elementwise_addition() { - let client = client(&DummyDevice); - let lhs = client.create(&[0, 1, 2]); - let rhs = client.create(&[4, 4, 4]); - let out = client.empty(3); - - client.execute( - Arc::new(DummyElementwiseAddition), - (), - vec![lhs.binding(), rhs.binding(), out.clone().binding()], - ); - - let obtained_resource = client.read(out.binding()); - - assert_eq!(obtained_resource, Vec::from([4, 5, 6])) -} - -#[test] -#[serial] -#[cfg(feature = "std")] -fn autotune_basic_addition_execution() { - let client = client(&DummyDevice); - - let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs = client.create(&[0, 1, 2]); - let rhs = client.create(&[4, 4, 4]); - let out = client.empty(3); - let handles = vec![lhs.binding(), rhs.binding(), out.clone().binding()]; - - let addition_autotune_kernel = - dummy::AdditionAutotuneOperationSet::new(client.clone(), shapes, handles); - client.autotune_execute(Box::new(addition_autotune_kernel)); - - let obtained_resource = client.read(out.binding()); - - // If slow kernel was selected it would output [0, 1, 2] - assert_eq!(obtained_resource, Vec::from([4, 5, 6])); -} - -#[test] -#[serial] -#[cfg(feature = "std")] -fn autotune_basic_multiplication_execution() { - let client = client(&DummyDevice); - - let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs = client.create(&[0, 1, 2]); - let rhs = client.create(&[4, 4, 4]); - let out = client.empty(3); - let handles = vec![lhs.binding(), rhs.binding(), out.clone().binding()]; - - let multiplication_autotune_kernel = - dummy::MultiplicationAutotuneOperationSet::new(client.clone(), shapes, handles); - client.autotune_execute(Box::new(multiplication_autotune_kernel)); - - let obtained_resource = client.read(out.binding()); - - // If slow kernel was selected it would output [0, 1, 2] - assert_eq!(obtained_resource, Vec::from([0, 4, 8])); -} - -#[test] -#[serial] -#[cfg(feature = "std")] -fn autotune_cache_same_key_return_a_cache_hit() { - type Runtime = ComputeRuntime; - let runtime = Runtime::new(); - - let client = runtime.client(&DummyDevice, dummy::init_client); - - // note: the key name depends on the shapes of the operation set - // see log_shape_input_key for more info. - - // in this test both shapes [1,3] and [1,4] end up with the same key name - // which is 'cache_test-1,4' - let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs_1 = client.create(&[0, 1, 2]); - let rhs_1 = client.create(&[4, 4, 4]); - let out_1 = client.empty(3); - let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()]; - - let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]]; - let lhs_2 = client.create(&[0, 1, 2, 3]); - let rhs_2 = client.create(&[5, 6, 7, 8]); - let out_2 = client.empty(4); - let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()]; - - let cache_test_autotune_kernel_1 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); - let cache_test_autotune_kernel_2 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); - client.autotune_execute(Box::new(cache_test_autotune_kernel_1)); - client.autotune_execute(Box::new(cache_test_autotune_kernel_2)); - - let obtained_resource = client.read(out_2.binding()); - - // Cache should be hit, so CacheTestFastOn3 should be used, returning lhs - assert_eq!(obtained_resource, Vec::from([0, 1, 2, 3])); -} - -#[test] -#[serial] -#[cfg(feature = "std")] -fn autotune_cache_no_cache_on_disk_return_a_cache_miss() { - // delete the cache file - let file_path = burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_PREFIX); - let _ = std::fs::remove_file(file_path); - - type Runtime = ComputeRuntime; - let compute = Runtime::new(); - - let client = compute.client(&DummyDevice, dummy::init_client); - - // in this test shapes [1,3] and [1,5] ends up with different key names - // which are 'cache_test-1,4' and 'cache_test-1,8' - let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs_1 = client.create(&[0, 1, 2]); - let rhs_1 = client.create(&[4, 4, 4]); - let out_1 = client.empty(3); - let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()]; - - let shapes_2 = vec![vec![1, 5], vec![1, 5], vec![1, 5]]; - let lhs_2 = client.create(&[0, 1, 2, 3, 4]); - let rhs_2 = client.create(&[5, 6, 7, 8, 9]); - let out_2 = client.empty(5); - let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()]; - - let cache_test_autotune_kernel_1 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); - let cache_test_autotune_kernel_2 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); - client.autotune_execute(Box::new(cache_test_autotune_kernel_1)); - client.autotune_execute(Box::new(cache_test_autotune_kernel_2)); - - // read the resource which should update the cache on disk - let obtained_resource = client.read(out_2.binding()); - - // Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs - assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9])); -} - -#[test] -#[serial] -#[cfg(feature = "std")] -fn autotune_cache_file_path_creation_works_when_path_does_not_exist_yet() { - // delete the cache file - - use burn_common::sync_type::SyncType; - let file_path = burn_compute::tune::get_persistent_cache_file_path(crate::dummy::TUNER_PREFIX); - let parent_dir = file_path - .parent() - .expect("Cache file should have a parent directory"); - // Delete the cache file's parent directory - let _ = std::fs::remove_dir_all(parent_dir); - - type Runtime = ComputeRuntime; - let runtime = Runtime::new(); - let client = runtime.client(&DummyDevice, dummy::init_client); - - // in this test shapes [1,3] and [1,5] ends up with different key names - // which are 'cache_test-1,4' and 'cache_test-1,8' - let shapes = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs = client.create(&[0, 1, 2]); - let rhs = client.create(&[4, 4, 4]); - let out = client.empty(3); - let handles = vec![lhs.binding(), rhs.binding(), out.clone().binding()]; - - let cache_test_autotune_kernel = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes, handles); - client.autotune_execute(Box::new(cache_test_autotune_kernel)); - // ensure that the autotune operations are run and cached - client.sync(SyncType::Wait); - - assert!( - parent_dir.exists(), - "Parent directory of the cache file should exist" - ); - assert!(file_path.exists(), "Cache file should exist"); -} - -#[test] -#[serial] -#[cfg(feature = "std")] -fn autotune_cache_different_keys_return_a_cache_miss() { - let client = client(&DummyDevice); - - // in this test shapes [1,3] and [1,5] ends up with different key names - // which are 'cache_test-1,4' and 'cache_test-1,8' - let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs_1 = client.create(&[0, 1, 2]); - let rhs_1 = client.create(&[4, 4, 4]); - let out_1 = client.empty(3); - let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()]; - - let shapes_2 = vec![vec![1, 5], vec![1, 5], vec![1, 5]]; - let lhs_2 = client.create(&[0, 1, 2, 3, 4]); - let rhs_2 = client.create(&[5, 6, 7, 8, 9]); - let out_2 = client.empty(5); - let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()]; - - let cache_test_autotune_kernel_1 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); - let cache_test_autotune_kernel_2 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); - client.autotune_execute(Box::new(cache_test_autotune_kernel_1)); - client.autotune_execute(Box::new(cache_test_autotune_kernel_2)); - - let obtained_resource = client.read(out_2.binding()); - - // Cache should be missed, so CacheTestSlowOn3 (but faster on 5) should be used, returning rhs - assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8, 9])); -} - -#[test] -#[serial] -#[cfg(feature = "std")] -fn autotune_cache_different_checksums_return_a_cache_miss() { - use burn_common::sync_type::SyncType; - - type Runtime = ComputeRuntime; - let runtime = Runtime::new(); - let client = runtime.client(&DummyDevice, dummy::init_client); - - // in this test both shapes [1,3] and [1,4] end up with the same key name - // which is 'cache_test-1,4' - let shapes_1 = vec![vec![1, 3], vec![1, 3], vec![1, 3]]; - let lhs_1 = client.create(&[0, 1, 2]); - let rhs_1 = client.create(&[4, 4, 4]); - let out_1 = client.empty(3); - let handles_1 = vec![lhs_1.binding(), rhs_1.binding(), out_1.binding()]; - let cache_test_autotune_kernel_1 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_1, handles_1); - client.autotune_execute(Box::new(cache_test_autotune_kernel_1)); - client.sync(SyncType::Wait); - - // we use a second compute client in order to have freshly initialized autotune cache - // and test invalidation of the cache when the checksum of the operation set is - // different - let runtime = Runtime::new(); - let client = runtime.client(&DummyDevice, dummy::init_client); - - let shapes_2 = vec![vec![1, 4], vec![1, 4], vec![1, 4]]; - let lhs_2 = client.create(&[0, 1, 2, 3]); - let rhs_2 = client.create(&[5, 6, 7, 8]); - let out_2 = client.empty(4); - let handles_2 = vec![lhs_2.binding(), rhs_2.binding(), out_2.clone().binding()]; - - let mut cache_test_autotune_kernel_2 = - dummy::CacheTestAutotuneOperationSet::new(client.clone(), shapes_2, handles_2); - cache_test_autotune_kernel_2.generate_random_checksum = true; - client.autotune_execute(Box::new(cache_test_autotune_kernel_2)); - client.sync(SyncType::Wait); - - let obtained_resource = client.read(out_2.binding()); - - // Cache should be missed because the checksum on 4 is generated randomly - // and thus is always different, - // so CacheTestSlowOn3 (but faster on 4) should be used, returning rhs - assert_eq!(obtained_resource, Vec::from([5, 6, 7, 8])); -} diff --git a/crates/burn-cube-macros/Cargo.toml b/crates/burn-cube-macros/Cargo.toml deleted file mode 100644 index 11757df219..0000000000 --- a/crates/burn-cube-macros/Cargo.toml +++ /dev/null @@ -1,27 +0,0 @@ -[package] -authors = [ - "nathanielsimard ", - "louisfd VariableTracker { - let analyzer = VariableAnalyzer::default(); - analyzer.analyze(func) - } -} - -impl VariableAnalyzer { - fn analyze(mut self, func: &syn::ItemFn) -> VariableTracker { - // Build the vector of (Id, depth), using recursion - self.signature_declarations(&func.sig); - self.find_occurrences_in_stmts(&func.block.stmts, 0); - - self.variable_tracker - } - - fn signature_declarations(&mut self, sig: &syn::Signature) { - for input in &sig.inputs { - match input { - syn::FnArg::Typed(pat) => { - let ident = &*pat.pat; - let is_comptime = is_ty_comptime(&pat.ty); - - match ident { - syn::Pat::Ident(pat_ident) => { - let id = &pat_ident.ident; - self.variable_tracker - .analyze_declare(id.to_string(), 0, is_comptime); - } - _ => todo!("Analysis: unsupported ident {ident:?}"), - } - } - _ => todo!("Analysis: unsupported input {input:?}"), - } - } - } - - fn find_occurrences_in_stmts(&mut self, stmts: &Vec, depth: u8) { - for stmt in stmts { - match stmt { - // Declaration - syn::Stmt::Local(local) => { - let mut is_comptime = false; - let id = match &local.pat { - syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), - syn::Pat::Type(pat_type) => { - is_comptime = is_ty_comptime(&pat_type.ty); - match &*pat_type.pat { - syn::Pat::Ident(pat_ident) => Some(&pat_ident.ident), - _ => todo!("Analysis: unsupported typed path {:?}", pat_type.pat), - } - } - syn::Pat::Wild(_) => None, - _ => todo!("Analysis: unsupported path {:?}", local.pat), - }; - if let Some(id) = id { - self.variable_tracker - .analyze_declare(id.to_string(), depth, is_comptime); - } - if let Some(local_init) = &local.init { - self.find_occurrences_in_expr(&local_init.expr, depth) - } - } - syn::Stmt::Expr(expr, _) => self.find_occurrences_in_expr(expr, depth), - _ => todo!("Analysis: unsupported stmt {stmt:?}"), - } - } - } - - fn find_occurrences_in_expr(&mut self, expr: &syn::Expr, depth: u8) { - match expr { - syn::Expr::ForLoop(expr) => { - self.find_occurrences_in_expr(&expr.expr, depth); - - let depth = depth + 1; - - if let syn::Pat::Ident(pat_ident) = &*expr.pat { - let id = &pat_ident.ident; - self.variable_tracker - .analyze_declare(id.to_string(), depth, false); - } - - self.find_occurrences_in_stmts(&expr.body.stmts, depth); - } - syn::Expr::While(expr) => { - let depth = depth + 1; - - self.find_occurrences_in_expr(&expr.cond, depth); - self.find_occurrences_in_stmts(&expr.body.stmts, depth); - } - syn::Expr::Loop(expr) => { - let depth = depth + 1; - - self.find_occurrences_in_stmts(&expr.body.stmts, depth); - } - syn::Expr::If(expr) => { - let depth = depth + 1; - - self.find_occurrences_in_expr(&expr.cond, depth); - self.find_occurrences_in_stmts(&expr.then_branch.stmts, depth); - if let Some((_, expr)) = &expr.else_branch { - match &**expr { - syn::Expr::Block(expr_block) => { - self.find_occurrences_in_stmts(&expr_block.block.stmts, depth); - } - syn::Expr::If(expr) => { - self.find_occurrences_in_expr(&syn::Expr::If(expr.clone()), depth); - } - _ => unreachable!(), - } - } - } - syn::Expr::Assign(expr) => { - self.find_occurrences_in_expr(&expr.left, depth); - self.find_occurrences_in_expr(&expr.right, depth); - } - syn::Expr::Index(expr) => { - self.find_occurrences_in_expr(&expr.expr, depth); - self.find_occurrences_in_expr(&expr.index, depth); - } - syn::Expr::Path(expr) => { - if let Some(ident) = expr.path.get_ident() { - if !KEYWORDS.contains(&ident.to_string().as_str()) { - self.variable_tracker.analyze_reuse(ident, depth, None); - } - } - } - syn::Expr::Binary(expr) => { - self.find_occurrences_in_expr(&expr.left, depth); - self.find_occurrences_in_expr(&expr.right, depth); - } - syn::Expr::Lit(_) => {} - syn::Expr::Call(expr) => { - match &*expr.func { - syn::Expr::Path(expr_path) => { - if let Some(first_segment) = expr_path.path.segments.first() { - // Check if the path segment has generic arguments - if let PathArguments::AngleBracketed(arguments) = - &first_segment.arguments - { - // Extract the generic arguments - for arg in &arguments.args { - match arg { - syn::GenericArgument::Type(_) - | syn::GenericArgument::Constraint(_) => {} - _ => todo!("Analysis: Generic {:?} not supported", arg), - } - } - } - } - } - _ => todo!("Analysis: unsupported func expr {:?}", expr.func), - } - for arg in expr.args.iter() { - self.find_occurrences_in_expr(arg, depth); - } - } - syn::Expr::MethodCall(expr) => { - self.find_occurrences_in_expr(&expr.receiver, depth); - for arg in expr.args.iter() { - self.find_occurrences_in_expr(arg, depth); - } - } - syn::Expr::Break(_) => {} - syn::Expr::Return(expr) => { - if expr.expr.is_some() { - // Unsupported: handled in codegen. - } - } - syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth), - syn::Expr::Array(_expr) => { - // No analysis since only literals are supported - } - syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth), - syn::Expr::Closure(expr) => { - let depth = depth + 1; - - for path in expr.inputs.iter() { - let mut is_comptime = false; - let ident = match path { - Pat::Ident(pat_ident) => &pat_ident.ident, - Pat::Type(pat_type) => { - is_comptime = is_ty_comptime(&pat_type.ty); - - if let Pat::Ident(pat_ident) = &*pat_type.pat { - &pat_ident.ident - } else { - todo!("Analysis: {:?} not supported in closure inputs. ", path); - } - } - _ => todo!("Analysis: {:?} not supported in closure inputs. ", path), - }; - - self.variable_tracker - .analyze_declare(ident.to_string(), depth, is_comptime); - } - - self.find_occurrences_in_expr(&expr.body, depth) - } - syn::Expr::Unary(expr) => self.find_occurrences_in_expr(&expr.expr, depth), - syn::Expr::Field(expr) => { - if let Member::Named(attribute_ident) = &expr.member { - if let syn::Expr::Path(struct_expr) = &*expr.base { - let struct_ident = struct_expr - .path - .get_ident() - .expect("Analysis: field access only supported on ident struct."); - - self.variable_tracker.analyze_reuse( - struct_ident, - depth, - Some(attribute_ident.to_string()), - ); - } else { - todo!("Analysis: field access only supported on ident struct."); - } - } else { - todo!("Analysis: unnamed attribute not supported."); - } - } - syn::Expr::Struct(expr) => { - for field in expr.fields.iter() { - self.find_occurrences_in_expr(&field.expr, depth) - } - } - syn::Expr::Range(_range) => { - // Error is handled during codegen. - } - _ => { - // Error is handled during codegen. - } - } - } -} - -fn is_ty_comptime(ty: &syn::Type) -> bool { - if let syn::Type::Path(path) = ty { - for segment in path.path.segments.iter() { - if segment.ident == "Comptime" { - return true; - } - } - } - - false -} diff --git a/crates/burn-cube-macros/src/codegen_common/mod.rs b/crates/burn-cube-macros/src/codegen_common/mod.rs deleted file mode 100644 index ed3f3a2df2..0000000000 --- a/crates/burn-cube-macros/src/codegen_common/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub(crate) mod signature; diff --git a/crates/burn-cube-macros/src/codegen_common/signature.rs b/crates/burn-cube-macros/src/codegen_common/signature.rs deleted file mode 100644 index 9f9654f471..0000000000 --- a/crates/burn-cube-macros/src/codegen_common/signature.rs +++ /dev/null @@ -1,70 +0,0 @@ -use quote::ToTokens; - -use crate::tracker::VariableTracker; - -#[derive(Copy, Clone, Debug)] -pub enum ExpandMode { - FuncImpl, - MethodImpl, -} - -pub fn expand_sig( - sig: &syn::Signature, - visibility: &syn::Visibility, - mut variable_tracker: Option<&mut VariableTracker>, - mode: ExpandMode, -) -> proc_macro2::TokenStream { - let mut inputs = quote::quote!(); - - for input in &sig.inputs { - match input { - syn::FnArg::Typed(pat) => { - let ident = pat.pat.clone(); - - if let syn::Pat::Ident(ident) = ident.as_ref() { - if let Some(vars) = &mut variable_tracker { - vars.codegen_declare(ident.ident.to_string(), 0); - } - } - - let ty = no_ref(pat.ty.as_ref()); - inputs.extend(quote::quote! { - #ident: <#ty as burn_cube::frontend::CubeType>::ExpandType, - }); - } - _ => todo!("Only Typed inputs are supported"), - } - } - - let mut output = quote::quote!(); - - match &sig.output { - syn::ReturnType::Default => output.extend(quote::quote! { ()}), - syn::ReturnType::Type(_, ty) => { - let ty = no_ref(ty.as_ref()); - output.extend(quote::quote! { - <#ty as burn_cube::frontend::CubeType>::ExpandType - }); - } - } - - let ident = &sig.ident; - let ident = match mode { - ExpandMode::FuncImpl => syn::Ident::new("__expand".to_string().as_str(), ident.span()), - _ => syn::Ident::new(format!("__expand_{ident}").as_str(), ident.span()), - }; - - let generics = sig.generics.clone().into_token_stream(); - - quote::quote! { - /// Expanded Cube function - #visibility fn #ident #generics (context: &mut burn_cube::frontend::CubeContext, #inputs) -> #output - } -} - -pub fn no_ref(ty: &syn::Type) -> &syn::Type { - match ty { - syn::Type::Reference(val) => &val.elem, - _ => ty, - } -} diff --git a/crates/burn-cube-macros/src/codegen_function/base.rs b/crates/burn-cube-macros/src/codegen_function/base.rs deleted file mode 100644 index e0e25c8c1d..0000000000 --- a/crates/burn-cube-macros/src/codegen_function/base.rs +++ /dev/null @@ -1,94 +0,0 @@ -use proc_macro2::TokenStream; -use quote::ToTokens; - -use super::{expr::codegen_expr, variable::codegen_local}; -use crate::tracker::VariableTracker; - -/// Codegen for a statement (generally one line) -/// Entry point of code generation -pub fn codegen_statement( - statement: &syn::Stmt, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - match statement { - syn::Stmt::Local(local) => codegen_local(local, loop_level, variable_tracker), - syn::Stmt::Expr(expr, semi) => { - let expr = codegen_expr(expr, loop_level, variable_tracker).tokens; - - match semi { - Some(_semi) => quote::quote!( - #expr; - ), - None => expr, - } - } - _ => todo!("Codegen: statement {statement:?} not supported"), - } -} - -/// Codegen for a code block (a list of statements) -pub(crate) fn codegen_block( - block: &syn::Block, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let mut statements = quote::quote!(); - - for statement in block.stmts.iter() { - statements.extend(codegen_statement(statement, loop_level, variable_tracker)); - } - - quote::quote! { - { - #statements - } - } -} - -pub(crate) struct Codegen { - pub tokens: proc_macro2::TokenStream, - pub is_comptime: bool, - pub array_indexing: Option, -} - -pub(crate) struct ArrayIndexing { - pub array: proc_macro2::TokenStream, - pub index: proc_macro2::TokenStream, -} - -impl From for Codegen { - fn from(tokens: proc_macro2::TokenStream) -> Self { - Self { - tokens, - is_comptime: false, - array_indexing: None, - } - } -} - -impl Codegen { - pub fn new>(tokens: S, is_comptime: bool) -> Self { - Self { - tokens: tokens.into(), - is_comptime, - array_indexing: None, - } - } - - pub fn split(self) -> (proc_macro2::TokenStream, bool) { - (self.tokens, self.is_comptime) - } -} - -impl ToTokens for Codegen { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend(self.tokens.clone()); - } - fn into_token_stream(self) -> TokenStream - where - Self: Sized, - { - self.tokens - } -} diff --git a/crates/burn-cube-macros/src/codegen_function/branch.rs b/crates/burn-cube-macros/src/codegen_function/branch.rs deleted file mode 100644 index e8d132d3a0..0000000000 --- a/crates/burn-cube-macros/src/codegen_function/branch.rs +++ /dev/null @@ -1,193 +0,0 @@ -use proc_macro2::TokenStream; - -use crate::{codegen_function::expr::codegen_expr, tracker::VariableTracker}; - -use super::{ - base::{codegen_block, Codegen}, - function::codegen_call, - operation::codegen_binary, - variable::{codegen_lit, codegen_path_var}, -}; - -/// Codegen of for loops -/// Supports range: -/// for i in range(start, end, unroll) {...} -pub(crate) fn codegen_for_loop( - for_loop: &syn::ExprForLoop, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let i = &for_loop.pat; - - if let syn::Pat::Ident(pat_ident) = &*for_loop.pat { - let id = &pat_ident.ident; - variable_tracker.codegen_declare(id.to_string(), loop_level as u8 + 1); - } - - let invalid_for_loop = || { - syn::Error::new_spanned( - &for_loop.expr, - "Invalid for loop: use [range](cubecl::prelude::range] instead.", - ) - .into_compile_error() - }; - - match for_loop.expr.as_ref() { - syn::Expr::Call(call) => { - let func_name = match call.func.as_ref() { - syn::Expr::Path(path) => match path.path.get_ident() { - Some(ident) => ident, - None => return invalid_for_loop(), - }, - _ => { - return invalid_for_loop(); - } - }; - - if &func_name.to_string() == "range" { - let mut args = call.args.clone(); - - let unroll = codegen_expr( - &args.pop().unwrap().into_value(), - loop_level, - variable_tracker, - ); - let end = codegen_expr( - &args.pop().unwrap().into_value(), - loop_level, - variable_tracker, - ); - let start = codegen_expr( - &args.pop().unwrap().into_value(), - loop_level, - variable_tracker, - ); - - let block = codegen_block(&for_loop.body, loop_level + 1, variable_tracker); - - quote::quote! { - { - let _start = #start; - let _end = #end; - let _unroll = #unroll; - burn_cube::frontend::branch::range_expand(context, _start, _end, _unroll, |context, #i| #block); - } - } - } else { - invalid_for_loop() - } - } - _ => invalid_for_loop(), - } -} - -/// Codegen for condition of an if or a while -pub(crate) fn codegen_cond( - cond: &syn::Expr, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - match cond { - syn::Expr::Binary(expr) => codegen_binary(expr, loop_level, variable_tracker), - syn::Expr::Lit(expr) => Codegen::new(codegen_lit(expr), false), - syn::Expr::Path(expr) => codegen_path_var(expr, loop_level, variable_tracker), - syn::Expr::Call(expr) => codegen_call(expr, loop_level, variable_tracker), - _ => todo!("{cond:?} cond not supported"), - } -} - -/// Codegen for break statement -pub(crate) fn codegen_break() -> TokenStream { - quote::quote! { - burn_cube::frontend::branch::break_expand(context); - } -} - -/// Codegen for return statement -pub(crate) fn codegen_return(expr_return: &syn::ExprReturn) -> TokenStream { - if expr_return.expr.is_some() { - return syn::Error::new_spanned(expr_return, "Only void return is supported.") - .into_compile_error(); - } - - quote::quote! { - burn_cube::frontend::branch::return_expand(context); - } -} - -/// Codegen for if and if/else statements -/// Supports: -/// if cond {...} -/// if cond {...} else {...} -/// if Comptime::get(...) {...} [else {...}] -/// if Comptime::get(...) {...} [else if Comptime::get(...) {...}]* [else {...}] -pub(crate) fn codegen_if( - expr_if: &syn::ExprIf, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let (cond, is_comptime) = codegen_cond(&expr_if.cond, loop_level, variable_tracker).split(); - let comptime_bool = if is_comptime { - quote::quote! { Some(#cond) } - } else { - quote::quote! { None } - }; - - let then_block = codegen_block(&expr_if.then_branch, loop_level + 1, variable_tracker); - - if let Some((_, expr)) = &expr_if.else_branch { - let else_block = match &**expr { - syn::Expr::Block(expr_block) => { - codegen_block(&expr_block.block, loop_level + 1, variable_tracker) - } - - syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level + 1, variable_tracker), - _ => unreachable!(), - }; - quote::quote! { - { - let _cond = #cond; - burn_cube::frontend::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block); - } - } - } else { - quote::quote! { - let _cond = #cond; - burn_cube::frontend::branch::if_expand(context, #comptime_bool, _cond.into(), |context| #then_block); - } - } -} - -/// Codegen of loop -pub(crate) fn codegen_loop( - loop_expr: &syn::ExprLoop, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let block = codegen_block(&loop_expr.body, loop_level + 1, variable_tracker); - - quote::quote! { - burn_cube::frontend::branch::loop_expand(context, |context| #block); - } -} - -/// Codegen for while loop -pub(crate) fn codegen_while_loop( - while_loop: &syn::ExprWhile, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let (cond, is_comptime) = - codegen_cond(&while_loop.cond, loop_level + 1, variable_tracker).split(); - - if is_comptime { - return syn::Error::new_spanned(while_loop.while_token, "Comptime not supported for while") - .into_compile_error(); - } - - let block = codegen_block(&while_loop.body, loop_level + 1, variable_tracker); - - quote::quote! { - burn_cube::frontend::branch::while_loop_expand(context, |context| #cond, |context| #block); - } -} diff --git a/crates/burn-cube-macros/src/codegen_function/expr.rs b/crates/burn-cube-macros/src/codegen_function/expr.rs deleted file mode 100644 index 6c89b610b6..0000000000 --- a/crates/burn-cube-macros/src/codegen_function/expr.rs +++ /dev/null @@ -1,99 +0,0 @@ -use crate::tracker::VariableTracker; -use proc_macro2::TokenStream; - -use super::{ - base::{codegen_block, Codegen}, - branch::{ - codegen_break, codegen_for_loop, codegen_if, codegen_loop, codegen_return, - codegen_while_loop, - }, - function::{codegen_call, codegen_closure, codegen_expr_method_call}, - operation::{codegen_binary, codegen_unary}, - variable::{ - codegen_array_lit, codegen_assign, codegen_field, codegen_index, codegen_lit, - codegen_path_var, codegen_struct, - }, -}; - -/// Codegen for expressions -pub(crate) fn codegen_expr( - expr: &syn::Expr, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - match expr { - syn::Expr::Call(call) => codegen_call(call, loop_level, variable_tracker), - syn::Expr::Paren(paren) => codegen_expr(&paren.expr, loop_level, variable_tracker), - _ => { - let mut array_indexing = None; - let tokens = match expr { - syn::Expr::Path(path) => { - return codegen_path_var(path, loop_level, variable_tracker) - } - syn::Expr::Binary(op) => return codegen_binary(op, loop_level, variable_tracker), - syn::Expr::Unary(op) => return codegen_unary(op, loop_level, variable_tracker), - syn::Expr::Lit(lit) => codegen_lit(lit), - syn::Expr::Closure(closure) => { - codegen_closure(closure, loop_level, variable_tracker) - } - syn::Expr::Block(block) => codegen_expr_block(block, loop_level, variable_tracker), - syn::Expr::Assign(assign) => codegen_assign(assign, loop_level, variable_tracker), - syn::Expr::ForLoop(for_loop) => { - codegen_for_loop(for_loop, loop_level, variable_tracker) - } - syn::Expr::While(while_loop) => { - codegen_while_loop(while_loop, loop_level, variable_tracker) - } - syn::Expr::Loop(loop_expr) => codegen_loop(loop_expr, loop_level, variable_tracker), - syn::Expr::Break(_) => codegen_break(), - syn::Expr::Return(return_expr) => codegen_return(return_expr), - syn::Expr::If(expr_if) => codegen_if(expr_if, loop_level, variable_tracker), - syn::Expr::MethodCall(call) => { - codegen_expr_method_call(call, loop_level, variable_tracker) - } - syn::Expr::Index(index) => { - let codegen = codegen_index(index, loop_level, variable_tracker); - array_indexing = codegen.array_indexing; - codegen.tokens - } - syn::Expr::Array(array) => codegen_array_lit(array), - syn::Expr::Reference(reference) => { - codegen_ref(reference, loop_level, variable_tracker) - } - syn::Expr::Field(field) => codegen_field(field, loop_level, variable_tracker), - syn::Expr::Struct(struct_) => codegen_struct(struct_, loop_level, variable_tracker), - syn::Expr::Range(range) => syn::Error::new_spanned( - range, - "Range is not supported, use [range](cubecl::prelude::range) instead.", - ) - .to_compile_error(), - _ => { - syn::Error::new_spanned(expr, "Expression is not supported").to_compile_error() - } - }; - - let mut codegen = Codegen::new(tokens, false); - codegen.array_indexing = array_indexing; - codegen - } - } -} - -/// Codegen for an expression containing a block -pub(crate) fn codegen_expr_block( - block: &syn::ExprBlock, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - codegen_block(&block.block, loop_level, variable_tracker) -} - -pub(crate) fn codegen_ref( - reference: &syn::ExprReference, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - // We ignore reference for the expansion. - let inner = codegen_expr(&reference.expr, loop_level, variable_tracker); - quote::quote! { #inner } -} diff --git a/crates/burn-cube-macros/src/codegen_function/function.rs b/crates/burn-cube-macros/src/codegen_function/function.rs deleted file mode 100644 index 436ab4a6c1..0000000000 --- a/crates/burn-cube-macros/src/codegen_function/function.rs +++ /dev/null @@ -1,250 +0,0 @@ -use proc_macro2::{Span, TokenStream}; -use quote::quote_spanned; -use syn::{ - punctuated::Punctuated, spanned::Spanned, AngleBracketedGenericArguments, Expr, Ident, - PathArguments, Token, -}; - -use crate::{codegen_function::expr::codegen_expr, tracker::VariableTracker}; - -use super::base::Codegen; - -/// Codegen for method call -/// Supports [expr].method(args) -pub(crate) fn codegen_expr_method_call( - call: &syn::ExprMethodCall, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let receiver = codegen_expr(&call.receiver, loop_level, variable_tracker); - let method_expand = syn::Ident::new( - format!("{}_expand", call.method).as_str(), - proc_macro2::Span::call_site(), - ); - let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker); - - quote::quote!( { - #expansion - #receiver . #method_expand ( #variables ) - }) -} - -/// Codegen for a closure -pub(crate) fn codegen_closure( - closure: &syn::ExprClosure, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let mut inputs = quote::quote! {}; - for input in closure.inputs.iter() { - let (ident, ty) = match input { - syn::Pat::Ident(ident) => (&ident.ident, None), - syn::Pat::Type(pat_type) => ( - if let syn::Pat::Ident(ident) = &*pat_type.pat { - &ident.ident - } else { - return syn::Error::new_spanned(pat_type, "Unsupported input") - .into_compile_error(); - }, - Some(pat_type.ty.clone()), - ), - _ => return syn::Error::new_spanned(input, "Unsupported input").into_compile_error(), - }; - - if let Some(ty) = ty { - inputs.extend(quote::quote! { - #ident : #ty, - }); - } else { - inputs.extend(quote::quote! { - #ident, - }); - } - } - - let body = codegen_expr(closure.body.as_ref(), loop_level, variable_tracker); - - quote::quote! { - |context: &mut CubeContext, #inputs| #body - } -} - -/// Maps -/// [A[::<...>]?::]^* func[::<...>] (args) -/// to -/// [A[::<...>]?::]^* func_expand[::<...>] (context, args) -/// -/// Also returns a bool that is true if it's comptime -pub(crate) fn codegen_call( - call: &syn::ExprCall, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - // We start with parsing the function path - let path: Vec<(&Ident, Option<&AngleBracketedGenericArguments>)> = match call.func.as_ref() { - syn::Expr::Path(expr_path) => { - let mut path = Vec::new(); - for segment in expr_path.path.segments.iter() { - let generics = if let PathArguments::AngleBracketed(arguments) = &segment.arguments - { - Some(arguments) - } else { - None - }; - path.push((&segment.ident, generics)); - } - path - } - _ => { - return Codegen::new( - syn::Error::new_spanned(&call.func, "Unsupported").into_compile_error(), - false, - ) - } - }; - - // Path - let mut path_tokens = TokenStream::new(); - let mut is_comptime = false; - let mut is_plain_func = true; - let mut comptime_func: Option = None; - - for (i, (ident, generics)) in path.iter().enumerate() { - let name = ident.to_string(); - - if name == "Comptime" { - is_comptime = true; - continue; - } - - if let Some(first_char) = name.chars().next() { - if first_char.is_uppercase() { - is_plain_func = false; - } - } - - if i == path.len() - 1 { - if is_comptime { - comptime_func = Some(ident.to_string()); - break; - } - - let func_name_expand = if is_plain_func { - quote::quote! { - #ident::__expand - } - } else { - let ident = syn::Ident::new( - format!("__expand_{ident}").as_str(), - proc_macro2::Span::call_site(), - ); - quote::quote! { - #ident - } - }; - path_tokens.extend(quote_spanned! {func_name_expand.span() => #func_name_expand }); - if let Some(generics) = generics { - path_tokens.extend(quote_spanned! {generics.span() => #generics }); - } - } else if let Some(generics) = generics { - path_tokens.extend(quote_spanned! {ident.span() => #ident }); - path_tokens.extend(quote_spanned! {generics.span() => #generics :: }); - } else { - path_tokens.extend(quote_spanned! {ident.span() => #ident :: }); - } - } - - // Arguments - if let Some(func_name) = comptime_func { - let tokens = match func_name.as_str() { - "get" | "new" => { - let code = call.args.first().unwrap(); - quote::quote! {#code} - } - "map" => { - let args = &call.args; - - // Codegen - quote::quote! { - { - Comptime::map_expand(#args) - } - } - } - "unwrap_or_else" => { - let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker); - - // Codegen - quote::quote! {{ - #expansion - Comptime::unwrap_or_else_expand(#variables) - }} - } - "is_some" => { - let code = call.args.first().unwrap(); - quote::quote! { #code.is_some() } - } - "vectorization" => { - let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker); - - // Codegen - quote::quote! {{ - #expansion - Comptime::vectorization_expand(#variables) - }} - } - "vectorize" => { - let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker); - - // Codegen - quote::quote! {{ - #expansion - Comptime::vectorize_expand(#variables) - }} - } - "runtime" => { - let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker); - - // Codegen - quote::quote! {{ - #expansion - Comptime::runtime_expand(#variables) - }} - } - - _ => panic!("Codegen: Comptime function {:?} does not exist", func_name), - }; - - Codegen::new(tokens, true) - } else { - let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker); - - // Codegen - let tokens = quote::quote! {{ - #expansion - #path_tokens (#variables) - }}; - - Codegen::new(tokens, false) - } -} - -fn codegen_args( - args: &Punctuated, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> (TokenStream, TokenStream) { - let mut expansion = quote::quote! {}; - let mut variables = quote::quote! {}; - - variables.extend(quote::quote! { context, }); - - for (i, argument) in args.iter().enumerate() { - let ident = Ident::new(format!("_var_{i}").as_str(), Span::call_site()); - let arg_token = codegen_expr(argument, loop_level, variable_tracker); - expansion.extend(quote::quote! { let #ident = #arg_token; }); - variables.extend(quote::quote! { #ident, }); - } - - (expansion, variables) -} diff --git a/crates/burn-cube-macros/src/codegen_function/launch.rs b/crates/burn-cube-macros/src/codegen_function/launch.rs deleted file mode 100644 index 7f035a55be..0000000000 --- a/crates/burn-cube-macros/src/codegen_function/launch.rs +++ /dev/null @@ -1,521 +0,0 @@ -use proc_macro2::{Span, TokenStream}; -use syn::{parse_quote, Generics, Ident}; - -#[derive(Default)] -struct Codegen { - // Basic attributes. - name: String, - generics: Generics, - fn_inputs: TokenStream, - fn_output: TokenStream, - // States to generate code. - state_comptimes: Vec<(syn::Type, Ident)>, - state_args: Vec, - state_inputs: Vec<(Ident, syn::Type)>, - state_outputs: Vec<(Ident, syn::Type)>, -} - -impl Codegen { - fn from_sig(sig: &syn::Signature) -> Self { - let mut codegen = Codegen::default(); - - let mut first_letter = sig.ident.to_string(); - let second_part = first_letter.split_off(1); - - codegen.name = format!("{}{}", first_letter.to_uppercase(), second_part); - codegen.generics = sig.generics.clone(); - - let mut inputs = quote::quote!(); - - for input in &sig.inputs { - let mut is_output = false; - let mut comptime = false; - - match input { - syn::FnArg::Typed(pat) => { - let (ty, ident) = match pat.pat.as_ref() { - syn::Pat::Ident(ident) => { - if ident.mutability.is_some() { - is_output = true; - } - - if let syn::Type::Reference(ty) = pat.ty.as_ref() { - if ty.mutability.is_some() { - is_output = true; - } - }; - - if let syn::Type::Path(pat) = pat.ty.as_ref() { - if let Some(name) = pat.path.segments.first() { - let name = name.ident.to_string(); - - if name == "Comptime" { - comptime = true; - } - } - }; - - (pat.ty.clone(), ident.ident.clone()) - } - _ => panic!("Nop"), - }; - - if comptime { - codegen.state_args.push(quote::quote! { - self.#ident - }); - } else { - codegen.state_args.push(quote::quote! { - #ident - }); - } - - if comptime { - let ty = no_ref(&ty); - inputs.extend(quote::quote! { - #ident: <#ty as burn_cube::frontend::CubeType>::ExpandType, - }); - } else { - let ty = no_ref(&ty); - inputs.extend(quote::quote! { - #ident: RuntimeArg<'a, #ty, R>, - }); - } - - if is_output { - codegen - .state_outputs - .push((ident.clone(), no_ref(&ty).clone())); - } else if comptime { - codegen - .state_comptimes - .push((first_generic_ty(&ty).clone(), ident.clone())); - } else { - codegen - .state_inputs - .push((ident.clone(), no_ref(&ty).clone())); - } - } - _ => panic!("Only Typed inputs are supported"), - }; - } - - let mut output = quote::quote!(); - - match &sig.output { - syn::ReturnType::Default => output.extend(quote::quote! {()}), - syn::ReturnType::Type(_, ty) => { - output.extend(quote::quote! { - <#ty as burn_cube::frontend::CubeType>::ExpandType - }); - } - } - - codegen.fn_inputs = inputs; - codegen.fn_output = output; - - codegen - } - - fn gen_kernel_struct(&self) -> TokenStream { - let ident = Ident::new(&self.name, Span::call_site()); - let generics = add_runtime(self.generics.clone()); - let phantoms = self.phantoms(&generics, true); - let mut comptimes = quote::quote! {}; - - for (ty, ident) in self.state_comptimes.iter() { - comptimes.extend(quote::quote! { - #ident: #ty, - }); - } - - quote::quote! { - /// Kernel - pub struct #ident #generics { - settings: KernelSettings, - #comptimes - #phantoms - } - } - } - - fn gen_settings(&self) -> TokenStream { - let mut variables = quote::quote! {}; - - for (pos, (ident, _ty)) in self.state_inputs.iter().enumerate() { - variables.extend(quote::quote! { - settings = ArgSettings::::configure_input(&#ident, #pos, settings); - }); - } - - for (pos, (ident, _ty)) in self.state_outputs.iter().enumerate() { - variables.extend(quote::quote! { - settings = ArgSettings::::configure_output(&#ident, #pos, settings); - }); - } - - quote::quote! { - let mut settings = KernelSettings::default(); - settings = settings.cube_dim(cube_dim); - #variables - } - } - - fn gen_register_input(&self) -> TokenStream { - let generics = &self.generics; - let mut variables = quote::quote! {}; - - for (pos, (_ident, ty)) in self.state_inputs.iter().enumerate() { - variables.extend(quote::quote! { - #pos => std::sync::Arc::new(<#ty as LaunchArgExpand>::expand(builder, settings.vectorization_input(#pos))), - }); - } - - quote::quote! { - #[allow(unused)] - fn register_input #generics( - builder: &mut KernelBuilder, - settings: &KernelSettings, - position: usize, - ) -> std::sync::Arc { - match position { - #variables - _ => panic!("Input {position} is invalid."), - } - } - } - } - - fn gen_register_output(&self) -> TokenStream { - let generics = &self.generics; - let mut variables = quote::quote! {}; - - for (pos, (_ident, ty)) in self.state_outputs.iter().enumerate() { - variables.extend(quote::quote! { - #pos => std::sync::Arc::new(<#ty as LaunchArgExpand>::expand_output(builder, settings.vectorization_output(#pos))), - }); - } - - quote::quote! { - #[allow(unused)] - fn register_output #generics ( - builder: &mut KernelBuilder, - settings: &KernelSettings, - position: usize, - ) -> std::sync::Arc { - match position { - #variables - _ => panic!("Input {position} is invalid."), - } - } - } - } - - fn gen_define_impl(&self, expand: &TokenStream) -> TokenStream { - let mut expand_args = quote::quote! { &mut builder.context, }; - - let mut variables = quote::quote! {}; - - for (pos, (ident, ty)) in self.state_inputs.iter().enumerate() { - variables.extend(quote::quote! { - let #ident: &<#ty as CubeType>::ExpandType = inputs - .get(&#pos) - .unwrap() - .downcast_ref() - .expect("Input type should be correct. It could be caused by an invalid kernel input/output alias."); - }); - } - - for (pos, (ident, ty)) in self.state_outputs.iter().enumerate() { - variables.extend(quote::quote! { - let #ident: &<#ty as CubeType>::ExpandType = outputs - .get(&#pos) - .unwrap() - .downcast_ref() - .expect("Output type should be correct. It could be caused by an invalid kernel input/output alias."); - }); - } - - for arg in self.state_args.iter() { - expand_args.extend(quote::quote! { - #arg.clone(), - }) - } - - let expand_func = match self.generics.params.is_empty() { - true => quote::quote! { #expand }, - false => { - let generics = self.generics.split_for_impl().1; - quote::quote! { #expand::#generics } - } - }; - - quote::quote! { - #variables - #expand_func(#expand_args); - builder.build(self.settings.clone()) - } - } - - fn gen_define_args(&self) -> TokenStream { - let num_inputs = self.state_inputs.len(); - let num_outputs = self.state_outputs.len(); - - let register_input = self.gen_register_input(); - let register_output = self.gen_register_output(); - - let (register_input_call, register_output_call) = match self.generics.params.is_empty() { - true => ( - quote::quote! { register_input }, - quote::quote! { register_output }, - ), - false => { - let generics = self.generics.split_for_impl().1; - - ( - quote::quote! { register_input::#generics }, - quote::quote! { register_output::#generics }, - ) - } - }; - - let mut variables = quote::quote! {}; - - for (pos, (ident, ty)) in self.state_inputs.iter().enumerate() { - variables.extend(quote::quote! { - let #ident = <&#ty as CubeType>::ExpandType = - *inputs.remove(&#pos).unwrap().downcast().unwrap(); - }); - } - - for (pos, (ident, ty)) in self.state_outputs.iter().enumerate() { - variables.extend(quote::quote! { - let #ident = <&mut #ty as CubeType>::ExpandType = - *outputs.remove(&#pos).unwrap().downcast().unwrap(); - }); - } - - let mut tokens = quote::quote! { - let mut builder = KernelBuilder::default(); - - let mut inputs: std::collections::BTreeMap> = std::collections::BTreeMap::new(); - let mut outputs: std::collections::BTreeMap> = std::collections::BTreeMap::new(); - - for mapping in self.settings.mappings.iter() { - if !inputs.contains_key(&mapping.pos_input) { - inputs.insert( - mapping.pos_input, - #register_input_call(&mut builder, &self.settings, mapping.pos_input), - ); - } - - let input = inputs.get(&mapping.pos_input).unwrap(); - outputs.insert(mapping.pos_output, input.clone()); - } - - #register_input - #register_output - }; - - if num_inputs > 0 { - tokens.extend(quote::quote! { - for i in 0..#num_inputs { - if !inputs.contains_key(&i) { - inputs.insert(i, #register_input_call(&mut builder, &self.settings, i)); - } - } - }); - } - - if num_outputs > 0 { - tokens.extend(quote::quote! { - for i in 0..#num_outputs { - if !outputs.contains_key(&i) { - outputs.insert(i, #register_output_call(&mut builder, &self.settings, i)); - } - } - }); - } - - tokens - } - - fn gen_compile_impl(&self, expand: &TokenStream) -> TokenStream { - let ident = Ident::new(&self.name, Span::call_site()); - let generics = add_runtime(self.generics.clone()); - let (impl_gen, ty_gen, where_gen) = generics.split_for_impl(); - - let mut format_str = "{:?}-{}".to_string(); - for _ in 0..self.state_comptimes.len() { - format_str.push_str("-{:?}"); - } - - let mut format_args = quote::quote! { core::any::TypeId::of::(), self.settings, }; - - for (_, ident) in self.state_comptimes.iter() { - format_args.extend(quote::quote! { self.#ident, }); - } - - let define_args = self.gen_define_args(); - let define_impl = self.gen_define_impl(expand); - - quote::quote! { - impl #impl_gen Kernel for #ident #ty_gen #where_gen { - fn define(&self) -> KernelDefinition { - #define_args - #define_impl - } - - fn id(&self) -> String { - format!(#format_str, #format_args) - } - } - } - } - - fn phantoms(&self, generics: &Generics, declaration: bool) -> TokenStream { - let mut phantoms = quote::quote! {}; - - for param in generics.params.iter() { - let ty = match param { - syn::GenericParam::Type(ty) => ty, - _ => continue, - }; - let ident = Ident::new( - format!("_{}", ty.ident.to_string().to_lowercase()).as_str(), - Span::call_site(), - ); - let ty = &ty.ident; - if declaration { - phantoms.extend(quote::quote! { - #ident: core::marker::PhantomData<#ty>, - }); - } else { - phantoms.extend(quote::quote! { - #ident: core::marker::PhantomData::<#ty>, - }); - } - } - phantoms - } - - fn gen_launch_body(&self) -> TokenStream { - let ident = Ident::new(&self.name, Span::call_site()); - let generics = add_runtime(self.generics.clone()); - let phantoms = self.phantoms(&generics, false); - - let mut comptimes = quote::quote! {}; - let settings = self.gen_settings(); - - let mut body = quote::quote! { - let mut launcher = KernelLauncher::::default(); - }; - - for (input, _) in self.state_inputs.iter() { - body.extend(quote::quote! { - #input.register(&mut launcher); - }); - } - - for (input, _) in self.state_outputs.iter() { - body.extend(quote::quote! { - #input.register(&mut launcher); - }); - } - - for (_ty, ident) in self.state_comptimes.iter() { - comptimes.extend(quote::quote! { - #ident, - }); - } - - let kernel = quote::quote! { - #ident { - settings, - #comptimes - #phantoms - } - }; - - quote::quote! { - #settings - - let kernel = #kernel; - - #body - - launcher.launch(cube_count, kernel, client); - } - } -} - -pub fn codegen_launch(sig: &syn::Signature) -> TokenStream { - let codegen = Codegen::from_sig(sig); - - let ident = &sig.ident; - - let ident_expand = quote::quote! { - __expand - }; - - let generics = add_runtime(add_lifetime(sig.generics.clone())); - let body = codegen.gen_launch_body(); - let kernel = codegen.gen_kernel_struct(); - let compile = codegen.gen_compile_impl(&ident_expand); - let (inputs, output) = (codegen.fn_inputs, codegen.fn_output); - let doc = format!("Launch the kernel [{ident}()] on the given runtime."); - - quote::quote! { - #kernel - #compile - - #[allow(clippy::too_many_arguments)] - #[doc = #doc] - pub fn launch #generics ( - client: ComputeClient, - cube_count: CubeCount, - cube_dim: CubeDim, - #inputs - ) -> #output { - #body; - } - } -} - -pub fn add_lifetime(mut generics: Generics) -> Generics { - let lifetime: syn::Generics = parse_quote! {<'a>}; - - generics - .params - .insert(0, lifetime.params.into_iter().next().unwrap()); - generics -} - -pub fn add_runtime(mut generics: Generics) -> Generics { - let runtime: syn::Generics = parse_quote! { }; - - generics - .params - .push(runtime.params.into_iter().next().unwrap()); - generics -} - -fn first_generic_ty(ty: &syn::Type) -> syn::Type { - match ty { - syn::Type::Path(pat) => match &pat.path.segments.first().unwrap().arguments { - syn::PathArguments::AngleBracketed(ty) => match ty.args.first().unwrap() { - syn::GenericArgument::Type(ty) => ty.clone(), - _ => panic!("Should have a generic type"), - }, - _ => panic!("Comptime must have a generic"), - }, - _ => todo!(), - } -} - -fn no_ref(ty: &syn::Type) -> &syn::Type { - match ty { - syn::Type::Reference(val) => &val.elem, - _ => ty, - } -} diff --git a/crates/burn-cube-macros/src/codegen_function/mod.rs b/crates/burn-cube-macros/src/codegen_function/mod.rs deleted file mode 100644 index ed9bff87a2..0000000000 --- a/crates/burn-cube-macros/src/codegen_function/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -mod base; -mod branch; -mod expr; -mod function; -mod launch; -mod operation; -mod variable; - -pub(crate) use base::codegen_statement; -pub(crate) use launch::codegen_launch; diff --git a/crates/burn-cube-macros/src/codegen_function/operation.rs b/crates/burn-cube-macros/src/codegen_function/operation.rs deleted file mode 100644 index fa6235c045..0000000000 --- a/crates/burn-cube-macros/src/codegen_function/operation.rs +++ /dev/null @@ -1,270 +0,0 @@ -use crate::tracker::VariableTracker; - -use super::{base::Codegen, expr::codegen_expr}; - -/// Codegen for binary operations (+, -, *, etc.) -pub(crate) fn codegen_binary( - binary: &syn::ExprBinary, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - let lhs = codegen_expr(&binary.left, loop_level, variable_tracker); - let (lhs, is_comptime_lhs, lhs_array) = (lhs.tokens, lhs.is_comptime, lhs.array_indexing); - let (rhs, is_comptime_rhs) = codegen_expr(&binary.right, loop_level, variable_tracker).split(); - - if is_comptime_lhs && is_comptime_rhs { - return Codegen::new( - quote::quote! { - #binary - }, - true, - ); - } - - Codegen::new( - match binary.op { - syn::BinOp::Add(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::add::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Sub(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::sub::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Mul(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::mul::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Div(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::div::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Rem(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::rem::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Ne(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::ne::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Gt(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::gt::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Ge(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::ge::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Lt(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::lt::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Le(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::le::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Eq(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::eq::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::AddAssign(_) => { - if let Some(array) = lhs_array { - let (array, index) = (array.array, array.index); - - quote::quote! { - { - let _array = #array; - let _index = #index; - let _value = #rhs; - burn_cube::frontend::add_assign_array_op::expand(context, _array, _index, _value) - } - } - } else { - quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::add_assign_op::expand(context, _lhs, _rhs) - } - } - } - } - syn::BinOp::SubAssign(_) => { - if let Some(array) = lhs_array { - let (array, index) = (array.array, array.index); - - quote::quote! { - { - let _array = #array; - let _index = #index; - let _value = #rhs; - burn_cube::frontend::sub_assign_array_op::expand(context, _array, _index, _value) - } - } - } else { - quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::sub_assign_op::expand(context, _lhs, _rhs) - } - } - } - } - syn::BinOp::MulAssign(_) => { - if let Some(array) = lhs_array { - let (array, index) = (array.array, array.index); - - quote::quote! { - { - let _array = #array; - let _index = #index; - let _value = #rhs; - burn_cube::frontend::mul_assign_array_op::expand(context, _array, _index, _value) - } - } - } else { - quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::mul_assign_op::expand(context, _lhs, _rhs) - } - } - } - } - syn::BinOp::DivAssign(_) => { - if let Some(array) = lhs_array { - let (array, index) = (array.array, array.index); - - quote::quote! { - { - let _array = #array; - let _index = #index; - let _value = #rhs; - burn_cube::frontend::div_assign_array_op::expand(context, _array, _index, _value) - } - } - } else { - quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::div_assign_op::expand(context, _lhs, _rhs) - } - } - } - } - syn::BinOp::And(_) => quote::quote! { - { - - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::and::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Or(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::or::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::BitAnd(_) => quote::quote! { - { - - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::bitand::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::BitXor(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::bitxor::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Shl(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::shl::expand(context, _lhs, _rhs) - } - }, - syn::BinOp::Shr(_) => quote::quote! { - { - let _lhs = #lhs; - let _rhs = #rhs; - burn_cube::frontend::shr::expand(context, _lhs, _rhs) - } - }, - _ => todo!("Codegen: unsupported op {:?}", binary.op), - }, - false, - ) -} - -/// Codegen for unary operations -pub(crate) fn codegen_unary( - unary: &syn::ExprUnary, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - let (inner, is_comptime) = codegen_expr(&unary.expr, loop_level, variable_tracker).split(); - - if is_comptime { - return Codegen::new( - quote::quote! { - #unary - }, - true, - ); - } - - Codegen::new( - match unary.op { - syn::UnOp::Not(_) => quote::quote! { - { - let _inner = #inner; - burn_cube::frontend::not::expand(context, _inner) - } - }, - _ => todo!("Codegen: unsupported op {:?}", unary.op), - }, - false, - ) -} diff --git a/crates/burn-cube-macros/src/codegen_function/variable.rs b/crates/burn-cube-macros/src/codegen_function/variable.rs deleted file mode 100644 index 4ca1d88432..0000000000 --- a/crates/burn-cube-macros/src/codegen_function/variable.rs +++ /dev/null @@ -1,322 +0,0 @@ -use proc_macro2::TokenStream; -use quote::ToTokens; -use syn::{punctuated::Punctuated, FieldValue, Lit, Member, PathArguments, Token}; - -use crate::{analyzer::KEYWORDS, codegen_function::expr::codegen_expr, tracker::VariableTracker}; - -use super::base::Codegen; - -/// Codegen for literals -pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream { - match lit.lit { - // We treat floats differently to avoid getting 4..into() for instance - Lit::Float(_) => { - let lit_str = lit.lit.to_token_stream().to_string(); - let float_lit = lit_str.parse::().unwrap(); - quote::quote! { #float_lit } - } - _ => { - quote::quote! { #lit } - } - } -} - -/// Codegen for arrays of literals -pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream { - let mut tokens = quote::quote! {}; - for element in array.elems.iter() { - let token = match element { - syn::Expr::Lit(lit) => codegen_lit(lit), - _ => { - return syn::Error::new_spanned(array, "Only arrays of literals are supported") - .into_compile_error() - } - }; - tokens.extend(quote::quote! { #token, }); - } - quote::quote! { [ #tokens ] } -} - -/// Codegen for a local declaration (let ...) -/// Supports: -/// let x = ... -/// let x: T = ... -/// let _ = ... -/// let mut _ = ... -pub(crate) fn codegen_local( - local: &syn::Local, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let let_tok = local.let_token; - - let ident = match &local.pat { - syn::Pat::Ident(ident) => ident.to_token_stream(), - syn::Pat::Type(pat_type) => match &*pat_type.pat { - syn::Pat::Ident(pat_ident) => pat_ident.to_token_stream(), - _ => todo!("Codegen: Unsupported typed path {:?}", pat_type.pat), - }, - syn::Pat::Wild(wild) => wild.underscore_token.to_token_stream(), - _ => todo!("Codegen: Declaration {:?} is unsupported.", local.pat), - }; - - variable_tracker.codegen_declare(ident.to_string(), loop_level as u8); - - match local.init.as_ref() { - Some(init) => { - let (init, is_comptime) = - codegen_expr(&init.expr, loop_level, variable_tracker).split(); - - if is_comptime { - variable_tracker - .set_as_comptime(ident.to_string(), loop_level as u8, None) - .unwrap(); - } - - if is_comptime { - quote::quote! { - #let_tok #ident = #init; - } - } else { - quote::quote! { - #let_tok #ident = { - let _inner = #init; - burn_cube::frontend::Init::init(_inner, context) - }; - } - } - } - None => { - quote::quote! { - #let_tok #ident; - } - } - } -} - -/// Codegen for indexed access -pub(crate) fn codegen_index( - index: &syn::ExprIndex, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - let array = codegen_expr(&index.expr, loop_level, variable_tracker); - let index = codegen_expr(&index.index, loop_level, variable_tracker); - - let tokens = quote::quote! { - { - let _array = #array; - let _index = #index; - burn_cube::frontend::index::expand(context, _array, _index) - } - }; - - let mut codegen = Codegen::new(tokens, false); - codegen.array_indexing = Some(super::base::ArrayIndexing { - array: array.tokens, - index: index.tokens, - }); - - codegen -} - -/// Codegen for assignation -/// Supports: -/// - scalar -/// - indexed array -pub(crate) fn codegen_assign( - assign: &syn::ExprAssign, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - match assign.left.as_ref() { - syn::Expr::Index(index) => { - let array = codegen_expr(&index.expr, loop_level, variable_tracker); - let index = codegen_expr(&index.index, loop_level, variable_tracker); - let value = codegen_expr(&assign.right, loop_level, variable_tracker); - - quote::quote! { - { - let _array = #array; - let _index = #index; - let _value = #value; - burn_cube::frontend::index_assign::expand(context, _array, _index, _value) - } - } - } - syn::Expr::Path(_) => { - let lhs = codegen_expr(&assign.left, loop_level, variable_tracker); - let rhs = codegen_expr(&assign.right, loop_level, variable_tracker); - - quote::quote! { - { - let _assign_lhs = #lhs; - let _assign_rhs = #rhs; - burn_cube::frontend::assign::expand(context, _assign_rhs, _assign_lhs) - } - } - } - syn::Expr::Field(_) => { - let lhs = codegen_expr(&assign.left, loop_level, variable_tracker); - let rhs = codegen_expr(&assign.right, loop_level, variable_tracker); - - quote::quote! { - { - let _assign_lhs = #lhs; - let _assign_rhs = #rhs; - burn_cube::frontend::assign::expand(context, _assign_rhs, _assign_lhs) - } - } - } - _ => todo!("Assign of expr {:?} unsupported", assign.left), - } -} - -pub(crate) fn codegen_path_var( - path: &syn::ExprPath, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> Codegen { - let ident = match path.path.get_ident() { - Some(ident) => ident, - None => { - return Codegen::new( - quote::quote! { - #path - }, - false, - ); - } - }; - - let name = ident.to_string(); - - if name == "None" { - return Codegen::new(quote::quote! { None }, true); - } - - if KEYWORDS.contains(&name.as_str()) { - Codegen::new( - quote::quote! { - #ident :: expand(context) - }, - false, - ) - } else { - let (will_be_used_again, is_comptime) = variable_tracker - .codegen_reuse(name, loop_level as u8, None) - .unwrap_or((true, false)); - - let output = if will_be_used_again { - quote::quote! { - #ident.clone() - } - } else { - quote::quote! { - #ident - } - }; - - Codegen::new(output, is_comptime) - } -} - -/// Codegen for a field used in rhs of a statement -/// This function adds cloning when necessary -pub(crate) fn codegen_field( - field: &syn::ExprField, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let (struct_, field) = if let Member::Named(attribute_ident) = &field.member { - if let syn::Expr::Path(struct_expr) = &*field.base { - let struct_ident = struct_expr - .path - .get_ident() - .expect("Codegen: field access only supported on ident struct."); - - (struct_ident, attribute_ident) - } else { - todo!("Codegen: field access only supported on ident struct."); - } - } else { - todo!("Codegen: unnamed attribute not supported."); - }; - - let (will_be_used_again, _) = variable_tracker - .codegen_reuse( - struct_.to_string(), - loop_level as u8, - Some(field.to_string()), - ) - .unwrap(); - - if will_be_used_again { - quote::quote! { - #struct_ . #field .clone() - } - } else { - quote::quote! { - #struct_ . #field - } - } -} - -// Codegen for a struct declaration -pub(crate) fn codegen_struct( - struct_: &syn::ExprStruct, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let mut deconstructed_path = Vec::new(); - for segment in struct_.path.segments.iter() { - let generics = if let PathArguments::AngleBracketed(arguments) = &segment.arguments { - Some(arguments) - } else { - None - }; - deconstructed_path.push((&segment.ident, generics)); - } - - let (struct_name, generics) = deconstructed_path - .pop() - .expect("At least one ident in the path"); - - // This is hacky but using ::ExpandType {...} is experimental in Rust - let expanded_struct_name = syn::Ident::new( - format!("{}Expand", struct_name).as_str(), - proc_macro2::Span::call_site(), - ); - - deconstructed_path.push((&expanded_struct_name, generics)); - - // Reconstruct the path - let mut path_tokens = quote::quote! {}; - for (ident, angle_bracketed_generics) in deconstructed_path { - let ident_tokens = ident.to_token_stream(); - let generics_tokens = angle_bracketed_generics.to_token_stream(); - - path_tokens.extend(quote::quote! { - #ident_tokens #generics_tokens - }); - } - - let fields = codegen_field_creation(&struct_.fields, loop_level, variable_tracker); - quote::quote! { - #path_tokens { #fields } - } -} - -fn codegen_field_creation( - fields: &Punctuated, - loop_level: usize, - variable_tracker: &mut VariableTracker, -) -> TokenStream { - let mut field_tokens = quote::quote! {}; - for field in fields.iter() { - let field_name_token = &field.member; - let field_value_token = codegen_expr(&field.expr, loop_level, variable_tracker); - field_tokens.extend(quote::quote! { #field_name_token : #field_value_token, }); - } - field_tokens -} diff --git a/crates/burn-cube-macros/src/codegen_trait/mod.rs b/crates/burn-cube-macros/src/codegen_trait/mod.rs deleted file mode 100644 index bba751a472..0000000000 --- a/crates/burn-cube-macros/src/codegen_trait/mod.rs +++ /dev/null @@ -1,110 +0,0 @@ -use proc_macro2::TokenStream; - -use crate::codegen_common::signature::{expand_sig, ExpandMode}; - -pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream { - let mut expand_items = Vec::new(); - - for item in tr.items.iter() { - match item { - syn::TraitItem::Fn(func) => { - let expand = expand_sig( - &func.sig, - &syn::Visibility::Inherited, - None, - ExpandMode::MethodImpl, - ); - expand_items.push(syn::parse_quote!(#expand;)); - } - _ => continue, - } - } - tr.items.append(&mut expand_items); - - quote::quote! { - #tr - } -} - -pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream { - let mut expand_items = Vec::new(); - - for item in tr.items.iter() { - match item { - syn::ImplItem::Fn(func) => { - let ident = &func.sig.ident; - let ident = quote::quote! {#ident::__expand}; - let mut inputs = quote::quote!(); - - for input in &func.sig.inputs { - match input { - syn::FnArg::Typed(pat) => { - let ident = pat.pat.clone(); - inputs.extend(quote::quote! { - #ident, - }); - } - _ => todo!("Only Typed inputs are supported"), - } - } - - let expand = expand_sig( - &func.sig, - &syn::Visibility::Inherited, - None, - ExpandMode::MethodImpl, - ); - - let tokens = if !tr.generics.params.is_empty() { - let mut func = func.clone(); - for param in tr.generics.params.iter() { - func.sig.generics.params.push(param.clone()); - } - register_expand(&func, &ident, expand, inputs) - } else { - register_expand(func, &ident, expand, inputs) - }; - - expand_items.push(syn::parse2(tokens).unwrap()); - } - _ => continue, - } - } - tr.items.append(&mut expand_items); - - quote::quote! { - #tr - } -} - -fn register_expand( - func: &syn::ImplItemFn, - name: &TokenStream, - expand: proc_macro2::TokenStream, - inputs: proc_macro2::TokenStream, -) -> proc_macro2::TokenStream { - let (func, func_expand) = if func.sig.generics.params.is_empty() { - ( - quote::quote! { #func }, - quote::quote! { - #name(context, #inputs) - }, - ) - } else { - let (_, gen, _) = &func.sig.generics.split_for_impl(); - ( - quote::quote! { #func }, - quote::quote! { - #name::#gen(context, #inputs) - }, - ) - }; - - quote::quote! ( - #expand { - #[cube] - #func - #func_expand - } - ) -} diff --git a/crates/burn-cube-macros/src/codegen_type/base.rs b/crates/burn-cube-macros/src/codegen_type/base.rs deleted file mode 100644 index d772a28aeb..0000000000 --- a/crates/burn-cube-macros/src/codegen_type/base.rs +++ /dev/null @@ -1,294 +0,0 @@ -use proc_macro::TokenStream; -use quote::quote; -use syn::Ident; - -use super::GenericsCodegen; - -struct TypeCodegen { - name: syn::Ident, - name_launch: syn::Ident, - name_expand: syn::Ident, - fields: Vec, - generics: GenericsCodegen, - vis: syn::Visibility, -} - -impl TypeCodegen { - pub fn expand_ty(&self) -> proc_macro2::TokenStream { - let mut fields = quote::quote! {}; - let name = &self.name_expand; - - for field in self.fields.iter() { - let ident = &field.ident; - let ty = &field.ty; - let vis = &field.vis; - - fields.extend(quote! { - #vis #ident: <#ty as CubeType>::ExpandType, - }); - } - - let generics = self.generics.type_definitions(); - let vis = &self.vis; - - quote! { - #[derive(Clone)] - #vis struct #name #generics { - #fields - } - } - } - - pub fn launch_ty(&self) -> proc_macro2::TokenStream { - let mut fields = quote::quote! {}; - let name = &self.name_launch; - - for field in self.fields.iter() { - let ident = &field.ident; - let ty = &field.ty; - let vis = &field.vis; - - fields.extend(quote! { - #vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>, - }); - } - - let generics = self.generics.all_definitions(); - - quote! { - struct #name #generics { - #fields - } - } - } - - pub fn launch_new(&self) -> proc_macro2::TokenStream { - let mut args = quote::quote! {}; - let mut fields = quote::quote! {}; - let name = &self.name_launch; - - for field in self.fields.iter() { - let ident = &field.ident; - let ty = &field.ty; - let vis = &field.vis; - - args.extend(quote! { - #vis #ident: <#ty as LaunchArg>::RuntimeArg<'a, R>, - }); - fields.extend(quote! { - #ident, - }); - } - - let generics_impl = self.generics.all_definitions(); - let generics_use = self.generics.all_in_use(); - let vis = &self.vis; - - quote! { - impl #generics_impl #name #generics_use { - /// New kernel - #[allow(clippy::too_many_arguments)] - #vis fn new(#args) -> Self { - Self { - #fields - } - } - } - } - } - - pub fn arg_settings_impl(&self) -> proc_macro2::TokenStream { - let mut register_body = quote::quote! {}; - let mut configure_input_body = quote::quote! {}; - let mut configure_output_body = quote::quote! {}; - let name = &self.name_launch; - - for (pos, field) in self.fields.iter().enumerate() { - let ident = &field.ident; - - register_body.extend(quote! { - self.#ident.register(launcher); - }); - configure_input_body.extend(quote! { - settings = ArgSettings::::configure_input(&self.#ident, #pos, settings); - }); - configure_output_body.extend(quote! { - settings = ArgSettings::::configure_output(&self.#ident, #pos, settings); - }); - } - - let generics_impl = self.generics.all_definitions(); - let generics_use = self.generics.all_in_use(); - - quote! { - impl #generics_impl ArgSettings for #name #generics_use { - fn register(&self, launcher: &mut KernelLauncher) { - #register_body - } - - fn configure_input(&self, position: usize, mut settings: KernelSettings) -> KernelSettings { - #configure_input_body - - settings - } - - fn configure_output(&self, position: usize, mut settings: KernelSettings) -> KernelSettings { - #configure_output_body - - settings - } - } - } - } - - pub fn cube_type_impl(&self) -> proc_macro2::TokenStream { - let name = &self.name; - let name_expand = &self.name_expand; - - let generics_impl = self.generics.type_definitions(); - let generics_use = self.generics.type_in_use(); - - quote! { - impl #generics_impl CubeType for #name #generics_use { - type ExpandType = #name_expand #generics_use; - } - } - } - - pub fn launch_arg_impl(&self) -> proc_macro2::TokenStream { - let mut body_input = quote::quote! {}; - let mut body_output = quote::quote! {}; - let name = &self.name; - let name_launch = &self.name_launch; - let name_expand = &self.name_expand; - - for field in self.fields.iter() { - let ident = &field.ident; - let ty = &field.ty; - let vis = &field.vis; - - body_input.extend(quote! { - #vis #ident: <#ty as LaunchArgExpand>::expand(builder, vectorization), - }); - body_output.extend(quote! { - #vis #ident: <#ty as LaunchArgExpand>::expand_output(builder, vectorization), - }); - } - - let type_generics_impl = self.generics.type_definitions(); - let type_generics_use = self.generics.type_in_use(); - - let runtime_generics_impl = self.generics.runtime_definitions(); - let all_generics_use = self.generics.all_in_use(); - - quote! { - impl #type_generics_impl LaunchArg for #name #type_generics_use { - type RuntimeArg #runtime_generics_impl = #name_launch #all_generics_use; - } - - impl #type_generics_impl LaunchArgExpand for #name #type_generics_use { - fn expand( - builder: &mut KernelBuilder, - vectorization: burn_cube::ir::Vectorization, - ) -> ::ExpandType { - #name_expand { - #body_input - } - } - fn expand_output( - builder: &mut KernelBuilder, - vectorization: burn_cube::ir::Vectorization, - ) -> ::ExpandType { - #name_expand { - #body_output - } - } - } - } - } - - pub fn expand_type_impl(&self) -> proc_macro2::TokenStream { - let name_expand = &self.name_expand; - let type_generics_impl = self.generics.type_definitions(); - let type_generics_use = self.generics.type_in_use(); - - let mut body = quote::quote! {}; - for field in self.fields.iter() { - let ident = &field.ident; - body.extend(quote::quote! { - #ident: Init::init(self.#ident, context), - }); - } - - quote! { - impl #type_generics_impl Init for #name_expand #type_generics_use { - fn init(self, context: &mut CubeContext) -> Self { - Self { - #body - } - } - } - } - } -} - -pub(crate) fn generate_cube_type(ast: &syn::DeriveInput, with_launch: bool) -> TokenStream { - let name = ast.ident.clone(); - let generics = ast.generics.clone(); - let visibility = ast.vis.clone(); - - let name_string = name.to_string(); - let name_expand = Ident::new(format!("{}Expand", name_string).as_str(), name.span()); - let name_launch = Ident::new(format!("{}Launch", name_string).as_str(), name.span()); - - let mut fields = Vec::new(); - - match &ast.data { - syn::Data::Struct(struct_data) => { - for field in struct_data.fields.iter() { - fields.push(field.clone()); - } - } - syn::Data::Enum(_) => panic!("Only struct can be derived"), - syn::Data::Union(_) => panic!("Only struct can be derived"), - }; - - let codegen = TypeCodegen { - name, - name_launch, - name_expand, - fields, - generics: GenericsCodegen::new(generics), - vis: visibility, - }; - - let expand_ty = codegen.expand_ty(); - let launch_ty = codegen.launch_ty(); - let launch_new = codegen.launch_new(); - - let cube_type_impl = codegen.cube_type_impl(); - let arg_settings_impl = codegen.arg_settings_impl(); - let launch_arg_impl = codegen.launch_arg_impl(); - let expand_type_impl = codegen.expand_type_impl(); - - if with_launch { - quote! { - #expand_ty - #launch_ty - #launch_new - - #cube_type_impl - #arg_settings_impl - #launch_arg_impl - #expand_type_impl - } - .into() - } else { - quote! { - #expand_ty - #cube_type_impl - #expand_type_impl - } - .into() - } -} diff --git a/crates/burn-cube-macros/src/codegen_type/generics.rs b/crates/burn-cube-macros/src/codegen_type/generics.rs deleted file mode 100644 index b92170a047..0000000000 --- a/crates/burn-cube-macros/src/codegen_type/generics.rs +++ /dev/null @@ -1,81 +0,0 @@ -use proc_macro2::{Span, TokenStream}; -use quote::ToTokens; -use syn::{GenericParam, Generics, Ident, Lifetime, LifetimeParam, TypeParam}; - -pub(crate) struct GenericsCodegen { - arg_lifetime: syn::Generics, - arg_runtime: syn::Generics, - type_gens: syn::Generics, -} - -impl GenericsCodegen { - pub(crate) fn new(type_gens: syn::Generics) -> Self { - Self { - arg_lifetime: Self::arg_lifetime(), - arg_runtime: Self::arg_runtime(), - type_gens, - } - } - - fn arg_lifetime() -> Generics { - let mut generics = Generics::default(); - let lifetime = - GenericParam::Lifetime(LifetimeParam::new(Lifetime::new("'a", Span::call_site()))); - generics.params.push(lifetime); - generics - } - - fn arg_runtime() -> Generics { - let mut generics = Generics::default(); - let mut runtime_param = TypeParam::from(Ident::new("R", Span::call_site())); - runtime_param - .bounds - .push(syn::parse_str("Runtime").unwrap()); - let runtime = GenericParam::Type(runtime_param); - generics.params.push(runtime); - generics - } - - pub(crate) fn type_definitions(&self) -> TokenStream { - self.type_gens.to_token_stream() - } - - pub(crate) fn type_in_use(&self) -> TokenStream { - generics_in_use_codegen(self.type_gens.clone()) - } - - pub(crate) fn runtime_definitions(&self) -> TokenStream { - let mut generics = self.arg_runtime.clone(); - generics.params.extend(self.arg_lifetime.params.clone()); - generics.to_token_stream() - } - - pub(crate) fn all_definitions(&self) -> TokenStream { - let mut generics = self.arg_lifetime.clone(); - generics.params.extend(self.arg_runtime.params.clone()); - generics.params.extend(self.type_gens.params.clone()); - generics.to_token_stream() - } - - pub(crate) fn all_in_use(&self) -> TokenStream { - let mut generics = self.arg_lifetime.clone(); - generics.params.extend(self.arg_runtime.params.clone()); - generics.params.extend(self.type_gens.params.clone()); - generics_in_use_codegen(generics) - } -} - -fn generics_in_use_codegen(generics: Generics) -> TokenStream { - let mut tokens = quote::quote! {<}; - for generic in generics.params.iter() { - let ident = match generic { - GenericParam::Lifetime(param) => param.lifetime.to_token_stream(), - GenericParam::Type(param) => param.ident.to_token_stream(), - GenericParam::Const(_) => todo!("Const generic not supported"), - }; - tokens.extend(quote::quote! { #ident, }) - } - tokens.extend(quote::quote! {>}); - - tokens -} diff --git a/crates/burn-cube-macros/src/codegen_type/mod.rs b/crates/burn-cube-macros/src/codegen_type/mod.rs deleted file mode 100644 index 68f38dcd04..0000000000 --- a/crates/burn-cube-macros/src/codegen_type/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod base; -mod generics; - -pub(crate) use base::*; -pub(crate) use generics::*; diff --git a/crates/burn-cube-macros/src/lib.rs b/crates/burn-cube-macros/src/lib.rs deleted file mode 100644 index afcdd1bf46..0000000000 --- a/crates/burn-cube-macros/src/lib.rs +++ /dev/null @@ -1,182 +0,0 @@ -#[macro_use] -extern crate derive_new; - -mod analyzer; -mod codegen_function; -mod codegen_trait; -mod codegen_type; -mod tracker; - -pub(crate) mod codegen_common; - -use analyzer::VariableAnalyzer; -use codegen_common::signature::{expand_sig, ExpandMode}; -use codegen_function::{codegen_launch, codegen_statement}; -use codegen_trait::{expand_trait_def, expand_trait_impl}; -use codegen_type::generate_cube_type; -use proc_macro::TokenStream; -use syn::{parse_macro_input, punctuated::Punctuated, token::Comma, Meta}; -use tracker::VariableTracker; - -enum CubeMode { - /// Generates the expanded version of the function - Default, - /// Panics and prints the generated code, useful when debugging - /// Use by writing #[cube(panic)] - Debug, -} - -// Derive macro to define a cube type that is launched with a kernel -#[proc_macro_derive(CubeLaunch)] -pub fn module_derive_cube_launch(input: TokenStream) -> TokenStream { - let input = syn::parse(input).unwrap(); - - generate_cube_type(&input, true) -} - -// Derive macro to define a cube type that is not launched -#[proc_macro_derive(CubeType)] -pub fn module_derive_cube_type(input: TokenStream) -> TokenStream { - let input = syn::parse(input).unwrap(); - - generate_cube_type(&input, false) -} - -struct SupportedAttributes { - mode: CubeMode, - launch: bool, -} - -/// Derive macro for the module. -#[proc_macro_attribute] -pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream { - let args = parse_macro_input!(attr with Punctuated::::parse_terminated); - let attrs = parse_attributes(&args); - - let code: TokenStream = match syn::parse::(tokens).unwrap() { - syn::Item::Fn(func) => cube_fn(func, &attrs), - syn::Item::Impl(item) => expand_trait_impl(item).into(), - syn::Item::Trait(item) => expand_trait_def(item).into(), - _ => panic!("Cube annotations only supported for functions"), - }; - - match attrs.mode { - CubeMode::Default => code, - CubeMode::Debug => panic!("{code}"), - } -} - -fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream { - let mut variable_tracker = VariableAnalyzer::create_tracker(&func); - - match codegen_cube(&func, &mut variable_tracker, attrs.launch) { - Ok(code) => code.into(), - Err(err) => err.into(), - } -} - -fn parse_attributes(args: &Punctuated) -> SupportedAttributes { - let mut mode = CubeMode::Default; - let mut launch = false; - - for arg in args.iter() { - match arg { - Meta::Path(path) => { - if let Some(ident) = path.get_ident().map(|id| id.to_string()) { - match ident.as_str() { - "debug" => { - mode = CubeMode::Debug; - } - "launch" => { - launch = true; - } - _ => panic!("Attribute {ident} is not supported"), - } - } else { - panic!("Only ident attribute supported"); - } - } - Meta::List(_) => panic!("No List attribute supported"), - Meta::NameValue(_) => panic!("No NameValue attribute supported"), - } - } - - SupportedAttributes { mode, launch } -} - -/// Generate the expanded version of a function marked with the cube macro -fn codegen_cube( - func: &syn::ItemFn, - variable_tracker: &mut VariableTracker, - launch: bool, -) -> Result { - let signature = expand_sig( - &func.sig, - &syn::Visibility::Public(Default::default()), // Always public, otherwise we can't import - // it from an outside module. - Some(variable_tracker), - ExpandMode::FuncImpl, - ); - let mut body = quote::quote! {}; - - for statement in func.block.stmts.iter() { - let tokens = codegen_statement(statement, 0, variable_tracker); - body.extend(tokens); - } - - let is_in_error = !variable_tracker.errors.is_empty(); - - if is_in_error { - // When there is an error, we don't generate the expand method, since it's only going to - // create more errors that won't help fixing the issue. - - let mut code = quote::quote! { - #[allow(dead_code)] - #[allow(clippy::too_many_arguments)] - #func - }; - - for err in variable_tracker.errors.drain(..) { - code.extend(err.into_compile_error()); - } - - return Err(code); - } - - let launch_doc = if launch { - "and launch functions " - } else { - "function " - }; - - let launch = if launch { - codegen_launch(&func.sig) - } else { - quote::quote! {} - }; - - let mod_name = &func.sig.ident; - let vis = &func.vis; - let doc = format!("Module containing the expand {launch_doc}of {mod_name}."); - - Ok(quote::quote! { - #[allow(dead_code)] - #[allow(clippy::too_many_arguments)] - #func - - - #[doc = #doc] - #vis mod #mod_name { - use super::*; - - #launch - - #[allow(unused_mut)] - #[allow(clippy::too_many_arguments)] - #signature { - #body - } - - } - }) -} diff --git a/crates/burn-cube-macros/src/tracker.rs b/crates/burn-cube-macros/src/tracker.rs deleted file mode 100644 index 7371cf1043..0000000000 --- a/crates/burn-cube-macros/src/tracker.rs +++ /dev/null @@ -1,244 +0,0 @@ -use std::collections::HashMap; - -#[derive(new, Hash, PartialEq, Eq, Debug, Clone)] -/// Identifies a variable uniquely -pub struct VariableIdent { - name: String, - repeat: u8, - scope: u8, - field: Option, -} - -#[derive(new, Eq, PartialEq, Hash, Debug)] -/// Identifies a variable, with possible collisions when variables are redeclared -struct VariableKey { - name: String, - scope: u8, -} - -#[derive(Debug, Default)] -/// Tracks variable uses -pub(crate) struct VariableTracker { - scopes_declared: HashMap>, - analysis_repeats: HashMap, - codegen_repeats: HashMap, - variable_uses: HashMap, - pub errors: Vec, -} - -#[derive(Debug, Default)] -/// Encapsulates number of uses and whether this implies cloning -pub(crate) struct VariableUse { - pub num_used: usize, - pub is_comptime: bool, -} - -impl VariableUse { - pub fn should_clone(&self) -> bool { - self.num_used > 1 - } -} - -impl VariableTracker { - /// During analysis, tracks a variable declaration - pub(crate) fn analyze_declare(&mut self, name: String, scope: u8, is_comptime: bool) { - if let Some(scopes) = self.scopes_declared.get_mut(&name) { - if !scopes.contains(&scope) { - scopes.push(scope); - } - } else { - self.scopes_declared.insert(name.clone(), vec![scope]); - } - - let key = VariableKey::new(name.clone(), scope); - let repeat = if let Some(count) = self.analysis_repeats.get_mut(&key) { - *count += 1; - *count - } else { - self.analysis_repeats.insert(key, 0); - 0 - }; - - let analysis = VariableUse { - num_used: 1, - is_comptime, - }; - let variable_ident = VariableIdent::new(name, repeat, scope, None); - self.variable_uses.insert(variable_ident, analysis); - } - - /// During analysis, tracks a variable use - pub(crate) fn analyze_reuse(&mut self, ident: &syn::Ident, scope: u8, field: Option) { - let name = ident.to_string(); - - if name == "None" { - return; - } - - let scopes_declared = match self.scopes_declared.get(&name) { - Some(val) => val, - None => { - self.errors - .push(syn::Error::new_spanned(ident, "Variable not declared")); - return; - } - }; - - let scope = *scopes_declared - .iter() - .filter(|s| **s <= scope) - .max() - .unwrap(); - let key = VariableKey::new(name.clone(), scope); - - // If the name and scope do not match a declared variable, - // then we are using a variable declared in a parent scope, and - // cloning must always happen, therefore no need for further analysis - if let Some(repeat) = self.analysis_repeats.get(&key) { - let variable = VariableIdent::new(name, *repeat, scope, field); - self.analyze(&variable); - } - } - - /// Increments variable use and its parent struct if need be - fn analyze(&mut self, variable_ident: &VariableIdent) { - match self.variable_uses.get_mut(variable_ident) { - Some(variable_use) => { - variable_use.num_used += 1; - } - None => { - // If variable was not inserted yet, it must be a field - if variable_ident.field.is_some() { - let mut parent_ident = variable_ident.clone(); - parent_ident.field = None; - let parent = self.variable_uses.get(&parent_ident).unwrap(); - - let attr_analysis = VariableUse { - num_used: 1, - is_comptime: parent.is_comptime, - }; - self.variable_uses - .insert(variable_ident.clone(), attr_analysis); - } else { - panic!("Variable not declared"); - } - } - }; - - // Whether a field was previously seen or not, we must increase the use of the parent struct - if variable_ident.field.is_some() { - let mut declaration_ident = variable_ident.clone(); - declaration_ident.field = None; - let declaration = self - .variable_uses - .get_mut(&declaration_ident) - .unwrap_or_else(|| panic!("Struct {:?} does not exist", declaration_ident)); - declaration.num_used += 1; - } - } - - /// During codegen, tracks a variable declaration. - /// This must be done again to know on what repeat a use occurs - pub(crate) fn codegen_declare(&mut self, name: String, scope: u8) { - let key = VariableKey::new(name.clone(), scope); - if let Some(count) = self.codegen_repeats.get_mut(&key) { - *count += 1; - } else { - self.codegen_repeats.insert(key, 0); - } - } - - /// During codegen, tracks a variable use. - pub(crate) fn codegen_reuse( - &mut self, - name: String, - scope: u8, - field: Option, - ) -> Result<(bool, bool), VariableReuseError> { - let scopes_declared = self - .scopes_declared - .get(&name) - .ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?; - let scope_declared = *scopes_declared - .iter() - .filter(|s| **s <= scope) - .max() - .ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?; - - let key = VariableKey::new(name.clone(), scope_declared); - let repeat = self.codegen_repeats.get(&key).unwrap_or(&0); - let ident = VariableIdent::new(name.clone(), *repeat, scope_declared, field.clone()); - - let should_clone_parent = if field.is_some() { - let struct_ident = VariableIdent::new(name.clone(), *repeat, scope_declared, None); - let parent_analysis = self - .variable_uses - .get_mut(&struct_ident) - .ok_or_else(|| VariableNotFound::new(name.clone(), scope_declared, None))?; - - parent_analysis.num_used -= 1; - parent_analysis.should_clone() - } else { - false - }; - - let analysis = self - .variable_uses - .get_mut(&ident) - .ok_or_else(|| VariableNotFound::new(name, scope_declared, field))?; - - analysis.num_used -= 1; - let should_clone = - analysis.should_clone() || should_clone_parent || scope_declared != scope; - Ok((should_clone, analysis.is_comptime)) - } - - pub fn set_as_comptime( - &mut self, - name: String, - scope: u8, - field: Option, - ) -> Result<(), VariableReuseError> { - let scopes_declared = self - .scopes_declared - .get(&name) - .ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?; - let scope_declared = *scopes_declared - .iter() - .filter(|s| **s <= scope) - .max() - .ok_or_else(|| VariableNotFound::new(name.clone(), scope, field.clone()))?; - - let key = VariableKey::new(name.clone(), scope_declared); - let repeat = self.codegen_repeats.get(&key).unwrap_or(&0); - let ident = VariableIdent::new(name.clone(), *repeat, scope_declared, field.clone()); - - let analysis = self - .variable_uses - .get_mut(&ident) - .ok_or_else(|| VariableNotFound::new(name, scope_declared, field))?; - - analysis.is_comptime = true; - - Ok(()) - } -} - -#[derive(new, Debug)] -pub struct VariableNotFound { - _name: String, - _scope: u8, - _field: Option, -} - -#[derive(Debug)] -#[allow(dead_code)] -pub enum VariableReuseError { - VariableNotFound(VariableNotFound), -} - -impl From for VariableReuseError { - fn from(value: VariableNotFound) -> Self { - Self::VariableNotFound(value) - } -} diff --git a/crates/burn-cube/Cargo.toml b/crates/burn-cube/Cargo.toml deleted file mode 100644 index 6b9089c8ed..0000000000 --- a/crates/burn-cube/Cargo.toml +++ /dev/null @@ -1,37 +0,0 @@ -[package] -authors = [ - "nathanielsimard ", - "louisfd ", -] -categories = ["science"] -description = "Cube Compute Language (CubeCL) is a subset of Rust that can be executed on accelerators for compute intensive tasks." -edition.workspace = true -keywords = [] -license.workspace = true -name = "burn-cube" -readme.workspace = true -repository = "https://github.com/tracel-ai/burn/tree/main/burn-cube" -version.workspace = true - -[features] -default = ["tensor"] -std = [] -template = [] -tensor = ["burn-tensor"] -export_tests = [] - -[dependencies] -burn-compute = { path = "../burn-compute", version = "0.14.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.14.0", default-features = false, optional = true } - -bytemuck = { workspace = true } -half = { workspace = true, features = ["bytemuck"] } -serde = { workspace = true } -burn-cube-macros = { path = "../burn-cube-macros", version = "0.14.0" } -derive-new = { workspace = true } -num-traits = { workspace = true } - -log = { workspace = true } - -[dev-dependencies] -trybuild = "1" diff --git a/crates/burn-cube/README.md b/crates/burn-cube/README.md deleted file mode 100644 index c0ce4a1b3f..0000000000 --- a/crates/burn-cube/README.md +++ /dev/null @@ -1,202 +0,0 @@ -
- - -
-
- -[![Rust Version](https://img.shields.io/badge/Rust-1.79.0+-blue)](https://releases.rs/docs/1.79.0) -![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue) - ---- - -**Multi-platform high-performance compute language extension for Rust.** -
- -
- -## TL;DR - -With CubeCL, you can program your GPU using Rust leveraging zero-cost abstraction to create maintainable, flexible and optimal compute kernels. - -## Motivation - -The goal of CubeCL is to ease the pain of writing highly optimized compute kernels that are portable across hardware. -There is currently no adequate solution when you want optimal performance while still being multi-platform. -You either have to write custom kernels for different hardware, often with different languages such as CUDA, Metal, or ROCm. -To fix this, we created a Just-in-Time compiler with three core features: **automatic vectorization**, **comptime**, and **autotune**! - -These features are extremely useful for anyone writing high-performance kernels, even when portability is not a concern. -They improve code composability, reusability, testability, and maintainability, all while staying optimal. - -### Disclaimer & History - -CubeCL is currently in **alpha**. -The only supported runtimes are CUDA and WebGPU for now. -It's easy to add more GPU runtimes and we intend to support Metal, ROCm, and Vulkan; contributions are welcome! -We also want to have an optimized JIT CPU runtime with SIMD instructions, leveraging [Cranelift](https://cranelift.dev). - -While CubeCL is currently in use in [Burn](https://burn.dev), there are still a lot of rough edges; it isn't refined yet. -The project started as a WebGPU-only backend for Burn. -As we optimized it, we realized that we needed an intermediate representation (IR) that could be optimized then compiled to WGSL. -Having an IR made it easy to support another compilation target, so we made a CUDA runtime. -However, writing kernels directly in that IR wasn't easy, so we created a Rust frontend using the [syn](https://github.com/dtolnay/syn) crate. -Navigating the differences between CUDA and WebGPU, while leveraging both platforms, forced us to come up with general concepts that worked everywhere. -Hence, CubeCL was born! - -## Design - -CubeCL is designed around - you guessed it - Cubes! More specifically, it's based on cuboids, because not all axes are the same size. -Since all compute APIs need to map to the hardware, which are tiles that can be accessed using a 3D representation, our topology can easily be mapped to concepts from other APIs. - -
- -### CubeCL - Topology - - -
-
-
- -_A cube is composed of units, so a 3x3x3 cube has 27 units that can be accessed by their positions along the x, y, and z axes. -Similarly, a hyper-cube is composed of cubes, just as a cube is composed of units. -Each cube in the hyper-cube can be accessed by its position relative to the hyper-cube along the x, y, and z axes. -Hence, a hyper-cube of 3x3x3 will have 27 cubes. -In this example, the total number of working units would be 27 x 27 = 729._ - -
-Topology Equivalence 👇 -
- -Since all topology variables are constant within the kernel entry point, we chose to use the Rust constant syntax with capital letters. -Often when creating kernels, we don't always care about the relative position of a unit within a cube along each axis, but often we only care about its position in general. -Therefore, each kind of variable also has its own axis-independent variable, which is often not present in other languages, except WebGPU with `local_invocation_index`. - -
- -| CubeCL | CUDA | WebGPU | -| -------------- | ----------- | ---------------------- | -| CUBE_COUNT | N/A | N/A | -| CUBE_COUNT_X | gridDim.x | num_workgroups.x | -| CUBE_COUNT_Y | gridDim.y | num_workgroups.y | -| CUBE_COUNT_Z | gridDim.z | num_workgroups.z | -| CUBE_POS | N/A | N/A | -| CUBE_POS_X | blockIdx.x | workgroup.x | -| CUBE_POS_Y | blockIdx.y | workgroup.y | -| CUBE_POS_Z | blockIdx.z | workgroup.z | -| CUBE_DIM | N/A | N/A | -| CUBE_DIM_X | blockDim.x | workgroup_size.x | -| CUBE_DIM_Y | blockDim.y | workgroup_size.y | -| CUBE_DIM_Z | blockDim.z | workgroup_size.z | -| UNIT_POS | N/A | local_invocation_index | -| UNIT_POS_X | threadIdx.x | local_invocation_id.x | -| UNIT_POS_Y | threadIdx.y | local_invocation_id.y | -| UNIT_POS_Z | threadIdx.z | local_invocation_id.z | -| SUBCUBE_DIM | warpSize | subgroup_size | -| ABSOLUTE_POS | N/A | N/A | -| ABSOLUTE_POS_X | N/A | global_id.x | -| ABSOLUTE_POS_Y | N/A | global_id.y | -| ABSOLUTE_POS_Z | N/A | global_id.z | - -
- -## Special Features - -#### Automatic Vectorization - -High-performance kernels should rely on SIMD instructions whenever possible, but doing so can quickly get pretty complicated! -With CubeCL, you can specify the vectorization factor of each input variable when launching a kernel. -Inside the kernel code, you still use only one type, which is dynamically vectorized and supports automatic broadcasting. -The runtimes are able to compile kernels and have all the necessary information to use the best instruction! -However, since the algorithmic behavior may depend on the vectorization factor, CubeCL allows you to access it directly in the kernel when needed, without any performance loss, using the comptime system! - -#### Comptime - -CubeCL isn't just a new compute language: though it feels like you are writing GPU kernels, you are, in fact, writing compiler plugins that you can fully customize! -Comptime is a way to modify the compiler IR at runtime when compiling a kernel for the first time. - -This enables lots of optimizations and flexibility without having to write many separate variants of the same kernels to ensure maximal performance. - -| Feature | Description | -| ------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| **Instruction Specialization** | Not all instructions are available on all hardware, but when a specialized one exists, it should be enabled with a simple if statement. | -| **Automatic Vectorization** | When you can use SIMD instructions, you should! But since not all hardware supports the same vectorization factors, it can be injected at runtime! | -| **Loop Unrolling** | You may want multiple flavors of the same kernel, with loop unrolling for only a certain range of values. This can be configured easily with Comptime. | -| **Shape Specialization** | For deep learning kernels, it's often crucial to rely on different kernels for different input sizes; you can do it by passing the shape information as Comptime values. | -| **Compile Time Calculation** | In general, you can calculate a constant using Rust runtime properties and inject it into a kernel during its compilation, to avoid recalculating it during each execution. | - -#### Autotuning - -Autotuning drastically simplifies kernel selection by running small benchmarks at runtime to figure out the best kernels with the best configurations to run on the current hardware; an essential feature for portability. -This feature combines gracefully with comptime to test the effect of different comptime values on performance; sometimes it can be surprising! - -Even if the benchmarks may add some overhead when running the application for the first time, the information gets cached on the device and will be reused. -It is usually a no-brainer trade-off for throughput-oriented programs such as deep learning models. -You can even ship the autotune cache with your program, reducing cold start time when you have more control over the deployment target. - -## Example - -CubeCL is designed to be easy to use for Rust programmers: it relies on the same syntax and is fully integrated with the language. -You can simply add an attribute on the top of a Rust function for it to be executed on the GPU. - -```rust -#[cube(launch)] -fn gelu(input: &Array, output: &mut Array) { - if ABSOLUTE_POS < input.len() { - let x = input[ABSOLUTE_POS] - let gelu = x * (1 + erf(x / sqrt(2))) / 2; - output[ABSOLUTE_POS] = gelu; - } -} - -fn main() { - type Runtime = CudaRuntime; - - let device = Default::default(); - let client = Runtime::client(&device); - - let input_handle = client.create(f32::as_bytes(&[-1., 0., 1., 5.])); - let output_handle = client.empty(input.len() * core::mem::size_of::()); - - gelu::launch::( - client, - CubeCount::new(1, 1, 1), - CubeDim::new(4, 1, 1), - &input_handle, - &output_handle, - ); - - let output = client.read(output_handle.binding()).read_sync().unwrap(); - let output = f32::from_bytes(&output); - - // Should be [-0.1587, 0.0000, 0.8413, 5.0000] - println!("{output:?}"); -} - -``` - -The `cube` attribute generates the code that is needed to compile a kernel. -In the case above, the function `gelu_expand` and `gelu_launch` are automatically generated. -This allows you to compose Cube functions easily: - -```rust - -#[cube] -fn gelu_scalar(x: F) -> F { - x * (1 + erf(x / sqrt(2))) / 2 -} - -#[cube(launch)] -fn gelu(input: Array, mut output: Array) { - if ABSOLUTE_POS < input.shape(0) { - output[ABSOLUTE_POS] = gelu_scalar::(input[ABSOLUTE_POS]); - } -} -``` - -Note that you don't have to specify `launch` in a function that is only used by another Cube function. -In addition, you can have return types without problem, which isn't the case when you are writing an entry point to a kernel using the `launch` attribute. -The function `gelu_expand` will actually use `gelu_scalar_expand`, making it easy to combine your functions. - -## Resource - -If you have any questions or want to contribute, don't hesitate to join the [Discord](https://discord.gg/uPEBbYYDB6). diff --git a/crates/burn-cube/assets/CubeCL.webp b/crates/burn-cube/assets/CubeCL.webp deleted file mode 100644 index 305c1b882c..0000000000 Binary files a/crates/burn-cube/assets/CubeCL.webp and /dev/null differ diff --git a/crates/burn-cube/assets/cubecl.drawio.svg b/crates/burn-cube/assets/cubecl.drawio.svg deleted file mode 100644 index e0278413ab..0000000000 --- a/crates/burn-cube/assets/cubecl.drawio.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
(2, 2, 0)
(1, 2, 0)
(0, 2, 0)
(2, 1, 0)
(1, 1, 0)
(0, 1, 0)
(2, 0, 0)
(1, 0, 0)
(0, 0, 0)
Cube
[2, 2, 0]
[1, 2, 0]
[0, 2, 0]
[2, 1, 0]
[1, 1, 0]
[0, 1, 0]
[2, 0, 0]
[1, 0, 0]
Hyper-Cube
(0, 0, 0)
Unit
(0, 0, 0)
(0, 0, 0)
diff --git a/crates/burn-cube/assets/logo.drawio.svg b/crates/burn-cube/assets/logo.drawio.svg deleted file mode 100644 index a83fc35206..0000000000 --- a/crates/burn-cube/assets/logo.drawio.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -

CubeCL

\ No newline at end of file diff --git a/crates/burn-cube/src/codegen/compiler.rs b/crates/burn-cube/src/codegen/compiler.rs deleted file mode 100644 index 4f1e451057..0000000000 --- a/crates/burn-cube/src/codegen/compiler.rs +++ /dev/null @@ -1,21 +0,0 @@ -use crate::ir::{Elem, KernelDefinition}; -use std::fmt::Display; - -/// Trait for compiled code representation -pub trait CompilerRepresentation: Display { - /// Computes and returns the shared memory size - fn shared_memory_size(&self) -> usize; -} - -/// Compiles the representation into its own representation that can be formatted into tokens. -pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug { - /// The representation for the compiled code. - type Representation: CompilerRepresentation; - - /// Compiles the [kernel definition](KernelDefinition) into the compiler's representation. - fn compile(kernel: KernelDefinition) -> Self::Representation; - /// The size of the given element in bytes. - fn elem_size(elem: Elem) -> usize; - /// The maximal size of a shared memory - fn max_shared_memory_size() -> usize; -} diff --git a/crates/burn-cube/src/codegen/execution.rs b/crates/burn-cube/src/codegen/execution.rs deleted file mode 100644 index d9fe2cb5fd..0000000000 --- a/crates/burn-cube/src/codegen/execution.rs +++ /dev/null @@ -1,358 +0,0 @@ -use crate::compute::{CubeCount, KernelTask}; -use crate::frontend::TensorHandle; -use crate::ir::Elem; -use crate::pod::CubeElement; -use crate::{calculate_cube_count_elemwise, Kernel, Runtime, SUBCUBE_DIM_APPROX}; -use burn_compute::client::ComputeClient; -use burn_compute::server::{Binding, ComputeServer, Handle}; - -/// The position of the input or output to calculate the number of cubes to launch. -pub enum CubeCountSettings { - Input { pos: usize }, - Output { pos: usize }, - Custom(CubeCount), -} - -pub struct Execution<'h, K, R: Runtime, Scalars> { - scalars: Scalars, - client: ComputeClient, - kernel: K, - inputs: &'h [TensorHandle<'h, R>], - outputs: &'h [TensorHandle<'h, R>], -} - -impl<'h, K, R: Runtime> Execution<'h, K, R, ()> { - pub fn start( - kernel: K, - client: ComputeClient, - ) -> Execution<'h, K, R, ()> { - Execution { - scalars: (), - client, - kernel, - inputs: &[], - outputs: &[], - } - } - - #[allow(unused)] - pub fn inputs(self, inputs: &'h [TensorHandle<'h, R>]) -> Execution<'h, K, R, ()> { - Execution { - scalars: self.scalars, - client: self.client, - kernel: self.kernel, - inputs, - outputs: self.outputs, - } - } - - pub fn outputs(self, outputs: &'h [TensorHandle<'h, R>]) -> Execution<'h, K, R, ()> { - Execution { - scalars: self.scalars, - client: self.client, - kernel: self.kernel, - inputs: self.inputs, - outputs, - } - } -} - -impl<'h, K, R> Execution<'h, K, R, ()> -where - K: Kernel + 'static, - R: Runtime, -{ - pub fn with_scalars(self, scalars: &[E]) -> Execution<'h, K, R, (&[E],)> { - Execution { - scalars: (scalars,), - client: self.client, - kernel: self.kernel, - inputs: self.inputs, - outputs: self.outputs, - } - } - /// Execute a dynamic kernel. - #[allow(unused)] - pub fn execute(self, launch: CubeCountSettings) { - execute_dynamic::( - self.inputs, - self.outputs, - None, - None, - None, - self.kernel, - launch, - self.client, - ) - } -} - -impl<'h, 'a, K, R, E> Execution<'h, K, R, (&'a [E],)> -where - K: Kernel + 'static, - R: Runtime, - E: CubeElement, -{ - pub fn with_scalars<'b, E2>( - self, - scalars: &'b [E2], - ) -> Execution<'h, K, R, (&'a [E], &'b [E2])> { - Execution { - scalars: (self.scalars.0, scalars), - client: self.client, - kernel: self.kernel, - inputs: self.inputs, - outputs: self.outputs, - } - } - - /// Execute a dynamic kernel. - #[allow(unused)] - pub fn execute(self, launch: CubeCountSettings) { - execute_dynamic::( - self.inputs, - self.outputs, - Some(self.scalars.0), - None, - None, - self.kernel, - launch, - self.client, - ) - } -} - -impl<'h, 'a, 'b, K, R, E1, E2> Execution<'h, K, R, (&'a [E1], &'b [E2])> -where - K: Kernel + 'static, - R: Runtime, - E1: CubeElement, - E2: CubeElement, -{ - #[allow(unused, clippy::type_complexity)] - pub fn with_scalars<'c, E3>( - self, - scalars: &'c [E3], - ) -> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])> { - Execution { - scalars: (self.scalars.0, self.scalars.1, scalars), - client: self.client, - kernel: self.kernel, - inputs: self.inputs, - outputs: self.outputs, - } - } - /// Execute a dynamic kernel. - #[allow(clippy::too_many_arguments)] - pub fn execute(self, launch: CubeCountSettings) - where - K: Kernel + 'static, - R: Runtime, - { - execute_dynamic::( - self.inputs, - self.outputs, - Some(self.scalars.0), - Some(self.scalars.1), - None, - self.kernel, - launch, - self.client, - ) - } -} - -impl<'h, 'a, 'b, 'c, K, R, E1, E2, E3> Execution<'h, K, R, (&'a [E1], &'b [E2], &'c [E3])> -where - K: Kernel + 'static, - R: Runtime, - E1: CubeElement, - E2: CubeElement, - E3: CubeElement, -{ - /// Execute a dynamic kernel. - #[allow(unused)] - pub fn execute(self, launch: CubeCountSettings) { - execute_dynamic::( - self.inputs, - self.outputs, - Some(self.scalars.0), - Some(self.scalars.1), - Some(self.scalars.2), - self.kernel, - launch, - self.client, - ) - } -} - -#[allow(clippy::too_many_arguments)] -fn execute_dynamic( - inputs: &[TensorHandle], - outputs: &[TensorHandle], - scalars_1: Option<&[E1]>, - scalars_2: Option<&[E2]>, - scalars_3: Option<&[E3]>, - kernel: K, - launch: CubeCountSettings, - client: ComputeClient, -) where - K: Kernel + 'static, - R: Runtime, - E1: CubeElement, - E2: CubeElement, - E3: CubeElement, -{ - let settings = execute_settings( - inputs, outputs, scalars_1, scalars_2, scalars_3, launch, &client, - ); - let mut handles = settings.handles_tensors; - - handles.push(settings.handle_info.binding()); - for handle in settings.handles_scalars.into_iter() { - handles.push(handle.binding()); - } - - let kernel = Box::new(KernelTask::::new(kernel)); - client.execute(kernel, settings.cube_count, handles); -} - -struct ExecuteSettings { - handles_tensors: Vec>, - handle_info: Handle, - handles_scalars: Vec>, - cube_count: CubeCount, -} - -fn execute_settings<'a, R: Runtime, E1: CubeElement, E2: CubeElement, E3: CubeElement>( - inputs: &'a [TensorHandle], - outputs: &'a [TensorHandle], - scalars_1: Option<&[E1]>, - scalars_2: Option<&[E2]>, - scalars_3: Option<&[E3]>, - launch: CubeCountSettings, - client: &ComputeClient, -) -> ExecuteSettings { - let mut info = Vec::new(); - let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2); - - // Inner function to fill the info buffer. - let mut register_info_tensor = |strides: &[usize], shape: &[usize]| { - if info.is_empty() { - info.push(strides.len() as u32); - } - - for s in strides.iter() { - info.push(*s as u32); - } - for s in shape.iter() { - info.push(*s as u32); - } - }; - - let mut num_elems_output = 0; - - // We start by registering the inputs. - for (i, input) in inputs.iter().enumerate() { - if let CubeCountSettings::Input { pos } = &launch { - if i == *pos { - num_elems_output = calculate_num_elems_dyn_rank(input.shape); - } - }; - register_info_tensor(input.strides, input.shape); - handles.push(input.handle.clone().binding()); - } - - // Then we follow with the outputs. - for (i, output) in outputs.iter().enumerate() { - if let CubeCountSettings::Output { pos } = &launch { - if i == *pos { - num_elems_output = calculate_num_elems_dyn_rank(output.shape); - } - }; - register_info_tensor(output.strides, output.shape); - handles.push(output.handle.clone().binding()); - } - - // [2, I0stride0, I0stride1, I0shape0, I0shape1i, I1... O0..., I0len, I1len1, O0len] - if R::require_array_lengths() { - for input in inputs.iter() { - let len = calculate_num_elems_dyn_rank(input.shape); - info.push(len as u32); - } - - for output in outputs.iter() { - let len = calculate_num_elems_dyn_rank(output.shape); - info.push(len as u32); - } - } - - let info = client.create(bytemuck::cast_slice(&info)); - - // Finally we finish with the named bindings. - let handles_scalars = - create_scalar_handles::(scalars_1, scalars_2, scalars_3, client); - - let cube_count = match launch { - CubeCountSettings::Custom(count) => count, - _ => calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX), - }; - - ExecuteSettings { - handles_tensors: handles, - handle_info: info, - handles_scalars, - cube_count, - } -} - -fn create_scalar_handles( - scalars_0: Option<&[E1]>, - scalars_1: Option<&[E2]>, - scalars_2: Option<&[E3]>, - client: &ComputeClient, -) -> Vec> { - // It is crucial that scalars follow this order: float, int, uint - let element_priority = |elem: Elem| match elem { - Elem::Float(_) => 0, - Elem::Int(_) => 1, - Elem::UInt => 2, - Elem::Bool => panic!("Bool scalars are not supported"), - }; - let scalar_priorities: [usize; 3] = [ - element_priority(E1::cube_elem()), - element_priority(E2::cube_elem()), - element_priority(E3::cube_elem()), - ]; - - let mut handles_scalars = Vec::new(); - for i in 0..3 { - for (j, scalar_priority) in scalar_priorities.iter().enumerate() { - if scalar_priority == &i { - if j == 0 { - if let Some(values) = &scalars_0 { - handles_scalars.push(client.create(bytemuck::cast_slice(values))); - } - } else if j == 1 { - if let Some(values) = &scalars_1 { - handles_scalars.push(client.create(bytemuck::cast_slice(values))); - } - } else if j == 2 { - if let Some(values) = &scalars_2 { - handles_scalars.push(client.create(bytemuck::cast_slice(values))); - } - } - } - } - } - - handles_scalars -} - -pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize { - let mut num_elems = 1; - for i in shape.iter() { - num_elems *= i; - } - num_elems -} diff --git a/crates/burn-cube/src/codegen/integrator.rs b/crates/burn-cube/src/codegen/integrator.rs deleted file mode 100644 index c19b810a8c..0000000000 --- a/crates/burn-cube/src/codegen/integrator.rs +++ /dev/null @@ -1,560 +0,0 @@ -use super::Compiler; -use crate::{ - ir::{ - Binding, CubeDim, Elem, Item, KernelDefinition, Location, ReadingStrategy, Scope, Variable, - Vectorization, Visibility, - }, - Runtime, -}; - -/// The kernel integrator allows you to create a [kernel definition](KernelDefinition) based on -/// [kernel expansion](KernelExpansion) and [kernel settings](KernelSettings). -#[derive(Clone)] -pub struct KernelIntegrator { - expansion: KernelExpansion, - input_bindings: Vec, - output_bindings: Vec, - named_bindings: Vec<(String, Binding)>, -} - -/// The information necessary to compile a [kernel definition](KernelDefinition). -#[derive(Clone)] -pub struct KernelExpansion { - pub inputs: Vec, - pub outputs: Vec, - pub scope: Scope, -} - -/// Simply indicate the output that can be replaced by the input. -#[derive(new, Clone, Copy, Debug)] -pub struct InplaceMapping { - /// Input position. - pub pos_input: usize, - /// Output position. - pub pos_output: usize, -} - -#[derive(Clone, Copy, Debug)] -enum VectorizationPartial { - Input { - pos: usize, - vectorization: Vectorization, - }, - Output { - pos: usize, - vectorization: Vectorization, - }, -} - -#[derive(Default, Clone)] -pub struct KernelSettings { - pub mappings: Vec, - vectorization_global: Option, - vectorization_partial: Vec, - cube_dim: CubeDim, - pub reading_strategy: Vec<(u16, ReadingStrategy)>, -} - -impl core::fmt::Display for KernelSettings { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // The goal of this implementation is to generate the shortest representation - // that won't clash with any other compilation settings. This is crucial since we rely on - // this representation to know when to compile a new version of a kernel. - // - // Each main section starts with a letter that can't be used by other main sections: - // - // * Mapping: m - // * Input: i - // * Output: o - // - // * Reading Strategy: r - // * Output layout: o - // * Plain: p - // - // * Vectorization Global: vg{factor} - // * Vectorization Partial Input: v{factor}i{pos} - // * Vectorization Partial Output: vo - // * Cube Dim X: x - // * Cube Dim Y: y - // * Cube Dim Z: z - f.write_str("m")?; - for mapping in self.mappings.iter() { - f.write_fmt(format_args!( - "i{}o{}", - mapping.pos_input, mapping.pos_output - ))?; - } - - f.write_str("r")?; - - for (input, strategy) in self.reading_strategy.iter() { - match strategy { - ReadingStrategy::OutputLayout => f.write_fmt(format_args!("i{}o", input)), - ReadingStrategy::Plain => f.write_fmt(format_args!("i{}p", input)), - }?; - } - - match self.vectorization_global { - Some(vectorization) => f.write_fmt(format_args!("vg{}", vectorization))?, - None => f.write_str("vn")?, - }; - - for vectorization in self.vectorization_partial.iter() { - match vectorization { - VectorizationPartial::Input { pos, vectorization } => { - f.write_fmt(format_args!("v{vectorization}i{pos}"))? - } - VectorizationPartial::Output { pos, vectorization } => { - f.write_fmt(format_args!("v{vectorization}o{pos}"))? - } - }; - } - - f.write_fmt(format_args!( - "x{}y{}z{}", - self.cube_dim.x, self.cube_dim.y, self.cube_dim.x - )) - } -} - -impl KernelSettings { - /// Compile the shader with vectorization enabled for all inputs and outputs. - #[allow(dead_code)] - pub fn vectorize_global(mut self, vectorization: Vectorization) -> Self { - self.vectorization_global = Some(vectorization); - self - } - - /// Compile the shader with vectorization enabled for an input. - #[allow(dead_code)] - pub fn vectorize_input(mut self, position: usize, vectorization: Vectorization) -> Self { - // Not setting the vectorization factor when it's the default value reduces the kernel id - // size. - if vectorization == 1 { - return self; - } - - self.vectorization_partial - .push(VectorizationPartial::Input { - pos: position, - vectorization, - }); - self - } - - /// Compile the shader with vectorization enabled for an output. - #[allow(dead_code)] - pub fn vectorize_output(mut self, position: usize, vectorization: Vectorization) -> Self { - // Not setting the vectorization factor when it's the default value reduces the kernel id - // size. - if vectorization == 1 { - return self; - } - - self.vectorization_partial - .push(VectorizationPartial::Output { - pos: position, - vectorization, - }); - self - } - - /// Fetch the vectorization for the provided input position. - pub fn vectorization_input(&self, position: usize) -> Vectorization { - if let Some(vec) = self.vectorization_global { - return vec; - } - - for partial in self.vectorization_partial.iter() { - if let VectorizationPartial::Input { pos, vectorization } = partial { - if *pos == position { - return *vectorization; - } - } - } - - 1 - } - - /// Fetch the vectorization for the provided output position. - pub fn vectorization_output(&self, position: usize) -> Vectorization { - if let Some(vec) = self.vectorization_global { - return vec; - } - - for partial in self.vectorization_partial.iter() { - if let VectorizationPartial::Output { pos, vectorization } = partial { - if *pos == position { - return *vectorization; - } - } - } - - 1 - } - - /// Compile the shader with inplace enabled by the given [mapping](InplaceMapping). - /// - /// Notes: - /// - /// You should favor using `dynamic_settings` when using fusion, since the mapping is going to - /// be created from the runtime information. - pub fn inplace(mut self, mappings: Vec) -> Self { - self.mappings = mappings; - self - } - - /// Set cube dimension. - #[allow(dead_code)] - pub fn cube_dim(mut self, cube_dim: CubeDim) -> Self { - self.cube_dim = cube_dim; - self - } -} - -#[allow(dead_code)] -fn is_contiguous(strides: &[usize]) -> bool { - let mut current = 0; - - for stride in strides.iter().rev() { - if current > *stride { - return false; - } - current = *stride; - } - - true -} - -/// Information related to an input. -#[derive(Clone, Debug)] -pub enum InputInfo { - Array { item: Item, visibility: Visibility }, - Scalar { elem: Elem, size: usize }, -} - -impl InputInfo { - /// The item type of the input. - #[allow(dead_code)] - pub fn item(&self) -> Item { - match self { - InputInfo::Array { - item, - visibility: _, - } => *item, - InputInfo::Scalar { elem, size: _ } => Item::new(*elem), - } - } -} - -impl OutputInfo { - /// The item type of the input. - #[allow(dead_code)] - pub fn item(&self) -> Item { - match self { - OutputInfo::ArrayWrite { - item, - local: _, - position: _, - } => *item, - OutputInfo::InputArrayWrite { - item, - input: _, - local: _, - position: _, - } => *item, - OutputInfo::Array { item } => *item, - } - } -} - -/// Information related to an output. -#[derive(Clone, Debug)] -pub enum OutputInfo { - /// Write the local variable to a new array. - /// - /// This will create a new binding in the [kernel definition](KernelDefinition). - ArrayWrite { - item: Item, - local: u16, - position: Variable, - }, - /// Write the local variable to an existing input binding. - InputArrayWrite { - item: Item, - input: u16, - local: u16, - position: Variable, - }, - /// Simply register the output, but don't automatically add a write to it. - /// - /// Useful when a procedure writes to the output using operations. - Array { item: Item }, -} - -impl OutputInfo { - #[allow(dead_code)] - pub fn elem_size(&self) -> usize { - let elem = match self { - OutputInfo::ArrayWrite { - item, - local: _, - position: _, - } => bool_elem(item.elem()), - OutputInfo::InputArrayWrite { - item, - input: _, - local: _, - position: _, - } => bool_elem(item.elem()), - OutputInfo::Array { item } => bool_elem(item.elem()), - }; - ::elem_size(elem) - } -} - -impl KernelIntegrator { - /// Starts a new compilation. - pub fn new(info: KernelExpansion) -> Self { - Self { - expansion: info, - input_bindings: Default::default(), - output_bindings: Default::default(), - named_bindings: Default::default(), - } - } - - /// Performs the compilation with the provided [settings](KernelSettings). - pub fn integrate(mut self, mut settings: KernelSettings) -> KernelDefinition { - if let Some(vectorization) = settings.vectorization_global { - self.expansion.scope.vectorize(vectorization); - } - - self.register_inputs(&settings); - self.register_outputs(&mut settings); - - let inputs = self.input_bindings; - let outputs = self.output_bindings; - let mut named = Vec::with_capacity(2); - - named.push(( - "info".to_string(), - Binding { - item: Item::new(Elem::UInt), - visibility: Visibility::Read, - location: Location::Storage, - size: None, // We avoid putting the length here since it will force a new kernel - // for each tensor rank. - }, - )); - - for (name, binding) in self.named_bindings.into_iter() { - named.push((name, binding)); - } - - KernelDefinition { - inputs, - outputs, - named, - cube_dim: settings.cube_dim, - body: self.expansion.scope, - } - } - - fn register_inputs(&mut self, settings: &KernelSettings) { - for (id, strategy) in settings.reading_strategy.iter() { - self.expansion.scope.update_read(*id, *strategy); - } - - for input in self.expansion.inputs.drain(..) { - match input { - InputInfo::Array { item, visibility } => { - let item = if let Some(vectorization) = settings.vectorization_global { - item.vectorize(vectorization) - } else { - item - }; - - self.input_bindings.push(Binding { - item: bool_item(item), - visibility, - location: Location::Storage, - size: None, - }); - } - InputInfo::Scalar { elem, size } => { - let elem = bool_elem(elem); - - self.named_bindings.push(( - format!("scalars_{}", elem), - Binding { - item: Item::new(elem), - visibility: Visibility::Read, - location: Location::Storage, - size: Some(size), - }, - )); - } - } - } - } - - fn register_outputs(&mut self, settings: &mut KernelSettings) { - let mut index = 0; - - if !settings.mappings.is_empty() { - let mut mappings = Vec::new(); - core::mem::swap(&mut settings.mappings, &mut mappings); - - for mapping in mappings { - self.register_inplace_mapping(mapping); - } - } - - for array in self.expansion.outputs.drain(..) { - match array { - OutputInfo::ArrayWrite { - item, - local, - position, - } => { - let item = if let Some(vectorization) = settings.vectorization_global { - item.vectorize(vectorization) - } else { - item - }; - let item_adapted = bool_item(item); - - self.output_bindings.push(Binding { - item: item_adapted, - visibility: Visibility::ReadWrite, - location: Location::Storage, - size: None, - }); - self.expansion.scope.write_global( - Variable::Local { - id: local, - item, - depth: self.expansion.scope.depth, - }, - Variable::GlobalOutputArray { - id: index, - item: item_adapted, - }, - position, - ); - index += 1; - } - OutputInfo::InputArrayWrite { - item, - input, - local, - position, - } => { - let item = if let Some(vectorization) = settings.vectorization_global { - item.vectorize(vectorization) - } else { - item - }; - - self.expansion.scope.write_global( - Variable::Local { - id: local, - item, - depth: self.expansion.scope.depth, - }, - Variable::GlobalInputArray { - id: input, - item: bool_item(item), - }, - position, - ); - } - OutputInfo::Array { item } => { - let item = if let Some(vectorization) = settings.vectorization_global { - item.vectorize(vectorization) - } else { - item - }; - let elem_adapted = bool_item(item); - - self.output_bindings.push(Binding { - item: elem_adapted, - visibility: Visibility::ReadWrite, - location: Location::Storage, - size: None, - }); - - index += 1; - } - } - } - } - - fn register_inplace_mapping(&mut self, mapping: InplaceMapping) { - let output = match self.expansion.outputs.get_mut(mapping.pos_output) { - Some(output) => output, - None => { - // The mapping is handled differently, normally by cube itself. - return; - } - }; - - let (item, local, position) = match output { - OutputInfo::ArrayWrite { item, local, position } => (item, local, position), - OutputInfo::InputArrayWrite { - item: _, - input, - local: _, - position: _, - } => { - assert_eq!( - *input, mapping.pos_input as u16, - "Can't use different inputs for the same output." - ); - return; - } - OutputInfo::Array { item: _ } => panic!("Can't register an inplace operation for an array that isn't using a defined writing strategy."), - }; - - let item = match self.input_bindings.get_mut(mapping.pos_input) { - Some(binding) => { - // Update input visibility. - binding.visibility = Visibility::ReadWrite; - // Inputs modified inplace should be read without any specified layout. - self.expansion - .scope - .update_read(mapping.pos_input as u16, ReadingStrategy::Plain); - - // Use the same item as the input. - // - // The output can be different (i.e inplace boolean operations on float bindings). - binding.item - } - None => *item, - }; - - // Update the output. - *output = OutputInfo::InputArrayWrite { - item, - input: mapping.pos_input as u16, - local: *local, - position: *position, - }; - } -} - -fn bool_item(ty: Item) -> Item { - Item { - elem: bool_elem(ty.elem), - vectorization: ty.vectorization, - } -} - -pub fn bool_elem(elem: Elem) -> Elem { - match elem { - // U32 are used for bool tensors - Elem::Bool => Elem::UInt, - _ => elem, - } -} diff --git a/crates/burn-cube/src/codegen/mod.rs b/crates/burn-cube/src/codegen/mod.rs deleted file mode 100644 index b94d342875..0000000000 --- a/crates/burn-cube/src/codegen/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod execution; -mod integrator; - -mod compiler; - -pub use compiler::*; -pub use execution::*; -pub use integrator::*; diff --git a/crates/burn-cube/src/compute/builder.rs b/crates/burn-cube/src/compute/builder.rs deleted file mode 100644 index 5664a9f619..0000000000 --- a/crates/burn-cube/src/compute/builder.rs +++ /dev/null @@ -1,104 +0,0 @@ -use crate::ir::{Elem, Item, Visibility}; -use crate::prelude::KernelDefinition; -use crate::KernelSettings; -use crate::{ - frontend::{CubeContext, ExpandElement}, - InputInfo, KernelExpansion, KernelIntegrator, OutputInfo, -}; -use std::collections::HashMap; - -/// Prepare a kernel to create a [kernel definition](crate::KernelDefinition). -pub struct KernelBuilder { - /// Cube [context](CubeContext). - pub context: CubeContext, - inputs: Vec, - outputs: Vec, - indices: HashMap, - num_input: u16, - num_output: u16, -} - -impl KernelBuilder { - /// Register a scalar and return the [element](ExpandElement) to be used for kernel expansion. - pub fn scalar(&mut self, elem: Elem) -> ExpandElement { - let index = match self.indices.get_mut(&elem) { - Some(index) => match self.inputs.get_mut(*index).unwrap() { - InputInfo::Scalar { elem: _, size } => { - *size += 1; - *size as u16 - 1 - } - _ => panic!("Should be a scalar."), - }, - None => { - self.indices.insert(elem, self.inputs.len()); - self.inputs.push(InputInfo::Scalar { size: 1, elem }); - 0 - } - }; - - self.context.scalar(index, elem) - } - - /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion. - pub fn output_tensor(&mut self, item: Item) -> ExpandElement { - self.outputs.push(OutputInfo::Array { item }); - let variable = self.context.output(self.num_output, item); - self.num_output += 1; - - variable - } - - /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion. - pub fn input_tensor(&mut self, item: Item) -> ExpandElement { - self.inputs.push(InputInfo::Array { - item, - visibility: Visibility::Read, - }); - let variable = self.context.input(self.num_input, item); - self.num_input += 1; - variable - } - - /// Register an output array and return the [element](ExpandElement) to be used for kernel expansion. - pub fn output_array(&mut self, item: Item) -> ExpandElement { - self.outputs.push(OutputInfo::Array { item }); - let variable = self.context.output(self.num_output, item); - self.num_output += 1; - - variable - } - - /// Register an input array and return the [element](ExpandElement) to be used for kernel expansion. - pub fn input_array(&mut self, item: Item) -> ExpandElement { - self.inputs.push(InputInfo::Array { - item, - visibility: Visibility::Read, - }); - let variable = self.context.input(self.num_input, item); - self.num_input += 1; - variable - } - - /// Build the [kernel definition](KernelDefinition). - pub fn build(self, settings: KernelSettings) -> KernelDefinition { - KernelIntegrator::new(KernelExpansion { - scope: self.context.into_scope(), - inputs: self.inputs, - outputs: self.outputs, - }) - .integrate(settings) - } -} - -impl Default for KernelBuilder { - fn default() -> Self { - Self { - context: CubeContext::root(), - inputs: Vec::new(), - outputs: Vec::new(), - indices: HashMap::new(), - num_input: 0, - num_output: 0, - } - } -} diff --git a/crates/burn-cube/src/compute/kernel.rs b/crates/burn-cube/src/compute/kernel.rs deleted file mode 100644 index 356ecb4e04..0000000000 --- a/crates/burn-cube/src/compute/kernel.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::marker::PhantomData; - -use crate::{codegen::CompilerRepresentation, ir::CubeDim, Compiler, Kernel}; -use alloc::sync::Arc; -use burn_compute::server::{Binding, ComputeServer}; - -/// A kernel, compiled in the target language -pub struct CompiledKernel { - /// Source code of the kernel - pub source: String, - /// Size of a cube for the compiled kernel - pub cube_dim: CubeDim, - /// The number of bytes used by the share memory - pub shared_mem_bytes: usize, -} - -/// Kernel trait with the ComputeShader that will be compiled and cached based on the -/// provided id. -pub trait CubeTask: Send + Sync { - /// Identifier for the kernel, used for caching kernel compilation. - fn id(&self) -> String; - /// Compile the kernel into source - fn compile(&self) -> CompiledKernel; -} - -/// Wraps a [kernel](Kernel) to create a [cube task](CubeTask). -#[derive(new)] -pub struct KernelTask { - kernel_definition: K, - _compiler: PhantomData, -} - -impl CubeTask for KernelTask { - fn compile(&self) -> CompiledKernel { - let gpu_ir = self.kernel_definition.define(); - let cube_dim = gpu_ir.cube_dim; - let lower_level_ir = C::compile(gpu_ir); - let shared_mem_bytes = lower_level_ir.shared_memory_size(); - let source = lower_level_ir.to_string(); - - CompiledKernel { - source, - cube_dim, - shared_mem_bytes, - } - } - - fn id(&self) -> String { - self.kernel_definition.id().clone() - } -} - -impl CubeTask for Arc { - fn compile(&self) -> CompiledKernel { - self.as_ref().compile() - } - - fn id(&self) -> String { - self.as_ref().id() - } -} - -impl CubeTask for Box { - fn compile(&self) -> CompiledKernel { - self.as_ref().compile() - } - - fn id(&self) -> String { - self.as_ref().id() - } -} - -/// Provides launch information specifying the number of work groups to be used by a compute shader. -pub enum CubeCount { - /// Dispatch x,y,z work groups. - Static(u32, u32, u32), - /// Dispatch work groups based on the values in this buffer. The buffer should contain a u32 array [x, y, z]. - Dynamic(Binding), -} - -impl Clone for CubeCount { - fn clone(&self) -> Self { - match self { - Self::Static(x, y, z) => Self::Static(*x, *y, *z), - Self::Dynamic(handle) => Self::Dynamic(handle.clone()), - } - } -} diff --git a/crates/burn-cube/src/compute/launcher.rs b/crates/burn-cube/src/compute/launcher.rs deleted file mode 100644 index 7bc292229b..0000000000 --- a/crates/burn-cube/src/compute/launcher.rs +++ /dev/null @@ -1,337 +0,0 @@ -use crate::compute::{CubeCount, KernelTask}; -use crate::ir::{Elem, FloatKind, IntKind}; -use crate::prelude::ArrayHandle; -use crate::KernelSettings; -use crate::{calculate_num_elems_dyn_rank, frontend::TensorHandle, Kernel, Runtime}; -use burn_compute::client::ComputeClient; -use burn_compute::server::Binding; -use bytemuck::NoUninit; -use num_traits::ToPrimitive; - -/// Prepare a kernel for [launch](KernelLauncher::launch). -pub struct KernelLauncher { - tensors: TensorState, - scalar_bf16: ScalarState, - scalar_f16: ScalarState, - scalar_f32: ScalarState, - scalar_f64: ScalarState, - scalar_u32: ScalarState, - scalar_i64: ScalarState, - scalar_i32: ScalarState, - scalar_order: Vec, - pub settings: KernelSettings, -} - -impl KernelLauncher { - /// Register a tensor to be launched. - pub fn register_tensor(&mut self, tensor: &TensorHandle<'_, R>) { - self.tensors.push(tensor); - } - - /// Register an array to be launched. - pub fn register_array(&mut self, array: &ArrayHandle<'_, R>) { - self.tensors.push(&array.as_tensor()); - } - - /// Register a u32 scalar to be launched. - pub fn register_u32(&mut self, scalar: u32) { - self.register_scalar(Elem::UInt); - self.scalar_u32.push(scalar); - } - - /// Register a i32 scalar to be launched. - pub fn register_i32(&mut self, scalar: i32) { - self.register_scalar(Elem::Int(IntKind::I32)); - self.scalar_i32.push(scalar); - } - - /// Register a i64 scalar to be launched. - pub fn register_i64(&mut self, scalar: i64) { - self.register_scalar(Elem::Int(IntKind::I64)); - self.scalar_i64.push(scalar); - } - - /// Register a bf16 scalar to be launched. - pub fn register_bf16(&mut self, scalar: half::bf16) { - self.register_scalar(Elem::Float(FloatKind::BF16)); - self.scalar_bf16.push(scalar); - } - - /// Register a f16 scalar to be launched. - pub fn register_f16(&mut self, scalar: half::f16) { - self.register_scalar(Elem::Float(FloatKind::F16)); - self.scalar_f16.push(scalar); - } - - /// Register a f32 scalar to be launched. - pub fn register_f32(&mut self, scalar: f32) { - self.register_scalar(Elem::Float(FloatKind::F32)); - self.scalar_f32.push(scalar); - } - - /// Register a f64 scalar to be launched. - pub fn register_f64(&mut self, scalar: f64) { - self.register_scalar(Elem::Float(FloatKind::F64)); - self.scalar_f64.push(scalar); - } - - /// Launch the kernel. - pub fn launch( - self, - cube_count: CubeCount, - kernel: K, - client: ComputeClient, - ) { - let bindings = self.into_bindings(&client); - - let kernel = Box::new(KernelTask::::new(kernel)); - - client.execute(kernel, cube_count, bindings); - } - - /// We need to create the bindings in the same order they are defined in the compilation step. - /// - /// The function [crate::KernelIntegrator::integrate] stars by registering the input tensors followed - /// by the output tensors. Then the tensor metadata, and the scalars at the end. The scalars - /// are registered in the same order they are added. This is why we store the scalar data type - /// in the `scalar_order` vector, so that we can register them in the same order. - fn into_bindings( - mut self, - client: &ComputeClient, - ) -> Vec> { - let mut bindings = Vec::new(); - - self.tensors.register(client, &mut bindings); - - for elem in self.scalar_order.drain(..) { - match elem { - Elem::Float(kind) => match kind { - FloatKind::F16 => self.scalar_f16.register::(client, &mut bindings), - FloatKind::BF16 => self.scalar_bf16.register::(client, &mut bindings), - FloatKind::F32 => self.scalar_f32.register::(client, &mut bindings), - FloatKind::F64 => self.scalar_f64.register::(client, &mut bindings), - }, - Elem::Int(kind) => match kind { - IntKind::I32 => self.scalar_i32.register::(client, &mut bindings), - IntKind::I64 => self.scalar_i64.register::(client, &mut bindings), - }, - Elem::UInt => self.scalar_u32.register::(client, &mut bindings), - Elem::Bool => panic!("Bool can't be passed as bindings."), - } - } - - bindings - } - - fn register_scalar(&mut self, elem: Elem) { - if !self.scalar_order.contains(&elem) { - self.scalar_order.push(elem); - } - } -} - -/// Handles the tensor state. -pub enum TensorState { - /// No tensor is registered yet. - Empty, - /// The registered tensors. - Some { - bindings: Vec>, - metadata: Vec, - lengths: Vec, - }, -} - -/// Handles the scalar state of an element type -/// -/// The scalars are grouped to reduce the number of buffers needed to send data to the compute device. -pub enum ScalarState { - /// No scalar of that type is registered yet. - Empty, - /// The registered scalars. - Some(Vec), -} - -impl TensorState { - /// Push a new tensor to the state. - pub fn push(&mut self, tensor: &TensorHandle<'_, R>) { - if let TensorState::Empty = self { - *self = TensorState::Some { - bindings: Vec::with_capacity(1), - metadata: Vec::new(), - lengths: Vec::new(), - }; - }; - - let (bindings, metadata, lengths) = match self { - TensorState::Empty => panic!("Should be init"), - TensorState::Some { - bindings, - metadata, - lengths, - } => (bindings, metadata, lengths), - }; - - bindings.push(tensor.handle.clone().binding()); - - let old_rank = if metadata.is_empty() { - let rank = tensor.strides.len() as u32; - metadata.push(rank); - None - } else if tensor.strides.len() > metadata[0] as usize { - let old_rank = metadata[0]; - let rank = tensor.strides.len() as u32; - Self::adjust_rank(metadata, bindings.len(), rank); - Some(old_rank) - } else { - None - }; - - Self::register_strides(tensor.strides, tensor.shape, old_rank, metadata); - Self::register_shape(tensor.shape, old_rank, metadata); - - if R::require_array_lengths() { - let len = calculate_num_elems_dyn_rank(tensor.shape); - lengths.push(len as u32); - } - } - - fn adjust_rank(metadata: &mut Vec, num_registered: usize, rank: u32) { - let old_rank = metadata[0] as usize; - let rank_diff = rank as usize - old_rank; - let mut updated_metadata = Vec::with_capacity(2 * rank_diff * num_registered); - - for pos in 0..num_registered { - let stride_index = (pos * old_rank * 2) + 1; - let shape_index = stride_index + old_rank; - - let strides_old = &metadata[stride_index..stride_index + old_rank]; - let shape_old = &metadata[shape_index..shape_index + old_rank]; - - Self::register_strides( - strides_old, - shape_old, - Some(old_rank as u32), - &mut updated_metadata, - ); - Self::register_shape(shape_old, Some(old_rank as u32), &mut updated_metadata); - } - - core::mem::swap(&mut updated_metadata, metadata); - } - - fn register_strides( - strides: &[T], - shape: &[T], - old_rank: Option, - output: &mut Vec, - ) { - let old_rank = if let Some(old_rank) = old_rank { - let rank = output[0]; - let rank_diff = old_rank - rank; - let padded_strides = if rank_diff > 0 { - shape - .iter() - .take(old_rank as usize) - .map(|a| a.to_u32().unwrap()) - .sum::() - } else { - 0 - }; - - for _ in 0..rank_diff { - output.push(padded_strides.to_u32().unwrap()); - } - - old_rank as usize - } else { - output[0] as usize // same as current. - }; - - for stride in strides.iter().take(old_rank) { - output.push(stride.to_u32().unwrap()); - } - } - - fn register_shape(shape: &[T], old_rank: Option, output: &mut Vec) { - let old_rank = if let Some(old_rank) = old_rank { - let rank = output[0]; - let rank_diff = rank - old_rank; - - for _ in 0..rank_diff { - output.push(1); - } - - old_rank as usize - } else { - output[0] as usize // same as current - }; - - for elem in shape.iter().take(old_rank) { - output.push(elem.to_u32().unwrap()); - } - } - - fn register( - self, - client: &ComputeClient, - bindings_global: &mut Vec>, - ) { - if let Self::Some { - bindings, - mut metadata, - lengths, - } = self - { - if R::require_array_lengths() { - for len in lengths { - metadata.push(len); - } - } - - bindings_global.extend(bindings); - bindings_global.push(client.create(bytemuck::cast_slice(&metadata)).binding()); - } - } -} - -impl ScalarState { - /// Add a new scalar value to the state. - pub fn push(&mut self, val: T) { - match self { - ScalarState::Empty => *self = Self::Some(vec![val]), - ScalarState::Some(values) => values.push(val), - } - } - - fn register( - &self, - client: &ComputeClient, - bindings: &mut Vec>, - ) { - match self { - ScalarState::Empty => (), - ScalarState::Some(values) => { - let handle = client.create(bytemuck::cast_slice(values)); - bindings.push(handle.binding()); - } - } - } -} - -impl Default for KernelLauncher { - fn default() -> Self { - Self { - tensors: TensorState::Empty, - scalar_bf16: ScalarState::Empty, - scalar_f16: ScalarState::Empty, - scalar_f32: ScalarState::Empty, - scalar_f64: ScalarState::Empty, - scalar_u32: ScalarState::Empty, - scalar_i64: ScalarState::Empty, - scalar_i32: ScalarState::Empty, - scalar_order: Vec::new(), - settings: Default::default(), - } - } -} diff --git a/crates/burn-cube/src/compute/mod.rs b/crates/burn-cube/src/compute/mod.rs deleted file mode 100644 index 200399f700..0000000000 --- a/crates/burn-cube/src/compute/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod builder; -mod kernel; -mod launcher; - -pub use builder::*; -pub use kernel::*; -pub use launcher::*; diff --git a/crates/burn-cube/src/frontend/base.rs b/crates/burn-cube/src/frontend/base.rs deleted file mode 100644 index ba92c36c54..0000000000 --- a/crates/burn-cube/src/frontend/base.rs +++ /dev/null @@ -1,12 +0,0 @@ -#[macro_export] -macro_rules! unexpanded { - () => ({ - panic!("Unexpanded Cube functions should not be called. "); - }); - ($msg:expr) => ({ - panic!($msg); - }); - ($fmt:expr, $($arg:tt)*) => ({ - panic!($fmt, $($arg)*); - }); -} diff --git a/crates/burn-cube/src/frontend/branch.rs b/crates/burn-cube/src/frontend/branch.rs deleted file mode 100644 index 6caf4c34ea..0000000000 --- a/crates/burn-cube/src/frontend/branch.rs +++ /dev/null @@ -1,153 +0,0 @@ -use std::ops::Deref; - -use crate::frontend::{CubeContext, ExpandElement, UInt}; -use crate::ir::{Branch, Elem, If, IfElse, Item, Loop, RangeLoop, Variable}; - -use super::comptime::Comptime; - -pub fn range(start: S, end: E, _unroll: Comptime) -> impl Iterator -where - S: Into, - E: Into, -{ - let start: UInt = start.into(); - let end: UInt = end.into(); - - (start.val..end.val).map(UInt::new) -} - -pub fn range_expand(context: &mut CubeContext, start: S, end: E, unroll: bool, mut func: F) -where - F: FnMut(&mut CubeContext, ExpandElement), - S: Into, - E: Into, -{ - let start: ExpandElement = start.into(); - let end: ExpandElement = end.into(); - - if unroll { - let start = match start.deref() { - Variable::ConstantScalar { value, .. } => *value as usize, - _ => panic!("Only constant start can be unrolled."), - }; - let end = match end.deref() { - Variable::ConstantScalar { value, .. } => *value as usize, - _ => panic!("Only constant end can be unrolled."), - }; - - for i in start..end { - func(context, i.into()) - } - } else { - let mut child = context.child(); - let index_ty = Item::new(Elem::UInt); - let i = child.scope.borrow_mut().create_local_undeclared(index_ty); - let i = ExpandElement::Plain(i); - - func(&mut child, i.clone()); - - context.register(Branch::RangeLoop(RangeLoop { - i: *i, - start: *start, - end: *end, - scope: child.into_scope(), - })); - } -} - -pub fn if_expand( - context: &mut CubeContext, - comptime_cond: Option, - runtime_cond: ExpandElement, - mut block: IF, -) where - IF: FnMut(&mut CubeContext), -{ - match comptime_cond { - Some(cond) => { - if cond { - block(context); - } - } - None => { - let mut child = context.child(); - - block(&mut child); - - context.register(Branch::If(If { - cond: *runtime_cond, - scope: child.into_scope(), - })); - } - } -} - -pub fn if_else_expand( - context: &mut CubeContext, - comptime_cond: Option, - runtime_cond: ExpandElement, - mut then_block: IF, - mut else_block: EL, -) where - IF: FnMut(&mut CubeContext), - EL: FnMut(&mut CubeContext), -{ - match comptime_cond { - Some(cond) => { - if cond { - then_block(context); - } else { - else_block(context); - } - } - None => { - let mut then_child = context.child(); - then_block(&mut then_child); - - let mut else_child = context.child(); - else_block(&mut else_child); - - context.register(Branch::IfElse(IfElse { - cond: *runtime_cond, - scope_if: then_child.into_scope(), - scope_else: else_child.into_scope(), - })); - } - } -} - -pub fn break_expand(context: &mut CubeContext) { - context.register(Branch::Break); -} - -pub fn return_expand(context: &mut CubeContext) { - context.register(Branch::Return); -} - -pub fn loop_expand(context: &mut CubeContext, mut block: FB) -where - FB: FnMut(&mut CubeContext), -{ - let mut inside_loop = context.child(); - - block(&mut inside_loop); - context.register(Branch::Loop(Loop { - scope: inside_loop.into_scope(), - })); -} - -pub fn while_loop_expand(context: &mut CubeContext, mut cond_fn: FC, mut block: FB) -where - FC: FnMut(&mut CubeContext) -> ExpandElement, - FB: FnMut(&mut CubeContext), -{ - let mut inside_loop = context.child(); - - let cond: ExpandElement = cond_fn(&mut inside_loop); - if_expand(&mut inside_loop, None, cond, break_expand); - - block(&mut inside_loop); - context.register(Branch::Loop(Loop { - scope: inside_loop.into_scope(), - })); -} diff --git a/crates/burn-cube/src/frontend/cmma.rs b/crates/burn-cube/src/frontend/cmma.rs deleted file mode 100644 index 5490ea1300..0000000000 --- a/crates/burn-cube/src/frontend/cmma.rs +++ /dev/null @@ -1,238 +0,0 @@ -//! This module exposes cooperative matrix-multiply and accumulate operations. -//! -//! Most of the functions are actually unsafe, since they mutate their input, even if they are -//! passed as reference. -//! -//! # Example -//! -//! This is a basic 16x16x16 matrix multiplication example. -//! -//! ```rust, ignore -//! #[cube(launch)] -//! pub fn example(lhs: &Array, rhs: &Array, out: &mut Array) { -//! let a = cmma::Matrix::::new( -//! cmma::MatrixIdent::A, -//! 16, -//! 16, -//! 16, -//! cmma::MatrixLayout::RowMajor, -//! ); -//! let b = cmma::Matrix::::new( -//! cmma::MatrixIdent::B, -//! 16, -//! 16, -//! 16, -//! cmma::MatrixLayout::ColMajor, -//! ); -//! let c = cmma::Matrix::::new( -//! cmma::MatrixIdent::Accumulator, -//! 16, -//! 16, -//! 16, -//! cmma::MatrixLayout::Undefined, -//! ); -//! cmma::fill::(&c, F32::new(0.0)); -//! cmma::load::(&a, lhs.as_slice(), UInt::new(16)); -//! cmma::load::(&b, rhs.as_slice(), UInt::new(16)); -//! -//! cmma::execute::(&a, &b, &c, &c); -//! -//! cmma::store::( -//! out.as_slice_mut(), -//! &c, -//! UInt::new(16), -//! cmma::MatrixLayout::RowMajor, -//! ); -//! } -//! ``` - -use std::marker::PhantomData; - -use crate::{ - ir::{self, Operation}, - unexpanded, -}; - -use super::{ - CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, Slice, SliceMut, - UInt, -}; - -pub use ir::{MatrixIdent, MatrixLayout}; - -/// A matrix represent a 2D grid of numbers. -/// -/// They can either be in a [row major](MatrixLayout::RowMajor) or a -/// [column major](MatrixLayout::ColMajor) format. -pub struct Matrix { - _c: PhantomData, -} - -/// Expand type of [Matrix]. -#[derive(Clone)] -pub struct MatrixExpand { - elem: ExpandElement, -} - -impl CubeType for Matrix { - type ExpandType = MatrixExpand; -} - -impl Init for MatrixExpand { - fn init(self, _context: &mut CubeContext) -> Self { - self - } -} - -impl Matrix { - /// Create a new matrix that is going to be used in the - /// [matrix-multiply and accumulate](execute()) function. - /// - /// You have to declare the shape used for the execution. - /// The shape of the current matrix is determined using the [MatrixIdent]. - /// - /// * [MatrixIdent::A] Shape => (M, K) - /// * [MatrixIdent::B] Shape => (K, N) - /// * [MatrixIdent::Accumulator] Shape => (M, N) - /// - /// Not all shapes are supported, and the permitted shapes depend on the element type. - /// - /// Refer to [nvidia documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#element-types-and-matrix-sizes). - #[allow(unused_variables)] - pub fn new(ident: MatrixIdent, m: u8, n: u8, k: u8, layout: MatrixLayout) -> Self { - Matrix { _c: PhantomData } - } - - pub fn __expand_new( - context: &mut CubeContext, - ident: MatrixIdent, - m: u8, - n: u8, - k: u8, - layout: MatrixLayout, - ) -> MatrixExpand { - let elem = context.create_matrix(ir::Matrix { - ident, - m, - n, - k, - elem: C::as_elem(), - layout, - }); - MatrixExpand { elem } - } -} - -/// Fill the matrix with the provided value. -#[allow(unused_variables)] -pub fn fill(mat: &Matrix, value: C) { - unexpanded!() -} - -/// Module containing the expand function for [fill()]. -pub mod fill { - use super::*; - - /// Expand method of [fill()]. - pub fn __expand( - context: &mut CubeContext, - mat: MatrixExpand, - value: ExpandElement, - ) { - context.register(Operation::CoopMma(ir::CoopMma::Fill { - mat: *mat.elem, - value: *value, - })); - } -} - -/// Load the matrix with the provided array using the stride. -#[allow(unused_variables)] -pub fn load(mat: &Matrix, value: &Slice<'_, C>, stride: UInt) { - unexpanded!() -} - -/// Module containing the expand function for [load()]. -pub mod load { - use super::*; - - /// Expand method of [load()]. - #[allow(unused_variables)] - pub fn __expand( - context: &mut CubeContext, - mat: MatrixExpand, - value: ExpandElementTyped>, - stride: ExpandElement, - ) { - context.register(Operation::CoopMma(ir::CoopMma::Load { - mat: *mat.elem, - value: *value.expand, - stride: *stride, - })); - } -} - -/// Store the matrix in the given array following the given stride and layout. -#[allow(unused_variables)] -pub fn store( - output: &mut SliceMut<'_, C>, - mat: &Matrix, - stride: UInt, - layout: MatrixLayout, -) { - unexpanded!() -} - -/// Module containing the expand function for [store()]. -pub mod store { - use super::*; - - /// Expand method of [store()]. - #[allow(unused_variables)] - pub fn __expand( - context: &mut CubeContext, - output: ExpandElementTyped>, - mat: MatrixExpand, - stride: ExpandElement, - layout: MatrixLayout, - ) { - context.register(Operation::CoopMma(ir::CoopMma::Store { - output: *output.expand, - mat: *mat.elem, - stride: *stride, - layout, - })); - } -} - -/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix). -#[allow(unused_variables)] -pub fn execute( - mat_a: &Matrix, - mat_b: &Matrix, - mat_c: &Matrix, - mat_d: &Matrix, -) { - unexpanded!() -} - -/// Module containing the expand function for [execute()]. -pub mod execute { - use super::*; - - /// Expand method of [execute()]. - pub fn __expand( - context: &mut CubeContext, - mat_a: MatrixExpand, - mat_b: MatrixExpand, - mat_c: MatrixExpand, - mat_d: MatrixExpand, - ) { - context.register(Operation::CoopMma(ir::CoopMma::Execute { - mat_a: *mat_a.elem, - mat_b: *mat_b.elem, - mat_c: *mat_c.elem, - mat_d: *mat_d.elem, - })); - } -} diff --git a/crates/burn-cube/src/frontend/comptime.rs b/crates/burn-cube/src/frontend/comptime.rs deleted file mode 100644 index 8f8d807bc0..0000000000 --- a/crates/burn-cube/src/frontend/comptime.rs +++ /dev/null @@ -1,146 +0,0 @@ -use crate::{ - frontend::{CubeContext, CubeType}, - unexpanded, -}; - -use super::{ExpandElement, Init, UInt, Vectorized}; - -#[derive(Clone, Copy)] -/// Encapsulates a value to signify it must be used at compilation time rather than in the kernel -/// -/// Use `Comptime>` to have an alternate runtime behaviour if the compilation time value is not present -pub struct Comptime { - pub(crate) inner: T, -} - -impl Comptime { - /// Create a new Comptime. Useful when hardcoding values in - /// Cube kernels. For instance: - /// if Comptime::new(false) {...} never generates the inner code block - pub fn new(inner: T) -> Self { - Self { inner } - } - - /// Get the inner value of a Comptime. For instance: - /// let c = Comptime::new(false); - /// if Comptime::get(c) {...} - pub fn get(_comptime: Self) -> T { - unexpanded!() - } - - pub fn map R>(_comptime: Self, _closure: F) -> Comptime { - unexpanded!() - } - - pub fn map_expand R>(inner: T, closure: F) -> R { - closure(inner) - } -} - -impl> Comptime> { - /// Map a Comptime optional to a Comptime boolean that tell - /// whether the optional contained a value - pub fn is_some(comptime: Self) -> Comptime { - Comptime::new(comptime.inner.is_some()) - } - - /// Return the inner value of the Comptime if it exists, - /// otherwise tell how to compute it at runtime - pub fn unwrap_or_else(_comptime: Self, mut _alt: F) -> T - where - F: FnOnce() -> T, - { - unexpanded!() - } - - /// Expanded version of unwrap_or_else - pub fn unwrap_or_else_expand( - context: &mut CubeContext, - t: Option, - alt: F, - ) -> ::ExpandType - where - F: FnOnce(&mut CubeContext) -> T::ExpandType, - { - match t { - Some(t) => t.into(), - None => alt(context), - } - } -} - -impl CubeType for Comptime { - type ExpandType = T; -} - -impl Comptime { - pub fn vectorization(_state: &T) -> Comptime { - unexpanded!() - } - - pub fn vectorization_expand(_context: &mut CubeContext, state: T) -> UInt { - state.vectorization_factor() - } -} - -impl> Comptime { - pub fn runtime(_comptime: Self) -> T { - unexpanded!() - } - - pub fn runtime_expand(_context: &mut CubeContext, inner: T) -> ExpandElement { - inner.into() - } -} - -impl> core::ops::Add for Comptime { - type Output = Comptime; - - fn add(self, rhs: Self) -> Self::Output { - Comptime::new(self.inner.add(rhs.inner)) - } -} - -impl> core::ops::Sub for Comptime { - type Output = Comptime; - - fn sub(self, rhs: Self) -> Self::Output { - Comptime::new(self.inner.sub(rhs.inner)) - } -} - -impl> core::ops::Div for Comptime { - type Output = Comptime; - - fn div(self, rhs: Self) -> Self::Output { - Comptime::new(self.inner.div(rhs.inner)) - } -} - -impl> core::ops::Mul for Comptime { - type Output = Comptime; - - fn mul(self, rhs: Self) -> Self::Output { - Comptime::new(self.inner.mul(rhs.inner)) - } -} - -impl> core::ops::Rem for Comptime { - type Output = Comptime; - - fn rem(self, rhs: Self) -> Self::Output { - Comptime::new(self.inner.rem(rhs.inner)) - } -} - -impl core::cmp::PartialEq for Comptime { - fn eq(&self, other: &Self) -> bool { - core::cmp::PartialEq::eq(&self.inner, &other.inner) - } -} - -impl core::cmp::PartialOrd for Comptime { - fn partial_cmp(&self, other: &Self) -> Option { - core::cmp::PartialOrd::partial_cmp(&self.inner, &other.inner) - } -} diff --git a/crates/burn-cube/src/frontend/context.rs b/crates/burn-cube/src/frontend/context.rs deleted file mode 100644 index 1d8c52290f..0000000000 --- a/crates/burn-cube/src/frontend/context.rs +++ /dev/null @@ -1,145 +0,0 @@ -use crate::frontend::ExpandElement; -use crate::ir::{self, Elem, Item, Operation, Scope}; -use alloc::rc::Rc; -use core::cell::RefCell; -use std::collections::HashMap; - -#[derive(Default, Clone)] -pub struct VariablePool { - map: Rc>>>, -} - -impl VariablePool { - /// Returns an old, not used anymore variable, if there exists one. - pub fn reuse(&self, item: Item) -> Option { - let map = self.map.borrow(); - - // Filter for candidate variables of the same Item - let variables = map.get(&item)?; - - // Among the candidates, take a variable if it's only referenced by the map - // Arbitrarily takes the first it finds in reverse order. - for variable in variables.iter().rev() { - match variable { - ExpandElement::Managed(var) => { - if Rc::strong_count(var) == 1 { - return Some(variable.clone()); - } - } - ExpandElement::Plain(_) => (), - } - } - - // If no candidate was found, a new var will be needed - None - } - - /// Insert a new variable in the map, which is classified by Item - pub fn insert(&mut self, var: ExpandElement) { - let mut map = self.map.borrow_mut(); - let item = var.item(); - - if let Some(variables) = map.get_mut(&item) { - variables.push(var.clone()); - } else { - map.insert(var.item(), vec![var.clone()]); - } - } -} - -pub struct CubeContext { - pub root: Rc>, - pub scope: Rc>, - pub pool: VariablePool, -} - -impl CubeContext { - /// Create a new cube context, with a root scope - /// A root scope is at the root of a compute shader - /// Therefore there is one cube context per shader - pub fn root() -> CubeContext { - let root = Rc::new(RefCell::new(Scope::root())); - let scope = root.clone(); - - Self { - pool: Default::default(), - scope, - root, - } - } - - pub fn register>(&mut self, op: O) { - self.scope.borrow_mut().register(op) - } - - pub fn child(&mut self) -> CubeContext { - let scope = self.scope.borrow_mut().child(); - - Self { - scope: Rc::new(RefCell::new(scope)), - root: self.root.clone(), - pool: self.pool.clone(), - } - } - - pub fn into_scope(self) -> Scope { - core::mem::drop(self.root); - - Rc::into_inner(self.scope) - .expect("Only one reference") - .into_inner() - } - - /// When a new variable is required, we check if we can reuse an old one - /// Otherwise we create a new one. - pub fn create_local(&mut self, item: Item) -> ExpandElement { - // Reuse an old variable if possible - if let Some(var) = self.pool.reuse(item) { - return var; - } - - // Create a new variable at the root scope - // Insert it in the variable pool for potential reuse - let new = ExpandElement::Managed(Rc::new(self.root.borrow_mut().create_local(item))); - self.pool.insert(new.clone()); - - new - } - - /// Create a new matrix element. - pub fn create_matrix(&mut self, matrix: ir::Matrix) -> ExpandElement { - let variable = self.scope.borrow_mut().create_matrix(matrix); - ExpandElement::Plain(variable) - } - - /// Create a new slice element. - pub fn create_slice(&mut self, item: Item) -> ExpandElement { - let variable = self.scope.borrow_mut().create_slice(item); - ExpandElement::Plain(variable) - } - - pub fn create_shared(&mut self, item: Item, size: u32) -> ExpandElement { - ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size)) - } - - pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement { - ExpandElement::Plain(self.root.borrow_mut().create_local_array(item, size)) - } - - /// Obtain the index-th input - pub fn input(&mut self, id: u16, item: Item) -> ExpandElement { - ExpandElement::Plain(crate::ir::Variable::GlobalInputArray { id, item }) - } - - /// Obtain the index-th output - pub fn output(&mut self, id: u16, item: Item) -> ExpandElement { - let var = crate::ir::Variable::GlobalOutputArray { id, item }; - self.scope.borrow_mut().write_global_custom(var); - ExpandElement::Plain(var) - } - - /// Obtain the index-th scalar - pub fn scalar(&self, id: u16, elem: Elem) -> ExpandElement { - ExpandElement::Plain(crate::ir::Variable::GlobalScalar { id, elem }) - } -} diff --git a/crates/burn-cube/src/frontend/element/array.rs b/crates/burn-cube/src/frontend/element/array.rs deleted file mode 100644 index 5831a8d476..0000000000 --- a/crates/burn-cube/src/frontend/element/array.rs +++ /dev/null @@ -1,238 +0,0 @@ -use std::marker::PhantomData; - -use crate::{ - compute::{KernelBuilder, KernelLauncher}, - frontend::CubeType, - ir::{Item, Vectorization}, - unexpanded, KernelSettings, Runtime, -}; -use crate::{ - frontend::{indexation::Index, CubeContext}, - prelude::{assign, index, index_assign, Comptime}, -}; - -use super::{ - ArgSettings, CubePrimitive, ExpandElement, ExpandElementTyped, Init, LaunchArg, - LaunchArgExpand, TensorHandle, UInt, -}; - -/// A contiguous array of elements. -pub struct Array { - _val: PhantomData, -} - -impl CubeType for Array { - type ExpandType = ExpandElementTyped>; -} - -impl Array { - pub fn new(_size: S) -> Self { - Array { _val: PhantomData } - } - - pub fn vectorized(_size: S, _vectorization_factor: UInt) -> Self { - Array { _val: PhantomData } - } - - pub fn __expand_new( - context: &mut CubeContext, - size: S, - ) -> ::ExpandType { - let size = size.value(); - let size = match size { - crate::ir::Variable::ConstantScalar { value, .. } => value as u32, - _ => panic!("Array need constant initialization value"), - }; - context - .create_local_array(Item::new(T::as_elem()), size) - .into() - } - - pub fn __expand_vectorized( - context: &mut CubeContext, - size: S, - vectorization_factor: UInt, - ) -> ::ExpandType { - let size = size.value(); - let size = match size { - crate::ir::Variable::ConstantScalar { value, .. } => value as u32, - _ => panic!("Shared memory need constant initialization value"), - }; - context - .create_local_array( - Item::vectorized(T::as_elem(), vectorization_factor.val as u8), - size, - ) - .into() - } - - pub fn to_vectorized(self, _vectorization_factor: Comptime) -> T { - unexpanded!() - } -} - -impl ExpandElementTyped> { - pub fn to_vectorized_expand( - self, - context: &mut CubeContext, - vectorization_factor: UInt, - ) -> ExpandElement { - let factor = vectorization_factor.val; - let var = self.expand.clone(); - let mut new_var = context.create_local(Item::vectorized(var.item().elem(), factor as u8)); - if vectorization_factor.val == 1 { - let element = index::expand(context, self.clone(), 0u32); - assign::expand(context, element, new_var.clone()); - } else { - for i in 0..factor { - let element = index::expand(context, self.expand.clone(), i); - new_var = index_assign::expand(context, new_var, i, element); - } - } - new_var - } -} - -impl CubeType for &Array { - type ExpandType = ExpandElementTyped>; -} -impl Init for ExpandElementTyped> { - fn init(self, _context: &mut crate::prelude::CubeContext) -> Self { - // The type can't be deeply cloned/copied. - self - } -} - -impl Array { - /// Obtain the array length - pub fn len(&self) -> UInt { - unexpanded!() - } -} - -impl LaunchArg for Array { - type RuntimeArg<'a, R: Runtime> = ArrayArg<'a, R>; -} - -impl LaunchArgExpand for Array { - fn expand( - builder: &mut KernelBuilder, - vectorization: Vectorization, - ) -> ExpandElementTyped> { - builder - .input_array(Item::vectorized(C::as_elem(), vectorization)) - .into() - } - fn expand_output( - builder: &mut KernelBuilder, - vectorization: Vectorization, - ) -> ExpandElementTyped> { - builder - .output_array(Item::vectorized(C::as_elem(), vectorization)) - .into() - } -} - -/// Tensor representation with a reference to the [server handle](burn_compute::server::Handle). -pub struct ArrayHandle<'a, R: Runtime> { - pub handle: &'a burn_compute::server::Handle, - pub length: [usize; 1], -} - -pub enum ArrayArg<'a, R: Runtime> { - /// The array is passed with an array handle. - Handle { - /// The array handle. - handle: ArrayHandle<'a, R>, - /// The vectorization factor. - vectorization_factor: u8, - }, - /// The array is aliasing another input array. - Alias { - /// The position of the input array. - input_pos: usize, - }, -} - -impl<'a, R: Runtime> ArgSettings for ArrayArg<'a, R> { - fn register(&self, launcher: &mut KernelLauncher) { - if let ArrayArg::Handle { - handle, - vectorization_factor: _, - } = self - { - launcher.register_array(handle) - } - } - - fn configure_input(&self, position: usize, settings: KernelSettings) -> KernelSettings { - match self { - Self::Handle { - handle: _, - vectorization_factor, - } => settings.vectorize_input(position, *vectorization_factor), - Self::Alias { input_pos: _ } => { - panic!("Not yet supported, only output can be aliased for now."); - } - } - } - - fn configure_output(&self, position: usize, mut settings: KernelSettings) -> KernelSettings { - match self { - Self::Handle { - handle: _, - vectorization_factor, - } => settings.vectorize_output(position, *vectorization_factor), - Self::Alias { input_pos } => { - settings.mappings.push(crate::InplaceMapping { - pos_input: *input_pos, - pos_output: position, - }); - settings - } - } - } -} - -impl<'a, R: Runtime> ArrayArg<'a, R> { - /// Create a new array argument. - /// - /// Equivalent to using the [vectorized constructor](Self::vectorized) with a vectorization - /// factor of 1. - pub fn new(handle: &'a burn_compute::server::Handle, length: usize) -> Self { - ArrayArg::Handle { - handle: ArrayHandle::new(handle, length), - vectorization_factor: 1, - } - } - /// Create a new array argument specified with its vectorization factor. - pub fn vectorized( - vectorization_factor: u8, - handle: &'a burn_compute::server::Handle, - length: usize, - ) -> Self { - ArrayArg::Handle { - handle: ArrayHandle::new(handle, length), - vectorization_factor, - } - } -} - -impl<'a, R: Runtime> ArrayHandle<'a, R> { - pub fn new(handle: &'a burn_compute::server::Handle, length: usize) -> Self { - Self { - handle, - length: [length], - } - } - - pub fn as_tensor(&self) -> TensorHandle<'_, R> { - let shape = &self.length; - - TensorHandle { - handle: self.handle, - strides: &[1], - shape, - } - } -} diff --git a/crates/burn-cube/src/frontend/element/base.rs b/crates/burn-cube/src/frontend/element/base.rs deleted file mode 100644 index f9de7ad4fe..0000000000 --- a/crates/burn-cube/src/frontend/element/base.rs +++ /dev/null @@ -1,277 +0,0 @@ -use std::marker::PhantomData; - -use crate::{ - ir::{Operator, Variable, Vectorization}, - prelude::{init_expand, CubeContext, KernelBuilder, KernelLauncher}, - KernelSettings, Runtime, -}; -use alloc::rc::Rc; - -use super::{UInt, Vectorized}; - -/// Types used in a cube function must implement this trait -/// -/// Variables whose values will be known at runtime must -/// have ExpandElement as associated type -/// Variables whose values will be known at compile time -/// must have the primitive type as associated type -/// -/// Note: Cube functions should be written using CubeTypes, -/// so that the code generated uses the associated ExpandType. -/// This allows Cube code to not necessitate cloning, which is cumbersome -/// in algorithmic code. The necessary cloning will automatically appear in -/// the generated code. -pub trait CubeType { - type ExpandType: Clone + Init; -} - -/// Trait to be implemented by [cube types](CubeType) implementations. -pub trait Init: Sized { - /// Initialize a type within a [context](CubeContext). - /// - /// You can return the same value when the variable is a non-mutable data structure or - /// if the type can not be deeply cloned/copied. - fn init(self, context: &mut CubeContext) -> Self; -} - -/// Defines how a [launch argument](LaunchArg) can be expanded. -/// -/// Normally this type should be implemented two times for an argument. -/// Once for the reference and the other for the mutable reference. Often time, the reference -/// should expand the argument as an input while the mutable reference should expand the argument -/// as an output. -pub trait LaunchArgExpand: CubeType { - /// Register an input variable during compilation that fill the [KernelBuilder]. - fn expand( - builder: &mut KernelBuilder, - vectorization: Vectorization, - ) -> ::ExpandType; - /// Register an output variable during compilation that fill the [KernelBuilder]. - fn expand_output( - builder: &mut KernelBuilder, - vectorization: Vectorization, - ) -> ::ExpandType { - Self::expand(builder, vectorization) - } -} - -/// Defines a type that can be used as argument to a kernel. -pub trait LaunchArg: LaunchArgExpand + Send + Sync + 'static { - /// The runtime argument for the kernel. - type RuntimeArg<'a, R: Runtime>: ArgSettings; -} - -impl LaunchArg for () { - type RuntimeArg<'a, R: Runtime> = (); -} - -impl ArgSettings for () { - fn register(&self, _launcher: &mut KernelLauncher) { - // nothing to do - } -} - -impl LaunchArgExpand for () { - fn expand( - _builder: &mut KernelBuilder, - _vectorization: Vectorization, - ) -> ::ExpandType { - } -} - -impl CubeType for () { - type ExpandType = (); -} - -impl Init for () { - fn init(self, _context: &mut CubeContext) -> Self { - self - } -} - -/// Defines the argument settings used to launch a kernel. -pub trait ArgSettings: Send + Sync { - /// Register the information to the [KernelLauncher]. - fn register(&self, launcher: &mut KernelLauncher); - /// Configure an input argument at the given position. - fn configure_input(&self, _position: usize, settings: KernelSettings) -> KernelSettings { - settings - } - /// Configure an output argument at the given position. - fn configure_output(&self, _position: usize, settings: KernelSettings) -> KernelSettings { - settings - } -} - -/// Reference to a JIT variable -#[derive(Clone, Debug)] -pub enum ExpandElement { - /// Variable kept in the variable pool. - Managed(Rc), - /// Variable not kept in the variable pool. - Plain(Variable), -} - -/// Expand type associated with a type. -#[derive(new)] -pub struct ExpandElementTyped { - pub(crate) expand: ExpandElement, - pub(crate) _type: PhantomData, -} - -impl Vectorized for ExpandElementTyped { - fn vectorization_factor(&self) -> UInt { - self.expand.vectorization_factor() - } - - fn vectorize(self, factor: UInt) -> Self { - Self { - expand: self.expand.vectorize(factor), - _type: PhantomData, - } - } -} - -impl Clone for ExpandElementTyped { - fn clone(&self) -> Self { - Self { - expand: self.expand.clone(), - _type: PhantomData, - } - } -} - -impl From for ExpandElementTyped { - fn from(expand: ExpandElement) -> Self { - Self { - expand, - _type: PhantomData, - } - } -} - -impl From> for ExpandElement { - fn from(value: ExpandElementTyped) -> Self { - value.expand - } -} - -impl ExpandElement { - pub fn can_mut(&self) -> bool { - match self { - ExpandElement::Managed(var) => { - if let Variable::Local { .. } = var.as_ref() { - Rc::strong_count(var) <= 2 - } else { - false - } - } - ExpandElement::Plain(_) => false, - } - } -} - -impl core::ops::Deref for ExpandElement { - type Target = Variable; - - fn deref(&self) -> &Self::Target { - match self { - ExpandElement::Managed(var) => var.as_ref(), - ExpandElement::Plain(var) => var, - } - } -} - -impl From for Variable { - fn from(value: ExpandElement) -> Self { - match value { - ExpandElement::Managed(var) => *var, - ExpandElement::Plain(var) => var, - } - } -} - -impl Init for ExpandElement { - fn init(self, context: &mut CubeContext) -> Self { - if self.can_mut() { - // Can reuse inplace :) - return self; - } - - let mut init = |elem: Self| init_expand(context, elem, Operator::Assign); - - match *self { - Variable::GlobalScalar { .. } => init(self), - Variable::LocalScalar { .. } => init(self), - Variable::ConstantScalar { .. } => init(self), - Variable::Local { .. } => init(self), - // Constant should be initialized since the new variable can be mutated afterward. - // And it is assumed those values are cloned. - Variable::Rank - | Variable::UnitPos - | Variable::UnitPosX - | Variable::UnitPosY - | Variable::UnitPosZ - | Variable::CubePos - | Variable::CubePosX - | Variable::CubePosY - | Variable::CubePosZ - | Variable::CubeDim - | Variable::CubeDimX - | Variable::CubeDimY - | Variable::CubeDimZ - | Variable::CubeCount - | Variable::CubeCountX - | Variable::CubeCountY - | Variable::CubeCountZ - | Variable::SubcubeDim - | Variable::AbsolutePos - | Variable::AbsolutePosX - | Variable::AbsolutePosY - | Variable::AbsolutePosZ => init(self), - // Array types can't be copied, so we should simply return the same variable. - Variable::SharedMemory { .. } - | Variable::GlobalInputArray { .. } - | Variable::GlobalOutputArray { .. } - | Variable::LocalArray { .. } - | Variable::Slice { .. } - | Variable::Matrix { .. } => self, - } - } -} - -macro_rules! impl_init_for { - ($($t:ty),*) => { - $( - impl Init for $t { - fn init(self, _context: &mut CubeContext) -> Self { - panic!("Shouln't be called, only for comptime.") - } - } - - )* - }; -} - -// Add all types used within comptime -impl_init_for!(u32, bool, UInt); - -impl Init for Option { - fn init(self, context: &mut CubeContext) -> Self { - self.map(|o| Init::init(o, context)) - } -} - -impl CubeType for Vec { - type ExpandType = Vec; -} - -impl CubeType for &mut Vec { - type ExpandType = Vec; -} - -impl Init for Vec { - fn init(self, context: &mut CubeContext) -> Self { - self.into_iter().map(|e| e.init(context)).collect() - } -} diff --git a/crates/burn-cube/src/frontend/element/bool.rs b/crates/burn-cube/src/frontend/element/bool.rs deleted file mode 100644 index e5e2675b55..0000000000 --- a/crates/burn-cube/src/frontend/element/bool.rs +++ /dev/null @@ -1,28 +0,0 @@ -use crate::frontend::{CubePrimitive, CubeType, ExpandElement}; -use crate::ir::Elem; - -use super::Vectorized; - -// To be consistent with other primitive type. -/// Boolean type. -pub type Bool = bool; - -impl CubeType for bool { - type ExpandType = ExpandElement; -} - -impl CubePrimitive for bool { - fn as_elem() -> Elem { - Elem::Bool - } -} - -impl Vectorized for bool { - fn vectorization_factor(&self) -> crate::prelude::UInt { - todo!() - } - - fn vectorize(self, _factor: crate::prelude::UInt) -> Self { - todo!() - } -} diff --git a/crates/burn-cube/src/frontend/element/cast.rs b/crates/burn-cube/src/frontend/element/cast.rs deleted file mode 100644 index 4187b510a5..0000000000 --- a/crates/burn-cube/src/frontend/element/cast.rs +++ /dev/null @@ -1,31 +0,0 @@ -use crate::frontend::{assign, CubeContext, CubePrimitive, CubeType}; -use crate::ir::{Item, Variable}; -use crate::{frontend::ExpandElement, unexpanded}; - -/// Enable elegant casting from any to any CubeElem -pub trait Cast: CubePrimitive { - fn cast_from(value: From) -> Self; - - fn __expand_cast_from( - context: &mut CubeContext, - value: From, - ) -> ::ExpandType - where - From: Into, - { - let value: ExpandElement = value.into(); - let var: Variable = *value; - let new_var = context.create_local(Item::vectorized( - ::as_elem(), - var.item().vectorization, - )); - assign::expand(context, value, new_var.clone()); - new_var - } -} - -impl Cast for P { - fn cast_from(_value: From) -> Self { - unexpanded!() - } -} diff --git a/crates/burn-cube/src/frontend/element/cube_elem.rs b/crates/burn-cube/src/frontend/element/cube_elem.rs deleted file mode 100644 index e949171f52..0000000000 --- a/crates/burn-cube/src/frontend/element/cube_elem.rs +++ /dev/null @@ -1,49 +0,0 @@ -use crate::frontend::UInt; -use crate::frontend::{CubeType, ExpandElement}; -use crate::ir::{Elem, Variable}; - -use super::Vectorized; - -/// Form of CubeType that encapsulates all primitive types: -/// Numeric, UInt, Bool -pub trait CubePrimitive: - CubeType - + Vectorized - + core::cmp::Eq - + core::cmp::PartialEq - + Send - + Sync - + 'static - + Clone - + Copy -{ - /// Return the element type to use on GPU - fn as_elem() -> Elem; -} - -macro_rules! impl_into_expand_element { - ($type:ty) => { - impl From<$type> for ExpandElement { - fn from(value: $type) -> Self { - ExpandElement::Plain(Variable::from(value)) - } - } - }; -} - -impl_into_expand_element!(u32); -impl_into_expand_element!(usize); -impl_into_expand_element!(bool); -impl_into_expand_element!(f32); -impl_into_expand_element!(i32); -impl_into_expand_element!(i64); - -/// Useful for Comptime -impl From for ExpandElement { - fn from(value: UInt) -> Self { - ExpandElement::Plain(crate::ir::Variable::ConstantScalar { - value: value.val as f64, - elem: UInt::as_elem(), - }) - } -} diff --git a/crates/burn-cube/src/frontend/element/float.rs b/crates/burn-cube/src/frontend/element/float.rs deleted file mode 100644 index c8dec62abe..0000000000 --- a/crates/burn-cube/src/frontend/element/float.rs +++ /dev/null @@ -1,233 +0,0 @@ -use half::{bf16, f16}; - -use crate::frontend::{Ceil, Cos, Erf, Exp, Floor, Log, Log1p, Powf, Recip, Sin, Sqrt, Tanh}; -use crate::frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement, Numeric}; -use crate::ir::{Elem, FloatKind, Item, Variable, Vectorization}; - -use crate::compute::{KernelBuilder, KernelLauncher}; -use crate::prelude::index_assign; -use crate::{unexpanded, Runtime}; - -use super::{LaunchArgExpand, ScalarArgSettings, UInt, Vectorized}; - -/// Floating point numbers. Used as input in float kernels -pub trait Float: - Numeric - + Exp - + Log - + Log1p - + Cos - + Sin - + Tanh - + Powf - + Sqrt - + Floor - + Ceil - + Erf - + Recip - + core::ops::Index - + core::ops::IndexMut -{ - fn new(val: f32) -> Self; - fn vectorized(val: f32, vectorization: UInt) -> Self; - fn vectorized_empty(vectorization: UInt) -> Self; - fn __expand_new(context: &mut CubeContext, val: f32) -> ::ExpandType; - fn __expand_vectorized( - context: &mut CubeContext, - val: f32, - vectorization: UInt, - ) -> ::ExpandType; - fn __expand_vectorized_empty( - context: &mut CubeContext, - vectorization: UInt, - ) -> ::ExpandType; -} - -macro_rules! impl_float { - ($type:ident, $primitive:ty) => { - #[derive(Clone, Copy)] - pub struct $type { - pub val: f32, - pub vectorization: u8, - } - - impl CubeType for $type { - type ExpandType = ExpandElement; - } - - impl CubePrimitive for $type { - /// Return the element type to use on GPU - fn as_elem() -> Elem { - Elem::Float(FloatKind::$type) - } - } - - impl Numeric for $type { - type Primitive = $primitive; - } - - impl Float for $type { - fn new(val: f32) -> Self { - Self { - val, - vectorization: 1, - } - } - - fn vectorized(val: f32, vectorization: UInt) -> Self { - if vectorization.val == 1 { - Self::new(val) - } else { - Self { - val, - vectorization: vectorization.val as u8, - } - } - } - - fn vectorized_empty(vectorization: UInt) -> Self { - Self::vectorized(0., vectorization) - } - - fn __expand_new( - _context: &mut CubeContext, - val: f32, - ) -> ::ExpandType { - let new_var = Variable::ConstantScalar { - value: val as f64, - elem: Self::as_elem(), - }; - ExpandElement::Plain(new_var) - } - - fn __expand_vectorized( - context: &mut CubeContext, - val: f32, - vectorization: UInt, - ) -> ::ExpandType { - if vectorization.val == 1 { - Self::__expand_new(context, val) - } else { - let mut new_var = context - .create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8)); - for (i, element) in vec![val; vectorization.val as usize].iter().enumerate() { - new_var = index_assign::expand(context, new_var, i, *element); - } - - new_var - } - } - - fn __expand_vectorized_empty( - context: &mut CubeContext, - vectorization: UInt, - ) -> ::ExpandType { - if vectorization.val == 1 { - Self::__expand_new(context, 0.) - } else { - context.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8)) - } - } - } - - impl core::ops::Index for $type { - type Output = Self; - - fn index(&self, _index: UInt) -> &Self::Output { - unexpanded!() - } - } - - impl core::ops::IndexMut for $type { - fn index_mut(&mut self, _index: UInt) -> &mut Self::Output { - unexpanded!() - } - } - - impl LaunchArgExpand for $type { - fn expand(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement { - assert_eq!(vectorization, 1, "Attempted to vectorize a scalar"); - builder.scalar($type::as_elem()) - } - } - - impl Vectorized for $type { - fn vectorization_factor(&self) -> UInt { - UInt { - val: self.vectorization as u32, - vectorization: 1, - } - } - - fn vectorize(mut self, factor: UInt) -> Self { - self.vectorization = factor.vectorization; - self - } - } - }; -} - -impl_float!(F16, f16); -impl_float!(BF16, bf16); -impl_float!(F32, f32); -impl_float!(F64, f64); - -impl From for F32 { - fn from(value: f32) -> Self { - Self { - val: value, - vectorization: 1, - } - } -} - -impl From for BF16 { - fn from(value: f32) -> Self { - Self { - val: value, - vectorization: 1, - } - } -} - -impl From for F16 { - fn from(value: f32) -> Self { - Self { - val: value, - vectorization: 1, - } - } -} - -impl From for F64 { - fn from(value: f32) -> Self { - Self { - val: value, - vectorization: 1, - } - } -} - -impl ScalarArgSettings for f16 { - fn register(&self, settings: &mut KernelLauncher) { - settings.register_f16(*self); - } -} - -impl ScalarArgSettings for bf16 { - fn register(&self, settings: &mut KernelLauncher) { - settings.register_bf16(*self); - } -} - -impl ScalarArgSettings for f32 { - fn register(&self, settings: &mut KernelLauncher) { - settings.register_f32(*self); - } -} - -impl ScalarArgSettings for f64 { - fn register(&self, settings: &mut KernelLauncher) { - settings.register_f64(*self); - } -} diff --git a/crates/burn-cube/src/frontend/element/int.rs b/crates/burn-cube/src/frontend/element/int.rs deleted file mode 100644 index d5a92f73c7..0000000000 --- a/crates/burn-cube/src/frontend/element/int.rs +++ /dev/null @@ -1,146 +0,0 @@ -use crate::compute::{KernelBuilder, KernelLauncher}; -use crate::frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement, Numeric}; -use crate::ir::{Elem, IntKind, Item, Variable, Vectorization}; -use crate::prelude::index_assign; -use crate::Runtime; - -use super::{LaunchArgExpand, ScalarArgSettings, UInt, Vectorized}; - -/// Signed integer. Used as input in int kernels -pub trait Int: Numeric + std::ops::Rem { - fn new(val: i64) -> Self; - fn vectorized(val: i64, vectorization: UInt) -> Self; - fn __expand_new(context: &mut CubeContext, val: i64) -> ::ExpandType; - fn __expand_vectorized( - context: &mut CubeContext, - val: i64, - vectorization: UInt, - ) -> ::ExpandType; -} - -macro_rules! impl_int { - ($type:ident, $primitive:ty) => { - #[derive(Clone, Copy)] - pub struct $type { - pub val: $primitive, - pub vectorization: u8, - } - - impl CubeType for $type { - type ExpandType = ExpandElement; - } - - impl CubePrimitive for $type { - fn as_elem() -> Elem { - Elem::Int(IntKind::$type) - } - } - - impl Numeric for $type { - type Primitive = $primitive; - } - - impl Int for $type { - fn new(val: i64) -> Self { - Self { - val: val as $primitive, - vectorization: 1, - } - } - - fn vectorized(val: i64, vectorization: UInt) -> Self { - if vectorization.val == 1 { - Self::new(val) - } else { - Self { - val: val as $primitive, - vectorization: vectorization.val as u8, - } - } - } - - fn __expand_new( - _context: &mut CubeContext, - val: i64, - ) -> ::ExpandType { - let new_var = Variable::ConstantScalar { - value: val as f64, - elem: Self::as_elem(), - }; - ExpandElement::Plain(new_var) - } - - fn __expand_vectorized( - context: &mut CubeContext, - val: i64, - vectorization: UInt, - ) -> ::ExpandType { - if vectorization.val == 1 { - Self::__expand_new(context, val) - } else { - let mut new_var = context - .create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8)); - for (i, element) in vec![val; vectorization.val as usize].iter().enumerate() { - new_var = index_assign::expand(context, new_var, i, *element); - } - - new_var - } - } - } - - impl LaunchArgExpand for $type { - fn expand(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement { - assert_eq!(vectorization, 1, "Attempted to vectorize a scalar"); - builder.scalar($type::as_elem()) - } - } - - impl Vectorized for $type { - fn vectorization_factor(&self) -> UInt { - UInt { - val: self.vectorization as u32, - vectorization: 1, - } - } - - fn vectorize(mut self, factor: UInt) -> Self { - self.vectorization = factor.vectorization; - self - } - } - }; -} - -impl_int!(I32, i32); -impl_int!(I64, i64); - -impl From for I64 { - fn from(value: i64) -> Self { - Self { - val: value, - vectorization: 1, - } - } -} - -impl From for I32 { - fn from(value: i32) -> Self { - Self { - val: value, - vectorization: 1, - } - } -} - -impl ScalarArgSettings for i32 { - fn register(&self, settings: &mut KernelLauncher) { - settings.register_i32(*self); - } -} - -impl ScalarArgSettings for i64 { - fn register(&self, settings: &mut KernelLauncher) { - settings.register_i64(*self); - } -} diff --git a/crates/burn-cube/src/frontend/element/mod.rs b/crates/burn-cube/src/frontend/element/mod.rs deleted file mode 100644 index d67fb86474..0000000000 --- a/crates/burn-cube/src/frontend/element/mod.rs +++ /dev/null @@ -1,27 +0,0 @@ -mod array; -mod base; -mod bool; -mod cast; -mod cube_elem; -mod float; -mod int; -mod numeric; -mod shared_memory; -mod slice; -mod tensor; -mod uint; -mod vectorized; - -pub use array::*; -pub use base::*; -pub use bool::*; -pub use cast::*; -pub use cube_elem::*; -pub use float::*; -pub use int::*; -pub use numeric::*; -pub use shared_memory::*; -pub use slice::*; -pub use tensor::*; -pub use uint::*; -pub use vectorized::*; diff --git a/crates/burn-cube/src/frontend/element/numeric.rs b/crates/burn-cube/src/frontend/element/numeric.rs deleted file mode 100644 index 3c92a2eeb5..0000000000 --- a/crates/burn-cube/src/frontend/element/numeric.rs +++ /dev/null @@ -1,93 +0,0 @@ -use crate::compute::KernelLauncher; -use crate::frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement}; -use crate::ir::{Item, Variable}; -use crate::prelude::Clamp; -use crate::Runtime; -use crate::{ - frontend::{index_assign, Abs, Max, Min, Remainder}, - unexpanded, -}; - -use super::{ArgSettings, LaunchArg, LaunchArgExpand}; - -/// Type that encompasses both (unsigned or signed) integers and floats -/// Used in kernels that should work for both. -pub trait Numeric: - Copy - + CubePrimitive - + LaunchArgExpand - + std::ops::Add - + std::ops::AddAssign - + std::ops::SubAssign - + std::ops::MulAssign - + std::ops::DivAssign - + std::ops::Sub - + std::ops::Mul - + std::ops::Div - + std::cmp::PartialOrd - + Abs - + Max - + Min - + Clamp - + Remainder -{ - /// Create a new constant numeric. - /// - /// Note: since this must work for both integer and float - /// only the less expressive of both can be created (int) - /// If a number with decimals is needed, use Float::new. - /// - /// This method panics when unexpanded. For creating an element - /// with a val, use the new method of the sub type. - fn from_int(_val: i64) -> Self { - unexpanded!() - } - - type Primitive: ScalarArgSettings; - - fn from_vec(_vec: [i64; D]) -> Self { - unexpanded!() - } - - fn __expand_from_int(_context: &mut CubeContext, val: i64) -> ::ExpandType { - let new_var = Variable::ConstantScalar { - value: val as f64, - elem: Self::as_elem(), - }; - ExpandElement::Plain(new_var) - } - - fn __expand_from_vec( - context: &mut CubeContext, - vec: [i64; D], - ) -> ::ExpandType { - let mut new_var = context.create_local(Item::vectorized(Self::as_elem(), vec.len() as u8)); - for (i, element) in vec.iter().enumerate() { - new_var = index_assign::expand(context, new_var, i, *element); - } - - new_var - } -} - -/// Similar to [ArgSettings], however only for scalar types that don't depend on the [Runtime] -/// trait. -pub trait ScalarArgSettings: Send + Sync { - /// Register the information to the [KernelLauncher]. - fn register(&self, launcher: &mut KernelLauncher); -} - -#[derive(new)] -pub struct ScalarArg { - elem: T::Primitive, -} - -impl ArgSettings for ScalarArg { - fn register(&self, launcher: &mut crate::compute::KernelLauncher) { - self.elem.register(launcher); - } -} - -impl LaunchArg for T { - type RuntimeArg<'a, R: Runtime> = ScalarArg; -} diff --git a/crates/burn-cube/src/frontend/element/shared_memory.rs b/crates/burn-cube/src/frontend/element/shared_memory.rs deleted file mode 100644 index 3ad49c3300..0000000000 --- a/crates/burn-cube/src/frontend/element/shared_memory.rs +++ /dev/null @@ -1,63 +0,0 @@ -use std::marker::PhantomData; - -use crate::{ - frontend::{indexation::Index, CubeContext, CubePrimitive, CubeType}, - ir::Item, -}; - -use super::{ExpandElementTyped, Init, UInt}; - -#[derive(Clone, Copy)] -pub struct SharedMemory { - _val: PhantomData, -} - -impl Init for ExpandElementTyped> { - fn init(self, _context: &mut CubeContext) -> Self { - self - } -} - -impl CubeType for SharedMemory { - type ExpandType = ExpandElementTyped>; -} - -impl SharedMemory { - pub fn new(_size: S) -> Self { - SharedMemory { _val: PhantomData } - } - - pub fn vectorized(_size: S, _vectorization_factor: UInt) -> Self { - SharedMemory { _val: PhantomData } - } - - pub fn __expand_vectorized( - context: &mut CubeContext, - size: S, - vectorization_factor: UInt, - ) -> ::ExpandType { - let size = size.value(); - let size = match size { - crate::ir::Variable::ConstantScalar { value, .. } => value as u32, - _ => panic!("Shared memory need constant initialization value"), - }; - let var = context.create_shared( - Item::vectorized(T::as_elem(), vectorization_factor.val as u8), - size, - ); - ExpandElementTyped::new(var) - } - - pub fn __expand_new( - context: &mut CubeContext, - size: S, - ) -> ::ExpandType { - let size = size.value(); - let size = match size { - crate::ir::Variable::ConstantScalar { value, .. } => value as u32, - _ => panic!("Shared memory need constant initialization value"), - }; - let var = context.create_shared(Item::new(T::as_elem()), size); - ExpandElementTyped::new(var) - } -} diff --git a/crates/burn-cube/src/frontend/element/slice.rs b/crates/burn-cube/src/frontend/element/slice.rs deleted file mode 100644 index fea7e801d9..0000000000 --- a/crates/burn-cube/src/frontend/element/slice.rs +++ /dev/null @@ -1,285 +0,0 @@ -use std::marker::PhantomData; - -use super::{ - Array, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, SharedMemory, Tensor, - UInt, -}; -use crate::{ - frontend::indexation::Index, - ir::{self, Operator}, - prelude::CubeContext, - unexpanded, -}; - -/// A read-only contiguous list of elements -pub struct Slice<'a, E> { - _e: PhantomData, - _l: &'a (), -} - -/// A read-write contiguous list of elements. -pub struct SliceMut<'a, E> { - _e: PhantomData, - _l: &'a mut (), -} - -impl<'a, E> Slice<'a, E> { - /// Get the length of the slice. - pub fn len(&self) -> UInt { - unexpanded!() - } -} - -impl<'a, E> SliceMut<'a, E> { - /// Get the length of the slice. - pub fn len(&self) -> UInt { - unexpanded!() - } -} - -impl<'a, E: CubeType> CubeType for Slice<'a, E> { - type ExpandType = ExpandElementTyped>; -} - -impl<'a, C: CubeType> Init for ExpandElementTyped> { - fn init(self, _context: &mut crate::prelude::CubeContext) -> Self { - // The type can't be deeply cloned/copied. - self - } -} - -impl<'a, E: CubeType> CubeType for SliceMut<'a, E> { - type ExpandType = ExpandElementTyped>; -} - -impl<'a, C: CubeType> Init for ExpandElementTyped> { - fn init(self, _context: &mut crate::prelude::CubeContext) -> Self { - // The type can't be deeply cloned/copied. - self - } -} - -pub trait SliceOperator: CubeType { - type Expand: SliceOperatorExpand; - - /// Return a read-only view of all elements comprise between the start and end index. - #[allow(unused_variables)] - fn slice(&self, start: Start, end: End) -> &'_ Slice<'_, E> { - unexpanded!() - } - /// Expand function of [SliceOperator::slice]. - fn slice_expand( - context: &mut CubeContext, - expand: Self::Expand, - start: Start, - end: End, - ) -> ExpandElementTyped> { - expand.slice_expand(context, start, end) - } - - /// Return a read-write view of all elements comprise between the start and end index. - #[allow(unused_variables)] - fn slice_mut( - &mut self, - start: Start, - end: End, - ) -> &'_ mut SliceMut<'_, E> { - unexpanded!() - } - - /// Expand function of [SliceOperator::slice_mut]. - fn slice_mut_expand( - context: &mut CubeContext, - expand: Self::Expand, - start: Start, - end: End, - ) -> ExpandElementTyped> { - expand.slice_mut_expand(context, start, end) - } - - /// Return a read-write view of all elements comprise between the start and end index. - /// - /// # Warning - /// - /// Ignore the multiple borrow rule. - #[allow(unused_variables)] - fn slice_mut_unsafe( - &self, - start: Start, - end: End, - ) -> SliceMut<'static, E> { - unexpanded!() - } - - /// Expand function of [SliceOperator::slice_mut_unsafe]. - fn slice_mut_unsafe_expand( - context: &mut CubeContext, - expand: Self::Expand, - start: Start, - end: End, - ) -> ExpandElementTyped> { - expand.slice_mut_unsafe_expand(context, start, end) - } - - /// Reinterprete the current type as a read-only slice. - #[allow(unused_variables)] - fn as_slice(&self) -> &'_ Slice<'_, E> { - unexpanded!() - } - - /// Expand function of [SliceOperator::as_slice]. - fn as_slice_expand( - context: &mut CubeContext, - expand: Self::Expand, - ) -> ExpandElementTyped> { - expand.as_slice_expand(context) - } - - /// Reinterprete the current type as a read-write slice. - #[allow(unused_variables)] - fn as_slice_mut(&mut self) -> &'_ mut SliceMut<'_, E> { - unexpanded!() - } - - /// Expand function of [SliceOperator::as_slice_mut]. - fn as_slice_mut_expand( - context: &mut CubeContext, - expand: Self::Expand, - ) -> ExpandElementTyped> { - expand.as_slice_mut_expand(context) - } - - /// Reinterprete the current type as a read-write slice. - /// - /// # Warning - /// - /// Ignore the multiple borrow rule. - #[allow(unused_variables)] - fn as_slice_mut_unsafe(&self) -> SliceMut<'static, E> { - unexpanded!() - } - - /// Expand function of [SliceOperator::as_slice_mut_unsafe]. - fn as_slice_mut_unsafe_expand( - context: &mut CubeContext, - expand: Self::Expand, - ) -> ExpandElementTyped> { - expand.as_slice_mut_unsafe_expand(context) - } -} - -pub trait SliceOperatorExpand: Into + Clone { - fn slice_base( - &self, - context: &mut CubeContext, - start: Start, - end: End, - ) -> ExpandElement; - - fn slice_expand( - &self, - context: &mut CubeContext, - start: Start, - end: End, - ) -> ExpandElementTyped> { - ExpandElementTyped::new(self.slice_base(context, start, end)) - } - - fn slice_mut_expand( - &self, - context: &mut CubeContext, - start: Start, - end: End, - ) -> ExpandElementTyped> { - ExpandElementTyped::new(self.slice_base(context, start, end)) - } - - fn slice_mut_unsafe_expand( - &self, - context: &mut CubeContext, - start: Start, - end: End, - ) -> ExpandElementTyped> { - ExpandElementTyped::new(self.slice_base(context, start, end)) - } - - fn as_slice_expand(&self, _context: &mut CubeContext) -> ExpandElementTyped> { - let expand = self.clone().into(); - ExpandElementTyped::new(expand) - } - - fn as_slice_mut_unsafe_expand( - &self, - context: &mut CubeContext, - ) -> ExpandElementTyped> { - self.as_slice_mut_expand(context) - } - - fn as_slice_mut_expand( - &self, - _context: &mut CubeContext, - ) -> ExpandElementTyped> { - let expand = self.clone().into(); - ExpandElementTyped::new(expand) - } -} - -macro_rules! slice_op { - ($type:ident) => { - impl SliceOperator for $type { - type Expand = ExpandElementTyped<$type>; - } - - impl SliceOperatorExpand for ExpandElementTyped<$type> { - fn slice_base( - &self, - context: &mut CubeContext, - start: Start, - end: End, - ) -> ExpandElement { - slice_expand(context, self.clone(), start, end) - } - } - }; - (slice $type:ident) => { - impl<'a, E: CubePrimitive> SliceOperator for $type<'a, E> { - type Expand = ExpandElementTyped<$type<'static, E>>; - } - - impl<'a, E: CubePrimitive> SliceOperatorExpand for ExpandElementTyped<$type<'a, E>> { - fn slice_base( - &self, - context: &mut CubeContext, - start: Start, - end: End, - ) -> ExpandElement { - slice_expand(context, self.clone(), start, end) - } - } - }; -} - -slice_op!(Array); -slice_op!(Tensor); -slice_op!(SharedMemory); -slice_op!(slice Slice); -slice_op!(slice SliceMut); - -pub fn slice_expand, S1: Index, S2: Index>( - context: &mut CubeContext, - input: I, - start: S1, - end: S2, // Todo use it to get the length. -) -> ExpandElement { - let input = input.into(); - let out = context.create_slice(input.item()); - - context.register(Operator::Slice(ir::SliceOperator { - input: *input, - start: start.value(), - end: end.value(), - out: *out, - })); - - out -} diff --git a/crates/burn-cube/src/frontend/element/tensor.rs b/crates/burn-cube/src/frontend/element/tensor.rs deleted file mode 100644 index 7a61422c43..0000000000 --- a/crates/burn-cube/src/frontend/element/tensor.rs +++ /dev/null @@ -1,216 +0,0 @@ -use super::{ExpandElementTyped, Init, LaunchArgExpand}; -use crate::{ - frontend::{ - indexation::Index, ArgSettings, CubeContext, CubePrimitive, CubeType, ExpandElement, UInt, - }, - ir::{Elem, Item, Metadata, Variable, Vectorization}, - prelude::{KernelBuilder, KernelLauncher}, - unexpanded, KernelSettings, LaunchArg, Runtime, -}; -use std::marker::PhantomData; - -/// The tensor type is similar to the [array type](crate::prelude::Array), however it comes with more -/// metadata such as [stride](Tensor::stride) and [shape](Tensor::shape). -#[derive(new)] -pub struct Tensor { - _val: PhantomData, -} - -impl CubeType for Tensor { - type ExpandType = ExpandElementTyped>; -} - -impl Init for ExpandElementTyped> { - fn init(self, _context: &mut crate::prelude::CubeContext) -> Self { - // The type can't be deeply cloned/copied. - self - } -} - -impl LaunchArgExpand for Tensor { - fn expand( - builder: &mut KernelBuilder, - vectorization: Vectorization, - ) -> ExpandElementTyped> { - builder - .input_array(Item::vectorized(C::as_elem(), vectorization)) - .into() - } - fn expand_output( - builder: &mut KernelBuilder, - vectorization: Vectorization, - ) -> ExpandElementTyped> { - builder - .output_array(Item::vectorized(C::as_elem(), vectorization)) - .into() - } -} - -impl LaunchArg for Tensor { - type RuntimeArg<'a, R: Runtime> = TensorArg<'a, R>; -} - -/// Tensor representation with a reference to the [server handle](burn_compute::server::Handle), -/// the strides and the shape. -#[derive(new)] -pub struct TensorHandle<'a, R: Runtime> { - pub handle: &'a burn_compute::server::Handle, - pub strides: &'a [usize], - pub shape: &'a [usize], -} - -/// Argument to be used for [tensors](Tensor) passed as arguments to kernels. -pub enum TensorArg<'a, R: Runtime> { - /// The tensor is passed with a tensor handle. - Handle { - /// The tensor handle. - handle: TensorHandle<'a, R>, - /// The vectorization factor. - vectorization_factor: u8, - }, - /// The tensor is aliasing another input tensor. - Alias { - /// The position of the input tensor. - input_pos: usize, - }, -} - -impl<'a, R: Runtime> TensorArg<'a, R> { - /// Create a new tensor argument. - /// - /// Equivalent to using the [vectorized constructor](Self::vectorized) with a vectorization - /// factor of 1. - pub fn new( - handle: &'a burn_compute::server::Handle, - strides: &'a [usize], - shape: &'a [usize], - ) -> Self { - Self::Handle { - handle: TensorHandle::new(handle, strides, shape), - vectorization_factor: 1, - } - } - /// Create a new tensor argument specified with its vectorization factor. - pub fn vectorized( - factor: u8, - handle: &'a burn_compute::server::Handle, - strides: &'a [usize], - shape: &'a [usize], - ) -> Self { - Self::Handle { - handle: TensorHandle::new(handle, strides, shape), - vectorization_factor: factor, - } - } - pub fn alias(position: usize) -> Self { - Self::Alias { - input_pos: position, - } - } -} - -impl<'a, R: Runtime> ArgSettings for TensorArg<'a, R> { - fn register(&self, launcher: &mut KernelLauncher) { - if let TensorArg::Handle { - handle, - vectorization_factor: _, - } = self - { - launcher.register_tensor(handle) - } - } - - fn configure_input(&self, position: usize, settings: KernelSettings) -> KernelSettings { - match self { - TensorArg::Handle { - handle: _, - vectorization_factor, - } => settings.vectorize_input(position, *vectorization_factor), - TensorArg::Alias { input_pos: _ } => { - panic!("Not yet supported, only output can be aliased for now."); - } - } - } - - fn configure_output(&self, position: usize, mut settings: KernelSettings) -> KernelSettings { - match self { - TensorArg::Handle { - handle: _, - vectorization_factor, - } => settings.vectorize_output(position, *vectorization_factor), - TensorArg::Alias { input_pos } => { - settings.mappings.push(crate::InplaceMapping { - pos_input: *input_pos, - pos_output: position, - }); - settings - } - } - } -} - -impl Tensor { - /// Obtain the stride of input at dimension dim - pub fn stride(&self, _dim: C) -> UInt { - unexpanded!() - } - - /// Obtain the shape of input at dimension dim - pub fn shape(&self, _dim: C) -> UInt { - unexpanded!() - } - - /// The length of the buffer representing the tensor. - /// - /// # Warning - /// - /// The length will be affected by the vectorization factor. To obtain the number of elements, - /// you should multiply the length by the vectorization factor. - pub fn len(&self) -> UInt { - unexpanded!() - } - - /// Returns the rank of the tensor. - pub fn rank(&self) -> UInt { - unexpanded!() - } -} - -impl ExpandElementTyped { - // Expanded version of stride - pub fn stride_expand(self, context: &mut CubeContext, dim: C) -> ExpandElement { - let out = context.create_local(Item::new(Elem::UInt)); - context.register(Metadata::Stride { - dim: dim.value(), - var: self.expand.into(), - out: out.clone().into(), - }); - out - } - - // Expanded version of shape - pub fn shape_expand(self, context: &mut CubeContext, dim: C) -> ExpandElement { - let out = context.create_local(Item::new(Elem::UInt)); - context.register(Metadata::Shape { - dim: dim.value(), - var: self.expand.into(), - out: out.clone().into(), - }); - out - } - - // Expanded version of len - pub fn len_expand(self, context: &mut CubeContext) -> ExpandElement { - let out = context.create_local(Item::new(Elem::UInt)); - context.register(Metadata::Length { - var: self.expand.into(), - out: out.clone().into(), - }); - out - } - - // Expanded version of rank. - pub fn rank_expand(self, _context: &mut CubeContext) -> ExpandElement { - ExpandElement::Plain(Variable::Rank) - } -} diff --git a/crates/burn-cube/src/frontend/element/uint.rs b/crates/burn-cube/src/frontend/element/uint.rs deleted file mode 100644 index 4df2516b69..0000000000 --- a/crates/burn-cube/src/frontend/element/uint.rs +++ /dev/null @@ -1,125 +0,0 @@ -use crate::frontend::{CubeContext, CubePrimitive, CubeType, ExpandElement, Numeric}; -use crate::ir::{Elem, Item, Variable, Vectorization}; -use crate::prelude::{index_assign, KernelBuilder, KernelLauncher}; -use crate::{frontend::Comptime, Runtime}; - -use super::{LaunchArgExpand, ScalarArgSettings, Vectorized}; - -#[derive(Clone, Copy, Debug)] -/// An unsigned int. -/// Preferred for indexing operations -pub struct UInt { - pub val: u32, - pub vectorization: u8, -} - -impl CubeType for UInt { - type ExpandType = ExpandElement; -} - -impl CubePrimitive for UInt { - fn as_elem() -> Elem { - Elem::UInt - } -} - -impl LaunchArgExpand for UInt { - fn expand(builder: &mut KernelBuilder, vectorization: Vectorization) -> ExpandElement { - assert_eq!(vectorization, 1, "Attempted to vectorize a scalar"); - builder.scalar(UInt::as_elem()) - } -} - -impl ScalarArgSettings for u32 { - fn register(&self, settings: &mut KernelLauncher) { - settings.register_u32(*self); - } -} - -impl Numeric for UInt { - type Primitive = u32; -} - -impl UInt { - pub const fn new(val: u32) -> Self { - Self { - val, - vectorization: 1, - } - } - - pub fn __expand_new(_context: &mut CubeContext, val: u32) -> ::ExpandType { - let new_var = Variable::ConstantScalar { - value: val as f64, - elem: Self::as_elem(), - }; - ExpandElement::Plain(new_var) - } - - pub fn vectorized(val: u32, vectorization: UInt) -> Self { - if vectorization.val == 1 { - Self::new(val) - } else { - Self { - val, - vectorization: vectorization.val as u8, - } - } - } - - pub fn __expand_vectorized( - context: &mut CubeContext, - val: u32, - vectorization: UInt, - ) -> ::ExpandType { - if vectorization.val == 1 { - Self::__expand_new(context, val) - } else { - let mut new_var = - context.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8)); - for (i, element) in vec![val; vectorization.val as usize].iter().enumerate() { - new_var = index_assign::expand(context, new_var, i, *element); - } - - new_var - } - } -} - -impl From for UInt { - fn from(value: u32) -> Self { - UInt::new(value) - } -} - -impl From> for UInt { - fn from(value: Comptime) -> Self { - UInt::new(value.inner) - } -} - -impl From for UInt { - fn from(value: usize) -> Self { - UInt::new(value as u32) - } -} - -impl From for UInt { - fn from(value: i32) -> Self { - UInt::new(value as u32) - } -} - -impl Vectorized for UInt { - fn vectorization_factor(&self) -> UInt { - UInt { - val: self.vectorization as u32, - vectorization: 1, - } - } - - fn vectorize(mut self, factor: UInt) -> Self { - self.vectorization = factor.vectorization; - self - } -} diff --git a/crates/burn-cube/src/frontend/element/vectorized.rs b/crates/burn-cube/src/frontend/element/vectorized.rs deleted file mode 100644 index e9497acf80..0000000000 --- a/crates/burn-cube/src/frontend/element/vectorized.rs +++ /dev/null @@ -1,68 +0,0 @@ -use crate::unexpanded; - -use super::{CubeType, ExpandElement, Tensor, UInt}; - -pub trait Vectorized { - fn vectorization_factor(&self) -> UInt; - fn vectorize(self, factor: UInt) -> Self; -} - -impl Vectorized for Tensor { - fn vectorization_factor(&self) -> UInt { - unexpanded!() - } - - fn vectorize(self, _factor: UInt) -> Self { - unexpanded!() - } -} - -impl Vectorized for &Tensor { - fn vectorization_factor(&self) -> UInt { - unexpanded!() - } - - fn vectorize(self, _factor: UInt) -> Self { - unexpanded!() - } -} - -impl Vectorized for &mut Tensor { - fn vectorization_factor(&self) -> UInt { - unexpanded!() - } - - fn vectorize(self, _factor: UInt) -> Self { - unexpanded!() - } -} - -impl Vectorized for ExpandElement { - fn vectorization_factor(&self) -> UInt { - let var = match self { - ExpandElement::Managed(var) => var, - ExpandElement::Plain(var) => var, - }; - - UInt::new(var.item().vectorization as u32) - } - - fn vectorize(self, _factor: UInt) -> Self { - todo!() - } -} - -impl Vectorized for &ExpandElement { - fn vectorization_factor(&self) -> UInt { - let var = match self { - ExpandElement::Managed(var) => var, - ExpandElement::Plain(var) => var, - }; - - UInt::new(var.item().vectorization as u32) - } - - fn vectorize(self, _factor: UInt) -> Self { - todo!() - } -} diff --git a/crates/burn-cube/src/frontend/indexation.rs b/crates/burn-cube/src/frontend/indexation.rs deleted file mode 100644 index 9960d1f4b5..0000000000 --- a/crates/burn-cube/src/frontend/indexation.rs +++ /dev/null @@ -1,57 +0,0 @@ -use super::{Comptime, ExpandElement, UInt}; -use crate::ir::{Elem, Variable}; - -pub trait Index { - fn value(self) -> Variable; -} - -impl Index for Comptime { - fn value(self) -> Variable { - Variable::ConstantScalar { - value: self.inner as f64, - elem: Elem::UInt, - } - } -} - -impl Index for Comptime { - fn value(self) -> Variable { - Variable::ConstantScalar { - value: self.inner as f64, - elem: Elem::UInt, - } - } -} - -impl Index for i32 { - fn value(self) -> Variable { - Variable::ConstantScalar { - value: self as f64, - elem: Elem::UInt, - } - } -} - -impl Index for u32 { - fn value(self) -> Variable { - Variable::ConstantScalar { - value: self as f64, - elem: Elem::UInt, - } - } -} - -impl Index for UInt { - fn value(self) -> Variable { - Variable::ConstantScalar { - value: self.val as f64, - elem: Elem::UInt, - } - } -} - -impl Index for ExpandElement { - fn value(self) -> Variable { - *self - } -} diff --git a/crates/burn-cube/src/frontend/mod.rs b/crates/burn-cube/src/frontend/mod.rs deleted file mode 100644 index f7f29c4d4e..0000000000 --- a/crates/burn-cube/src/frontend/mod.rs +++ /dev/null @@ -1,19 +0,0 @@ -pub mod branch; -pub mod cmma; -pub mod synchronization; - -mod base; -mod comptime; -mod context; -mod element; -mod indexation; -mod operation; -mod subcube; -mod topology; - -pub use comptime::*; -pub use context::*; -pub use element::*; -pub use operation::*; -pub use subcube::*; -pub use topology::*; diff --git a/crates/burn-cube/src/frontend/operation/assignation.rs b/crates/burn-cube/src/frontend/operation/assignation.rs deleted file mode 100644 index 0136b8025c..0000000000 --- a/crates/burn-cube/src/frontend/operation/assignation.rs +++ /dev/null @@ -1,335 +0,0 @@ -use crate::frontend::{Array, CubeContext, ExpandElement, SharedMemory, Tensor, UInt}; -use crate::{ir, unexpanded}; - -pub mod assign { - use self::ir::{Operator, UnaryOperator}; - - use super::*; - - pub fn expand, O: Into>( - context: &mut CubeContext, - input: I, - output: O, - ) { - context.register(Operator::Assign(UnaryOperator { - input: *input.into(), - out: *output.into(), - })); - } -} - -pub mod index_assign { - use crate::{frontend::CubeType, prelude::SliceMut, unexpanded}; - - use self::ir::{BinaryOperator, Operator, Variable}; - - use super::*; - - pub fn expand, I: Into, V: Into>( - context: &mut CubeContext, - array: A, - index: I, - value: V, - ) -> ExpandElement { - let array = array.into(); - let index: Variable = *index.into(); - let index = match index { - Variable::ConstantScalar { value, .. } => Variable::ConstantScalar { - value, - elem: ir::Elem::UInt, - }, - _ => index, - }; - context.register(Operator::IndexAssign(BinaryOperator { - lhs: index, - rhs: *value.into(), - out: *array, - })); - array - } - - macro_rules! impl_index { - ($type:ident) => { - impl> core::ops::IndexMut for $type { - fn index_mut(&mut self, _index: I) -> &mut Self::Output { - unexpanded!() - } - } - }; - } - - impl_index!(Array); - impl_index!(Tensor); - impl_index!(SharedMemory); - - impl<'a, E: CubeType, I: Into> core::ops::IndexMut for SliceMut<'a, E> { - fn index_mut(&mut self, _index: I) -> &mut Self::Output { - unexpanded!() - } - } -} - -pub mod index { - use crate::{ - frontend::{ - operation::base::{binary_expand, binary_expand_no_vec}, - CubeType, - }, - prelude::{Slice, SliceMut}, - unexpanded, - }; - - use self::ir::{Operator, Variable}; - - use super::*; - - pub fn expand, R: Into>( - context: &mut CubeContext, - array: L, - index: R, - ) -> ExpandElement { - let index: ExpandElement = index.into(); - let index_var: Variable = *index; - let index = match index_var { - Variable::ConstantScalar { value, .. } => { - ExpandElement::Plain(Variable::ConstantScalar { - value, - elem: ir::Elem::UInt, - }) - } - _ => index, - }; - let array: ExpandElement = array.into(); - let var: Variable = *array; - match var { - Variable::Local { .. } => binary_expand_no_vec(context, array, index, Operator::Index), - _ => binary_expand(context, array, index, Operator::Index), - } - } - - macro_rules! impl_index { - ($type:ident) => { - impl> core::ops::Index for $type { - type Output = E; - - fn index(&self, _index: I) -> &Self::Output { - unexpanded!() - } - } - }; - } - - impl_index!(Array); - impl_index!(Tensor); - impl_index!(SharedMemory); - - impl<'a, E: CubeType, I: Into> core::ops::Index for SliceMut<'a, E> { - type Output = E; - fn index(&self, _index: I) -> &Self::Output { - unexpanded!() - } - } - - impl<'a, E: CubeType, I: Into> core::ops::Index for Slice<'a, E> { - type Output = E; - fn index(&self, _index: I) -> &Self::Output { - unexpanded!() - } - } -} - -pub mod add_assign_array_op { - use crate::prelude::array_assign_binary_op_expand; - - use self::ir::Operator; - - use super::*; - - pub fn expand< - Array: Into, - Index: Into, - Value: Into, - >( - context: &mut CubeContext, - array: Array, - index: Index, - value: Value, - ) { - array_assign_binary_op_expand(context, array, index, value, Operator::Add); - } -} - -pub mod sub_assign_array_op { - use crate::prelude::array_assign_binary_op_expand; - - use self::ir::Operator; - - use super::*; - - pub fn expand< - Array: Into, - Index: Into, - Value: Into, - >( - context: &mut CubeContext, - array: Array, - index: Index, - value: Value, - ) { - array_assign_binary_op_expand(context, array, index, value, Operator::Sub); - } -} - -pub mod mul_assign_array_op { - use crate::prelude::array_assign_binary_op_expand; - - use self::ir::Operator; - - use super::*; - - pub fn expand< - Array: Into, - Index: Into, - Value: Into, - >( - context: &mut CubeContext, - array: Array, - index: Index, - value: Value, - ) { - array_assign_binary_op_expand(context, array, index, value, Operator::Mul); - } -} - -pub mod div_assign_array_op { - use crate::prelude::array_assign_binary_op_expand; - - use self::ir::Operator; - - use super::*; - - pub fn expand< - Array: Into, - Index: Into, - Value: Into, - >( - context: &mut CubeContext, - array: Array, - index: Index, - value: Value, - ) { - array_assign_binary_op_expand(context, array, index, value, Operator::Div); - } -} - -pub mod add_assign_op { - use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64}; - - use self::ir::Operator; - - use super::*; - - pub fn expand, R: Into>( - context: &mut CubeContext, - lhs: L, - rhs: R, - ) -> ExpandElement { - assign_op_expand(context, lhs.into(), rhs.into(), Operator::Add) - } - - macro_rules! impl_add_assign { - ($($type:ty),*) => { - $(impl core::ops::AddAssign for $type { - fn add_assign(&mut self, _rhs: Self) { - unexpanded!() - } - })* - }; - } - - impl_add_assign!(F16, BF16, F32, F64, I32, I64, UInt); -} - -pub mod sub_assign_op { - use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64}; - - use self::ir::Operator; - - use super::*; - - pub fn expand, R: Into>( - context: &mut CubeContext, - lhs: L, - rhs: R, - ) -> ExpandElement { - assign_op_expand(context, lhs.into(), rhs.into(), Operator::Sub) - } - - macro_rules! impl_add_assign { - ($($type:ty),*) => { - $(impl core::ops::SubAssign for $type { - fn sub_assign(&mut self, _rhs: Self) { - unexpanded!() - } - })* - }; - } - - impl_add_assign!(F16, BF16, F32, F64, I32, I64, UInt); -} - -pub mod mul_assign_op { - use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64}; - - use self::ir::Operator; - - use super::*; - - pub fn expand, R: Into>( - context: &mut CubeContext, - lhs: L, - rhs: R, - ) -> ExpandElement { - assign_op_expand(context, lhs.into(), rhs.into(), Operator::Mul) - } - - macro_rules! impl_add_assign { - ($($type:ty),*) => { - $(impl core::ops::MulAssign for $type { - fn mul_assign(&mut self, _rhs: Self) { - unexpanded!() - } - })* - }; - } - - impl_add_assign!(F16, BF16, F32, F64, I32, I64, UInt); -} - -pub mod div_assign_op { - use crate::frontend::{operation::base::assign_op_expand, BF16, F16, F32, F64, I32, I64}; - - use self::ir::Operator; - - use super::*; - - pub fn expand, R: Into>( - context: &mut CubeContext, - lhs: L, - rhs: R, - ) -> ExpandElement { - assign_op_expand(context, lhs.into(), rhs.into(), Operator::Div) - } - - macro_rules! impl_add_assign { - ($($type:ty),*) => { - $(impl core::ops::DivAssign for $type { - fn div_assign(&mut self, _rhs: Self) { - unexpanded!() - } - })* - }; - } - - impl_add_assign!(F16, BF16, F32, F64, I32, I64, UInt); -} diff --git a/crates/burn-cube/src/frontend/operation/base.rs b/crates/burn-cube/src/frontend/operation/base.rs deleted file mode 100644 index 4d0c705486..0000000000 --- a/crates/burn-cube/src/frontend/operation/base.rs +++ /dev/null @@ -1,245 +0,0 @@ -use crate::frontend::{CubeContext, ExpandElement}; -use crate::ir::{BinaryOperator, Elem, Item, Operator, UnaryOperator, Variable, Vectorization}; - -pub(crate) fn binary_expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - func: F, -) -> ExpandElement -where - F: Fn(BinaryOperator) -> Operator, -{ - let lhs_var: Variable = *lhs; - let rhs_var: Variable = *rhs; - - let item_lhs = lhs.item(); - let item_rhs = rhs.item(); - - let vectorization = check_vectorization(item_lhs.vectorization, item_rhs.vectorization); - let item = Item::vectorized(item_lhs.elem, vectorization); - - // We can only reuse rhs. - let out = if lhs.can_mut() && item_lhs == item { - lhs - } else if rhs.can_mut() && item_rhs == item { - rhs - } else { - context.create_local(item) - }; - - let out_var = *out; - - let op = func(BinaryOperator { - lhs: lhs_var, - rhs: rhs_var, - out: out_var, - }); - - context.register(op); - - out -} - -pub(crate) fn binary_expand_no_vec( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - func: F, -) -> ExpandElement -where - F: Fn(BinaryOperator) -> Operator, -{ - let lhs_var: Variable = *lhs; - let rhs_var: Variable = *rhs; - - let item_lhs = lhs.item(); - let item_rhs = rhs.item(); - - let item = Item::new(item_lhs.elem); - - // We can only reuse rhs. - let out = if lhs.can_mut() && item_lhs == item { - lhs - } else if rhs.can_mut() && item_rhs == item { - rhs - } else { - context.create_local(item) - }; - - let out_var = *out; - - let op = func(BinaryOperator { - lhs: lhs_var, - rhs: rhs_var, - out: out_var, - }); - - context.register(op); - - out -} - -pub(crate) fn cmp_expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - func: F, -) -> ExpandElement -where - F: Fn(BinaryOperator) -> Operator, -{ - let lhs: Variable = *lhs; - let rhs: Variable = *rhs; - let item = lhs.item(); - - check_vectorization(item.vectorization, rhs.item().vectorization); - - let out_item = Item { - elem: Elem::Bool, - vectorization: item.vectorization, - }; - - let out = context.create_local(out_item); - let out_var = *out; - - let op = func(BinaryOperator { - lhs, - rhs, - out: out_var, - }); - - context.register(op); - - out -} - -pub(crate) fn assign_op_expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - func: F, -) -> ExpandElement -where - F: Fn(BinaryOperator) -> Operator, -{ - let lhs_var: Variable = *lhs; - let rhs: Variable = *rhs; - - check_vectorization(lhs_var.item().vectorization, rhs.item().vectorization); - - let op = func(BinaryOperator { - lhs: lhs_var, - rhs, - out: lhs_var, - }); - - context.register(op); - - lhs -} - -pub fn unary_expand(context: &mut CubeContext, input: ExpandElement, func: F) -> ExpandElement -where - F: Fn(UnaryOperator) -> Operator, -{ - let input_var: Variable = *input; - - let item = input.item(); - - let out = if input.can_mut() { - input - } else { - context.create_local(item) - }; - - let out_var = *out; - - let op = func(UnaryOperator { - input: input_var, - out: out_var, - }); - - context.register(op); - - out -} - -pub fn init_expand(context: &mut CubeContext, input: ExpandElement, func: F) -> ExpandElement -where - F: Fn(UnaryOperator) -> Operator, -{ - if input.can_mut() { - return input; - } - - let input_var: Variable = *input; - let item = input.item(); - - let out = context.create_local(item); - let out_var = *out; - - let op = func(UnaryOperator { - input: input_var, - out: out_var, - }); - - context.register(op); - - out -} - -fn check_vectorization(lhs: Vectorization, rhs: Vectorization) -> Vectorization { - let output = u8::max(lhs, rhs); - - if lhs == 1 || rhs == 1 { - return output; - } - - assert!( - lhs == rhs, - "Tried to perform binary operation on different vectorization schemes." - ); - - output -} - -pub fn array_assign_binary_op_expand< - Array: Into, - Index: Into, - Value: Into, - F: Fn(BinaryOperator) -> Operator, ->( - context: &mut CubeContext, - array: Array, - index: Index, - value: Value, - func: F, -) { - let array: ExpandElement = array.into(); - let index: ExpandElement = index.into(); - let value: ExpandElement = value.into(); - - let tmp = context.create_local(array.item()); - - let read = Operator::Index(BinaryOperator { - lhs: *array, - rhs: *index, - out: *tmp, - }); - let calculate = func(BinaryOperator { - lhs: *tmp, - rhs: *value, - out: *tmp, - }); - - let write = Operator::IndexAssign(BinaryOperator { - lhs: *index, - rhs: *tmp, - out: *array, - }); - - context.register(read); - context.register(calculate); - context.register(write); -} diff --git a/crates/burn-cube/src/frontend/operation/binary.rs b/crates/burn-cube/src/frontend/operation/binary.rs deleted file mode 100644 index 08cf562442..0000000000 --- a/crates/burn-cube/src/frontend/operation/binary.rs +++ /dev/null @@ -1,331 +0,0 @@ -use crate::frontend::operation::base::binary_expand; -use crate::frontend::{CubeContext, ExpandElement, UInt, BF16, F16, F32, F64, I32, I64}; -use crate::ir::Operator; -use crate::{frontend::CubeType, unexpanded}; - -pub mod add { - use super::*; - - pub fn expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - ) -> ExpandElement { - binary_expand(context, lhs, rhs, Operator::Add) - } - - macro_rules! impl_add { - ($type:ty) => { - impl core::ops::Add for $type { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - (self.val + rhs.val).into() - } - } - }; - } - - impl_add!(F16); - impl_add!(BF16); - impl_add!(F32); - impl_add!(F64); - impl_add!(I32); - impl_add!(I64); - impl_add!(UInt); -} - -pub mod sub { - use super::*; - - pub fn expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - ) -> ExpandElement { - binary_expand(context, lhs, rhs, Operator::Sub) - } - - macro_rules! impl_sub { - ($type:ty) => { - impl core::ops::Sub for $type { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - (self.val - rhs.val).into() - } - } - }; - } - - impl_sub!(F16); - impl_sub!(BF16); - impl_sub!(F32); - impl_sub!(F64); - impl_sub!(I32); - impl_sub!(I64); - impl_sub!(UInt); -} - -pub mod mul { - use super::*; - - pub fn expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - ) -> ExpandElement { - binary_expand(context, lhs, rhs, Operator::Mul) - } - - macro_rules! impl_mul { - ($type:ty) => { - impl core::ops::Mul for $type { - type Output = Self; - - fn mul(self, rhs: Self) -> Self::Output { - (self.val * rhs.val).into() - } - } - }; - } - - impl_mul!(F16); - impl_mul!(BF16); - impl_mul!(F32); - impl_mul!(F64); - impl_mul!(I32); - impl_mul!(I64); - impl_mul!(UInt); -} - -pub mod div { - use super::*; - - pub fn expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - ) -> ExpandElement { - binary_expand(context, lhs, rhs, Operator::Div) - } - - macro_rules! impl_div { - ($type:ty) => { - impl core::ops::Div for $type { - type Output = Self; - - fn div(self, rhs: Self) -> Self::Output { - (self.val / rhs.val).into() - } - } - }; - } - - impl_div!(F16); - impl_div!(BF16); - impl_div!(F32); - impl_div!(F64); - impl_div!(I32); - impl_div!(I64); - impl_div!(UInt); -} - -pub mod rem { - use super::*; - - pub fn expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - ) -> ExpandElement { - binary_expand(context, lhs, rhs, Operator::Modulo) - } - - macro_rules! impl_rem { - ($type:ty) => { - impl core::ops::Rem for $type { - type Output = Self; - - fn rem(self, _rhs: Self) -> Self::Output { - unexpanded!() - } - } - }; - } - - impl_rem!(I32); - impl_rem!(I64); - impl_rem!(UInt); -} - -pub mod and { - use super::*; - - pub fn expand, R: Into>( - context: &mut CubeContext, - lhs: L, - rhs: R, - ) -> ExpandElement { - binary_expand(context, lhs.into(), rhs.into(), Operator::And) - } -} - -pub mod bitand { - use super::*; - - pub fn expand, R: Into>( - context: &mut CubeContext, - lhs: L, - rhs: R, - ) -> ExpandElement { - binary_expand(context, lhs.into(), rhs.into(), Operator::BitwiseAnd) - } - - impl core::ops::BitAnd for UInt { - type Output = UInt; - - fn bitand(self, _rhs: Self) -> Self::Output { - unexpanded!() - } - } -} - -pub mod or { - use super::*; - - pub fn expand(context: &mut CubeContext, lhs: L, rhs: R) -> ExpandElement - where - L: Into, - R: Into, - { - binary_expand(context, lhs.into(), rhs.into(), Operator::Or) - } -} - -pub mod bitxor { - use super::*; - - pub fn expand, R: Into>( - context: &mut CubeContext, - lhs: L, - rhs: R, - ) -> ExpandElement { - binary_expand(context, lhs.into(), rhs.into(), Operator::BitwiseXor) - } - - impl core::ops::BitXor for UInt { - type Output = UInt; - - fn bitxor(self, _rhs: Self) -> Self::Output { - unexpanded!() - } - } -} - -pub mod shl { - use super::*; - - pub fn expand, R: Into>( - context: &mut CubeContext, - lhs: L, - rhs: R, - ) -> ExpandElement { - binary_expand(context, lhs.into(), rhs.into(), Operator::ShiftLeft) - } - - impl core::ops::Shl for UInt { - type Output = UInt; - - fn shl(self, _rhs: Self) -> Self::Output { - unexpanded!() - } - } -} - -pub mod shr { - use super::*; - - pub fn expand, R: Into>( - context: &mut CubeContext, - lhs: L, - rhs: R, - ) -> ExpandElement { - binary_expand(context, lhs.into(), rhs.into(), Operator::ShiftRight) - } - - impl core::ops::Shr for UInt { - type Output = UInt; - - fn shr(self, _rhs: Self) -> Self::Output { - unexpanded!() - } - } -} - -/// For binary functions without special syntax -macro_rules! impl_binary_func { - ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => { - pub trait $trait_name: CubeType + Sized { - fn $method_name(self, _rhs: Self) -> Self { - unexpanded!() - } - - fn $method_name_expand(context: &mut CubeContext, lhs: ExpandElement, rhs: ExpandElement) -> ExpandElement { - binary_expand(context, lhs, rhs, $operator) - } - } - - $(impl $trait_name for $type {})* - } -} - -impl_binary_func!( - Powf, - powf, - __expand_powf, - Operator::Powf, - F16, - BF16, - F32, - F64 -); -impl_binary_func!( - Max, - max, - __expand_max, - Operator::Max, - F16, - BF16, - F32, - F64, - I32, - I64, - UInt -); -impl_binary_func!( - Min, - min, - __expand_min, - Operator::Min, - F16, - BF16, - F32, - F64, - I32, - I64, - UInt -); -impl_binary_func!( - Remainder, - rem, - __expand_rem, - Operator::Remainder, - F16, - BF16, - F32, - F64, - I32, - I64, - UInt -); diff --git a/crates/burn-cube/src/frontend/operation/clamp.rs b/crates/burn-cube/src/frontend/operation/clamp.rs deleted file mode 100644 index 3765c5aa9e..0000000000 --- a/crates/burn-cube/src/frontend/operation/clamp.rs +++ /dev/null @@ -1,38 +0,0 @@ -use crate::{ - ir::{ClampOperator, Operator}, - prelude::{CubeContext, CubePrimitive, UInt, BF16, F16, F32, F64, I32, I64}, - unexpanded, -}; - -use super::unary_expand; - -pub trait Clamp: CubePrimitive + Sized { - /// Clamp the input value between the max and min values provided. - #[allow(unused_variables)] - fn clamp(input: Self, min_value: Self, max_value: Self) -> Self { - unexpanded!() - } - fn __expand_clamp( - context: &mut CubeContext, - input: Self::ExpandType, - min_value: Self::ExpandType, - max_value: Self::ExpandType, - ) -> Self::ExpandType { - unary_expand(context, input, |op| { - Operator::Clamp(ClampOperator { - input: op.input, - min_value: *min_value, - max_value: *max_value, - out: op.out, - }) - }) - } -} - -impl Clamp for F16 {} -impl Clamp for BF16 {} -impl Clamp for F32 {} -impl Clamp for F64 {} -impl Clamp for I32 {} -impl Clamp for I64 {} -impl Clamp for UInt {} diff --git a/crates/burn-cube/src/frontend/operation/cmp.rs b/crates/burn-cube/src/frontend/operation/cmp.rs deleted file mode 100644 index 4affba7d76..0000000000 --- a/crates/burn-cube/src/frontend/operation/cmp.rs +++ /dev/null @@ -1,118 +0,0 @@ -use crate::frontend::operation::base::cmp_expand; -use crate::frontend::{CubeContext, ExpandElement, UInt, BF16, F16, F32, F64, I32, I64}; -use crate::ir::Operator; - -macro_rules! impl_cmp { - ($type:ty) => { - impl core::cmp::PartialEq for $type { - fn eq(&self, other: &Self) -> bool { - self.val == other.val && self.vectorization == other.vectorization - } - } - - impl core::cmp::Eq for $type {} - - impl core::cmp::PartialOrd for $type { - fn partial_cmp(&self, other: &Self) -> Option { - match self.val.partial_cmp(&other.val) { - Some(core::cmp::Ordering::Equal) => {} - ord => return ord, - } - self.vectorization.partial_cmp(&other.vectorization) - } - } - }; -} - -impl_cmp!(F16); -impl_cmp!(BF16); -impl_cmp!(F32); -impl_cmp!(F64); -impl_cmp!(I32); -impl_cmp!(I64); -impl_cmp!(UInt); - -pub mod ne { - - use super::*; - - pub fn expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - ) -> ExpandElement { - cmp_expand(context, lhs, rhs, Operator::NotEqual) - } -} - -pub mod gt { - use super::*; - - pub fn expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - ) -> ExpandElement { - cmp_expand(context, lhs, rhs, Operator::Greater) - } -} - -pub mod lt { - use super::*; - - pub fn expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - ) -> ExpandElement { - cmp_expand(context, lhs, rhs, Operator::Lower) - } -} - -pub mod ge { - use super::*; - - pub fn expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - ) -> ExpandElement { - cmp_expand(context, lhs, rhs, Operator::GreaterEqual) - } -} - -pub mod le { - use super::*; - - pub fn expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - ) -> ExpandElement { - cmp_expand(context, lhs, rhs, Operator::LowerEqual) - } -} - -pub mod eq { - use super::*; - - pub fn expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - ) -> ExpandElement { - cmp_expand(context, lhs, rhs, Operator::Equal) - } -} - -pub mod add_assign { - use super::*; - - pub fn expand( - context: &mut CubeContext, - lhs: ExpandElement, - rhs: ExpandElement, - ) -> ExpandElement { - cmp_expand(context, lhs, rhs, Operator::Add) - } -} diff --git a/crates/burn-cube/src/frontend/operation/fma.rs b/crates/burn-cube/src/frontend/operation/fma.rs deleted file mode 100644 index 9b106e4c90..0000000000 --- a/crates/burn-cube/src/frontend/operation/fma.rs +++ /dev/null @@ -1,36 +0,0 @@ -use crate::{ - ir::{FmaOperator, Operation, Operator}, - prelude::{CubeContext, CubePrimitive, ExpandElement}, - unexpanded, -}; - -/// Fused multiply-add `A*B+C`. -#[allow(unused_variables)] -pub fn fma(a: C, b: C, c: C) -> C { - unexpanded!() -} - -/// Expand method of [fma]. -#[allow(unused_variables)] -pub fn fma_expand( - context: &mut CubeContext, - a: ExpandElement, - b: ExpandElement, - c: ExpandElement, -) -> ExpandElement { - let output = context.create_local(a.item()); - - let out = *output; - let a = *a; - let b = *b; - let c = *c; - - context.register(Operation::Operator(Operator::Fma(FmaOperator { - a, - b, - c, - out, - }))); - - output -} diff --git a/crates/burn-cube/src/frontend/operation/mod.rs b/crates/burn-cube/src/frontend/operation/mod.rs deleted file mode 100644 index 06273444b2..0000000000 --- a/crates/burn-cube/src/frontend/operation/mod.rs +++ /dev/null @@ -1,15 +0,0 @@ -mod assignation; -mod base; -mod binary; -mod clamp; -mod cmp; -mod fma; -mod unary; - -pub use assignation::*; -pub use base::*; -pub use binary::*; -pub use clamp::*; -pub use cmp::*; -pub use fma::*; -pub use unary::*; diff --git a/crates/burn-cube/src/frontend/operation/unary.rs b/crates/burn-cube/src/frontend/operation/unary.rs deleted file mode 100644 index ec780652f3..0000000000 --- a/crates/burn-cube/src/frontend/operation/unary.rs +++ /dev/null @@ -1,110 +0,0 @@ -use crate::{ - frontend::{CubeContext, CubeType, ExpandElement, UInt, BF16, F16, F32, F64, I32, I64}, - ir::Operator, - unexpanded, -}; - -use super::base::unary_expand; - -pub mod not { - use super::*; - - pub fn expand(context: &mut CubeContext, x: ExpandElement) -> ExpandElement { - unary_expand(context, x, Operator::Not) - } -} - -macro_rules! impl_unary_func { - ($trait_name:ident, $method_name:ident, $method_name_expand:ident, $operator:expr, $($type:ty),*) => { - pub trait $trait_name: CubeType + Sized { - fn $method_name(_input: Self) -> Self { - unexpanded!() - } - - fn $method_name_expand(context: &mut CubeContext, x: ExpandElement) -> ExpandElement { - unary_expand(context, x, $operator) - } - } - - $(impl $trait_name for $type {})* - } -} - -impl_unary_func!( - Abs, - abs, - __expand_abs, - Operator::Abs, - F16, - BF16, - F32, - F64, - I32, - I64, - UInt -); -impl_unary_func!(Exp, exp, __expand_exp, Operator::Exp, F16, BF16, F32, F64); -impl_unary_func!(Log, log, __expand_log, Operator::Log, F16, BF16, F32, F64); -impl_unary_func!( - Log1p, - log1p, - __expand_log1p, - Operator::Log1p, - F16, - BF16, - F32, - F64 -); -impl_unary_func!(Cos, cos, __expand_cos, Operator::Cos, F16, BF16, F32, F64); -impl_unary_func!(Sin, sin, __expand_sin, Operator::Sin, F16, BF16, F32, F64); -impl_unary_func!( - Tanh, - tanh, - __expand_tanh, - Operator::Tanh, - F16, - BF16, - F32, - F64 -); -impl_unary_func!( - Sqrt, - sqrt, - __expand_sqrt, - Operator::Sqrt, - F16, - BF16, - F32, - F64 -); -impl_unary_func!( - Floor, - floor, - __expand_floor, - Operator::Floor, - F16, - BF16, - F32, - F64 -); -impl_unary_func!( - Ceil, - ceil, - __expand_ceil, - Operator::Ceil, - F16, - BF16, - F32, - F64 -); -impl_unary_func!(Erf, erf, __expand_erf, Operator::Erf, F16, BF16, F32, F64); -impl_unary_func!( - Recip, - recip, - __expand_recip, - Operator::Recip, - F16, - BF16, - F32, - F64 -); diff --git a/crates/burn-cube/src/frontend/subcube.rs b/crates/burn-cube/src/frontend/subcube.rs deleted file mode 100644 index 3f50154ab7..0000000000 --- a/crates/burn-cube/src/frontend/subcube.rs +++ /dev/null @@ -1,161 +0,0 @@ -use super::{CubeContext, CubePrimitive, ExpandElement}; -use crate::{ - ir::{Elem, InitOperator, Item, Operation, Subcube, UnaryOperator}, - unexpanded, -}; - -/// Returns true if the cube unit has the lowest subcube_unit_id among active unit in the subcube -pub fn subcube_elect() -> bool { - unexpanded!() -} - -pub fn subcube_elect_expand(context: &mut CubeContext) -> ExpandElement { - let output = context.create_local(Item::new(Elem::Bool)); - - let out = *output; - - context.register(Operation::Subcube(Subcube::Elect(InitOperator { out }))); - - output -} - -/// Perform a reduce sum operation across all units in a subcube. -#[allow(unused_variables)] -pub fn subcube_sum(value: E) -> E { - unexpanded!() -} - -/// Module containing the expand function for [subcube_sum()]. -pub mod subcube_sum { - use super::*; - - /// Expand method of [subcube_sum()]. - pub fn __expand( - context: &mut CubeContext, - elem: ExpandElement, - ) -> ExpandElement { - let output = context.create_local(elem.item()); - - let out = *output; - let input = *elem; - - context.register(Operation::Subcube(Subcube::Sum(UnaryOperator { - input, - out, - }))); - - output - } -} - -/// Perform a reduce prod operation across all units in a subcube. -pub fn subcube_prod(_elem: E) -> E { - unexpanded!() -} - -/// Module containing the expand function for [subcube_prod()]. -pub mod subcube_prod { - use super::*; - - /// Expand method of [subcube_prod()]. - pub fn __expand( - context: &mut CubeContext, - elem: ExpandElement, - ) -> ExpandElement { - let output = context.create_local(elem.item()); - - let out = *output; - let input = *elem; - - context.register(Operation::Subcube(Subcube::Prod(UnaryOperator { - input, - out, - }))); - - output - } -} - -/// Perform a reduce max operation across all units in a subcube. -pub fn subcube_max(_elem: E) -> E { - unexpanded!() -} - -/// Module containing the expand function for [subcube_max()]. -pub mod subcube_max { - use super::*; - - /// Expand method of [subcube_max()]. - pub fn __expand( - context: &mut CubeContext, - elem: ExpandElement, - ) -> ExpandElement { - let output = context.create_local(elem.item()); - - let out = *output; - let input = *elem; - - context.register(Operation::Subcube(Subcube::Max(UnaryOperator { - input, - out, - }))); - - output - } -} - -/// Perform a reduce min operation across all units in a subcube. -pub fn subcube_min(_elem: E) -> E { - unexpanded!() -} - -/// Module containing the expand function for [subcube_min()]. -pub mod subcube_min { - use super::*; - - /// Expand method of [subcube_min()]. - pub fn __expand( - context: &mut CubeContext, - elem: ExpandElement, - ) -> ExpandElement { - let output = context.create_local(elem.item()); - - let out = *output; - let input = *elem; - - context.register(Operation::Subcube(Subcube::Min(UnaryOperator { - input, - out, - }))); - - output - } -} - -/// Perform a reduce all operation across all units in a subcube. -pub fn subcube_all(_elem: E) -> E { - unexpanded!() -} - -/// Module containing the expand function for [subcube_all()]. -pub mod subcube_all { - use super::*; - - /// Expand method of [subcube_all()]. - pub fn __expand( - context: &mut CubeContext, - elem: ExpandElement, - ) -> ExpandElement { - let output = context.create_local(elem.item()); - - let out = *output; - let input = *elem; - - context.register(Operation::Subcube(Subcube::All(UnaryOperator { - input, - out, - }))); - - output - } -} diff --git a/crates/burn-cube/src/frontend/synchronization.rs b/crates/burn-cube/src/frontend/synchronization.rs deleted file mode 100644 index a47967e4a0..0000000000 --- a/crates/burn-cube/src/frontend/synchronization.rs +++ /dev/null @@ -1,12 +0,0 @@ -use crate::frontend::CubeContext; -use crate::ir::Synchronization; - -pub fn sync_units() {} - -pub mod sync_units { - use super::*; - - pub fn __expand(context: &mut CubeContext) { - context.register(Synchronization::SyncUnits) - } -} diff --git a/crates/burn-cube/src/frontend/topology.rs b/crates/burn-cube/src/frontend/topology.rs deleted file mode 100644 index e969d7e1d7..0000000000 --- a/crates/burn-cube/src/frontend/topology.rs +++ /dev/null @@ -1,189 +0,0 @@ -//! In this file we use a trick where the constant has the same name as the module containing -//! the expand function, so that a user implicitly imports the expand function when importing the constant. - -use crate::frontend::UInt; - -macro_rules! constant { - ($ident:ident, $var:expr, $doc:expr) => { - #[doc = $doc] - pub const $ident: UInt = UInt::new(0u32); - - #[allow(non_snake_case)] - #[doc = $doc] - pub mod $ident { - use crate::frontend::{CubeContext, ExpandElement}; - - /// Expansion of the constant variable. - pub fn expand(_context: &mut CubeContext) -> ExpandElement { - ExpandElement::Plain($var) - } - } - }; -} - -constant!( - SUBCUBE_DIM, - crate::ir::Variable::SubcubeDim, - r" -The total amount of working units in a subcube. -" -); - -constant!( - UNIT_POS, - crate::ir::Variable::UnitPos, - r" -The position of the working unit inside the cube, without regards to axis. -" -); - -constant!( - UNIT_POS_X, - crate::ir::Variable::UnitPosX, - r" -The position of the working unit inside the cube along the X axis. -" -); - -constant!( - UNIT_POS_Y, - crate::ir::Variable::UnitPosY, - r" -The position of the working unit inside the cube along the Y axis. -" -); - -constant!( - UNIT_POS_Z, - crate::ir::Variable::UnitPosZ, - r" -The position of the working unit inside the cube along the Z axis. -" -); - -constant!( - CUBE_DIM, - crate::ir::Variable::CubeDim, - r" -The total amount of working units in a cube. -" -); - -constant!( - CUBE_DIM_X, - crate::ir::Variable::CubeDimX, - r" -The dimension of the cube along the X axis. -" -); - -constant!( - CUBE_DIM_Y, - crate::ir::Variable::CubeDimY, - r" -The dimension of the cube along the Y axis. -" -); - -constant!( - CUBE_DIM_Z, - crate::ir::Variable::CubeDimZ, - r" -The dimension of the cube along the Z axis. -" -); - -constant!( - CUBE_POS, - crate::ir::Variable::CubePos, - r" -The cube position, without regards to axis. -" -); - -constant!( - CUBE_POS_X, - crate::ir::Variable::CubePosX, - r" -The cube position along the X axis. -" -); - -constant!( - CUBE_POS_Y, - crate::ir::Variable::CubePosY, - r" -The cube position along the Y axis. -" -); - -constant!( - CUBE_POS_Z, - crate::ir::Variable::CubePosZ, - r" -The cube position along the Z axis. -" -); -constant!( - CUBE_COUNT, - crate::ir::Variable::CubeCount, - r" -The number of cubes launched. -" -); - -constant!( - CUBE_COUNT_X, - crate::ir::Variable::CubeCountX, - r" -The number of cubes launched along the X axis. -" -); - -constant!( - CUBE_COUNT_Y, - crate::ir::Variable::CubeCountY, - r" -The number of cubes launched along the Y axis. -" -); - -constant!( - CUBE_COUNT_Z, - crate::ir::Variable::CubeCountZ, - r" -The number of cubes launched along the Z axis. -" -); - -constant!( - ABSOLUTE_POS, - crate::ir::Variable::AbsolutePos, - r" -The position of the working unit in the whole cube kernel, without regards to cubes and axis. -" -); - -constant!( - ABSOLUTE_POS_X, - crate::ir::Variable::AbsolutePosX, - r" -The index of the working unit in the whole cube kernel along the X axis, without regards to cubes. -" -); - -constant!( - ABSOLUTE_POS_Y, - crate::ir::Variable::AbsolutePosY, - r" -The index of the working unit in the whole cube kernel along the Y axis, without regards to cubes. -" -); - -constant!( - ABSOLUTE_POS_Z, - crate::ir::Variable::AbsolutePosZ, - r" -The index of the working unit in the whole cube kernel along the Z axis, without regards to cubes. -" -); diff --git a/crates/burn-cube/src/ir/branch.rs b/crates/burn-cube/src/ir/branch.rs deleted file mode 100644 index 52ef709684..0000000000 --- a/crates/burn-cube/src/ir/branch.rs +++ /dev/null @@ -1,133 +0,0 @@ -use super::{Elem, Item, Scope, Variable}; -use serde::{Deserialize, Serialize}; - -/// All branching types. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum Branch { - /// An if statement. - If(If), - /// An if else statement. - IfElse(IfElse), - /// A range loop. - RangeLoop(RangeLoop), - /// A loop. - Loop(Loop), - /// A return statement. - Return, - /// A break statement. - Break, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct If { - pub cond: Variable, - pub scope: Scope, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct IfElse { - pub cond: Variable, - pub scope_if: Scope, - pub scope_else: Scope, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct RangeLoop { - pub i: Variable, - pub start: Variable, - pub end: Variable, - pub scope: Scope, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct Loop { - pub scope: Scope, -} - -impl If { - /// Registers an if statement to the given scope. - pub fn register(parent_scope: &mut Scope, cond: Variable, func: F) { - let mut scope = parent_scope.child(); - - func(&mut scope); - - let op = Self { cond, scope }; - parent_scope.register(Branch::If(op)); - } -} - -impl IfElse { - /// Registers an if else statement to the given scope. - pub fn register( - parent_scope: &mut Scope, - cond: Variable, - func_if: IF, - func_else: ELSE, - ) where - IF: Fn(&mut Scope), - ELSE: Fn(&mut Scope), - { - let mut scope_if = parent_scope.child(); - let mut scope_else = parent_scope.child(); - - func_if(&mut scope_if); - func_else(&mut scope_else); - - parent_scope.register(Branch::IfElse(Self { - cond, - scope_if, - scope_else, - })); - } -} - -impl RangeLoop { - /// Registers a range loop to the given scope. - pub fn register( - parent_scope: &mut Scope, - start: Variable, - end: Variable, - func: F, - ) { - let mut scope = parent_scope.child(); - let index_ty = Item::new(Elem::UInt); - let i = scope.create_local_undeclared(index_ty); - - func(i, &mut scope); - - parent_scope.register(Branch::RangeLoop(Self { - i, - start, - end, - scope, - })); - } -} - -impl Loop { - /// Registers a loop to the given scope. - pub fn register(parent_scope: &mut Scope, func: F) { - let mut scope = parent_scope.child(); - - func(&mut scope); - - let op = Self { scope }; - parent_scope.register(Branch::Loop(op)); - } -} - -#[allow(missing_docs)] -pub struct UnrolledRangeLoop; - -impl UnrolledRangeLoop { - /// Registers an unrolled range loop to the given scope. - pub fn register(scope: &mut Scope, start: u32, end: u32, func: F) { - for i in start..end { - func(i.into(), scope); - } - } -} diff --git a/crates/burn-cube/src/ir/cmma.rs b/crates/burn-cube/src/ir/cmma.rs deleted file mode 100644 index 133cf8eaff..0000000000 --- a/crates/burn-cube/src/ir/cmma.rs +++ /dev/null @@ -1,60 +0,0 @@ -use serde::{Deserialize, Serialize}; - -use super::{Elem, Variable}; - -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub enum MatrixIdent { - A, - B, - Accumulator, -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub enum MatrixLayout { - ColMajor, - RowMajor, - Undefined, -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct Matrix { - pub ident: MatrixIdent, - pub m: u8, - pub n: u8, - pub k: u8, - pub elem: Elem, - pub layout: MatrixLayout, -} - -/// Cooperative Matrix-Multiply and Accumulate Instruction. -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub enum CoopMma { - /// Fill the matrix with the value. - Fill { mat: Variable, value: Variable }, - /// Load the value into the matrix given the stride. - Load { - mat: Variable, - value: Variable, - stride: Variable, - }, - /// Executes D=A*B+C; - /// - /// For implementing a matmul, `D=C` : `C+=A*B` - Execute { - mat_a: Variable, - mat_b: Variable, - mat_c: Variable, - mat_d: Variable, - }, - /// Store the matrix in an output variable following the stride and the layout. - Store { - output: Variable, - mat: Variable, - stride: Variable, - layout: MatrixLayout, - }, -} diff --git a/crates/burn-cube/src/ir/macros.rs b/crates/burn-cube/src/ir/macros.rs deleted file mode 100644 index d50035b51e..0000000000 --- a/crates/burn-cube/src/ir/macros.rs +++ /dev/null @@ -1,462 +0,0 @@ -use super::Variable; - -#[macro_export(local_inner_macros)] -/// Cube Pseudo Assembly. -macro_rules! cpa { - // out = lhs + rhs - ($scope:expr, $out:ident = $lhs:ident + $rhs:expr) => { - cpa!($scope, $out = add($lhs, $rhs)) - }; - // out += input - ($scope:expr, $out:ident += $input:ident) => { - cpa!($scope, $out = add($out, $input)) - }; - // out = add(lhs, rhs) - ($scope:expr, $out:ident = add($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Add( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs - rhs - ($scope:expr, $out:ident = $lhs:ident - $rhs:expr) => { - cpa!($scope, $out = sub($lhs, $rhs)) - }; - // out = sub(lhs, rhs) - ($scope:expr, $out:ident = sub($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Sub( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs * rhs - ($scope:expr, $out:ident = $lhs:ident * $rhs:expr) => { - cpa!($scope, $out = mul($lhs, $rhs)) - }; - // out *= input - ($scope:expr, $out:ident *= $input:ident) => { - cpa!($scope, $out = mul($out, $input)) - }; - // out = mul(lhs, rhs) - ($scope:expr, $out:ident = mul($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Mul( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs / rhs - ($scope:expr, $out:ident = $lhs:ident / $rhs:expr) => { - cpa!($scope, $out = div($lhs, $rhs)) - }; - // out = div(lhs, rhs) - ($scope:expr, $out:ident = div($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Div( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs % rhs - ($scope:expr, $out:ident = $lhs:ident % $rhs:expr) => { - cpa!($scope, $out = modulo($lhs, $rhs)) - }; - // out = modulo(lhs, rhs) - ($scope:expr, $out:ident = modulo($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Modulo( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = powf(lhs, rhs) - ($scope:expr, $out:ident = powf($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Powf( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs && rhs - ($scope:expr, $out:ident = $lhs:ident && $rhs:expr) => { - cpa!($scope, $out = and($lhs, $rhs)) - }; - // out = and(lhs, rhs) - ($scope:expr, $out:ident = and($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::And( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs || rhs - ($scope:expr, $out:ident = $lhs:ident || $rhs:expr) => { - cpa!($scope, $out = or($lhs, $rhs)) - }; - // out = or(lhs, rhs) - ($scope:expr, $out:ident = or($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Or( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = !input - ($scope:expr, $out:ident = !$input:expr) => { - cpa!($scope, $out = not($input)) - }; - // out = not(input) - ($scope:expr, $out:ident = not($input:expr)) => { - $scope.register($crate::ir::Operator::Not( - cpa!(unary $input, $out) - )); - }; - // out = lhs & rhs - ($scope:expr, $out: ident = $lhs:ident & $rhs:ident) => { - cpa!($scope, $out = bitwise_and($lhs, $rhs)) - }; - // out = bitwise_and(lhs, rhs) - ($scope:expr, $out:ident = bitwise_and($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::BitwiseAnd( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs ^ rhs - ($scope:expr, $out: ident = $lhs:ident ^ $rhs:ident) => { - cpa!($scope, $out = bitwise_xor($lhs, $rhs)) - }; - // out = bitwise_xor(lhs, rhs) - ($scope:expr, $out:ident = bitwise_xor($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::BitwiseXor( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs << rhs - ($scope:expr, $out: ident = $lhs:ident << $rhs:ident) => { - cpa!($scope, $out = shift_left($lhs, $rhs)) - }; - // out = shift_left(lhs, rhs) - ($scope:expr, $out:ident = shift_left($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::ShiftLeft( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs >> rhs - ($scope:expr, $out: ident = $lhs:ident >> $rhs:ident) => { - cpa!($scope, $out = shift_right($lhs, $rhs)) - }; - // out = shift_right(lhs, rhs) - ($scope:expr, $out:ident = shift_right($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::ShiftRight( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs == rhs - ($scope:expr, $out:ident = $lhs:ident == $rhs:expr) => { - cpa!($scope, $out = equal($lhs, $rhs)) - }; - // out = equal(lhs, rhs) - ($scope:expr, $out:ident = equal($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Equal( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs != rhs - ($scope:expr, $out:ident = $lhs:ident != $rhs:expr) => { - cpa!($scope, $out = not_equal($lhs, $rhs)) - }; - // out = not_equal(lhs, rhs) - ($scope:expr, $out:ident = not_equal($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::NotEqual( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs > rhs - ($scope:expr, $out:ident = $lhs:ident > $rhs:expr) => { - cpa!($scope, $out = greater($lhs, $rhs)) - }; - // out = greater(lhs, rhs) - ($scope:expr, $out:ident = greater($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Greater( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs >= rhs - ($scope:expr, $out:ident = $lhs:ident >= $rhs:expr) => { - cpa!($scope, $out = greater_equal($lhs, $rhs)) - }; - // out = greater_equal(lhs, rhs) - ($scope:expr, $out:ident = greater_equal($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::GreaterEqual( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs < rhs - ($scope:expr, $out:ident = $lhs:ident < $rhs:expr) => { - cpa!($scope, $out = lower($lhs, $rhs)) - }; - // out = lower(lhs, rhs) - ($scope:expr, $out:ident = lower($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Lower( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs <= rhs - ($scope:expr, $out:ident = $lhs:ident <= $rhs:expr) => { - cpa!($scope, $out = lower_equal($lhs, $rhs)) - }; - // out = lower_equal(lhs, rhs) - ($scope:expr, $out:ident = lower_equal($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::LowerEqual( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = max(lhs, rhs) - ($scope:expr, $out:ident = max($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Max( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = min(lhs, rhs) - ($scope:expr, $out:ident = min($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Min( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = lhs[rhs] - ($scope:expr, $out:ident = $lhs:ident[$rhs:expr]) => { - cpa!($scope, $out = index($lhs, $rhs)) - }; - // out = index(lhs, rhs) - ($scope:expr, $out:ident = index($lhs:expr, $rhs:expr)) => { - $scope.register($crate::ir::Operator::Index( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = unchecked(lhs[rhs]) - ($scope:expr, $out:ident = unchecked($lhs:ident[$rhs:expr])) => { - $scope.register($crate::ir::Operator::UncheckedIndex( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out[lhs] = rhs - ($scope:expr, $out:ident[$lhs:ident] = $rhs:expr) => { - $scope.register($crate::ir::Operator::IndexAssign( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // unchecked(out[lhs]) = rhs - ($scope:expr, unchecked($out:ident[$lhs:ident]) = $rhs:expr) => { - $scope.register($crate::ir::Operator::UncheckedIndexAssign( - cpa!(binary $lhs, $rhs, $out) - )); - }; - // out = |input| - ($scope:expr, $out:ident = |$input:ident|) => { - cpa!($scope, $out = abs($input)) - }; - // out = abs(input) - ($scope:expr, $out:ident = abs($input:expr)) => { - $scope.register($crate::ir::Operator::Abs( - cpa!(unary $input, $out) - )); - }; - // out = exp(input) - ($scope:expr, $out:ident = exp($input:expr)) => { - $scope.register($crate::ir::Operator::Exp( - cpa!(unary $input, $out) - )); - }; - // out = log(input) - ($scope:expr, $out:ident = log($input:expr)) => { - $scope.register($crate::ir::Operator::Log( - cpa!(unary $input, $out) - )); - }; - // out = log1p(input) - ($scope:expr, $out:ident = log1p($input:expr)) => { - $scope.register($crate::ir::Operator::Log1p( - cpa!(unary $input, $out) - )); - }; - // out = cos(input) - ($scope:expr, $out:ident = cos($input:expr)) => { - $scope.register($crate::ir::Operator::Cos( - cpa!(unary $input, $out) - )); - }; - // out = sin(input) - ($scope:expr, $out:ident = sin($input:expr)) => { - $scope.register($crate::ir::Operator::Sin( - cpa!(unary $input, $out) - )); - }; - // out = tanh(input) - ($scope:expr, $out:ident = tanh($input:expr)) => { - $scope.register($crate::ir::Operator::Tanh( - cpa!(unary $input, $out) - )); - }; - // out = sqrt(input) - ($scope:expr, $out:ident = sqrt($input:expr)) => { - $scope.register($crate::ir::Operator::Sqrt( - cpa!(unary $input, $out) - )); - }; - // out = floor(input) - ($scope:expr, $out:ident = floor($input:expr)) => { - $scope.register($crate::ir::Operator::Floor( - cpa!(unary $input, $out) - )); - }; - // out = ceil(input) - ($scope:expr, $out:ident = ceil($input:expr)) => { - $scope.register($crate::ir::Operator::Ceil( - cpa!(unary $input, $out) - )); - }; - // out = erf(input) - ($scope:expr, $out:ident = erf($input:expr)) => { - $scope.register($crate::ir::Operator::Erf( - cpa!(unary $input, $out) - )); - }; - // out = input - ($scope:expr, $out:ident = $input:ident) => { - $scope.register($crate::ir::Operator::Assign( - cpa!(unary $input, $out) - )); - }; - // out = vec4(a, b, c, d) - ($scope:expr, $out:ident = vec4($a:ident,$b:ident,$c:ident,$d:ident)) => { - let i = $scope.zero(Elem::UInt); - cpa!($scope, $out[i] = $a); - cpa!($scope, i = i + 1u32); - cpa!($scope, $out[i] = $b); - cpa!($scope, i = i + 1u32); - cpa!($scope, $out[i] = $c); - cpa!($scope, i = i + 1u32); - cpa!($scope, $out[i] = $d); - }; - // out = input - ($scope:expr, $out:ident = $input:ident) => { - cpa!($scope, $out = cast($input)) - }; - // out = cast(input) - ($scope:expr, $out:ident = cast($input:expr)) => { - $scope.register($crate::ir::Operator::Assign( - cpa!(unary $input, $out) - )); - }; - // out = shape(tensor, dim) - ($scope:expr, $out:ident = shape($input:expr, $dim:expr)) => { - $scope.register($crate::ir::Metadata::Shape { - dim: $dim.into(), - var: $input.into(), - out: $out.into(), - }); - }; - // out = stride(tensor, dim) - ($scope:expr, $out:ident = stride($input:expr, $dim:expr)) => { - $scope.register($crate::ir::Metadata::Stride { - dim: $dim.into(), - var: $input.into(), - out: $out.into(), - }); - }; - // out = len(array) - ($scope:expr, $out:ident = len($input:expr)) => { - $scope.register($crate::ir::Metadata::Length { - var: $input.into(), - out: $out.into(), - }); - }; - // range(start, end).for_each(|i, scope| { ... }) - ($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => { - $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), $arg); - }; - // range(start, end, unroll).for_each(|i, scope| { ... }) - ($scope:expr, range($start:expr, $end:expr, $unroll:expr).for_each($arg:expr)) => { - if $unroll { - $crate::ir::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), $arg); - } else { - $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), $arg); - } - }; - // loop(|scope| { ... }) - ($scope:expr, loop($arg:expr)) => { - $crate::ir::Loop::register($scope, $arg); - }; - // if (cond).then(|scope| { ... }) - ($scope:expr, if ($cond:expr).then($arg:expr)) => { - $crate::ir::If::register($scope, $cond.into(), $arg); - }; - // if (cond).then(|scope| { ... }).else(|scope| { ... }) - ($scope:expr, if ($cond:expr).then($arg_if:expr).else($arg_else:expr)) => { - $crate::ir::IfElse::register($scope, $cond.into(), $arg_if, $arg_else); - }; - (binary $lhs:expr, $rhs:expr, $out:expr) => { - $crate::ir::BinaryOperator { - lhs: $lhs.into(), - rhs: $rhs.into(), - out: $out.into(), - } - }; - (unary $input:expr, $out:expr) => { - $crate::ir::UnaryOperator { - input: $input.into(), - out: $out.into(), - } - }; -} - -impl From for Variable { - fn from(value: bool) -> Self { - Self::ConstantScalar { - value: if value { 1.0 } else { 0.0 }, - elem: super::Elem::Bool, - } - } -} - -impl From for Variable { - fn from(value: i32) -> Self { - Self::ConstantScalar { - value: value as f64, - elem: super::Elem::Int(super::IntKind::I32), - } - } -} - -impl From for Variable { - fn from(value: i64) -> Self { - Self::ConstantScalar { - value: value as f64, - elem: super::Elem::Int(super::IntKind::I64), - } - } -} - -impl From for Variable { - fn from(value: f32) -> Self { - Self::ConstantScalar { - value: value as f64, - elem: super::Elem::Float(super::FloatKind::F32), - } - } -} - -impl From for Variable { - fn from(value: f64) -> Self { - Self::ConstantScalar { - value, - elem: super::Elem::Float(super::FloatKind::F64), - } - } -} - -impl From for Variable { - fn from(value: u32) -> Self { - Self::ConstantScalar { - value: value as f64, - elem: super::Elem::UInt, - } - } -} - -impl From for Variable { - fn from(value: usize) -> Self { - Self::ConstantScalar { - value: value as f64, - elem: super::Elem::UInt, - } - } -} - -pub(crate) use cpa; diff --git a/crates/burn-cube/src/ir/mod.rs b/crates/burn-cube/src/ir/mod.rs deleted file mode 100644 index 2e684fb0a5..0000000000 --- a/crates/burn-cube/src/ir/mod.rs +++ /dev/null @@ -1,25 +0,0 @@ -mod branch; -mod cmma; -mod macros; -mod operation; -mod procedure; -mod processing; -mod scope; -mod shader; -mod subcube; -mod synchronization; -mod variable; -mod vectorization; - -pub use branch::*; -pub use cmma::*; -pub use operation::*; -pub use procedure::*; -pub use scope::*; -pub use shader::*; -pub use subcube::*; -pub use synchronization::*; -pub use variable::*; -pub use vectorization::*; - -pub(crate) use macros::cpa; diff --git a/crates/burn-cube/src/ir/operation.rs b/crates/burn-cube/src/ir/operation.rs deleted file mode 100644 index f5ed445bf8..0000000000 --- a/crates/burn-cube/src/ir/operation.rs +++ /dev/null @@ -1,184 +0,0 @@ -use super::{Branch, CoopMma, Procedure, Subcube, Synchronization, Variable}; -use serde::{Deserialize, Serialize}; - -/// All operations that can be used in a GPU compute shader. -/// -/// Notes: -/// -/// [Operator] and [Procedure] can be vectorized, but [Metadata] and [Branch] can't. -/// Therefore, during tracing, only operators and procedures can be registered. -/// -/// [Procedure] expansions can safely use all operation variants. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(dead_code, missing_docs)] // Some variants might not be used with different flags -pub enum Operation { - Operator(Operator), - Procedure(Procedure), - Metadata(Metadata), - Branch(Branch), - Synchronization(Synchronization), - Subcube(Subcube), - CoopMma(CoopMma), -} - -/// All operators that can be used in a GPU compute shader. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(dead_code, missing_docs)] // Some variants might not be used with different flags -pub enum Operator { - Add(BinaryOperator), - Fma(FmaOperator), - Sub(BinaryOperator), - Mul(BinaryOperator), - Div(BinaryOperator), - Abs(UnaryOperator), - Exp(UnaryOperator), - Log(UnaryOperator), - Log1p(UnaryOperator), - Cos(UnaryOperator), - Sin(UnaryOperator), - Tanh(UnaryOperator), - Powf(BinaryOperator), - Sqrt(UnaryOperator), - Floor(UnaryOperator), - Ceil(UnaryOperator), - Erf(UnaryOperator), - Recip(UnaryOperator), - Equal(BinaryOperator), - NotEqual(BinaryOperator), - Lower(BinaryOperator), - Clamp(ClampOperator), - Greater(BinaryOperator), - LowerEqual(BinaryOperator), - GreaterEqual(BinaryOperator), - Assign(UnaryOperator), - Modulo(BinaryOperator), - Index(BinaryOperator), - Slice(SliceOperator), - UncheckedIndex(BinaryOperator), - IndexAssign(BinaryOperator), - UncheckedIndexAssign(BinaryOperator), - And(BinaryOperator), - Or(BinaryOperator), - Not(UnaryOperator), - Max(BinaryOperator), - Min(BinaryOperator), - BitwiseAnd(BinaryOperator), - BitwiseXor(BinaryOperator), - ShiftLeft(BinaryOperator), - ShiftRight(BinaryOperator), - Remainder(BinaryOperator), -} - -/// All metadata that can be access in a shader. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub enum Metadata { - /// The stride of an array at the given dimension. - Stride { - dim: Variable, - var: Variable, - out: Variable, - }, - /// The shape of an array at the given dimension. - Shape { - dim: Variable, - var: Variable, - out: Variable, - }, - Length { - var: Variable, - out: Variable, - }, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct BinaryOperator { - pub lhs: Variable, - pub rhs: Variable, - pub out: Variable, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct UnaryOperator { - pub input: Variable, - pub out: Variable, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct InitOperator { - pub out: Variable, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct ClampOperator { - pub input: Variable, - pub min_value: Variable, - pub max_value: Variable, - pub out: Variable, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct SliceOperator { - pub input: Variable, - pub start: Variable, - pub end: Variable, - pub out: Variable, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct ReadGlobalOperator { - pub variable: Variable, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct ReadGlobalWithLayoutOperator { - pub variable: Variable, - pub tensor_read_pos: usize, - pub tensor_layout_pos: usize, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct FmaOperator { - pub a: Variable, - pub b: Variable, - pub c: Variable, - pub out: Variable, -} - -impl From for Operation { - fn from(val: Operator) -> Self { - Operation::Operator(val) - } -} - -impl From for Operation { - fn from(value: Branch) -> Self { - Self::Branch(value) - } -} - -impl From for Operation { - fn from(value: Synchronization) -> Self { - Self::Synchronization(value) - } -} - -impl From for Operation { - fn from(val: Metadata) -> Self { - Operation::Metadata(val) - } -} - -impl From for Operation { - fn from(val: Procedure) -> Self { - Operation::Procedure(val) - } -} diff --git a/crates/burn-cube/src/ir/procedure/assign.rs b/crates/burn-cube/src/ir/procedure/assign.rs deleted file mode 100644 index e72f23dceb..0000000000 --- a/crates/burn-cube/src/ir/procedure/assign.rs +++ /dev/null @@ -1,71 +0,0 @@ -use crate::ir::{macros::cpa, Scope, Variable, Vectorization}; -use serde::{Deserialize, Serialize}; - -/// Assign value to a variable based on a given condition. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct ConditionalAssign { - pub cond: Variable, - pub lhs: Variable, - pub rhs: Variable, - pub out: Variable, -} - -impl ConditionalAssign { - #[allow(missing_docs)] - pub fn expand(self, scope: &mut Scope) { - let cond = self.cond; - let lhs = self.lhs; - let rhs = self.rhs; - let out = self.out; - - let index_var = - |scope: &mut Scope, var: Variable, index: usize| match var.item().vectorization == 1 { - true => var, - false => { - let out = scope.create_local(var.item().elem()); - cpa!(scope, out = var[index]); - out - } - }; - - let mut assign_index = |index: usize| { - let cond = index_var(scope, cond, index); - - cpa!(scope, if (cond).then(|scope| { - let lhs = index_var(scope, lhs, index); - let index: Variable = index.into(); - cpa!(scope, out[index] = lhs); - }).else(|scope| { - let rhs = index_var(scope, rhs, index); - let index: Variable = index.into(); - cpa!(scope, out[index] = rhs); - })); - }; - - let vectorization = out.item().vectorization; - match vectorization == 1 { - true => { - cpa!(scope, if (cond).then(|scope| { - cpa!(scope, out = lhs); - }).else(|scope| { - cpa!(scope, out = rhs); - })); - } - false => { - for i in 0..vectorization { - assign_index(i as usize); - } - } - }; - } - - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - Self { - cond: self.cond.vectorize(vectorization), - lhs: self.lhs.vectorize(vectorization), - rhs: self.rhs.vectorize(vectorization), - out: self.out.vectorize(vectorization), - } - } -} diff --git a/crates/burn-cube/src/ir/procedure/base.rs b/crates/burn-cube/src/ir/procedure/base.rs deleted file mode 100644 index 9fec84adfc..0000000000 --- a/crates/burn-cube/src/ir/procedure/base.rs +++ /dev/null @@ -1,42 +0,0 @@ -use super::{ - CheckedIndex, CheckedIndexAssign, ConditionalAssign, IndexOffsetGlobalWithLayout, ReadGlobal, - ReadGlobalWithLayout, WriteGlobal, -}; -use crate::ir::Vectorization; -use serde::{Deserialize, Serialize}; - -/// Tensor operations that can't be executed with a simple [operator](super::super::Operator) should use a -/// procedure. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub enum Procedure { - ReadGlobalWithLayout(ReadGlobalWithLayout), - IndexOffsetGlobalWithLayout(IndexOffsetGlobalWithLayout), - ReadGlobal(ReadGlobal), - WriteGlobal(WriteGlobal), - CheckedIndex(CheckedIndex), - CheckedIndexAssign(CheckedIndexAssign), - ConditionalAssign(ConditionalAssign), -} - -impl Procedure { - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - match self { - Procedure::ReadGlobalWithLayout(op) => { - Procedure::ReadGlobalWithLayout(op.vectorize(vectorization)) - } - Procedure::IndexOffsetGlobalWithLayout(op) => { - Procedure::IndexOffsetGlobalWithLayout(op.vectorize(vectorization)) - } - Procedure::ReadGlobal(op) => Procedure::ReadGlobal(op.vectorize(vectorization)), - Procedure::WriteGlobal(op) => Procedure::WriteGlobal(op.vectorize(vectorization)), - Procedure::CheckedIndex(proc) => Procedure::CheckedIndex(proc.vectorize(vectorization)), - Procedure::CheckedIndexAssign(proc) => { - Procedure::CheckedIndexAssign(proc.vectorize(vectorization)) - } - Procedure::ConditionalAssign(proc) => { - Procedure::ConditionalAssign(proc.vectorize(vectorization)) - } - } - } -} diff --git a/crates/burn-cube/src/ir/procedure/index.rs b/crates/burn-cube/src/ir/procedure/index.rs deleted file mode 100644 index 16bff18423..0000000000 --- a/crates/burn-cube/src/ir/procedure/index.rs +++ /dev/null @@ -1,74 +0,0 @@ -use crate::ir::{macros::cpa, Elem, Item, Scope, Variable, Vectorization}; -use serde::{Deserialize, Serialize}; - -/// Perform a check bound on the index (lhs) of value (rhs) -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct CheckedIndex { - pub lhs: Variable, - pub rhs: Variable, - pub out: Variable, -} - -impl CheckedIndex { - #[allow(missing_docs)] - pub fn expand(self, scope: &mut Scope) { - let lhs = self.lhs; - let rhs = self.rhs; - let out = self.out; - let array_len = scope.create_local(Item::new(Elem::UInt)); - let inside_bound = scope.create_local(Item::new(Elem::Bool)); - - cpa!(scope, array_len = len(lhs)); - cpa!(scope, inside_bound = rhs < array_len); - - cpa!(scope, if(inside_bound).then(|scope| { - cpa!(scope, out = unchecked(lhs[rhs])); - }).else(|scope| { - cpa!(scope, out = cast(0)); - })); - } - - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - Self { - lhs: self.lhs.vectorize(vectorization), - rhs: self.rhs.vectorize(vectorization), - out: self.out.vectorize(vectorization), - } - } -} - -/// Perform a check bound on the index (lhs) of output before assigning the value (rhs) -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct CheckedIndexAssign { - pub lhs: Variable, - pub rhs: Variable, - pub out: Variable, -} - -impl CheckedIndexAssign { - #[allow(missing_docs)] - pub fn expand(self, scope: &mut Scope) { - let lhs = self.lhs; - let rhs = self.rhs; - let out = self.out; - let array_len = scope.create_local(Item::new(Elem::UInt)); - let inside_bound = scope.create_local(Item::new(Elem::Bool)); - - cpa!(scope, array_len = len(out)); - cpa!(scope, inside_bound = lhs < array_len); - - cpa!(scope, if(inside_bound).then(|scope| { - cpa!(scope, unchecked(out[lhs]) = rhs); - })); - } - - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - Self { - lhs: self.lhs.vectorize(vectorization), - rhs: self.rhs.vectorize(vectorization), - out: self.out.vectorize(vectorization), - } - } -} diff --git a/crates/burn-cube/src/ir/procedure/mod.rs b/crates/burn-cube/src/ir/procedure/mod.rs deleted file mode 100644 index a537fc04dc..0000000000 --- a/crates/burn-cube/src/ir/procedure/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod assign; -mod base; -mod index; -mod read; -mod write; - -pub use assign::*; -pub use base::*; -pub use index::*; -pub use read::*; -pub use write::*; diff --git a/crates/burn-cube/src/ir/procedure/read.rs b/crates/burn-cube/src/ir/procedure/read.rs deleted file mode 100644 index e63065d53b..0000000000 --- a/crates/burn-cube/src/ir/procedure/read.rs +++ /dev/null @@ -1,200 +0,0 @@ -use super::super::{cpa, Elem, Item, Operator, Scope, Variable}; -use crate::ir::{BinaryOperator, Vectorization}; -use serde::{Deserialize, Serialize}; - -/// Read a global array. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct ReadGlobal { - /// The array to be read. - pub global: Variable, - /// The output variable to write the result. - pub out: Variable, - /// The reference position index. - pub position: Variable, -} - -/// Read a global array with the given layout. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub struct ReadGlobalWithLayout { - /// The array to be read. - pub globals: Vec, - /// The output variable to write the result. - pub outs: Vec, - /// The layout to be used. - pub layout: Variable, - /// The reference position index. - pub position: Variable, -} - -impl ReadGlobal { - #[allow(missing_docs)] - pub fn expand(self, scope: &mut Scope) { - scope.register(Operator::Index(BinaryOperator { - lhs: self.global, - rhs: self.position, - out: self.out, - })); - } - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - Self { - global: self.global.vectorize(vectorization), - out: self.out.vectorize(vectorization), - position: self.position, - } - } -} - -impl ReadGlobalWithLayout { - /// Try to merge two reads together reducing branching. - pub fn try_merge(&self, other: &Self) -> Option { - // Can only merge two reads when they share the same reference layout. - if self.layout != other.layout { - return None; - } - - if self.position != other.position { - return None; - } - - let mut globals = Vec::with_capacity(self.globals.len() + other.globals.len()); - globals.extend(&self.globals); - globals.extend(&other.globals); - - let mut outs = Vec::with_capacity(self.outs.len() + other.outs.len()); - outs.extend(&self.outs); - outs.extend(&other.outs); - - Some(Self { - globals, - outs, - layout: self.layout, - position: self.position, - }) - } - - #[allow(missing_docs)] - pub fn expand(self, scope: &mut Scope) { - let outputs = self.outs; - let tensors = self.globals; - let indexes = tensors - .iter() - .map(|_| scope.create_local(Elem::UInt)) - .collect::>(); - - IndexOffsetGlobalWithLayout { - tensors: tensors.clone(), - layout: self.layout, - indexes: indexes.clone(), - position: self.position, - dim_start: 0u32.into(), - dim_end: Variable::Rank, - } - .expand(scope); - - for i in 0..outputs.len() { - let tensor = tensors[i]; - let output = outputs[i]; - let index = indexes[i]; - - cpa!(scope, output = tensor[index]); - } - } - - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - Self { - globals: self - .globals - .iter() - .map(|g| g.vectorize(vectorization)) - .collect(), - layout: self.layout.vectorize(vectorization), - outs: self - .outs - .iter() - .map(|o| o.vectorize(vectorization)) - .collect(), - position: self.position, - } - } -} - -/// Calculate the index offset for all tensor variables provided compatible with the given layout. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct IndexOffsetGlobalWithLayout { - /// Tensor [variables](Variable), same length as [indexes](Self::indexes). - pub tensors: Vec, - /// Offsets that are going to be written to. - pub indexes: Vec, - /// Reference layout. - pub layout: Variable, - /// Position index that corresponds to the reference layout. - /// - /// All other indexes will be made to be compatible with this one. - pub position: Variable, - pub dim_start: Variable, - pub dim_end: Variable, -} - -impl IndexOffsetGlobalWithLayout { - #[allow(missing_docs)] - pub fn expand(self, scope: &mut Scope) { - let layout = self.layout; - let index_item_ty = Item::new(Elem::UInt); - let offset_ref = self.position; - let zero: Variable = 0u32.into(); - let vectorization_factor: u8 = self.tensors[0].item().vectorization; - let vectorization_factor: Variable = (vectorization_factor as u32).into(); - for index in self.indexes.iter() { - cpa!(scope, index = zero); - } - - cpa!( - scope, - range(self.dim_start, self.dim_end).for_each(|i, scope| { - let stride_layout = scope.create_local(index_item_ty); - let ogwl = scope.create_local(index_item_ty); - - cpa!(scope, stride_layout = stride(layout, i)); - cpa!(scope, ogwl = offset_ref * vectorization_factor); - cpa!(scope, ogwl = ogwl / stride_layout); - - for (tensor, index) in self.tensors.iter().zip(self.indexes.iter()) { - let stride = scope.create_local(index_item_ty); - let shape = scope.create_local(index_item_ty); - let tmp = scope.create_local(index_item_ty); - - cpa!(scope, stride = stride(tensor, i)); - cpa!(scope, shape = shape(tensor, i)); - - cpa!(scope, tmp = ogwl % shape); - cpa!(scope, tmp = tmp * stride); - cpa!(scope, index = index + tmp); - } - }) - ); - - for index in self.indexes { - cpa!(scope, index = index / vectorization_factor); - } - } - - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - Self { - tensors: self - .tensors - .iter() - .map(|t| t.vectorize(vectorization)) - .collect(), - indexes: self - .indexes - .iter() - .map(|t| t.vectorize(vectorization)) - .collect(), - layout: self.layout.vectorize(vectorization), - position: self.position.vectorize(vectorization), - dim_start: self.dim_start, - dim_end: self.dim_end, - } - } -} diff --git a/crates/burn-cube/src/ir/procedure/write.rs b/crates/burn-cube/src/ir/procedure/write.rs deleted file mode 100644 index bf25d433d8..0000000000 --- a/crates/burn-cube/src/ir/procedure/write.rs +++ /dev/null @@ -1,30 +0,0 @@ -use crate::ir::{macros::cpa, Scope, Variable, Vectorization}; -use serde::{Deserialize, Serialize}; - -/// Write to a global array. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct WriteGlobal { - pub input: Variable, - pub global: Variable, - pub position: Variable, -} - -impl WriteGlobal { - #[allow(missing_docs)] - pub fn expand(self, scope: &mut Scope) { - let output = self.global; - let input = self.input; - let position = self.position; - - cpa!(scope, output[position] = input); - } - - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - Self { - input: self.input.vectorize(vectorization), - global: self.global.vectorize(vectorization), - position: self.position, - } - } -} diff --git a/crates/burn-cube/src/ir/processing.rs b/crates/burn-cube/src/ir/processing.rs deleted file mode 100644 index d704e832b1..0000000000 --- a/crates/burn-cube/src/ir/processing.rs +++ /dev/null @@ -1,104 +0,0 @@ -use super::{Operation, Procedure, Variable}; -use crate::ir::ReadGlobalWithLayout; - -/// Information necessary when compiling a scope. -pub struct ScopeProcessing { - /// The variable declarations. - pub variables: Vec, - /// The operations. - pub operations: Vec, -} - -impl ScopeProcessing { - /// Optimize the [variables](Variable) and [operations](Operation). - /// - /// ## Notes: - /// - /// This should be called once right after the creation of the type. - /// If you built this type from the [scope process function](super::Scope::process), you don't have to - /// call it again. - pub fn optimize(self) -> Self { - self.merge_read_global_with_layout() - } - - /// Merge all compatible [read global with layout procedures](ReadGlobalWithLayout). - fn merge_read_global_with_layout(mut self) -> Self { - #[derive(Default)] - struct Optimization { - merged_procs: Vec, - } - - #[derive(new)] - struct MergedProc { - proc: ReadGlobalWithLayout, - positions: Vec, - } - - impl Optimization { - fn new(existing_operations: &[Operation]) -> Self { - let mut optim = Self::default(); - - existing_operations - .iter() - .enumerate() - .for_each(|(position, operation)| { - if let Operation::Procedure(Procedure::ReadGlobalWithLayout(proc)) = - operation - { - optim.register_one(proc, position); - } - }); - - optim - } - - fn register_one(&mut self, proc: &ReadGlobalWithLayout, position: usize) { - for merged_proc in self.merged_procs.iter_mut() { - if let Some(merged) = merged_proc.proc.try_merge(proc) { - merged_proc.proc = merged; - merged_proc.positions.push(position); - return; - } - } - - self.merged_procs - .push(MergedProc::new(proc.clone(), vec![position])); - } - - fn apply(self, existing_operations: Vec) -> Vec { - if self.merged_procs.is_empty() { - return existing_operations; - } - - let mut operations = Vec::with_capacity(existing_operations.len()); - - for (position, operation) in existing_operations.into_iter().enumerate() { - let mut is_merged_op = false; - - for merged_proc in self.merged_procs.iter() { - if merged_proc.positions[0] == position { - operations.push(Operation::Procedure(Procedure::ReadGlobalWithLayout( - merged_proc.proc.clone(), - ))); - is_merged_op = true; - } - - if merged_proc.positions.contains(&position) { - is_merged_op = true; - } - } - - if !is_merged_op { - operations.push(operation); - } - } - - operations - } - } - - let optimization = Optimization::new(&self.operations); - self.operations = optimization.apply(self.operations); - self - } -} diff --git a/crates/burn-cube/src/ir/scope.rs b/crates/burn-cube/src/ir/scope.rs deleted file mode 100644 index 2e6675981e..0000000000 --- a/crates/burn-cube/src/ir/scope.rs +++ /dev/null @@ -1,437 +0,0 @@ -use super::{ - cpa, processing::ScopeProcessing, Elem, IndexOffsetGlobalWithLayout, Item, Matrix, Operation, - Operator, Procedure, ReadGlobal, ReadGlobalWithLayout, UnaryOperator, Variable, Vectorization, - WriteGlobal, -}; -use serde::{Deserialize, Serialize}; - -/// The scope is the main [operation](Operation) and [variable](Variable) container that simplify -/// the process of reading inputs, creating local variables and adding new operations. -/// -/// Notes: -/// -/// This type isn't responsible for creating [shader bindings](super::Binding) and figuring out which -/// variable can be written to. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub struct Scope { - pub depth: u8, - pub operations: Vec, - locals: Vec, - matrices: Vec, - slices: Vec, - shared_memories: Vec, - local_arrays: Vec, - reads_global: Vec<(Variable, ReadingStrategy, Variable, Variable)>, - index_offset_with_output_layout_position: Vec, - writes_global: Vec<(Variable, Variable, Variable)>, - reads_scalar: Vec<(Variable, Variable)>, - pub layout_ref: Option, - undeclared: u16, -} - -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub enum ReadingStrategy { - /// Each element will be read in a way to be compatible with the output layout. - OutputLayout, - /// Keep the current layout. - Plain, -} - -impl Scope { - /// Create a scope that is at the root of a - /// [kernel definition](crate::ir::KernelDefinition). - /// - /// A local scope can be created with the [child](Self::child) method. - pub fn root() -> Self { - Self { - depth: 0, - operations: Vec::new(), - locals: Vec::new(), - matrices: Vec::new(), - slices: Vec::new(), - local_arrays: Vec::new(), - shared_memories: Vec::new(), - reads_global: Vec::new(), - index_offset_with_output_layout_position: Vec::new(), - writes_global: Vec::new(), - reads_scalar: Vec::new(), - layout_ref: None, - undeclared: 0, - } - } - - /// Create a variable initialized at zero. - pub fn zero>(&mut self, item: I) -> Variable { - let local = self.create_local(item); - let zero: Variable = 0u32.into(); - cpa!(self, local = zero); - local - } - - /// Create a variable initialized at some value. - pub fn create_with_value(&mut self, value: E, item: I) -> Variable - where - E: Into, - I: Into + Copy, - { - let local = self.create_local(item); - let value = Variable::ConstantScalar { - value: value.into(), - elem: item.into().elem(), - }; - cpa!(self, local = value); - local - } - - /// Create a matrix variable - pub fn create_matrix(&mut self, matrix: Matrix) -> Variable { - let index = self.matrices.len() as u16; - let variable = Variable::Matrix { - id: index, - mat: matrix, - }; - self.matrices.push(variable); - variable - } - - /// Create a slice variable - pub fn create_slice(&mut self, item: Item) -> Variable { - let id = self.slices.len() as u16; - let variable = Variable::Slice { - id, - item, - depth: self.depth, - }; - self.slices.push(variable); - variable - } - - /// Create a local variable of the given [item type](Item). - pub fn create_local>(&mut self, item: I) -> Variable { - let item = item.into(); - let index = self.new_local_index(); - let local = Variable::Local { - id: index, - item, - depth: self.depth, - }; - self.locals.push(local); - local - } - - /// Create a new local variable, but doesn't perform the declaration. - /// - /// Useful for _for loops_ and other algorithms that require the control over initialization. - pub fn create_local_undeclared(&mut self, item: Item) -> Variable { - let index = self.new_local_index(); - let local = Variable::Local { - id: index, - item, - depth: self.depth, - }; - self.undeclared += 1; - local - } - - /// Reads an input array to a local variable. - /// - /// The index refers to the argument position of the array in the compute shader. - pub fn read_array>( - &mut self, - index: u16, - item: I, - position: Variable, - ) -> Variable { - self.read_input_strategy(index, item.into(), ReadingStrategy::OutputLayout, position) - } - - /// Add the procedure into the scope. - pub fn index_offset_with_output_layout(&mut self, proc: IndexOffsetGlobalWithLayout) { - self.index_offset_with_output_layout_position - .push(self.operations.len()); - self.operations - .push(Procedure::IndexOffsetGlobalWithLayout(proc).into()); - } - - /// Reads an input scalar to a local variable. - /// - /// The index refers to the scalar position for the same [element](Elem) type. - pub fn read_scalar(&mut self, index: u16, elem: Elem) -> Variable { - let local = Variable::LocalScalar { - id: self.new_local_scalar_index(), - elem, - depth: self.depth, - }; - let scalar = Variable::GlobalScalar { id: index, elem }; - - self.reads_scalar.push((local, scalar)); - - local - } - - /// Retrieve the last local variable that was created. - pub fn last_local_index(&self) -> Option<&Variable> { - self.locals.last() - } - - /// Vectorize the scope using the [vectorization](Vectorization) type. - /// - /// Notes: - /// - /// Scopes created _during_ compilation (after the tracing is done) should not be vectorized. - pub fn vectorize(&mut self, vectorization: Vectorization) { - self.operations - .iter_mut() - .for_each(|op| *op = op.vectorize(vectorization)); - self.locals - .iter_mut() - .for_each(|var| *var = var.vectorize(vectorization)); - self.reads_global - .iter_mut() - .for_each(|(input, _, output, _position)| { - *input = input.vectorize(vectorization); - *output = output.vectorize(vectorization); - }); - self.writes_global - .iter_mut() - .for_each(|(input, output, _)| { - *input = input.vectorize(vectorization); - *output = output.vectorize(vectorization); - }); - } - - /// Writes a variable to given output. - /// - /// Notes: - /// - /// This should only be used when doing compilation. - pub fn write_global(&mut self, input: Variable, output: Variable, position: Variable) { - // This assumes that all outputs have the same layout - if self.layout_ref.is_none() { - self.layout_ref = Some(output); - } - self.writes_global.push((input, output, position)); - } - - /// Writes a variable to given output. - /// - /// Notes: - /// - /// This should only be used when doing compilation. - pub fn write_global_custom(&mut self, output: Variable) { - // This assumes that all outputs have the same layout - if self.layout_ref.is_none() { - self.layout_ref = Some(output); - } - } - - /// Update the [reading strategy](ReadingStrategy) for an input array. - /// - /// Notes: - /// - /// This should only be used when doing compilation. - pub(crate) fn update_read(&mut self, index: u16, strategy: ReadingStrategy) { - if let Some((_, strategy_old, _, _position)) = self - .reads_global - .iter_mut() - .find(|(var, _, _, _)| var.index() == Some(index)) - { - *strategy_old = strategy; - } - } - - #[allow(dead_code)] - pub fn read_globals(&self) -> Vec<(u16, ReadingStrategy)> { - self.reads_global - .iter() - .map(|(var, strategy, _, _)| match var { - Variable::GlobalInputArray { id, .. } => (*id, *strategy), - _ => panic!("Can only read global input arrays."), - }) - .collect() - } - - /// Register an [operation](Operation) into the scope. - pub fn register>(&mut self, operation: T) { - self.operations.push(operation.into()) - } - - /// Create an empty child scope. - pub fn child(&mut self) -> Self { - Self { - depth: self.depth + 1, - operations: Vec::new(), - locals: Vec::new(), - matrices: Vec::new(), - slices: Vec::new(), - shared_memories: Vec::new(), - local_arrays: Vec::new(), - reads_global: Vec::new(), - index_offset_with_output_layout_position: Vec::new(), - writes_global: Vec::new(), - reads_scalar: Vec::new(), - layout_ref: self.layout_ref, - undeclared: 0, - } - } - - /// Returns the variables and operations to be declared and executed. - /// - /// Notes: - /// - /// New operations and variables can be created within the same scope without having name - /// conflicts. - pub fn process(&mut self) -> ScopeProcessing { - self.undeclared += self.locals.len() as u16; - - let mut variables = Vec::new(); - core::mem::swap(&mut self.locals, &mut variables); - - for var in self.matrices.drain(..) { - variables.push(var); - } - for var in self.slices.drain(..) { - variables.push(var); - } - - for index in self.index_offset_with_output_layout_position.drain(..) { - if let Some(Operation::Procedure(Procedure::IndexOffsetGlobalWithLayout(proc))) = - self.operations.get_mut(index) - { - proc.layout = self.layout_ref.expect( - "Output should be set when processing an index offset with output layout.", - ); - } - } - - let mut operations = Vec::new(); - - for (input, strategy, local, position) in self.reads_global.drain(..) { - match strategy { - ReadingStrategy::OutputLayout => { - let output = self.layout_ref.expect( - "Output should be set when processing an input with output layout.", - ); - operations.push(Operation::Procedure(Procedure::ReadGlobalWithLayout( - ReadGlobalWithLayout { - globals: vec![input], - layout: output, - outs: vec![local], - position, - }, - ))); - } - ReadingStrategy::Plain => { - operations.push(Operation::Procedure(Procedure::ReadGlobal(ReadGlobal { - global: input, - out: local, - position, - }))) - } - } - } - - for (local, scalar) in self.reads_scalar.drain(..) { - operations.push( - Operator::Assign(UnaryOperator { - input: scalar, - out: local, - }) - .into(), - ); - variables.push(local); - } - - for op in self.operations.drain(..) { - operations.push(op); - } - - for (input, global, position) in self.writes_global.drain(..) { - operations.push(Operation::Procedure(Procedure::WriteGlobal(WriteGlobal { - input, - global, - position, - }))) - } - - ScopeProcessing { - variables, - operations, - } - .optimize() - } - - fn new_local_index(&self) -> u16 { - self.locals.len() as u16 + self.undeclared - } - - fn new_local_scalar_index(&self) -> u16 { - self.reads_scalar.len() as u16 - } - - fn new_shared_index(&self) -> u16 { - self.shared_memories.len() as u16 - } - - fn new_local_array_index(&self) -> u16 { - self.local_arrays.len() as u16 - } - - fn read_input_strategy( - &mut self, - index: u16, - item: Item, - strategy: ReadingStrategy, - position: Variable, - ) -> Variable { - let item_global = match item.elem() { - Elem::Bool => Item { - elem: Elem::UInt, - vectorization: item.vectorization, - }, - _ => item, - }; - let input = Variable::GlobalInputArray { - id: index, - item: item_global, - }; - let index = self.new_local_index(); - let local = Variable::Local { - id: index, - item, - depth: self.depth, - }; - self.reads_global.push((input, strategy, local, position)); - self.locals.push(local); - local - } - - /// Create a shared variable of the given [item type](Item). - pub fn create_shared>(&mut self, item: I, shared_memory_size: u32) -> Variable { - let item = item.into(); - let index = self.new_shared_index(); - let shared_memory = Variable::SharedMemory { - id: index, - item, - length: shared_memory_size, - }; - self.shared_memories.push(shared_memory); - shared_memory - } - - /// Create a local array of the given [item type](Item). - pub fn create_local_array>(&mut self, item: I, array_size: u32) -> Variable { - let item = item.into(); - let index = self.new_local_array_index(); - let local_array = Variable::LocalArray { - id: index, - item, - depth: self.depth, - length: array_size, - }; - self.local_arrays.push(local_array); - local_array - } -} diff --git a/crates/burn-cube/src/ir/shader.rs b/crates/burn-cube/src/ir/shader.rs deleted file mode 100644 index e5d9364159..0000000000 --- a/crates/burn-cube/src/ir/shader.rs +++ /dev/null @@ -1,148 +0,0 @@ -use super::{Scope, Vectorization}; -use crate::SUBCUBE_DIM_APPROX; -use serde::{Deserialize, Serialize}; -use std::fmt::Display; - -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] -#[allow(missing_docs)] -pub enum Location { - Storage, - Cube, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)] -#[allow(missing_docs)] -pub enum Visibility { - Read, - ReadWrite, -} - -#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)] -#[allow(missing_docs)] -pub enum FloatKind { - F16, - BF16, - F32, - F64, -} - -#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)] -#[allow(missing_docs)] -pub enum IntKind { - I32, - I64, -} - -#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)] -#[allow(missing_docs)] -pub enum Elem { - Float(FloatKind), - Int(IntKind), - UInt, - Bool, -} - -impl From for Item { - fn from(val: Elem) -> Self { - Item::new(val) - } -} - -#[cfg(feature = "tensor")] -impl From for Elem { - fn from(dtype: burn_tensor::DType) -> Self { - match dtype { - burn_tensor::DType::F64 => Elem::Float(FloatKind::F64), - burn_tensor::DType::F32 => Elem::Float(FloatKind::F32), - burn_tensor::DType::F16 => Elem::Float(FloatKind::F16), - burn_tensor::DType::BF16 => Elem::Float(FloatKind::BF16), - burn_tensor::DType::I64 => Elem::Int(IntKind::I64), - burn_tensor::DType::I32 => Elem::Int(IntKind::I32), - burn_tensor::DType::I16 => panic!("i16 isn't supported yet."), - burn_tensor::DType::I8 => panic!("i8 isn't supported yet."), - burn_tensor::DType::U64 => Elem::UInt, - burn_tensor::DType::U32 => Elem::UInt, - burn_tensor::DType::U8 => panic!("u8 isn't supported yet."), - burn_tensor::DType::Bool => Elem::Bool, - burn_tensor::DType::QFloat(_) => panic!("quantized type is not supported yet."), - } - } -} - -impl Display for Elem { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - // NOTE: we'll eventually want to differentiate between int/float types - Self::Float(_) => f.write_str("float"), - Self::Int(_) => f.write_str("int"), - Self::UInt => f.write_str("uint"), - Self::Bool => f.write_str("bool"), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize, Hash)] -pub struct Item { - pub elem: Elem, - pub vectorization: Vectorization, -} - -impl Item { - /// Fetch the elem of the item. - pub fn elem(&self) -> Elem { - self.elem - } - - /// Create a new item without vectorization - pub fn new(elem: Elem) -> Self { - Self { - elem, - vectorization: 1, - } - } - - /// Create a new item with vectorization - pub fn vectorized(elem: Elem, vectorization: Vectorization) -> Self { - Self { - elem, - vectorization, - } - } -} - -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] -#[allow(missing_docs)] -pub struct Binding { - pub location: Location, - pub visibility: Visibility, - pub item: Item, - pub size: Option, -} - -#[derive(new, Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize, Hash)] -#[allow(missing_docs)] -pub struct CubeDim { - pub x: u32, - pub y: u32, - pub z: u32, -} - -impl Default for CubeDim { - fn default() -> Self { - Self { - x: SUBCUBE_DIM_APPROX as u32, - y: SUBCUBE_DIM_APPROX as u32, - z: 1, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[allow(missing_docs)] -pub struct KernelDefinition { - pub inputs: Vec, - pub outputs: Vec, - pub named: Vec<(String, Binding)>, - pub cube_dim: CubeDim, - pub body: Scope, -} diff --git a/crates/burn-cube/src/ir/subcube.rs b/crates/burn-cube/src/ir/subcube.rs deleted file mode 100644 index 36dd2800ab..0000000000 --- a/crates/burn-cube/src/ir/subcube.rs +++ /dev/null @@ -1,21 +0,0 @@ -use super::{BinaryOperator, InitOperator, UnaryOperator}; -use serde::{Deserialize, Serialize}; - -/// All subcube operations. -/// -/// Note that not all backends support subcube (warp/subgroup) operations. Use the [runtime flag](crate::Feature::Subcube). -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(dead_code, missing_docs)] // Some variants might not be used with different flags -pub enum Subcube { - Elect(InitOperator), - All(UnaryOperator), - Any(UnaryOperator), - Broadcast(BinaryOperator), - Sum(UnaryOperator), - Prod(UnaryOperator), - And(UnaryOperator), - Or(UnaryOperator), - Xor(UnaryOperator), - Min(UnaryOperator), - Max(UnaryOperator), -} diff --git a/crates/burn-cube/src/ir/synchronization.rs b/crates/burn-cube/src/ir/synchronization.rs deleted file mode 100644 index 1db20c9b99..0000000000 --- a/crates/burn-cube/src/ir/synchronization.rs +++ /dev/null @@ -1,9 +0,0 @@ -use serde::{Deserialize, Serialize}; - -/// All synchronization types. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub enum Synchronization { - // Synchronizize units in a cube. - SyncUnits, -} diff --git a/crates/burn-cube/src/ir/variable.rs b/crates/burn-cube/src/ir/variable.rs deleted file mode 100644 index ea6335b7ad..0000000000 --- a/crates/burn-cube/src/ir/variable.rs +++ /dev/null @@ -1,159 +0,0 @@ -use super::{Elem, Item, Matrix}; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] -#[allow(missing_docs)] -pub enum Variable { - Rank, - GlobalInputArray { - id: u16, - item: Item, - }, - GlobalScalar { - id: u16, - elem: Elem, - }, - GlobalOutputArray { - id: u16, - item: Item, - }, - Local { - id: u16, - item: Item, - depth: u8, - }, - LocalScalar { - id: u16, - elem: Elem, - depth: u8, - }, - ConstantScalar { - value: f64, - elem: Elem, - }, - SharedMemory { - id: u16, - item: Item, - length: u32, - }, - LocalArray { - id: u16, - item: Item, - depth: u8, - length: u32, - }, - Matrix { - id: u16, - mat: Matrix, - }, - Slice { - id: u16, - item: Item, - depth: u8, - }, - UnitPos, - UnitPosX, - UnitPosY, - UnitPosZ, - CubePos, - CubePosX, - CubePosY, - CubePosZ, - CubeDim, - CubeDimX, - CubeDimY, - CubeDimZ, - CubeCount, - CubeCountX, - CubeCountY, - CubeCountZ, - SubcubeDim, - AbsolutePos, - AbsolutePosX, - AbsolutePosY, - AbsolutePosZ, -} - -impl Variable { - pub fn index(&self) -> Option { - match self { - Variable::GlobalInputArray { id, .. } => Some(*id), - Variable::GlobalScalar { id, .. } => Some(*id), - Variable::Local { id, .. } => Some(*id), - Variable::Slice { id, .. } => Some(*id), - Variable::LocalScalar { id, .. } => Some(*id), - Variable::GlobalOutputArray { id, .. } => Some(*id), - Variable::ConstantScalar { .. } => None, - Variable::SharedMemory { id, .. } => Some(*id), - Variable::LocalArray { id, .. } => Some(*id), - Variable::Matrix { id, .. } => Some(*id), - Variable::AbsolutePos => None, - Variable::UnitPos => None, - Variable::UnitPosX => None, - Variable::UnitPosY => None, - Variable::UnitPosZ => None, - Variable::Rank => None, - Variable::CubePosX => None, - Variable::CubePosY => None, - Variable::CubePosZ => None, - Variable::AbsolutePosX => None, - Variable::AbsolutePosY => None, - Variable::AbsolutePosZ => None, - Variable::CubeDimX => None, - Variable::CubeDimY => None, - Variable::CubeDimZ => None, - Variable::CubeCountX => None, - Variable::CubeCountY => None, - Variable::CubeCountZ => None, - Variable::CubePos => None, - Variable::CubeCount => None, - Variable::CubeDim => None, - Variable::SubcubeDim => None, - } - } - - /// Fetch the item of the variable. - pub fn item(&self) -> Item { - match self { - Variable::GlobalInputArray { item, .. } => *item, - Variable::GlobalOutputArray { item, .. } => *item, - Variable::GlobalScalar { elem, .. } => Item::new(*elem), - Variable::Local { item, .. } => *item, - Variable::LocalScalar { elem, .. } => Item::new(*elem), - Variable::ConstantScalar { elem, .. } => Item::new(*elem), - Variable::SharedMemory { item, .. } => *item, - Variable::LocalArray { item, .. } => *item, - Variable::Slice { item, .. } => *item, - Variable::Matrix { mat, .. } => Item::new(mat.elem), - Variable::AbsolutePos => Item::new(Elem::UInt), - Variable::Rank => Item::new(Elem::UInt), - Variable::UnitPos => Item::new(Elem::UInt), - Variable::UnitPosX => Item::new(Elem::UInt), - Variable::UnitPosY => Item::new(Elem::UInt), - Variable::UnitPosZ => Item::new(Elem::UInt), - Variable::CubePosX => Item::new(Elem::UInt), - Variable::CubePosY => Item::new(Elem::UInt), - Variable::CubePosZ => Item::new(Elem::UInt), - Variable::AbsolutePosX => Item::new(Elem::UInt), - Variable::AbsolutePosY => Item::new(Elem::UInt), - Variable::AbsolutePosZ => Item::new(Elem::UInt), - Variable::CubeDimX => Item::new(Elem::UInt), - Variable::CubeDimY => Item::new(Elem::UInt), - Variable::CubeDimZ => Item::new(Elem::UInt), - Variable::CubeCountX => Item::new(Elem::UInt), - Variable::CubeCountY => Item::new(Elem::UInt), - Variable::CubeCountZ => Item::new(Elem::UInt), - Variable::CubePos => Item::new(Elem::UInt), - Variable::CubeCount => Item::new(Elem::UInt), - Variable::CubeDim => Item::new(Elem::UInt), - Variable::SubcubeDim => Item::new(Elem::UInt), - } - } -} - -// Useful with the cube_inline macro. -impl From<&Variable> for Variable { - fn from(value: &Variable) -> Self { - *value - } -} diff --git a/crates/burn-cube/src/ir/vectorization.rs b/crates/burn-cube/src/ir/vectorization.rs deleted file mode 100644 index 15c81b3f0d..0000000000 --- a/crates/burn-cube/src/ir/vectorization.rs +++ /dev/null @@ -1,250 +0,0 @@ -use super::{ - BinaryOperator, ClampOperator, FmaOperator, InitOperator, Item, Operation, Operator, - SliceOperator, Subcube, UnaryOperator, Variable, -}; - -pub type Vectorization = u8; - -impl Operation { - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - match self { - Operation::Operator(op) => Operation::Operator(op.vectorize(vectorization)), - Operation::Procedure(op) => Operation::Procedure(op.vectorize(vectorization)), - Operation::Metadata(_) => panic!( - "Metadata can't be vectorized, they should only be generated after vectorization." - ), - Operation::Branch(_) => panic!( - "A branch can't be vectorized, they should only be generated after vectorization." - ), - Operation::Synchronization(_) => panic!( - "Synchronization instructions can't be vectorized, they should only be generated after vectorization." - ), - Operation::Subcube(op) => Operation::Subcube(op.vectorize(vectorization)), - Operation::CoopMma(_) => panic!( - "Cooperative matrix-multiply and accumulate doesn't support vectorization." - ), - } - } -} - -impl Operator { - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - match self { - Operator::Max(op) => Operator::Max(op.vectorize(vectorization)), - Operator::Min(op) => Operator::Min(op.vectorize(vectorization)), - Operator::Add(op) => Operator::Add(op.vectorize(vectorization)), - Operator::Fma(op) => Operator::Fma(op.vectorize(vectorization)), - Operator::Index(op) => Operator::Index(op.vectorize(vectorization)), - Operator::UncheckedIndex(op) => Operator::UncheckedIndex(op.vectorize(vectorization)), - Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)), - Operator::Mul(op) => Operator::Mul(op.vectorize(vectorization)), - Operator::Div(op) => Operator::Div(op.vectorize(vectorization)), - Operator::Floor(op) => Operator::Floor(op.vectorize(vectorization)), - Operator::Ceil(op) => Operator::Ceil(op.vectorize(vectorization)), - Operator::Abs(op) => Operator::Abs(op.vectorize(vectorization)), - Operator::Exp(op) => Operator::Exp(op.vectorize(vectorization)), - Operator::Log(op) => Operator::Log(op.vectorize(vectorization)), - Operator::Log1p(op) => Operator::Log1p(op.vectorize(vectorization)), - Operator::Cos(op) => Operator::Cos(op.vectorize(vectorization)), - Operator::Sin(op) => Operator::Sin(op.vectorize(vectorization)), - Operator::Tanh(op) => Operator::Tanh(op.vectorize(vectorization)), - Operator::Powf(op) => Operator::Powf(op.vectorize(vectorization)), - Operator::Sqrt(op) => Operator::Sqrt(op.vectorize(vectorization)), - Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)), - Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)), - Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)), - Operator::NotEqual(op) => Operator::NotEqual(op.vectorize(vectorization)), - Operator::Lower(op) => Operator::Lower(op.vectorize(vectorization)), - Operator::Clamp(op) => Operator::Clamp(op.vectorize(vectorization)), - Operator::Greater(op) => Operator::Greater(op.vectorize(vectorization)), - Operator::LowerEqual(op) => Operator::LowerEqual(op.vectorize(vectorization)), - Operator::GreaterEqual(op) => Operator::GreaterEqual(op.vectorize(vectorization)), - Operator::Assign(op) => { - if let Variable::GlobalScalar { .. } = op.input { - // Assign will not change the type of the output if the input can't be - // vectorized. - return Operator::Assign(op.clone()); - } - - Operator::Assign(op.vectorize(vectorization)) - } - Operator::Modulo(op) => Operator::Modulo(op.vectorize(vectorization)), - Operator::IndexAssign(op) => Operator::IndexAssign(op.vectorize(vectorization)), - Operator::UncheckedIndexAssign(op) => { - Operator::UncheckedIndexAssign(op.vectorize(vectorization)) - } - Operator::And(op) => Operator::And(op.vectorize(vectorization)), - Operator::Or(op) => Operator::Or(op.vectorize(vectorization)), - Operator::Not(op) => Operator::Not(op.vectorize(vectorization)), - Operator::BitwiseAnd(op) => Operator::BitwiseAnd(op.vectorize(vectorization)), - Operator::BitwiseXor(op) => Operator::BitwiseXor(op.vectorize(vectorization)), - Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)), - Operator::ShiftRight(op) => Operator::ShiftRight(op.vectorize(vectorization)), - Operator::Remainder(op) => Operator::Remainder(op.vectorize(vectorization)), - Operator::Slice(op) => Operator::Slice(op.vectorize(vectorization)), - } - } -} - -impl BinaryOperator { - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - let lhs = self.lhs.vectorize(vectorization); - let rhs = self.rhs.vectorize(vectorization); - let out = self.out.vectorize(vectorization); - - Self { lhs, rhs, out } - } -} - -impl UnaryOperator { - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - let input = self.input.vectorize(vectorization); - let out = self.out.vectorize(vectorization); - - Self { input, out } - } -} - -impl SliceOperator { - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - let input = self.input.vectorize(vectorization); - let start = self.start.vectorize(vectorization); - let end = self.end.vectorize(vectorization); - let out = self.out.vectorize(vectorization); - - Self { - input, - start, - end, - out, - } - } -} - -impl InitOperator { - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - let out = self.out.vectorize(vectorization); - - Self { out } - } -} - -impl Subcube { - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - match self { - Subcube::Elect(op) => Subcube::Elect(op.vectorize(vectorization)), - Subcube::All(op) => Subcube::All(op.vectorize(vectorization)), - Subcube::Any(op) => Subcube::Any(op.vectorize(vectorization)), - Subcube::Broadcast(op) => Subcube::Broadcast(op.vectorize(vectorization)), - Subcube::Sum(op) => Subcube::Sum(op.vectorize(vectorization)), - Subcube::Prod(op) => Subcube::Prod(op.vectorize(vectorization)), - Subcube::And(op) => Subcube::And(op.vectorize(vectorization)), - Subcube::Or(op) => Subcube::Or(op.vectorize(vectorization)), - Subcube::Xor(op) => Subcube::Xor(op.vectorize(vectorization)), - Subcube::Min(op) => Subcube::Min(op.vectorize(vectorization)), - Subcube::Max(op) => Subcube::Max(op.vectorize(vectorization)), - } - } -} - -impl ClampOperator { - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - Self { - input: self.input.vectorize(vectorization), - out: self.out.vectorize(vectorization), - min_value: self.min_value.vectorize(vectorization), - max_value: self.max_value.vectorize(vectorization), - } - } -} - -impl FmaOperator { - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { - Self { - a: self.a.vectorize(vectorization), - b: self.b.vectorize(vectorization), - c: self.c.vectorize(vectorization), - out: self.out.vectorize(vectorization), - } - } -} - -impl Variable { - pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Self { - match self { - Variable::GlobalInputArray { id, item } => Variable::GlobalInputArray { - id: *id, - item: item.vectorize(vectorize), - }, - Variable::Local { id, item, depth } => Variable::Local { - id: *id, - item: item.vectorize(vectorize), - depth: *depth, - }, - Variable::Slice { id, item, depth } => Variable::Slice { - id: *id, - item: item.vectorize(vectorize), - depth: *depth, - }, - Variable::GlobalOutputArray { id, item } => Variable::GlobalOutputArray { - id: *id, - item: item.vectorize(vectorize), - }, - Variable::SharedMemory { id, item, length } => Variable::SharedMemory { - id: *id, - item: item.vectorize(vectorize), - length: item.vectorized_size(vectorize, *length), - }, - Variable::LocalArray { - id, - item, - depth, - length, - } => Variable::LocalArray { - id: *id, - item: item.vectorize(vectorize), - depth: *depth, - length: item.vectorized_size(vectorize, *length), - }, - Variable::ConstantScalar { .. } => *self, - Variable::GlobalScalar { .. } => *self, - Variable::AbsolutePos => *self, - Variable::Rank => *self, - Variable::LocalScalar { .. } => *self, - Variable::Matrix { .. } => *self, - Variable::UnitPos => *self, - Variable::UnitPosX => *self, - Variable::UnitPosY => *self, - Variable::UnitPosZ => *self, - Variable::CubePosX => *self, - Variable::CubePosY => *self, - Variable::CubePosZ => *self, - Variable::AbsolutePosX => *self, - Variable::AbsolutePosY => *self, - Variable::AbsolutePosZ => *self, - Variable::CubeDimX => *self, - Variable::CubeDimY => *self, - Variable::CubeDimZ => *self, - Variable::CubeCountX => *self, - Variable::CubeCountY => *self, - Variable::CubeCountZ => *self, - Variable::CubePos => *self, - Variable::CubeCount => *self, - Variable::CubeDim => *self, - Variable::SubcubeDim => *self, - } - } -} - -impl Item { - pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Item { - Item { - elem: self.elem, - vectorization, - } - } - - pub(crate) fn vectorized_size(&self, vectorize: Vectorization, size: u32) -> u32 { - size / (vectorize as u32) - } -} diff --git a/crates/burn-cube/src/lib.rs b/crates/burn-cube/src/lib.rs deleted file mode 100644 index bf2372b72a..0000000000 --- a/crates/burn-cube/src/lib.rs +++ /dev/null @@ -1,94 +0,0 @@ -extern crate alloc; - -#[macro_use] -extern crate derive_new; - -/// Cube Frontend Types. -pub mod frontend; - -use burn_compute::server::ComputeServer; -pub use frontend::cmma; - -/// Cube Language Internal Representation. -pub mod ir; - -pub mod codegen; -pub mod compute; -pub mod prelude; - -mod pod; -mod runtime; - -pub use codegen::*; -pub use pod::*; -pub use runtime::*; - -pub use burn_cube_macros::cube; -pub use burn_cube_macros::CubeLaunch; -pub use burn_cube_macros::CubeType; - -/// An approximation of the subcube dimension. -pub const SUBCUBE_DIM_APPROX: usize = 16; - -use crate::ir::KernelDefinition; -use frontend::LaunchArg; -use prelude::CubeCount; - -/// Implement this trait to create a [kernel definition](KernelDefinition). -pub trait Kernel: Send + Sync + 'static { - /// Convert to a kernel definition. - fn define(&self) -> KernelDefinition; - /// Identifier for the kernel, used for caching kernel compilation. - fn id(&self) -> String { - format!("{:?}", core::any::TypeId::of::()) - } -} - -/// Calculate the number of cubes required to execute an operation where one cube unit is -/// assigned to one element. -pub fn calculate_cube_count_elemwise( - num_elems: usize, - cube_dim: usize, -) -> CubeCount { - let num_elems_per_cube = cube_dim * cube_dim; - let cube_counts = f32::ceil(num_elems as f32 / num_elems_per_cube as f32); - let cube_count_x = f32::ceil(f32::sqrt(cube_counts)); - let cube_count_y = f32::ceil(num_elems as f32 / (cube_count_x * num_elems_per_cube as f32)); - - CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1) -} - -pub fn tensor_vectorization_factor( - factors: &[u8], - shape: &[usize], - strides: &[usize], - dim: usize, -) -> u8 { - if let Some(val) = strides.get(dim) { - if *val != 1 { - return 1; - } - } else { - return 1; - } - - let dim_size = match shape.get(dim) { - Some(val) => val, - None => return 1, - }; - - for factor in factors { - if dim_size % *factor as usize == 0 { - return *factor; - } - } - - 1 -} - -/// Runtime arguments to launch a kernel. -pub type RuntimeArg<'a, T, R> = ::RuntimeArg<'a, R>; - -#[cfg(feature = "export_tests")] -/// Tests only useful for runtimes. -pub mod runtime_tests; diff --git a/crates/burn-cube/src/pod.rs b/crates/burn-cube/src/pod.rs deleted file mode 100644 index b3109c02b6..0000000000 --- a/crates/burn-cube/src/pod.rs +++ /dev/null @@ -1,124 +0,0 @@ -use crate::ir::{Elem, FloatKind, IntKind}; - -/// The base element trait for the jit backend. -pub trait CubeElement: core::fmt::Debug + Send + Sync + 'static + Clone + bytemuck::Pod { - /// Returns the name of the type. - fn type_name() -> &'static str; - /// Convert a slice of elements to a slice of bytes. - fn as_bytes(slice: &[Self]) -> &[u8]; - /// Convert a slice of bytes to a slice of elements. - fn from_bytes(bytes: &[u8]) -> &[Self]; - /// Element representation for `cubecl`. - fn cube_elem() -> Elem; - /// Highest possible value - fn maximum_value() -> Self; - /// Lowest possible value - fn minimum_value() -> Self; -} - -impl CubeElement for u32 { - fn type_name() -> &'static str { - "u32" - } - fn as_bytes(slice: &[Self]) -> &[u8] { - bytemuck::cast_slice(slice) - } - fn from_bytes(bytes: &[u8]) -> &[Self] { - bytemuck::cast_slice(bytes) - } - fn cube_elem() -> Elem { - Elem::UInt - } - fn maximum_value() -> Self { - u32::MAX - } - fn minimum_value() -> Self { - u32::MIN - } -} - -impl CubeElement for i32 { - fn type_name() -> &'static str { - "i32" - } - fn as_bytes(slice: &[Self]) -> &[u8] { - bytemuck::cast_slice(slice) - } - fn from_bytes(bytes: &[u8]) -> &[Self] { - bytemuck::cast_slice(bytes) - } - fn cube_elem() -> Elem { - Elem::Int(IntKind::I32) - } - fn maximum_value() -> Self { - // Seems to cause problem for some GPU - i32::MAX - 1 - } - fn minimum_value() -> Self { - // Seems to cause problem for some GPU - i32::MIN + 1 - } -} - -impl CubeElement for f32 { - fn type_name() -> &'static str { - "f32" - } - fn as_bytes(slice: &[Self]) -> &[u8] { - bytemuck::cast_slice(slice) - } - fn from_bytes(bytes: &[u8]) -> &[Self] { - bytemuck::cast_slice(bytes) - } - fn cube_elem() -> Elem { - Elem::Float(FloatKind::F32) - } - fn maximum_value() -> Self { - f32::MAX - } - fn minimum_value() -> Self { - f32::MIN - } -} - -impl CubeElement for half::f16 { - fn type_name() -> &'static str { - "f16" - } - fn as_bytes(slice: &[Self]) -> &[u8] { - bytemuck::cast_slice(slice) - } - fn from_bytes(bytes: &[u8]) -> &[Self] { - bytemuck::cast_slice(bytes) - } - fn cube_elem() -> Elem { - Elem::Float(FloatKind::F16) - } - fn maximum_value() -> Self { - half::f16::MAX - } - fn minimum_value() -> Self { - half::f16::MIN - } -} - -impl CubeElement for half::bf16 { - fn type_name() -> &'static str { - "bf16" - } - fn as_bytes(slice: &[Self]) -> &[u8] { - bytemuck::cast_slice(slice) - } - fn from_bytes(bytes: &[u8]) -> &[Self] { - bytemuck::cast_slice(bytes) - } - fn cube_elem() -> Elem { - Elem::Float(FloatKind::BF16) - } - fn maximum_value() -> Self { - half::bf16::MAX - } - fn minimum_value() -> Self { - half::bf16::MIN - } -} diff --git a/crates/burn-cube/src/prelude.rs b/crates/burn-cube/src/prelude.rs deleted file mode 100644 index 027a7f866d..0000000000 --- a/crates/burn-cube/src/prelude.rs +++ /dev/null @@ -1,30 +0,0 @@ -pub use crate::{cube, CubeLaunch, CubeType, Kernel, RuntimeArg}; - -pub use crate::codegen::{KernelExpansion, KernelIntegrator, KernelSettings}; -pub use crate::compute::{ - CompiledKernel, CubeCount, CubeTask, KernelBuilder, KernelLauncher, KernelTask, -}; -pub use crate::frontend::cmma; -pub use crate::frontend::{branch::*, synchronization::*}; -pub use crate::ir::{CubeDim, KernelDefinition}; -pub use crate::runtime::Runtime; - -/// Elements -pub use crate::frontend::{ - Array, ArrayHandle, Bool, Float, LaunchArg, Slice, SliceMut, Tensor, TensorArg, UInt, F16, F32, - F64, I32, I64, -}; -pub use crate::pod::CubeElement; - -/// Topology -pub use crate::frontend::{ - ABSOLUTE_POS, ABSOLUTE_POS_X, ABSOLUTE_POS_Y, ABSOLUTE_POS_Z, CUBE_COUNT, CUBE_COUNT_X, - CUBE_COUNT_Y, CUBE_COUNT_Z, CUBE_DIM, CUBE_DIM_X, CUBE_DIM_Y, CUBE_DIM_Z, CUBE_POS, CUBE_POS_X, - CUBE_POS_Y, CUBE_POS_Z, UNIT_POS, UNIT_POS_X, UNIT_POS_Y, UNIT_POS_Z, -}; - -/// Export subcube operations. -pub use crate::frontend::{subcube_all, subcube_max, subcube_min, subcube_prod, subcube_sum}; -pub use burn_compute::client::ComputeClient; - -pub use crate::frontend::*; diff --git a/crates/burn-cube/src/runtime.rs b/crates/burn-cube/src/runtime.rs deleted file mode 100644 index 2975a3f57b..0000000000 --- a/crates/burn-cube/src/runtime.rs +++ /dev/null @@ -1,78 +0,0 @@ -use crate::{ - codegen::Compiler, - compute::{CubeCount, CubeTask}, - ir::Elem, -}; -use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer}; - -/// Runtime for the CubeCL. -pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { - /// The compiler used to compile the inner representation into tokens. - type Compiler: Compiler; - /// The compute server used to run kernels and perform autotuning. - type Server: ComputeServer< - Kernel = Box, - DispatchOptions = CubeCount, - FeatureSet = FeatureSet, - >; - /// The channel used to communicate with the compute server. - type Channel: ComputeChannel; - /// The device used to retrieve the compute client. - type Device; - - /// Retrieve the compute client from the runtime device. - fn client(device: &Self::Device) -> ComputeClient; - - /// The runtime name. - fn name() -> &'static str; - - /// Return true if global input array lengths should be added to kernel info. - fn require_array_lengths() -> bool { - false - } -} - -/// The set of [features](Feature) supported by a [runtime](Runtime). -#[derive(Default)] -pub struct FeatureSet { - set: alloc::collections::BTreeSet, -} - -impl FeatureSet { - pub fn new(features: &[Feature]) -> Self { - let mut this = Self::default(); - - for feature in features { - this.register(*feature); - } - - this - } - /// Check if the provided [feature](Feature) is supported by the runtime. - pub fn enabled(&self, feature: Feature) -> bool { - self.set.contains(&feature) - } - - /// Register a [feature](Feature) supported by the compute server. - /// - /// This should only be used by a [runtime](Runtime) when initializing a device. - pub fn register(&mut self, feature: Feature) -> bool { - self.set.insert(feature) - } -} - -/// Every feature that can be supported by a [cube runtime](Runtime). -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum Feature { - /// The subcube feature enables all basic warp/subgroup operations. - Subcube, - /// The cmma feature enables cooperative matrix-multiply and accumulate operations. - Cmma { - a: Elem, - b: Elem, - c: Elem, - m: u8, - k: u8, - n: u8, - }, -} diff --git a/crates/burn-cube/src/runtime_tests/cmma.rs b/crates/burn-cube/src/runtime_tests/cmma.rs deleted file mode 100644 index 6c0668cbb2..0000000000 --- a/crates/burn-cube/src/runtime_tests/cmma.rs +++ /dev/null @@ -1,118 +0,0 @@ -use crate::{self as burn_cube, Feature}; -use burn_cube::{ - ir::{Elem, FloatKind}, - prelude::*, -}; -use burn_tensor::ElementConversion; -use half::f16; - -#[cube(launch)] -/// Executes Out = Lhs @ Rhs.T -pub fn kernel_simple_1(lhs: &Array, rhs: &Array, out: &mut Array) { - let a = cmma::Matrix::::new( - cmma::MatrixIdent::A, - 16, - 16, - 16, - cmma::MatrixLayout::RowMajor, - ); - let b = cmma::Matrix::::new( - cmma::MatrixIdent::B, - 16, - 16, - 16, - cmma::MatrixLayout::ColMajor, - ); - let c = cmma::Matrix::::new( - cmma::MatrixIdent::Accumulator, - 16, - 16, - 16, - cmma::MatrixLayout::Undefined, - ); - cmma::fill::(&c, F32::new(0.0)); - cmma::load::(&a, lhs.as_slice(), UInt::new(16)); - cmma::load::(&b, rhs.as_slice(), UInt::new(16)); - - cmma::execute::(&a, &b, &c, &c); - - cmma::store::( - out.as_slice_mut(), - &c, - UInt::new(16), - cmma::MatrixLayout::RowMajor, - ); -} - -pub fn test_simple_1(client: ComputeClient) { - if !client.features().enabled(Feature::Cmma { - a: Elem::Float(FloatKind::F16), - b: Elem::Float(FloatKind::F16), - c: Elem::Float(FloatKind::F32), - m: 16, - k: 16, - n: 16, - }) { - // We can't execute the test, skip. - return; - } - - let lhs: Vec = (0..256).map(|i| i.elem()).collect(); - let rhs: Vec = (0..256).map(|i| (i % 8).elem()).collect(); - - let lhs = client.create(f16::as_bytes(&lhs)); - let rhs = client.create(f16::as_bytes(&rhs)); - let out = client.empty(core::mem::size_of::() * 256); - - kernel_simple_1::launch::( - client.clone(), - CubeCount::Static(1, 1, 1), - CubeDim::new(16, 16, 1), - ArrayArg::new(&lhs, 256), - ArrayArg::new(&rhs, 256), - ArrayArg::new(&out, 256), - ); - - let actual = client.read(out.binding()); - let actual = f32::from_bytes(&actual); - - let expected = [ - 504., 504., 504., 504., 504., 504., 504., 504., 504., 504., 504., 504., 504., 504., 504., - 504., 1400., 1400., 1400., 1400., 1400., 1400., 1400., 1400., 1400., 1400., 1400., 1400., - 1400., 1400., 1400., 1400., 2296., 2296., 2296., 2296., 2296., 2296., 2296., 2296., 2296., - 2296., 2296., 2296., 2296., 2296., 2296., 2296., 3192., 3192., 3192., 3192., 3192., 3192., - 3192., 3192., 3192., 3192., 3192., 3192., 3192., 3192., 3192., 3192., 4088., 4088., 4088., - 4088., 4088., 4088., 4088., 4088., 4088., 4088., 4088., 4088., 4088., 4088., 4088., 4088., - 4984., 4984., 4984., 4984., 4984., 4984., 4984., 4984., 4984., 4984., 4984., 4984., 4984., - 4984., 4984., 4984., 5880., 5880., 5880., 5880., 5880., 5880., 5880., 5880., 5880., 5880., - 5880., 5880., 5880., 5880., 5880., 5880., 6776., 6776., 6776., 6776., 6776., 6776., 6776., - 6776., 6776., 6776., 6776., 6776., 6776., 6776., 6776., 6776., 7672., 7672., 7672., 7672., - 7672., 7672., 7672., 7672., 7672., 7672., 7672., 7672., 7672., 7672., 7672., 7672., 8568., - 8568., 8568., 8568., 8568., 8568., 8568., 8568., 8568., 8568., 8568., 8568., 8568., 8568., - 8568., 8568., 9464., 9464., 9464., 9464., 9464., 9464., 9464., 9464., 9464., 9464., 9464., - 9464., 9464., 9464., 9464., 9464., 10360., 10360., 10360., 10360., 10360., 10360., 10360., - 10360., 10360., 10360., 10360., 10360., 10360., 10360., 10360., 10360., 11256., 11256., - 11256., 11256., 11256., 11256., 11256., 11256., 11256., 11256., 11256., 11256., 11256., - 11256., 11256., 11256., 12152., 12152., 12152., 12152., 12152., 12152., 12152., 12152., - 12152., 12152., 12152., 12152., 12152., 12152., 12152., 12152., 13048., 13048., 13048., - 13048., 13048., 13048., 13048., 13048., 13048., 13048., 13048., 13048., 13048., 13048., - 13048., 13048., 13944., 13944., 13944., 13944., 13944., 13944., 13944., 13944., 13944., - 13944., 13944., 13944., 13944., 13944., 13944., 13944., - ]; - - assert_eq!(expected, actual); -} - -#[allow(missing_docs)] -#[macro_export] -macro_rules! testgen_cmma { - () => { - use super::*; - - #[test] - fn test_cmma_simple_1() { - let client = TestRuntime::client(&Default::default()); - burn_cube::runtime_tests::cmma::test_simple_1::(client); - } - }; -} diff --git a/crates/burn-cube/src/runtime_tests/launch.rs b/crates/burn-cube/src/runtime_tests/launch.rs deleted file mode 100644 index fe0646d979..0000000000 --- a/crates/burn-cube/src/runtime_tests/launch.rs +++ /dev/null @@ -1,68 +0,0 @@ -use crate as burn_cube; -use burn_cube::prelude::*; - -#[cube(launch)] -pub fn kernel_with_generics(output: &mut Array) { - if UNIT_POS == UInt::new(0) { - output[0] = F::new(5.0); - } -} - -#[cube(launch)] -pub fn kernel_without_generics(output: &mut Array) { - if UNIT_POS == UInt::new(0) { - output[0] = F32::new(5.0); - } -} - -pub fn test_kernel_with_generics(client: ComputeClient) { - let handle = client.create(f32::as_bytes(&[0.0, 1.0])); - - kernel_with_generics::launch::( - client.clone(), - CubeCount::Static(1, 1, 1), - CubeDim::default(), - ArrayArg::new(&handle, 2), - ); - - let actual = client.read(handle.binding()); - let actual = f32::from_bytes(&actual); - - assert_eq!(actual[0], 5.0); -} - -pub fn test_kernel_without_generics(client: ComputeClient) { - let handle = client.create(f32::as_bytes(&[0.0, 1.0])); - - kernel_without_generics::launch::( - client.clone(), - CubeCount::Static(1, 1, 1), - CubeDim::default(), - ArrayArg::new(&handle, 2), - ); - - let actual = client.read(handle.binding()); - let actual = f32::from_bytes(&actual); - - assert_eq!(actual[0], 5.0); -} - -#[allow(missing_docs)] -#[macro_export] -macro_rules! testgen_launch { - () => { - use super::*; - - #[test] - fn test_launch_with_generics() { - let client = TestRuntime::client(&Default::default()); - burn_cube::runtime_tests::launch::test_kernel_with_generics::(client); - } - - #[test] - fn test_launch_without_generics() { - let client = TestRuntime::client(&Default::default()); - burn_cube::runtime_tests::launch::test_kernel_without_generics::(client); - } - }; -} diff --git a/crates/burn-cube/src/runtime_tests/mod.rs b/crates/burn-cube/src/runtime_tests/mod.rs deleted file mode 100644 index 1eda960288..0000000000 --- a/crates/burn-cube/src/runtime_tests/mod.rs +++ /dev/null @@ -1,17 +0,0 @@ -pub mod cmma; -pub mod launch; -pub mod slice; -pub mod subcube; - -#[allow(missing_docs)] -#[macro_export] -macro_rules! testgen_all { - () => { - use burn_cube::prelude::*; - - burn_cube::testgen_subcube!(); - burn_cube::testgen_launch!(); - burn_cube::testgen_cmma!(); - burn_cube::testgen_slice!(); - }; -} diff --git a/crates/burn-cube/src/runtime_tests/slice.rs b/crates/burn-cube/src/runtime_tests/slice.rs deleted file mode 100644 index c46c07fe32..0000000000 --- a/crates/burn-cube/src/runtime_tests/slice.rs +++ /dev/null @@ -1,107 +0,0 @@ -use crate as burn_cube; -use burn_cube::prelude::*; - -#[cube(launch)] -pub fn slice_select(input: &Array, output: &mut Array) { - if UNIT_POS == UInt::new(0) { - let slice = input.slice(2, 3); - output[0] = slice[0u32]; - } -} - -#[cube(launch)] -pub fn slice_assign(input: &Array, output: &mut Array) { - if UNIT_POS == UInt::new(0) { - let slice_1 = output.slice_mut(2, 3); - slice_1[0] = input[0u32]; - } -} - -#[cube(launch)] -pub fn slice_len(input: &Array, output: &mut Array) { - if UNIT_POS == UInt::new(0) { - let slice = input.slice(2, 4); - let _tmp = slice[0]; // It must be used at least once, otherwise wgpu isn't happy. - output[0] = slice.len(); - } -} - -pub fn test_slice_select(client: ComputeClient) { - let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0])); - let output = client.empty(core::mem::size_of::()); - - slice_select::launch::( - client.clone(), - CubeCount::Static(1, 1, 1), - CubeDim::new(1, 1, 1), - ArrayArg::new(&input, 5), - ArrayArg::new(&output, 1), - ); - - let actual = client.read(output.binding()); - let actual = f32::from_bytes(&actual); - - assert_eq!(actual[0], 2.0); -} - -pub fn test_slice_len(client: ComputeClient) { - let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0])); - let output = client.empty(core::mem::size_of::()); - - slice_len::launch::( - client.clone(), - CubeCount::Static(1, 1, 1), - CubeDim::new(1, 1, 1), - ArrayArg::new(&input, 5), - ArrayArg::new(&output, 1), - ); - - let actual = client.read(output.binding()); - let actual = u32::from_bytes(&actual); - - assert_eq!(actual, &[2]); -} - -pub fn test_slice_assign(client: ComputeClient) { - let input = client.create(f32::as_bytes(&[15.0])); - let output = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0])); - - slice_assign::launch::( - client.clone(), - CubeCount::Static(1, 1, 1), - CubeDim::new(1, 1, 1), - ArrayArg::new(&input, 5), - ArrayArg::new(&output, 1), - ); - - let actual = client.read(output.binding()); - let actual = f32::from_bytes(&actual); - - assert_eq!(actual, &[0.0, 1.0, 15.0, 3.0, 4.0]); -} - -#[allow(missing_docs)] -#[macro_export] -macro_rules! testgen_slice { - () => { - use super::*; - - #[test] - fn test_slice_select() { - let client = TestRuntime::client(&Default::default()); - burn_cube::runtime_tests::slice::test_slice_select::(client); - } - - #[test] - fn test_slice_assign() { - let client = TestRuntime::client(&Default::default()); - burn_cube::runtime_tests::slice::test_slice_assign::(client); - } - - #[test] - fn test_slice_len() { - let client = TestRuntime::client(&Default::default()); - burn_cube::runtime_tests::slice::test_slice_len::(client); - } - }; -} diff --git a/crates/burn-cube/src/runtime_tests/subcube.rs b/crates/burn-cube/src/runtime_tests/subcube.rs deleted file mode 100644 index c3c9cd53c3..0000000000 --- a/crates/burn-cube/src/runtime_tests/subcube.rs +++ /dev/null @@ -1,153 +0,0 @@ -use crate::{self as burn_cube, Feature}; -use burn_cube::prelude::*; - -#[cube(launch)] -pub fn kernel_sum(output: &mut Tensor) { - let val = output[UNIT_POS]; - let val2 = subcube_sum::(val); - - if UNIT_POS == UInt::new(0) { - output[0] = val2; - } -} - -#[cube(launch)] -pub fn kernel_prod(output: &mut Tensor) { - let val = output[UNIT_POS]; - let val2 = subcube_prod::(val); - - if UNIT_POS == UInt::new(0) { - output[0] = val2; - } -} - -#[cube(launch)] -pub fn kernel_max(output: &mut Tensor) { - let val = output[UNIT_POS]; - let val2 = subcube_max::(val); - - if UNIT_POS == UInt::new(0) { - output[0] = val2; - } -} - -#[cube(launch)] -pub fn kernel_min(output: &mut Tensor) { - let val = output[UNIT_POS]; - let val2 = subcube_min::(val); - - if UNIT_POS == UInt::new(0) { - output[0] = val2; - } -} - -pub fn test_subcube_sum( - client: ComputeClient, -) { - test_subcube_operation::( - &[4.0, 5.0, 7.0, 1.0], - &[17.0, 5.0, 7.0, 1.0], - client.clone(), - |cube_count, cube_dim, handle| { - kernel_sum::launch::(client.clone(), cube_count, cube_dim, handle) - }, - ); -} - -pub fn test_subcube_prod( - client: ComputeClient, -) { - test_subcube_operation::( - &[4.0, 5.0, 7.0, 1.0], - &[140.0, 5.0, 7.0, 1.0], - client.clone(), - |cube_dim, settings, handle| { - kernel_prod::launch::(client.clone(), cube_dim, settings, handle) - }, - ); -} -pub fn test_subcube_max( - client: ComputeClient, -) { - test_subcube_operation::( - &[4.0, 5.0, 7.0, 1.0], - &[7.0, 5.0, 7.0, 1.0], - client.clone(), - |cube_dim, settings, handle| { - kernel_max::launch::(client.clone(), cube_dim, settings, handle) - }, - ); -} - -pub fn test_subcube_min( - client: ComputeClient, -) { - test_subcube_operation::( - &[4.0, 5.0, 7.0, 1.0], - &[1.0, 5.0, 7.0, 1.0], - client.clone(), - |cube_dim, settings, handle| { - kernel_min::launch::(client.clone(), cube_dim, settings, handle) - }, - ); -} - -fn test_subcube_operation( - input: &[f32], - expected: &[f32], - client: ComputeClient, - launch: Launch, -) where - Launch: Fn(CubeCount, CubeDim, TensorArg<'_, TestRuntime>), -{ - if !client.features().enabled(Feature::Subcube) { - // Can't execute the test. - return; - } - - let handle = client.create(f32::as_bytes(input)); - let (shape, strides) = ([input.len()], [1]); - - launch( - CubeCount::Static(1, 1, 1), - CubeDim::new(input.len() as u32, 1, 1), - TensorArg::new(&handle, &strides, &shape), - ); - - let actual = client.read(handle.binding()); - let actual = f32::from_bytes(&actual); - - assert_eq!(actual, expected); -} - -#[allow(missing_docs)] -#[macro_export] -macro_rules! testgen_subcube { - () => { - use super::*; - - #[test] - fn test_subcube_sum() { - let client = TestRuntime::client(&Default::default()); - burn_cube::runtime_tests::subcube::test_subcube_sum::(client); - } - - #[test] - fn test_subcube_prod() { - let client = TestRuntime::client(&Default::default()); - burn_cube::runtime_tests::subcube::test_subcube_prod::(client); - } - - #[test] - fn test_subcube_max() { - let client = TestRuntime::client(&Default::default()); - burn_cube::runtime_tests::subcube::test_subcube_max::(client); - } - - #[test] - fn test_subcube_min() { - let client = TestRuntime::client(&Default::default()); - burn_cube::runtime_tests::subcube::test_subcube_max::(client); - } - }; -} diff --git a/crates/burn-cube/tests/error/array_variable.rs b/crates/burn-cube/tests/error/array_variable.rs deleted file mode 100644 index a97ee7555e..0000000000 --- a/crates/burn-cube/tests/error/array_variable.rs +++ /dev/null @@ -1,8 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -fn range(x: UInt, y: UInt) { - let _array = [x, y]; -} - -fn main() {} diff --git a/crates/burn-cube/tests/error/array_variable.stderr b/crates/burn-cube/tests/error/array_variable.stderr deleted file mode 100644 index 7386052f7a..0000000000 --- a/crates/burn-cube/tests/error/array_variable.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: Only arrays of literals are supported - --> tests/error/array_variable.rs:5:18 - | -5 | let _array = [x, y]; - | ^^^^^^ diff --git a/crates/burn-cube/tests/error/for_loop_range.rs b/crates/burn-cube/tests/error/for_loop_range.rs deleted file mode 100644 index e6863287ce..0000000000 --- a/crates/burn-cube/tests/error/for_loop_range.rs +++ /dev/null @@ -1,8 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -fn range() { - for _ in 0..10 {} -} - -fn main() {} diff --git a/crates/burn-cube/tests/error/for_loop_range.stderr b/crates/burn-cube/tests/error/for_loop_range.stderr deleted file mode 100644 index e650866ed3..0000000000 --- a/crates/burn-cube/tests/error/for_loop_range.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: Invalid for loop: use [range](cubecl::prelude::range] instead. - --> tests/error/for_loop_range.rs:5:14 - | -5 | for _ in 0..10 {} - | ^^^^^ diff --git a/crates/burn-cube/tests/error/range.rs b/crates/burn-cube/tests/error/range.rs deleted file mode 100644 index 94b99b11f9..0000000000 --- a/crates/burn-cube/tests/error/range.rs +++ /dev/null @@ -1,8 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -fn range() { - 0..10; -} - -fn main() {} diff --git a/crates/burn-cube/tests/error/range.stderr b/crates/burn-cube/tests/error/range.stderr deleted file mode 100644 index 0d7273849f..0000000000 --- a/crates/burn-cube/tests/error/range.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: Range is not supported, use [range](cubecl::prelude::range) instead. - --> tests/error/range.rs:5:5 - | -5 | 0..10; - | ^^^^^ diff --git a/crates/burn-cube/tests/error/return_value.rs b/crates/burn-cube/tests/error/return_value.rs deleted file mode 100644 index 58437cb413..0000000000 --- a/crates/burn-cube/tests/error/return_value.rs +++ /dev/null @@ -1,12 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -fn range(x: UInt, y: UInt) -> UInt { - if x == y { - return x; - } - - y -} - -fn main() {} diff --git a/crates/burn-cube/tests/error/return_value.stderr b/crates/burn-cube/tests/error/return_value.stderr deleted file mode 100644 index 86d6ee82da..0000000000 --- a/crates/burn-cube/tests/error/return_value.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: Only void return is supported. - --> tests/error/return_value.rs:6:9 - | -6 | return x; - | ^^^^^^^^ diff --git a/crates/burn-cube/tests/error/undeclared_variable.rs b/crates/burn-cube/tests/error/undeclared_variable.rs deleted file mode 100644 index e31660ce9b..0000000000 --- a/crates/burn-cube/tests/error/undeclared_variable.rs +++ /dev/null @@ -1,9 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -fn kernel(x: UInt) { - if x == y { - } -} - -fn main() {} diff --git a/crates/burn-cube/tests/error/undeclared_variable.stderr b/crates/burn-cube/tests/error/undeclared_variable.stderr deleted file mode 100644 index fcac87e8cc..0000000000 --- a/crates/burn-cube/tests/error/undeclared_variable.stderr +++ /dev/null @@ -1,11 +0,0 @@ -error: Variable not declared - --> tests/error/undeclared_variable.rs:5:13 - | -5 | if x == y { - | ^ - -error[E0425]: cannot find value `y` in this scope - --> tests/error/undeclared_variable.rs:5:13 - | -5 | if x == y { - | ^ help: a local variable with a similar name exists: `x` diff --git a/crates/burn-cube/tests/frontend/array.rs b/crates/burn-cube/tests/frontend/array.rs deleted file mode 100644 index 502ee2b247..0000000000 --- a/crates/burn-cube/tests/frontend/array.rs +++ /dev/null @@ -1,216 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -pub fn array_read_write(array_size: Comptime) { - let mut array = Array::::new(array_size); - array[0] = T::from_int(3); - let _ = array[0]; -} - -#[cube] -pub fn array_to_vectorized_variable() -> T { - let mut array = Array::::new(2); - array[0] = T::from_int(0); - array[1] = T::from_int(1); - array.to_vectorized(Comptime::new(UInt::new(2))) -} - -#[cube] -pub fn array_of_one_to_vectorized_variable() -> T { - let mut array = Array::::new(1); - array[0] = T::from_int(3); - array.to_vectorized(Comptime::new(UInt::new(1))) -} - -#[cube] -pub fn array_add_assign_simple(array: &mut Array) { - array[UInt::new(1)] += UInt::new(1); -} - -#[cube] -pub fn array_add_assign_expr(array: &mut Array) { - array[UInt::new(1) + UInt::new(5)] += UInt::new(1); -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - ir::{self, Elem, Item, Variable}, - }; - - type ElemType = F32; - - #[test] - fn cube_support_array() { - let mut context = CubeContext::root(); - - array_read_write::__expand::(&mut context, 512); - assert_eq!( - context.into_scope().operations, - inline_macro_ref_read_write() - ) - } - - #[test] - fn array_add_assign() { - let mut context = CubeContext::root(); - let array = context.input(0, Item::new(Elem::UInt)); - - array_add_assign_simple::__expand(&mut context, array.into()); - let scope = context.into_scope(); - - assert_eq!(scope.operations, inline_macro_array_add_assign_simple()); - } - - #[test] - fn cube_array_to_vectorized() { - let mut context = CubeContext::root(); - - array_to_vectorized_variable::__expand::(&mut context); - assert_eq!( - context.into_scope().operations, - inline_macro_ref_to_vectorized() - ); - } - - #[test] - fn cube_array_of_one_to_vectorized() { - let mut context = CubeContext::root(); - - array_of_one_to_vectorized_variable::__expand::(&mut context); - assert_eq!( - context.into_scope().operations, - inline_macro_ref_one_to_vectorized() - ); - } - - fn inline_macro_ref_read_write() -> Vec { - let context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - - let mut scope = context.into_scope(); - let var = scope.create_local(item); - let pos: Variable = 0u32.into(); - - // Create - let array = scope.create_local_array(item, 512); - - // Write - cpa!(scope, array[pos] = 3.0_f32); - - // Read - cpa!(scope, var = array[pos]); - - scope.operations - } - - #[test] - fn array_add_assign_expr() { - let mut context = CubeContext::root(); - let array = context.input(0, Item::new(Elem::UInt)); - - array_add_assign_expr::__expand(&mut context, array.into()); - let scope = context.into_scope(); - - assert_eq!(scope.operations, inline_macro_array_add_assign_expr()); - } - - fn inline_macro_array_add_assign_simple() -> Vec { - let context = CubeContext::root(); - - let mut scope = context.into_scope(); - let local = scope.create_local(Item::new(Elem::UInt)); - - let array = Variable::GlobalInputArray { - id: 0, - item: Item::new(Elem::UInt), - }; - let index = Variable::ConstantScalar { - value: 1., - elem: Elem::UInt, - }; - let value = Variable::ConstantScalar { - value: 1., - elem: Elem::UInt, - }; - - cpa!(scope, local = array[index]); - cpa!(scope, local += value); - cpa!(scope, array[index] = local); - - scope.operations - } - - fn inline_macro_ref_to_vectorized() -> Vec { - let context = CubeContext::root(); - let scalar_item = Item::new(ElemType::as_elem()); - let vectorized_item = Item::vectorized(ElemType::as_elem(), 2); - - let mut scope = context.into_scope(); - let pos0: Variable = 0u32.into(); - let pos1: Variable = 1u32.into(); - let array = scope.create_local_array(scalar_item, 2); - cpa!(scope, array[pos0] = 0.0_f32); - cpa!(scope, array[pos1] = 1.0_f32); - - let vectorized_var = scope.create_local(vectorized_item); - let tmp = scope.create_local(scalar_item); - cpa!(scope, tmp = array[pos0]); - cpa!(scope, vectorized_var[pos0] = tmp); - cpa!(scope, tmp = array[pos1]); - cpa!(scope, vectorized_var[pos1] = tmp); - - scope.operations - } - - fn inline_macro_ref_one_to_vectorized() -> Vec { - let context = CubeContext::root(); - let scalar_item = Item::new(ElemType::as_elem()); - let unvectorized_item = Item::new(ElemType::as_elem()); - - let mut scope = context.into_scope(); - let pos0: Variable = 0u32.into(); - let array = scope.create_local_array(scalar_item, 1); - cpa!(scope, array[pos0] = 3.0_f32); - - let unvectorized_var = scope.create_local(unvectorized_item); - let tmp = scope.create_local(scalar_item); - cpa!(scope, tmp = array[pos0]); - cpa!(scope, unvectorized_var = tmp); - - scope.operations - } - - fn inline_macro_array_add_assign_expr() -> Vec { - let context = CubeContext::root(); - - let mut scope = context.into_scope(); - let index = scope.create_local(Item::new(Elem::UInt)); - let local = scope.create_local(Item::new(Elem::UInt)); - - let array = Variable::GlobalInputArray { - id: 0, - item: Item::new(Elem::UInt), - }; - let const1 = Variable::ConstantScalar { - value: 1., - elem: Elem::UInt, - }; - let const2 = Variable::ConstantScalar { - value: 5., - elem: Elem::UInt, - }; - let value = Variable::ConstantScalar { - value: 1., - elem: Elem::UInt, - }; - - cpa!(scope, index = const1 + const2); - cpa!(scope, local = array[index]); - cpa!(scope, local += value); - cpa!(scope, array[index] = local); - - scope.operations - } -} diff --git a/crates/burn-cube/tests/frontend/assign.rs b/crates/burn-cube/tests/frontend/assign.rs deleted file mode 100644 index b720a3b87d..0000000000 --- a/crates/burn-cube/tests/frontend/assign.rs +++ /dev/null @@ -1,196 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -pub fn mut_assign() { - let mut x = UInt::new(0); - x += UInt::new(1); -} - -#[cube] -pub fn mut_assign_input(y: UInt) -> UInt { - let mut x = y; - x += UInt::new(1); - y + UInt::new(2) -} - -#[cube] -pub fn assign_mut_input(mut y: UInt) -> UInt { - let x = y; - y += UInt::new(1); - x + UInt::new(2) -} - -#[cube] -pub fn assign_vectorized(y: UInt) -> UInt { - let vectorization_factor = Comptime::vectorization(&y); - let x = UInt::vectorized(1, Comptime::get(vectorization_factor)); - x + y -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - ir::{Elem, Item, Variable}, - }; - - #[test] - fn cube_mut_assign_test() { - let mut context = CubeContext::root(); - - mut_assign::__expand(&mut context); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_mut_assign() - ); - } - - #[test] - fn cube_mut_assign_input_test() { - let mut context = CubeContext::root(); - - let y = context.create_local(Item::new(UInt::as_elem())); - - mut_assign_input::__expand(&mut context, y); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_mut_assign_input() - ); - } - - #[test] - fn cube_assign_mut_input_test() { - let mut context = CubeContext::root(); - - let y = context.create_local(Item::new(UInt::as_elem())); - - assign_mut_input::__expand(&mut context, y); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_assign_mut_input() - ); - } - - #[test] - fn cube_assign_vectorized_test() { - let mut context = CubeContext::root(); - - let y = context.create_local(Item::vectorized(UInt::as_elem(), 4)); - - assign_vectorized::__expand(&mut context, y); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_assign_vectorized() - ); - } - - fn inline_macro_ref_mut_assign() -> String { - let context = CubeContext::root(); - - let mut scope = context.into_scope(); - let x = scope.create_local(Item::new(Elem::UInt)); - - let zero = Variable::ConstantScalar { - value: 0., - elem: Elem::UInt, - }; - let one = Variable::ConstantScalar { - value: 1., - elem: Elem::UInt, - }; - cpa!(scope, x = zero); - cpa!(scope, x = x + one); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_mut_assign_input() -> String { - let mut context = CubeContext::root(); - let item = Item::new(Elem::UInt); - let y = context.create_local(item); - - let mut scope = context.into_scope(); - let y: Variable = y.into(); - let x = scope.create_local(item); - - let one = Variable::ConstantScalar { - value: 1., - elem: Elem::UInt, - }; - let two = Variable::ConstantScalar { - value: 2., - elem: Elem::UInt, - }; - cpa!(scope, x = y); - cpa!(scope, x = x + one); - cpa!(scope, y = y + two); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_assign_mut_input() -> String { - let mut context = CubeContext::root(); - let item = Item::new(Elem::UInt); - let y = context.create_local(item); - - let mut scope = context.into_scope(); - let y: Variable = y.into(); - let x = scope.create_local(item); - - let one = Variable::ConstantScalar { - value: 1., - elem: Elem::UInt, - }; - let two = Variable::ConstantScalar { - value: 2., - elem: Elem::UInt, - }; - cpa!(scope, x = y); - cpa!(scope, y = y + one); - cpa!(scope, x = x + two); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_assign_vectorized() -> String { - let mut context = CubeContext::root(); - let item = Item::vectorized(Elem::UInt, 4); - let y = context.create_local(item); - - let mut scope = context.into_scope(); - let y: Variable = y.into(); - let x = scope.create_local(item); - - let zero = Variable::ConstantScalar { - value: 0., - elem: Elem::UInt, - }; - let one = Variable::ConstantScalar { - value: 1., - elem: Elem::UInt, - }; - let two = Variable::ConstantScalar { - value: 2., - elem: Elem::UInt, - }; - let three = Variable::ConstantScalar { - value: 3., - elem: Elem::UInt, - }; - cpa!(scope, x[zero] = one); - cpa!(scope, x[one] = one); - cpa!(scope, x[two] = one); - cpa!(scope, x[three] = one); - cpa!(scope, x = x + y); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/cast_elem.rs b/crates/burn-cube/tests/frontend/cast_elem.rs deleted file mode 100644 index 8ceaf164f9..0000000000 --- a/crates/burn-cube/tests/frontend/cast_elem.rs +++ /dev/null @@ -1,281 +0,0 @@ -use burn_cube::{ - cube, - frontend::{Bool, Cast, Numeric, UInt, F32, I32}, -}; - -// From float -#[cube] -pub fn float_to_float(x: F32) { - let y = x + F32::from_int(2); - let _ = F32::cast_from(y) + F32::from_int(34); -} - -#[cube] -pub fn float_to_int(x: F32) { - let y = x + F32::from_int(2); - let _ = I32::cast_from(y) + I32::from_int(34); -} - -#[cube] -pub fn float_to_uint(x: F32) { - let y = x + F32::from_int(2); - let _ = UInt::cast_from(y) + UInt::from_int(34); -} - -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -pub fn float_to_bool(x: F32) { - let y = x + F32::from_int(2); - let _ = Bool::cast_from(y) || true; -} - -// From int -#[cube] -pub fn int_to_float(x: I32) { - let y = x + I32::from_int(2); - let _ = F32::cast_from(y) + F32::from_int(34); -} - -#[cube] -#[allow(clippy::useless_conversion)] -pub fn int_to_int(x: I32) { - let y = x + I32::from_int(2); - let _ = I32::cast_from(y) + I32::from_int(34); -} - -#[cube] -pub fn int_to_uint(x: I32) { - let y = x + I32::from_int(2); - let _ = UInt::cast_from(y) + UInt::from_int(34); -} - -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -pub fn int_to_bool(x: I32) { - let y = x + I32::from_int(2); - let _ = Bool::cast_from(y) || true; -} - -// // From uint -#[cube] -pub fn uint_to_float(x: UInt) { - let y = x + UInt::from_int(2); - let _ = F32::cast_from(y) + F32::from_int(34); -} - -#[cube] -pub fn uint_to_int(x: UInt) { - let y = x + UInt::from_int(2); - let _ = I32::cast_from(y) + I32::from_int(34); -} - -#[cube] -#[allow(clippy::useless_conversion)] -pub fn uint_to_uint(x: UInt) { - let y = x + UInt::from_int(2); - let _ = UInt::cast_from(y) + UInt::from_int(34); -} - -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -pub fn uint_to_bool(x: UInt) { - let y = x + UInt::from_int(2); - let _ = Bool::cast_from(y) || true; -} - -// From bool -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -pub fn bool_to_float(x: Bool) { - let y = x && false; - let _ = F32::cast_from(y) + F32::from_int(34); -} - -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -pub fn bool_to_int(x: Bool) { - let y = x && false; - let _ = I32::cast_from(y) + I32::from_int(34); -} - -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -pub fn bool_to_uint(x: Bool) { - let y = x && false; - let _ = UInt::cast_from(y) + UInt::from_int(34); -} - -#[cube] -#[allow(clippy::overly_complex_bool_expr)] -#[allow(clippy::useless_conversion)] -pub fn bool_to_bool(x: Bool) { - let y = x && false; - let _ = Bool::cast_from(y) || true; -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - frontend::{CubeContext, CubePrimitive}, - ir::{Elem, Item, Variable}, - }; - - macro_rules! cast_test { - ($name:ident, $module:expr, $from:expr, $to:expr) => { - #[test] - fn $name() { - let mut context = CubeContext::root(); - - let x = context.create_local($from); - - $module(&mut context, x); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_cast($from, $to) - ); - } - }; - } - - cast_test!( - cube_float_to_float_test, - float_to_float::__expand, - Item::new(F32::as_elem()), - Item::new(F32::as_elem()) - ); - - cast_test!( - cube_float_to_int_test, - float_to_int::__expand, - Item::new(F32::as_elem()), - Item::new(I32::as_elem()) - ); - - cast_test!( - cube_float_to_uint_test, - float_to_uint::__expand, - Item::new(F32::as_elem()), - Item::new(Elem::UInt) - ); - - cast_test!( - cube_float_to_bool_test, - float_to_bool::__expand, - Item::new(F32::as_elem()), - Item::new(Elem::Bool) - ); - - cast_test!( - cube_int_to_float_test, - int_to_float::__expand, - Item::new(I32::as_elem()), - Item::new(F32::as_elem()) - ); - - cast_test!( - cube_int_to_int_test, - int_to_int::__expand, - Item::new(I32::as_elem()), - Item::new(I32::as_elem()) - ); - - cast_test!( - cube_int_to_uint_test, - int_to_uint::__expand, - Item::new(I32::as_elem()), - Item::new(Elem::UInt) - ); - - cast_test!( - cube_int_to_bool_test, - int_to_bool::__expand, - Item::new(I32::as_elem()), - Item::new(Elem::Bool) - ); - - cast_test!( - cube_uint_to_float_test, - uint_to_float::__expand, - Item::new(Elem::UInt), - Item::new(F32::as_elem()) - ); - - cast_test!( - cube_uint_to_int_test, - uint_to_int::__expand, - Item::new(Elem::UInt), - Item::new(I32::as_elem()) - ); - - cast_test!( - cube_uint_to_uint_test, - uint_to_uint::__expand, - Item::new(Elem::UInt), - Item::new(Elem::UInt) - ); - - cast_test!( - cube_uint_to_bool_test, - uint_to_bool::__expand, - Item::new(Elem::UInt), - Item::new(Elem::Bool) - ); - - cast_test!( - cube_bool_to_float_test, - bool_to_float::__expand, - Item::new(Elem::Bool), - Item::new(F32::as_elem()) - ); - - cast_test!( - cube_bool_to_int_test, - bool_to_int::__expand, - Item::new(Elem::Bool), - Item::new(I32::as_elem()) - ); - - cast_test!( - cube_bool_to_uint_test, - bool_to_uint::__expand, - Item::new(Elem::Bool), - Item::new(Elem::UInt) - ); - - cast_test!( - cube_bool_to_bool_test, - bool_to_bool::__expand, - Item::new(Elem::Bool), - Item::new(Elem::Bool) - ); - - fn inline_macro_ref_cast(from_item: Item, to_item: Item) -> String { - let mut context = CubeContext::root(); - let x = context.create_local(from_item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y = scope.create_local(to_item); - - match from_item.elem() { - Elem::Float(_) => cpa!(scope, x = x + 2f32), - Elem::Int(_) => cpa!(scope, x = x + 2i32), - Elem::UInt => cpa!(scope, x = x + 2u32), - Elem::Bool => cpa!(scope, x = x && false), - } - - cpa!(scope, y = cast(x)); - - match to_item.elem() { - Elem::Float(_) => cpa!(scope, y = y + 34f32), - Elem::Int(_) => cpa!(scope, y = y + 34i32), - Elem::UInt => cpa!(scope, y = y + 34u32), - Elem::Bool => cpa!(scope, y = y || true), - } - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/cast_kind.rs b/crates/burn-cube/tests/frontend/cast_kind.rs deleted file mode 100644 index 26e54b6f6e..0000000000 --- a/crates/burn-cube/tests/frontend/cast_kind.rs +++ /dev/null @@ -1,127 +0,0 @@ -use burn_cube::{ - cube, - frontend::{Cast, Float, Int, Numeric}, -}; - -#[cube] -pub fn cast_float_kind(input: F1) { - let x = input + F1::new(5.9); - let y = F2::cast_from(x); - let _ = y + F2::new(2.3); -} - -#[cube] -pub fn cast_int_kind(input: I1) { - let x = input + I1::new(5); - let y = I2::cast_from(x); - let _ = y + I2::new(2); -} - -#[cube] -pub fn cast_numeric_to_kind(input: T) { - let x = input + T::from_int(5); - let y = I::cast_from(x); - let _ = y + I::from_int(2); -} - -#[cube] -pub fn cast_int_to_numeric(input: I) { - let x = input + I::from_int(5); - let y = T::cast_from(x); - let _ = y + T::from_int(2); -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - frontend::{CubeContext, CubePrimitive, F32, F64, I32, I64}, - ir::{Item, Variable}, - }; - - #[test] - fn cube_cast_float_kind_test() { - let mut context = CubeContext::root(); - let item = Item::new(F64::as_elem()); - - let input = context.create_local(item); - - cast_float_kind::__expand::(&mut context, input); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); - } - - #[test] - fn cube_cast_int_kind_test() { - let mut context = CubeContext::root(); - let item = Item::new(I32::as_elem()); - - let input = context.create_local(item); - - cast_int_kind::__expand::(&mut context, input); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); - } - - #[test] - fn cube_cast_numeric_kind_test() { - let mut context = CubeContext::root(); - let item = Item::new(I32::as_elem()); - - let input = context.create_local(item); - - cast_numeric_to_kind::__expand::(&mut context, input); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); - } - - #[test] - fn cube_cast_kind_numeric_test() { - let mut context = CubeContext::root(); - let item = Item::new(I32::as_elem()); - - let input = context.create_local(item); - - cast_int_to_numeric::__expand::(&mut context, input); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); - } - - fn inline_macro_ref_float() -> String { - let mut context = CubeContext::root(); - let float_64 = Item::new(F64::as_elem()); - let float_32 = Item::new(F32::as_elem()); - let input = context.create_local(float_64); - - let mut scope = context.into_scope(); - let input: Variable = input.into(); - let y = scope.create_local(float_32); - - cpa!(scope, input = input + 5.9f32 as f64); - cpa!(scope, y = cast(input)); - cpa!(scope, y = y + 2.3f32); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_int() -> String { - let mut context = CubeContext::root(); - let int_32 = Item::new(I32::as_elem()); - let int_64 = Item::new(I64::as_elem()); - let input = context.create_local(int_32); - - let mut scope = context.into_scope(); - let input: Variable = input.into(); - let y = scope.create_local(int_64); - - cpa!(scope, input = input + 5i32); - cpa!(scope, y = cast(input)); - cpa!(scope, y = y + 2i64); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/comptime.rs b/crates/burn-cube/tests/frontend/comptime.rs deleted file mode 100644 index 3e4be79bb1..0000000000 --- a/crates/burn-cube/tests/frontend/comptime.rs +++ /dev/null @@ -1,323 +0,0 @@ -use burn_cube::prelude::*; - -#[derive(Clone)] -pub struct State { - cond: bool, - bound: u32, -} - -impl Init for State { - fn init(self, _context: &mut CubeContext) -> Self { - self - } -} - -#[cube] -pub fn comptime_if_else(lhs: T, cond: Comptime) { - if Comptime::get(cond) { - let _ = lhs + T::from_int(4); - } else { - let _ = lhs - T::from_int(5); - } -} - -#[cube] -#[allow(clippy::collapsible_else_if)] -pub fn comptime_else_then_if(lhs: T, cond1: Comptime, cond2: Comptime) { - if Comptime::get(cond1) { - let _ = lhs + T::from_int(4); - } else { - if Comptime::get(cond2) { - let _ = lhs + T::from_int(5); - } else { - let _ = lhs - T::from_int(6); - } - } -} - -#[cube] -pub fn comptime_elsif(lhs: T, cond1: Comptime, cond2: Comptime) { - if Comptime::get(cond1) { - let _ = lhs + T::from_int(4); - } else if Comptime::get(cond2) { - let _ = lhs + T::from_int(5); - } else { - let _ = lhs - T::from_int(6); - } -} - -#[cube] -pub fn comptime_elsif_with_runtime1(lhs: T, comptime_cond: Comptime) { - let runtime_cond = lhs >= T::from_int(2); - if Comptime::get(comptime_cond) { - let _ = lhs + T::from_int(4); - } else if runtime_cond { - let _ = lhs + T::from_int(5); - } else { - let _ = lhs - T::from_int(6); - } -} - -#[cube] -pub fn comptime_elsif_with_runtime2(lhs: T, comptime_cond: Comptime) { - let runtime_cond = lhs >= T::from_int(2); - if runtime_cond { - let _ = lhs + T::from_int(4); - } else if Comptime::get(comptime_cond) { - let _ = lhs + T::from_int(5); - } else { - let _ = lhs - T::from_int(6); - } -} - -#[cube] -pub fn comptime_if_expr(lhs: T, x: Comptime, y: Comptime) { - let y2 = x + y; - - if x < y2 { - let _ = lhs + T::from_int(4); - } else { - let _ = lhs - T::from_int(5); - } -} - -#[cube] -pub fn comptime_with_map_bool(state: Comptime) -> T { - let cond = Comptime::map(state, |s: State| s.cond); - - let mut x = T::from_int(3); - if Comptime::get(cond) { - x += T::from_int(4); - } else { - x -= T::from_int(4); - } - x -} - -#[cube] -pub fn comptime_with_map_uint(state: Comptime) -> T { - let bound = Comptime::map(state, |s: State| s.bound); - - let mut x = T::from_int(3); - for _ in range(0u32, Comptime::get(bound), Comptime::new(true)) { - x += T::from_int(4); - } - - x -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - frontend::{CubeContext, CubePrimitive, F32}, - ir::{Elem, Item, Variable}, - }; - - type ElemType = F32; - - #[test] - fn cube_comptime_if_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(ElemType::as_elem())); - - comptime_if_else::__expand::(&mut context, lhs, true); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_comptime(true) - ); - } - - #[test] - fn cube_comptime_if_numeric_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(ElemType::as_elem())); - - comptime_if_expr::__expand::(&mut context, lhs, UInt::new(4), UInt::new(5)); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_comptime(true) - ); - } - - #[test] - fn cube_comptime_else_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(ElemType::as_elem())); - - comptime_if_else::__expand::(&mut context, lhs, false); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_comptime(false) - ); - } - - #[test] - fn cube_comptime_elsif_test() { - for cond1 in [false, true] { - for cond2 in [false, true] { - let mut context1 = CubeContext::root(); - let lhs = context1.create_local(Item::new(ElemType::as_elem())); - comptime_else_then_if::__expand::(&mut context1, lhs, cond1, cond2); - let scope1 = context1.into_scope(); - - let mut context2 = CubeContext::root(); - let lhs = context2.create_local(Item::new(ElemType::as_elem())); - comptime_elsif::__expand::(&mut context2, lhs, cond1, cond2); - let scope2 = context2.into_scope(); - - assert_eq!( - format!("{:?}", scope1.operations), - format!("{:?}", scope2.operations), - ); - } - } - } - - #[test] - fn cube_comptime_elsif_runtime1_test() { - for cond in [false, true] { - let mut context = CubeContext::root(); - let lhs = context.create_local(Item::new(ElemType::as_elem())); - comptime_elsif_with_runtime1::__expand::(&mut context, lhs, cond); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_elsif_runtime1(cond) - ); - } - } - - #[test] - fn cube_comptime_elsif_runtime2_test() { - for cond in [false, true] { - let mut context = CubeContext::root(); - let lhs = context.create_local(Item::new(ElemType::as_elem())); - comptime_elsif_with_runtime2::__expand::(&mut context, lhs, cond); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_elsif_runtime2(cond) - ); - } - } - - #[test] - fn cube_comptime_map_bool_test() { - let mut context1 = CubeContext::root(); - let mut context2 = CubeContext::root(); - - let comptime_state_true = State { - cond: true, - bound: 4, - }; - let comptime_state_false = State { - cond: false, - bound: 4, - }; - - comptime_with_map_bool::__expand::(&mut context1, comptime_state_true); - comptime_with_map_bool::__expand::(&mut context2, comptime_state_false); - - let scope1 = context1.into_scope(); - let scope2 = context2.into_scope(); - - assert_ne!( - format!("{:?}", scope1.operations), - format!("{:?}", scope2.operations) - ); - } - - #[test] - fn cube_comptime_map_uint_test() { - let mut context = CubeContext::root(); - - let comptime_state = State { - cond: true, - bound: 4, - }; - - comptime_with_map_uint::__expand::(&mut context, comptime_state); - - let scope = context.into_scope(); - - assert!(!format!("{:?}", scope.operations).contains("RangeLoop")); - } - - fn inline_macro_ref_comptime(cond: bool) -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y = scope.create_local(item); - - if cond { - cpa!(scope, y = x + 4.0f32); - } else { - cpa!(scope, y = x - 5.0f32); - }; - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_elsif_runtime1(comptime_cond: bool) -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let runtime_cond = scope.create_local(Item::new(Elem::Bool)); - let y = scope.create_local(item); - cpa!(scope, runtime_cond = x >= 2.0f32); - - if comptime_cond { - cpa!(scope, y = x + 4.0f32); - } else { - cpa!(&mut scope, if(runtime_cond).then(|scope| { - cpa!(scope, y = x + 5.0f32); - }).else(|scope| { - cpa!(scope, y = x - 6.0f32); - })); - }; - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_elsif_runtime2(comptime_cond: bool) -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let runtime_cond = scope.create_local(Item::new(Elem::Bool)); - let y = scope.create_local(item); - cpa!(scope, runtime_cond = x >= 2.0f32); - - cpa!(&mut scope, if(runtime_cond).then(|scope| { - cpa!(scope, y = x + 4.0f32); - }).else(|scope| { - if comptime_cond { - cpa!(scope, y = x + 5.0f32); - } else { - cpa!(scope, y = x - 6.0f32); - } - })); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/cube_trait.rs b/crates/burn-cube/tests/frontend/cube_trait.rs deleted file mode 100644 index d74814e167..0000000000 --- a/crates/burn-cube/tests/frontend/cube_trait.rs +++ /dev/null @@ -1,109 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -trait FunctionGeneric { - #[allow(unused)] - fn test(lhs: C, rhs: C) -> C; -} - -#[cube] -trait TraitGeneric { - #[allow(unused)] - fn test(lhs: C, rhs: C) -> C; -} - -#[cube] -trait CombinedTraitFunctionGeneric { - #[allow(unused)] - fn test(lhs: C, rhs: C) -> O; -} - -struct Test; - -#[cube] -impl FunctionGeneric for Test { - fn test(lhs: C, rhs: C) -> C { - lhs + rhs - } -} - -#[cube] -impl TraitGeneric for Test { - fn test(lhs: C, rhs: C) -> C { - lhs + rhs - } -} - -#[cube] -impl CombinedTraitFunctionGeneric for Test { - fn test(lhs: C, rhs: C) -> O { - O::cast_from(lhs + rhs) - } -} - -#[cube] -pub fn simple(lhs: C, rhs: C) -> C { - lhs + rhs -} - -#[cube] -pub fn with_cast(lhs: C, rhs: C) -> O { - O::cast_from(lhs + rhs) -} - -mod tests { - use burn_cube::ir::{Item, Scope}; - - use super::*; - - #[test] - fn test_function_generic() { - let mut context = CubeContext::root(); - let lhs = context.create_local(Item::new(F32::as_elem())); - let rhs = context.create_local(Item::new(F32::as_elem())); - - ::__expand_test::(&mut context, lhs, rhs); - - assert_eq!(simple_scope(), context.into_scope()); - } - - #[test] - fn test_trait_generic() { - let mut context = CubeContext::root(); - let lhs = context.create_local(Item::new(F32::as_elem())); - let rhs = context.create_local(Item::new(F32::as_elem())); - - >::__expand_test(&mut context, lhs, rhs); - - assert_eq!(simple_scope(), context.into_scope()); - } - - #[test] - fn test_combined_function_generic() { - let mut context = CubeContext::root(); - let lhs = context.create_local(Item::new(F32::as_elem())); - let rhs = context.create_local(Item::new(F32::as_elem())); - - >::__expand_test::(&mut context, lhs, rhs); - - assert_eq!(with_cast_scope(), context.into_scope()); - } - - fn simple_scope() -> Scope { - let mut context_ref = CubeContext::root(); - let lhs = context_ref.create_local(Item::new(F32::as_elem())); - let rhs = context_ref.create_local(Item::new(F32::as_elem())); - - simple::__expand::(&mut context_ref, lhs, rhs); - context_ref.into_scope() - } - - fn with_cast_scope() -> Scope { - let mut context_ref = CubeContext::root(); - let lhs = context_ref.create_local(Item::new(F32::as_elem())); - let rhs = context_ref.create_local(Item::new(F32::as_elem())); - - with_cast::__expand::(&mut context_ref, lhs, rhs); - context_ref.into_scope() - } -} diff --git a/crates/burn-cube/tests/frontend/for_loop.rs b/crates/burn-cube/tests/frontend/for_loop.rs deleted file mode 100644 index dcc1c9e0d2..0000000000 --- a/crates/burn-cube/tests/frontend/for_loop.rs +++ /dev/null @@ -1,79 +0,0 @@ -use burn_cube::{ - cube, - frontend::branch::range, - frontend::{Array, Comptime, CubeContext, CubePrimitive, Float, UInt, F32}, -}; - -type ElemType = F32; - -#[cube] -pub fn for_loop(mut lhs: Array, rhs: F, end: UInt, unroll: Comptime) { - let tmp1 = rhs * rhs; - let tmp2 = tmp1 + rhs; - - for i in range(0u32, end, unroll) { - lhs[i] = tmp2 + lhs[i]; - } -} - -mod tests { - use burn_cube::{cpa, ir::Item}; - - use super::*; - - #[test] - fn test_for_loop_with_unroll() { - let mut context = CubeContext::root(); - let unroll = true; - - let lhs = context.create_local_array(Item::new(ElemType::as_elem()), 4u32); - let rhs = context.create_local(Item::new(ElemType::as_elem())); - let end = 4u32.into(); - - for_loop::__expand::(&mut context, lhs.into(), rhs, end, unroll); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); - } - - #[test] - fn test_for_loop_no_unroll() { - let mut context = CubeContext::root(); - let unroll = false; - - let lhs = context.create_local_array(Item::new(ElemType::as_elem()), 4u32); - let rhs = context.create_local(Item::new(ElemType::as_elem())); - let end = 4u32.into(); - - for_loop::__expand::(&mut context, lhs.into(), rhs, end, unroll); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll)); - } - - fn inline_macro_ref(unroll: bool) -> String { - let context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - - let mut scope = context.into_scope(); - let lhs = scope.create_local_array(item, 4u32); - let rhs = scope.create_local(item); - let end = 4u32; - - // Kernel - let tmp1 = scope.create_local(item); - cpa!(scope, tmp1 = rhs * rhs); - cpa!(scope, tmp1 = tmp1 + rhs); - - cpa!( - &mut scope, - range(0u32, end, unroll).for_each(|i, scope| { - cpa!(scope, rhs = lhs[i]); - cpa!(scope, rhs = tmp1 + rhs); - cpa!(scope, lhs[i] = rhs); - }) - ); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/function_call.rs b/crates/burn-cube/tests/frontend/function_call.rs deleted file mode 100644 index 88d5d7f87a..0000000000 --- a/crates/burn-cube/tests/frontend/function_call.rs +++ /dev/null @@ -1,113 +0,0 @@ -use burn_cube::{ - cube, - frontend::{Numeric, UInt}, -}; - -#[cube] -pub fn caller_no_arg(x: UInt) { - let _ = x + callee_no_arg(); -} - -#[cube] -pub fn callee_no_arg() -> UInt { - UInt::from_int(8) -} - -#[cube] -pub fn no_call_no_arg(x: UInt) { - let _ = x + UInt::from_int(8); -} - -#[cube] -pub fn caller_with_arg(x: UInt) { - let _ = x + callee_with_arg(x); -} - -#[cube] -pub fn callee_with_arg(x: UInt) -> UInt { - x * UInt::from_int(8) -} - -#[cube] -pub fn no_call_with_arg(x: UInt) { - let _ = x + x * UInt::from_int(8); -} - -#[cube] -pub fn caller_with_generics(x: T) { - let _ = x + callee_with_generics::(x); -} - -#[cube] -pub fn callee_with_generics(x: T) -> T { - x * T::from_int(8) -} - -#[cube] -pub fn no_call_with_generics(x: T) { - let _ = x + x * T::from_int(8); -} - -mod tests { - use super::*; - use burn_cube::{ - frontend::{CubeContext, CubePrimitive, I64}, - ir::{Elem, Item}, - }; - - #[test] - fn cube_call_equivalent_to_no_call_no_arg_test() { - let mut caller_context = CubeContext::root(); - let x = caller_context.create_local(Item::new(Elem::UInt)); - caller_no_arg::__expand(&mut caller_context, x); - let caller_scope = caller_context.into_scope(); - - let mut no_call_context = CubeContext::root(); - let x = no_call_context.create_local(Item::new(Elem::UInt)); - no_call_no_arg::__expand(&mut no_call_context, x); - let no_call_scope = no_call_context.into_scope(); - - assert_eq!( - format!("{:?}", caller_scope.operations), - format!("{:?}", no_call_scope.operations) - ); - } - - #[test] - fn cube_call_equivalent_to_no_call_with_arg_test() { - let mut caller_context = CubeContext::root(); - - let x = caller_context.create_local(Item::new(Elem::UInt)); - caller_with_arg::__expand(&mut caller_context, x); - let caller_scope = caller_context.into_scope(); - - let mut no_call_context = CubeContext::root(); - let x = no_call_context.create_local(Item::new(Elem::UInt)); - no_call_with_arg::__expand(&mut no_call_context, x); - let no_call_scope = no_call_context.into_scope(); - - assert_eq!( - format!("{:?}", caller_scope.operations), - format!("{:?}", no_call_scope.operations) - ); - } - - #[test] - fn cube_call_equivalent_to_no_call_with_generics_test() { - let mut caller_context = CubeContext::root(); - type ElemType = I64; - let x = caller_context.create_local(Item::new(ElemType::as_elem())); - caller_with_generics::__expand::(&mut caller_context, x); - let caller_scope = caller_context.into_scope(); - - let mut no_call_context = CubeContext::root(); - let x = no_call_context.create_local(Item::new(ElemType::as_elem())); - no_call_with_generics::__expand::(&mut no_call_context, x); - let no_call_scope = no_call_context.into_scope(); - - assert_eq!( - format!("{:?}", caller_scope.operations), - format!("{:?}", no_call_scope.operations) - ); - } -} diff --git a/crates/burn-cube/tests/frontend/generic_kernel.rs b/crates/burn-cube/tests/frontend/generic_kernel.rs deleted file mode 100644 index d2879c4086..0000000000 --- a/crates/burn-cube/tests/frontend/generic_kernel.rs +++ /dev/null @@ -1,64 +0,0 @@ -use burn_cube::{cube, frontend::Numeric}; - -#[cube] -pub fn generic_kernel(lhs: T) { - let _ = lhs + T::from_int(5); -} - -mod tests { - use burn_cube::{ - cpa, - frontend::{CubeContext, CubePrimitive, F32, I32}, - ir::{Item, Variable}, - }; - - use super::*; - - #[test] - fn cube_generic_float_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(F32::as_elem())); - - generic_kernel::__expand::(&mut context, lhs); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float()); - } - - #[test] - fn cube_generic_int_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(I32::as_elem())); - - generic_kernel::__expand::(&mut context, lhs); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int()); - } - - fn inline_macro_ref_float() -> String { - let mut context = CubeContext::root(); - let item = Item::new(F32::as_elem()); - let var = context.create_local(item); - - let mut scope = context.into_scope(); - let var: Variable = var.into(); - cpa!(scope, var = var + 5.0f32); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_int() -> String { - let mut context = CubeContext::root(); - let item = Item::new(I32::as_elem()); - let var = context.create_local(item); - - let mut scope = context.into_scope(); - let var: Variable = var.into(); - cpa!(scope, var = var + 5); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/if.rs b/crates/burn-cube/tests/frontend/if.rs deleted file mode 100644 index 895e11e02d..0000000000 --- a/crates/burn-cube/tests/frontend/if.rs +++ /dev/null @@ -1,151 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -pub fn if_greater(lhs: T) { - if lhs > T::from_int(0) { - let _ = lhs + T::from_int(4); - } -} - -#[cube] -pub fn if_greater_var(lhs: T) { - let x = lhs > T::from_int(0); - if x { - let _ = lhs + T::from_int(4); - } -} - -#[cube] -pub fn if_then_else(lhs: F) { - if lhs < F::from_int(0) { - let _ = lhs + F::from_int(4); - } else { - let _ = lhs - F::from_int(5); - } -} - -#[cube] -pub fn elsif(lhs: F) { - if lhs < F::new(0.) { - let _ = lhs + F::new(2.); - } else if lhs > F::new(0.) { - let _ = lhs + F::new(1.); - } else { - let _ = lhs + F::new(0.); - } -} - -mod tests { - use burn_cube::{ - cpa, - frontend::{CubeContext, CubePrimitive, F32}, - ir::{Elem, Item, Variable}, - }; - - use super::*; - - type ElemType = F32; - - #[test] - fn cube_if_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(ElemType::as_elem())); - - if_greater::__expand::(&mut context, lhs); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_if()); - } - - #[test] - fn cube_if_else_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(ElemType::as_elem())); - - if_then_else::__expand::(&mut context, lhs); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_if_else() - ); - } - - #[test] - fn cube_elsif_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(ElemType::as_elem())); - - elsif::__expand::(&mut context, lhs); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_elsif()); - } - - fn inline_macro_ref_if() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::new(Elem::Bool)); - let lhs: Variable = lhs.into(); - let y = scope.create_local(item); - - cpa!(scope, cond = lhs > 0f32); - cpa!(&mut scope, if(cond).then(|scope| { - cpa!(scope, y = lhs + 4.0f32); - })); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_if_else() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::new(Elem::Bool)); - let lhs: Variable = lhs.into(); - let y = scope.create_local(item); - - cpa!(scope, cond = lhs < 0f32); - cpa!(&mut scope, if(cond).then(|scope| { - cpa!(scope, y = lhs + 4.0f32); - }).else(|scope|{ - cpa!(scope, y = lhs - 5.0f32); - })); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_elsif() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local(item); - - let mut scope = context.into_scope(); - let cond1 = scope.create_local(Item::new(Elem::Bool)); - let lhs: Variable = lhs.into(); - let y = scope.create_local(item); - let cond2 = scope.create_local(Item::new(Elem::Bool)); - - cpa!(scope, cond1 = lhs < 0f32); - cpa!(&mut scope, if(cond1).then(|scope| { - cpa!(scope, y = lhs + 2.0f32); - }).else(|mut scope|{ - cpa!(scope, cond2 = lhs > 0f32); - cpa!(&mut scope, if(cond2).then(|scope| { - cpa!(scope, y = lhs + 1.0f32); - }).else(|scope|{ - cpa!(scope, y = lhs + 0.0f32); - })); - })); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/literal.rs b/crates/burn-cube/tests/frontend/literal.rs deleted file mode 100644 index d825b0c73c..0000000000 --- a/crates/burn-cube/tests/frontend/literal.rs +++ /dev/null @@ -1,57 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -pub fn literal(lhs: F) { - let _ = lhs + F::from_int(5); -} - -#[cube] -pub fn literal_float_no_decimals(lhs: F) { - let _ = lhs + F::new(5.); -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - ir::{Item, Variable}, - }; - - type ElemType = F32; - - #[test] - fn cube_literal_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(ElemType::as_elem())); - - literal::__expand::(&mut context, lhs); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); - } - - #[test] - fn cube_literal_float_no_decimal_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(ElemType::as_elem())); - - literal_float_no_decimals::__expand::(&mut context, lhs); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); - } - - fn inline_macro_ref() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local(item); - - let mut scope = context.into_scope(); - let lhs: Variable = lhs.into(); - cpa!(scope, lhs = lhs + 5.0f32); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/loop.rs b/crates/burn-cube/tests/frontend/loop.rs deleted file mode 100644 index 5c0d318c6b..0000000000 --- a/crates/burn-cube/tests/frontend/loop.rs +++ /dev/null @@ -1,102 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -pub fn while_not(lhs: I) { - while lhs != I::from_int(0) { - let _ = lhs % I::from_int(1); - } -} - -#[cube] -pub fn manual_loop_break(lhs: I) { - loop { - if lhs != I::from_int(0) { - break; - } - let _ = lhs % I::from_int(1); - } -} - -#[cube] -pub fn loop_with_return(lhs: I) { - loop { - if lhs != I::from_int(0) { - return; - } - let _ = lhs % I::from_int(1); - } -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - ir::{Branch, Elem, Item, Variable}, - }; - - type ElemType = I32; - - #[test] - fn cube_while_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(ElemType::as_elem())); - - while_not::__expand::(&mut context, lhs); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(false)); - } - - #[test] - fn cube_loop_break_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(ElemType::as_elem())); - - manual_loop_break::__expand::(&mut context, lhs); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(false)); - } - - #[test] - fn cube_loop_with_return_test() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::new(ElemType::as_elem())); - - loop_with_return::__expand::(&mut context, lhs); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(true)); - } - - fn inline_macro_ref(is_return: bool) -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let lhs = context.create_local(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::new(Elem::Bool)); - let lhs: Variable = lhs.into(); - let rhs = scope.create_local(item); - - cpa!( - &mut scope, - loop(|scope| { - cpa!(scope, cond = lhs != 0); - cpa!(scope, if(cond).then(|scope|{ - match is_return { - true => scope.register(Branch::Return), - false => scope.register(Branch::Break) - } - })); - - cpa!(scope, rhs = lhs % 1i32); - }) - ); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/mod.rs b/crates/burn-cube/tests/frontend/mod.rs deleted file mode 100644 index 2f0aa5ddad..0000000000 --- a/crates/burn-cube/tests/frontend/mod.rs +++ /dev/null @@ -1,23 +0,0 @@ -mod array; -mod assign; -mod cast_elem; -mod cast_kind; -mod comptime; -mod cube_trait; -mod for_loop; -mod function_call; -mod generic_kernel; -mod r#if; -mod literal; -mod r#loop; -mod module_import; -mod ops; -mod parenthesis; -mod redeclare; -mod reuse; -mod shared_memory; -mod r#struct; -mod tensor; -mod topology; -mod r#trait; -mod vectorization; diff --git a/crates/burn-cube/tests/frontend/module_import.rs b/crates/burn-cube/tests/frontend/module_import.rs deleted file mode 100644 index fdd3c4bcac..0000000000 --- a/crates/burn-cube/tests/frontend/module_import.rs +++ /dev/null @@ -1,49 +0,0 @@ -use burn_cube::prelude::*; - -mod elsewhere { - use super::*; - - #[cube] - pub fn my_func(x: F) -> F { - x * F::from_int(2) - } -} - -mod here { - use super::*; - - #[cube] - pub fn caller(x: F) { - let _ = x + elsewhere::my_func::(x); - } - - #[cube] - pub fn no_call_ref(x: F) { - let _ = x + x * F::from_int(2); - } -} - -mod tests { - use super::*; - use burn_cube::ir::Item; - - type ElemType = F32; - - #[test] - fn cube_call_equivalent_to_no_call_no_arg_test() { - let mut caller_context = CubeContext::root(); - let x = caller_context.create_local(Item::new(ElemType::as_elem())); - here::caller::__expand::(&mut caller_context, x); - let caller_scope = caller_context.into_scope(); - - let mut no_call_context = CubeContext::root(); - let x = no_call_context.create_local(Item::new(ElemType::as_elem())); - here::no_call_ref::__expand::(&mut no_call_context, x); - let no_call_scope = no_call_context.into_scope(); - - assert_eq!( - format!("{:?}", caller_scope.operations), - format!("{:?}", no_call_scope.operations) - ); - } -} diff --git a/crates/burn-cube/tests/frontend/ops.rs b/crates/burn-cube/tests/frontend/ops.rs deleted file mode 100644 index e53f664c18..0000000000 --- a/crates/burn-cube/tests/frontend/ops.rs +++ /dev/null @@ -1,427 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -pub fn add_op(a: T, b: T) -> T { - a + b -} - -#[cube] -pub fn sub_op(a: T, b: T) -> T { - a - b -} - -#[cube] -pub fn mul_op(a: T, b: T) -> T { - a * b -} - -#[cube] -pub fn div_op(a: T, b: T) -> T { - a / b -} - -#[cube] -pub fn abs_op(a: T) -> T { - T::abs(a) -} - -#[cube] -pub fn exp_op(a: F) -> F { - F::exp(a) -} - -#[cube] -pub fn log_op(a: F) -> F { - F::log(a) -} - -#[cube] -pub fn log1p_op(a: F) -> F { - F::log1p(a) -} - -#[cube] -pub fn cos_op(a: F) -> F { - F::cos(a) -} - -#[cube] -pub fn sin_op(a: F) -> F { - F::sin(a) -} - -#[cube] -pub fn tanh_op(a: F) -> F { - F::tanh(a) -} - -#[cube] -pub fn powf_op(a: F, b: F) -> F { - F::powf(a, b) -} - -#[cube] -pub fn sqrt_op(a: F) -> F { - F::sqrt(a) -} - -#[cube] -pub fn floor_op(a: F) -> F { - F::floor(a) -} - -#[cube] -pub fn ceil_op(a: F) -> F { - F::ceil(a) -} - -#[cube] -pub fn erf_op(a: F) -> F { - F::erf(a) -} - -#[cube] -pub fn recip_op(a: F) -> F { - F::recip(a) -} - -#[cube] -pub fn equal_op(a: T, b: T) -> bool { - a == b -} - -#[cube] -pub fn not_equal_op(a: T, b: T) -> bool { - a != b -} - -#[cube] -pub fn lower_op(a: T, b: T) -> bool { - a < b -} - -#[cube] -pub fn greater_op(a: T, b: T) -> bool { - a > b -} - -#[cube] -pub fn lower_equal_op(a: T, b: T) -> bool { - a <= b -} - -#[cube] -pub fn greater_equal_op(a: T, b: T) -> bool { - a >= b -} - -#[cube] -pub fn modulo_op(a: UInt, b: UInt) -> UInt { - a % b -} - -#[cube] -pub fn remainder_op(a: T, b: T) -> T { - T::rem(a, b) -} - -#[cube] -pub fn max_op(a: T, b: T) -> T { - T::max(a, b) -} - -#[cube] -pub fn min_op(a: T, b: T) -> T { - T::min(a, b) -} - -#[cube] -pub fn and_op(a: bool, b: bool) -> bool { - a && b -} - -#[cube] -pub fn or_op(a: bool, b: bool) -> bool { - a || b -} - -#[cube] -pub fn not_op(a: bool) -> bool { - !a -} - -#[cube] -pub fn bitand_op(a: UInt, b: UInt) -> UInt { - a & b -} - -#[cube] -pub fn bitxor_op(a: UInt, b: UInt) -> UInt { - a ^ b -} - -#[cube] -pub fn shl_op(a: UInt, b: UInt) -> UInt { - a << b -} - -#[cube] -pub fn shr_op(a: UInt, b: UInt) -> UInt { - a >> b -} - -#[cube] -pub fn add_assign_op(mut a: T, b: T) { - a += b; -} - -#[cube] -pub fn sub_assign_op(mut a: T, b: T) { - a -= b; -} - -#[cube] -pub fn mul_assign_op(mut a: T, b: T) { - a *= b; -} - -#[cube] -pub fn div_assign_op(mut a: T, b: T) { - a /= b; -} - -mod tests { - use super::*; - use burn_cube::ir::{Elem, FloatKind, Item}; - - macro_rules! binary_test { - ($test_name:ident, $op_expand:expr, $op_name:expr, $func:ident) => { - #[test] - fn $test_name() { - let mut context = CubeContext::root(); - let x = context.create_local(Item::new(Elem::Float(FloatKind::F32))); - let y = context.create_local(Item::new(Elem::Float(FloatKind::F32))); - - $op_expand(&mut context, x, y); - - assert_eq!( - format!("{:?}", context.into_scope().operations), - $func($op_name) - ); - } - }; - } - - macro_rules! unary_test { - ($test_name:ident, $op_expand:expr, $op_name:expr) => { - #[test] - fn $test_name() { - let mut context = CubeContext::root(); - let x = context.create_local(Item::new(Elem::Float(FloatKind::F32))); - - $op_expand(&mut context, x); - - assert_eq!( - format!("{:?}", context.into_scope().operations), - ref_ops_unary($op_name) - ); - } - }; - } - - macro_rules! binary_boolean_test { - ($test_name:ident, $op_expand:expr, $op_name:expr) => { - #[test] - fn $test_name() { - let mut context = CubeContext::root(); - let x = context.create_local(Item::new(Elem::Bool)); - let y = context.create_local(Item::new(Elem::Bool)); - - $op_expand(&mut context, x, y); - - assert_eq!( - format!("{:?}", context.into_scope().operations), - ref_ops_binary_boolean($op_name) - ); - } - }; - } - - macro_rules! binary_uint_test { - ($test_name:ident, $op_expand:expr, $op_name:expr) => { - #[test] - fn $test_name() { - let mut context = CubeContext::root(); - let x = context.create_local(Item::new(Elem::UInt)); - let y = context.create_local(Item::new(Elem::UInt)); - - $op_expand(&mut context, x, y); - - assert_eq!( - format!("{:?}", context.into_scope().operations), - ref_ops_binary_uint($op_name) - ); - } - }; - } - - binary_test!(cube_can_add, add_op::__expand::, "Add", ref_ops_binary); - binary_test!(cube_can_sub, sub_op::__expand::, "Sub", ref_ops_binary); - binary_test!(cube_can_mul, mul_op::__expand::, "Mul", ref_ops_binary); - binary_test!(cube_can_div, div_op::__expand::, "Div", ref_ops_binary); - unary_test!(cube_can_abs, abs_op::__expand::, "Abs"); - unary_test!(cube_can_exp, exp_op::__expand::, "Exp"); - unary_test!(cube_can_log, log_op::__expand::, "Log"); - unary_test!(cube_can_log1p, log1p_op::__expand::, "Log1p"); - unary_test!(cube_can_cos, cos_op::__expand::, "Cos"); - unary_test!(cube_can_sin, sin_op::__expand::, "Sin"); - unary_test!(cube_can_tanh, tanh_op::__expand::, "Tanh"); - binary_test!( - cube_can_powf, - powf_op::__expand::, - "Powf", - ref_ops_binary - ); - unary_test!(cube_can_sqrt, sqrt_op::__expand::, "Sqrt"); - unary_test!(cube_can_erf, erf_op::__expand::, "Erf"); - unary_test!(cube_can_recip, recip_op::__expand::, "Recip"); - unary_test!(cube_can_floor, floor_op::__expand::, "Floor"); - unary_test!(cube_can_ceil, ceil_op::__expand::, "Ceil"); - binary_test!(cube_can_eq, equal_op::__expand::, "Equal", ref_ops_cmp); - binary_test!( - cube_can_ne, - not_equal_op::__expand::, - "NotEqual", - ref_ops_cmp - ); - binary_test!(cube_can_lt, lower_op::__expand::, "Lower", ref_ops_cmp); - binary_test!( - cube_can_le, - lower_equal_op::__expand::, - "LowerEqual", - ref_ops_cmp - ); - binary_test!( - cube_can_ge, - greater_equal_op::__expand::, - "GreaterEqual", - ref_ops_cmp - ); - binary_test!( - cube_can_gt, - greater_op::__expand::, - "Greater", - ref_ops_cmp - ); - binary_test!(cube_can_max, max_op::__expand::, "Max", ref_ops_binary); - binary_test!(cube_can_min, min_op::__expand::, "Min", ref_ops_binary); - binary_test!( - cube_can_add_assign, - add_assign_op::__expand::, - "Add", - ref_ops_binary - ); - binary_test!( - cube_can_sub_assign, - sub_assign_op::__expand::, - "Sub", - ref_ops_binary - ); - binary_test!( - cube_can_mul_assign, - mul_assign_op::__expand::, - "Mul", - ref_ops_binary - ); - binary_test!( - cube_can_div_assign, - div_assign_op::__expand::, - "Div", - ref_ops_binary - ); - binary_boolean_test!(cube_can_and, and_op::__expand, "And"); - binary_boolean_test!(cube_can_or, or_op::__expand, "Or"); - binary_uint_test!(cube_can_bitand, bitand_op::__expand, "BitwiseAnd"); - binary_uint_test!(cube_can_bitxor, bitxor_op::__expand, "BitwiseXor"); - binary_uint_test!(cube_can_shl, shl_op::__expand, "ShiftLeft"); - binary_uint_test!(cube_can_shr, shr_op::__expand, "ShiftRight"); - binary_uint_test!(cube_can_mod, modulo_op::__expand, "Modulo"); - binary_test!( - cube_can_rem, - remainder_op::__expand::, - "Remainder", - ref_ops_binary - ); - - #[test] - fn cube_can_not() { - let mut context = CubeContext::root(); - let x = context.create_local(Item::new(Elem::Bool)); - - not_op::__expand(&mut context, x); - - assert_eq!( - format!("{:?}", context.into_scope().operations), - ref_ops_unary_boolean("Not") - ); - } - - fn ref_ops_binary(ops_name: &str) -> String { - ref_ops_template(ops_name, "Float(F32)", "Float(F32)", true) - } - - fn ref_ops_unary(ops_name: &str) -> String { - ref_ops_template(ops_name, "Float(F32)", "Float(F32)", false) - } - - fn ref_ops_cmp(ops_name: &str) -> String { - ref_ops_template(ops_name, "Float(F32)", "Bool", true) - } - - fn ref_ops_unary_boolean(ops_name: &str) -> String { - ref_ops_template(ops_name, "Bool", "Bool", false) - } - - fn ref_ops_binary_boolean(ops_name: &str) -> String { - ref_ops_template(ops_name, "Bool", "Bool", true) - } - - fn ref_ops_binary_uint(ops_name: &str) -> String { - ref_ops_template(ops_name, "UInt", "UInt", true) - } - - fn ref_ops_template(ops_name: &str, in_type: &str, out_type: &str, binary: bool) -> String { - if binary { - let out_number = if in_type == out_type { 0 } else { 2 }; - format!( - "[Operator({ops_name}(BinaryOperator {{ \ - lhs: Local {{ id: 0, item: Item {{ \ - elem: {in_type}, \ - vectorization: 1 \ - }}, depth: 0 }}, \ - rhs: Local {{ id: 1, item: Item {{ \ - elem: {in_type}, \ - vectorization: 1 \ - }}, depth: 0 }}, \ - out: Local {{ id: {out_number}, item: Item {{ \ - elem: {out_type}, \ - vectorization: 1 \ - }}, depth: 0 }} \ - }}))]" - ) - } else { - format!( - "[Operator({ops_name}(UnaryOperator {{ \ - input: Local {{ id: 0, item: Item {{ \ - elem: {in_type}, \ - vectorization: 1 \ - }}, depth: 0 }}, \ - out: Local {{ id: 0, item: Item {{ \ - elem: {out_type}, \ - vectorization: 1 \ - }}, depth: 0 }} \ - }}))]" - ) - } - } -} diff --git a/crates/burn-cube/tests/frontend/parenthesis.rs b/crates/burn-cube/tests/frontend/parenthesis.rs deleted file mode 100644 index bb522cc1e1..0000000000 --- a/crates/burn-cube/tests/frontend/parenthesis.rs +++ /dev/null @@ -1,48 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -pub fn parenthesis(x: T, y: T, z: T) -> T { - x * (y + z) -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - ir::{Item, Variable}, - }; - - type ElemType = F32; - - #[test] - fn cube_parenthesis_priority_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - let y = context.create_local(Item::new(ElemType::as_elem())); - let z = context.create_local(Item::new(ElemType::as_elem())); - - parenthesis::__expand::(&mut context, x, y, z); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref()); - } - - fn inline_macro_ref() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local(item); - let y = context.create_local(item); - let z = context.create_local(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y: Variable = y.into(); - let z: Variable = z.into(); - - cpa!(scope, y = y + z); - cpa!(scope, x = x * y); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/redeclare.rs b/crates/burn-cube/tests/frontend/redeclare.rs deleted file mode 100644 index 4fb786721c..0000000000 --- a/crates/burn-cube/tests/frontend/redeclare.rs +++ /dev/null @@ -1,200 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -pub fn redeclare_same_scope(mut x: I) { - let i = I::new(1); - x += i; - let i = I::new(2); - x += i; -} - -#[cube] -pub fn redeclare_same_scope_other_type(mut x: I) -> F { - let i = I::new(1); - x += i; - let i = F::new(2.); - i + i -} - -#[cube] -pub fn redeclare_different_scope(mut x: I) { - let y = I::new(1); - x += y; - for _ in range(0u32, 2u32, Comptime::new(false)) { - let y = I::new(2); - x += y; - } -} - -#[cube] -pub fn redeclare_two_for_loops(mut x: UInt) { - for i in range(0u32, 2u32, Comptime::new(false)) { - x += i; - } - for i in range(0u32, 2u32, Comptime::new(false)) { - x += i; - x += i; - } -} - -mod tests { - use burn_cube::{ - cpa, - ir::{Item, Variable}, - }; - - use super::*; - - type ElemType = I32; - - #[test] - fn cube_redeclare_same_scope_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - - redeclare_same_scope::__expand::(&mut context, x); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_same_scope() - ); - } - - #[test] - fn cube_redeclare_same_scope_other_type_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - - redeclare_same_scope_other_type::__expand::(&mut context, x); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_same_scope_other_type() - ); - } - - #[test] - fn cube_redeclare_different_scope_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - - redeclare_different_scope::__expand::(&mut context, x); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_different() - ); - } - - #[test] - fn cube_redeclare_two_for_loops_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(UInt::as_elem())); - - redeclare_two_for_loops::__expand(&mut context, x); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_two_for_loops() - ); - } - - fn inline_macro_ref_same_scope() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - - let x = context.create_local(item); - let mut scope = context.into_scope(); - let x: Variable = x.into(); - - let i = scope.create_with_value(1, item); - cpa!(scope, x += i); - let value = Variable::ConstantScalar { - value: 2., - elem: item.elem(), - }; - cpa!(scope, i = value); - cpa!(scope, x += i); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_same_scope_other_type() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - - let x = context.create_local(item); - let mut scope = context.into_scope(); - let x: Variable = x.into(); - - let i = scope.create_with_value(1, item); - cpa!(scope, x += i); - let i = scope.create_with_value(2, Item::new(F32::as_elem())); - let y = scope.create_local(Item::new(F32::as_elem())); - cpa!(scope, y = i + i); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_different() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - - let x = context.create_local(item); - let end = 2u32; - let mut scope = context.into_scope(); - let x: Variable = x.into(); - - let y = scope.create_with_value(1, item); - cpa!(scope, x += y); - - cpa!( - &mut scope, - range(0u32, end, false).for_each(|_, scope| { - let value = Variable::ConstantScalar { - value: 2.into(), - elem: item.elem(), - }; - cpa!(scope, y = value); - cpa!(scope, x += y); - }) - ); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_two_for_loops() -> String { - let mut context = CubeContext::root(); - let item = Item::new(UInt::as_elem()); - - let x = context.create_local(item); - let end = 2u32; - let mut scope = context.into_scope(); - let x: Variable = x.into(); - - cpa!( - &mut scope, - range(0u32, end, false).for_each(|i, scope| { - cpa!(scope, x += i); - }) - ); - - cpa!( - &mut scope, - range(0u32, end, false).for_each(|i, scope| { - cpa!(scope, x += i); - cpa!(scope, x += i); - }) - ); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/reuse.rs b/crates/burn-cube/tests/frontend/reuse.rs deleted file mode 100644 index 53063f798a..0000000000 --- a/crates/burn-cube/tests/frontend/reuse.rs +++ /dev/null @@ -1,102 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -#[allow(clippy::assign_op_pattern)] -pub fn reuse(mut x: I) { - // a += b is more efficient than a = a + b - // Because the latter does not assume that a is the same in lhs and rhs - // Normally clippy should detect it - while x < I::from_int(10) { - x = x + I::from_int(1); - } -} - -#[cube] -pub fn reuse_incr(mut x: I) { - while x < I::from_int(10) { - x += I::from_int(1); - } -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - ir::{Branch, Elem, Item, Variable}, - }; - - type ElemType = I32; - #[test] - fn cube_reuse_assign_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - - reuse::__expand::(&mut context, x); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_assign()); - } - - #[test] - fn cube_reuse_incr_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - - reuse_incr::__expand::(&mut context, x); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_incr()); - } - - fn inline_macro_ref_assign() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::new(Elem::Bool)); - let x: Variable = x.into(); - let tmp = scope.create_local(item); - - cpa!( - &mut scope, - loop(|scope| { - cpa!(scope, cond = x < 10); - cpa!(scope, if(cond).then(|scope|{ - scope.register(Branch::Break); - })); - - cpa!(scope, tmp = x + 1); - cpa!(scope, x = tmp); - }) - ); - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_incr() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local(item); - - let mut scope = context.into_scope(); - let cond = scope.create_local(Item::new(Elem::Bool)); - let x: Variable = x.into(); - - cpa!( - &mut scope, - loop(|scope| { - cpa!(scope, cond = x < 10); - cpa!(scope, if(cond).then(|scope|{ - scope.register(Branch::Break); - })); - - cpa!(scope, x = x + 1); - }) - ); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/shared_memory.rs b/crates/burn-cube/tests/frontend/shared_memory.rs deleted file mode 100644 index 5018630cc1..0000000000 --- a/crates/burn-cube/tests/frontend/shared_memory.rs +++ /dev/null @@ -1,49 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -pub fn shared_memory_read_write(sm_size: Comptime) { - let mut shared = SharedMemory::::new(sm_size); - shared[0] = T::from_int(3); - let _ = shared[0]; -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - ir::{Item, Variable}, - }; - - type ElemType = F32; - - #[test] - fn cube_support_shared_memory() { - let mut context = CubeContext::root(); - - shared_memory_read_write::__expand::(&mut context, 512); - assert_eq!( - format!("{:?}", context.into_scope().operations), - inline_macro_ref() - ); - } - - fn inline_macro_ref() -> String { - let context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - - let mut scope = context.into_scope(); - let var = scope.create_local(item); - let pos: Variable = 0u32.into(); - - // Create - let shared = scope.create_shared(item, 512); - - // Write - cpa!(scope, shared[pos] = 3.0_f32); - - // Read - cpa!(scope, var = shared[pos]); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/struct.rs b/crates/burn-cube/tests/frontend/struct.rs deleted file mode 100644 index d5539dc070..0000000000 --- a/crates/burn-cube/tests/frontend/struct.rs +++ /dev/null @@ -1,159 +0,0 @@ -use burn_cube::prelude::*; - -#[derive(CubeType)] -pub struct State { - first: T, - second: T, -} - -#[cube] -pub fn state_receiver_with_reuse(state: State) -> T { - let x = state.first + state.second; - state.second + x + state.first -} - -#[cube] -pub fn attribute_modifier_reuse_field(mut state: State) -> T { - state.first = T::from_int(4); - state.first -} - -#[cube] -pub fn attribute_modifier_reuse_struct(mut state: State) -> State { - state.first = T::from_int(4); - state -} - -#[cube] -fn creator(x: T, second: T) -> State { - let mut state = State:: { first: x, second }; - state.second = state.first; - - state -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - ir::{Item, Variable}, - }; - - type ElemType = F32; - - #[test] - fn cube_new_struct_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - let y = context.create_local(Item::new(ElemType::as_elem())); - - creator::__expand::(&mut context, x, y); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - creator_inline_macro_ref() - ); - } - - #[test] - fn cube_struct_as_arg_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - let y = context.create_local(Item::new(ElemType::as_elem())); - - let expanded_state = StateExpand { - first: x, - second: y, - }; - state_receiver_with_reuse::__expand::(&mut context, expanded_state); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - receive_state_with_reuse_inline_macro_ref() - ); - } - - #[test] - fn cube_struct_assign_to_field_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - let y = context.create_local(Item::new(ElemType::as_elem())); - - let expanded_state = StateExpand { - first: x, - second: y, - }; - attribute_modifier_reuse_field::__expand::(&mut context, expanded_state); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - field_modifier_inline_macro_ref() - ); - } - - #[test] - fn cube_struct_assign_to_field_reuse_struct_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - let y = context.create_local(Item::new(ElemType::as_elem())); - - let expanded_state = StateExpand { - first: x, - second: y, - }; - attribute_modifier_reuse_struct::__expand::(&mut context, expanded_state); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - field_modifier_inline_macro_ref() - ); - } - - fn creator_inline_macro_ref() -> String { - let context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - - let mut scope = context.into_scope(); - let x = scope.create_local(item); - let y = scope.create_local(item); - cpa!(scope, y = x); - - format!("{:?}", scope.operations) - } - - fn field_modifier_inline_macro_ref() -> String { - let context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - - let mut scope = context.into_scope(); - scope.create_with_value(4, item); - - format!("{:?}", scope.operations) - } - - fn receive_state_with_reuse_inline_macro_ref() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local(item); - let y = context.create_local(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y: Variable = y.into(); - let z = scope.create_local(item); - - cpa!(scope, z = x + y); - cpa!(scope, z = y + z); - cpa!(scope, z = z + x); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/tensor.rs b/crates/burn-cube/tests/frontend/tensor.rs deleted file mode 100644 index 2d27c3ad76..0000000000 --- a/crates/burn-cube/tests/frontend/tensor.rs +++ /dev/null @@ -1,48 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -pub fn kernel(input: &Tensor) { - let _shape = input.shape(1); - let _stride = input.stride(1); - let _length = input.len(); -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - ir::{Item, Variable}, - }; - - type ElemType = F32; - - #[test] - fn cube_support_tensor_metadata() { - let mut context = CubeContext::root(); - let input = context.input(0, Item::new(ElemType::as_elem())); - - kernel::__expand::(&mut context, input.into()); - assert_eq!( - format!("{:?}", context.into_scope().operations), - inline_macro_ref() - ); - } - - fn inline_macro_ref() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let input = context.input(0, item); - - let mut scope = context.into_scope(); - let input: Variable = input.into(); - let x = scope.create_local(Item::new(UInt::as_elem())); - let y = scope.create_local(Item::new(UInt::as_elem())); - let z = scope.create_local(Item::new(UInt::as_elem())); - - cpa!(&mut scope, x = shape(input, 1u32)); - cpa!(&mut scope, y = stride(input, 1u32)); - cpa!(&mut scope, z = len(input)); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/topology.rs b/crates/burn-cube/tests/frontend/topology.rs deleted file mode 100644 index 9a65482493..0000000000 --- a/crates/burn-cube/tests/frontend/topology.rs +++ /dev/null @@ -1,46 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -pub fn topology_kernel(input: Tensor) { - let x = ABSOLUTE_POS + UInt::new(4); - let _ = input[x]; -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - ir::{Elem, Item, Variable}, - }; - - type ElemType = F32; - - #[test] - fn cube_support_topology() { - let mut context = CubeContext::root(); - let input = context.input(0, Item::new(ElemType::as_elem())); - - topology_kernel::__expand::(&mut context, input.into()); - assert_eq!( - format!("{:?}", context.into_scope().operations), - inline_macro_ref() - ); - } - - fn inline_macro_ref() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let input = context.input(0, item); - - let mut scope = context.into_scope(); - let input: Variable = input.into(); - let x = scope.create_local(Item::new(Elem::UInt)); - let y = scope.create_local(item); - - let id = Variable::AbsolutePos; - cpa!(&mut scope, x = id + 4u32); - cpa!(&mut scope, y = input[x]); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/trait.rs b/crates/burn-cube/tests/frontend/trait.rs deleted file mode 100644 index b85c43c21a..0000000000 --- a/crates/burn-cube/tests/frontend/trait.rs +++ /dev/null @@ -1,177 +0,0 @@ -use burn_cube::prelude::*; - -/// Traits used in Cube kernels must expose an _expand variant -/// for all their methods. However, one does not need to provide its -/// implementation, see examples below. -#[cube] -pub trait Strategy { - fn operation(input_1: T, input_2: T) -> T; -} - -struct AddStrategy; - -#[cube] -/// The actual implementation of AddStrategy's operation -/// Automatically generated an _expand variant -pub fn add_strategy_operation(input_1: T, input_2: T) -> T { - input_1 + input_2 -} - -#[cube] -impl Strategy for AddStrategy { - fn operation(input_1: T, input_2: T) -> T { - add_strategy_operation::(input_1, input_2) - } -} - -struct SubStrategy; - -#[cube] -impl Strategy for SubStrategy { - fn operation(input_1: T, input_2: T) -> T { - input_1 - input_2 - } -} - -#[cube] -pub fn with_strategy_trait, T: Numeric>(x: T, y: T) -> T { - S::operation(x, y) -} - -#[cube] -pub fn two_strategy_traits, S2: Strategy, F: Float>(x: F, y: F) -> F { - let z = S1::operation(x, y); - S2::operation(z, y) -} - -pub trait MethodTypedStrategy { - fn operation(input_1: T, input_2: T) -> T; - fn __expand_operation( - _context: &mut CubeContext, - input_1: ::ExpandType, - input_2: ::ExpandType, - ) -> ::ExpandType; -} - -impl MethodTypedStrategy for AddStrategy { - fn operation(input_1: T, input_2: T) -> T { - add_strategy_operation(input_1, input_2) - } - - fn __expand_operation( - context: &mut CubeContext, - input_1: ::ExpandType, - input_2: ::ExpandType, - ) -> ::ExpandType { - add_strategy_operation::__expand::(context, input_1, input_2) - } -} - -#[cube] -pub fn with_trait_generic_method(x: T, y: T) -> T { - S::operation::(x, y) -} - -mod tests { - use super::*; - use burn_cube::{ - cpa, - ir::{Item, Variable}, - }; - - type ElemType = F32; - #[test] - fn cube_strategy_trait_add_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - let y = context.create_local(Item::new(ElemType::as_elem())); - - with_strategy_trait::__expand::(&mut context, x, y); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_one(true) - ); - } - - #[test] - fn cube_strategy_trait_sub_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - let y = context.create_local(Item::new(ElemType::as_elem())); - - with_strategy_trait::__expand::(&mut context, x, y); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_one(false) - ); - } - - #[test] - fn cube_two_strategy_traits_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - let y = context.create_local(Item::new(ElemType::as_elem())); - - two_strategy_traits::__expand::(&mut context, x, y); - let scope = context.into_scope(); - - assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_two()); - } - - #[test] - fn cube_trait_generic_method_test() { - let mut context = CubeContext::root(); - - let x = context.create_local(Item::new(ElemType::as_elem())); - let y = context.create_local(Item::new(ElemType::as_elem())); - - with_trait_generic_method::__expand::(&mut context, x, y); - let scope = context.into_scope(); - - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_ref_one(true) - ); - } - - fn inline_macro_ref_one(is_add_strategy: bool) -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local(item); - let y = context.create_local(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y: Variable = y.into(); - - match is_add_strategy { - true => cpa!(scope, x = x + y), - false => cpa!(scope, x = x - y), - } - - format!("{:?}", scope.operations) - } - - fn inline_macro_ref_two() -> String { - let mut context = CubeContext::root(); - let item = Item::new(ElemType::as_elem()); - let x = context.create_local(item); - let y = context.create_local(item); - - let mut scope = context.into_scope(); - let x: Variable = x.into(); - let y: Variable = y.into(); - - cpa!(scope, x = x - y); - cpa!(scope, x = x + y); - - format!("{:?}", scope.operations) - } -} diff --git a/crates/burn-cube/tests/frontend/vectorization.rs b/crates/burn-cube/tests/frontend/vectorization.rs deleted file mode 100644 index 4dd0a8dcf7..0000000000 --- a/crates/burn-cube/tests/frontend/vectorization.rs +++ /dev/null @@ -1,65 +0,0 @@ -use burn_cube::prelude::*; - -#[cube] -pub fn vectorization_binary(lhs: T) { - let _ = lhs + T::from_vec([4, 5]); -} - -#[cube] -pub fn vectorization_cmp(rhs: T) { - let _ = T::from_vec([4, 5]) > rhs; -} - -mod tests { - use super::*; - use burn_cube::ir::Item; - - type ElemType = F32; - - #[test] - fn cube_vectorization_binary_op_with_same_scheme_does_not_fail() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2)); - - vectorization_binary::__expand::(&mut context, lhs); - } - - #[test] - #[should_panic] - fn cube_vectorization_binary_op_with_different_scheme_fails() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4)); - - vectorization_binary::__expand::(&mut context, lhs); - } - - #[test] - fn cube_vectorization_cmp_op_with_same_scheme_does_not_fail() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2)); - - vectorization_cmp::__expand::(&mut context, lhs); - } - - #[test] - #[should_panic] - fn cube_vectorization_cmp_op_with_different_scheme_fails() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4)); - - vectorization_cmp::__expand::(&mut context, lhs); - } - - #[test] - fn cube_vectorization_can_be_broadcasted() { - let mut context = CubeContext::root(); - - let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 1)); - - vectorization_cmp::__expand::(&mut context, lhs); - } -} diff --git a/crates/burn-cube/tests/mod.rs b/crates/burn-cube/tests/mod.rs deleted file mode 100644 index 40398e64cb..0000000000 --- a/crates/burn-cube/tests/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod frontend; - -#[test] -fn compile_fail_tests() { - let t = trybuild::TestCases::new(); - t.compile_fail("tests/error/*.rs"); -} diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml index 766094108b..420831b2d0 100644 --- a/crates/burn-cuda/Cargo.toml +++ b/crates/burn-cuda/Cargo.toml @@ -18,16 +18,13 @@ doc = ["burn-jit/doc"] std = ["burn-jit/std"] [dependencies] +cubecl = { workspace = true, features = ["cuda"] } burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false } -burn-compute = { path = "../burn-compute", version = "0.14.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.14.0" } -burn-common = { path = "../burn-common", version = "0.14.0" } -burn-cube = { path = "../burn-cube", version = "0.14.0" } +burn-tensor = { path = "../burn-tensor", version = "0.14.0", features = ["cubecl-cuda"] } burn-fusion = { path = "../burn-fusion", version = "0.14.0", optional = true } half = { workspace = true } bytemuck = { workspace = true } -cudarc = { version = "0.11.7", features = ["cuda-12030"] } log = { workspace = true } derive-new = { workspace = true } @@ -36,9 +33,6 @@ derive-new = { workspace = true } burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false, features = [ "export_tests", ] } -burn-cube = { path = "../burn-cube", version = "0.14.0", features = [ - "export_tests", -] } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-cuda/src/compiler/base.rs b/crates/burn-cuda/src/compiler/base.rs deleted file mode 100644 index d0f2a05534..0000000000 --- a/crates/burn-cuda/src/compiler/base.rs +++ /dev/null @@ -1,554 +0,0 @@ -use burn_cube::{ir as gpu, Compiler}; - -use super::{Instruction, WarpInstruction}; - -#[allow(clippy::too_many_arguments)] -#[derive(new, Clone, Debug, Default)] -pub struct CudaCompiler { - shape: bool, - stride: bool, - num_inputs: usize, - num_outputs: usize, - shared_memories: Vec, - local_arrays: Vec, - id: bool, - rank: bool, - invocation_index: bool, - global_invocation_id: (bool, bool, bool), - wrap_size_checked: bool, - wmma: bool, -} - -impl Compiler for CudaCompiler { - type Representation = super::ComputeShader; - - fn compile(shader: burn_cube::ir::KernelDefinition) -> Self::Representation { - let compiler = Self::default(); - compiler.compile_shader(shader) - } - - fn elem_size(elem: gpu::Elem) -> usize { - Self::compile_elem(elem).size() - } - - fn max_shared_memory_size() -> usize { - // TODO: Find out this value. - usize::MAX - } -} - -impl CudaCompiler { - fn compile_shader(mut self, mut value: gpu::KernelDefinition) -> super::ComputeShader { - self.num_inputs = value.inputs.len(); - self.num_outputs = value.outputs.len(); - - let instructions = self.compile_scope(&mut value.body); - let body = super::Body { - instructions, - stride: true, - shape: true, - shared_memories: self.shared_memories, - local_arrays: self.local_arrays, - rank: self.rank, - id: self.id, - invocation_index: self.invocation_index, - global_invocation_id: self.global_invocation_id, - wrap_size_checked: self.wrap_size_checked, - }; - - super::ComputeShader { - inputs: value - .inputs - .into_iter() - .map(Self::compile_binding) - .collect(), - outputs: value - .outputs - .into_iter() - .map(Self::compile_binding) - .collect(), - named: value - .named - .into_iter() - .map(|(name, binding)| (name, Self::compile_binding(binding))) - .collect(), - cube_dim: value.cube_dim, - body, - wmma_activated: self.wmma, - } - } - - fn compile_scope(&mut self, value: &mut gpu::Scope) -> Vec { - let mut instructions = Vec::new(); - let processing = value.process(); - - for var in processing.variables { - if let gpu::Variable::Slice { .. } = var { - continue; - } - instructions.push(Instruction::DeclareVariable { - var: self.compile_variable(var), - }); - } - - processing - .operations - .into_iter() - .for_each(|op| self.compile_operation(&mut instructions, op, value)); - - instructions - } - - fn compile_operation( - &mut self, - instructions: &mut Vec, - operation: gpu::Operation, - scope: &mut gpu::Scope, - ) { - match operation { - gpu::Operation::Operator(op) => instructions.push(self.compile_instruction(op)), - gpu::Operation::Procedure(proc) => self.compile_procedure(instructions, proc, scope), - gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op)), - gpu::Operation::Branch(val) => self.compile_branch(instructions, val), - gpu::Operation::Synchronization(val) => match val { - gpu::Synchronization::SyncUnits => instructions.push(Instruction::SyncThreads), - }, - gpu::Operation::Subcube(op) => { - self.wrap_size_checked = true; - match op { - gpu::Subcube::Sum(op) => { - instructions.push(Instruction::Wrap(WarpInstruction::ReduceSum { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - })) - } - gpu::Subcube::Prod(op) => { - instructions.push(Instruction::Wrap(WarpInstruction::ReduceProd { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - })) - } - gpu::Subcube::Max(op) => { - instructions.push(Instruction::Wrap(WarpInstruction::ReduceMax { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - })) - } - - gpu::Subcube::Min(op) => { - instructions.push(Instruction::Wrap(WarpInstruction::ReduceMin { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - })) - } - - _ => todo!(), - } - } - gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma)), - } - } - - fn compile_cmma(&mut self, cmma: gpu::CoopMma) -> Instruction { - match cmma { - gpu::CoopMma::Fill { mat: frag, value } => { - Instruction::Wmma(super::WmmaInstruction::Fill { - frag: self.compile_variable(frag), - value: self.compile_variable(value), - }) - } - gpu::CoopMma::Load { mat, value, stride } => { - Instruction::Wmma(super::WmmaInstruction::Load { - frag: self.compile_variable(mat), - value: self.compile_variable(value), - stride: self.compile_variable(stride), - }) - } - gpu::CoopMma::Execute { - mat_a, - mat_b, - mat_c, - mat_d, - } => Instruction::Wmma(super::WmmaInstruction::Execute { - frag_a: self.compile_variable(mat_a), - frag_b: self.compile_variable(mat_b), - frag_c: self.compile_variable(mat_c), - frag_d: self.compile_variable(mat_d), - }), - gpu::CoopMma::Store { - output, - mat, - stride, - layout, - } => Instruction::Wmma(super::WmmaInstruction::Store { - output: self.compile_variable(output), - frag: self.compile_variable(mat), - stride: self.compile_variable(stride), - layout: Self::compile_matrix_layout(layout) - .expect("Layout required for store instruction"), - }), - } - } - - fn compile_metadata(&mut self, metadata: gpu::Metadata) -> Instruction { - match metadata { - gpu::Metadata::Stride { dim, var, out } => { - self.stride = true; - let position = match var { - gpu::Variable::GlobalInputArray { id, .. } => id as usize, - gpu::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize, - _ => panic!("Only Input and Output have a stride, got: {:?}", var), - }; - Instruction::Stride { - dim: self.compile_variable(dim), - position, - out: self.compile_variable(out), - } - } - gpu::Metadata::Shape { dim, var, out } => { - self.shape = true; - let position = match var { - gpu::Variable::GlobalInputArray { id, .. } => id as usize, - gpu::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize, - _ => panic!("Only Input and Output have a shape, got {:?}", var), - }; - Instruction::Shape { - dim: self.compile_variable(dim), - position, - out: self.compile_variable(out), - } - } - gpu::Metadata::Length { var, out } => { - let input = self.compile_variable(var); - let out = self.compile_variable(out); - - match input { - super::Variable::Slice { .. } => super::Instruction::SliceLength { input, out }, - _ => super::Instruction::Length { - input, - out, - num_inputs: self.num_inputs, - num_outputs: self.num_outputs, - }, - } - } - } - } - - fn compile_branch(&mut self, instructions: &mut Vec, branch: gpu::Branch) { - match branch { - gpu::Branch::If(mut op) => instructions.push(Instruction::If { - cond: self.compile_variable(op.cond), - instructions: self.compile_scope(&mut op.scope), - }), - gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse { - cond: self.compile_variable(op.cond), - instructions_if: self.compile_scope(&mut op.scope_if), - instructions_else: self.compile_scope(&mut op.scope_else), - }), - gpu::Branch::Return => instructions.push(Instruction::Return), - gpu::Branch::Break => instructions.push(Instruction::Break), - gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop { - i: self.compile_variable(range_loop.i), - start: self.compile_variable(range_loop.start), - end: self.compile_variable(range_loop.end), - instructions: self.compile_scope(&mut range_loop.scope), - }), - gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop { - instructions: self.compile_scope(&mut op.scope), - }), - }; - } - fn compile_procedure( - &mut self, - instructions: &mut Vec, - proc: gpu::Procedure, - scope: &mut gpu::Scope, - ) { - let mut compile = |scope: &mut gpu::Scope| { - instructions.extend(self.compile_scope(scope)); - }; - - match proc { - gpu::Procedure::ReadGlobalWithLayout(proc) => { - proc.expand(scope); - compile(scope); - } - gpu::Procedure::ReadGlobal(proc) => { - proc.expand(scope); - compile(scope); - } - gpu::Procedure::WriteGlobal(proc) => { - proc.expand(scope); - compile(scope); - } - gpu::Procedure::ConditionalAssign(proc) => { - proc.expand(scope); - compile(scope); - } - gpu::Procedure::CheckedIndex(proc) => { - proc.expand(scope); - compile(scope); - } - gpu::Procedure::CheckedIndexAssign(proc) => { - proc.expand(scope); - compile(scope); - } - gpu::Procedure::IndexOffsetGlobalWithLayout(proc) => { - proc.expand(scope); - compile(scope); - } - } - } - - fn compile_instruction(&mut self, value: gpu::Operator) -> Instruction { - match value { - gpu::Operator::Add(op) => Instruction::Add(self.compile_binary(op)), - gpu::Operator::Mul(op) => Instruction::Mul(self.compile_binary(op)), - gpu::Operator::Div(op) => Instruction::Div(self.compile_binary(op)), - gpu::Operator::Sub(op) => Instruction::Sub(self.compile_binary(op)), - gpu::Operator::Assign(op) => Instruction::Assign(self.compile_unary(op)), - gpu::Operator::Slice(op) => Instruction::Slice { - input: self.compile_variable(op.input), - start: self.compile_variable(op.start), - end: self.compile_variable(op.end), - out: self.compile_variable(op.out), - }, - gpu::Operator::Index(op) => Instruction::Index(self.compile_binary(op)), - gpu::Operator::UncheckedIndex(op) => Instruction::Index(self.compile_binary(op)), - gpu::Operator::IndexAssign(op) => Instruction::IndexAssign(self.compile_binary(op)), - gpu::Operator::UncheckedIndexAssign(op) => { - Instruction::IndexAssign(self.compile_binary(op)) - } - gpu::Operator::Modulo(op) => Instruction::Modulo(self.compile_binary(op)), - gpu::Operator::Equal(op) => Instruction::Equal(self.compile_binary(op)), - gpu::Operator::Lower(op) => Instruction::Lower(self.compile_binary(op)), - gpu::Operator::Greater(op) => Instruction::Greater(self.compile_binary(op)), - gpu::Operator::LowerEqual(op) => Instruction::LowerEqual(self.compile_binary(op)), - gpu::Operator::GreaterEqual(op) => Instruction::GreaterEqual(self.compile_binary(op)), - gpu::Operator::Abs(op) => Instruction::Abs(self.compile_unary(op)), - gpu::Operator::Exp(op) => Instruction::Exp(self.compile_unary(op)), - gpu::Operator::Log(op) => Instruction::Log(self.compile_unary(op)), - gpu::Operator::Log1p(op) => Instruction::Log1p(self.compile_unary(op)), - gpu::Operator::Cos(op) => Instruction::Cos(self.compile_unary(op)), - gpu::Operator::Sin(op) => Instruction::Sin(self.compile_unary(op)), - gpu::Operator::Tanh(op) => Instruction::Tanh(self.compile_unary(op)), - gpu::Operator::Powf(op) => Instruction::Powf(self.compile_binary(op)), - gpu::Operator::Sqrt(op) => Instruction::Sqrt(self.compile_unary(op)), - gpu::Operator::Erf(op) => Instruction::Erf(self.compile_unary(op)), - gpu::Operator::And(op) => Instruction::And(self.compile_binary(op)), - gpu::Operator::Or(op) => Instruction::Or(self.compile_binary(op)), - gpu::Operator::Not(op) => Instruction::Not(self.compile_unary(op)), - gpu::Operator::Max(op) => Instruction::Max(self.compile_binary(op)), - gpu::Operator::Min(op) => Instruction::Min(self.compile_binary(op)), - gpu::Operator::NotEqual(op) => Instruction::NotEqual(self.compile_binary(op)), - gpu::Operator::BitwiseAnd(op) => Instruction::BitwiseAnd(self.compile_binary(op)), - gpu::Operator::BitwiseXor(op) => Instruction::BitwiseXor(self.compile_binary(op)), - gpu::Operator::ShiftLeft(op) => Instruction::ShiftLeft(self.compile_binary(op)), - gpu::Operator::ShiftRight(op) => Instruction::ShiftRight(self.compile_binary(op)), - gpu::Operator::Clamp(op) => Instruction::Clamp { - input: self.compile_variable(op.input), - min_value: self.compile_variable(op.min_value), - max_value: self.compile_variable(op.max_value), - out: self.compile_variable(op.out), - }, - gpu::Operator::Recip(op) => Instruction::Div(super::BinaryInstruction { - lhs: super::Variable::ConstantScalar( - 1.0, - Self::compile_elem(op.input.item().elem()), - ), - rhs: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }), - gpu::Operator::Floor(op) => Instruction::Floor(self.compile_unary(op)), - gpu::Operator::Ceil(op) => Instruction::Ceil(self.compile_unary(op)), - gpu::Operator::Remainder(_op) => todo!(), - gpu::Operator::Fma(op) => Instruction::Fma { - a: self.compile_variable(op.a), - b: self.compile_variable(op.b), - c: self.compile_variable(op.c), - out: self.compile_variable(op.out), - }, - } - } - - fn compile_binary(&mut self, value: gpu::BinaryOperator) -> super::BinaryInstruction { - super::BinaryInstruction { - lhs: self.compile_variable(value.lhs), - rhs: self.compile_variable(value.rhs), - out: self.compile_variable(value.out), - } - } - - fn compile_unary(&mut self, value: gpu::UnaryOperator) -> super::UnaryInstruction { - super::UnaryInstruction { - input: self.compile_variable(value.input), - out: self.compile_variable(value.out), - } - } - - fn compile_variable(&mut self, value: gpu::Variable) -> super::Variable { - match value { - gpu::Variable::GlobalInputArray { id, item } => { - super::Variable::GlobalInputArray(id, Self::compile_item(item)) - } - gpu::Variable::GlobalScalar { id, elem } => { - super::Variable::GlobalScalar(id, Self::compile_elem(elem), elem) - } - gpu::Variable::Local { id, item, depth } => super::Variable::Local { - id, - item: Self::compile_item(item), - depth, - }, - gpu::Variable::Slice { id, item, depth } => super::Variable::Slice { - id, - item: Self::compile_item(item), - depth, - }, - gpu::Variable::LocalScalar { id, elem, depth } => super::Variable::LocalScalar { - id, - elem: Self::compile_elem(elem), - depth, - }, - gpu::Variable::GlobalOutputArray { id, item } => { - super::Variable::GlobalOutputArray(id, Self::compile_item(item)) - } - gpu::Variable::ConstantScalar { value, elem } => { - super::Variable::ConstantScalar(value, Self::compile_elem(elem)) - } - gpu::Variable::SharedMemory { id, item, length } => { - let item = Self::compile_item(item); - if !self.shared_memories.iter().any(|s| s.index == id) { - self.shared_memories - .push(super::SharedMemory::new(id, item, length)); - } - super::Variable::SharedMemory(id, item, length) - } - gpu::Variable::AbsolutePos => { - self.id = true; - super::Variable::Id - } - gpu::Variable::Rank => { - self.rank = true; - super::Variable::Rank - } - gpu::Variable::UnitPos => { - self.invocation_index = true; - super::Variable::LocalInvocationIndex - } - gpu::Variable::UnitPosX => super::Variable::LocalInvocationIdX, - gpu::Variable::UnitPosY => super::Variable::LocalInvocationIdY, - gpu::Variable::UnitPosZ => super::Variable::LocalInvocationIdZ, - gpu::Variable::CubePosX => super::Variable::WorkgroupIdX, - gpu::Variable::CubePosY => super::Variable::WorkgroupIdY, - gpu::Variable::CubePosZ => super::Variable::WorkgroupIdZ, - gpu::Variable::AbsolutePosX => { - self.global_invocation_id.0 = true; - super::Variable::GlobalInvocationIdX - } - gpu::Variable::AbsolutePosY => { - self.global_invocation_id.1 = true; - super::Variable::GlobalInvocationIdY - } - gpu::Variable::AbsolutePosZ => { - self.global_invocation_id.2 = true; - super::Variable::GlobalInvocationIdZ - } - gpu::Variable::CubeDimX => super::Variable::WorkgroupSizeX, - gpu::Variable::CubeDimY => super::Variable::WorkgroupSizeY, - gpu::Variable::CubeDimZ => super::Variable::WorkgroupSizeZ, - gpu::Variable::CubeCountX => super::Variable::NumWorkgroupsX, - gpu::Variable::CubeCountY => super::Variable::NumWorkgroupsY, - gpu::Variable::CubeCountZ => super::Variable::NumWorkgroupsZ, - gpu::Variable::LocalArray { - id, - item, - depth, - length, - } => { - let item = Self::compile_item(item); - if !self - .local_arrays - .iter() - .any(|s| s.index == id && s.depth == depth) - { - self.local_arrays - .push(super::LocalArray::new(id, item, depth, length)); - } - super::Variable::LocalArray(id, item, depth, length) - } - gpu::Variable::CubePos => todo!(), - gpu::Variable::CubeDim => todo!(), - gpu::Variable::CubeCount => todo!(), - gpu::Variable::SubcubeDim => todo!(), - gpu::Variable::Matrix { id, mat } => { - self.wmma = true; - super::Variable::WmmaFragment { - id, - frag: Self::compile_matrix(mat), - } - } - } - } - - fn compile_matrix(matrix: gpu::Matrix) -> super::Fragment { - super::Fragment { - ident: Self::compile_matrix_ident(matrix.ident), - m: matrix.m, - n: matrix.n, - k: matrix.k, - elem: Self::compile_elem(matrix.elem), - layout: Self::compile_matrix_layout(matrix.layout), - } - } - - fn compile_matrix_ident(ident: gpu::MatrixIdent) -> super::FragmentIdent { - match ident { - gpu::MatrixIdent::A => super::FragmentIdent::A, - gpu::MatrixIdent::B => super::FragmentIdent::B, - gpu::MatrixIdent::Accumulator => super::FragmentIdent::Accumulator, - } - } - - fn compile_matrix_layout(layout: gpu::MatrixLayout) -> Option { - match layout { - gpu::MatrixLayout::ColMajor => Some(super::FragmentLayout::ColMajor), - gpu::MatrixLayout::RowMajor => Some(super::FragmentLayout::RowMajor), - gpu::MatrixLayout::Undefined => None, - } - } - - fn compile_binding(binding: gpu::Binding) -> super::Binding { - super::Binding { - item: Self::compile_item(binding.item), - size: binding.size, - } - } - - fn compile_item(item: gpu::Item) -> super::Item { - match item.vectorization { - 4 => super::Item::Vec4(Self::compile_elem(item.elem)), - 3 => super::Item::Vec3(Self::compile_elem(item.elem)), - 2 => super::Item::Vec2(Self::compile_elem(item.elem)), - 1 => super::Item::Scalar(Self::compile_elem(item.elem)), - _ => panic!("Vectorization factor unsupported {:?}", item.vectorization), - } - } - - fn compile_elem(value: gpu::Elem) -> super::Elem { - match value { - gpu::Elem::Float(kind) => match kind { - gpu::FloatKind::F16 => super::Elem::F16, - gpu::FloatKind::BF16 => super::Elem::BF16, - gpu::FloatKind::F32 => super::Elem::F32, - gpu::FloatKind::F64 => panic!("f64 isn't supported yet"), - }, - gpu::Elem::Int(kind) => match kind { - gpu::IntKind::I32 => super::Elem::I32, - gpu::IntKind::I64 => panic!("i64 isn't supported yet"), - }, - gpu::Elem::UInt => super::Elem::U32, - gpu::Elem::Bool => super::Elem::Bool, - } - } -} diff --git a/crates/burn-cuda/src/compiler/binary.rs b/crates/burn-cuda/src/compiler/binary.rs deleted file mode 100644 index ed722a8acf..0000000000 --- a/crates/burn-cuda/src/compiler/binary.rs +++ /dev/null @@ -1,483 +0,0 @@ -use super::{Component, Elem, InstructionSettings, Item, Variable}; -use std::fmt::Display; - -pub trait Binary { - fn format( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - ) -> std::fmt::Result { - let item = out.item(); - let settings = Self::settings(*item.elem()); - - match item { - Item::Vec4(elem) => { - if settings.native_vec4 && lhs.item() == rhs.item() { - Self::format_native_vec4(f, lhs, rhs, out, elem) - } else { - Self::unroll_vec4(f, lhs, rhs, out, elem) - } - } - Item::Vec3(elem) => { - if settings.native_vec3 && lhs.item() == rhs.item() { - Self::format_native_vec3(f, lhs, rhs, out, elem) - } else { - Self::unroll_vec3(f, lhs, rhs, out, elem) - } - } - Item::Vec2(elem) => { - if settings.native_vec2 && lhs.item() == rhs.item() { - Self::format_native_vec2(f, lhs, rhs, out, elem) - } else { - Self::unroll_vec2(f, lhs, rhs, out, elem) - } - } - Item::Scalar(elem) => Self::format_scalar(f, *lhs, *rhs, *out, elem), - } - } - - fn settings(_elem: Elem) -> InstructionSettings { - InstructionSettings::default() - } - - fn format_scalar( - f: &mut std::fmt::Formatter<'_>, - lhs: Lhs, - rhs: Rhs, - out: Out, - elem: Elem, - ) -> std::fmt::Result - where - Lhs: Component, - Rhs: Component, - Out: Component; - - fn format_native_vec4( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::format_scalar(f, *lhs, *rhs, *out, elem) - } - - fn format_native_vec3( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::format_scalar(f, *lhs, *rhs, *out, elem) - } - - fn format_native_vec2( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::format_scalar(f, *lhs, *rhs, *out, elem) - } - - fn unroll_vec2( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - - let out0 = out.index(0); - let out1 = out.index(1); - - Self::format_scalar(f, lhs0, rhs0, out0, elem)?; - Self::format_scalar(f, lhs1, rhs1, out1, elem)?; - - Ok(()) - } - - fn unroll_vec3( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - let lhs2 = lhs.index(2); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - let rhs2 = rhs.index(2); - - let out0 = out.index(0); - let out1 = out.index(1); - let out2 = out.index(2); - - Self::format_scalar(f, lhs0, rhs0, out0, elem)?; - Self::format_scalar(f, lhs1, rhs1, out1, elem)?; - Self::format_scalar(f, lhs2, rhs2, out2, elem)?; - - Ok(()) - } - - fn unroll_vec4( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - let lhs2 = lhs.index(2); - let lhs3 = lhs.index(3); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - let rhs2 = rhs.index(2); - let rhs3 = rhs.index(3); - - let out0 = out.index(0); - let out1 = out.index(1); - let out2 = out.index(2); - let out3 = out.index(3); - - Self::format_scalar(f, lhs0, rhs0, out0, elem)?; - Self::format_scalar(f, lhs1, rhs1, out1, elem)?; - Self::format_scalar(f, lhs2, rhs2, out2, elem)?; - Self::format_scalar(f, lhs3, rhs3, out3, elem)?; - - Ok(()) - } -} - -macro_rules! operator { - ($name:ident, $op:expr) => { - operator!( - $name, - $op, - InstructionSettings { - native_vec4: false, - native_vec3: false, - native_vec2: false, - } - ); - }; - ($name:ident, $op:expr, $vectorization:expr) => { - pub struct $name; - - impl Binary for $name { - fn format_scalar( - f: &mut std::fmt::Formatter<'_>, - lhs: Lhs, - rhs: Rhs, - out: Out, - _elem: Elem, - ) -> std::fmt::Result { - f.write_fmt(format_args!("{out} = {lhs} {} {rhs};\n", $op)) - } - - #[allow(unused_variables)] - fn settings(elem: Elem) -> InstructionSettings { - $vectorization - } - } - }; -} - -macro_rules! function { - ($name:ident, $op:expr) => { - function!( - $name, - $op, - InstructionSettings { - native_vec4: false, - native_vec3: false, - native_vec2: true, - } - ); - }; - ($name:ident, $op:expr, $vectorization:expr) => { - pub struct $name; - - impl Binary for $name { - fn format_scalar( - f: &mut std::fmt::Formatter<'_>, - lhs: Lhs, - rhs: Rhs, - out: Out, - _elem: Elem, - ) -> std::fmt::Result { - f.write_fmt(format_args!("{out} = {}({lhs}, {rhs});\n", $op)) - } - - #[allow(unused_variables)] - fn settings(elem: Elem) -> InstructionSettings { - $vectorization - } - } - }; -} - -operator!(Add, "+"); -operator!(Sub, "-"); -operator!(Div, "/"); -operator!(Mul, "*"); -operator!(Modulo, "%"); -operator!(Equal, "=="); -operator!(NotEqual, "!="); -operator!(Lower, "<"); -operator!(LowerEqual, "<="); -operator!(Greater, ">"); -operator!(GreaterEqual, ">="); -operator!(ShiftLeft, "<<"); -operator!(ShiftRight, ">>"); -operator!(BitwiseAnd, "&"); -operator!(BitwiseXor, "^"); -operator!(Or, "||"); -operator!(And, "&&"); - -function!(Powf, "powf"); -function!(Max, "max"); -function!(Min, "min"); - -pub struct IndexAssign; -pub struct Index; - -impl Binary for IndexAssign { - fn format_scalar( - f: &mut std::fmt::Formatter<'_>, - lhs: Lhs, - rhs: Rhs, - out: Out, - elem: Elem, - ) -> std::fmt::Result - where - Lhs: Component, - Rhs: Component, - Out: Component, - { - let elem_rhs = rhs.elem(); - // Cast only when necessary. - if elem != elem_rhs { - if let Elem::Bool = elem_rhs { - match rhs.item() { - Item::Vec4(_) => { - f.write_fmt(format_args!("{out}[{lhs}] = make_uint4({elem}({rhs}.x), {elem}({rhs}.y), {elem}({rhs}.z), {elem}({rhs}.w));\n")) - }, - Item::Vec3(_) => todo!(), - Item::Vec2(_) => todo!(), - Item::Scalar(_) => todo!(), - } - } else { - f.write_fmt(format_args!("{out}[{lhs}] = {elem}({rhs});\n")) - } - } else { - f.write_fmt(format_args!("{out}[{lhs}] = {rhs};\n")) - } - } - - fn unroll_vec2( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - - Self::format_scalar(f, lhs0, rhs0, *out, elem)?; - Self::format_scalar(f, lhs1, rhs1, *out, elem)?; - - Ok(()) - } - - fn unroll_vec3( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - let lhs2 = lhs.index(2); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - let rhs2 = rhs.index(2); - - Self::format_scalar(f, lhs0, rhs0, *out, elem)?; - Self::format_scalar(f, lhs1, rhs1, *out, elem)?; - Self::format_scalar(f, lhs2, rhs2, *out, elem)?; - - Ok(()) - } - - fn unroll_vec4( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - let lhs2 = lhs.index(2); - let lhs3 = lhs.index(3); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - let rhs2 = rhs.index(2); - let rhs3 = rhs.index(3); - - Self::format_scalar(f, lhs0, rhs0, *out, elem)?; - Self::format_scalar(f, lhs1, rhs1, *out, elem)?; - Self::format_scalar(f, lhs2, rhs2, *out, elem)?; - Self::format_scalar(f, lhs3, rhs3, *out, elem)?; - - Ok(()) - } - - fn format( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - ) -> std::fmt::Result { - if let Variable::Local { - id: _, - item: _, - depth: _, - } = out - { - return IndexAssignVector::format(f, lhs, rhs, out); - }; - - let elem = out.elem(); - - match lhs.item() { - Item::Vec4(_) => Self::unroll_vec4(f, lhs, rhs, out, elem), - Item::Vec3(_) => Self::unroll_vec3(f, lhs, rhs, out, elem), - Item::Vec2(_) => Self::unroll_vec2(f, lhs, rhs, out, elem), - Item::Scalar(_) => Self::format_scalar(f, *lhs, *rhs, *out, elem), - } - } -} - -impl Binary for Index { - fn format( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - ) -> std::fmt::Result { - if let Variable::Local { - id: _, - item: _, - depth: _, - } = lhs - { - return IndexVector::format(f, lhs, rhs, out); - } - - Self::format_scalar(f, *lhs, *rhs, *out, out.elem()) - } - - fn format_scalar( - f: &mut std::fmt::Formatter<'_>, - lhs: Lhs, - rhs: Rhs, - out: Out, - _elem: Elem, - ) -> std::fmt::Result - where - Lhs: Component, - Rhs: Component, - Out: Component, - { - f.write_fmt(format_args!("{out} = {lhs}[{rhs}];\n")) - } -} - -/// The goal is to support indexing of vectorized types. -/// -/// # Examples -/// -/// ```c -/// float4 rhs; -/// float item = var[0]; // We want that. -/// float item = var.x; // So we compile to that. -/// ``` -struct IndexVector; - -/// The goal is to support indexing of vectorized types. -/// -/// # Examples -/// -/// ```c -/// float4 var; -/// -/// var[0] = 1.0; // We want that. -/// var.x = 1.0; // So we compile to that. -/// ``` -struct IndexAssignVector; - -impl IndexVector { - fn format( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - ) -> std::fmt::Result { - let index = match rhs { - Variable::ConstantScalar(value, _elem) => *value as usize, - _ => { - let elem = out.elem(); - return f.write_fmt(format_args!("{out} = *(({elem}*)&{lhs} + {rhs});\n")); - } - }; - - let out = out.index(index); - let lhs = lhs.index(index); - - f.write_fmt(format_args!("{out} = {lhs};\n")) - } -} - -impl IndexAssignVector { - fn format( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - ) -> std::fmt::Result { - let index = match lhs { - Variable::ConstantScalar(value, _) => *value as usize, - _ => { - let elem = out.elem(); - return f.write_fmt(format_args!("*(({elem}*)&{out} + {lhs}) = {rhs};\n")); - } - }; - - let out = out.index(index); - let rhs = rhs.index(index); - - f.write_fmt(format_args!("{out} = {rhs};\n")) - } -} diff --git a/crates/burn-cuda/src/compiler/body.rs b/crates/burn-cuda/src/compiler/body.rs deleted file mode 100644 index 72fe5fac4c..0000000000 --- a/crates/burn-cuda/src/compiler/body.rs +++ /dev/null @@ -1,89 +0,0 @@ -use super::Instruction; -use std::fmt::Display; - -/// A body is composed of a list of [instructions](Instruction). -#[derive(Debug, Clone)] -pub struct Body { - pub instructions: Vec, - pub shared_memories: Vec, - pub local_arrays: Vec, - pub stride: bool, - pub shape: bool, - pub id: bool, - pub rank: bool, - pub invocation_index: bool, - pub global_invocation_id: (bool, bool, bool), - pub wrap_size_checked: bool, -} - -impl Display for Body { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.id - || self.global_invocation_id.0 - || self.global_invocation_id.1 - || self.global_invocation_id.2 - { - f.write_str( - " - int3 globalInvocationId = make_int3( - blockIdx.x * blockDim.x + threadIdx.x, - blockIdx.y * blockDim.y + threadIdx.y, - blockIdx.z * blockDim.z + threadIdx.z - ); -", - )?; - } - - if self.id { - f.write_str( - " - uint id = (globalInvocationId.z * gridDim.x * blockDim.x * gridDim.y * blockDim.y) + (globalInvocationId.y * gridDim.x * blockDim.x) + globalInvocationId.x; -", - )?; - } - - if self.invocation_index { - f.write_str( - " - int invocationIndex = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.x * blockDim.y); - ", - )?; - } - if self.wrap_size_checked { - f.write_str( - " - int warpSizeChecked = min(warpSize, blockDim.x * blockDim.y * blockDim.z); -", - )?; - } - - if self.rank || self.stride || self.shape { - f.write_str("uint rank = info[0];\n")?; - } - - if self.stride || self.shape { - f.write_str("uint rank_2 = rank * 2;\n")?; - } - - for shared in self.shared_memories.iter() { - f.write_fmt(format_args!( - "__shared__ {} shared_memory_{}[{}];\n", - shared.item, shared.index, shared.size - ))?; - } - - // Local arrays - for array in self.local_arrays.iter() { - f.write_fmt(format_args!( - "{} l_arr_{}_{}[{}];\n\n", - array.item, array.index, array.depth, array.size - ))?; - } - - for ops in self.instructions.iter() { - f.write_fmt(format_args!("{ops}"))?; - } - - Ok(()) - } -} diff --git a/crates/burn-cuda/src/compiler/element.rs b/crates/burn-cuda/src/compiler/element.rs deleted file mode 100644 index 583f2bb241..0000000000 --- a/crates/burn-cuda/src/compiler/element.rs +++ /dev/null @@ -1,327 +0,0 @@ -use burn_cube::ir as gpu; -use half::{bf16, f16}; -use std::fmt::Display; - -use super::Fragment; - -#[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub enum Elem { - F32, - F16, - BF16, - I32, - U32, - Bool, -} - -#[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub enum Item { - Vec4(Elem), - Vec3(Elem), - Vec2(Elem), - Scalar(Elem), -} - -impl Display for Elem { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Elem::F16 => f.write_str("half"), - Elem::F32 => f.write_str("float"), - Elem::BF16 => f.write_str("bf16"), - Elem::I32 => f.write_str("int"), - Elem::U32 => f.write_str("uint"), - Elem::Bool => f.write_str("bool"), - } - } -} - -impl Display for Item { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Item::Vec4(elem) => match elem { - Elem::F32 => f.write_str("float4"), - Elem::I32 => f.write_str("int4"), - Elem::U32 => f.write_str("uint4"), - Elem::Bool => f.write_str("bool4"), - Elem::BF16 => f.write_str("bf164"), - Elem::F16 => f.write_str("f164"), - }, - Item::Vec3(elem) => match elem { - Elem::F32 => f.write_str("float3"), - Elem::I32 => f.write_str("int3"), - Elem::U32 => f.write_str("uint3"), - Elem::Bool => f.write_str("bool3"), - Elem::BF16 => f.write_str("bf163"), - Elem::F16 => f.write_str("f163"), - }, - Item::Vec2(elem) => match elem { - Elem::F32 => f.write_str("float2"), - Elem::I32 => f.write_str("int2"), - Elem::U32 => f.write_str("uint2"), - Elem::Bool => f.write_str("bool2"), - Elem::BF16 => f.write_str("bf162"), - Elem::F16 => f.write_str("f162"), - }, - Item::Scalar(elem) => f.write_fmt(format_args!("{elem}")), - } - } -} - -pub trait Component: Display { - fn item(&self) -> Item; - fn elem(&self) -> Elem { - *self.item().elem() - } -} - -impl Component for IndexedVariable { - fn item(&self) -> Item { - self.var.item() - } -} -impl Component for Variable { - fn item(&self) -> Item { - match self { - Variable::GlobalInputArray(_, e) => *e, - Variable::GlobalOutputArray(_, e) => *e, - Variable::SharedMemory(_, e, _) => *e, - Variable::Local { - id: _, - item, - depth: _, - } => *item, - Variable::Slice { - id: _, - item, - depth: _, - } => *item, - Variable::ConstantScalar(_, e) => Item::Scalar(*e), - Variable::GlobalScalar(_, e, _) => Item::Scalar(*e), - Variable::Id => Item::Scalar(Elem::U32), - Variable::LocalInvocationIndex => Item::Scalar(Elem::U32), - Variable::LocalInvocationIdX => Item::Scalar(Elem::U32), - Variable::LocalInvocationIdY => Item::Scalar(Elem::U32), - Variable::LocalInvocationIdZ => Item::Scalar(Elem::U32), - Variable::Rank => Item::Scalar(Elem::U32), - Variable::LocalScalar { - id: _, - elem, - depth: _, - } => Item::Scalar(*elem), - Variable::WorkgroupIdX => Item::Scalar(Elem::U32), - Variable::WorkgroupIdY => Item::Scalar(Elem::U32), - Variable::WorkgroupIdZ => Item::Scalar(Elem::U32), - Variable::GlobalInvocationIdX => Item::Scalar(Elem::U32), - Variable::GlobalInvocationIdY => Item::Scalar(Elem::U32), - Variable::GlobalInvocationIdZ => Item::Scalar(Elem::U32), - Variable::WorkgroupSizeX => Item::Scalar(Elem::U32), - Variable::WorkgroupSizeY => Item::Scalar(Elem::U32), - Variable::WorkgroupSizeZ => Item::Scalar(Elem::U32), - Variable::NumWorkgroupsX => Item::Scalar(Elem::U32), - Variable::NumWorkgroupsY => Item::Scalar(Elem::U32), - Variable::NumWorkgroupsZ => Item::Scalar(Elem::U32), - Variable::LocalArray(_, e, _, _) => *e, - Variable::WarpSize => Item::Scalar(Elem::U32), - Variable::WmmaFragment { id: _, frag } => Item::Scalar(frag.elem), - } - } -} - -#[derive(Debug, Clone, Copy)] -pub enum Variable { - WarpSize, - GlobalInputArray(u16, Item), - GlobalOutputArray(u16, Item), - GlobalScalar(u16, Elem, gpu::Elem), - ConstantScalar(f64, Elem), - Local { id: u16, item: Item, depth: u8 }, - Slice { id: u16, item: Item, depth: u8 }, - LocalScalar { id: u16, elem: Elem, depth: u8 }, - SharedMemory(u16, Item, u32), - LocalArray(u16, Item, u8, u32), - Id, - LocalInvocationIndex, - LocalInvocationIdX, - LocalInvocationIdY, - LocalInvocationIdZ, - Rank, - WorkgroupIdX, - WorkgroupIdY, - WorkgroupIdZ, - GlobalInvocationIdX, - GlobalInvocationIdY, - GlobalInvocationIdZ, - WorkgroupSizeX, - WorkgroupSizeY, - WorkgroupSizeZ, - NumWorkgroupsX, - NumWorkgroupsY, - NumWorkgroupsZ, - WmmaFragment { id: u16, frag: Fragment }, -} - -impl Display for Variable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")), - Variable::LocalScalar { - id: index, - elem: _, - depth: scope_depth, - } => f.write_fmt(format_args!("s_{scope_depth}_{index}")), - Variable::Local { - id: index, - item: _, - depth: scope_depth, - } => f.write_fmt(format_args!("l_{scope_depth}_{index}")), - Variable::Slice { id, item: _, depth } => { - f.write_fmt(format_args!("slice_{depth}_{id}")) - } - Variable::GlobalOutputArray(number, _) => f.write_fmt(format_args!("output_{number}")), - Variable::GlobalScalar(number, _, elem) => { - f.write_fmt(format_args!("scalars_{elem}[{number}]")) - } - Variable::ConstantScalar(number, elem) => f.write_fmt(format_args!("{elem}({number})")), - Variable::SharedMemory(number, _, _) => { - f.write_fmt(format_args!("shared_memory_{number}")) - } - Variable::Id => f.write_str("id"), - Variable::LocalInvocationIndex => f.write_str("invocationIndex"), - Variable::LocalInvocationIdX => f.write_str("threadIdx.x"), - Variable::LocalInvocationIdY => f.write_str("threadIdx.y"), - Variable::LocalInvocationIdZ => f.write_str("threadIdx.z"), - Variable::Rank => f.write_str("rank"), - Variable::WorkgroupIdX => f.write_str("blockIdx.x"), - Variable::WorkgroupIdY => f.write_str("blockIdx.y"), - Variable::WorkgroupIdZ => f.write_str("blockIdx.z"), - Variable::WorkgroupSizeX => f.write_str("blockDim.x"), - Variable::WorkgroupSizeY => f.write_str("blockDim.y"), - Variable::WorkgroupSizeZ => f.write_str("blockDim.z"), - Variable::NumWorkgroupsX => f.write_str("gridDim.x"), - Variable::NumWorkgroupsY => f.write_str("gridDim.y"), - Variable::NumWorkgroupsZ => f.write_str("gridDim.z"), - Variable::GlobalInvocationIdX => f.write_str("globalInvocationId.x"), - Variable::GlobalInvocationIdY => f.write_str("globalInvocationId.y"), - Variable::GlobalInvocationIdZ => f.write_str("globalInvocationId.z"), - Variable::LocalArray(id, _item, depth, _size) => { - f.write_fmt(format_args!("l_arr_{}_{}", id, depth)) - } - Variable::WarpSize => f.write_str("warpSize"), - Variable::WmmaFragment { id: index, frag: _ } => { - f.write_fmt(format_args!("frag_{index}")) - } - } - } -} - -impl Variable { - pub fn is_always_scalar(&self) -> bool { - match self { - Variable::GlobalScalar(_, _, _) => true, - Variable::ConstantScalar(_, _) => true, - Variable::LocalScalar { - id: _, - elem: _, - depth: _, - } => true, - Variable::Id => true, - Variable::LocalInvocationIndex => true, - Variable::LocalInvocationIdX => true, - Variable::LocalInvocationIdY => true, - Variable::LocalInvocationIdZ => true, - Variable::Rank => true, - Variable::GlobalInputArray(_, _) => false, - Variable::GlobalOutputArray(_, _) => false, - Variable::SharedMemory(_, _, _) => false, - Variable::Local { - id: _, - item: _, - depth: _, - } => false, - Variable::Slice { - id: _, - item: _, - depth: _, - } => false, - Variable::WorkgroupIdX => true, - Variable::WorkgroupIdY => true, - Variable::WorkgroupIdZ => true, - Variable::GlobalInvocationIdX => true, - Variable::GlobalInvocationIdY => true, - Variable::GlobalInvocationIdZ => true, - Variable::WorkgroupSizeX => true, - Variable::WorkgroupSizeY => true, - Variable::WorkgroupSizeZ => true, - Variable::NumWorkgroupsX => true, - Variable::NumWorkgroupsY => true, - Variable::NumWorkgroupsZ => true, - Variable::LocalArray(_, _, _, _) => false, - Variable::WarpSize => true, - Variable::WmmaFragment { id: _, frag: _ } => false, - } - } - - pub fn index(&self, index: usize) -> IndexedVariable { - IndexedVariable { var: *self, index } - } -} - -#[derive(Debug, Clone)] -pub struct IndexedVariable { - var: Variable, - index: usize, -} - -impl Display for IndexedVariable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let var = &self.var; - let item = self.var.item(); - - match item { - Item::Vec4(_) => match self.index { - 0 => f.write_fmt(format_args!("{var}.x"))?, - 1 => f.write_fmt(format_args!("{var}.y"))?, - 2 => f.write_fmt(format_args!("{var}.z"))?, - 3 => f.write_fmt(format_args!("{var}.w"))?, - _ => unreachable!(), - }, - Item::Vec3(_) => match self.index { - 0 => f.write_fmt(format_args!("{var}.x"))?, - 1 => f.write_fmt(format_args!("{var}.y"))?, - 2 => f.write_fmt(format_args!("{var}.z"))?, - _ => unreachable!(), - }, - Item::Vec2(_) => match self.index { - 0 => f.write_fmt(format_args!("{var}.x"))?, - 1 => f.write_fmt(format_args!("{var}.y"))?, - _ => unreachable!(), - }, - Item::Scalar(_) => f.write_fmt(format_args!("{var}"))?, - } - - Ok(()) - } -} -impl Item { - pub fn elem(&self) -> &Elem { - match self { - Item::Vec4(e) => e, - Item::Vec3(e) => e, - Item::Vec2(e) => e, - Item::Scalar(e) => e, - } - } -} - -impl Elem { - pub fn size(&self) -> usize { - match self { - Self::F32 => core::mem::size_of::(), - Self::F16 => core::mem::size_of::(), - Self::BF16 => core::mem::size_of::(), - Self::I32 => core::mem::size_of::(), - Self::U32 => core::mem::size_of::(), - Self::Bool => core::mem::size_of::(), - } - } -} diff --git a/crates/burn-cuda/src/compiler/instruction.rs b/crates/burn-cuda/src/compiler/instruction.rs deleted file mode 100644 index 622416b23e..0000000000 --- a/crates/burn-cuda/src/compiler/instruction.rs +++ /dev/null @@ -1,312 +0,0 @@ -use super::{binary::*, unary::*, Component, Variable, WarpInstruction, WmmaInstruction}; -use std::fmt::Display; - -#[derive(Debug, Clone)] -pub struct BinaryInstruction { - pub lhs: Variable, - pub rhs: Variable, - pub out: Variable, -} - -#[derive(Debug, Clone)] -pub struct UnaryInstruction { - pub input: Variable, - pub out: Variable, -} - -#[derive(Debug, Clone)] -pub enum Instruction { - Length { - input: Variable, - out: Variable, - num_inputs: usize, - num_outputs: usize, - }, - SliceLength { - input: Variable, - out: Variable, - }, - DeclareVariable { - var: Variable, - }, - Modulo(BinaryInstruction), - Add(BinaryInstruction), - Fma { - a: Variable, - b: Variable, - c: Variable, - out: Variable, - }, - Div(BinaryInstruction), - Mul(BinaryInstruction), - Sub(BinaryInstruction), - Index(BinaryInstruction), - IndexAssign(BinaryInstruction), - CheckedIndexAssign(BinaryInstruction), - Assign(UnaryInstruction), - RangeLoop { - i: Variable, - start: Variable, - end: Variable, - instructions: Vec, - }, - Loop { - instructions: Vec, - }, - If { - cond: Variable, - instructions: Vec, - }, - IfElse { - cond: Variable, - instructions_if: Vec, - instructions_else: Vec, - }, - Slice { - input: Variable, - start: Variable, - end: Variable, - out: Variable, - }, - Return, - Break, - Stride { - dim: Variable, - position: usize, - out: Variable, - }, - Shape { - dim: Variable, - position: usize, - out: Variable, - }, - Equal(BinaryInstruction), - NotEqual(BinaryInstruction), - Lower(BinaryInstruction), - Greater(BinaryInstruction), - LowerEqual(BinaryInstruction), - GreaterEqual(BinaryInstruction), - Erf(UnaryInstruction), - BitwiseAnd(BinaryInstruction), - BitwiseXor(BinaryInstruction), - ShiftLeft(BinaryInstruction), - ShiftRight(BinaryInstruction), - Abs(UnaryInstruction), - Exp(UnaryInstruction), - Log(UnaryInstruction), - Log1p(UnaryInstruction), - Cos(UnaryInstruction), - Sin(UnaryInstruction), - Tanh(UnaryInstruction), - Powf(BinaryInstruction), - Sqrt(UnaryInstruction), - Min(BinaryInstruction), - Max(BinaryInstruction), - Not(UnaryInstruction), - Or(BinaryInstruction), - And(BinaryInstruction), - Clamp { - input: Variable, - min_value: Variable, - max_value: Variable, - out: Variable, - }, - SyncThreads, - Ceil(UnaryInstruction), - Floor(UnaryInstruction), - Wrap(WarpInstruction), - Wmma(WmmaInstruction), -} - -impl Display for Instruction { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Instruction::Return => f.write_str("return;"), - Instruction::Break => f.write_str("break;"), - Instruction::DeclareVariable { var } => match var { - Variable::WmmaFragment { id: _, frag } => { - f.write_fmt(format_args!("{frag} {var};\n")) - } - _ => { - let item = var.item(); - f.write_fmt(format_args!("{item} {var};\n")) - } - }, - Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Slice { - input, - start, - end, - out, - } => { - let item = out.item(); - f.write_fmt(format_args!("uint {out}_length = {end} - {start};\n"))?; - f.write_fmt(format_args!("{item} *{out} = {input} + {start};\n")) - } - Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out), - Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::BitwiseXor(it) => BitwiseXor::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::ShiftLeft(it) => ShiftLeft::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::ShiftRight(it) => ShiftRight::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Index(it) => Index::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::IndexAssign(it) => IndexAssign::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::CheckedIndexAssign(it) => { - IndexAssign::format(f, &it.lhs, &it.rhs, &it.out) - } - Instruction::Assign(it) => Assign::format(f, &it.input, &it.out), - Instruction::RangeLoop { - i, - start, - end, - instructions, - } => { - f.write_fmt(format_args!( - " -for (uint {i} = {start}; {i} < {end}; {i}++) {{ -" - ))?; - for instruction in instructions { - f.write_fmt(format_args!("{instruction}"))?; - } - - f.write_str("}\n") - } - - Instruction::Loop { instructions } => { - f.write_fmt(format_args!("while (true) {{\n"))?; - for i in instructions { - f.write_fmt(format_args!("{i}"))?; - } - f.write_str("}\n") - } - Instruction::If { cond, instructions } => { - f.write_fmt(format_args!("if ({cond}) {{\n"))?; - for i in instructions { - f.write_fmt(format_args!("{i}"))?; - } - f.write_str("}\n") - } - Instruction::IfElse { - cond, - instructions_if, - instructions_else, - } => { - f.write_fmt(format_args!("if ({cond}) {{\n"))?; - for i in instructions_if { - f.write_fmt(format_args!("{i}"))?; - } - f.write_str("} else {\n")?; - for i in instructions_else { - f.write_fmt(format_args!("{i}"))?; - } - f.write_str("}\n") - } - Instruction::Stride { dim, position, out } => f.write_fmt(format_args!( - "{out} = info[({position} * rank_2) + {dim} + 1];\n" - )), - Instruction::Shape { dim, position, out } => f.write_fmt(format_args!( - "{out} = info[({position} * rank_2) + rank + {dim} + 1];\n" - )), - Instruction::Equal(it) => Equal::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::NotEqual(it) => NotEqual::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Lower(it) => Lower::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Greater(it) => Greater::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::LowerEqual(it) => LowerEqual::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::GreaterEqual(it) => GreaterEqual::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Erf(it) => Erf::format(f, &it.input, &it.out), - Instruction::Abs(it) => Abs::format(f, &it.input, &it.out), - Instruction::Exp(it) => Exp::format(f, &it.input, &it.out), - Instruction::Log(it) => Log::format(f, &it.input, &it.out), - Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out), - Instruction::Cos(it) => Cos::format(f, &it.input, &it.out), - Instruction::Sin(it) => Sin::format(f, &it.input, &it.out), - Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out), - Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out), - Instruction::Max(it) => Max::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Min(it) => Min::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Not(it) => Not::format(f, &it.input, &it.out), - Instruction::Or(it) => Or::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::And(it) => And::format(f, &it.lhs, &it.rhs, &it.out), - Instruction::Clamp { - input, - min_value, - max_value, - out, - } => f.write_fmt(format_args!( - " -{out} = min({input}, {max_value}); -{out} = max({out}, {min_value}); - " - )), - Instruction::SyncThreads => f.write_str("__syncthreads();\n"), - Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out), - Instruction::Floor(it) => Floor::format(f, &it.input, &it.out), - Instruction::SliceLength { input, out } => { - f.write_fmt(format_args!("{out} = {input}_length;\n")) - } - Instruction::Length { - input, - out, - num_inputs, - num_outputs, - } => { - let offset = num_inputs + num_outputs; - let index = match input { - Variable::GlobalInputArray(index, _) => *index as usize, - Variable::GlobalOutputArray(index, _) => *index as usize + num_inputs, - _ => panic!("Can only know the len of a global array."), - } + 1; - let factor = match input.item() { - super::Item::Vec4(_) => 4, - super::Item::Vec3(_) => 3, - super::Item::Vec2(_) => 2, - super::Item::Scalar(_) => { - return f.write_fmt(format_args!( - "{out} = info[({offset} * 2 * info[0]) + {index}];\n" - )) - } - }; - f.write_fmt(format_args!( - "{out} = info[({offset} * 2 * info[0]) + {index}] / {factor};\n" - )) - } - Instruction::Wrap(it) => f.write_fmt(format_args!("{it}")), - Instruction::Fma { a, b, c, out } => Fma::format(f, a, b, c, out), - Instruction::Wmma(it) => f.write_fmt(format_args!("{it}")), - } - } -} - -struct Fma; - -impl Fma { - fn format( - f: &mut core::fmt::Formatter<'_>, - a: &Variable, - b: &Variable, - c: &Variable, - out: &Variable, - ) -> core::fmt::Result { - let num = match out.item() { - super::Item::Vec4(_) => 4, - super::Item::Vec3(_) => 3, - super::Item::Vec2(_) => 2, - super::Item::Scalar(_) => 1, - }; - - for i in 0..num { - let ai = a.index(i); - let bi = b.index(i); - let ci = c.index(i); - let outi = out.index(i); - - f.write_fmt(format_args!("{outi} = fma({ai}, {bi}, {ci});\n"))?; - } - - Ok(()) - } -} diff --git a/crates/burn-cuda/src/compiler/mma.rs b/crates/burn-cuda/src/compiler/mma.rs deleted file mode 100644 index 0ed4131949..0000000000 --- a/crates/burn-cuda/src/compiler/mma.rs +++ /dev/null @@ -1,129 +0,0 @@ -use std::fmt::Display; - -use super::{Elem, Variable}; - -#[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub enum FragmentIdent { - A, - B, - Accumulator, -} - -#[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub enum FragmentLayout { - ColMajor, - RowMajor, -} - -#[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub struct Fragment { - pub ident: FragmentIdent, - pub m: u8, - pub n: u8, - pub k: u8, - pub elem: Elem, - pub layout: Option, -} - -/// Warp Matrix-Multiply and Accumulate Instruction. -#[derive(Debug, Clone, Copy)] -pub enum WmmaInstruction { - /// Fill the fragment with the value. - Fill { frag: Variable, value: Variable }, - /// Load the value into the fragment given the stride. - Load { - frag: Variable, - value: Variable, - stride: Variable, - }, - /// Executes D=A*B+C; - /// - /// For implementing a matmul, `D=C` : `C+=A*B` - Execute { - frag_a: Variable, - frag_b: Variable, - frag_c: Variable, - frag_d: Variable, - }, - /// Store the fragment in an output variable following the stride and the layout. - Store { - output: Variable, - frag: Variable, - stride: Variable, - layout: FragmentLayout, - }, -} - -impl Display for FragmentLayout { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - FragmentLayout::ColMajor => f.write_str("wmma::col_major"), - FragmentLayout::RowMajor => f.write_str("wmma::row_major"), - } - } -} - -impl Display for FragmentIdent { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - FragmentIdent::A => f.write_str("wmma::matrix_a"), - FragmentIdent::B => f.write_str("wmma::matrix_b"), - FragmentIdent::Accumulator => f.write_str("wmma::accumulator"), - } - } -} - -impl Display for Fragment { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.layout { - Some(layout) => f.write_fmt(format_args!( - "wmma::fragment<{}, {}, {}, {}, {}, {}>", - self.ident, self.m, self.n, self.k, self.elem, layout - )), - None => f.write_fmt(format_args!( - "wmma::fragment<{}, {}, {}, {}, {}>", - self.ident, self.m, self.n, self.k, self.elem, - )), - } - } -} - -impl Display for WmmaInstruction { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - WmmaInstruction::Fill { frag, value } => { - f.write_fmt(format_args!("wmma::fill_fragment({frag}, {value});\n")) - } - WmmaInstruction::Load { - frag, - value, - stride, - } => f.write_fmt(format_args!( - "wmma::load_matrix_sync({frag}, {value}, {stride});\n" - )), - WmmaInstruction::Execute { - frag_a, - frag_b, - frag_c, - frag_d, - } => f.write_fmt(format_args!( - "wmma::mma_sync({frag_d}, {frag_a}, {frag_b}, {frag_c});\n" - )), - WmmaInstruction::Store { - output, - frag, - stride, - layout, - } => { - let layout = match layout { - FragmentLayout::ColMajor => "wmma::mem_col_major", - FragmentLayout::RowMajor => "wmma::mem_row_major", - }; - - f.write_fmt(format_args!( - "wmma::store_matrix_sync({output}, {frag}, {stride}, {layout});\n" - )) - } - } - } -} diff --git a/crates/burn-cuda/src/compiler/mod.rs b/crates/burn-cuda/src/compiler/mod.rs deleted file mode 100644 index 73846b87af..0000000000 --- a/crates/burn-cuda/src/compiler/mod.rs +++ /dev/null @@ -1,20 +0,0 @@ -pub mod binary; -pub mod unary; - -mod base; -mod body; -mod element; -mod instruction; -mod mma; -mod settings; -mod shader; -mod warp; - -pub use base::*; -pub use body::*; -pub use element::*; -pub use instruction::*; -pub use mma::*; -pub use settings::*; -pub use shader::*; -pub use warp::*; diff --git a/crates/burn-cuda/src/compiler/settings.rs b/crates/burn-cuda/src/compiler/settings.rs deleted file mode 100644 index 09e35427f7..0000000000 --- a/crates/burn-cuda/src/compiler/settings.rs +++ /dev/null @@ -1,6 +0,0 @@ -#[derive(Debug, Default)] -pub struct InstructionSettings { - pub native_vec4: bool, - pub native_vec3: bool, - pub native_vec2: bool, -} diff --git a/crates/burn-cuda/src/compiler/shader.rs b/crates/burn-cuda/src/compiler/shader.rs deleted file mode 100644 index 4681b59fdd..0000000000 --- a/crates/burn-cuda/src/compiler/shader.rs +++ /dev/null @@ -1,159 +0,0 @@ -use burn_cube::{ir::CubeDim, CompilerRepresentation}; - -// use super::{Body, Extension, Item}; -use super::{Body, Item}; -use std::fmt::Display; - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum Location { - Storage, - #[allow(dead_code)] - Warp, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum Visibility { - Read, - ReadWrite, -} - -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct Binding { - pub item: Item, - pub size: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct SharedMemory { - pub index: u16, - pub item: Item, - pub size: u32, -} - -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct LocalArray { - pub index: u16, - pub item: Item, - pub depth: u8, - pub size: u32, -} - -impl LocalArray { - pub fn new(index: u16, item: Item, depth: u8, size: u32) -> Self { - Self { - index, - item, - depth, - size, - } - } -} - -impl SharedMemory { - pub fn new(index: u16, item: Item, size: u32) -> Self { - Self { index, item, size } - } -} - -#[derive(Debug, Clone)] -pub struct ComputeShader { - pub inputs: Vec, - pub outputs: Vec, - pub named: Vec<(String, Binding)>, - pub cube_dim: CubeDim, - pub body: Body, - pub wmma_activated: bool, -} - -impl CompilerRepresentation for ComputeShader { - fn shared_memory_size(&self) -> usize { - let mut current = 0usize; - - for var in self.body.shared_memories.iter() { - let factor = match var.item { - Item::Vec4(_) => 4, - Item::Vec3(_) => 3, - Item::Vec2(_) => 2, - Item::Scalar(_) => 1, - }; - - let elem_size_bytes = var.item.elem().size(); - current += (var.size as usize) * factor * elem_size_bytes; - } - - current - } -} - -impl Display for ComputeShader { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.wmma_activated { - f.write_str("#include \nusing namespace nvcuda;\n")?; - } - - f.write_fmt(format_args!( - " -typedef unsigned int uint; - -extern \"C\" struct bool4 {{ - bool x; - bool y; - bool z; - bool w; -}}; - -extern \"C\" __global__ void kernel( -", - ))?; - - let num_bindings = self.inputs.len() + self.outputs.len() + self.named.len(); - let mut binding_index = 0; - for (index, binding) in self.inputs.iter().enumerate() { - binding_index += 1; - f.write_fmt(format_args!("{} input_{}[]", binding.item, index))?; - if binding_index < num_bindings { - f.write_str(",")?; - } - } - for (index, binding) in self.outputs.iter().enumerate() { - binding_index += 1; - f.write_fmt(format_args!("{} output_{}[]", binding.item, index))?; - if binding_index < num_bindings { - f.write_str(",")?; - } - } - for (name, binding) in self.named.iter() { - binding_index += 1; - f.write_fmt(format_args!("{} {}[]", binding.item, name))?; - - if binding_index < num_bindings { - f.write_str(",")?; - } - } - - f.write_str("\n) {\n")?; - - f.write_fmt(format_args!("{}", self.body))?; - f.write_str("\n}")?; - - Ok(()) - } -} - -impl Display for Location { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Location::Storage => f.write_str("storage"), - Location::Warp => f.write_str("workgroup"), - } - } -} - -impl Display for Visibility { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Visibility::Read => f.write_str("read"), - Visibility::ReadWrite => f.write_str("read_write"), - } - } -} diff --git a/crates/burn-cuda/src/compiler/unary.rs b/crates/burn-cuda/src/compiler/unary.rs deleted file mode 100644 index 1731b1b6b1..0000000000 --- a/crates/burn-cuda/src/compiler/unary.rs +++ /dev/null @@ -1,210 +0,0 @@ -use super::{Component, Elem, InstructionSettings, Item, Variable}; -use std::fmt::Display; - -pub trait Unary { - fn format( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - ) -> std::fmt::Result { - let item = out.item(); - let settings = Self::settings(*item.elem()); - - match item { - Item::Vec4(elem) => { - if settings.native_vec4 { - Self::format_native_vec4(f, input, out, elem) - } else { - Self::unroll_vec4(f, input, out, elem) - } - } - Item::Vec3(elem) => { - if settings.native_vec3 { - Self::format_native_vec3(f, input, out, elem) - } else { - Self::unroll_vec3(f, input, out, elem) - } - } - Item::Vec2(elem) => { - if settings.native_vec2 { - Self::format_native_vec2(f, input, out, elem) - } else { - Self::unroll_vec2(f, input, out, elem) - } - } - Item::Scalar(elem) => Self::format_scalar(f, *input, *out, elem), - } - } - - fn settings(_elem: Elem) -> InstructionSettings { - InstructionSettings::default() - } - - fn format_scalar( - f: &mut std::fmt::Formatter<'_>, - input: Input, - out: Out, - elem: Elem, - ) -> std::fmt::Result - where - Input: Component, - Out: Component; - - fn format_native_vec4( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::format_scalar(f, *input, *out, elem) - } - - fn format_native_vec3( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::format_scalar(f, *input, *out, elem) - } - - fn format_native_vec2( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - Self::format_scalar(f, *input, *out, elem) - } - - fn unroll_vec2( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - let input0 = input.index(0); - let input1 = input.index(1); - - let out0 = out.index(0); - let out1 = out.index(1); - - Self::format_scalar(f, input0, out0, elem)?; - Self::format_scalar(f, input1, out1, elem)?; - - Ok(()) - } - - fn unroll_vec3( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - let input0 = input.index(0); - let input1 = input.index(1); - let input2 = input.index(2); - - let out0 = out.index(0); - let out1 = out.index(1); - let out2 = out.index(2); - - Self::format_scalar(f, input0, out0, elem)?; - Self::format_scalar(f, input1, out1, elem)?; - Self::format_scalar(f, input2, out2, elem)?; - - Ok(()) - } - - fn unroll_vec4( - f: &mut std::fmt::Formatter<'_>, - input: &Variable, - out: &Variable, - elem: Elem, - ) -> std::fmt::Result { - let input0 = input.index(0); - let input1 = input.index(1); - let input2 = input.index(2); - let input3 = input.index(3); - - let out0 = out.index(0); - let out1 = out.index(1); - let out2 = out.index(2); - let out3 = out.index(3); - - Self::format_scalar(f, input0, out0, elem)?; - Self::format_scalar(f, input1, out1, elem)?; - Self::format_scalar(f, input2, out2, elem)?; - Self::format_scalar(f, input3, out3, elem)?; - - Ok(()) - } -} - -macro_rules! function { - ($name:ident, $func:expr) => { - pub struct $name; - - impl Unary for $name { - fn format_scalar( - f: &mut std::fmt::Formatter<'_>, - input: Input, - out: Out, - _elem: Elem, - ) -> std::fmt::Result { - f.write_fmt(format_args!("{out} = {}({input});\n", $func)) - } - } - }; -} - -function!(Abs, "abs"); -function!(Log, "log"); -function!(Log1p, "log1p"); -function!(Cos, "cos"); -function!(Sin, "sin"); -function!(Tanh, "tanh"); -function!(Sqrt, "sqrt"); -function!(Exp, "exp"); -function!(Erf, "erff"); -function!(Ceil, "ceil"); -function!(Floor, "floor"); - -pub struct Not; - -impl Unary for Not { - fn format_scalar( - f: &mut std::fmt::Formatter<'_>, - input: Input, - out: Out, - _elem: Elem, - ) -> std::fmt::Result - where - Input: Component, - Out: Component, - { - f.write_fmt(format_args!("{out} = !{input};\n")) - } -} - -pub struct Assign; - -impl Unary for Assign { - fn format_scalar( - f: &mut std::fmt::Formatter<'_>, - input: Input, - out: Out, - elem: Elem, - ) -> std::fmt::Result - where - Input: Component, - Out: Component, - { - // Cast only when necessary. - if elem != input.elem() { - f.write_fmt(format_args!("{out} = {elem}({input});\n")) - } else { - f.write_fmt(format_args!("{out} = {input};\n")) - } - } -} diff --git a/crates/burn-cuda/src/compiler/warp.rs b/crates/burn-cuda/src/compiler/warp.rs deleted file mode 100644 index 9f67de6a0b..0000000000 --- a/crates/burn-cuda/src/compiler/warp.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::fmt::Display; - -use super::Variable; - -#[derive(Clone, Debug)] -pub enum WarpInstruction { - ReduceSum { input: Variable, out: Variable }, - ReduceProd { input: Variable, out: Variable }, - ReduceMax { input: Variable, out: Variable }, - ReduceMin { input: Variable, out: Variable }, -} - -impl Display for WarpInstruction { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - WarpInstruction::ReduceSum { input, out } => f.write_fmt(format_args!( - " -{out} = {input}; - {{ - for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) {{ - {out} += __shfl_down_sync(0xFFFFFFFF, {out}, offset); - }} -}} - " - )), - WarpInstruction::ReduceProd { input, out } => f.write_fmt(format_args!( - " -{out} = {input}; - {{ - for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) {{ - {out} *= __shfl_down_sync(0xFFFFFFFF, {out}, offset); - }} -}} - " - )), - WarpInstruction::ReduceMax { input, out } => f.write_fmt(format_args!( - " -{out} = {input}; - {{ -for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) {{ - {out} = max({out}, __shfl_down_sync(0xFFFFFFFF, {out}, offset)); -}} -}} - " - )), - WarpInstruction::ReduceMin { input, out } => f.write_fmt(format_args!( - " -{out} = {input}; - {{ -for (int offset = warpSizeChecked / 2; offset > 0; offset /= 2) {{ - {out} = min({out}, __shfl_down_sync(0xFFFFFFFF, {out}, offset)); -}} -}} - " - )), - } - } -} diff --git a/crates/burn-cuda/src/compute/mod.rs b/crates/burn-cuda/src/compute/mod.rs deleted file mode 100644 index 4139c3868f..0000000000 --- a/crates/burn-cuda/src/compute/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod server; -mod storage; - -pub use server::*; -pub use storage::*; diff --git a/crates/burn-cuda/src/compute/server.rs b/crates/burn-cuda/src/compute/server.rs deleted file mode 100644 index d147f25974..0000000000 --- a/crates/burn-cuda/src/compute/server.rs +++ /dev/null @@ -1,306 +0,0 @@ -use super::storage::Binding; -use super::storage::CudaStorage; -use burn_compute::{ - memory_management::MemoryManagement, - server::{self, ComputeServer}, -}; -use burn_cube::ir::CubeDim; -use burn_cube::prelude::*; -use burn_cube::FeatureSet; -use burn_jit::JitAutotuneKey; -use burn_tensor::backend::SyncType; -use burn_tensor::reader_from_concrete; -use burn_tensor::Reader; -use cudarc::driver::sys::CUctx_st; -use cudarc::driver::sys::CUfunc_st; -use std::collections::HashMap; -use std::ffi::CStr; -use std::ffi::CString; - -#[derive(Debug)] -pub struct CudaServer> { - state: CudaServerState, - pub(crate) archs: Vec, - pub(crate) minimum_arch_version: i32, -} - -pub(crate) enum CudaServerState> { - Uninitialized { - device_index: usize, - init: Box CudaContext>, - }, - Initialized { - ctx: CudaContext, - }, -} - -impl> core::fmt::Debug for CudaServerState { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("Context") - } -} - -#[derive(Debug)] -pub(crate) struct CudaContext> { - context: *mut CUctx_st, - stream: cudarc::driver::sys::CUstream, - memory_management: MM, - module_names: HashMap, -} - -#[derive(Debug)] -struct CompiledKernel { - cube_dim: CubeDim, - shared_mem_bytes: usize, - func: *mut CUfunc_st, -} - -unsafe impl> Send for CudaServer {} - -impl> CudaServer { - fn read_sync(&mut self, binding: server::Binding) -> Vec { - let ctx = self.get_context(); - let resource = ctx.memory_management.get(binding.memory); - - // TODO: Check if it is possible to make this faster - let mut data = vec![0; resource.size() as usize]; - unsafe { - cudarc::driver::result::memcpy_dtoh_async(&mut data, resource.ptr, ctx.stream).unwrap(); - }; - ctx.sync(); - data - } -} - -impl> ComputeServer for CudaServer { - type Kernel = Box; - type DispatchOptions = CubeCount; - type Storage = CudaStorage; - type MemoryManagement = MM; - type AutotuneKey = JitAutotuneKey; - type FeatureSet = FeatureSet; - - fn read(&mut self, binding: server::Binding) -> Reader { - reader_from_concrete(self.read_sync(binding)) - } - - fn create(&mut self, data: &[u8]) -> server::Handle { - let ctx = self.get_context(); - let handle = ctx.memory_management.reserve(data.len(), || unsafe { - cudarc::driver::result::stream::synchronize(ctx.stream).unwrap(); - }); - let handle = server::Handle::new(handle); - let binding = handle.clone().binding().memory; - let resource = ctx.memory_management.get(binding); - - unsafe { - cudarc::driver::result::memcpy_htod_async(resource.ptr, data, ctx.stream).unwrap(); - } - - handle - } - - fn empty(&mut self, size: usize) -> server::Handle { - let ctx = self.get_context(); - let handle = ctx.memory_management.reserve(size, || unsafe { - cudarc::driver::result::stream::synchronize(ctx.stream).unwrap(); - }); - server::Handle::new(handle) - } - - fn execute( - &mut self, - kernel: Self::Kernel, - count: Self::DispatchOptions, - bindings: Vec>, - ) { - let arch = self.minimum_arch_version; - - let kernel_id = kernel.id(); - - let count = match count { - CubeCount::Static(x, y, z) => (x, y, z), - // TODO: CUDA doesn't have an exact equivalen of dynamic dispatch. Instead, kernels are free to launch other kernels. - // One option is to create a dummy kernel with 1 thread that launches the real kernel with the dynamic dispatch settings. - // For now, just read the dispatch settings from the buffer. - CubeCount::Dynamic(binding) => { - let data = self.read_sync(binding); - let data = bytemuck::cast_slice(&data); - assert!( - data.len() == 3, - "Dynamic cube count should contain 3 values" - ); - (data[0], data[1], data[2]) - } - }; - - let ctx = self.get_context(); - - if !ctx.module_names.contains_key(&kernel_id) { - ctx.compile_kernel(&kernel_id, kernel, arch); - } - - let bindings = bindings - .into_iter() - .map(|binding| ctx.memory_management.get(binding.memory).as_binding()) - .collect(); - - ctx.execute_task(kernel_id, count, bindings); - // TODO: fix this - // self.memory_management.storage().perform_deallocations(); - } - - fn sync(&mut self, sync_type: SyncType) { - match sync_type { - // Synchronize the stream if waiting. - SyncType::Wait => { - let ctx = self.get_context(); - ctx.sync(); - } - // Nothing to do - all tasks are already submitted to the stream. - SyncType::Flush => (), - } - } - - fn get_resource( - &mut self, - binding: server::Binding, - ) -> ::Resource { - let ctx = self.get_context(); - ctx.memory_management.get(binding.memory) - } -} - -impl> CudaContext { - pub fn new( - memory_management: MM, - stream: cudarc::driver::sys::CUstream, - context: *mut CUctx_st, - ) -> Self { - Self { - context, - memory_management, - module_names: HashMap::new(), - stream, - } - } - - fn sync(&mut self) { - unsafe { - cudarc::driver::result::stream::synchronize(self.stream).unwrap(); - }; - } - - fn compile_kernel(&mut self, kernel_id: &str, kernel: Box, arch: i32) { - let kernel_compiled = kernel.compile(); - let shared_mem_bytes = kernel_compiled.shared_mem_bytes; - let cube_dim = kernel_compiled.cube_dim; - let arch = format!("--gpu-architecture=sm_{}", arch); - - #[cfg(target_os = "linux")] - let options = &[ - arch.as_str(), - "--include-path=/usr/include", - "--include-path=/usr/include/cuda", - "--include-path=/usr/local/include/cuda", - ]; - #[cfg(not(target_os = "linux"))] // TODO: add include-path for other OS. - let options = &[arch.as_str()]; - - let ptx = unsafe { - let program = cudarc::nvrtc::result::create_program(kernel_compiled.source).unwrap(); - if cudarc::nvrtc::result::compile_program(program, options).is_err() { - let log_raw = cudarc::nvrtc::result::get_program_log(program).unwrap(); - let log_ptr = log_raw.as_ptr(); - let log = CStr::from_ptr(log_ptr).to_str().unwrap(); - let mut message = "[Compilation Error] ".to_string(); - for line in log.split('\n') { - if !line.is_empty() { - message += format!("\n {line}").as_str(); - } - } - let source = kernel.compile().source; - panic!("{message}\n[Source] \n{source}"); - }; - cudarc::nvrtc::result::get_ptx(program).unwrap() - }; - - let func_name = CString::new("kernel".to_string()).unwrap(); - let func = unsafe { - let module = - cudarc::driver::result::module::load_data(ptx.as_ptr() as *const _).unwrap(); - cudarc::driver::result::module::get_function(module, func_name).unwrap() - }; - - self.module_names.insert( - kernel_id.to_string(), - CompiledKernel { - cube_dim, - shared_mem_bytes, - func, - }, - ); - } - - fn execute_task( - &mut self, - kernel_id: String, - dispatch_count: (u32, u32, u32), - mut bindings: Vec, - ) { - let kernel = self.module_names.get(&kernel_id).unwrap(); - let cube_dim = kernel.cube_dim; - unsafe { - cudarc::driver::result::launch_kernel( - kernel.func, - dispatch_count, - (cube_dim.x, cube_dim.y, cube_dim.z), - kernel.shared_mem_bytes as u32, - self.stream, - &mut bindings, - ) - .unwrap(); - }; - } -} - -impl> CudaServer { - /// Create a new cuda server. - pub(crate) fn new(index: usize, init: Box CudaContext>) -> Self { - let archs = unsafe { - let mut num_supported_arg: core::ffi::c_int = 0; - cudarc::nvrtc::sys::lib() - .nvrtcGetNumSupportedArchs(core::ptr::from_mut(&mut num_supported_arg)); - - let mut archs: Vec = vec![0; num_supported_arg as usize]; - cudarc::nvrtc::sys::lib().nvrtcGetSupportedArchs(core::ptr::from_mut(&mut archs[0])); - archs - }; - - let minimum_arch_version = archs[0]; - - Self { - state: CudaServerState::Uninitialized { - device_index: index, - init, - }, - archs, - minimum_arch_version, - } - } - - fn get_context(&mut self) -> &mut CudaContext { - if let CudaServerState::Uninitialized { device_index, init } = &self.state { - let ctx = init(*device_index); - self.state = CudaServerState::Initialized { ctx }; - } - if let CudaServerState::Initialized { ctx } = &mut self.state { - unsafe { - cudarc::driver::result::ctx::set_current(ctx.context).unwrap(); - }; - ctx - } else { - panic!("Context should be initialized"); - } - } -} diff --git a/crates/burn-cuda/src/compute/storage.rs b/crates/burn-cuda/src/compute/storage.rs deleted file mode 100644 index 4c373ed538..0000000000 --- a/crates/burn-cuda/src/compute/storage.rs +++ /dev/null @@ -1,132 +0,0 @@ -use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; -use cudarc::driver::sys::CUstream; -use std::collections::HashMap; - -/// Buffer storage for cuda. -pub struct CudaStorage { - memory: HashMap, - deallocations: Vec, - stream: cudarc::driver::sys::CUstream, -} - -unsafe impl Send for CudaStorage {} - -impl core::fmt::Debug for CudaStorage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(format!("CudaStorage {{ device: {:?} }}", self.stream).as_str()) - } -} - -/// Keeps actual wgpu buffer references in a hashmap with ids as key. -impl CudaStorage { - /// Create a new storage on the given [device](wgpu::Device). - pub fn new(stream: CUstream) -> Self { - Self { - memory: HashMap::new(), - deallocations: Vec::new(), - stream, - } - } - - /// Actually deallocates buffers tagged to be deallocated. - pub fn perform_deallocations(&mut self) { - for id in self.deallocations.drain(..) { - if let Some(ptr) = self.memory.remove(&id) { - unsafe { - cudarc::driver::result::free_async(ptr, self.stream).unwrap(); - } - } - } - } -} - -/// The memory resource that can be allocated for wgpu. -#[derive(new, Debug)] -pub struct CudaResource { - /// The wgpu buffer. - pub ptr: u64, - pub binding: *mut std::ffi::c_void, - /// How the resource is used. - pub kind: CudaResourceKind, -} - -unsafe impl Send for CudaResource {} - -pub type Binding = *mut std::ffi::c_void; - -impl CudaResource { - /// Return the binding view of the buffer. - pub fn as_binding(&self) -> Binding { - self.binding - } - - /// Return the buffer size. - pub fn size(&self) -> u64 { - match self.kind { - CudaResourceKind::Full { size } => size as u64, - CudaResourceKind::Slice { size, offset: _ } => size as u64, - } - } - - /// Return the buffer offset. - pub fn offset(&self) -> u64 { - match self.kind { - CudaResourceKind::Full { size: _ } => 0, - CudaResourceKind::Slice { size: _, offset } => offset as u64, - } - } -} - -/// How the resource is used, either as a slice or fully. -#[derive(Debug)] -pub enum CudaResourceKind { - /// Represents an entire buffer. - Full { size: usize }, - /// A slice over a buffer. - Slice { size: usize, offset: usize }, -} - -impl ComputeStorage for CudaStorage { - type Resource = CudaResource; - - fn get(&mut self, handle: &StorageHandle) -> Self::Resource { - let ptr = self.memory.get(&handle.id).unwrap(); - match handle.utilization { - StorageUtilization::Full(size) => CudaResource::new( - *ptr, - ptr as *const cudarc::driver::sys::CUdeviceptr as *mut std::ffi::c_void, - CudaResourceKind::Full { size }, - ), - StorageUtilization::Slice { offset, size } => CudaResource::new( - *ptr, - ptr as *const cudarc::driver::sys::CUdeviceptr as *mut std::ffi::c_void, - CudaResourceKind::Slice { size, offset }, - ), - } - } - - fn alloc(&mut self, size: usize) -> StorageHandle { - let id = StorageId::new(); - let ptr = unsafe { cudarc::driver::result::malloc_async(self.stream, size).unwrap() }; - self.memory.insert(id.clone(), ptr); - StorageHandle::new(id, StorageUtilization::Full(size)) - } - - fn dealloc(&mut self, id: StorageId) { - self.deallocations.push(id); - } - - fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) { - let num_bytes = from.size(); - - unsafe { - cudarc::driver::result::memcpy_dtod_async( - self.get(to).ptr, - self.get(from).ptr, - num_bytes, - self.stream, - ) - .unwrap(); - } - } -} diff --git a/crates/burn-cuda/src/device.rs b/crates/burn-cuda/src/device.rs deleted file mode 100644 index 04a6014436..0000000000 --- a/crates/burn-cuda/src/device.rs +++ /dev/null @@ -1,12 +0,0 @@ -use burn_tensor::backend::{DeviceId, DeviceOps}; - -#[derive(new, Clone, Debug, PartialEq, Eq, Default, Hash)] -pub struct CudaDevice { - pub index: usize, -} - -impl DeviceOps for CudaDevice { - fn id(&self) -> DeviceId { - DeviceId::new(0, self.index as u32) - } -} diff --git a/crates/burn-cuda/src/lib.rs b/crates/burn-cuda/src/lib.rs index 8f5d181cdf..23958db2e4 100644 --- a/crates/burn-cuda/src/lib.rs +++ b/crates/burn-cuda/src/lib.rs @@ -1,16 +1,8 @@ -#[macro_use] -extern crate derive_new; extern crate alloc; -mod compute; -mod device; -mod runtime; - -pub mod compiler; -pub use device::*; - use burn_jit::JitBackend; -pub use runtime::CudaRuntime; +pub use cubecl::cuda::CudaDevice; +use cubecl::cuda::CudaRuntime; #[cfg(not(feature = "fusion"))] pub type Cuda = JitBackend; @@ -20,10 +12,9 @@ pub type Cuda = burn_fusion::Fusion>; -} - -static RUNTIME: ComputeRuntime> = - ComputeRuntime::new(); - -type Server = CudaServer>; - -impl Runtime for CudaRuntime { - type Compiler = CudaCompiler; - type Server = CudaServer>; - - type Channel = MutexComputeChannel>>; - type Device = CudaDevice; - - fn client(device: &Self::Device) -> ComputeClient { - fn init(index: usize) -> CudaContext> { - cudarc::driver::result::init().unwrap(); - let device_ptr = cudarc::driver::result::device::get(index as i32).unwrap(); - - let ctx = unsafe { - let ctx = cudarc::driver::result::primary_ctx::retain(device_ptr).unwrap(); - cudarc::driver::result::ctx::set_current(ctx).unwrap(); - ctx - }; - - let stream = cudarc::driver::result::stream::create( - cudarc::driver::result::stream::StreamKind::NonBlocking, - ) - .unwrap(); - let storage = CudaStorage::new(stream); - let memory_management = SimpleMemoryManagement::new( - storage, - DeallocStrategy::new_period_tick(1), - SliceStrategy::Ratio(0.8), - ); - CudaContext::new(memory_management, stream, ctx) - } - - RUNTIME.client(device, move || { - let mut server = CudaServer::new(device.index, Box::new(init)); - let mut features = FeatureSet::new(&[Feature::Subcube]); - let tuner_device_id = tuner_device_id(); - - if let Some(wmma_minimum_version) = register_wmma_features(&mut features, &server.archs) - { - server.minimum_arch_version = - i32::max(server.minimum_arch_version, wmma_minimum_version); - } - - ComputeClient::new( - MutexComputeChannel::new(server), - Arc::new(RwLock::new(Tuner::new("cuda", &tuner_device_id))), - Arc::new(features), - ) - }) - } - - fn name() -> &'static str { - "cuda" - } - - fn require_array_lengths() -> bool { - true - } -} - -fn register_wmma_features(features: &mut FeatureSet, archs: &[i32]) -> Option { - let wmma_minimum_version = 70; - let mut wmma = false; - - for arch in archs { - if *arch >= wmma_minimum_version { - wmma = true; - break; - } - } - - if wmma { - // Types fully supported. - for (a, b, c) in [ - ( - Elem::Float(FloatKind::F16), - Elem::Float(FloatKind::F16), - Elem::Float(FloatKind::F16), - ), - ( - Elem::Float(FloatKind::F16), - Elem::Float(FloatKind::F16), - Elem::Float(FloatKind::F32), - ), - ( - Elem::Float(FloatKind::BF16), - Elem::Float(FloatKind::BF16), - Elem::Float(FloatKind::F32), - ), - ] { - features.register(Feature::Cmma { - a, - b, - c, - m: 16, - k: 16, - n: 16, - }); - features.register(Feature::Cmma { - a, - b, - c, - m: 32, - k: 8, - n: 16, - }); - features.register(Feature::Cmma { - a, - b, - c, - m: 8, - k: 32, - n: 16, - }); - } - return Some(wmma_minimum_version); - } - - None -} -fn tuner_device_id() -> String { - "cuda".into() -} diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index c0eee0a635..d0c1e1f8cd 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -11,8 +11,8 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-jit" version.workspace = true [features] -default = ["autotune", "std", "burn-compute/default", "fusion"] -std = [] +default = ["autotune", "std", "fusion", "cubecl/default"] +std = ["cubecl/std"] doc = ["default"] autotune = [] template = [] @@ -27,9 +27,9 @@ export_tests = [ ] [dependencies] +cubecl = { workspace = true, features = ["linalg"] } burn-common = { path = "../burn-common", version = "0.14.0" } -burn-tensor = { path = "../burn-tensor", version = "0.14.0" } -burn-cube = { path = "../burn-cube", version = "0.14.0" } +burn-tensor = { path = "../burn-tensor", version = "0.14.0", features = ["cubecl"] } burn-fusion = { path = "../burn-fusion", version = "0.14.0", optional = true } bytemuck = { workspace = true } @@ -45,10 +45,6 @@ serde = { workspace = true } text_placeholder = { workspace = true, features = ["struct_context"] } hashbrown = { workspace = true } -burn-compute = { path = "../burn-compute", version = "0.14.0", default-features = false, features = [ - "channel-mutex", - "std", -] } burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.14.0", optional = true } # When exporting tests diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index 7eb1cc9305..d3945d3c54 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -1,9 +1,9 @@ use crate::{ tensor::{JitTensor, QJitTensor}, - FloatElement, IntElement, JitAutotuneKey, JitRuntime, PrecisionBridge, + FloatElement, IntElement, JitRuntime, PrecisionBridge, }; -use burn_compute::server::ComputeServer; -use burn_tensor::backend::{Backend, SyncType}; +use burn_tensor::backend::{Backend, DeviceOps, SyncType}; +use cubecl::server::ComputeServer; use rand::{rngs::StdRng, SeedableRng}; use std::{marker::PhantomData, sync::Mutex}; @@ -20,7 +20,7 @@ pub struct JitBackend { impl Backend for JitBackend where R: JitRuntime, - R::Server: ComputeServer, + R::Server: ComputeServer, R::Device: burn_tensor::backend::DeviceOps, F: FloatElement, I: IntElement, @@ -51,8 +51,12 @@ where } fn sync(device: &Self::Device, sync_type: SyncType) { + let sync = match sync_type { + SyncType::Flush => cubecl::client::SyncType::Flush, + SyncType::Wait => cubecl::client::SyncType::Wait, + }; let client = R::client(device); - client.sync(sync_type); + client.sync(sync); } } @@ -73,3 +77,11 @@ impl Default for JitBackend JitRuntime for R +where + R::Device: DeviceOps, +{ + type JitDevice = R::Device; + type JitServer = R::Server; +} diff --git a/crates/burn-jit/src/element.rs b/crates/burn-jit/src/element.rs index a59ea36e96..def5e5e4ba 100644 --- a/crates/burn-jit/src/element.rs +++ b/crates/burn-jit/src/element.rs @@ -1,4 +1,4 @@ -use burn_cube::{ +use cubecl::{ frontend::{Float, Int, Numeric, UInt, BF16, F16, F32, I32}, CubeElement, }; diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 97177fed8a..df79f1fd44 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -3,11 +3,11 @@ use crate::{ element::JitElement, fusion::ElementWiseBuilder, kernel, tensor::JitTensor, FloatElement, IntElement, JitBackend, JitRuntime, }; -use burn_compute::client::ComputeClient; -use burn_cube::{ir::ReadingStrategy, InplaceMapping, KernelExpansion, KernelSettings}; use burn_fusion::{client::MutexFusionClient, FusionBackend, FusionRuntime}; use burn_tensor::{repr::ReprBackend, Shape}; use core::marker::PhantomData; +use cubecl::client::ComputeClient; +use cubecl::{ir::ReadingStrategy, InplaceMapping, KernelExpansion, KernelSettings}; use half::{bf16, f16}; use serde::{Deserialize, Serialize}; @@ -162,7 +162,7 @@ pub struct JitFusionHandle { /// Compute client for jit. pub client: ComputeClient, /// The buffer where the data are stored. - pub handle: burn_compute::server::Handle, + pub handle: cubecl::server::Handle, /// The device of the current tensor. pub device: R::Device, pub(crate) strides: Vec, diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 1693310623..f30ccbe8f9 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -4,9 +4,6 @@ use crate::{ fusion::{tracing::TraceBuilder, JitOptimization}, JitRuntime, }; -use burn_cube::ir::{ - BinaryOperator, ConditionalAssign, Operator, Procedure, UnaryOperator, Variable, -}; use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus}; use burn_tensor::{ repr::{ @@ -16,6 +13,10 @@ use burn_tensor::{ }, Element, }; +use cubecl::ir::{ + BinaryOperator, ConditionalAssign, ConstantScalarValue, Elem, Operator, Procedure, + UnaryOperator, Variable, +}; /// Fused element wise operations that are normally memory bound. pub(crate) struct ElementWiseBuilder { @@ -288,10 +289,14 @@ impl ElementWiseBuilder { return false; } - let input = Variable::ConstantScalar { - value: 1.0, - elem: desc.dtype.into(), + let elem: Elem = desc.dtype.into(); + let input = match elem { + Elem::Float(kind) => ConstantScalarValue::Float(1.0, kind), + Elem::Int(kind) => ConstantScalarValue::Int(1, kind), + Elem::UInt => ConstantScalarValue::UInt(1), + Elem::Bool => ConstantScalarValue::Bool(true), }; + let input = Variable::ConstantScalar(input); let out = self.builder.output(desc, Variable::AbsolutePos); self.builder @@ -304,10 +309,14 @@ impl ElementWiseBuilder { return false; } - let input = Variable::ConstantScalar { - value: 0.0, - elem: desc.dtype.into(), + let elem: Elem = desc.dtype.into(); + let input = match elem { + Elem::Float(kind) => ConstantScalarValue::Float(0.0, kind), + Elem::Int(kind) => ConstantScalarValue::Int(0, kind), + Elem::UInt => ConstantScalarValue::UInt(0), + Elem::Bool => ConstantScalarValue::Bool(false), }; + let input = Variable::ConstantScalar(input); let out = self.builder.output(desc, Variable::AbsolutePos); self.builder diff --git a/crates/burn-jit/src/fusion/elemwise/kernel.rs b/crates/burn-jit/src/fusion/elemwise/kernel.rs index a58c22d9f7..7d9c4714dc 100644 --- a/crates/burn-jit/src/fusion/elemwise/kernel.rs +++ b/crates/burn-jit/src/fusion/elemwise/kernel.rs @@ -1,8 +1,8 @@ -use burn_cube::{ +use burn_tensor::repr::TensorDescription; +use cubecl::{ calculate_cube_count_elemwise, calculate_num_elems_dyn_rank, ir::CubeDim, KernelExpansion, KernelSettings, }; -use burn_tensor::repr::TensorDescription; use crate::{ fusion::{ diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index 86d641a33e..65bc5b8d7a 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -7,12 +7,15 @@ use super::{ use crate::{ fusion::{kernel::FusionKernel, tracing::Trace, JitFusionHandle}, tune_key::JitAutotuneKey, - JitRuntime, + JitRuntime, JitTuneId, }; use burn_common::id::IdGenerator; -use burn_compute::client::ComputeClient; -use burn_cube::ir::CubeDim; use burn_fusion::stream::Context; +use cubecl::ir::CubeDim; +use cubecl::{ + client::ComputeClient, + tune::{local_tuner, LocalTuner}, +}; use serde::{Deserialize, Serialize}; #[derive(new)] @@ -71,10 +74,14 @@ impl ElementWise> { self.autotune_shape(context), )); - if let Some(index) = client.autotune_result(&key) { + let id = JitTuneId::new::(&self.device); + + static TUNER: LocalTuner = local_tuner!(); + + if let Some(index) = TUNER.autotune_result(&id, &key) { self.run_kernel(context, client, index) } else { - self.run_autotune(context, client, key) + self.run_autotune(context, client, id, key, &TUNER) } } @@ -107,7 +114,9 @@ impl ElementWise> { &mut self, context: &mut Context<'_, JitFusionHandle>, client: ComputeClient, + id: JitTuneId, key: JitAutotuneKey, + tuner: &LocalTuner, ) { let info = self.trace.running(); @@ -136,12 +145,16 @@ impl ElementWise> { false, ); - client.autotune_execute(Box::new(ElementWiseAutotuneOperationSet::new( - key, - kernel_1.into(), - kernel_2.into(), - kernel_default.into(), - ))); + tuner.execute( + &id, + &client, + Box::new(ElementWiseAutotuneOperationSet::new( + key, + kernel_1.into(), + kernel_2.into(), + kernel_default.into(), + )), + ); } pub(crate) fn len(&self) -> usize { diff --git a/crates/burn-jit/src/fusion/elemwise/tune.rs b/crates/burn-jit/src/fusion/elemwise/tune.rs index 10be68fbb3..10135f37de 100644 --- a/crates/burn-jit/src/fusion/elemwise/tune.rs +++ b/crates/burn-jit/src/fusion/elemwise/tune.rs @@ -3,7 +3,7 @@ use std::fmt::Display; use crate::{ fusion::kernel::AutotunableKernel, tune::anchor, tune_key::JitAutotuneKey, JitRuntime, }; -use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet}; +use cubecl::tune::{AutotuneOperation, AutotuneOperationSet}; use serde::{Deserialize, Serialize}; #[derive(new)] @@ -38,7 +38,7 @@ impl AutotuneOperationSet for ElementWiseAutotune self.key.clone() } - fn autotunables(&self) -> Vec> { + fn autotunables(&self) -> Vec> { let kernel_1: Box = self.kernel_1.clone(); let kernel_2: Box = self.kernel_2.clone(); diff --git a/crates/burn-jit/src/fusion/kernel.rs b/crates/burn-jit/src/fusion/kernel.rs index c7a766ac10..4fc17882f3 100644 --- a/crates/burn-jit/src/fusion/kernel.rs +++ b/crates/burn-jit/src/fusion/kernel.rs @@ -1,16 +1,16 @@ -use burn_cube::calculate_num_elems_dyn_rank; -use burn_cube::prelude::*; +use cubecl::calculate_num_elems_dyn_rank; +use cubecl::prelude::*; use crate::fusion::strides_dyn_rank; use crate::fusion::JitFusionHandle; use crate::kernel::Kernel; use crate::JitRuntime; -use burn_compute::client::ComputeClient; -use burn_compute::server::Binding; -use burn_compute::tune::AutotuneOperation; use burn_fusion::stream::Context; use burn_tensor::repr::TensorDescription; use burn_tensor::repr::TensorStatus; +use cubecl::client::ComputeClient; +use cubecl::server::Binding; +use cubecl::tune::AutotuneOperation; use std::marker::PhantomData; use std::sync::Arc; diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index efb86ad11a..e214f7f8ef 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -1,12 +1,12 @@ use super::{trace::Trace, Scalars}; -use burn_cube::ir::{ - BinaryOperator, Elem, Item, Operation, Operator, Procedure, Scope, Subcube, UnaryOperator, - Variable, -}; use burn_tensor::{ repr::{TensorDescription, TensorId, TensorStatus}, Element, }; +use cubecl::ir::{ + BinaryOperator, Elem, Item, Operation, Operator, Procedure, Scope, Subcube, UnaryOperator, + Variable, +}; use hashbrown::HashMap; /// Type facilitating building a [trace](Trace) by doing most of the conversions between the diff --git a/crates/burn-jit/src/fusion/tracing/trace.rs b/crates/burn-jit/src/fusion/tracing/trace.rs index 42b3645965..2d008f52c7 100644 --- a/crates/burn-jit/src/fusion/tracing/trace.rs +++ b/crates/burn-jit/src/fusion/tracing/trace.rs @@ -1,9 +1,9 @@ use super::Scalars; -use burn_cube::{ +use burn_tensor::repr::TensorDescription; +use cubecl::{ ir::{Elem, FloatKind, IntKind, Item, Scope, Variable, Visibility}, InputInfo, KernelExpansion, OutputInfo, }; -use burn_tensor::repr::TensorDescription; use serde::{Deserialize, Serialize}; /// A trace encapsulates all information necessary to perform the compilation and execution of diff --git a/crates/burn-jit/src/kernel/binary.rs b/crates/burn-jit/src/kernel/binary.rs index 1fb1816051..39c82b7f62 100644 --- a/crates/burn-jit/src/kernel/binary.rs +++ b/crates/burn-jit/src/kernel/binary.rs @@ -1,6 +1,6 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_cube::{frontend::TensorHandle, CubeCountSettings, Execution}; use burn_tensor::Shape; +use cubecl::{frontend::TensorHandleRef, CubeCountSettings, Execution}; /// Creates a binary kernel. #[macro_export] @@ -50,50 +50,50 @@ macro_rules! binary { #[allow(clippy::redundant_closure_call)] fn compile( - settings: burn_cube::KernelSettings, - ) -> burn_cube::ir::KernelDefinition + settings: cubecl::KernelSettings, + ) -> cubecl::ir::KernelDefinition where I: $crate::element::JitElement, O: $crate::element::JitElement { - let mut scope = burn_cube::ir::Scope::root(); - let position = burn_cube::ir::Variable::AbsolutePos; + let mut scope = cubecl::ir::Scope::root(); + let position = cubecl::ir::Variable::AbsolutePos; let op = $ops(&mut scope, I::cube_elem(), position); scope.register(op); let local = scope.last_local_index().unwrap().index().unwrap(); - let lhs = burn_cube::InputInfo::Array { - item: burn_cube::ir::Item::new(I::cube_elem()), - visibility: burn_cube::ir::Visibility::Read, + let lhs = cubecl::InputInfo::Array { + item: cubecl::ir::Item::new(I::cube_elem()), + visibility: cubecl::ir::Visibility::Read, }; - let rhs = burn_cube::InputInfo::Array { - item: burn_cube::ir::Item::new(I::cube_elem()), - visibility: burn_cube::ir::Visibility::Read, + let rhs = cubecl::InputInfo::Array { + item: cubecl::ir::Item::new(I::cube_elem()), + visibility: cubecl::ir::Visibility::Read, }; - let out = burn_cube::OutputInfo::ArrayWrite { - item: burn_cube::ir::Item::new(O::cube_elem()), + let out = cubecl::OutputInfo::ArrayWrite { + item: cubecl::ir::Item::new(O::cube_elem()), local, position, }; - let info = burn_cube::prelude::KernelExpansion { + let info = cubecl::prelude::KernelExpansion { inputs: vec![lhs, rhs], outputs: vec![out], scope, }; - burn_cube::prelude::KernelIntegrator::new(info).integrate(settings) + cubecl::prelude::KernelIntegrator::new(info).integrate(settings) } #[allow(clippy::redundant_closure_call)] impl $crate::kernel::Kernel for Ops where - C: burn_cube::Compiler, + C: cubecl::Compiler, I: $crate::element::JitElement, O: $crate::element::JitElement { - fn define(&self) -> burn_cube::ir::KernelDefinition { - let settings = burn_cube::KernelSettings::default(); + fn define(&self) -> cubecl::ir::KernelDefinition { + let settings = cubecl::KernelSettings::default(); compile::(settings) } } @@ -102,16 +102,16 @@ macro_rules! binary { impl $crate::kernel::Kernel for OpsInplaceLhs where - C: burn_cube::Compiler, + C: cubecl::Compiler, I: $crate::element::JitElement, O: $crate::element::JitElement { - fn define(&self) -> burn_cube::ir::KernelDefinition { - let mapping = burn_cube::InplaceMapping { + fn define(&self) -> cubecl::ir::KernelDefinition { + let mapping = cubecl::InplaceMapping { pos_input: 0, pos_output: 0, }; - let settings = burn_cube::KernelSettings::default() + let settings = cubecl::KernelSettings::default() .inplace(vec![mapping]); compile::(settings) } @@ -121,16 +121,16 @@ macro_rules! binary { impl $crate::kernel::Kernel for OpsInplaceRhs where - C: burn_cube::Compiler, + C: cubecl::Compiler, I: $crate::element::JitElement, O: $crate::element::JitElement { - fn define(&self) -> burn_cube::ir::KernelDefinition { - let mapping = burn_cube::InplaceMapping { + fn define(&self) -> cubecl::ir::KernelDefinition { + let mapping = cubecl::InplaceMapping { pos_input: 1, pos_output: 0, }; - let settings = burn_cube::KernelSettings::default() + let settings = cubecl::KernelSettings::default() .inplace(vec![mapping]); compile::(settings) } @@ -156,8 +156,8 @@ where if inplace_enabled && lhs.can_mut_broadcast(&rhs) { Execution::start(kernel_inplace_lhs, rhs.client) .inputs(&[ - TensorHandle::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), - TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), + TensorHandleRef::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + TensorHandleRef::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), ]) .execute(CubeCountSettings::Input { pos: 0 }); @@ -165,8 +165,8 @@ where } else if inplace_enabled && rhs.can_mut_broadcast(&lhs) { Execution::start(kernel_inplace_rhs, lhs.client) .inputs(&[ - TensorHandle::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), - TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), + TensorHandleRef::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + TensorHandleRef::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), ]) .execute(CubeCountSettings::Input { pos: 1 }); @@ -189,10 +189,10 @@ where Execution::start(kernel, lhs.client) .inputs(&[ - TensorHandle::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), - TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), + TensorHandleRef::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + TensorHandleRef::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), ]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &out.handle, &out.strides, &out.shape.dims, diff --git a/crates/burn-jit/src/kernel/cast/base.rs b/crates/burn-jit/src/kernel/cast/base.rs index 944a1abad7..acac2567f0 100644 --- a/crates/burn-jit/src/kernel/cast/base.rs +++ b/crates/burn-jit/src/kernel/cast/base.rs @@ -1,6 +1,6 @@ -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -35,12 +35,12 @@ pub fn cast( ); Execution::start(kernel, tensor.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &tensor.handle, &tensor.strides, &tensor.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/cast/bool_cast.rs b/crates/burn-jit/src/kernel/cast/bool_cast.rs index f68ee0b8a9..a4e9a2c0f5 100644 --- a/crates/burn-jit/src/kernel/cast/bool_cast.rs +++ b/crates/burn-jit/src/kernel/cast/bool_cast.rs @@ -1,7 +1,7 @@ use crate::{kernel::Kernel, tensor::JitTensor, JitElement, JitRuntime}; -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, Item, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -28,12 +28,12 @@ pub fn bool_cast( ); Execution::start(kernel, tensor.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &tensor.handle, &tensor.strides, &tensor.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/clamp.rs b/crates/burn-jit/src/kernel/clamp.rs index 3a77d37413..58a711610e 100644 --- a/crates/burn-jit/src/kernel/clamp.rs +++ b/crates/burn-jit/src/kernel/clamp.rs @@ -1,4 +1,4 @@ -use burn_cube::prelude::*; +use cubecl::prelude::*; use crate::kernel::{launch_unary, UnaryOp}; use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; diff --git a/crates/burn-jit/src/kernel/comparison.rs b/crates/burn-jit/src/kernel/comparison.rs index 8f3bf75394..0ecec95e2f 100644 --- a/crates/burn-jit/src/kernel/comparison.rs +++ b/crates/burn-jit/src/kernel/comparison.rs @@ -1,10 +1,10 @@ use super::{index_offset_with_layout, Kernel}; use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_cube::{ +use burn_tensor::Shape; +use cubecl::{ calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, Runtime, SUBCUBE_DIM_APPROX, }; -use burn_tensor::Shape; #[cube] pub(crate) trait ComparisonOp: 'static + Send + Sync { @@ -147,7 +147,7 @@ pub(crate) fn launch_cmp< let same_tensor_type = core::any::TypeId::of::() == core::any::TypeId::of::(); if same_tensor_type && lhs.can_mut_broadcast(&rhs) { kernel_cmp::launch::( - client, + &client, cube_count, CubeDim::default(), TensorArg::vectorized( @@ -171,7 +171,7 @@ pub(crate) fn launch_cmp< JitTensor::new(lhs.client, lhs.handle, lhs.shape, lhs.device, lhs.strides) } else if same_tensor_type && rhs.can_mut_broadcast(&lhs) { kernel_cmp::launch::( - client, + &client, cube_count, CubeDim::default(), TensorArg::vectorized( @@ -200,7 +200,7 @@ pub(crate) fn launch_cmp< let output = JitTensor::new_contiguous(lhs.client.clone(), lhs.device, shape_out, buffer); kernel_cmp::launch::( - client, + &client, cube_count, CubeDim::default(), TensorArg::vectorized( @@ -252,7 +252,7 @@ pub(crate) fn launch_scalar_cmp< let same_tensor_type = core::any::TypeId::of::() == core::any::TypeId::of::(); if same_tensor_type && tensor.can_mut() { kernel_scalar_cmp::launch::( - client, + &client, cube_count, CubeDim::default(), TensorArg::vectorized( @@ -283,7 +283,7 @@ pub(crate) fn launch_scalar_cmp< ); kernel_scalar_cmp::launch::( - client, + &client, cube_count, CubeDim::default(), TensorArg::vectorized( diff --git a/crates/burn-jit/src/kernel/contiguous.rs b/crates/burn-jit/src/kernel/contiguous.rs index 6e4576ff87..f210b6133f 100644 --- a/crates/burn-jit/src/kernel/contiguous.rs +++ b/crates/burn-jit/src/kernel/contiguous.rs @@ -1,7 +1,7 @@ use super::Kernel; use crate::{tensor::JitTensor, JitElement, JitRuntime}; -use burn_cube::{calculate_cube_count_elemwise, prelude::*}; -use burn_cube::{frontend::TensorArg, KernelSettings, SUBCUBE_DIM_APPROX}; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; +use cubecl::{frontend::TensorArg, KernelSettings, SUBCUBE_DIM_APPROX}; /// Returns the offset of the tensor corresponding to the layout tensor. #[cube] @@ -88,7 +88,7 @@ pub fn into_contiguous( ); into_contiguous_kernel::launch::( - client, + &client, cube_count, CubeDim::default(), TensorArg::vectorized( diff --git a/crates/burn-jit/src/kernel/conv/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d.rs index 106979ebd8..60a3879263 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d.rs @@ -1,4 +1,4 @@ -use burn_cube::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX}; +use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX}; use burn_tensor::{ ops::{conv::calculate_conv_output_size, ConvOptions}, @@ -164,7 +164,7 @@ pub(crate) fn conv2d( let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX); conv2d_kernel::launch::( - input.client, + &input.client, cube_dim, CubeDim::default(), TensorArg::new(&input.handle, &input.strides, &input.shape.dims), diff --git a/crates/burn-jit/src/kernel/conv/conv3d.rs b/crates/burn-jit/src/kernel/conv/conv3d.rs index da1e3aea07..5870224a4c 100644 --- a/crates/burn-jit/src/kernel/conv/conv3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv3d.rs @@ -1,4 +1,4 @@ -use burn_cube::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX}; +use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX}; use burn_tensor::{ ops::{conv::calculate_conv_output_size, ConvOptions}, @@ -189,7 +189,7 @@ pub(crate) fn conv3d( }; conv3d_kernel::launch::( - input.client, + &input.client, calculate_cube_count_elemwise(output.shape.num_elements(), SUBCUBE_DIM_APPROX), CubeDim::default(), TensorArg::new(&input.handle, &input.strides, &input.shape.dims), diff --git a/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs index 657b477390..e2b8a6c753 100644 --- a/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs @@ -1,6 +1,6 @@ -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -410,11 +410,11 @@ pub(crate) fn conv_transpose2d( Execution::start(kernel, input.client.clone()) .inputs(&[ - TensorHandle::::new(&input.handle, &input.strides, &input.shape.dims), - TensorHandle::new(&weight.handle, &weight.strides, &weight.shape.dims), - TensorHandle::new(&bias.handle, &bias.strides, &bias.shape.dims), + TensorHandleRef::::new(&input.handle, &input.strides, &input.shape.dims), + TensorHandleRef::new(&weight.handle, &weight.strides, &weight.shape.dims), + TensorHandleRef::new(&bias.handle, &bias.strides, &bias.shape.dims), ]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs b/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs index cc4ca53afe..ac94be7cc5 100644 --- a/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs @@ -1,6 +1,6 @@ -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -498,11 +498,11 @@ pub(crate) fn conv_transpose3d( Execution::start(kernel, input.client.clone()) .inputs(&[ - TensorHandle::::new(&input.handle, &input.strides, &input.shape.dims), - TensorHandle::new(&weight.handle, &weight.strides, &weight.shape.dims), - TensorHandle::new(&bias.handle, &bias.strides, &bias.shape.dims), + TensorHandleRef::::new(&input.handle, &input.strides, &input.shape.dims), + TensorHandleRef::new(&weight.handle, &weight.strides, &weight.shape.dims), + TensorHandleRef::new(&bias.handle, &bias.strides, &bias.shape.dims), ]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index d2a8ffae25..dcd632bf97 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -1,14 +1,14 @@ use crate::{ element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime, }; -use burn_cube::{ +use burn_tensor::ElementConversion; +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; -use burn_tensor::ElementConversion; use std::marker::PhantomData; #[derive(new)] @@ -136,12 +136,12 @@ pub(crate) fn flip_on_output( let kernel = FlipEagerKernel::::new(D); Execution::start(kernel, tensor.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &tensor.handle, &tensor.strides, &tensor.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/index/gather.rs b/crates/burn-jit/src/kernel/index/gather.rs index a58d32491a..f50cdc8653 100644 --- a/crates/burn-jit/src/kernel/index/gather.rs +++ b/crates/burn-jit/src/kernel/index/gather.rs @@ -1,11 +1,11 @@ use crate::{ element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime, }; -use burn_cube::ir::{ +use cubecl::ir::{ Elem, IndexOffsetGlobalWithLayout, IntKind, Item, KernelDefinition, Scope, Variable, Visibility, }; -use burn_cube::{ - cpa, frontend::TensorHandle, CubeCountSettings, Execution, InputInfo, KernelExpansion, +use cubecl::{ + cpa, frontend::TensorHandleRef, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; use std::marker::PhantomData; @@ -136,10 +136,10 @@ pub(crate) fn gather::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), - TensorHandle::new(&indices.handle, &indices.strides, &indices.shape.dims), + TensorHandleRef::::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), + TensorHandleRef::new(&indices.handle, &indices.strides, &indices.shape.dims), ]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/index/repeat.rs b/crates/burn-jit/src/kernel/index/repeat.rs index 111b985c43..1a215557c1 100644 --- a/crates/burn-jit/src/kernel/index/repeat.rs +++ b/crates/burn-jit/src/kernel/index/repeat.rs @@ -1,7 +1,7 @@ use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime}; -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -123,12 +123,12 @@ pub(crate) fn repeat( let kernel = RepeatEagerKernel::::new(dim, D1); Execution::start(kernel, input.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &input.handle, &input.strides, &input.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/index/scatter.rs b/crates/burn-jit/src/kernel/index/scatter.rs index 3438c4980e..ba9e1bb8d4 100644 --- a/crates/burn-jit/src/kernel/index/scatter.rs +++ b/crates/burn-jit/src/kernel/index/scatter.rs @@ -4,15 +4,15 @@ use crate::{ tensor::JitTensor, JitRuntime, }; -use burn_cube::{ - calculate_cube_count_elemwise, cpa, frontend::TensorHandle, CubeCountSettings, KernelExpansion, - KernelIntegrator, KernelSettings, +use cubecl::{ + calculate_cube_count_elemwise, cpa, frontend::TensorHandleRef, CubeCountSettings, + KernelExpansion, KernelIntegrator, KernelSettings, }; -use burn_cube::{ +use cubecl::{ ir::{Branch, Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility}, Execution, }; -use burn_cube::{InputInfo, SUBCUBE_DIM_APPROX}; +use cubecl::{InputInfo, SUBCUBE_DIM_APPROX}; use std::marker::PhantomData; #[derive(new)] @@ -227,9 +227,9 @@ pub(crate) fn scatter::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), - TensorHandle::new(&indices.handle, &indices.strides, &indices.shape.dims), - TensorHandle::new(&value.handle, &value.strides, &value.shape.dims), + TensorHandleRef::::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), + TensorHandleRef::new(&indices.handle, &indices.strides, &indices.shape.dims), + TensorHandleRef::new(&value.handle, &value.strides, &value.shape.dims), ]) .execute(CubeCountSettings::Custom(cube_count)); diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index 6e1a5782d6..54592af3b2 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -1,9 +1,9 @@ use crate::{ element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime, }; -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -128,14 +128,14 @@ pub(crate) fn select::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), + TensorHandleRef::::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), // This is a current hacks because the info buffer that contains the strides and shapes is // hardcoded to only contains information about tensors of the same rank. However, since // we don't rely on the shape and stride of the indices tensors, it doesn't matter // which value we put, it just needs to be of the same rank. - TensorHandle::new(&indices.handle, &[1; D], &[1; D]), + TensorHandleRef::new(&indices.handle, &[1; D], &[1; D]), ]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index 78c48706e2..aadc4c8ac2 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -4,9 +4,9 @@ use crate::{ tensor::JitTensor, JitRuntime, }; -use burn_cube::{ +use cubecl::{ calculate_cube_count_elemwise, cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Branch, Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, }; @@ -212,11 +212,11 @@ pub(crate) fn select_assign::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), - TensorHandle::new(&value.handle, &value.strides, &value.shape.dims), + TensorHandleRef::::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), + TensorHandleRef::new(&value.handle, &value.strides, &value.shape.dims), // We use the custom strides here instead of the shape, since we don't use it in the // kernel, but we need to put the right number of dimensions (rank). - TensorHandle::new(&indices.handle, &strides, &strides), + TensorHandleRef::new(&indices.handle, &strides, &strides), ]) .execute(CubeCountSettings::Custom(cube_count)); diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index b0f15862c3..add5f9bd76 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -1,14 +1,14 @@ use crate::{ element::JitElement, kernel::Kernel, ops::numeric::empty_device, tensor::JitTensor, JitRuntime, }; -use burn_cube::{ +use burn_tensor::{ElementConversion, Shape}; +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; -use burn_tensor::{ElementConversion, Shape}; use std::{marker::PhantomData, ops::Range}; #[derive(new)] @@ -134,12 +134,12 @@ pub(crate) fn slice_on_output::new(D1); Execution::start(kernel, tensor.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &tensor.handle, &tensor.strides, &tensor.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/index/slice_assign.rs b/crates/burn-jit/src/kernel/index/slice_assign.rs index 8ccdb18e87..61445c83d4 100644 --- a/crates/burn-jit/src/kernel/index/slice_assign.rs +++ b/crates/burn-jit/src/kernel/index/slice_assign.rs @@ -1,11 +1,11 @@ use crate::{element::JitElement, kernel::Kernel, tensor::JitTensor, JitRuntime}; -use burn_cube::{ +use burn_tensor::ElementConversion; +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, }; -use burn_tensor::ElementConversion; use std::{marker::PhantomData, ops::Range}; #[derive(new)] @@ -138,8 +138,8 @@ pub(crate) fn slice_assign::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), - TensorHandle::new(&value.handle, &value.strides, &value.shape.dims), + TensorHandleRef::::new(&tensor.handle, &tensor.strides, &tensor.shape.dims), + TensorHandleRef::new(&value.handle, &value.strides, &value.shape.dims), ]) .with_scalars(&scalars) .execute(CubeCountSettings::Input { pos: 0 }); diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index 60527f2985..97fedcdfe0 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -1,6 +1,6 @@ -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -412,12 +412,12 @@ pub(crate) fn interpolate_bicubic_launch( let kernel = InterpolateBicubicEagerKernel::::new(); Execution::start(kernel, input.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &input.handle, &input.strides, &input.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/interpolate/bilinear.rs b/crates/burn-jit/src/kernel/interpolate/bilinear.rs index cbe2f1a63b..eefebf3a2b 100644 --- a/crates/burn-jit/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-jit/src/kernel/interpolate/bilinear.rs @@ -1,6 +1,6 @@ -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -225,12 +225,12 @@ pub(crate) fn interpolate_bilinear_launch( let kernel = InterpolateBilinearEagerKernel::::new(); Execution::start(kernel, input.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &input.handle, &input.strides, &input.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/interpolate/nearest.rs b/crates/burn-jit/src/kernel/interpolate/nearest.rs index 9607c6dda0..2f5dc4554f 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest.rs @@ -1,6 +1,6 @@ -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -168,12 +168,12 @@ pub(crate) fn interpolate_nearest_launch( let kernel = InterpolateNearestEagerKernel::::new(); Execution::start(kernel, input.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &input.handle, &input.strides, &input.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs index d7ecc92353..d4503a683f 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs @@ -1,6 +1,6 @@ -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -226,12 +226,12 @@ pub(crate) fn interpolate_nearest_backward_launch( let kernel = InterpolateNearestBackwardEagerKernel::::new(); Execution::start(kernel, out_grad.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &out_grad.handle, &out_grad.strides, &out_grad.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/mask/mask_fill.rs b/crates/burn-jit/src/kernel/mask/mask_fill.rs index d026d83db4..b10f2c0d79 100644 --- a/crates/burn-jit/src/kernel/mask/mask_fill.rs +++ b/crates/burn-jit/src/kernel/mask/mask_fill.rs @@ -1,4 +1,4 @@ -use burn_cube::{frontend::TensorHandle, CubeCountSettings, Execution}; +use cubecl::{frontend::TensorHandleRef, CubeCountSettings, Execution}; use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; @@ -46,10 +46,10 @@ fn mask_fill_readonly::new(&input.handle, &input.strides, &input.shape.dims), - TensorHandle::new(&mask.handle, &mask.strides, &mask.shape.dims), + TensorHandleRef::::new(&input.handle, &input.strides, &input.shape.dims), + TensorHandleRef::new(&mask.handle, &mask.strides, &mask.shape.dims), ]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, @@ -71,8 +71,8 @@ fn mask_fill_inplace::new(&input.handle, &input.strides, &input.shape.dims), - TensorHandle::new(&mask.handle, &mask.strides, &mask.shape.dims), + TensorHandleRef::::new(&input.handle, &input.strides, &input.shape.dims), + TensorHandleRef::new(&mask.handle, &mask.strides, &mask.shape.dims), ]) .with_scalars(&[value]) .execute(CubeCountSettings::Input { pos: 0 }); diff --git a/crates/burn-jit/src/kernel/mask/mask_where.rs b/crates/burn-jit/src/kernel/mask/mask_where.rs index c3be34cb4b..79bd996ebf 100644 --- a/crates/burn-jit/src/kernel/mask/mask_where.rs +++ b/crates/burn-jit/src/kernel/mask/mask_where.rs @@ -1,4 +1,4 @@ -use burn_cube::{frontend::TensorHandle, CubeCountSettings, Execution}; +use cubecl::{frontend::TensorHandleRef, CubeCountSettings, Execution}; use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; @@ -49,11 +49,11 @@ fn mask_where_readonly::new(&input.handle, &input.strides, &input.shape.dims), - TensorHandle::new(&mask.handle, &mask.strides, &mask.shape.dims), - TensorHandle::new(&value.handle, &value.strides, &value.shape.dims), + TensorHandleRef::::new(&input.handle, &input.strides, &input.shape.dims), + TensorHandleRef::new(&mask.handle, &mask.strides, &mask.shape.dims), + TensorHandleRef::new(&value.handle, &value.strides, &value.shape.dims), ]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, @@ -75,9 +75,9 @@ fn mask_where_inplace::new(&input.handle, &input.strides, &input.shape.dims), - TensorHandle::new(&mask.handle, &mask.strides, &mask.shape.dims), - TensorHandle::new(&value.handle, &value.strides, &value.shape.dims), + TensorHandleRef::::new(&input.handle, &input.strides, &input.shape.dims), + TensorHandleRef::new(&mask.handle, &mask.strides, &mask.shape.dims), + TensorHandleRef::new(&value.handle, &value.strides, &value.shape.dims), ]) .execute(CubeCountSettings::Input { pos: 0 }); diff --git a/crates/burn-jit/src/kernel/mask/shader.rs b/crates/burn-jit/src/kernel/mask/shader.rs index b24bc0d905..0689b2ee9c 100644 --- a/crates/burn-jit/src/kernel/mask/shader.rs +++ b/crates/burn-jit/src/kernel/mask/shader.rs @@ -1,4 +1,4 @@ -use burn_cube::{ +use cubecl::{ cpa, ir::{Elem, IndexOffsetGlobalWithLayout, Item, KernelDefinition, Scope, Variable, Visibility}, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 009cd8081c..2e1d3070b5 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -1,11 +1,7 @@ +use super::{init_matmul_output, matmul_simple}; use crate::{tensor::JitTensor, FloatElement, JitRuntime}; -use burn_cube::prelude::*; use burn_tensor::Shape; - -use super::{ - config::Tiling2dConfig, init_matmul_output, matmul_simple, matmul_tiling_2d, - matmul_tiling_2d_cube, matmul_tiling_2d_padded, -}; +use cubecl::prelude::*; #[cfg(feature = "autotune")] use super::matmul_autotune; @@ -19,15 +15,11 @@ pub enum MatmulStrategy { /// Number of invocations in y grid_y: usize, }, - /// A tiling 2d kernel will be used, with support for any matrix size without padding. - Tiling2d(Tiling2dConfig), - /// A tiling 2d kernel will be used, with support for any matrix size with padding. - Tiling2dPadded(Tiling2dConfig), #[cfg(feature = "autotune")] /// Using autotune to chose the best kernel based on runtime information. Autotune, - /// A tiling 2d kernel with everything vectorized, and comptime bound checks - Tiling2dCube(Tiling2dConfig), + /// Cube implementation of matmul. + Cube, } impl Default for MatmulStrategy { @@ -37,7 +29,7 @@ impl Default for MatmulStrategy { return MatmulStrategy::Autotune; #[cfg(not(feature = "autotune"))] - MatmulStrategy::Tiling2d(Tiling2dConfig::default()) + MatmulStrategy::Cube } } @@ -52,17 +44,16 @@ pub fn matmul( let out = init_matmul_output(&lhs, &rhs); matmul_simple(lhs, rhs, out, grid_x, grid_y) } - MatmulStrategy::Tiling2d(config) => { - let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d(lhs, rhs, out, config) - } - MatmulStrategy::Tiling2dPadded(config) => { - let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d_padded(lhs, rhs, out, config) - } - MatmulStrategy::Tiling2dCube(config) => { - let out = init_matmul_output(&lhs, &rhs); - matmul_tiling_2d_cube(lhs, rhs, out, config) + MatmulStrategy::Cube => { + let out = init_matmul_output::(&lhs, &rhs); + let client = &lhs.client; + cubecl::linalg::matmul::launch_ref::( + client, + lhs.as_handle_ref(), + rhs.as_handle_ref(), + out.as_handle_ref(), + ); + out } #[cfg(feature = "autotune")] MatmulStrategy::Autotune => matmul_autotune(lhs, rhs), diff --git a/crates/burn-jit/src/kernel/matmul/config.rs b/crates/burn-jit/src/kernel/matmul/config.rs deleted file mode 100644 index 22b15f5f1f..0000000000 --- a/crates/burn-jit/src/kernel/matmul/config.rs +++ /dev/null @@ -1,130 +0,0 @@ -use burn_cube::{ - compute::CubeCount, - frontend::{CubeContext, Init, UInt}, - ir::CubeDim, -}; -use burn_tensor::Shape; - -use crate::JitRuntime; - -#[derive(Debug, Clone)] -/// Tiling 2D parameters -pub struct Tiling2dConfig { - /// Block size along dimension of lhs - pub block_size_m: usize, - /// Block size along common dimension - pub block_size_k: usize, - /// Block size along dimension of rhs - pub block_size_n: usize, - /// Tile size and shared memory vectorization - pub tile_size: usize, - /// Loop unrolling - pub unroll: bool, -} - -impl Default for Tiling2dConfig { - fn default() -> Self { - Self { - block_size_m: 64, - block_size_k: 32, - block_size_n: 64, - tile_size: 4, - unroll: false, - } - } -} - -impl Init for CubeTiling2dConfig { - fn init(self, _context: &mut CubeContext) -> Self { - self - } -} - -#[derive(Debug, Clone, Copy)] -/// Tiling 2D parameters -pub struct CubeTiling2dConfig { - /// Block size along dimension of lhs - pub block_size_m: UInt, - /// Block size along common dimension - pub block_size_k: UInt, - /// Block size along dimension of rhs - pub block_size_n: UInt, - /// Loop unrolling for inner compute loop. Probably slower - pub unroll_compute: bool, - /// Loop unrolling for all loops related to vectorization/tile size. Probably faster - pub unroll_tile: bool, - /// Bounds must be checked on lhs dimension - pub check_m_bounds: bool, - /// Bounds must be checked on common dimension - pub check_k_bounds: bool, - /// Bounds must be checked on rhs dimension - pub check_n_bounds: bool, - /// Tile size. Should correspond to vectorization of inputs/outputs/shared memory - pub tile_size: UInt, - /// Lhs is transposed in global memory - pub lhs_transposed: bool, - /// Rhs is transposed in global memory - pub rhs_transposed: bool, -} - -impl CubeTiling2dConfig { - pub fn new( - config: &Tiling2dConfig, - m: usize, - k: usize, - n: usize, - lhs_transposed: bool, - rhs_transposed: bool, - ) -> Self { - assert!( - config.block_size_k <= config.block_size_m - && config.block_size_k <= config.block_size_n, - "Larger block size in k than m or n results in unfilled shared memory." - ); - assert!( - config.block_size_m % config.tile_size == 0 - && config.block_size_k % config.tile_size == 0 - && config.block_size_n % config.tile_size == 0, - "Tiling 2d algorithm assumes tile size divides block size perfectly. " - ); - - CubeTiling2dConfig { - block_size_m: UInt::new(config.block_size_m as u32), - block_size_k: UInt::new(config.block_size_k as u32), - block_size_n: UInt::new(config.block_size_n as u32), - unroll_compute: config.unroll, - unroll_tile: true, - check_m_bounds: m % config.block_size_m != 0, - check_k_bounds: k % config.block_size_k != 0, - check_n_bounds: n % config.block_size_n != 0, - tile_size: UInt::new(config.tile_size as u32), - lhs_transposed, - rhs_transposed, - } - } -} - -pub fn tiling2d_cube_count( - output_shape: &Shape, - config: &Tiling2dConfig, -) -> CubeCount { - let num_rows = output_shape.dims[D - 2]; - let num_cols = output_shape.dims[D - 1]; - - let cubes_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32; - let cubes_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32; - let mut num_iter = 1; - for i in 0..D - 2 { - num_iter *= output_shape.dims[i]; - } - - CubeCount::Static(cubes_x, cubes_y, num_iter as u32) -} - -pub fn tiling2d_cube_dim(config: &Tiling2dConfig) -> CubeDim { - CubeDim::new( - (config.block_size_m / config.tile_size) as u32, - (config.block_size_n / config.tile_size) as u32, - 1, - ) -} diff --git a/crates/burn-jit/src/kernel/matmul/mod.rs b/crates/burn-jit/src/kernel/matmul/mod.rs index e1e4e30daa..633743564b 100644 --- a/crates/burn-jit/src/kernel/matmul/mod.rs +++ b/crates/burn-jit/src/kernel/matmul/mod.rs @@ -1,13 +1,5 @@ mod base; -mod config; mod simple; -mod tiling2d; -#[cfg(not(feature = "export_tests"))] -mod tiling2d_cube; -#[cfg(feature = "export_tests")] -/// Tiling 2d cube functions -pub mod tiling2d_cube; -mod tiling2d_shader; mod tune; /// Contains utilitary for matmul operation @@ -17,14 +9,3 @@ pub use base::*; pub use simple::*; pub use tune::*; pub use utils::*; - -#[cfg(feature = "export_tests")] -#[allow(missing_docs)] -pub mod padding; - -#[cfg(not(feature = "export_tests"))] -mod padding; - -pub use config::Tiling2dConfig; -pub use tiling2d::*; -pub use tiling2d_cube::*; diff --git a/crates/burn-jit/src/kernel/matmul/padding.rs b/crates/burn-jit/src/kernel/matmul/padding.rs deleted file mode 100644 index 34b7c28cab..0000000000 --- a/crates/burn-jit/src/kernel/matmul/padding.rs +++ /dev/null @@ -1,96 +0,0 @@ -use crate::{ - element::JitElement, - kernel::{slice_assign, slice_on_output}, - ops::numeric::zeros_device, - tensor::JitTensor, - JitRuntime, -}; -use burn_tensor::{Element, Shape}; -use std::ops::Range; - -// Output of the pad_round function. Allows to know explicitly if early return occurred -pub enum PaddingOutput { - Padded(JitTensor), - Unchanged(JitTensor), -} - -impl PaddingOutput { - pub fn into_tensor(self) -> JitTensor { - match self { - PaddingOutput::Padded(tensor) => tensor, - PaddingOutput::Unchanged(tensor) => tensor, - } - } -} - -/// Pads tensor with zeros to make tensor number of rows and columns -/// divisible by some quantity. -/// For instance tensor of shape [1000, 1000] with divisors 64 and 64 -/// will be padded to [1024, 1024] with the last 24 elements being zeros -pub fn pad_round( - tensor: JitTensor, - row_divisor: usize, - col_divisor: usize, -) -> PaddingOutput { - let previous_row_dim = tensor.shape.dims[D - 2]; - let previous_col_dim = tensor.shape.dims[D - 1]; - let row_modulo = previous_row_dim % row_divisor; - let col_modulo = previous_col_dim % col_divisor; - - let new_row_dim = match row_modulo { - 0 => previous_row_dim, - _ => previous_row_dim + row_divisor - row_modulo, - }; - let new_col_dim = match col_modulo { - 0 => previous_col_dim, - _ => previous_col_dim + col_divisor - col_modulo, - }; - if previous_row_dim == new_row_dim && previous_col_dim == new_col_dim { - return PaddingOutput::Unchanged(tensor); - } - - let mut padded_shape = Vec::with_capacity(D); - for i in 0..D - 2 { - padded_shape.push(tensor.shape.dims[i]); - } - padded_shape.push(new_row_dim); - padded_shape.push(new_col_dim); - - PaddingOutput::Padded(padding::(tensor, padded_shape.into())) -} - -/// Pads tensor by adding zeros when padded dim is larger than tensor dim -pub fn padding( - tensor: JitTensor, - padded_shape: Shape, -) -> JitTensor { - let ranges = padded_shape - .dims - .iter() - .map(|dim| 0..*dim) - .collect::>>() - .try_into() - .unwrap(); - - slice_assign::( - zeros_device::(tensor.client.clone(), tensor.device.clone(), padded_shape), - ranges, - tensor, - ) -} - -/// Crops tensor by deleting values when cropped dim is smaller than tensor dim -pub fn crop( - tensor: JitTensor, - output: JitTensor, -) -> JitTensor { - let ranges = output - .shape - .dims - .iter() - .map(|dim| 0..*dim) - .collect::>>() - .try_into() - .unwrap(); - slice_on_output::(tensor, output, ranges) -} diff --git a/crates/burn-jit/src/kernel/matmul/simple.rs b/crates/burn-jit/src/kernel/matmul/simple.rs index 63c68f58df..28c23623e4 100644 --- a/crates/burn-jit/src/kernel/matmul/simple.rs +++ b/crates/burn-jit/src/kernel/matmul/simple.rs @@ -4,11 +4,11 @@ use crate::{ tensor::JitTensor, FloatElement, JitRuntime, }; -use burn_cube::ir::KernelDefinition; -use burn_cube::{frontend::TensorArg, KernelSettings}; +use cubecl::ir::KernelDefinition; +use cubecl::{frontend::TensorArg, KernelSettings}; use super::simple_cube_count; -use burn_cube::prelude::*; +use cubecl::prelude::*; #[cube(launch)] fn matmul_kernel( @@ -117,7 +117,7 @@ pub fn matmul_simple( }; matmul_kernel::launch::( - lhs.client, + &lhs.client, cube_count, CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1), TensorArg::vectorized( diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs deleted file mode 100644 index 74807f44f3..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ /dev/null @@ -1,197 +0,0 @@ -use burn_cube::{ - frontend::TensorHandle, - ir::{BinaryOperator, Elem, FloatKind, KernelDefinition, Scope, Variable, Visibility}, - CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, - OutputInfo, -}; -use burn_tensor::{Element, Shape}; - -use crate::{ - element::JitElement, - kernel::{into_contiguous, matmul::config::tiling2d_cube_dim, Kernel}, - tensor::{JitTensor, MatrixLayout}, - JitRuntime, -}; -use std::marker::PhantomData; - -use super::{ - config::tiling2d_cube_count, - padding::{crop, pad_round, PaddingOutput}, - shape_out, - tiling2d_shader::MatmulTiling2dShader, -}; -use crate::kernel::matmul::config::Tiling2dConfig; - -#[derive(new, Debug)] -struct MatmulTiling2dEagerKernel { - config: Tiling2dConfig, - bounds_check_required: bool, - _runtime: PhantomData, - _elem: PhantomData, -} - -impl Kernel for MatmulTiling2dEagerKernel { - fn define(&self) -> KernelDefinition { - let mut scope = Scope::root(); - let elem = E::cube_elem(); - assert!( - elem == Elem::Float(FloatKind::F32) || elem == Elem::Float(FloatKind::F64), - "Only float elements are supported." - ); - let item = elem.into(); - - let lhs = Variable::GlobalInputArray { id: 0, item }; - let rhs = Variable::GlobalInputArray { id: 1, item }; - let out = Variable::GlobalOutputArray { id: 0, item }; - - scope.write_global_custom(out); - - MatmulTiling2dShader { - variables: BinaryOperator { lhs, rhs, out }, - config: self.config.clone(), - bounds_check_required: self.bounds_check_required, - } - .expand(&mut scope); - - let lhs = InputInfo::Array { - item, - visibility: Visibility::Read, - }; - let rhs = InputInfo::Array { - item, - visibility: Visibility::Read, - }; - let out = OutputInfo::Array { item }; - - let info = KernelExpansion { - inputs: vec![lhs, rhs], - outputs: vec![out], - scope, - }; - - let settings = KernelSettings::default().cube_dim(tiling2d_cube_dim(&self.config)); - KernelIntegrator::new(info).integrate(settings) - } - - fn id(&self) -> String { - format!( - "{:?}config={:?}boundcheck={:?}", - core::any::TypeId::of::(), - self.config, - self.bounds_check_required - ) - } -} - -/// Matrix multiplication using tiling 2d algorithm with -/// vec4 primitive on both lhs and rhs, with no padding needed -pub fn matmul_tiling_2d( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, - config: Tiling2dConfig, -) -> JitTensor { - let bounds_check_required = check_bound_requirement(&lhs.shape, &rhs.shape, &config); - - let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), bounds_check_required); - let client = lhs.client.clone(); - - let check_layout = |tensor: JitTensor| match tensor.matrix_layout() { - MatrixLayout::Contiguous => (tensor, false), - MatrixLayout::MildlyPermuted { - transposed, - batch_swap: _, - } => (tensor, transposed), - MatrixLayout::HighlyPermuted => (into_contiguous(tensor), false), - }; - let (lhs, _lhs_transposed) = check_layout(lhs); - let (rhs, _rhs_transposed) = check_layout(rhs); - - Execution::start(kernel, client) - .inputs(&[ - TensorHandle::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), - TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), - ]) - .outputs(&[TensorHandle::new( - &out.handle, - &out.strides, - &out.shape.dims, - )]) - .execute(CubeCountSettings::Custom(tiling2d_cube_count::( - &out.shape, &config, - ))); - - out -} - -/// Matrix multiplication using tiling 2d algorithm with padding needed -pub fn matmul_tiling_2d_padded( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, - config: Tiling2dConfig, -) -> JitTensor { - let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), false); - let client = lhs.client.clone(); - - // A tensor may need to be padded, in which case it will implicitly become contiguous - // If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim. - // If batches were swapped among themselves, or if the last two dims are transposed, the underlying - // kernel handles it without needing to turn it into contiguous. - let round_lhs = pad_round::(lhs, config.block_size_m, config.block_size_k); - let lhs = match round_lhs { - PaddingOutput::Unchanged(tensor) - if tensor.matrix_layout() == MatrixLayout::HighlyPermuted => - { - into_contiguous(tensor) - } - _ => round_lhs.into_tensor(), - }; - let round_rhs = pad_round::(rhs, config.block_size_k, config.block_size_n); - let rhs = match round_rhs { - PaddingOutput::Unchanged(tensor) - if tensor.matrix_layout() == MatrixLayout::HighlyPermuted => - { - into_contiguous(tensor) - } - _ => round_rhs.into_tensor(), - }; - - let rounded_output_shape = shape_out(&lhs, &rhs); - - let num_elems = rounded_output_shape.num_elements(); - let buffer = client.empty(num_elems * core::mem::size_of::()); - let rounded_output = JitTensor::new_contiguous( - rhs.client.clone(), - rhs.device.clone(), - rounded_output_shape.clone(), - buffer, - ); - - Execution::start(kernel, client) - .inputs(&[ - TensorHandle::::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), - TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), - ]) - .outputs(&[TensorHandle::new( - &rounded_output.handle, - &rounded_output.strides, - &rounded_output.shape.dims, - )]) - .execute(CubeCountSettings::Custom(tiling2d_cube_count::( - &rounded_output.shape, - &config, - ))); - - crop(rounded_output, out) -} - -fn check_bound_requirement( - lhs_shape: &Shape, - rhs_shape: &Shape, - config: &Tiling2dConfig, -) -> bool { - lhs_shape.dims[D - 2] % config.block_size_m != 0 - || lhs_shape.dims[D - 1] % config.block_size_k != 0 - || rhs_shape.dims[D - 1] % config.block_size_n != 0 -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs deleted file mode 100644 index e46f9a8c10..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs +++ /dev/null @@ -1,156 +0,0 @@ -use burn_cube::prelude::*; - -use crate::kernel::matmul::config::CubeTiling2dConfig; - -use super::block_loop::block_loop; - -#[cube(launch)] -#[allow(unused_mut)] -pub fn tiling2d_cube_kernel( - lhs: &Tensor, - rhs: &Tensor, - out: &mut Tensor, - config: Comptime, -) { - let dims = get_dims::(lhs, rhs); - let coordinates = calculate_coordinates(CUBE_POS_X, CUBE_POS_Y, UNIT_POS, config); - let offsets = calculate_batch_offsets::(lhs, rhs, out, CUBE_POS_Z); - let shared_memories = make_shared_memories::(config); - block_loop::( - lhs, - rhs, - out, - coordinates, - offsets, - shared_memories, - config, - dims, - ); -} - -#[derive(CubeType, Copy, Clone)] -/// Information available at runtime only -/// Strides assume contiguous -pub(crate) struct Dimensions { - pub m: UInt, - pub k: UInt, - pub n: UInt, -} - -#[derive(CubeType, Copy, Clone)] -pub(crate) struct SharedMemories { - pub lhs: SharedMemory, - pub rhs: SharedMemory, -} - -#[derive(CubeType, Copy, Clone)] -/// Number of elements in previous batches -/// Not divided by vectorization facto -pub(crate) struct BatchOffsets { - pub lhs: UInt, - pub rhs: UInt, - pub out: UInt, -} - -#[derive(CubeType, Copy, Clone)] -pub(crate) struct Coordinates { - pub unit_row: UInt, - pub unit_col: UInt, - pub skip_row: UInt, - pub skip_col: UInt, -} - -#[cube] -fn get_dims(lhs: &Tensor, rhs: &Tensor) -> Dimensions { - let rank = lhs.rank(); - let first_dim = rank - UInt::new(2); - let second_dim = rank - UInt::new(1); - let m = lhs.shape(first_dim); - let k = lhs.shape(second_dim); - let n = rhs.shape(second_dim); - - Dimensions { m, k, n } -} - -#[cube] -fn calculate_coordinates( - cube_pos_x: UInt, - cube_pos_y: UInt, - unit_pos: UInt, - config: Comptime, -) -> Coordinates { - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_n = Comptime::map(config, |c| c.block_size_n); - let tile_size = Comptime::map(config, |c| c.tile_size); - - let n_units_per_row = ((Comptime::runtime(block_size_n) - UInt::new(1)) - / Comptime::runtime(tile_size)) - + UInt::new(1); - - // Cube offset - let skip_row = cube_pos_x * Comptime::runtime(block_size_m); - let skip_col = cube_pos_y * Comptime::runtime(block_size_n); - - // Position of the first element of the unit, relative to the cube - let unit_row = (unit_pos / n_units_per_row) * Comptime::runtime(tile_size); - let unit_col = (unit_pos % n_units_per_row) * Comptime::runtime(tile_size); - - Coordinates { - unit_row, - unit_col, - skip_row, - skip_col, - } -} - -#[cube] -#[allow(unused_mut)] -fn calculate_batch_offsets( - lhs: &Tensor, - rhs: &Tensor, - out: &Tensor, - batch_number: UInt, -) -> BatchOffsets { - let rank = out.rank(); - - let dim_m = lhs.shape(rank - UInt::new(2)); - let dim_n = rhs.shape(rank - UInt::new(1)); - - // Batch offset for output - let mut offset_out = dim_m * dim_n * batch_number; - let mut offset_lhs = UInt::new(0); - let mut offset_rhs = UInt::new(0); - - // Batch offset for lhs, rhs - for b in range(0u32, rank - UInt::new(2), Comptime::new(false)) { - let tmp = offset_out / out.stride(b); - offset_lhs += tmp % lhs.shape(b) * lhs.stride(b); - offset_rhs += tmp % rhs.shape(b) * rhs.stride(b); - } - - BatchOffsets { - lhs: offset_lhs, - rhs: offset_rhs, - out: offset_out, - } -} - -#[cube] -fn make_shared_memories(config: Comptime) -> SharedMemories { - let tile_size = Comptime::map(config, |c| c.tile_size); - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_k = Comptime::map(config, |c| c.block_size_k); - let block_size_n = Comptime::map(config, |c| c.block_size_n); - - let lhs = SharedMemory::::vectorized( - Comptime::get(block_size_k * block_size_m / tile_size), - Comptime::get(tile_size), - ); - - let rhs = SharedMemory::::vectorized( - Comptime::get(block_size_k * block_size_n / tile_size), - Comptime::get(tile_size), - ); - - SharedMemories { lhs, rhs } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs deleted file mode 100644 index 110902db07..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs +++ /dev/null @@ -1,82 +0,0 @@ -use burn_cube::prelude::*; - -use crate::kernel::matmul::config::CubeTiling2dConfig; - -use super::{ - base::{BatchOffsets, Coordinates, Dimensions, SharedMemories}, - compute_loop::compute_loop, - load_shared_memory::load_to_shared_memories, - tile::{loader::TileLoader, writer::TileWriter}, - write_output::write_to_output, -}; - -#[cube] -pub(crate) fn block_loop( - lhs: &Tensor, - rhs: &Tensor, - out: &mut Tensor, - coordinates: Coordinates, - offsets: BatchOffsets, - shared: SharedMemories, - config: Comptime, - dims: Dimensions, -) { - let block_size_k = Comptime::map(config, |c| c.block_size_k); - let mut results = init_results::(config); - - let n_loops = calculate_n_loops::(lhs.shape(lhs.rank() - UInt::new(1)), config); - - for k in range(0u32, n_loops, Comptime::new(false)) { - let k = k * Comptime::runtime(block_size_k); - - load_to_shared_memories::>( - lhs, - rhs, - coordinates, - k, - offsets, - shared, - config, - dims, - ); - - sync_units(); - - compute_loop::(coordinates, shared.lhs, shared.rhs, &mut results, config); - - sync_units(); - } - - write_to_output::>(out, &results, coordinates, offsets.out, dims, config); -} - -#[cube] -fn init_results(config: Comptime) -> Array { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - - let mut results = Array::::new(Comptime::get(tile_size * tile_size)); - for i in range(0u32, Comptime::get(tile_size * tile_size), unroll) { - results[i] = F::new(0.); - } - - results -} - -#[cube] -#[allow(unused_assignments)] -fn calculate_n_loops(dim_k: UInt, config: Comptime) -> UInt { - let block_size_k = Comptime::map(config, |c| c.block_size_k); - let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); - - let mut n_loops = UInt::new(0); // TODO support syntax let x = if ... else ... - if Comptime::get(check_k_bounds) { - n_loops = UInt::cast_from(F::ceil( - F::cast_from(dim_k) / F::cast_from(Comptime::runtime(block_size_k)), - )); - } else { - n_loops = dim_k / Comptime::runtime(block_size_k); - } - - n_loops -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs deleted file mode 100644 index 29fe627f22..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs +++ /dev/null @@ -1,152 +0,0 @@ -use burn_cube::prelude::*; - -use crate::kernel::matmul::config::CubeTiling2dConfig; - -use super::{base::Coordinates, outer_product::tile_outer_product}; - -#[cube] -#[allow(unused_mut)] -pub(crate) fn compute_loop( - coordinates: Coordinates, - shared_lhs: SharedMemory, - shared_rhs: SharedMemory, - results: &mut Array, - config: Comptime, -) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_k = Comptime::runtime(Comptime::map(config, |c| c.block_size_k)); - let block_size_n = Comptime::map(config, |c| c.block_size_n); - let unroll = Comptime::map(config, |c| c.unroll_compute); - - let unit_row = coordinates.unit_row; - let unit_col = coordinates.unit_col; - - for dot_index in range(0u32, block_size_k, unroll) { - let register_m = shared_lhs[(unit_row + dot_index * Comptime::runtime(block_size_m)) - / Comptime::runtime(tile_size)]; - let register_n = shared_rhs[(unit_col + dot_index * Comptime::runtime(block_size_n)) - / Comptime::runtime(tile_size)]; - - tile_outer_product::(register_m, register_n, results, config); - } -} - -#[cfg(feature = "export_tests")] -/// Compute loop exported tests -pub mod tests { - use crate::{ - kernel::matmul::{ - config::CubeTiling2dConfig, - tiling2d_cube::test_utils::{ - assert_equals, create_empty, make_config, range_tensor, range_tensor_transposed, - TILE_SIZE, - }, - }, - JitRuntime, - }; - - use super::{super::base::CoordinatesExpand, *}; - - #[cube(launch)] - fn compute_loop_test( - lhs: Tensor, - rhs: Tensor, - unit_row: UInt, - unit_col: UInt, - results: &mut Array, - config: Comptime, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_k = Comptime::map(config, |c| c.block_size_m); - let block_size_n = Comptime::map(config, |c| c.block_size_m); - let sm_size_lhs = block_size_m * block_size_k / tile_size; - let sm_size_rhs = block_size_n * block_size_k / tile_size; - - // Shared memories are not launchable, so we launch with tensor and convert to shared memory - let mut shared_lhs = - SharedMemory::::vectorized(Comptime::get(sm_size_lhs), Comptime::get(tile_size)); - for i in range(0u32, lhs.len(), Comptime::new(false)) { - shared_lhs[i] = lhs[i]; - } - - let mut shared_rhs = - SharedMemory::::vectorized(Comptime::get(sm_size_rhs), Comptime::get(tile_size)); - for i in range(0u32, rhs.len(), Comptime::new(false)) { - shared_rhs[i] = rhs[i]; - } - - for i in range(0u32, 16u32, Comptime::new(false)) { - results[i] = F::new(0.); - } - - let coordinates = Coordinates { - unit_row, - unit_col, - skip_row: UInt::new(0), - skip_col: UInt::new(0), - }; - - compute_loop(coordinates, shared_lhs, shared_rhs, results, config) - } - - /// Exported test - pub fn compute_loop_unit_test(device: &R::Device) { - let lhs = range_tensor::(8, 8, device); - let rhs = range_tensor::(8, 8, device); - let results = create_empty::(TILE_SIZE, TILE_SIZE, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - const SOME_DIM: usize = 12; - let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM); - - compute_loop_test::launch::( - lhs.client.clone(), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape.dims), - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims), - ScalarArg::new(0), - ScalarArg::new(0), - ArrayArg::new(&results, 1), - config, - ); - - let expected = &[ - 8960.0, 9184.0, 9408.0, 9632.0, 9184.0, 9416.0, 9648.0, 9880.0, 9408.0, 9648.0, 9888.0, - 10128.0, 9632.0, 9880.0, 10128.0, 10376.0, - ]; - assert_equals::(results, expected, device); - } - - /// Exported test - pub fn compute_loop_unit_offset_test(device: &R::Device) { - let lhs = range_tensor_transposed::(8, 4, device); - let rhs = range_tensor::(4, 8, device); - let results = create_empty::(TILE_SIZE, TILE_SIZE, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(4, 8, 4); - - compute_loop_test::launch::( - lhs.client.clone(), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape.dims), - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims), - ScalarArg::new(4), - ScalarArg::new(4), - ArrayArg::new(&results, 1), - config, - ); - - let expected = &[ - 1160.0, 1230.0, 1300.0, 1370.0, 1416.0, 1502.0, 1588.0, 1674.0, 1672.0, 1774.0, 1876.0, - 1978.0, 1928.0, 2046.0, 2164.0, 2282.0, - ]; - assert_equals::(results, expected, device); - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/launch.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/launch.rs deleted file mode 100644 index 9f2606f09a..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/launch.rs +++ /dev/null @@ -1,97 +0,0 @@ -use std::cmp::max; - -use burn_cube::{frontend::TensorArg, Compiler}; - -use crate::{ - kernel::{ - into_contiguous, - matmul::{ - config::{tiling2d_cube_count, tiling2d_cube_dim, CubeTiling2dConfig}, - tiling2d_cube::base::tiling2d_cube_kernel, - Tiling2dConfig, - }, - }, - tensor::{JitTensor, MatrixLayout}, - FloatElement, JitRuntime, -}; - -/// Matrix multiplication using tiling 2d algorithm -pub fn matmul_tiling_2d_cube( - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, - config: Tiling2dConfig, -) -> JitTensor { - assert!( - config.block_size_k * max(config.block_size_m, config.block_size_n) - <= ::max_shared_memory_size(), - "Shared memory limit will be busted. " - ); - - let m = lhs.shape.dims[D - 2]; - let k = lhs.shape.dims[D - 1]; - let n = rhs.shape.dims[D - 1]; - - let client = lhs.client.clone(); - - let check_layout = |tensor: JitTensor| match tensor.matrix_layout() { - MatrixLayout::Contiguous => (tensor, false), - MatrixLayout::MildlyPermuted { - transposed, - batch_swap: _, - } => (tensor, transposed), - MatrixLayout::HighlyPermuted => (into_contiguous(tensor), false), - }; - let (lhs, lhs_transposed) = check_layout(lhs); - let (rhs, rhs_transposed) = check_layout(rhs); - - let vectorization = |shape: usize| { - [4, 2] - .into_iter() - .filter(|v| shape % v == 0) - .map(|v| v as u8) - .next() - .unwrap_or(1) - }; - - let lhs_vectorization = match lhs_transposed { - true => vectorization(m), - false => 1, - }; - let rhs_vectorization = match rhs_transposed { - true => 1, - false => vectorization(n), - }; - let out_vectorization = vectorization(n); - - let cube_count = tiling2d_cube_count::(&out.shape, &config); - let cube_dim = tiling2d_cube_dim(&config); - let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed); - - tiling2d_cube_kernel::launch::( - client, - cube_count, - cube_dim, - TensorArg::vectorized( - lhs_vectorization, - &lhs.handle, - &lhs.strides, - &lhs.shape.dims, - ), - TensorArg::vectorized( - rhs_vectorization, - &rhs.handle, - &rhs.strides, - &rhs.shape.dims, - ), - TensorArg::vectorized( - out_vectorization, - &out.handle, - &out.strides, - &out.shape.dims, - ), - cube_config, - ); - - out -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs deleted file mode 100644 index dc9ce5e9ab..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs +++ /dev/null @@ -1,723 +0,0 @@ -use burn_cube::prelude::*; - -use crate::kernel::matmul::config::CubeTiling2dConfig; - -use super::{ - base::{BatchOffsets, Coordinates, Dimensions, SharedMemories}, - tile::block_io::{ - base::BlockLoader, horizontal_block_check::HorizontalCheckBlockIO, - unchecked_block::UncheckedBlockIO, vertical_block_check::VerticalCheckBlockIO, - whole_block_check::WholeCheckBlockIO, - }, -}; - -#[derive(CubeType)] -#[allow(dead_code)] -pub(crate) struct LoadInfo { - pub coordinates: Coordinates, - pub k: UInt, - pub batch_offset: UInt, - pub shared_memory: SharedMemory, - pub config: Comptime, - pub dims: Dimensions, -} - -#[cube] -pub(crate) trait Loader: Sync + Send + 'static { - fn load_lhs_plain>(lhs: &Tensor, load_info: LoadInfo); - fn load_lhs_transposed>(lhs: &Tensor, load_info: LoadInfo); - fn load_rhs_plain>(rhs: &Tensor, load_info: LoadInfo); - fn load_rhs_transposed>(rhs: &Tensor, load_info: LoadInfo); -} - -#[cube] -pub(crate) fn load_to_shared_memories>( - lhs: &Tensor, - rhs: &Tensor, - coordinates: Coordinates, - k: UInt, - offsets: BatchOffsets, - shared: SharedMemories, - config: Comptime, - dims: Dimensions, -) { - let lhs_transposed = Comptime::map(config, |c| c.lhs_transposed); - let rhs_transposed = Comptime::map(config, |c| c.rhs_transposed); - - let lhs_load_info = LoadInfo { - coordinates, - k, - batch_offset: offsets.lhs, - shared_memory: shared.lhs, - config, - dims, - }; - let rhs_load_info = LoadInfo { - coordinates, - k, - batch_offset: offsets.rhs, - shared_memory: shared.rhs, - config, - dims, - }; - - // Lhs must be loaded as transposed. If it already is transposed in global memory, we load as plain. - if Comptime::get(lhs_transposed) { - load_lhs_plain::(lhs, lhs_load_info, config); - } else { - load_lhs_transposed::(lhs, lhs_load_info, config); - } - - // Rhs must be loaded as plain. If it is transposed in global memory, we transpose it back. - if Comptime::get(rhs_transposed) { - load_rhs_transposed::(rhs, rhs_load_info, config); - } else { - load_rhs_plain::(rhs, rhs_load_info, config); - } -} - -#[cube] -fn load_lhs_transposed>( - lhs: &Tensor, - load_info: LoadInfo, - config: Comptime, -) { - let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds); - let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); - - if Comptime::get(check_m_bounds) { - if Comptime::get(check_k_bounds) { - L::load_lhs_transposed::(lhs, load_info); - } else { - L::load_lhs_transposed::(lhs, load_info); - } - } else if Comptime::get(check_k_bounds) { - L::load_lhs_transposed::(lhs, load_info); - } else { - L::load_lhs_transposed::(lhs, load_info); - } -} - -#[cube] -fn load_lhs_plain>( - lhs: &Tensor, - load_info: LoadInfo, - config: Comptime, -) { - let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds); - let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); - - if Comptime::get(check_k_bounds) { - if Comptime::get(check_m_bounds) { - L::load_lhs_plain::(lhs, load_info); - } else { - L::load_lhs_plain::(lhs, load_info); - } - } else if Comptime::get(check_m_bounds) { - L::load_lhs_plain::(lhs, load_info); - } else { - L::load_lhs_plain::(lhs, load_info); - } -} - -#[cube] -fn load_rhs_transposed>( - rhs: &Tensor, - load_info: LoadInfo, - config: Comptime, -) { - let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); - let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds); - - if Comptime::get(check_n_bounds) { - if Comptime::get(check_k_bounds) { - L::load_rhs_transposed::(rhs, load_info); - } else { - L::load_rhs_transposed::(rhs, load_info); - } - } else if Comptime::get(check_k_bounds) { - L::load_rhs_transposed::(rhs, load_info); - } else { - L::load_rhs_transposed::(rhs, load_info); - } -} - -#[cube] -fn load_rhs_plain>( - rhs: &Tensor, - load_info: LoadInfo, - config: Comptime, -) { - let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds); - let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds); - - if Comptime::get(check_k_bounds) { - if Comptime::get(check_n_bounds) { - L::load_rhs_plain::(rhs, load_info); - } else { - L::load_rhs_plain::(rhs, load_info); - } - } else if Comptime::get(check_n_bounds) { - L::load_rhs_plain::(rhs, load_info); - } else { - L::load_rhs_plain::(rhs, load_info); - } -} - -#[cfg(feature = "export_tests")] -/// Exported tests for loading to shared memory -pub mod tests { - use crate::kernel::matmul::tiling2d_cube::load_shared_memory::LoadInfoExpand; - use crate::kernel::matmul::tiling2d_cube::test_utils::{ - assert_equals, create_empty, make_config, range_tensor, TILE_SIZE, - }; - use crate::kernel::matmul::tiling2d_cube::tile::loader::TileLoader; - use crate::JitRuntime; - - use super::{ - super::base::{CoordinatesExpand, DimensionsExpand}, - *, - }; - - #[cube(launch)] - fn load_tensor_test( - tensor: &Tensor, - sm_out: &mut Array, - unit_row: UInt, - unit_col: UInt, - k: UInt, - config: Comptime, - is_lhs: Comptime, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let block_size_k = Comptime::map(config, |c| c.block_size_k); - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let sm_size = block_size_k * block_size_m / tile_size; - let shared_memory = - SharedMemory::::vectorized(Comptime::get(sm_size), Comptime::get(tile_size)); - - let batch_offset = UInt::new(0); - - let coordinates = Coordinates { - unit_row, - unit_col, - skip_row: UInt::new(0), - skip_col: UInt::new(0), - }; - - if Comptime::get(is_lhs) { - let dims = Dimensions { - m: tensor.shape(tensor.rank() - UInt::new(2)), - k: tensor.shape(tensor.rank() - UInt::new(1)), - n: UInt::new(0), - }; - let info = LoadInfo { - coordinates, - k, - batch_offset, - shared_memory, - config, - dims, - }; - - load_lhs_transposed::>(tensor, info, config); - } else { - let dims = Dimensions { - m: UInt::new(0), - k: tensor.shape(tensor.rank() - UInt::new(2)), - n: tensor.shape(tensor.rank() - UInt::new(1)), - }; - let info = LoadInfo { - coordinates, - k, - batch_offset, - shared_memory, - config, - dims, - }; - - load_rhs_plain::>(tensor, info, config); - } - - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { - sm_out[i] = shared_memory[i]; - } - } - - #[cube(launch)] - fn load_tensor_permuted_test( - tensor: &Tensor, - sm_out: &mut Array, - unit_row: UInt, - unit_col: UInt, - k: UInt, - config: Comptime, - is_lhs: Comptime, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let block_size_k = Comptime::map(config, |c| c.block_size_k); - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let sm_size = block_size_k * block_size_m / tile_size; - let shared_memory = - SharedMemory::::vectorized(Comptime::get(sm_size), Comptime::get(tile_size)); - - let batch_offset = UInt::new(0); - - let coordinates = Coordinates { - unit_row, - unit_col, - skip_row: UInt::new(0), - skip_col: UInt::new(0), - }; - - if Comptime::get(is_lhs) { - // Permuted - let dims = Dimensions { - m: tensor.shape(tensor.rank() - UInt::new(1)), - k: tensor.shape(tensor.rank() - UInt::new(2)), - n: UInt::new(0), - }; - let info = LoadInfo { - coordinates, - k, - batch_offset, - shared_memory, - config, - dims, - }; - - load_lhs_plain::>(tensor, info, config); - } else { - // Permuted - let dims = Dimensions { - m: UInt::new(0), - k: tensor.shape(tensor.rank() - UInt::new(1)), - n: tensor.shape(tensor.rank() - UInt::new(2)), - }; - let info = LoadInfo { - coordinates, - k, - batch_offset, - shared_memory, - config, - dims, - }; - - load_rhs_transposed::>(tensor, info, config); - } - - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { - sm_out[i] = shared_memory[i]; - } - } - - #[cube(launch)] - fn load_tensor_multiple_tiles_test( - tensor: &Tensor, - sm_out: &mut Array, - k: UInt, - config: Comptime, - is_lhs: Comptime, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let block_size_k = Comptime::map(config, |c| c.block_size_k); - let block_size_m = Comptime::map(config, |c| c.block_size_m); - let sm_size = block_size_k * block_size_m / tile_size; - let shared_memory = - SharedMemory::::vectorized(Comptime::get(sm_size), Comptime::get(tile_size)); - - let unit_row = UInt::new(4) * UNIT_POS_X; - let unit_col = UInt::new(4) * UNIT_POS_Y; - let batch_offset = UInt::new(0); - - let coordinates = Coordinates { - unit_row, - unit_col, - skip_row: UInt::new(0), - skip_col: UInt::new(0), - }; - - if Comptime::get(is_lhs) { - let dims = Dimensions { - m: tensor.shape(tensor.rank() - UInt::new(2)), - k: tensor.shape(tensor.rank() - UInt::new(1)), - n: UInt::new(0), - }; - let info = LoadInfo { - coordinates, - k, - batch_offset, - shared_memory, - config, - dims, - }; - - load_lhs_transposed::>(tensor, info, config); - } else { - let dims = Dimensions { - m: UInt::new(0), - k: tensor.shape(tensor.rank() - UInt::new(2)), - n: tensor.shape(tensor.rank() - UInt::new(1)), - }; - let info = LoadInfo { - coordinates, - k, - batch_offset, - shared_memory, - config, - dims, - }; - - load_rhs_plain::>(tensor, info, config); - } - - for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { - sm_out[i] = shared_memory[i]; - } - } - - /// Exported test - pub fn load_lhs_transposed_unit_test(device: &R::Device) { - let lhs = range_tensor::(16, 16, device); - let sm_out = create_empty::(8, 8, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(16, 16, 8); - - load_tensor_test::launch::( - lhs.client.clone(), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - true, - ); - - let expected = &[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 76.0, 92.0, 108.0, 124.0, 0.0, 0.0, 0.0, 0.0, 77.0, 93.0, 109.0, 125.0, 0.0, - 0.0, 0.0, 0.0, 78.0, 94.0, 110.0, 126.0, 0.0, 0.0, 0.0, 0.0, 79.0, 95.0, 111.0, 127.0, - ]; - assert_equals::(sm_out, expected, device); - } - - /// Exported test - pub fn load_lhs_transposed_out_of_bounds_cube_test(device: &R::Device) { - let vectorization_factor = 1; - let lhs = range_tensor::(5, 1, device); - let sm_out = create_empty::(8, 8, device); - let cube_dim = CubeDim::new(2, 2, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(5, 1, 1); - - load_tensor_multiple_tiles_test::launch::( - lhs.client.clone(), - cube_count, - cube_dim, - TensorArg::vectorized( - vectorization_factor as u8, - &lhs.handle, - &lhs.strides, - &lhs.shape.dims, - ), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(0), - config, - true, - ); - - let expected = &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - ]; - assert_equals::(sm_out, expected, device); - } - - /// Exported test - pub fn load_lhs_transposed_cube_test(device: &R::Device) { - let lhs = range_tensor::(8, 8, device); - let sm_out = create_empty::(8, 8, device); - let cube_dim = CubeDim::new(2, 2, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(8, 8, 8); - - load_tensor_multiple_tiles_test::launch::( - lhs.client.clone(), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(0), - config, - true, - ); - - let expected = &[ - 0.0, 8.0, 16.0, 24.0, 32.0, 40.0, 48.0, 56.0, 1.0, 9.0, 17.0, 25.0, 33.0, 41.0, 49.0, - 57.0, 2.0, 10.0, 18.0, 26.0, 34.0, 42.0, 50.0, 58.0, 3.0, 11.0, 19.0, 27.0, 35.0, 43.0, - 51.0, 59.0, 4.0, 12.0, 20.0, 28.0, 36.0, 44.0, 52.0, 60.0, 5.0, 13.0, 21.0, 29.0, 37.0, - 45.0, 53.0, 61.0, 6.0, 14.0, 22.0, 30.0, 38.0, 46.0, 54.0, 62.0, 7.0, 15.0, 23.0, 31.0, - 39.0, 47.0, 55.0, 63.0, - ]; - assert_equals::(sm_out, expected, device); - } - - /// Exported test - pub fn load_lhs_transposed_offset_cube_test(device: &R::Device) { - let lhs = range_tensor::(8, 16, device); - let sm_out = create_empty::(8, 8, device); - let cube_dim = CubeDim::new(2, 2, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(8, 8, 16); - - load_tensor_multiple_tiles_test::launch::( - lhs.client.clone(), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(8), - config, - true, - ); - - let expected = &[ - 8.0, 24.0, 40.0, 56.0, 72.0, 88.0, 104.0, 120.0, 9.0, 25.0, 41.0, 57.0, 73.0, 89.0, - 105.0, 121.0, 10.0, 26.0, 42.0, 58.0, 74.0, 90.0, 106.0, 122.0, 11.0, 27.0, 43.0, 59.0, - 75.0, 91.0, 107.0, 123.0, 12.0, 28.0, 44.0, 60.0, 76.0, 92.0, 108.0, 124.0, 13.0, 29.0, - 45.0, 61.0, 77.0, 93.0, 109.0, 125.0, 14.0, 30.0, 46.0, 62.0, 78.0, 94.0, 110.0, 126.0, - 15.0, 31.0, 47.0, 63.0, 79.0, 95.0, 111.0, 127.0, - ]; - assert_equals::(sm_out, expected, device); - } - - /// Exported test - pub fn load_rhs_plain_unit_test(device: &R::Device) { - let rhs = range_tensor::(16, 16, device); - let sm_out = create_empty::(8, 8, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(8, 16, 16); - - load_tensor_test::launch::( - rhs.client.clone(), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - false, - ); - - let expected = &[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 196.0, 197.0, 198.0, 199.0, 0.0, 0.0, 0.0, 0.0, 212.0, 213.0, 214.0, 215.0, - 0.0, 0.0, 0.0, 0.0, 228.0, 229.0, 230.0, 231.0, 0.0, 0.0, 0.0, 0.0, 244.0, 245.0, - 246.0, 247.0, - ]; - assert_equals::(sm_out, expected, device); - } - - /// Exported test - pub fn load_rhs_plain_cube_test(device: &R::Device) { - let rhs = range_tensor::(8, 8, device); - let sm_out = create_empty::(8, 8, device); - let cube_dim = CubeDim::new(2, 2, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(8, 8, 8); - - load_tensor_multiple_tiles_test::launch::( - rhs.client.clone(), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(0), - config, - false, - ); - - let expected = &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, - 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, - 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, - 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, - 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, - ]; - assert_equals::(sm_out, expected, device); - } - - /// Exported test - pub fn load_rhs_plain_cube_offset_test(device: &R::Device) { - let rhs = range_tensor::(16, 8, device); - let sm_out = create_empty::(8, 8, device); - let cube_dim = CubeDim::new(2, 2, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(16, 16, 8); - - load_tensor_multiple_tiles_test::launch::( - rhs.client.clone(), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(8), - config, - false, - ); - - let expected = &[ - 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, - 78.0, 79.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, - 92.0, 93.0, 94.0, 95.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, - 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 112.0, 113.0, 114.0, 115.0, 116.0, - 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0, - ]; - assert_equals::(sm_out, expected, device); - } - - /// Exported test - pub fn load_lhs_plain_unit_test(device: &R::Device) { - let lhs = range_tensor::(16, 16, device); - let sm_out = create_empty::(8, 8, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(16, 16, 8); - - load_tensor_permuted_test::launch::( - lhs.client.clone(), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - true, - ); - - let expected = &[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 196.0, 197.0, 198.0, 199.0, 0.0, 0.0, 0.0, 0.0, 212.0, 213.0, 214.0, 215.0, - 0.0, 0.0, 0.0, 0.0, 228.0, 229.0, 230.0, 231.0, 0.0, 0.0, 0.0, 0.0, 244.0, 245.0, - 246.0, 247.0, - ]; - assert_equals::(sm_out, expected, device); - } - - /// Exported test - pub fn load_lhs_plain_out_of_bounds_unit_test(device: &R::Device) { - let (m, k) = (6, 14); - let lhs = range_tensor::(k, m, device); - let sm_out = create_empty::(8, 8, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(m, k, 8); - - load_tensor_permuted_test::launch::( - lhs.client.clone(), - cube_count, - cube_dim, - TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - true, - ); - - let expected = &[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 76.0, 77.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 82.0, 83.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - ]; - assert_equals::(sm_out, expected, device); - } - - /// Exported test - pub fn load_rhs_transposed_unit_test(device: &R::Device) { - let rhs = range_tensor::(16, 16, device); - let sm_out = create_empty::(8, 8, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(16, 16, 8); - - load_tensor_permuted_test::launch::( - rhs.client.clone(), - cube_count, - cube_dim, - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - false, - ); - - let expected = &[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 76.0, 92.0, 108.0, 124.0, 0.0, 0.0, 0.0, 0.0, 77.0, 93.0, 109.0, 125.0, 0.0, - 0.0, 0.0, 0.0, 78.0, 94.0, 110.0, 126.0, 0.0, 0.0, 0.0, 0.0, 79.0, 95.0, 111.0, 127.0, - ]; - assert_equals::(sm_out, expected, device); - } - - /// Exported test - pub fn load_rhs_transposed_out_of_bounds_unit_test(device: &R::Device) { - let (k, n) = (14, 6); - let rhs = range_tensor::(n, k, device); - let sm_out = create_empty::(8, 8, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(8, k, n); - - load_tensor_permuted_test::launch::( - rhs.client.clone(), - cube_count, - cube_dim, - TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), - ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), - ScalarArg::new(4), - ScalarArg::new(4), - ScalarArg::new(8), - config, - false, - ); - - let expected = &[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 68.0, 82.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 69.0, 83.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - ]; - assert_equals::(sm_out, expected, device); - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/mod.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/mod.rs deleted file mode 100644 index d971c92ff1..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/mod.rs +++ /dev/null @@ -1,19 +0,0 @@ -mod base; -mod block_loop; -mod compute_loop; -mod launch; -mod load_shared_memory; -mod outer_product; -#[cfg(feature = "export_tests")] -mod test_utils; -mod tile; -mod write_output; - -pub use launch::matmul_tiling_2d_cube; - -#[cfg(feature = "export_tests")] -pub use { - compute_loop::tests as compute_loop_tests, - load_shared_memory::tests as load_shared_memory_tests, - outer_product::tests as outer_product_tests, write_output::tests as write_output_tests, -}; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/outer_product.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/outer_product.rs deleted file mode 100644 index 2ab90e6116..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/outer_product.rs +++ /dev/null @@ -1,118 +0,0 @@ -use burn_cube::prelude::*; - -use crate::kernel::matmul::config::CubeTiling2dConfig; - -#[cube] -pub(crate) fn tile_outer_product( - register_m: F, - register_n: F, - results: &mut Array, - config: Comptime, -) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - - for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) { - let res_pos_base = res_idx_m * Comptime::runtime(tile_size); - for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) { - let mul = register_m[res_idx_m] * register_n[res_idx_n]; - results[res_pos_base + res_idx_n] += mul; - } - } -} - -#[cfg(feature = "export_tests")] -/// Exported tests for outer product -pub mod tests { - use crate::{ - kernel::matmul::{ - config::CubeTiling2dConfig, - tiling2d_cube::test_utils::{assert_equals, create_empty, make_config}, - }, - JitRuntime, - }; - - use super::*; - - #[cube(launch)] - #[allow(unused_mut)] - fn tile_outer_product_test( - register_m: Array, - register_n: Array, - results: &mut Array, - config: Comptime, - ) { - // We launch with array then convert to vectorized float, - // because direct launch of vectorized float is not supported - let tile_size = Comptime::map(config, |c| c.tile_size); - let register_m = register_m.to_vectorized(tile_size); - let register_n = register_n.to_vectorized(tile_size); - - for i in range( - 0u32, - Comptime::get(tile_size * tile_size), - Comptime::new(false), - ) { - results[i] = F::new(0.); - } - tile_outer_product::(register_m, register_n, results, config) - } - - /// Exported test - pub fn tile_outer_product_vectorized_unit_test(device: &R::Device) { - let client = R::client(device); - let register_m = client.create(f32::as_bytes(&[0., 1., 2., 3.])); - let register_n = client.create(f32::as_bytes(&[1., 2., 3., 4.])); - let results = create_empty::(4, 4, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - const SOME_DIM: usize = 12; - let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM); - - tile_outer_product_test::launch::( - client.clone(), - cube_count, - cube_dim, - ArrayArg::new(®ister_m, 4), - ArrayArg::new(®ister_n, 4), - ArrayArg::new(&results, 16), - config, - ); - - let expected = &[ - 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, - ]; - assert_equals::(results, expected, device); - } - - /// Exported test - pub fn tile_outer_product_vectorized_unit_test_2(device: &R::Device) { - let client = R::client(device); - - let register_m = client.create(f32::as_bytes(&[16., 20., 24., 28.])); - let register_n = client.create(f32::as_bytes(&[4., 5., 6., 7.])); - let results = create_empty::(4, 4, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - const SOME_DIM: usize = 12; - let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM); - - tile_outer_product_test::launch::( - client.clone(), - cube_count, - cube_dim, - ArrayArg::new(®ister_m, 4), - ArrayArg::new(®ister_n, 4), - ArrayArg::new(&results, 16), - config, - ); - - let expected = &[ - 64.0, 80.0, 96.0, 112.0, 80.0, 100.0, 120.0, 140.0, 96.0, 120.0, 144.0, 168.0, 112.0, - 140.0, 168.0, 196.0, - ]; - assert_equals::(results, expected, device); - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/test_utils.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/test_utils.rs deleted file mode 100644 index 4dacb8e7da..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/test_utils.rs +++ /dev/null @@ -1,89 +0,0 @@ -use burn_compute::server::Handle; -use burn_cube::CubeElement; - -use crate::{ - kernel::matmul::config::{CubeTiling2dConfig, Tiling2dConfig}, - tensor::JitTensor, - JitBackend, JitRuntime, -}; - -pub(crate) const TILE_SIZE: usize = 4; - -pub(crate) fn range_tensor( - x: usize, - y: usize, - device: &R::Device, -) -> JitTensor { - type B = JitBackend; - - let n_elements = (x * y) as i64; - burn_tensor::Tensor::, 1, burn_tensor::Int>::arange(0..n_elements, device) - .reshape([x, y]) - .float() - .into_primitive() - .tensor() -} - -pub(crate) fn range_tensor_transposed( - x: usize, - y: usize, - device: &R::Device, -) -> JitTensor { - type B = JitBackend; - - let n_elements = (x * y) as i64; - - burn_tensor::Tensor::, 2>::from_data( - burn_tensor::Tensor::, 1, burn_tensor::Int>::arange(0..n_elements, device) - .reshape([x, y]) - .float() - .transpose() - .into_data(), - device, - ) - .into_primitive() - .tensor() -} - -pub(crate) fn zeros_tensor( - x: usize, - y: usize, - device: &R::Device, -) -> JitTensor { - type B = JitBackend; - burn_tensor::Tensor::, 2>::zeros([x, y], device) - .into_primitive() - .tensor() -} - -pub(crate) fn create_empty( - x: usize, - y: usize, - device: &R::Device, -) -> Handle<::JitServer> { - let client = R::client(device); - client.empty(x * y * core::mem::size_of::()) -} - -pub(crate) fn assert_equals( - output: Handle<::JitServer>, - expected: &[f32], - device: &R::Device, -) { - let client = R::client(device); - - let actual = client.read(output.binding()); - let actual = f32::from_bytes(&actual); - - assert_eq!(actual, expected); -} - -pub(crate) fn make_config(m: usize, k: usize, n: usize) -> CubeTiling2dConfig { - let tiling2d_config = Tiling2dConfig { - block_size_m: 8, - block_size_k: 8, - block_size_n: 8, - ..Default::default() - }; - CubeTiling2dConfig::new(&tiling2d_config, m, k, n, false, false) -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/base.rs deleted file mode 100644 index 4d50b86aa2..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/base.rs +++ /dev/null @@ -1,78 +0,0 @@ -use burn_cube::prelude::*; - -use crate::kernel::matmul::{ - config::CubeTiling2dConfig, - tiling2d_cube::{ - tile::{ - loader::{CheckBounds, ReadTileInfo}, - memory_access::ContiguousAccess, - }, - write_output::WriteTileInfo, - }, -}; - -#[cube] -pub(crate) trait BlockLoader: Send + Sync + 'static { - fn load_tile_plain>( - tensor: &Tensor, - shared_memory: &mut SharedMemory, - read_tile_info: ReadTileInfo, - config: Comptime, - check_bounds: CheckBounds, - ); - - fn load_tile_transposed( - tensor: &Tensor, - shared_memory: &mut SharedMemory, - read_tile_info: ReadTileInfo, - config: Comptime, - check_bounds: CheckBounds, - ); -} - -#[cube] -pub(crate) trait BlockWriter: Send + Sync + 'static { - fn write_output>( - out: &mut Tensor, - results: &Array, - write_tile_info: WriteTileInfo, - config: Comptime, - check_bounds: CheckBounds, - ); -} - -#[cube] -pub(crate) fn all_zeros_runtime( - shared_memory: &mut SharedMemory, - start: UInt, - sm_position_base: UInt, - sm_stride: UInt, - config: Comptime, -) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let zeros = F::vectorized(0., Comptime::get(tile_size)); - - for i in range(start, Comptime::get(tile_size), Comptime::new(false)) { - let sm_position = (sm_position_base + i * sm_stride) / Comptime::runtime(tile_size); - - shared_memory[sm_position] = zeros; - } -} - -#[cube] -pub(crate) fn all_zeros_comptime( - shared_memory: &mut SharedMemory, - sm_position_base: UInt, - sm_stride: UInt, - config: Comptime, -) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - let zeros = F::vectorized(0., Comptime::get(tile_size)); - - for i in range(0u32, Comptime::get(tile_size), unroll) { - let sm_position = (sm_position_base + i * sm_stride) / Comptime::runtime(tile_size); - - shared_memory[sm_position] = zeros; - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/horizontal_block_check.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/horizontal_block_check.rs deleted file mode 100644 index 8b09877fd4..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/horizontal_block_check.rs +++ /dev/null @@ -1,118 +0,0 @@ -use burn_cube::prelude::*; - -use crate::kernel::matmul::{ - config::CubeTiling2dConfig, - tiling2d_cube::{ - tile::{ - loader::{CheckBounds, ReadTileInfo}, - memory_access::{ - ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions, - WritePositionsExpand, - }, - }, - write_output::WriteTileInfo, - }, -}; - -use super::base::{all_zeros_comptime, all_zeros_runtime, BlockLoader, BlockWriter}; - -pub(crate) struct HorizontalCheckBlockIO; - -#[cube] -impl BlockLoader for HorizontalCheckBlockIO { - fn load_tile_plain>( - tensor: &Tensor, - shared_memory: &mut SharedMemory, - info: ReadTileInfo, - config: Comptime, - check_bounds: CheckBounds, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let vectorization = Comptime::vectorization(&tensor); - let unroll = Comptime::map(config, |c| c.unroll_tile); - - let col = check_bounds.skip_col + info.read_col; - if check_bounds.dim_horizontal > col { - for i in range(0u32, Comptime::get(tile_size), unroll) { - let gm_position = - (info.gm_position_base + i * info.gm_stride) / Comptime::runtime(vectorization); - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); - - shared_memory[sm_position] = - A::read_contiguous_checked(tensor, gm_position, check_bounds, info, config); - } - } else { - all_zeros_comptime(shared_memory, info.sm_position_base, info.sm_stride, config); - } - } - - fn load_tile_transposed( - tensor: &Tensor, - shared_memory: &mut SharedMemory, - info: ReadTileInfo, - config: Comptime, - check_bounds: CheckBounds, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - - let mut num_reads = UInt::new(0); - let col = check_bounds.skip_col + info.read_col; - let dim_horizontal = check_bounds.dim_horizontal; - if dim_horizontal > col { - num_reads = UInt::min(dim_horizontal - col, Comptime::runtime(tile_size)); - } - - for i in range(0u32, num_reads, Comptime::new(false)) { - let gm_position = info.gm_position_base + i; - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); - - shared_memory[sm_position] = UnmatchingVectorization::read_strided_unchecked( - tensor, - gm_position, - info.gm_stride, - config, - ); - } - - all_zeros_runtime( - shared_memory, - num_reads, - info.sm_position_base, - info.sm_stride, - config, - ); - } -} - -#[cube] -impl BlockWriter for HorizontalCheckBlockIO { - fn write_output>( - out: &mut Tensor, - results: &Array, - info: WriteTileInfo, - config: Comptime, - check_bounds: CheckBounds, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - let coordinates = info.coordinates; - - let col = coordinates.skip_col + coordinates.unit_col; - - if check_bounds.dim_horizontal > col { - let row = coordinates.skip_row + coordinates.unit_row; - let out_position_base = row * info.out_stride + col + info.offset_output; - - for result_index in range(0u32, Comptime::get(tile_size), unroll) { - let positions = WritePositions { - result: result_index * Comptime::runtime(tile_size), - out: out_position_base + result_index * info.out_stride, - }; - - A::write_contiguous_checked(out, results, positions, check_bounds, col, config); - } - } - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/mod.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/mod.rs deleted file mode 100644 index 50c913843b..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub mod base; -pub mod horizontal_block_check; -pub mod unchecked_block; -pub mod vertical_block_check; -pub mod whole_block_check; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/unchecked_block.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/unchecked_block.rs deleted file mode 100644 index d695f3da65..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/unchecked_block.rs +++ /dev/null @@ -1,96 +0,0 @@ -use burn_cube::prelude::*; - -use crate::kernel::matmul::{ - config::CubeTiling2dConfig, - tiling2d_cube::{ - tile::{ - loader::{CheckBounds, ReadTileInfo}, - memory_access::{ - ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions, - WritePositionsExpand, - }, - }, - write_output::WriteTileInfo, - }, -}; - -use super::base::{BlockLoader, BlockWriter}; - -/// Assumes block sizes divide tensor shape -pub(crate) struct UncheckedBlockIO; - -#[cube] -impl BlockLoader for UncheckedBlockIO { - fn load_tile_plain>( - tensor: &Tensor, - shared_memory: &mut SharedMemory, - info: ReadTileInfo, - config: Comptime, - _check_bounds: CheckBounds, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - let vectorization = Comptime::vectorization(&tensor); - - for i in range(0u32, Comptime::get(tile_size), unroll) { - let gm_position = - (info.gm_position_base + i * info.gm_stride) / Comptime::runtime(vectorization); - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); - - shared_memory[sm_position] = A::read_contiguous_unchecked(tensor, gm_position, config); - } - } - - fn load_tile_transposed( - tensor: &Tensor, - shared_memory: &mut SharedMemory, - info: ReadTileInfo, - config: Comptime, - _check_bounds: CheckBounds, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - - for i in range(0u32, Comptime::get(tile_size), unroll) { - let gm_position = info.gm_position_base + i; - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); - - shared_memory[sm_position] = UnmatchingVectorization::read_strided_unchecked( - tensor, - gm_position, - info.gm_stride, - config, - ); - } - } -} - -#[cube] -impl BlockWriter for UncheckedBlockIO { - fn write_output>( - out: &mut Tensor, - results: &Array, - info: WriteTileInfo, - config: Comptime, - _check_bounds: CheckBounds, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - let coordinates = info.coordinates; - - let row = coordinates.skip_row + coordinates.unit_row; - let col = coordinates.skip_col + coordinates.unit_col; - let out_position_base = row * info.out_stride + col + info.offset_output; - - for result_index in range(0u32, Comptime::get(tile_size), unroll) { - let positions = WritePositions { - result: result_index * Comptime::runtime(tile_size), - out: out_position_base + result_index * info.out_stride, - }; - - A::write_contiguous_unchecked(out, results, positions, config); - } - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/vertical_block_check.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/vertical_block_check.rs deleted file mode 100644 index 1e91a32ac9..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/vertical_block_check.rs +++ /dev/null @@ -1,120 +0,0 @@ -use burn_cube::prelude::*; - -use crate::kernel::matmul::{ - config::CubeTiling2dConfig, - tiling2d_cube::{ - tile::{ - loader::{CheckBounds, ReadTileInfo}, - memory_access::{ - ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions, - WritePositionsExpand, - }, - }, - write_output::WriteTileInfo, - }, -}; - -use super::base::{all_zeros_runtime, BlockLoader, BlockWriter}; - -pub(crate) struct VerticalCheckBlockIO; - -#[cube] -impl BlockLoader for VerticalCheckBlockIO { - fn load_tile_plain>( - tensor: &Tensor, - shared_memory: &mut SharedMemory, - info: ReadTileInfo, - config: Comptime, - check_bounds: CheckBounds, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let vectorization = Comptime::vectorization(&tensor); - - let mut num_reads = UInt::new(0); - let row = check_bounds.skip_row + info.read_row; - if check_bounds.dim_vertical > row { - num_reads = UInt::min( - check_bounds.dim_vertical - row, - Comptime::runtime(tile_size), - ); - } - - for i in range(0u32, num_reads, Comptime::new(false)) { - let gm_position = - (info.gm_position_base + i * info.gm_stride) / Comptime::runtime(vectorization); - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); - - shared_memory[sm_position] = A::read_contiguous_unchecked(tensor, gm_position, config); - } - - all_zeros_runtime( - shared_memory, - num_reads, - info.sm_position_base, - info.sm_stride, - config, - ); - } - - fn load_tile_transposed( - tensor: &Tensor, - shared_memory: &mut SharedMemory, - info: ReadTileInfo, - config: Comptime, - check_bounds: CheckBounds, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - - for i in range(0u32, Comptime::get(tile_size), unroll) { - let gm_position = info.gm_position_base + i; - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); - - shared_memory[sm_position] = UnmatchingVectorization::read_strided_checked( - tensor, - gm_position, - info.gm_stride, - check_bounds, - info, - config, - ); - } - } -} - -#[cube] -impl BlockWriter for VerticalCheckBlockIO { - fn write_output>( - out: &mut Tensor, - results: &Array, - info: WriteTileInfo, - config: Comptime, - check_bounds: CheckBounds, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let coordinates = info.coordinates; - - let row = coordinates.skip_row + coordinates.unit_row; - let col = coordinates.skip_col + coordinates.unit_col; - let out_position_base = row * info.out_stride + col + info.offset_output; - - let mut num_writes = UInt::new(0); - if check_bounds.dim_vertical > row { - num_writes = UInt::min( - check_bounds.dim_vertical - row, - Comptime::runtime(tile_size), - ); - } - - for result_index in range(0u32, num_writes, Comptime::new(false)) { - let positions = WritePositions { - result: result_index * Comptime::runtime(tile_size), - out: out_position_base + result_index * info.out_stride, - }; - - A::write_contiguous_unchecked(out, results, positions, config); - } - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/whole_block_check.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/whole_block_check.rs deleted file mode 100644 index e868b16337..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/block_io/whole_block_check.rs +++ /dev/null @@ -1,143 +0,0 @@ -use burn_cube::prelude::*; - -use crate::kernel::matmul::{ - config::CubeTiling2dConfig, - tiling2d_cube::{ - tile::{ - loader::{CheckBounds, ReadTileInfo}, - memory_access::{ - ContiguousAccess, StridedAccess, UnmatchingVectorization, WritePositions, - WritePositionsExpand, - }, - }, - write_output::WriteTileInfo, - }, -}; - -use super::base::{all_zeros_comptime, all_zeros_runtime, BlockLoader, BlockWriter}; - -pub(crate) struct WholeCheckBlockIO; - -#[cube] -impl BlockLoader for WholeCheckBlockIO { - fn load_tile_plain>( - tensor: &Tensor, - shared_memory: &mut SharedMemory, - info: ReadTileInfo, - config: Comptime, - check_bounds: CheckBounds, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let vectorization = Comptime::vectorization(&tensor); - - let col = check_bounds.skip_col + info.read_col; - if check_bounds.dim_horizontal > col { - let mut num_reads_vertical = UInt::new(0); - let row = check_bounds.skip_row + info.read_row; - if check_bounds.dim_vertical > row { - num_reads_vertical = UInt::min( - check_bounds.dim_vertical - row, - Comptime::runtime(tile_size), - ); - } - - for i in range(0u32, num_reads_vertical, Comptime::new(false)) { - let gm_position = - (info.gm_position_base + i * info.gm_stride) / Comptime::runtime(vectorization); - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); - - shared_memory[sm_position] = - A::read_contiguous_checked(tensor, gm_position, check_bounds, info, config); - } - - all_zeros_runtime( - shared_memory, - num_reads_vertical, - info.sm_position_base, - info.sm_stride, - config, - ); - } else { - all_zeros_comptime(shared_memory, info.sm_position_base, info.sm_stride, config); - } - } - fn load_tile_transposed( - tensor: &Tensor, - shared_memory: &mut SharedMemory, - info: ReadTileInfo, - config: Comptime, - check_bounds: CheckBounds, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - - let mut num_reads_horizontal = UInt::new(0); - let col = check_bounds.skip_col + info.read_col; - let dim_horizontal = check_bounds.dim_horizontal; - if dim_horizontal > col { - num_reads_horizontal = UInt::min(dim_horizontal - col, Comptime::runtime(tile_size)); - } - - for i in range(0u32, num_reads_horizontal, Comptime::new(false)) { - let gm_position = info.gm_position_base + i; - let sm_position = - (info.sm_position_base + i * info.sm_stride) / Comptime::runtime(tile_size); - - shared_memory[sm_position] = UnmatchingVectorization::read_strided_checked( - tensor, - gm_position, - info.gm_stride, - check_bounds, - info, - config, - ); - } - - all_zeros_runtime( - shared_memory, - num_reads_horizontal, - info.sm_position_base, - info.sm_stride, - config, - ); - } -} - -#[cube] -impl BlockWriter for WholeCheckBlockIO { - fn write_output>( - out: &mut Tensor, - results: &Array, - info: WriteTileInfo, - config: Comptime, - check_bounds: CheckBounds, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let coordinates = info.coordinates; - - let col = coordinates.skip_col + coordinates.unit_col; - - if check_bounds.dim_horizontal > col { - let mut num_writes_vertical = UInt::new(0); - let row = coordinates.skip_row + coordinates.unit_row; - - if check_bounds.dim_vertical > row { - num_writes_vertical = UInt::min( - check_bounds.dim_vertical - row, - Comptime::runtime(tile_size), - ); - } - - let out_position_base = row * info.out_stride + col + info.offset_output; - - for result_index in range(0u32, num_writes_vertical, Comptime::new(false)) { - let positions = WritePositions { - result: result_index * Comptime::runtime(tile_size), - out: out_position_base + result_index * info.out_stride, - }; - - A::write_contiguous_checked(out, results, positions, check_bounds, col, config); - } - } - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/loader.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/loader.rs deleted file mode 100644 index 4df21a1df0..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/loader.rs +++ /dev/null @@ -1,216 +0,0 @@ -use std::marker::PhantomData; - -use burn_cube::prelude::*; - -use crate::kernel::matmul::tiling2d_cube::load_shared_memory::{LoadInfo, Loader}; - -use super::{ - block_io::base::BlockLoader, - memory_access::{MatchingVectorization, UnmatchingVectorization}, -}; - -// Transposed tensor's vectorization must be 1 -// Plain tensor's vectorization must equal tile size -pub(crate) struct TileLoader { - _f: PhantomData, -} - -#[derive(CubeType)] -pub(crate) struct LoadIndices { - pub offset: UInt, - pub gm_stride: UInt, - pub sm_stride: UInt, -} - -#[derive(CubeType, Copy, Clone)] -pub(crate) struct CheckBounds { - pub dim_vertical: UInt, - pub dim_horizontal: UInt, - pub skip_row: UInt, - pub skip_col: UInt, -} - -#[derive(CubeType, Copy, Clone)] -pub(crate) struct ReadTileInfo { - pub read_row: UInt, - pub read_col: UInt, - pub gm_position_base: UInt, - pub sm_position_base: UInt, - pub gm_stride: UInt, - pub sm_stride: UInt, -} - -#[cube] -impl Loader for TileLoader { - fn load_lhs_plain>(lhs: &Tensor, load_info: LoadInfo) { - let config = load_info.config; - let dims = load_info.dims; - let coordinates = load_info.coordinates; - let gm_stride = dims.m; - - let load_indices = LoadIndices { - offset: coordinates.skip_row + load_info.k * gm_stride + load_info.batch_offset, - gm_stride, - sm_stride: Comptime::runtime(Comptime::map(config, |c| c.block_size_n)), - }; - let check_bounds = CheckBounds { - dim_vertical: dims.k, - dim_horizontal: dims.m, - skip_row: load_info.k, - skip_col: coordinates.skip_row, - }; - - load_plain::(lhs, load_info, load_indices, check_bounds); - } - - fn load_lhs_transposed>(lhs: &Tensor, load_info: LoadInfo) { - let config = load_info.config; - let dims = load_info.dims; - let coordinates = load_info.coordinates; - let gm_stride = dims.k; - - let load_indices = LoadIndices { - offset: coordinates.skip_row * gm_stride + load_info.k + load_info.batch_offset, - gm_stride, - sm_stride: Comptime::runtime(Comptime::map(config, |c| c.block_size_m)), - }; - let check_bounds = CheckBounds { - dim_vertical: dims.m, - dim_horizontal: dims.k, - skip_row: coordinates.skip_row, - skip_col: load_info.k, - }; - - load_transposed::(lhs, load_info, load_indices, check_bounds); - } - - fn load_rhs_plain>(rhs: &Tensor, load_info: LoadInfo) { - let coordinates = load_info.coordinates; - let dims = load_info.dims; - let config = load_info.config; - let gm_stride = dims.n; - - let load_indices = LoadIndices { - offset: coordinates.skip_col + load_info.k * gm_stride + load_info.batch_offset, - gm_stride, - sm_stride: Comptime::runtime(Comptime::map(config, |c| c.block_size_n)), - }; - let check_bounds = CheckBounds { - dim_vertical: dims.k, - dim_horizontal: dims.n, - skip_row: load_info.k, - skip_col: coordinates.skip_col, - }; - - load_plain::(rhs, load_info, load_indices, check_bounds); - } - - fn load_rhs_transposed>(rhs: &Tensor, load_info: LoadInfo) { - let config = load_info.config; - let dims = load_info.dims; - let coordinates = load_info.coordinates; - let gm_stride = dims.k; - - let load_indices = LoadIndices { - offset: coordinates.skip_col * gm_stride + load_info.k + load_info.batch_offset, - gm_stride, - sm_stride: Comptime::runtime(Comptime::map(config, |c| c.block_size_n)), - }; - let check_bounds = CheckBounds { - dim_vertical: dims.n, - dim_horizontal: dims.k, - skip_row: coordinates.skip_col, - skip_col: load_info.k, - }; - - load_transposed::(rhs, load_info, load_indices, check_bounds); - } -} - -#[cube] -pub(crate) fn load_plain>( - tensor: &Tensor, - load_info: LoadInfo, - load_indices: LoadIndices, - check_bounds: CheckBounds, -) { - let coordinates = load_info.coordinates; - let config = load_info.config; - - let vectorization = Comptime::vectorization(tensor); - let tile_size = Comptime::map(config, |c| c.tile_size); - let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k)); - - let read_row = coordinates.unit_row; - let read_col = coordinates.unit_col; - let write_row = coordinates.unit_row; - let write_col = coordinates.unit_col; - - let gm_position_base = read_row * load_indices.gm_stride + read_col + load_indices.offset; - let sm_position_base = write_row * load_indices.sm_stride + write_col; - - let read_tile_info = ReadTileInfo { - read_row, - read_col, - gm_position_base, - sm_position_base, - gm_stride: load_indices.gm_stride, - sm_stride: load_indices.sm_stride, - }; - let mut sm = load_info.shared_memory; - - if write_row < sm_dim_vertical { - if vectorization == tile_size { - L::load_tile_plain::( - tensor, - &mut sm, - read_tile_info, - config, - check_bounds, - ); - } else { - L::load_tile_plain::( - tensor, - &mut sm, - read_tile_info, - config, - check_bounds, - ); - } - } -} - -#[cube] -pub(crate) fn load_transposed>( - tensor: &Tensor, - load_info: LoadInfo, - load_indices: LoadIndices, - check_bounds: CheckBounds, -) { - let coordinates = load_info.coordinates; - let config = load_info.config; - - let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k)); - - let read_row = coordinates.unit_row; - let read_col = coordinates.unit_col; - let write_row = coordinates.unit_col; - let write_col = coordinates.unit_row; - - let gm_position_base = read_row * load_indices.gm_stride + read_col + load_indices.offset; - let sm_position_base = write_row * load_indices.sm_stride + write_col; - - let read_tile_info = ReadTileInfo { - read_row, - read_col, - gm_position_base, - sm_position_base, - gm_stride: load_indices.gm_stride, - sm_stride: load_indices.sm_stride, - }; - let mut sm = load_info.shared_memory; - - if write_row < sm_dim_vertical { - L::load_tile_transposed(tensor, &mut sm, read_tile_info, config, check_bounds); - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/memory_access.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/memory_access.rs deleted file mode 100644 index 862472e0dc..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/memory_access.rs +++ /dev/null @@ -1,319 +0,0 @@ -use burn_cube::prelude::*; - -use crate::kernel::matmul::config::CubeTiling2dConfig; - -use super::loader::{CheckBounds, ReadTileInfo}; - -#[derive(CubeType)] -pub(crate) struct WritePositions { - pub out: UInt, - pub result: UInt, -} - -#[cube] -pub(crate) trait ContiguousAccess: Send + Sync + 'static { - fn read_contiguous_unchecked( - tensor: &Tensor, - gm_position: UInt, - config: Comptime, - ) -> F; - - fn read_contiguous_checked( - tensor: &Tensor, - gm_position: UInt, - check_bounds: CheckBounds, - read_info: ReadTileInfo, - config: Comptime, - ) -> F; - - fn write_contiguous_unchecked( - out: &mut Tensor, - results: &Array, - positions: WritePositions, - config: Comptime, - ); - - fn write_contiguous_checked( - out: &mut Tensor, - results: &Array, - positions: WritePositions, - check_bounds: CheckBounds, - write_col: UInt, - config: Comptime, - ); -} - -#[cube] -pub(crate) trait StridedAccess: Send + Sync + 'static { - fn read_strided_unchecked( - tensor: &Tensor, - gm_position: UInt, - gm_stride: UInt, - config: Comptime, - ) -> F; - - fn read_strided_checked( - tensor: &Tensor, - gm_position: UInt, - gm_stride: UInt, - check_bounds: CheckBounds, - info: ReadTileInfo, - config: Comptime, - ) -> F; -} - -#[derive(new)] -/// When vectorization == tile_size -pub(crate) struct MatchingVectorization; - -/// When vectorization != tile_size -#[derive(new)] -pub(crate) struct UnmatchingVectorization; - -#[cube] -impl ContiguousAccess for MatchingVectorization { - fn read_contiguous_unchecked( - tensor: &Tensor, - gm_position: UInt, - _config: Comptime, - ) -> F { - tensor[gm_position] - } - - fn read_contiguous_checked( - tensor: &Tensor, - gm_position: UInt, - _check_bounds: CheckBounds, - _read_info: ReadTileInfo, - config: Comptime, - ) -> F { - // If vectorization matches, then it's certain to fit since tile_size divides block_sizes - MatchingVectorization::read_contiguous_unchecked(tensor, gm_position, config) - } - - fn write_contiguous_unchecked( - out: &mut Tensor, - results: &Array, - positions: WritePositions, - config: Comptime, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - - let mut output_elem = F::vectorized_empty(Comptime::get(tile_size)); - - for i in range(0u32, Comptime::get(tile_size), unroll) { - output_elem[i] = results[positions.result + i]; - } - - out[positions.out / Comptime::runtime(tile_size)] = output_elem; - } - - fn write_contiguous_checked( - out: &mut Tensor, - results: &Array, - positions: WritePositions, - _check_bounds: CheckBounds, - _write_col: UInt, - config: Comptime, - ) { - // If vectorization matches, then it's certain to fit since tile_size divides block_sizes - MatchingVectorization::write_contiguous_unchecked(out, results, positions, config) - } -} - -#[cube] -impl ContiguousAccess for UnmatchingVectorization { - fn read_contiguous_unchecked( - tensor: &Tensor, - gm_position: UInt, - config: Comptime, - ) -> F { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - let vectorization_factor = Comptime::vectorization(tensor); - let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1); - - let mut vector = F::vectorized(0., Comptime::get(tile_size)); - - for i in range( - 0u32, - Comptime::get(tile_size / vectorization_factor), - unroll, - ) { - let runtime_vectorization = Comptime::runtime(vectorization_factor); - - if Comptime::get(is_scalar) { - vector[i] = tensor[gm_position + i]; - } else { - let intermediate = tensor[gm_position + i]; - - for j in range(0u32, Comptime::get(vectorization_factor), unroll) { - vector[i * runtime_vectorization + j] = intermediate[j]; - } - } - } - - vector - } - - fn read_contiguous_checked( - tensor: &Tensor, - gm_position: UInt, - check_bounds: CheckBounds, - read_info: ReadTileInfo, - config: Comptime, - ) -> F { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - let vectorization_factor = Comptime::vectorization(tensor); - let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1); - let runtime_vectorization = Comptime::runtime(vectorization_factor); - - let mut vector = F::vectorized(0., Comptime::get(tile_size)); - - let mut num_loops = UInt::new(0); - if check_bounds.dim_horizontal > read_info.read_col { - let num_reads = UInt::min( - check_bounds.dim_horizontal - read_info.read_col, - Comptime::runtime(tile_size), - ); - num_loops = num_reads / runtime_vectorization; - } - - for i in range(0u32, num_loops, Comptime::new(false)) { - if Comptime::get(is_scalar) { - vector[i] = tensor[gm_position + i]; - } else { - let intermediate = tensor[gm_position + i]; - - for j in range(0u32, Comptime::get(vectorization_factor), unroll) { - vector[i * runtime_vectorization + j] = intermediate[j]; - } - } - } - - vector - } - - fn write_contiguous_unchecked( - out: &mut Tensor, - results: &Array, - positions: WritePositions, - config: Comptime, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - let vectorization_factor = Comptime::vectorization(out); - let runtime_vectorization = Comptime::runtime(vectorization_factor); - let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1); - - for i in range( - 0u32, - Comptime::get(tile_size / vectorization_factor), - unroll, - ) { - if Comptime::get(is_scalar) { - out[i + positions.out] = results[positions.result + i]; - } else { - let mut output_elem = F::vectorized_empty(Comptime::get(vectorization_factor)); - - for j in range(0u32, Comptime::get(vectorization_factor), unroll) { - let index = i * runtime_vectorization + j; - output_elem[j] = results[positions.result + index]; - } - - out[i + positions.out / runtime_vectorization] = output_elem; - } - } - } - - fn write_contiguous_checked( - out: &mut Tensor, - results: &Array, - positions: WritePositions, - check_bounds: CheckBounds, - write_col: UInt, - config: Comptime, - ) { - let tile_size = Comptime::map(config, |c| c.tile_size); - let vectorization_factor = Comptime::vectorization(out); - let runtime_vectorization = Comptime::runtime(vectorization_factor); - let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1); - - let mut num_loops = UInt::new(0); - if check_bounds.dim_horizontal > write_col { - let num_writes = UInt::min( - check_bounds.dim_horizontal - write_col, - Comptime::runtime(tile_size), - ); - num_loops = num_writes / runtime_vectorization; - } - - for i in range(0u32, num_loops, Comptime::new(false)) { - let unroll = Comptime::map(config, |c| c.unroll_tile); - - if Comptime::get(is_scalar) { - out[i + positions.out] = results[positions.result + i]; - } else { - let mut output_elem = F::vectorized_empty(Comptime::get(vectorization_factor)); - - for j in range(0u32, Comptime::get(vectorization_factor), unroll) { - let index = i * runtime_vectorization + j; - output_elem[j] = results[positions.result + index]; - } - - out[i + positions.out / runtime_vectorization] = output_elem; - } - } - } -} - -#[cube] -impl StridedAccess for UnmatchingVectorization { - fn read_strided_unchecked( - tensor: &Tensor, - gm_position: UInt, - gm_stride: UInt, - config: Comptime, - ) -> F { - let tile_size = Comptime::map(config, |c| c.tile_size); - let unroll = Comptime::map(config, |c| c.unroll_tile); - - let mut vertical = F::vectorized_empty(Comptime::get(tile_size)); - for i in range(0u32, Comptime::get(tile_size), unroll) { - vertical[i] = tensor[gm_position + i * gm_stride]; - } - - vertical - } - - fn read_strided_checked( - tensor: &Tensor, - gm_position: UInt, - gm_stride: UInt, - check_bounds: CheckBounds, - info: ReadTileInfo, - config: Comptime, - ) -> F { - let tile_size = Comptime::map(config, |c| c.tile_size); - - let mut vertical = F::vectorized_empty(Comptime::get(tile_size)); - - let mut num_reads = UInt::new(0); - let row = check_bounds.skip_row + info.read_row; - let dim_vertical = check_bounds.dim_vertical; - if dim_vertical > row { - num_reads = UInt::min(dim_vertical - row, Comptime::runtime(tile_size)); - } - - for i in range(0u32, num_reads, Comptime::new(false)) { - vertical[i] = tensor[gm_position + i * gm_stride]; - } - for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) { - vertical[i] = F::new(0.); - } - - vertical - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/mod.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/mod.rs deleted file mode 100644 index 015d4a59c7..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -pub mod block_io; -pub mod loader; -pub mod memory_access; -pub mod writer; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/writer.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/writer.rs deleted file mode 100644 index 09c1a063ee..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/tile/writer.rs +++ /dev/null @@ -1,60 +0,0 @@ -use std::marker::PhantomData; - -use burn_cube::prelude::*; - -use crate::kernel::matmul::{ - config::CubeTiling2dConfig, - tiling2d_cube::{ - base::Dimensions, - write_output::{OutputWriter, WriteTileInfo}, - }, -}; - -use super::{ - block_io::base::BlockWriter, - loader::{CheckBounds, CheckBoundsExpand}, - memory_access::{MatchingVectorization, UnmatchingVectorization}, -}; -pub(crate) struct TileWriter { - _f: PhantomData, -} - -#[cube] -impl OutputWriter for TileWriter { - fn write_output>( - out: &mut Tensor, - results: &Array, - write_info: WriteTileInfo, - dims: Dimensions, - config: Comptime, - ) { - let vectorization = Comptime::vectorization(out); - let tile_size = Comptime::map(config, |c| c.tile_size); - let coordinates = write_info.coordinates; - - let check_bounds = CheckBounds { - dim_vertical: dims.m, - dim_horizontal: dims.n, - skip_row: coordinates.skip_row, - skip_col: coordinates.skip_col, - }; - - if vectorization == tile_size { - B::write_output::( - out, - results, - write_info, - config, - check_bounds, - ); - } else { - B::write_output::( - out, - results, - write_info, - config, - check_bounds, - ); - } - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs deleted file mode 100644 index 42d2ee8d1c..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs +++ /dev/null @@ -1,263 +0,0 @@ -use burn_cube::prelude::*; - -use crate::kernel::matmul::config::CubeTiling2dConfig; - -use super::{ - base::{Coordinates, Dimensions}, - tile::block_io::{ - base::BlockWriter, horizontal_block_check::HorizontalCheckBlockIO, - unchecked_block::UncheckedBlockIO, vertical_block_check::VerticalCheckBlockIO, - whole_block_check::WholeCheckBlockIO, - }, -}; - -#[derive(CubeType)] -pub(crate) struct WriteTileInfo { - pub coordinates: Coordinates, - pub offset_output: UInt, - pub out_stride: UInt, -} - -#[cube] -pub(crate) trait OutputWriter: Sync + Send + 'static { - fn write_output>( - out: &mut Tensor, - results: &Array, - write_tile_info: WriteTileInfo, - dims: Dimensions, - config: Comptime, - ); -} - -#[cube] -pub(crate) fn write_to_output>( - out: &mut Tensor, - results: &Array, - coordinates: Coordinates, - offset_output: UInt, - dims: Dimensions, - config: Comptime, -) { - let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds); - let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds); - - let write_info = WriteTileInfo { - coordinates, - offset_output, - out_stride: dims.n, - }; - - if Comptime::get(check_m_bounds) { - if Comptime::get(check_n_bounds) { - W::write_output::(out, results, write_info, dims, config); - } else { - W::write_output::(out, results, write_info, dims, config); - } - } else if Comptime::get(check_n_bounds) { - W::write_output::(out, results, write_info, dims, config); - } else { - W::write_output::(out, results, write_info, dims, config); - } -} - -#[cfg(feature = "export_tests")] -/// Exported tests for write output -pub mod tests { - use crate::{ - kernel::matmul::tiling2d_cube::{ - test_utils::{ - assert_equals, make_config, range_tensor, range_tensor_transposed, zeros_tensor, - TILE_SIZE, - }, - tile::writer::TileWriter, - }, - JitRuntime, - }; - - use super::{ - super::base::{CoordinatesExpand, DimensionsExpand}, - *, - }; - - #[cube(launch)] - fn write_to_output_test( - out: &mut Tensor, - results: &mut Array, - config: Comptime, - ) { - let coordinates = Coordinates { - unit_row: UInt::new(4), - unit_col: UInt::new(4), - skip_row: UInt::new(0), - skip_col: UInt::new(0), - }; - let dims = Dimensions { - m: out.shape(out.rank() - UInt::new(2)), - k: UInt::new(0), - n: out.shape(out.rank() - UInt::new(1)), - }; - - write_to_output::>(out, results, coordinates, UInt::new(0), dims, config); - } - - #[cube(launch)] - fn write_results_to_output_out_of_bounds_test( - out: &mut Tensor, - results: &mut Array, - config: Comptime, - ) { - let coordinates = Coordinates { - unit_row: UNIT_POS_X * UInt::new(4), - unit_col: UNIT_POS_Y * UInt::new(4), - skip_row: UInt::new(0), - skip_col: UInt::new(0), - }; - let dims = Dimensions { - m: out.shape(out.rank() - UInt::new(2)), - k: UInt::new(0), - n: out.shape(out.rank() - UInt::new(1)), - }; - - write_to_output::>(out, results, coordinates, UInt::new(0), dims, config); - } - - /// Exported test - pub fn write_to_output_over_height_unit_test(device: &R::Device) { - let out = zeros_tensor::(6, 8, device); - let tile = range_tensor::(4, 4, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(6, 8, 8); - - write_to_output_test::launch::( - out.client.clone(), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape.dims), - ArrayArg::new(&tile.handle, 16), - config, - ); - - let expected = &[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, - ]; - assert_equals::(out.handle, expected, device); - } - - /// Exported test - pub fn write_to_output_over_width_unit_test(device: &R::Device) { - let out = zeros_tensor::(8, 4, device); - let tile = range_tensor::(4, 4, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(8, 8, 4); - - write_to_output_test::launch::( - out.client.clone(), - cube_count, - cube_dim, - TensorArg::vectorized(TILE_SIZE as u8, &out.handle, &out.strides, &out.shape.dims), - ArrayArg::new(&tile.handle, 16), - config, - ); - - let expected = &[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - ]; - assert_equals::(out.handle, expected, device); - } - - /// Exported test - pub fn write_to_output_vectorized_less_than_tile_unit_test(device: &R::Device) { - let vectorization = 2; - let out = zeros_tensor::(8, 8, device); - let tile = range_tensor::(4, 4, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(8, 8, 8); - - write_to_output_test::launch::( - out.client.clone(), - cube_count, - cube_dim, - TensorArg::vectorized( - vectorization as u8, - &out.handle, - &out.strides, - &out.shape.dims, - ), - ArrayArg::new(&tile.handle, 16), - config, - ); - - let expected = &[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, 0.0, 0.0, 0.0, - 0.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, - ]; - assert_equals::(out.handle, expected, device); - } - - /// Exported test - pub fn write_to_output_scalar_unit_test(device: &R::Device) { - let vectorization = 1; - let out = zeros_tensor::(8, 8, device); - let tile = range_tensor::(4, 4, device); - let cube_dim = CubeDim::new(1, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(8, 8, 8); - - write_to_output_test::launch::( - out.client.clone(), - cube_count, - cube_dim, - TensorArg::vectorized( - vectorization as u8, - &out.handle, - &out.strides, - &out.shape.dims, - ), - ArrayArg::new(&tile.handle, 16), - config, - ); - - let expected = &[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, 0.0, 0.0, 0.0, - 0.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, - ]; - assert_equals::(out.handle, expected, device); - } - - /// Exported test - pub fn write_to_output_scalar_out_of_bounds_cube_test(device: &R::Device) { - let vectorization = 1; - let out = zeros_tensor::(5, 1, device); - let results = range_tensor_transposed::(4, 4, device); - let cube_dim = CubeDim::new(2, 1, 1); - let cube_count = CubeCount::Static(1, 1, 1); - - let config = make_config(5, 8, 1); - - write_results_to_output_out_of_bounds_test::launch::( - out.client.clone(), - cube_count, - cube_dim, - TensorArg::vectorized(vectorization, &out.handle, &out.strides, &out.shape.dims), - ArrayArg::new(&results.handle, 16), - config, - ); - - let expected = &[0.0, 1.0, 2.0, 3.0, 0.0]; - assert_equals::(out.handle, expected, device); - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs deleted file mode 100644 index 762c470356..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/base.rs +++ /dev/null @@ -1,68 +0,0 @@ -use burn_cube::cpa; -use burn_cube::ir::{BinaryOperator, Scope, Synchronization, Variable}; - -use crate::kernel::matmul::config::Tiling2dConfig; -use crate::kernel::matmul::tiling2d_shader::{ - computation_loop, gather_shader_information, load_shared_memory, write_to_output, -}; - -pub(crate) struct MatmulTiling2dShader { - pub variables: BinaryOperator, - pub config: Tiling2dConfig, - pub bounds_check_required: bool, -} - -pub(crate) struct Tiling2dState { - pub n_loops: Variable, - pub k: Variable, - pub lhs: Variable, - pub rhs: Variable, - pub out: Variable, - pub offset_lhs: Variable, - pub offset_rhs: Variable, - pub offset_output: Variable, - pub row: Variable, - pub col: Variable, - pub dim_m: Variable, - pub dim_k: Variable, - pub dim_n: Variable, - pub thread_col: Variable, - pub thread_row: Variable, - pub shared_lhs: Variable, - pub shared_rhs: Variable, - pub register_m: Variable, - pub register_n: Variable, - pub results: Variable, - pub lhs_stride_col: Variable, - pub lhs_stride_row: Variable, - pub rhs_stride_col: Variable, - pub rhs_stride_row: Variable, - pub out_stride_row: Variable, - pub out_stride_col: Variable, -} - -impl MatmulTiling2dShader { - pub(crate) fn expand(self, scope: &mut Scope) { - let shader_state = gather_shader_information(scope, &self); - - let block_size_k: Variable = self.config.block_size_k.into(); - cpa!( - scope, - range(0u32, shader_state.n_loops).for_each(|i, scope| { - // From 0 to K with steps block_size_k - let k = shader_state.k; - cpa!(scope, k = i * block_size_k); - - load_shared_memory(scope, &self, &shader_state); - - scope.register(Synchronization::SyncUnits); - - computation_loop(scope, &self, &shader_state); - - scope.register(Synchronization::SyncUnits); - }) - ); - - write_to_output(scope, &self, &shader_state); - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs deleted file mode 100644 index 6b128684cf..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/computation.rs +++ /dev/null @@ -1,87 +0,0 @@ -use burn_cube::{ - cpa, - ir::{Elem, Scope, Variable}, -}; - -use super::{MatmulTiling2dShader, Tiling2dState}; - -#[allow(clippy::too_many_arguments)] -pub fn computation_loop( - scope: &mut Scope, - shader: &MatmulTiling2dShader, - shader_state: &Tiling2dState, -) { - let thread_col = shader_state.thread_col; - let thread_row = shader_state.thread_row; - let shared_lhs = shader_state.shared_lhs; - let shared_rhs = shader_state.shared_rhs; - let register_m = shader_state.register_m; - let register_n = shader_state.register_n; - let results = shader_state.results; - - let block_size_k: Variable = shader.config.block_size_k.into(); - let block_size_n: Variable = shader.config.block_size_n.into(); - let elem = results.item().elem(); - - let lhs_sm_position = scope.create_local(Elem::UInt); - let rhs_sm_position = scope.create_local(Elem::UInt); - - let registered_m = scope.create_local(elem); - let registered_n = scope.create_local(elem); - - let multiplied = scope.create_local(elem); - let results_position = scope.create_local(Elem::UInt); - let results_before = scope.create_local(elem); - let results_after = scope.create_local(elem); - - cpa!( - scope, - range( - 0u32, - shader.config.block_size_k as u32, - shader.config.unroll - ) - .for_each(|dot_index, scope| { - // Load a subcolumn of values from lhs - cpa!(scope, lhs_sm_position = thread_row / 4u32); - cpa!(scope, lhs_sm_position *= block_size_k); - cpa!(scope, lhs_sm_position += dot_index); - cpa!(scope, register_m = shared_lhs[lhs_sm_position]); - - // Load a subrow of values from rhs - cpa!(scope, rhs_sm_position = dot_index * block_size_n); - cpa!(scope, rhs_sm_position += thread_col); - cpa!(scope, rhs_sm_position = rhs_sm_position / 4u32); - cpa!(scope, register_n = shared_rhs[rhs_sm_position]); - - cpa!( - scope, - range(0u32, shader.config.tile_size as u32, shader.config.unroll).for_each( - |res_idx_m, scope| { - cpa!( - scope, - range(0u32, shader.config.tile_size as u32, shader.config.unroll) - .for_each(|res_idx_n, scope| { - cpa!(scope, registered_m = register_m[res_idx_m]); - cpa!(scope, registered_n = register_n[res_idx_n]); - - cpa!(scope, multiplied = registered_m * registered_n); - - cpa!( - scope, - results_position = res_idx_m * shader.config.tile_size - ); - cpa!(scope, results_position += res_idx_n); - - cpa!(scope, results_before = results[results_position]); - cpa!(scope, results_after = results_before + multiplied); - - cpa!(scope, results[results_position] = results_after); - }) - ); - } - ) - ); - }) - ); -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs deleted file mode 100644 index ea161307b3..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/load_shared_memory.rs +++ /dev/null @@ -1,281 +0,0 @@ -use burn_cube::{ - cpa, - ir::{Elem, Scope, Variable}, -}; - -use super::{MatmulTiling2dShader, Tiling2dState}; - -enum InputIdentifier { - Lhs, - Rhs, -} - -pub(crate) fn load_shared_memory( - scope: &mut Scope, - shader: &MatmulTiling2dShader, - shader_state: &Tiling2dState, -) { - if shader.bounds_check_required { - load_shared_memory_with_bound_check(scope, shader, shader_state, InputIdentifier::Lhs); - load_shared_memory_with_bound_check(scope, shader, shader_state, InputIdentifier::Rhs); - } else { - load_shared_memory_no_bound_check(scope, shader, shader_state, InputIdentifier::Lhs); - load_shared_memory_no_bound_check(scope, shader, shader_state, InputIdentifier::Rhs); - } -} - -#[allow(clippy::too_many_arguments)] -fn load_shared_memory_with_bound_check( - scope: &mut Scope, - shader: &MatmulTiling2dShader, - shader_state: &Tiling2dState, - input_identifier: InputIdentifier, -) { - let ( - input, - input_offset, - shared_memory, - thread_idx_1, - thread_idx_2, - stride_1, - stride_2, - dim, - pos_in_dim, - ) = match input_identifier { - InputIdentifier::Lhs => ( - shader_state.lhs, - shader_state.offset_lhs, - shader_state.shared_lhs, - shader_state.thread_col, - shader_state.thread_row, - shader_state.lhs_stride_col, - shader_state.lhs_stride_row, - shader_state.dim_m, - shader_state.row, - ), - InputIdentifier::Rhs => ( - shader_state.rhs, - shader_state.offset_rhs, - shader_state.shared_rhs, - shader_state.thread_row, - shader_state.thread_col, - shader_state.rhs_stride_row, - shader_state.rhs_stride_col, - shader_state.dim_n, - shader_state.col, - ), - }; - let k = shader_state.k; - let dim_k = shader_state.dim_k; - - // How close is the thread to the end of the matrix. - // If < 4 then it is an edge case - let remain = scope.create_local(Elem::UInt); - cpa!(scope, remain = dim - pos_in_dim); - - let block_size_k: Variable = shader.config.block_size_k.into(); - let block_size_n: Variable = shader.config.block_size_n.into(); - let elem = input.item().elem(); - - let current = scope.create_local(Elem::UInt); - let aligned_with_shared_memory = scope.create_local(Elem::Bool); - let sm_position = scope.create_local(Elem::UInt); - let within_input = scope.create_local(Elem::Bool); - let current_with_k = scope.create_local(Elem::UInt); - let remain_at_least_1 = scope.create_local(Elem::Bool); - let read_condition = scope.create_local(Elem::Bool); - let val_vec4 = scope.create_local(shared_memory.item()); - - let tmp = scope.create_local(Elem::UInt); - let position_0 = scope.create_local(Elem::UInt); - let position_1 = scope.create_local(Elem::UInt); - let position_2 = scope.create_local(Elem::UInt); - let position_3 = scope.create_local(Elem::UInt); - let remain_n = scope.create_local(Elem::Bool); - - let val_0 = scope.create_local(elem); - let val_1 = scope.create_local(elem); - let val_2 = scope.create_local(elem); - let val_3 = scope.create_local(elem); - let zero: Variable = 0u32.into(); - - cpa!( - scope, - range(0_u32, 4u32, shader.config.unroll).for_each(|j, scope| { - cpa!(scope, current = thread_idx_1 + j); - - cpa!(scope, aligned_with_shared_memory = current < block_size_k); - - // To avoid overwriting following row in shared memory - cpa!(scope, if(aligned_with_shared_memory).then(|scope|{ - - // Position in shared memory - match input_identifier { - InputIdentifier::Lhs => { - cpa!(scope, sm_position = thread_idx_2 / 4u32); - cpa!(scope, sm_position *= block_size_k); - cpa!(scope, sm_position += current); - }, - InputIdentifier::Rhs => { - cpa!(scope, sm_position = current * block_size_n); - cpa!(scope, sm_position += thread_idx_2); - cpa!(scope, sm_position = sm_position / 4u32); - } - } - - // To pad with zeros if outside lhs - cpa!(scope, current_with_k = current + k); - cpa!(scope, within_input = current_with_k < dim_k); - cpa!(scope, remain_at_least_1 = remain >= 1u32); - cpa!(scope, read_condition = within_input && remain_at_least_1); - - cpa!(scope, if(read_condition).then(|scope| { - cpa!(scope, position_0 = k + current); - cpa!(scope, position_0 *= stride_1); - cpa!(scope, tmp = thread_idx_2 * stride_2); - cpa!(scope, position_0 += tmp); - cpa!(scope, position_0 += input_offset); - cpa!(scope, position_1 = position_0 + stride_2); - cpa!(scope, position_2 = position_1 + stride_2); - cpa!(scope, position_3 = position_2 + stride_2); - - cpa!(scope, remain_n = remain >= 4u32); - cpa!(scope, if(remain_n).then(|scope|{ - cpa!(scope, val_0 = input[position_0]); - cpa!(scope, val_1 = input[position_1]); - cpa!(scope, val_2 = input[position_2]); - cpa!(scope, val_3 = input[position_3]); - - }).else(|scope|{ - cpa!(scope, remain_n = remain == 3u32); - cpa!(scope, if(remain_n).then(|scope|{ - cpa!(scope, val_0 = input[position_0]); - cpa!(scope, val_1 = input[position_1]); - cpa!(scope, val_2 = input[position_2]); - cpa!(scope, val_3 = zero); - - }).else(|scope|{ - cpa!(scope, remain_n = remain == 2u32); - cpa!(scope, if(remain_n).then(|scope|{ - cpa!(scope, val_0 = input[position_0]); - cpa!(scope, val_1 = input[position_1]); - cpa!(scope, val_2 = zero); - cpa!(scope, val_3 = zero); - - }).else(|scope|{ - cpa!(scope, remain_n = remain == 1u32); - cpa!(scope, if(remain_n).then(|scope|{ - cpa!(scope, val_0 = input[position_0]); - cpa!(scope, val_1 = zero); - cpa!(scope, val_2 = zero); - cpa!(scope, val_3 = zero); - })); - })); - })); - })); - - cpa!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); - cpa!(scope, shared_memory[sm_position] = val_vec4); - - }).else(|scope|{ - cpa!(scope, val_0 = zero); - cpa!(scope, val_vec4 = vec4(val_0, val_0, val_0, val_0)); - cpa!(scope, shared_memory[sm_position] = val_vec4); - })); - })); - }) - ); -} - -#[allow(clippy::too_many_arguments)] -fn load_shared_memory_no_bound_check( - scope: &mut Scope, - shader: &MatmulTiling2dShader, - shader_state: &Tiling2dState, - input_identifier: InputIdentifier, -) { - let (input, input_offset, shared_memory, thread_idx_1, thread_idx_2, stride_1, stride_2) = - match input_identifier { - InputIdentifier::Lhs => ( - shader_state.lhs, - shader_state.offset_lhs, - shader_state.shared_lhs, - shader_state.thread_col, - shader_state.thread_row, - shader_state.lhs_stride_col, - shader_state.lhs_stride_row, - ), - InputIdentifier::Rhs => ( - shader_state.rhs, - shader_state.offset_rhs, - shader_state.shared_rhs, - shader_state.thread_row, - shader_state.thread_col, - shader_state.rhs_stride_row, - shader_state.rhs_stride_col, - ), - }; - let k = shader_state.k; - - let block_size_k: Variable = shader.config.block_size_k.into(); - let block_size_n: Variable = shader.config.block_size_n.into(); - let elem = input.item().elem(); - - let current = scope.create_local(Elem::UInt); - let aligned_with_shared_memory = scope.create_local(Elem::Bool); - let sm_position = scope.create_local(Elem::UInt); - - let tmp = scope.create_local(Elem::UInt); - let position_0 = scope.create_local(Elem::UInt); - let position_1 = scope.create_local(Elem::UInt); - let position_2 = scope.create_local(Elem::UInt); - let position_3 = scope.create_local(Elem::UInt); - let val_0 = scope.create_local(elem); - let val_1 = scope.create_local(elem); - let val_2 = scope.create_local(elem); - let val_3 = scope.create_local(elem); - let val_vec4 = scope.create_local(shared_memory.item()); - - cpa!( - scope, - range(0_u32, 4u32, shader.config.unroll).for_each(|j, scope| { - cpa!(scope, current = thread_idx_1 + j); - - cpa!(scope, aligned_with_shared_memory = current < block_size_k); - - // To avoid overwriting following row in shared memory - cpa!(scope, if(aligned_with_shared_memory).then(|scope|{ - - match input_identifier { - InputIdentifier::Lhs => { - cpa!(scope, sm_position = thread_idx_2 / 4u32); - cpa!(scope, sm_position *= block_size_k); - cpa!(scope, sm_position += current); - }, - InputIdentifier::Rhs => { - cpa!(scope, sm_position = current * block_size_n); - cpa!(scope, sm_position += thread_idx_2); - cpa!(scope, sm_position = sm_position / 4u32); - } - } - - cpa!(scope, position_0 = k + current); - cpa!(scope, position_0 *= stride_1); - cpa!(scope, tmp = thread_idx_2 * stride_2); - cpa!(scope, position_0 += tmp); - cpa!(scope, position_0 += input_offset); - cpa!(scope, position_1 = position_0 + stride_2); - cpa!(scope, position_2 = position_1 + stride_2); - cpa!(scope, position_3 = position_2 + stride_2); - - cpa!(scope, val_0 = input[position_0]); - cpa!(scope, val_1 = input[position_1]); - cpa!(scope, val_2 = input[position_2]); - cpa!(scope, val_3 = input[position_3]); - - cpa!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3)); - cpa!(scope, shared_memory[sm_position] = val_vec4); - })); - }) - ); -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/mod.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/mod.rs deleted file mode 100644 index 3ed28903d7..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod base; -mod computation; -mod load_shared_memory; -mod shader_information; -mod write_output; - -pub(crate) use base::*; -pub(crate) use computation::*; -pub(crate) use load_shared_memory::*; -pub(crate) use shader_information::*; -pub(crate) use write_output::*; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs deleted file mode 100644 index d9bad3b8d3..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/shader_information.rs +++ /dev/null @@ -1,183 +0,0 @@ -use burn_cube::{ - cpa, - ir::{Elem, Item, Scope, Variable}, -}; - -use super::{MatmulTiling2dShader, Tiling2dState}; - -pub(crate) fn gather_shader_information( - scope: &mut Scope, - shader: &MatmulTiling2dShader, -) -> Tiling2dState { - // Inputs - let lhs = shader.variables.lhs; - let rhs = shader.variables.rhs; - let out = shader.variables.out; - - // Config variables - let block_size_m: Variable = shader.config.block_size_m.into(); - let block_size_k: Variable = shader.config.block_size_k.into(); - let block_size_n: Variable = shader.config.block_size_n.into(); - let tile_size_m: Variable = shader.config.tile_size.into(); - let tile_size_n: Variable = shader.config.tile_size.into(); - let n_threads_per_row: Variable = - (((shader.config.block_size_n - 1) / shader.config.tile_size) + 1).into(); - let results_size = (shader.config.tile_size * shader.config.tile_size) as u32; - - // Shader info - let local_idx = Variable::UnitPos; - let batch = Variable::AbsolutePosZ; - - // Shapes - let rank = Variable::Rank; - let last_dim = scope.create_local(Elem::UInt); - let second_to_last_dim = scope.create_local(Elem::UInt); - let dim_m = scope.create_local(Elem::UInt); - let dim_k = scope.create_local(Elem::UInt); - let dim_n = scope.create_local(Elem::UInt); - cpa!(scope, last_dim = rank - 1u32); - cpa!(scope, second_to_last_dim = rank - 2u32); - cpa!(scope, dim_m = shape(lhs, second_to_last_dim)); - cpa!(scope, dim_k = shape(lhs, last_dim)); - cpa!(scope, dim_n = shape(rhs, last_dim)); - - // Strides - let lhs_stride_row = scope.create_local(Elem::UInt); - let lhs_stride_col = scope.create_local(Elem::UInt); - let rhs_stride_row = scope.create_local(Elem::UInt); - let rhs_stride_col = scope.create_local(Elem::UInt); - let out_stride_row = scope.create_local(Elem::UInt); - let out_stride_col = scope.create_local(Elem::UInt); - cpa!(scope, lhs_stride_row = stride(lhs, second_to_last_dim)); - cpa!(scope, lhs_stride_col = stride(lhs, last_dim)); - cpa!(scope, rhs_stride_row = stride(rhs, second_to_last_dim)); - cpa!(scope, rhs_stride_col = stride(rhs, last_dim)); - cpa!(scope, out_stride_row = stride(out, second_to_last_dim)); - cpa!(scope, out_stride_col = stride(out, last_dim)); - - // Cube offset - let skip_row = scope.create_local(Elem::UInt); - let skip_col = scope.create_local(Elem::UInt); - let cube_pos_x = Variable::CubePosX; - let cube_pos_y = Variable::CubePosY; - cpa!(scope, skip_row = cube_pos_x); - cpa!(scope, skip_row *= block_size_m); - cpa!(scope, skip_col = cube_pos_y); - cpa!(scope, skip_col *= block_size_n); - - // Position of the first element of the thread, relative to the block - let thread_row = scope.create_local(Elem::UInt); - let thread_col = scope.create_local(Elem::UInt); - cpa!(scope, thread_row = local_idx / n_threads_per_row); - cpa!(scope, thread_row *= tile_size_m); - cpa!(scope, thread_col = local_idx % n_threads_per_row); - cpa!(scope, thread_col *= tile_size_n); - - // Position of the first element of the thread, in absolute (in one batch) - let row = scope.create_local(Elem::UInt); - let col = scope.create_local(Elem::UInt); - cpa!(scope, row = skip_row + thread_row); - cpa!(scope, col = skip_col + thread_col); - - // Calculate offset. - let offset_lhs = scope.create_local(Elem::UInt); - let offset_rhs = scope.create_local(Elem::UInt); - cpa!(scope, offset_lhs = skip_row * lhs_stride_row); - cpa!(scope, offset_rhs = skip_col * rhs_stride_col); - - // Batch offset for the output. - let offset_output = scope.create_local(Elem::UInt); - let batch_dims = scope.create_local(Elem::UInt); - cpa!(scope, offset_output = dim_m * dim_n); - cpa!(scope, offset_output = offset_output * batch); - - // Batch offset for the lhs & rhs matrices. - let stride_lhs = scope.create_local(Elem::UInt); - let stride_rhs = scope.create_local(Elem::UInt); - let stride_output = scope.create_local(Elem::UInt); - let shape_lhs = scope.create_local(Elem::UInt); - let shape_rhs = scope.create_local(Elem::UInt); - let tmp = scope.create_local(Elem::UInt); - let tmp_lhs = scope.create_local(Elem::UInt); - let tmp_rhs = scope.create_local(Elem::UInt); - cpa!(scope, batch_dims = rank - 2u32); - cpa!( - scope, - range(0u32, batch_dims).for_each(|b, scope| { - cpa!(scope, stride_lhs = stride(lhs, b)); - cpa!(scope, stride_rhs = stride(rhs, b)); - cpa!(scope, stride_output = stride(out, b)); - cpa!(scope, shape_lhs = shape(lhs, b)); - cpa!(scope, shape_rhs = shape(rhs, b)); - - cpa!(scope, tmp = offset_output / stride_output); - cpa!(scope, tmp_lhs = tmp % shape_lhs); - cpa!(scope, tmp_lhs = tmp_lhs * stride_lhs); - cpa!(scope, offset_lhs += tmp_lhs); - - cpa!(scope, tmp_rhs = tmp % shape_rhs); - cpa!(scope, tmp_rhs = tmp_rhs * stride_rhs); - cpa!(scope, offset_rhs += tmp_rhs); - }) - ); - - let elem = lhs.item().elem(); - - // Registers used in the compute pass - let results = scope.create_local_array(elem, results_size); - let register_m = scope.create_local(Item::vectorized(elem, 4)); - let register_n = scope.create_local(Item::vectorized(elem, 4)); - let shared_lhs = scope.create_shared( - Item::vectorized(elem, 4), - shader.config.block_size_m as u32 * shader.config.block_size_k as u32 / 4u32, - ); - let shared_rhs = scope.create_shared( - Item::vectorized(elem, 4), - shader.config.block_size_k as u32 * shader.config.block_size_n as u32 / 4u32, - ); - - // Calculate exact number of loop iterations - let n_loops = scope.create_local(Elem::UInt); - let k = scope.create_local(Elem::UInt); - if shader.bounds_check_required { - let dim_k_float = scope.create_local(elem); - let block_size_k_float = scope.create_local(elem); - let n_loops_float = scope.create_local(elem); - cpa!(scope, dim_k_float = dim_k); - cpa!(scope, block_size_k_float = block_size_k); - cpa!(scope, n_loops_float = dim_k_float / block_size_k_float); - cpa!(scope, n_loops_float = ceil(n_loops_float)); - cpa!(scope, n_loops = n_loops_float); - } else { - cpa!(scope, n_loops = dim_k / block_size_k); - } - - Tiling2dState { - n_loops, - k, - lhs, - rhs, - out, - offset_lhs, - offset_rhs, - offset_output, - row, - col, - dim_m, - dim_k, - dim_n, - thread_col, - thread_row, - shared_lhs, - shared_rhs, - register_m, - register_n, - results, - lhs_stride_col, - lhs_stride_row, - rhs_stride_col, - rhs_stride_row, - out_stride_row, - out_stride_col, - } -} diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs deleted file mode 100644 index 0ce06307a9..0000000000 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_shader/write_output.rs +++ /dev/null @@ -1,124 +0,0 @@ -use burn_cube::{ - cpa, - ir::{Elem, Scope, Variable}, -}; - -use super::{MatmulTiling2dShader, Tiling2dState}; - -#[allow(clippy::too_many_arguments)] -pub fn write_to_output( - scope: &mut Scope, - shader: &MatmulTiling2dShader, - shader_state: &Tiling2dState, -) { - let row = shader_state.row; - let col = shader_state.col; - - let row_index = scope.create_local(Elem::UInt); - let col_index = scope.create_local(Elem::UInt); - - if shader.bounds_check_required { - let dim_m = shader_state.dim_m; - let dim_n = shader_state.dim_n; - - let within_output = scope.create_local(Elem::Bool); - let within_output_tmp = scope.create_local(Elem::Bool); - - cpa!( - scope, - range(0u32, shader.config.tile_size as u32, shader.config.unroll).for_each( - |res_idx_m, scope| { - cpa!( - scope, - range(0u32, shader.config.tile_size as u32, shader.config.unroll).for_each( - |res_idx_n, scope| { - cpa!(scope, row_index = row + res_idx_m); - cpa!(scope, col_index = col + res_idx_n); - - cpa!(scope, within_output = row_index < dim_m); - cpa!(scope, within_output_tmp = col_index < dim_n); - cpa!(scope, within_output = within_output && within_output_tmp); - - cpa!(scope, if(within_output).then(|scope|{ - write_inner( - scope, - shader, - shader_state, - res_idx_m, - res_idx_n, - row_index, - col_index, - ); - })); - } - ) - ); - } - ) - ); - } else { - cpa!( - scope, - range(0u32, shader.config.tile_size as u32, shader.config.unroll).for_each( - |res_idx_m, scope| { - cpa!( - scope, - range(0u32, shader.config.tile_size as u32, shader.config.unroll).for_each( - |res_idx_n, scope| { - cpa!(scope, row_index = row + res_idx_m); - cpa!(scope, col_index = col + res_idx_n); - - write_inner( - scope, - shader, - shader_state, - res_idx_m, - res_idx_n, - row_index, - col_index, - ) - } - ) - ); - } - ) - ); - } -} - -#[allow(clippy::too_many_arguments)] -fn write_inner( - scope: &mut Scope, - shader: &MatmulTiling2dShader, - shader_state: &Tiling2dState, - res_idx_m: Variable, - res_idx_n: Variable, - row_index: Variable, - col_index: Variable, -) { - let offset_output = shader_state.offset_output; - let out = shader_state.out; - let out_stride_row = shader_state.out_stride_row; - let out_stride_col = shader_state.out_stride_col; - let results = shader_state.results; - - let elem = results.item().elem(); - let results_position = scope.create_local(Elem::UInt); - let result = scope.create_local(elem); - let output_position = scope.create_local(Elem::UInt); - - cpa!( - scope, - results_position = res_idx_m * shader.config.tile_size - ); - cpa!(scope, results_position += res_idx_n); - - cpa!(scope, result = results[results_position]); - - cpa!(scope, row_index *= out_stride_row); - cpa!(scope, col_index *= out_stride_col); - cpa!(scope, output_position = row_index + col_index); - cpa!(scope, output_position += offset_output); - - cpa!(scope, out[output_position] = result); -} diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index d35de0bf6f..e58d97b656 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -1,16 +1,13 @@ -use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet}; use burn_tensor::{Element, ElementConversion}; +use cubecl::tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTuner}; use crate::{ element::FloatElement, - kernel::{ - matmul::{config::Tiling2dConfig, utils::init_matmul_output}, - prng::random_like_uniform, - }, + kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform}, ops::numeric::empty_device, tensor::JitTensor, tune_key::JitAutotuneKey, - JitRuntime, + JitRuntime, JitTuneId, }; use super::key::MatmulAutotuneKey; @@ -59,32 +56,7 @@ impl AutotuneOperationSet AutotuneOperationSet Box::new(SimpleMatmul::new(self.lhs, self.rhs, self.out)), 1 => Box::new(SimpleMatmul16x16::new(self.lhs, self.rhs, self.out)), - 2 => Box::new(Tiling2dMatmul::new(self.lhs, self.rhs, self.out)), - 3 => Box::new(Tiling2dMatmulUnrolled::new(self.lhs, self.rhs, self.out)), - 4 => Box::new(Tiling2dMatmulPadded::new(self.lhs, self.rhs, self.out)), - 5 => Box::new(Tiling2dMatmulPaddedUnrolled::new( - self.lhs, self.rhs, self.out, - )), - 6 => Box::new(Tiling2dMatmulCube::new(self.lhs, self.rhs, self.out)), - 7 => Box::new(Tiling2dMatmulCubeUnrolled::new( - self.lhs, self.rhs, self.out, - )), + 2 => Box::new(MatmulCube::new(self.lhs, self.rhs, self.out)), _ => panic!("Fastest index is out of bound"), } } @@ -116,9 +79,13 @@ pub fn matmul_autotune let output = init_matmul_output(&lhs, &rhs); - let operation_set = Box::new(MatmulAutotuneOperationSet::new(lhs, rhs, output.clone())); + static TUNER: LocalTuner = local_tuner!(); - client.autotune_execute(operation_set); + TUNER.execute( + &JitTuneId::new::(&lhs.device), + &client, + Box::new(MatmulAutotuneOperationSet::new(lhs, rhs, output.clone())), + ); output } @@ -160,56 +127,15 @@ matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| { crate::kernel::matmul::matmul_simple(lhs, rhs, out, 16, 16) }); -// Maybe the fastest for transposed inputs, without loop unrolling -matmul_tune_ops!(Tiling2dMatmul, |lhs, rhs, out| { - crate::kernel::matmul::matmul_tiling_2d(lhs, rhs, out, Tiling2dConfig::default()) -}); - -// Maybe the fastest for transposed inputs, with loop unrolling -matmul_tune_ops!(Tiling2dMatmulUnrolled, |lhs, rhs, out| { - crate::kernel::matmul::matmul_tiling_2d( - lhs, - rhs, - out, - Tiling2dConfig { - unroll: true, - ..Default::default() - }, - ) -}); - -// Maybe the fastest when fixed size, without loop unrolling -matmul_tune_ops!(Tiling2dMatmulPadded, |lhs, rhs, out| { - crate::kernel::matmul::matmul_tiling_2d_padded(lhs, rhs, out, Tiling2dConfig::default()) -}); - -// Maybe the fastest when fixed sizes, with loop unrolling -matmul_tune_ops!(Tiling2dMatmulPaddedUnrolled, |lhs, rhs, out| { - crate::kernel::matmul::matmul_tiling_2d_padded( - lhs, - rhs, - out, - Tiling2dConfig { - unroll: true, - ..Default::default() - }, - ) -}); - // Probably the fastest in the general case, without loop unrolling -matmul_tune_ops!(Tiling2dMatmulCube, |lhs, rhs, out| { - crate::kernel::matmul::matmul_tiling_2d_cube(lhs, rhs, out, Tiling2dConfig::default()) -}); - -// Probably the fastest in the general case, with loop unrolling -matmul_tune_ops!(Tiling2dMatmulCubeUnrolled, |lhs, rhs, out| { - crate::kernel::matmul::matmul_tiling_2d_cube( - lhs, - rhs, - out, - Tiling2dConfig { - unroll: true, - ..Default::default() - }, - ) -}); +matmul_tune_ops!( + MatmulCube, + |lhs: JitTensor, rhs: JitTensor, out: JitTensor| { + cubecl::linalg::matmul::launch_ref::( + &lhs.client, + lhs.as_handle_ref(), + rhs.as_handle_ref(), + out.as_handle_ref(), + ); + } +); diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index 867b92a287..19b5b896bd 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -13,7 +13,7 @@ pub use contiguous::*; pub use mask::*; pub(crate) use unary::*; -pub use burn_cube::{Kernel, SUBCUBE_DIM_APPROX}; +pub use cubecl::{Kernel, SUBCUBE_DIM_APPROX}; /// Convolution kernels pub mod conv; diff --git a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d.rs b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d.rs index 7ed905c612..2bc228a0da 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d.rs @@ -1,6 +1,6 @@ use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; -use burn_cube::{frontend::TensorHandle, CubeCountSettings, Execution}; use burn_tensor::Shape; +use cubecl::{frontend::TensorHandleRef, CubeCountSettings, Execution}; use super::AdaptivePool2dEagerKernel; @@ -16,12 +16,12 @@ pub(crate) fn adaptive_avg_pool2d( let kernel = AdaptivePool2dEagerKernel::::new(); Execution::start(kernel, input.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &input.handle, &input.strides, &input.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs index e3754b0863..5b42712ff4 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs @@ -1,8 +1,8 @@ use std::marker::PhantomData; -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -261,12 +261,12 @@ pub(crate) fn adaptive_avg_pool2d_backward( let kernel = AdaptiveAvgPool2dBackwardEagerKernel::::new(); Execution::start(kernel, x.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &out_grad.handle, &out_grad.strides, &out_grad.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs b/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs index 28c0eb5df9..fed897b419 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs @@ -1,4 +1,4 @@ -use burn_cube::{ +use cubecl::{ cpa, ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d.rs index c2853727c6..2d0ee46545 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d.rs @@ -1,11 +1,11 @@ use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; -use burn_cube::{ +use burn_tensor::{ops::conv::calculate_pool_output_size, Shape}; +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, Item, Scope, Variable}, CubeCountSettings, Execution, }; -use burn_tensor::{ops::conv::calculate_pool_output_size, Shape}; use std::fmt::Debug; use super::{Pool2dEagerKernel, PoolStrategy}; @@ -101,8 +101,12 @@ pub(crate) fn avg_pool2d( let kernel = Pool2dEagerKernel::::new(kernel_size, pool_strategy); Execution::start(kernel, x.client) - .inputs(&[TensorHandle::::new(&x.handle, &x.strides, &x.shape.dims)]) - .outputs(&[TensorHandle::new( + .inputs(&[TensorHandleRef::::new( + &x.handle, + &x.strides, + &x.shape.dims, + )]) + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs index 0eb6bce765..67bcd03ea0 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs @@ -5,9 +5,9 @@ use crate::{ tensor::JitTensor, JitRuntime, }; -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, IntKind, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -409,12 +409,12 @@ pub(crate) fn avg_pool2d_backward( let kernel = AvgPool2dBackwardEagerKernel::::new(kernel_size, count_include_pad); Execution::start(kernel, x.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &grad.handle, &grad.strides, &grad.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/pool/base.rs b/crates/burn-jit/src/kernel/pool/base.rs index 51c315f0bf..c3d514182b 100644 --- a/crates/burn-jit/src/kernel/pool/base.rs +++ b/crates/burn-jit/src/kernel/pool/base.rs @@ -1,6 +1,6 @@ use std::fmt::Debug; -use burn_cube::ir::{Item, Scope, Variable}; +use cubecl::ir::{Item, Scope, Variable}; pub(crate) trait PoolStrategy: Send + Sync + 'static + Clone + Debug { type Accumulator: Copy; diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d.rs b/crates/burn-jit/src/kernel/pool/max_pool2d.rs index 6e2e0ae34d..fedd11f273 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d.rs @@ -1,6 +1,6 @@ -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, Item, Scope, Variable}, CubeCountSettings, Execution, }; @@ -21,10 +21,7 @@ impl PoolStrategy for MaxPool { fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator { let max_val = scope.create_local(item); - let max_initial = Variable::ConstantScalar { - value: E::minimum_value().to_f64(), - elem: item.elem(), - }; + let max_initial = item.elem().constant_from_f64(E::minimum_value().to_f64()); cpa!(scope, max_val = max_initial); max_val } @@ -70,10 +67,7 @@ impl PoolStrategy for MaxPoolWithIndices { fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator { let max_val = scope.create_local(item); - let max_initial = Variable::ConstantScalar { - value: E::minimum_value().to_f64(), - elem: item.elem(), - }; + let max_initial = item.elem().constant_from_f64(E::minimum_value().to_f64()); cpa!(scope, max_val = max_initial); let max_index = scope.create_local(Elem::UInt); (max_val, max_index) @@ -143,8 +137,12 @@ pub(crate) fn max_pool2d( let kernel = Pool2dEagerKernel::, R, E>::new(kernel_size, MaxPool::default()); Execution::start(kernel, x.client) - .inputs(&[TensorHandle::::new(&x.handle, &x.strides, &x.shape.dims)]) - .outputs(&[TensorHandle::new( + .inputs(&[TensorHandleRef::::new( + &x.handle, + &x.strides, + &x.shape.dims, + )]) + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, @@ -196,10 +194,14 @@ pub(crate) fn max_pool2d_with_indices::new(&x.handle, &x.strides, &x.shape.dims)]) + .inputs(&[TensorHandleRef::::new( + &x.handle, + &x.strides, + &x.shape.dims, + )]) .outputs(&[ - TensorHandle::new(&output.handle, &output.strides, &output.shape.dims), - TensorHandle::new(&indices.handle, &indices.strides, &indices.shape.dims), + TensorHandleRef::new(&output.handle, &output.strides, &output.shape.dims), + TensorHandleRef::new(&indices.handle, &indices.strides, &indices.shape.dims), ]) .with_scalars(&[ stride[0] as i32, diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index 764700c7be..2310c01053 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -5,9 +5,9 @@ use crate::{ tensor::JitTensor, JitRuntime, }; -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -353,10 +353,10 @@ pub(crate) fn max_pool2d_with_indices_backward::new(&indices.handle, &indices.strides, &indices.shape.dims), - TensorHandle::new(&grad.handle, &grad.strides, &grad.shape.dims), + TensorHandleRef::::new(&indices.handle, &indices.strides, &indices.shape.dims), + TensorHandleRef::new(&grad.handle, &grad.strides, &grad.shape.dims), ]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/pool/pool2d_shader.rs b/crates/burn-jit/src/kernel/pool/pool2d_shader.rs index 6a7dc62c90..71d7602e09 100644 --- a/crates/burn-jit/src/kernel/pool/pool2d_shader.rs +++ b/crates/burn-jit/src/kernel/pool/pool2d_shader.rs @@ -1,4 +1,4 @@ -use burn_cube::{ +use cubecl::{ cpa, ir::{Elem, IntKind, Item, KernelDefinition, Scope, Variable, Visibility}, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, diff --git a/crates/burn-jit/src/kernel/prng/base.rs b/crates/burn-jit/src/kernel/prng/base.rs index 6d15d1bdee..df82b3bc8b 100644 --- a/crates/burn-jit/src/kernel/prng/base.rs +++ b/crates/burn-jit/src/kernel/prng/base.rs @@ -1,4 +1,4 @@ -use burn_cube::{ +use cubecl::{ cpa, ir::{Elem, Scope, Variable}, prelude::*, @@ -31,7 +31,7 @@ pub(crate) fn random, R: JitRuntime, E: JitElement, const D: usize>( let seeds = get_seeds(); Execution::start(kernel, client) - .outputs(&[TensorHandle::::new( + .outputs(&[TensorHandleRef::::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/prng/bernoulli.rs b/crates/burn-jit/src/kernel/prng/bernoulli.rs index 4ba9ce867b..a89e891e24 100644 --- a/crates/burn-jit/src/kernel/prng/bernoulli.rs +++ b/crates/burn-jit/src/kernel/prng/bernoulli.rs @@ -1,8 +1,8 @@ -use burn_cube::{ +use burn_tensor::Shape; +use cubecl::{ cpa, ir::{Elem, FloatKind, Scope, Variable}, }; -use burn_tensor::Shape; use crate::{ kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2}, diff --git a/crates/burn-jit/src/kernel/prng/normal.rs b/crates/burn-jit/src/kernel/prng/normal.rs index cb2544a88e..836dca19e4 100644 --- a/crates/burn-jit/src/kernel/prng/normal.rs +++ b/crates/burn-jit/src/kernel/prng/normal.rs @@ -1,4 +1,4 @@ -use burn_cube::{ +use cubecl::{ cpa, ir::{Elem, FloatKind, Scope, Variable}, }; diff --git a/crates/burn-jit/src/kernel/prng/uniform.rs b/crates/burn-jit/src/kernel/prng/uniform.rs index 1716d18c6e..c0c7ebe5ad 100644 --- a/crates/burn-jit/src/kernel/prng/uniform.rs +++ b/crates/burn-jit/src/kernel/prng/uniform.rs @@ -1,8 +1,8 @@ -use burn_cube::{ +use burn_tensor::Shape; +use cubecl::{ cpa, ir::{Elem, FloatKind, Scope, Variable}, }; -use burn_tensor::Shape; use crate::{ kernel::prng::{cast_uint_to_float, lcg_step, taus_step_0, taus_step_1, taus_step_2}, diff --git a/crates/burn-jit/src/kernel/reduce/naive/argmax.rs b/crates/burn-jit/src/kernel/reduce/naive/argmax.rs index ec5af005f4..e1402e91ac 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/argmax.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/argmax.rs @@ -1,5 +1,5 @@ use crate::{kernel::reduce::Argmax, JitElement}; -use burn_cube::{ +use cubecl::{ cpa, ir::{Elem, Item, Scope, Variable}, }; @@ -16,10 +16,9 @@ impl ReduceDimNaive for Argmax { ) -> Self::Accumulator { let index = scope.create_local(Elem::UInt); let max = scope.create_local(input_item); - let max_initial = Variable::ConstantScalar { - value: E::minimum_value().to_f64(), - elem: input_item.elem(), - }; + let max_initial = input_item + .elem() + .constant_from_f64(E::minimum_value().to_f64()); cpa!(scope, max = max_initial); (max, index) diff --git a/crates/burn-jit/src/kernel/reduce/naive/argmin.rs b/crates/burn-jit/src/kernel/reduce/naive/argmin.rs index 44c1f24d63..6359272a9f 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/argmin.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/argmin.rs @@ -1,4 +1,4 @@ -use burn_cube::{ +use cubecl::{ cpa, ir::{Elem, Item, Scope, Variable}, }; @@ -17,10 +17,10 @@ impl ReduceDimNaive for Argmin { ) -> Self::Accumulator { let index = scope.create_local(Elem::UInt); let min = scope.create_local(input_item); - let min_initial = Variable::ConstantScalar { - value: E::maximum_value().to_f64(), - elem: input_item.elem(), - }; + let min_initial = input_item + .elem() + .constant_from_f64(E::maximum_value().to_f64()); + cpa!(scope, min = min_initial); (min, index) diff --git a/crates/burn-jit/src/kernel/reduce/naive/base.rs b/crates/burn-jit/src/kernel/reduce/naive/base.rs index 56a07d2c41..8850f4a079 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/base.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/base.rs @@ -1,4 +1,4 @@ -use burn_cube::ir::{Item, Scope, Variable}; +use cubecl::ir::{Item, Scope, Variable}; use crate::JitElement; diff --git a/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs index 83828dcf25..f255e6c61a 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs @@ -1,5 +1,5 @@ use crate::{kernel::reduce::MeanDim, JitElement}; -use burn_cube::{ +use cubecl::{ cpa, ir::{Item, Scope, Variable}, }; diff --git a/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs b/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs index 8b3ad3a86e..474e72b446 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs @@ -1,5 +1,5 @@ use crate::{kernel::reduce::ProdDim, JitElement}; -use burn_cube::{ +use cubecl::{ cpa, ir::{Item, Scope, Variable}, }; diff --git a/crates/burn-jit/src/kernel/reduce/naive/shader.rs b/crates/burn-jit/src/kernel/reduce/naive/shader.rs index ed70b0cbd7..a0a094204a 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/shader.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/shader.rs @@ -1,6 +1,6 @@ -use burn_cube::{ +use cubecl::{ cpa, - frontend::TensorHandle, + frontend::TensorHandleRef, ir::{Elem, KernelDefinition, Scope, Variable, Visibility}, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, @@ -156,12 +156,12 @@ pub fn reduce_dim_naive< let kernel = NaiveReduceDimEagerKernel::::new(dim); Execution::start(kernel, input.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &input.handle, &input.strides, &input.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs b/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs index 8325b13330..71591478dd 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs @@ -1,5 +1,5 @@ use crate::{kernel::reduce::SumDim, JitElement}; -use burn_cube::{ +use cubecl::{ cpa, ir::{Item, Scope, Variable}, }; diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs index 910f98697c..edd9beef8a 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs @@ -1,5 +1,5 @@ use crate::{kernel::reduce::Argmax, JitElement}; -use burn_cube::{ +use cubecl::{ cpa, ir::{Elem, Item, Scope, Variable}, }; @@ -17,11 +17,9 @@ impl ReduceDimShared for Argmax { ) -> Self::Accumulator { let value_shared_memory = scope.create_shared(input_item, shared_memory_size); let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size); - - let max = Variable::ConstantScalar { - value: E::minimum_value().to_f64(), - elem: input_item.elem(), - }; + let max = input_item + .elem() + .constant_from_f64(E::minimum_value().to_f64()); cpa!(scope, value_shared_memory[write_position] = max); (value_shared_memory, index_shared_memory) } diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs index 3d17728400..e7a1f1694b 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs @@ -1,4 +1,4 @@ -use burn_cube::{ +use cubecl::{ cpa, ir::{Elem, Item, Scope, Variable}, }; @@ -18,11 +18,9 @@ impl ReduceDimShared for Argmin { ) -> Self::Accumulator { let value_shared_memory = scope.create_shared(input_item, shared_memory_size); let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size); - - let min = Variable::ConstantScalar { - value: E::maximum_value().to_f64(), - elem: input_item.elem(), - }; + let min = input_item + .elem() + .constant_from_f64(E::maximum_value().to_f64()); cpa!(scope, value_shared_memory[write_position] = min); (value_shared_memory, index_shared_memory) } diff --git a/crates/burn-jit/src/kernel/reduce/shared/base.rs b/crates/burn-jit/src/kernel/reduce/shared/base.rs index ddb6335c27..d62f169d13 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/base.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/base.rs @@ -1,4 +1,4 @@ -use burn_cube::ir::{Item, Scope, Variable}; +use cubecl::ir::{Item, Scope, Variable}; use crate::JitElement; diff --git a/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs index 6a229059a7..0339d9da43 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs @@ -1,5 +1,5 @@ use crate::{kernel::reduce::MeanDim, JitElement}; -use burn_cube::{ +use cubecl::{ cpa, ir::{Item, Scope, Variable}, }; diff --git a/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs index 638450cede..961e192a8b 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs @@ -1,5 +1,5 @@ use crate::{kernel::reduce::ProdDim, JitElement}; -use burn_cube::{ +use cubecl::{ cpa, ir::{Item, Scope, Variable}, }; diff --git a/crates/burn-jit/src/kernel/reduce/shared/shader.rs b/crates/burn-jit/src/kernel/reduce/shared/shader.rs index db80791e42..792dd6c51b 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/shader.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/shader.rs @@ -1,5 +1,5 @@ -use burn_cube::{ - cpa, frontend::TensorHandle, ir::KernelDefinition, prelude::CubeCount, CubeCountSettings, +use cubecl::{ + cpa, frontend::TensorHandleRef, ir::KernelDefinition, prelude::CubeCount, CubeCountSettings, Execution, InputInfo, KernelExpansion, KernelIntegrator, KernelSettings, OutputInfo, }; use std::marker::PhantomData; @@ -10,7 +10,7 @@ use crate::{ tensor::JitTensor, JitRuntime, }; -use burn_cube::ir::{Branch, CubeDim, Elem, Scope, Synchronization, Variable, Visibility}; +use cubecl::ir::{Branch, CubeDim, Elem, Scope, Synchronization, Variable, Visibility}; use super::base::ReduceDimShared; @@ -269,12 +269,12 @@ pub fn reduce_dim_shared< ); Execution::start(kernel, input.client) - .inputs(&[TensorHandle::::new( + .inputs(&[TensorHandleRef::::new( &input.handle, &input.strides, &input.shape.dims, )]) - .outputs(&[TensorHandle::new( + .outputs(&[TensorHandleRef::new( &output.handle, &output.strides, &output.shape.dims, diff --git a/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs index ad850f19c4..db85b09cb7 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs @@ -1,5 +1,5 @@ use crate::{kernel::reduce::SumDim, JitElement}; -use burn_cube::{ +use cubecl::{ cpa, ir::{Item, Scope, Variable}, }; diff --git a/crates/burn-jit/src/kernel/reduce/tune/base.rs b/crates/burn-jit/src/kernel/reduce/tune/base.rs index 14006b1927..7c9077dd81 100644 --- a/crates/burn-jit/src/kernel/reduce/tune/base.rs +++ b/crates/burn-jit/src/kernel/reduce/tune/base.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; -use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet}; use burn_tensor::{Element, ElementConversion}; +use cubecl::tune::{local_tuner, AutotuneOperation, AutotuneOperationSet, LocalTuner}; use crate::{ element::JitElement, @@ -15,7 +15,7 @@ use crate::{ ops::numeric::empty_device, tensor::JitTensor, tune_key::JitAutotuneKey, - JitRuntime, + JitRuntime, JitTuneId, }; use super::ReduceAutotuneKey; @@ -120,6 +120,7 @@ pub(crate) fn reduce_dim_autotune< let client = input.client.clone(); let output = init_reduce_output(&input, reduce_dim); + let id = JitTuneId::new::(&input.device); let operation_set = Box::new(ReduceDimAutotuneOperationSet::::new( input, @@ -127,7 +128,9 @@ pub(crate) fn reduce_dim_autotune< reduce_dim, )); - client.autotune_execute(operation_set); + static TUNER: LocalTuner = local_tuner!(); + + TUNER.execute(&id, &client, operation_set); output } diff --git a/crates/burn-jit/src/kernel/unary.rs b/crates/burn-jit/src/kernel/unary.rs index b2ebf220f1..0b5fa179ad 100644 --- a/crates/burn-jit/src/kernel/unary.rs +++ b/crates/burn-jit/src/kernel/unary.rs @@ -1,5 +1,5 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_cube::{ +use cubecl::{ calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, unexpanded, SUBCUBE_DIM_APPROX, }; @@ -79,7 +79,7 @@ where if tensor.can_mut() && is_contiguous { unary_kernel::launch::( - client, + &client, cube_count, CubeDim::default(), TensorArg::vectorized( @@ -105,7 +105,7 @@ where ); unary_kernel::launch::( - client, + &client, cube_count, CubeDim::default(), TensorArg::vectorized( diff --git a/crates/burn-jit/src/lib.rs b/crates/burn-jit/src/lib.rs index 450454b071..070e522e01 100644 --- a/crates/burn-jit/src/lib.rs +++ b/crates/burn-jit/src/lib.rs @@ -18,7 +18,8 @@ pub(crate) mod tune; /// Elements for JIT backend pub mod element; -use burn_cube::{ +use burn_tensor::backend::{DeviceId, DeviceOps}; +use cubecl::{ compute::{CubeCount, CubeTask}, Runtime, }; @@ -45,12 +46,37 @@ pub mod tests; /// Just-in-Time runtime extending the [cube runtime](Runtime). pub trait JitRuntime: Runtime { - /// The device that should also implement [DeviceOps](burn_tensor::backend::DeviceOps). + /// The device that should also implement [burn_tensor::backend::DeviceOps]. type JitDevice: burn_tensor::backend::DeviceOps; /// The cube server with the [JitAutotuneKey]. - type JitServer: burn_compute::server::ComputeServer< - AutotuneKey = JitAutotuneKey, + type JitServer: cubecl::server::ComputeServer< Kernel = Box, DispatchOptions = CubeCount, >; } + +/// ID used to identify a Just-in-Time environment. +#[derive(Hash, PartialEq, Eq, Debug, Clone)] +pub struct JitTuneId { + device: DeviceId, + name: &'static str, +} + +impl JitTuneId { + /// Create a new ID. + pub fn new(device: &R::Device) -> Self { + Self { + device: DeviceOps::id(device), + name: R::name(), + } + } +} + +impl core::fmt::Display for JitTuneId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!( + "device-{}-{}-{}", + self.device.type_id, self.device.index_id, self.name + )) + } +} diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index f1c9ba1160..ae4aca6f0d 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -1,6 +1,6 @@ use crate::{element::JitElement, kernel, tensor::JitTensor, JitRuntime}; -use burn_cube::CubeElement; use burn_tensor::{Shape, TensorData}; +use cubecl::CubeElement; use std::marker::PhantomData; pub(crate) fn from_data( diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 128c1195dc..def5282950 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -4,10 +4,10 @@ use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; use crate::kernel::{self, launch_unary, reduce, unary_op, UnaryOp}; use crate::JitBackend; use crate::{FloatElement, IntElement, JitRuntime}; -use burn_cube::prelude::*; use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor}; use burn_tensor::ElementConversion; use burn_tensor::{ops::FloatTensorOps, Distribution, Shape, TensorData}; +use cubecl::prelude::*; use std::ops::Range; impl FloatTensorOps for JitBackend diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index e5fdc50ecb..f36b0e5ecb 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -2,10 +2,10 @@ use super::{expand, numeric, permute}; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; use crate::kernel::{launch_unary, unary_op, UnaryOp}; use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime}; -use burn_cube::frontend::Numeric; -use burn_cube::prelude::*; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Shape, TensorData}; +use cubecl::frontend::Numeric; +use cubecl::prelude::*; use std::ops::Range; impl IntTensorOps for JitBackend diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index a334ad9c65..11807ebfab 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -1,11 +1,11 @@ use crate::kernel::{launch_unary, unary_op, UnaryOp}; use crate::{binary, JitRuntime}; use crate::{element::JitElement, tensor::JitTensor}; -use burn_compute::client::ComputeClient; -use burn_cube::ir::{BinaryOperator, Elem, Operator, Scope, Variable}; -use burn_cube::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX}; -use burn_cube::{tensor_vectorization_factor, Runtime}; use burn_tensor::{ElementConversion, Shape}; +use cubecl::client::ComputeClient; +use cubecl::ir::{BinaryOperator, Elem, Operator, Scope, Variable}; +use cubecl::{calculate_cube_count_elemwise, prelude::*, SUBCUBE_DIM_APPROX}; +use cubecl::{tensor_vectorization_factor, Runtime}; pub fn full( shape: Shape, @@ -43,7 +43,7 @@ pub fn full_device( ); full_kernel::launch::( - empty.client.clone(), + &empty.client, cube_count, CubeDim::default(), TensorArg::vectorized( diff --git a/crates/burn-jit/src/template/base.rs b/crates/burn-jit/src/template/base.rs index e94515af0b..e7ad950225 100644 --- a/crates/burn-jit/src/template/base.rs +++ b/crates/burn-jit/src/template/base.rs @@ -1,5 +1,5 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_cube::prelude::*; +use cubecl::prelude::*; use super::SourceTemplate; @@ -35,12 +35,12 @@ impl CubeTask for SourceKernel { /// Generates kernel source code by replacing some information using templating. #[macro_export] -macro_rules! kernel_wgsl { +macro_rules! kernel_source { ( $struct:ident, $file:expr ) => { - /// Generated kernel from wgsl file. + /// Generated kernel from a source file. #[derive(new)] pub struct $struct; diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index 8fb678c282..4f6ac226a4 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -1,15 +1,14 @@ use crate::element::JitElement; use crate::kernel::{launch_unary, unary_op, UnaryOp}; use crate::JitRuntime; -use burn_compute::client::ComputeClient; -use burn_compute::server::Handle; -use burn_cube::frontend::Numeric; -use burn_cube::prelude::*; use burn_tensor::Shape; +use cubecl::client::ComputeClient; +use cubecl::frontend::Numeric; +use cubecl::linalg::tensor::{matrix_layout, MatrixLayout, TensorHandle}; +use cubecl::prelude::{TensorHandleRef, *}; +use cubecl::server::Handle; use std::marker::PhantomData; -use super::layout::{memory_layout, MatrixLayout}; - /// The basic tensor primitive struct. #[derive(new)] pub struct JitTensor @@ -30,6 +29,14 @@ where pub(crate) elem: PhantomData, } +impl From> + for TensorHandle +{ + fn from(val: JitTensor) -> Self { + TensorHandle::new(val.shape.dims.to_vec(), val.strides.to_vec(), val.handle) + } +} + impl core::fmt::Debug for JitTensor where R: JitRuntime, @@ -121,6 +128,15 @@ where } } + /// Return the reference to a tensor handle. + pub fn as_handle_ref(&self) -> TensorHandleRef<'_, R> { + TensorHandleRef { + handle: &self.handle, + strides: &self.strides, + shape: &self.shape.dims, + } + } + pub(crate) fn can_mut_broadcast(&self, rhs: &Self) -> bool { if !self.handle.can_mut() { return false; @@ -171,6 +187,6 @@ where } pub(crate) fn matrix_layout(&self) -> MatrixLayout { - memory_layout(&self.strides) + matrix_layout(&self.strides) } } diff --git a/crates/burn-jit/src/tensor/layout.rs b/crates/burn-jit/src/tensor/layout.rs deleted file mode 100644 index 52b6b65166..0000000000 --- a/crates/burn-jit/src/tensor/layout.rs +++ /dev/null @@ -1,123 +0,0 @@ -#[derive(PartialEq, Eq, Debug)] -/// Layout for matrix tensors, i.e. tensors whose interpretation -/// is a bunch of batched matrices of 2 dimensions -pub(crate) enum MatrixLayout { - /// Memory is wholly contiguous, with row major layout - Contiguous, - /// Permutations happened, but may not impact some kernels - MildlyPermuted { - /// Last two dims are inverted - transposed: bool, - /// Some permutations exist in batch dimensions - batch_swap: bool, - }, - /// Permutations happened between batch dimensions and last two dims - HighlyPermuted, -} - -pub(crate) fn memory_layout(strides: &[usize; D]) -> MatrixLayout { - if D <= 1 { - return MatrixLayout::Contiguous; - } - - let mut transposed = false; - let mut batch_swap = false; - let row_stride = strides[D - 2]; - let col_stride = strides[D - 1]; - if row_stride < col_stride { - transposed = true; - } - let mut previous_stride = row_stride; - - for d in 0..D - 2 { - let current_stride = strides[D - 3 - d]; - if current_stride < row_stride || current_stride < col_stride { - return MatrixLayout::HighlyPermuted; - } - if current_stride < previous_stride { - batch_swap = true; - } - - previous_stride = current_stride; - } - - if transposed || batch_swap { - MatrixLayout::MildlyPermuted { - transposed, - batch_swap, - } - } else { - MatrixLayout::Contiguous - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn layout_is_contiguous() { - let strides = &[8, 4, 2, 1]; - assert_eq!(memory_layout(strides), MatrixLayout::Contiguous); - } - - #[test] - fn vector_is_contiguous() { - let strides = &[1]; - assert_eq!(memory_layout(strides), MatrixLayout::Contiguous) - } - - #[test] - fn layout_is_transposed_only() { - let strides = &[8, 4, 1, 2]; - if let MatrixLayout::MildlyPermuted { - transposed, - batch_swap, - } = memory_layout(strides) - { - assert!(transposed && !batch_swap); - } else { - unreachable!() - } - } - - #[test] - fn layout_has_swapped_batches_only() { - let strides = &[4, 8, 2, 1]; - if let MatrixLayout::MildlyPermuted { - transposed, - batch_swap, - } = memory_layout(strides) - { - assert!(!transposed && batch_swap); - } else { - unreachable!() - } - } - - #[test] - fn layout_has_swapped_batches_and_is_transposed() { - let strides = &[4, 8, 1, 2]; - if let MatrixLayout::MildlyPermuted { - transposed, - batch_swap, - } = memory_layout(strides) - { - assert!(transposed && batch_swap); - } else { - unreachable!() - } - } - - #[test] - fn layout_has_batch_swapped_with_row() { - let strides = &[8, 2, 4, 1]; - assert_eq!(memory_layout(strides), MatrixLayout::HighlyPermuted); - } - - #[test] - fn layout_has_batch_swapped_with_col() { - let strides = &[1, 4, 2, 8]; - assert_eq!(memory_layout(strides), MatrixLayout::HighlyPermuted); - } -} diff --git a/crates/burn-jit/src/tensor/mod.rs b/crates/burn-jit/src/tensor/mod.rs index d484b81985..960a77e445 100644 --- a/crates/burn-jit/src/tensor/mod.rs +++ b/crates/burn-jit/src/tensor/mod.rs @@ -1,7 +1,5 @@ mod base; -mod layout; mod qtensor; pub use base::*; -pub(crate) use layout::*; pub(crate) use qtensor::*; diff --git a/crates/burn-jit/src/tests/matmul.rs b/crates/burn-jit/src/tests/matmul.rs index 65ea81e56d..d0f373cc1f 100644 --- a/crates/burn-jit/src/tests/matmul.rs +++ b/crates/burn-jit/src/tests/matmul.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(matmul)] mod tests { use super::*; - use burn_jit::kernel::matmul::{matmul, MatmulStrategy, Tiling2dConfig}; + use burn_jit::kernel::matmul::{matmul, MatmulStrategy}; use burn_tensor::{Shape, Tensor, TensorPrimitive}; mod simple { @@ -110,493 +110,6 @@ mod tests { } } - mod tiling2d_padded { - use super::*; - - #[test] - pub fn straightforward() { - test_with_params(1, 2, 1, 1, 1); - } - - #[test] - pub fn shapes_smaller_than_blocks() { - test_with_params(8, 8, 8, 1, 1); - } - - #[test] - pub fn n_smaller_than_m() { - test_with_params(8, 8, 3, 1, 1); - } - - #[test] - pub fn m_smaller_than_n() { - test_with_params(3, 8, 8, 1, 1); - } - - #[test] - pub fn k_smaller_than_m_n() { - test_with_params(8, 3, 8, 1, 1); - } - - #[test] - pub fn k_larger_than_m_n() { - test_with_params(8, 48, 8, 1, 1); - } - - #[test] - pub fn multibatch_1_dim() { - test_with_params(8, 8, 8, 3, 1); - } - - #[test] - pub fn multibatch_2_dims() { - test_with_params(8, 8, 8, 3, 4); - } - - #[test] - pub fn blocks_divide_shapes_unevenly() { - test_with_params(7, 7, 7, 1, 1); - } - - #[test] - pub fn medium() { - test_with_params(17, 16, 16, 1, 1); - } - - #[test] - pub fn large() { - test_with_params(134, 242, 250, 1, 1); - } - - #[test] - fn swapped_batches_no_padding() { - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), - swap, - swap, - shape_lhs, - shape_rhs, - ); - } - - #[test] - fn swapped_row_col_no_padding() { - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), - swap_lhs, - swap_rhs, - shape_lhs, - shape_rhs, - ); - } - - #[test] - fn swapped_row_with_batch_no_padding() { - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), - swap_lhs, - swap_rhs, - shape_lhs, - shape_rhs, - ); - } - - #[test] - fn stable_test() { - let ref_tensor_device = Default::default(); - let x = - ReferenceTensor::<2>::from_floats([[0., 1., 2.], [3., 4., 5.]], &ref_tensor_device); - let y = - ReferenceTensor::from_floats([[0., 1.], [2., 3.], [4., 5.]], &ref_tensor_device); - - let test_tensor_device = Default::default(); - let x_jit = TestTensor::from_data(x.to_data(), &test_tensor_device); - let y_jit = TestTensor::from_data(y.to_data(), &test_tensor_device); - - let z_reference = x.matmul(y); - let z = Tensor::::from_primitive(TensorPrimitive::Float(matmul( - x_jit.into_primitive().tensor(), - y_jit.into_primitive().tensor(), - MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), - ))); - - z_reference.into_data().assert_approx_eq(&z.into_data(), 3); - } - - #[test] - fn stable_test_2() { - let ref_tensor_device = Default::default(); - let x = ReferenceTensor::<2>::from_floats( - [[0., 1.], [2., 3.], [4., 5.]], - &ref_tensor_device, - ); - let y = ReferenceTensor::from_floats([[0., 1., 2.], [3., 4., 5.]], &ref_tensor_device); - - let test_tensor_device = Default::default(); - let x_jit = TestTensor::from_data(x.to_data(), &test_tensor_device); - let y_jit = TestTensor::from_data(y.to_data(), &test_tensor_device); - - let z_reference = x.matmul(y); - let z = Tensor::::from_primitive(TensorPrimitive::Float(matmul( - x_jit.into_primitive().tensor(), - y_jit.into_primitive().tensor(), - MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), - ))); - - z_reference.into_data().assert_approx_eq(&z.into_data(), 3); - } - - fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference( - MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()), - shape_lhs, - shape_rhs, - ); - } - } - - mod tiling2d { - use super::*; - - #[test] - pub fn straightforward() { - test_with_params(1, 2, 1, 1, 1); - } - - #[test] - pub fn shapes_smaller_than_blocks() { - test_with_params(8, 8, 8, 1, 1); - } - - #[test] - pub fn shapes_equal_blocks() { - test_with_params(64, 32, 64, 2, 2); - } - - #[test] - pub fn m_exceeds_block() { - test_with_params(75, 32, 64, 2, 2); - } - - #[test] - pub fn k_exceeds_block() { - test_with_params(64, 33, 32, 1, 1); - } - - #[test] - pub fn test_matmul_irregular_shape() { - test_with_params(123, 255, 72, 3, 5); - } - - #[test] - pub fn test64_matmul_unpadded_n_exceeds_block() { - test_with_params(64, 32, 75, 2, 2); - } - - #[test] - pub fn n_smaller_than_m() { - test_with_params(8, 8, 3, 1, 1); - } - - #[test] - pub fn m_smaller_than_n() { - test_with_params(3, 8, 8, 1, 1); - } - - #[test] - pub fn k_smaller_than_m_n() { - test_with_params(8, 3, 8, 1, 1); - } - - #[test] - pub fn k_larger_than_m_n() { - test_with_params(8, 48, 8, 1, 1); - } - - #[test] - pub fn multibatch_1_dim() { - test_with_params(8, 8, 8, 3, 1); - } - - #[test] - pub fn multibatch_2_dims() { - test_with_params(8, 8, 8, 3, 4); - } - - #[test] - pub fn blocks_divide_shapes_unevenly() { - test_with_params(7, 7, 7, 1, 1); - } - - #[test] - pub fn medium() { - test_with_params(17, 16, 16, 1, 1); - } - - #[test] - pub fn large() { - test_with_params(134, 242, 250, 1, 1); - } - - #[test] - fn swapped_batches_no_padding() { - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims( - MatmulStrategy::Tiling2d(Tiling2dConfig::default()), - swap, - swap, - shape_lhs, - shape_rhs, - ); - } - - #[test] - fn swapped_row_col_no_padding() { - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims( - MatmulStrategy::Tiling2d((Tiling2dConfig::default())), - swap_lhs, - swap_rhs, - shape_lhs, - shape_rhs, - ); - } - - #[test] - fn swapped_row_with_batch_no_padding() { - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims( - MatmulStrategy::Tiling2d(Tiling2dConfig::default()), - swap_lhs, - swap_rhs, - shape_lhs, - shape_rhs, - ); - } - - fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference( - MatmulStrategy::Tiling2d(Tiling2dConfig::default()), - shape_lhs, - shape_rhs, - ); - } - } - - mod tiling2d_cube { - use super::*; - - #[test] - pub fn straightforward() { - test_with_params(1, 2, 1, 1, 1); - } - - #[test] - pub fn shapes_smaller_than_blocks() { - test_with_params(8, 8, 8, 1, 1); - } - - #[test] - pub fn shapes_equal_blocks() { - test_with_params(64, 32, 64, 2, 2); - } - - #[test] - pub fn m_exceeds_block() { - test_with_params(75, 32, 64, 2, 2); - } - - #[test] - pub fn k_exceeds_block() { - test_with_params(64, 33, 32, 1, 1); - } - - #[test] - pub fn test_matmul_irregular_shape() { - test_with_params(123, 255, 72, 3, 5); - } - - #[test] - pub fn test64_matmul_unpadded_n_exceeds_block() { - test_with_params(64, 32, 75, 2, 2); - } - - #[test] - pub fn n_smaller_than_m() { - test_with_params(8, 8, 3, 1, 1); - } - - #[test] - pub fn m_smaller_than_n() { - test_with_params(3, 8, 8, 1, 1); - } - - #[test] - pub fn k_smaller_than_m_n() { - test_with_params(8, 3, 8, 1, 1); - } - - #[test] - pub fn k_larger_than_m_n() { - test_with_params(8, 48, 8, 1, 1); - } - - #[test] - pub fn multibatch_1_dim() { - test_with_params(8, 8, 8, 3, 1); - } - - #[test] - pub fn multibatch_2_dims() { - test_with_params(8, 8, 8, 3, 4); - } - - #[test] - pub fn blocks_divide_shapes_unevenly() { - test_with_params(7, 7, 7, 1, 1); - } - - #[test] - pub fn medium() { - test_with_params(17, 16, 16, 1, 1); - } - - #[test] - pub fn large() { - test_with_params(256, 256, 256, 1, 1); - } - - #[test] - pub fn use_vec2() { - test_with_params(2, 2, 2, 1, 1); - } - - #[test] - fn swapped_batches_no_padding() { - let swap = [0, 1]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dCube(Tiling2dConfig::default()), - swap, - swap, - shape_lhs, - shape_rhs, - ); - } - - #[test] - fn swapped_row_col_no_padding() { - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, 4, 4]; - let shape_rhs = [3, 2, 4, 4]; - same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())), - swap_lhs, - swap_rhs, - shape_lhs, - shape_rhs, - ); - } - - #[test] - fn swapped_lhs_row_col_large_uneven_m() { - let (m, k, n) = (252, 256, 256); - let swap_lhs = [2, 3]; - let swap_rhs = [0, 0]; - let shape_lhs = [3, 2, k, m]; - let shape_rhs = [3, 2, k, n]; - same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())), - swap_lhs, - swap_rhs, - shape_lhs, - shape_rhs, - ); - } - - #[test] - fn swapped_rhs_row_col_large_uneven_n() { - let (m, k, n) = (256, 256, 252); - let swap_lhs = [0, 0]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, m, k]; - let shape_rhs = [3, 2, n, k]; - same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())), - swap_lhs, - swap_rhs, - shape_lhs, - shape_rhs, - ); - } - - #[test] - fn swapped_both_row_col_large_uneven_k() { - let (m, k, n) = (256, 252, 256); - let swap_lhs = [2, 3]; - let swap_rhs = [2, 3]; - let shape_lhs = [3, 2, k, m]; - let shape_rhs = [3, 2, n, k]; - same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())), - swap_lhs, - swap_rhs, - shape_lhs, - shape_rhs, - ); - } - - #[test] - fn swapped_row_with_batch_no_padding() { - let swap_lhs = [0, 3]; - let swap_rhs = [0, 2]; - let shape_lhs = [4, 4, 4, 4]; - let shape_rhs = [4, 4, 4, 4]; - same_as_reference_swapped_dims( - MatmulStrategy::Tiling2dCube(Tiling2dConfig::default()), - swap_lhs, - swap_rhs, - shape_lhs, - shape_rhs, - ); - } - - fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) { - let shape_lhs = [batch_1, batch_2, m, k]; - let shape_rhs = [batch_1, batch_2, k, n]; - same_as_reference( - MatmulStrategy::Tiling2dCube(Tiling2dConfig::default()), - shape_lhs, - shape_rhs, - ); - } - } - mod padding { use super::*; use burn_jit::kernel::matmul::padding::{crop, pad_round}; diff --git a/crates/burn-jit/src/tests/matmul_cube.rs b/crates/burn-jit/src/tests/matmul_cube.rs deleted file mode 100644 index 8cea6a3990..0000000000 --- a/crates/burn-jit/src/tests/matmul_cube.rs +++ /dev/null @@ -1,125 +0,0 @@ -#[burn_tensor_testgen::testgen(matmul_cube)] -mod tests { - use super::*; - use burn_jit::kernel::matmul::tiling2d_cube::{ - compute_loop_tests, load_shared_memory_tests, outer_product_tests, write_output_tests, - }; - use burn_jit::kernel::matmul::{matmul, MatmulStrategy, Tiling2dConfig}; - use burn_tensor::{Shape, Tensor}; - - #[test] - pub fn tiling2d_matmul_outer_product_vectorized_test() { - outer_product_tests::tile_outer_product_vectorized_unit_test::( - &Default::default(), - ) - } - - #[test] - pub fn tiling2d_matmul_outer_product_vectorized_test_2() { - outer_product_tests::tile_outer_product_vectorized_unit_test_2::( - &Default::default(), - ) - } - - #[test] - pub fn tiling2d_matmul_compute_loop_vectorized_test() { - compute_loop_tests::compute_loop_unit_test::(&Default::default()) - } - - #[test] - pub fn compute_loop_unit_offset_test() { - compute_loop_tests::compute_loop_unit_offset_test::(&Default::default()) - } - - #[test] - pub fn load_lhs_transposed_unit_test() { - load_shared_memory_tests::load_lhs_transposed_unit_test::(&Default::default()) - } - - #[test] - pub fn load_lhs_transposed_cube_test() { - load_shared_memory_tests::load_lhs_transposed_cube_test::(&Default::default()) - } - - #[test] - pub fn load_lhs_plain_unit_test() { - load_shared_memory_tests::load_lhs_plain_unit_test::(&Default::default()) - } - - #[test] - pub fn load_lhs_plain_out_of_bounds_unit_test() { - load_shared_memory_tests::load_lhs_plain_out_of_bounds_unit_test::( - &Default::default(), - ) - } - - #[test] - pub fn load_lhs_transposed_out_of_bounds_cube_test() { - load_shared_memory_tests::load_lhs_transposed_out_of_bounds_cube_test::( - &Default::default(), - ) - } - - #[test] - pub fn load_lhs_transposed_offset_cube_test() { - load_shared_memory_tests::load_lhs_transposed_offset_cube_test::( - &Default::default(), - ) - } - - #[test] - pub fn load_rhs_plain_unit_test() { - load_shared_memory_tests::load_rhs_plain_unit_test::(&Default::default()) - } - - #[test] - pub fn load_rhs_plain_cube_test() { - load_shared_memory_tests::load_rhs_plain_cube_test::(&Default::default()) - } - - #[test] - pub fn load_rhs_plain_cube_offset_test() { - load_shared_memory_tests::load_rhs_plain_cube_offset_test::(&Default::default()) - } - - #[test] - pub fn load_rhs_transposed_unit_test() { - load_shared_memory_tests::load_rhs_transposed_unit_test::(&Default::default()) - } - - #[test] - pub fn load_rhs_transposed_out_of_bounds_unit_test() { - load_shared_memory_tests::load_rhs_transposed_out_of_bounds_unit_test::( - &Default::default(), - ) - } - - #[test] - pub fn write_to_output_over_height_unit_test() { - write_output_tests::write_to_output_over_height_unit_test::(&Default::default()) - } - - #[test] - pub fn write_to_output_over_width_unit_test() { - write_output_tests::write_to_output_over_width_unit_test::(&Default::default()) - } - - #[test] - pub fn write_to_output_vectorized_less_than_tile_unit_test() { - write_output_tests::write_to_output_vectorized_less_than_tile_unit_test::( - &Default::default(), - ) - } - - #[test] - pub fn write_to_output_scalar_unit_test() { - write_output_tests::write_to_output_scalar_unit_test::(&Default::default()) - } - - #[test] - pub fn write_to_output_scalar_out_of_bounds_cube_test() { - write_output_tests::write_to_output_scalar_out_of_bounds_cube_test::( - &Default::default(), - ) - } -} diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index d6ec2edf17..a0bb1aa7cf 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -13,7 +13,6 @@ mod gather; mod mask_fill; mod mask_where; mod matmul; -pub mod matmul_cube; mod max_pool2d; mod max_pool2d_backward; mod normal; @@ -74,8 +73,6 @@ macro_rules! testgen_all { burn_jit::testgen_cat!(); burn_jit::testgen_clamp!(); burn_jit::testgen_unary!(); - burn_jit::testgen_matmul!(); - burn_jit::testgen_matmul_cube!(); } } mod jit_fusion { diff --git a/crates/burn-jit/src/tune_key.rs b/crates/burn-jit/src/tune_key.rs index 9513f02f5d..ebc1e9202c 100644 --- a/crates/burn-jit/src/tune_key.rs +++ b/crates/burn-jit/src/tune_key.rs @@ -1,5 +1,5 @@ use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey}; -use burn_compute::tune::AutotuneKey; +use cubecl::tune::AutotuneKey; use serde::{Deserialize, Serialize}; use std::fmt::Display; diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index c23bde5e1d..7bc5f681cf 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -17,10 +17,14 @@ experimental-named-tensor = [] export_tests = ["burn-tensor-testgen"] std = ["rand/std", "half/std", "num-traits/std", "burn-common/std", "burn-common/rayon"] repr = [] +cubecl = ["dep:cubecl"] +cubecl-wgpu = ["cubecl", "cubecl/wgpu"] +cubecl-cuda = ["cubecl", "cubecl/cuda"] [dependencies] burn-common = { path = "../burn-common", version = "0.14.0", default-features = false} burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.14.0", optional = true } +cubecl = { workspace = true, optional = true } derive-new = { workspace = true } half = { workspace = true, features = ["bytemuck"] } diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index dade906ac3..3d5032b562 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -26,3 +26,65 @@ pub(crate) use tensor::check::macros::check; pub use tensor::*; pub use burn_common::reader::*; // Useful so that backends don't have to add `burn_common` as a dependency. + +#[cfg(feature = "cubecl")] +mod cube { + use cubecl::ir::{Elem, FloatKind, IntKind}; + + impl From for cubecl::ir::Elem { + fn from(dtype: crate::DType) -> Self { + match dtype { + crate::DType::F64 => Elem::Float(FloatKind::F64), + crate::DType::F32 => Elem::Float(FloatKind::F32), + crate::DType::F16 => Elem::Float(FloatKind::F16), + crate::DType::BF16 => Elem::Float(FloatKind::BF16), + crate::DType::I64 => Elem::Int(IntKind::I64), + crate::DType::I32 => Elem::Int(IntKind::I32), + crate::DType::I16 => panic!("i16 isn't supported yet."), + crate::DType::I8 => panic!("i8 isn't supported yet."), + crate::DType::U64 => Elem::UInt, + crate::DType::U32 => Elem::UInt, + crate::DType::U8 => panic!("u8 isn't supported yet."), + crate::DType::Bool => Elem::Bool, + crate::DType::QFloat(_) => panic!("quantized type is not supported yet."), + } + } + } +} + +#[cfg(feature = "cubecl-wgpu")] +mod cube_wgpu { + use crate::backend::{DeviceId, DeviceOps}; + use cubecl::wgpu::WgpuDevice; + + impl DeviceOps for WgpuDevice { + fn id(&self) -> DeviceId { + match self { + WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32), + WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32), + WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32), + WgpuDevice::Cpu => DeviceId::new(3, 0), + WgpuDevice::BestAvailable => DeviceId::new(4, 0), + // For an existing device, use the 64 bit wgpu device ID as the burn DeviceID. + // We're only storing 32 bits, so wrap the the 64 bit value to 32 bits. This + // might collide - but a 1 in 4 billion chance seems ok given there's only a few + // devices in flight at any time. + WgpuDevice::Existing(id) => { + DeviceId::new(5, (id.inner() % (u32::MAX as u64)) as u32) + } + } + } + } +} + +#[cfg(feature = "cubecl-cuda")] +mod cube_cuda { + use crate::backend::{DeviceId, DeviceOps}; + use cubecl::cuda::CudaDevice; + + impl DeviceOps for CudaDevice { + fn id(&self) -> DeviceId { + DeviceId::new(0, self.index as u32) + } + } +} diff --git a/crates/burn-tensor/src/tensor/backend/device.rs b/crates/burn-tensor/src/tensor/backend/device.rs index 64279f5b01..f075f56767 100644 --- a/crates/burn-tensor/src/tensor/backend/device.rs +++ b/crates/burn-tensor/src/tensor/backend/device.rs @@ -12,3 +12,9 @@ pub trait DeviceOps: Clone + Default + PartialEq + Send + Sync + core::fmt::Debu /// Return the [device id](DeviceId). fn id(&self) -> DeviceId; } + +impl core::fmt::Display for DeviceId { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_fmt(format_args!("{:?}", self)) + } +} diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index 5accb02837..0df271431c 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -59,9 +59,21 @@ impl TensorData { } /// Initializes a new tensor data structure from the provided values. - fn init>>(value: Vec, shape: S, dtype: DType) -> Self { + fn init>>(mut value: Vec, shape: S, dtype: DType) -> Self { + // Ensure `E` satisfies the `Pod` trait requirements + assert_eq!(core::mem::size_of::() % core::mem::size_of::(), 0); + + let factor = core::mem::size_of::() / core::mem::size_of::(); + let len = value.len() * factor; + let capacity = value.capacity() * factor; + let ptr = value.as_mut_ptr(); + + core::mem::forget(value); + + let bytes = unsafe { Vec::from_raw_parts(ptr as *mut u8, len, capacity) }; + Self { - bytes: bytemuck::checked::cast_slice(&value).to_vec(), + bytes, shape: shape.into(), dtype, } @@ -1050,4 +1062,16 @@ mod tests { data1.assert_approx_eq(&data2, 2); } + + #[test] + fn should_convert_bytes_correctly() { + let mut vector: Vec = Vec::with_capacity(5); + vector.push(2.0); + vector.push(3.0); + let data1 = TensorData::new(vector, vec![2]); + + let factor = core::mem::size_of::() / core::mem::size_of::(); + assert_eq!(data1.bytes.len(), 2 * factor); + assert_eq!(data1.bytes.capacity(), 5 * factor); + } } diff --git a/crates/burn-tensor/src/tests/quantization/scheme.rs b/crates/burn-tensor/src/tests/quantization/scheme.rs index 11b59ceb06..810c0a1ece 100644 --- a/crates/burn-tensor/src/tests/quantization/scheme.rs +++ b/crates/burn-tensor/src/tests/quantization/scheme.rs @@ -20,7 +20,7 @@ mod tests { qparams .scale .into_data() - .assert_approx_eq(&TensorData::from([0.009_019_608]), 9); + .assert_approx_eq(&TensorData::from([0.009_019_608]), 8); qparams .offset .unwrap() @@ -42,7 +42,7 @@ mod tests { qparams .scale .into_data() - .assert_approx_eq(&TensorData::from([0.014_173_228]), 9); + .assert_approx_eq(&TensorData::from([0.014_173_228]), 8); assert!(qparams.offset.is_none()); } } diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index 48dab9b8b0..c6a2b9a544 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -11,34 +11,21 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-wgpu" version.workspace = true [features] -default = ["fusion", "burn-jit/default"] +default = ["fusion", "burn-jit/default", "cubecl/default"] fusion = ["burn-fusion", "burn-jit/fusion"] autotune = ["burn-jit/autotune"] -template = ["burn-jit/template", "burn-cube/template"] +template = ["burn-jit/template", "cubecl/template"] doc = ["burn-jit/doc"] std = ["burn-jit/std"] [dependencies] +cubecl = { workspace = true, features = ["wgpu"] } + burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false } -burn-compute = { path = "../burn-compute", version = "0.14.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.14.0" } -burn-common = { path = "../burn-common", version = "0.14.0" } +burn-tensor = { path = "../burn-tensor", version = "0.14.0", features = ["cubecl-wgpu"] } burn-fusion = { path = "../burn-fusion", version = "0.14.0", optional = true } -burn-cube = { path = "../burn-cube", version = "0.14.0" } - -bytemuck = { workspace = true } -wgpu = { workspace = true, features = ["fragile-send-sync-non-atomic-wasm"] } -pollster = { workspace = true } - -log = { workspace = true } -async-channel = { workspace = true } -derive-new = { workspace = true } -hashbrown = { workspace = true } [dev-dependencies] burn-jit = { path = "../burn-jit", version = "0.14.0", default-features = false, features = [ "export_tests", ] } -burn-cube = { path = "../burn-cube", version = "0.14.0", features = [ - "export_tests", -] } diff --git a/crates/burn-wgpu/src/compiler/mod.rs b/crates/burn-wgpu/src/compiler/mod.rs deleted file mode 100644 index 4cd660d773..0000000000 --- a/crates/burn-wgpu/src/compiler/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod wgsl; diff --git a/crates/burn-wgpu/src/compiler/wgsl/base.rs b/crates/burn-wgpu/src/compiler/wgsl/base.rs deleted file mode 100644 index 3eda1c3d23..0000000000 --- a/crates/burn-wgpu/src/compiler/wgsl/base.rs +++ /dev/null @@ -1,307 +0,0 @@ -use burn_cube::ir as cube; -use std::fmt::Display; - -#[derive(Debug, Clone)] -pub enum Variable { - SubgroupSize, - GlobalInputArray(u16, Item), - GlobalOutputArray(u16, Item), - GlobalScalar(u16, Elem, cube::Elem), - ConstantScalar(f64, Elem), - Local { - id: u16, - item: Item, - depth: u8, - }, - Named { - name: String, - item: Item, - is_array: bool, - }, - Slice { - id: u16, - item: Item, - depth: u8, - }, - LocalScalar { - id: u16, - elem: Elem, - depth: u8, - }, - SharedMemory(u16, Item, u32), - LocalArray(u16, Item, u8, u32), - Id, - LocalInvocationIndex, - LocalInvocationIdX, - LocalInvocationIdY, - LocalInvocationIdZ, - Rank, - WorkgroupId, - WorkgroupIdX, - WorkgroupIdY, - WorkgroupIdZ, - GlobalInvocationIdX, - GlobalInvocationIdY, - GlobalInvocationIdZ, - WorkgroupSize, - WorkgroupSizeX, - WorkgroupSizeY, - WorkgroupSizeZ, - NumWorkgroups, - NumWorkgroupsX, - NumWorkgroupsY, - NumWorkgroupsZ, -} - -#[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub enum Elem { - F32, - I32, - U32, - Bool, -} - -#[derive(Debug, Clone, PartialEq, Eq, Copy)] -pub enum Item { - Vec4(Elem), - Vec3(Elem), - Vec2(Elem), - Scalar(Elem), -} - -#[derive(Debug, Clone)] -pub struct IndexedVariable { - var: Variable, - index: usize, -} - -impl Variable { - pub fn is_always_scalar(&self) -> bool { - match self { - Variable::GlobalScalar(_, _, _) => true, - Variable::ConstantScalar(_, _) => true, - Variable::LocalScalar { - id: _, - elem: _, - depth: _, - } => true, - Variable::Id => true, - Variable::LocalInvocationIndex => true, - Variable::LocalInvocationIdX => true, - Variable::LocalInvocationIdY => true, - Variable::LocalInvocationIdZ => true, - Variable::Rank => true, - Variable::GlobalInputArray(_, _) => false, - Variable::GlobalOutputArray(_, _) => false, - Variable::SharedMemory(_, _, _) => false, - Variable::LocalArray(_, _, _, _) => false, - Variable::Local { .. } => false, - Variable::Named { .. } => false, - Variable::Slice { .. } => false, - Variable::WorkgroupIdX => true, - Variable::WorkgroupIdY => true, - Variable::WorkgroupIdZ => true, - Variable::GlobalInvocationIdX => true, - Variable::GlobalInvocationIdY => true, - Variable::GlobalInvocationIdZ => true, - Variable::WorkgroupSizeX => true, - Variable::WorkgroupSizeY => true, - Variable::WorkgroupSizeZ => true, - Variable::NumWorkgroupsX => true, - Variable::NumWorkgroupsY => true, - Variable::NumWorkgroupsZ => true, - Variable::WorkgroupId => true, - Variable::WorkgroupSize => true, - Variable::NumWorkgroups => true, - Variable::SubgroupSize => true, - } - } - pub fn index(&self, index: usize) -> IndexedVariable { - IndexedVariable { - var: self.clone(), - index, - } - } - - pub fn item(&self) -> Item { - match self { - Self::GlobalInputArray(_, e) => *e, - Self::GlobalOutputArray(_, e) => *e, - Self::SharedMemory(_, e, _) => *e, - Self::LocalArray(_, e, _, _) => *e, - Self::Local { item, .. } => *item, - Self::Slice { item, .. } => *item, - Self::Named { item, .. } => *item, - Self::ConstantScalar(_, e) => Item::Scalar(*e), - Self::GlobalScalar(_, e, _) => Item::Scalar(*e), - Self::Id => Item::Scalar(Elem::U32), - Self::LocalInvocationIndex => Item::Scalar(Elem::U32), - Self::LocalInvocationIdX => Item::Scalar(Elem::U32), - Self::LocalInvocationIdY => Item::Scalar(Elem::U32), - Self::LocalInvocationIdZ => Item::Scalar(Elem::U32), - Self::Rank => Item::Scalar(Elem::U32), - Self::LocalScalar { elem, .. } => Item::Scalar(*elem), - Self::WorkgroupId => Item::Scalar(Elem::U32), - Self::WorkgroupIdX => Item::Scalar(Elem::U32), - Self::WorkgroupIdY => Item::Scalar(Elem::U32), - Self::WorkgroupIdZ => Item::Scalar(Elem::U32), - Self::GlobalInvocationIdX => Item::Scalar(Elem::U32), - Self::GlobalInvocationIdY => Item::Scalar(Elem::U32), - Self::GlobalInvocationIdZ => Item::Scalar(Elem::U32), - Self::WorkgroupSize => Item::Scalar(Elem::U32), - Self::WorkgroupSizeX => Item::Scalar(Elem::U32), - Self::WorkgroupSizeY => Item::Scalar(Elem::U32), - Self::WorkgroupSizeZ => Item::Scalar(Elem::U32), - Self::NumWorkgroups => Item::Scalar(Elem::U32), - Self::NumWorkgroupsX => Item::Scalar(Elem::U32), - Self::NumWorkgroupsY => Item::Scalar(Elem::U32), - Self::NumWorkgroupsZ => Item::Scalar(Elem::U32), - Self::SubgroupSize => Item::Scalar(Elem::U32), - } - } - pub fn elem(&self) -> Elem { - *self.item().elem() - } -} - -impl Item { - pub fn elem(&self) -> &Elem { - match self { - Item::Vec4(e) => e, - Item::Vec3(e) => e, - Item::Vec2(e) => e, - Item::Scalar(e) => e, - } - } - - pub fn vectorization_factor(&self) -> usize { - match self { - Item::Vec4(_) => 4, - Item::Vec3(_) => 3, - Item::Vec2(_) => 2, - Item::Scalar(_) => 1, - } - } -} - -impl Elem { - pub fn size(&self) -> usize { - match self { - Self::F32 => core::mem::size_of::(), - Self::I32 => core::mem::size_of::(), - Self::U32 => core::mem::size_of::(), - Self::Bool => core::mem::size_of::(), - } - } -} - -impl Display for Elem { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::F32 => f.write_str("f32"), - Self::I32 => f.write_str("i32"), - Self::U32 => f.write_str("u32"), - Self::Bool => f.write_str("bool"), - } - } -} - -impl Display for Item { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Item::Vec4(elem) => f.write_fmt(format_args!("vec4<{elem}>")), - Item::Vec3(elem) => f.write_fmt(format_args!("vec3<{elem}>")), - Item::Vec2(elem) => f.write_fmt(format_args!("vec2<{elem}>")), - Item::Scalar(elem) => f.write_fmt(format_args!("{elem}")), - } - } -} - -impl Display for Variable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Variable::GlobalInputArray(number, _) => { - f.write_fmt(format_args!("input_{number}_global")) - } - Variable::LocalScalar { - id: index, - elem: _, - depth: scope_depth, - } => f.write_fmt(format_args!("s_{scope_depth}_{index}")), - Variable::Local { - id: index, - item: _, - depth: scope_depth, - } => f.write_fmt(format_args!("l_{scope_depth}_{index}")), - Variable::Named { name, .. } => f.write_str(name), - Variable::Slice { - id: index, - item: _, - depth: scope_depth, - } => f.write_fmt(format_args!("slice_{scope_depth}_{index}")), - Variable::GlobalOutputArray(number, _) => { - f.write_fmt(format_args!("output_{number}_global")) - } - Variable::GlobalScalar(number, _, elem) => { - f.write_fmt(format_args!("scalars_{elem}[{number}]")) - } - Variable::ConstantScalar(number, elem) => match elem { - Elem::F32 => f.write_fmt(format_args!("{number}f")), - Elem::I32 => f.write_fmt(format_args!("{number}i")), - Elem::U32 => f.write_fmt(format_args!("{number}u")), - Elem::Bool => f.write_fmt(format_args!("bool({number})")), - }, - Variable::SharedMemory(number, _, _) => { - f.write_fmt(format_args!("shared_memory_{number}")) - } - Variable::LocalArray(number, _, scope_depth, _) => { - f.write_fmt(format_args!("a_{scope_depth}_{number}")) - } - Variable::Id => f.write_str("id"), - Variable::LocalInvocationIndex => f.write_str("local_idx"), - Variable::LocalInvocationIdX => f.write_str("local_invocation_id.x"), - Variable::LocalInvocationIdY => f.write_str("local_invocation_id.y"), - Variable::LocalInvocationIdZ => f.write_str("local_invocation_id.z"), - Variable::Rank => f.write_str("rank"), - Variable::WorkgroupId => f.write_str("workgroup_id_no_axis"), - Variable::WorkgroupIdX => f.write_str("workgroup_id.x"), - Variable::WorkgroupIdY => f.write_str("workgroup_id.y"), - Variable::WorkgroupIdZ => f.write_str("workgroup_id.z"), - Variable::GlobalInvocationIdX => f.write_str("global_id.x"), - Variable::GlobalInvocationIdY => f.write_str("global_id.y"), - Variable::GlobalInvocationIdZ => f.write_str("global_id.z"), - Variable::WorkgroupSizeX => f.write_str("WORKGROUP_SIZE_X"), - Variable::WorkgroupSizeY => f.write_str("WORKGROUP_SIZE_Y"), - Variable::WorkgroupSizeZ => f.write_str("WORKGROUP_SIZE_Z"), - Variable::NumWorkgroupsX => f.write_str("num_workgroups.x"), - Variable::NumWorkgroupsY => f.write_str("num_workgroups.y"), - Variable::NumWorkgroupsZ => f.write_str("num_workgroups.z"), - Variable::WorkgroupSize => f.write_str("workgroup_size_no_axis"), - Variable::NumWorkgroups => f.write_str("num_workgroups_no_axis"), - Variable::SubgroupSize => f.write_str("subgroup_size"), - } - } -} - -impl Display for IndexedVariable { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let should_index = |item: &Item| match item { - Item::Vec4(_) => true, - Item::Vec3(_) => true, - Item::Vec2(_) => true, - Item::Scalar(_) => false, - }; - - let var = &self.var; - let item = var.item(); - let index = self.index; - - match self.var { - Variable::GlobalScalar(_, _, _) => f.write_fmt(format_args!("{var}")), - _ => match should_index(&item) { - true => f.write_fmt(format_args!("{var}[{index}]")), - false => f.write_fmt(format_args!("{var}")), - }, - } - } -} diff --git a/crates/burn-wgpu/src/compiler/wgsl/body.rs b/crates/burn-wgpu/src/compiler/wgsl/body.rs deleted file mode 100644 index debb734045..0000000000 --- a/crates/burn-wgpu/src/compiler/wgsl/body.rs +++ /dev/null @@ -1,38 +0,0 @@ -use super::Instruction; -use std::fmt::Display; - -/// A body is composed of a list of [instructions](Instruction). -/// -/// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size -/// X and Y, but with Z=1. -#[derive(Debug, Clone)] -pub struct Body { - pub instructions: Vec, - pub rank: bool, - pub id: bool, - pub stride: bool, - pub shape: bool, -} - -impl Display for Body { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if self.id { - f.write_str( - "let id = (global_id.z * num_workgroups.x * WORKGROUP_SIZE_X * num_workgroups.y * WORKGROUP_SIZE_Y) + (global_id.y * num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;\n", - )?; - } - if self.rank || self.stride || self.shape { - f.write_str("let rank: u32 = info[0];\n")?; - } - - if self.stride || self.shape { - f.write_str("let rank_2: u32 = rank * 2u;\n")?; - } - - for ops in self.instructions.iter() { - f.write_fmt(format_args!("{ops}"))?; - } - - Ok(()) - } -} diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs deleted file mode 100644 index f5aeac4fc5..0000000000 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ /dev/null @@ -1,744 +0,0 @@ -use super::{shader::ComputeShader, Item, SharedMemory}; -use super::{LocalArray, Subgroup}; -use crate::compiler::wgsl; -use burn_cube::ir as cube; - -/// Wgsl Compiler. -#[derive(Clone, Default)] -pub struct WgslCompiler { - num_inputs: usize, - num_outputs: usize, - local_invocation_index: bool, - local_invocation_id: bool, - global_invocation_id: bool, - workgroup_id: bool, - rank: bool, - id: bool, - stride: bool, - shape: bool, - num_workgroups: bool, - workgroup_id_no_axis: bool, - workgroup_size_no_axis: bool, - num_workgroup_no_axis: bool, - shared_memories: Vec, - local_arrays: Vec, -} - -impl core::fmt::Debug for WgslCompiler { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("WgslCompiler") - } -} - -impl burn_cube::Compiler for WgslCompiler { - type Representation = ComputeShader; - - fn compile(shader: cube::KernelDefinition) -> Self::Representation { - let mut compiler = Self::default(); - compiler.compile_shader(shader) - } - - fn elem_size(elem: cube::Elem) -> usize { - Self::compile_elem(elem).size() - } - - fn max_shared_memory_size() -> usize { - 8192 - } -} - -impl WgslCompiler { - fn compile_shader(&mut self, mut value: cube::KernelDefinition) -> wgsl::ComputeShader { - self.num_inputs = value.inputs.len(); - self.num_outputs = value.outputs.len(); - - let instructions = self.compile_scope(&mut value.body); - let extensions = register_extensions(&instructions); - let body = wgsl::Body { - instructions, - rank: true, - id: self.id, - stride: self.stride, - shape: self.shape, - }; - - wgsl::ComputeShader { - inputs: value - .inputs - .into_iter() - .map(Self::compile_binding) - .collect(), - outputs: value - .outputs - .into_iter() - .map(Self::compile_binding) - .collect(), - named: value - .named - .into_iter() - .map(|(name, binding)| (name, Self::compile_binding(binding))) - .collect(), - shared_memories: self.shared_memories.clone(), - local_arrays: self.local_arrays.clone(), - workgroup_size: value.cube_dim, - global_invocation_id: self.global_invocation_id || self.id, - local_invocation_index: self.local_invocation_index, - local_invocation_id: self.local_invocation_id, - num_workgroups: self.id - || self.num_workgroups - || self.num_workgroup_no_axis - || self.workgroup_id_no_axis, - workgroup_id: self.workgroup_id || self.workgroup_id_no_axis, - body, - extensions, - num_workgroups_no_axis: self.num_workgroup_no_axis, - workgroup_id_no_axis: self.workgroup_id_no_axis, - workgroup_size_no_axis: self.workgroup_size_no_axis, - } - } - - fn compile_item(item: cube::Item) -> Item { - let elem = Self::compile_elem(item.elem); - match item.vectorization { - 1 => wgsl::Item::Scalar(elem), - 2 => wgsl::Item::Vec2(elem), - 3 => wgsl::Item::Vec3(elem), - 4 => wgsl::Item::Vec4(elem), - _ => panic!("Unsupported vectorizations scheme {:?}", item.vectorization), - } - } - - fn compile_elem(value: cube::Elem) -> wgsl::Elem { - match value { - cube::Elem::Float(f) => match f { - cube::FloatKind::F16 => panic!("f16 is not yet supported"), - cube::FloatKind::BF16 => panic!("bf16 is not a valid WgpuElement"), - cube::FloatKind::F32 => wgsl::Elem::F32, - cube::FloatKind::F64 => panic!("f64 is not a valid WgpuElement"), - }, - cube::Elem::Int(i) => match i { - cube::IntKind::I32 => wgsl::Elem::I32, - cube::IntKind::I64 => panic!("i64 is not a valid WgpuElement"), - }, - cube::Elem::UInt => wgsl::Elem::U32, - cube::Elem::Bool => wgsl::Elem::Bool, - } - } - - fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable { - match value { - cube::Variable::GlobalInputArray { id, item } => { - wgsl::Variable::GlobalInputArray(id, Self::compile_item(item)) - } - cube::Variable::GlobalScalar { id, elem } => { - wgsl::Variable::GlobalScalar(id, Self::compile_elem(elem), elem) - } - cube::Variable::Local { id, item, depth } => wgsl::Variable::Local { - id, - item: Self::compile_item(item), - depth, - }, - cube::Variable::Slice { id, item, depth } => wgsl::Variable::Slice { - id, - item: Self::compile_item(item), - depth, - }, - cube::Variable::LocalScalar { id, elem, depth } => wgsl::Variable::LocalScalar { - id, - elem: Self::compile_elem(elem), - depth, - }, - cube::Variable::GlobalOutputArray { id, item } => { - wgsl::Variable::GlobalOutputArray(id, Self::compile_item(item)) - } - cube::Variable::ConstantScalar { value, elem } => { - wgsl::Variable::ConstantScalar(value, Self::compile_elem(elem)) - } - cube::Variable::SharedMemory { id, item, length } => { - let item = Self::compile_item(item); - if !self.shared_memories.iter().any(|s| s.index == id) { - self.shared_memories - .push(SharedMemory::new(id, item, length)); - } - wgsl::Variable::SharedMemory(id, item, length) - } - cube::Variable::LocalArray { - id, - item, - depth, - length, - } => { - let item = Self::compile_item(item); - if !self.local_arrays.iter().any(|s| s.index == id) { - self.local_arrays - .push(LocalArray::new(id, item, depth, length)); - } - wgsl::Variable::LocalArray(id, item, depth, length) - } - cube::Variable::AbsolutePos => { - self.id = true; - wgsl::Variable::Id - } - cube::Variable::Rank => { - self.rank = true; - wgsl::Variable::Rank - } - cube::Variable::UnitPos => { - self.local_invocation_index = true; - wgsl::Variable::LocalInvocationIndex - } - cube::Variable::UnitPosX => { - self.local_invocation_id = true; - wgsl::Variable::LocalInvocationIdX - } - cube::Variable::UnitPosY => { - self.local_invocation_id = true; - wgsl::Variable::LocalInvocationIdY - } - cube::Variable::UnitPosZ => { - self.local_invocation_id = true; - wgsl::Variable::LocalInvocationIdZ - } - cube::Variable::CubePosX => { - self.workgroup_id = true; - wgsl::Variable::WorkgroupIdX - } - cube::Variable::CubePosY => { - self.workgroup_id = true; - wgsl::Variable::WorkgroupIdY - } - cube::Variable::CubePosZ => { - self.workgroup_id = true; - wgsl::Variable::WorkgroupIdZ - } - cube::Variable::AbsolutePosX => { - self.global_invocation_id = true; - wgsl::Variable::GlobalInvocationIdX - } - cube::Variable::AbsolutePosY => { - self.global_invocation_id = true; - wgsl::Variable::GlobalInvocationIdY - } - cube::Variable::AbsolutePosZ => { - self.global_invocation_id = true; - wgsl::Variable::GlobalInvocationIdZ - } - cube::Variable::CubeDimX => wgsl::Variable::WorkgroupSizeX, - cube::Variable::CubeDimY => wgsl::Variable::WorkgroupSizeY, - cube::Variable::CubeDimZ => wgsl::Variable::WorkgroupSizeZ, - cube::Variable::CubeCountX => { - self.num_workgroups = true; - wgsl::Variable::NumWorkgroupsX - } - cube::Variable::CubeCountY => { - self.num_workgroups = true; - wgsl::Variable::NumWorkgroupsY - } - cube::Variable::CubeCountZ => { - self.num_workgroups = true; - wgsl::Variable::NumWorkgroupsZ - } - cube::Variable::CubePos => { - self.workgroup_id_no_axis = true; - wgsl::Variable::WorkgroupId - } - cube::Variable::CubeDim => { - self.workgroup_size_no_axis = true; - wgsl::Variable::WorkgroupSize - } - cube::Variable::CubeCount => { - self.num_workgroup_no_axis = true; - wgsl::Variable::NumWorkgroups - } - cube::Variable::SubcubeDim => wgsl::Variable::SubgroupSize, - cube::Variable::Matrix { .. } => { - panic!("Cooperative matrix-multiply and accumulate not supported.") - } - } - } - - fn compile_scope(&mut self, value: &mut cube::Scope) -> Vec { - let mut instructions = Vec::new(); - let processing = value.process(); - - for var in processing.variables { - // We don't declare slices. - if let cube::Variable::Slice { .. } = var { - continue; - } - - instructions.push(wgsl::Instruction::DeclareVariable { - var: self.compile_variable(var), - }); - } - - processing - .operations - .into_iter() - .for_each(|op| self.compile_operation(&mut instructions, op, value)); - - instructions - } - - fn compile_operation( - &mut self, - instructions: &mut Vec, - operation: cube::Operation, - scope: &mut cube::Scope, - ) { - match operation { - cube::Operation::Operator(op) => instructions.push(self.compile_instruction(op)), - cube::Operation::Procedure(proc) => self.compile_procedure(instructions, proc, scope), - cube::Operation::Metadata(op) => instructions.push(self.compile_metadata(op)), - cube::Operation::Branch(val) => self.compile_branch(instructions, val), - cube::Operation::Synchronization(val) => { - self.compile_synchronization(instructions, val) - } - cube::Operation::Subcube(op) => self.compile_subgroup(instructions, op), - cube::Operation::CoopMma(_) => { - panic!("Cooperative matrix-multiply and accumulate isn't supported on wgpu.") - } - } - } - - fn compile_subgroup( - &mut self, - instructions: &mut Vec, - subgroup: cube::Subcube, - ) { - let op = match subgroup { - cube::Subcube::Elect(op) => Subgroup::Elect { - out: self.compile_variable(op.out), - }, - cube::Subcube::All(op) => Subgroup::All { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Subcube::Any(op) => Subgroup::Any { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Subcube::Broadcast(op) => Subgroup::Broadcast { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Subcube::Sum(op) => Subgroup::Sum { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Subcube::Prod(op) => Subgroup::Prod { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Subcube::And(op) => Subgroup::And { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Subcube::Or(op) => Subgroup::Or { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Subcube::Xor(op) => Subgroup::Xor { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Subcube::Min(op) => Subgroup::Min { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Subcube::Max(op) => Subgroup::Max { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - }; - - instructions.push(wgsl::Instruction::Subgroup(op)); - } - - fn compile_branch(&mut self, instructions: &mut Vec, branch: cube::Branch) { - match branch { - cube::Branch::If(mut op) => instructions.push(wgsl::Instruction::If { - cond: self.compile_variable(op.cond), - instructions: self.compile_scope(&mut op.scope), - }), - cube::Branch::IfElse(mut op) => instructions.push(wgsl::Instruction::IfElse { - cond: self.compile_variable(op.cond), - instructions_if: self.compile_scope(&mut op.scope_if), - instructions_else: self.compile_scope(&mut op.scope_else), - }), - cube::Branch::Return => instructions.push(wgsl::Instruction::Return), - cube::Branch::Break => instructions.push(wgsl::Instruction::Break), - cube::Branch::RangeLoop(mut range_loop) => { - instructions.push(wgsl::Instruction::RangeLoop { - i: self.compile_variable(range_loop.i), - start: self.compile_variable(range_loop.start), - end: self.compile_variable(range_loop.end), - instructions: self.compile_scope(&mut range_loop.scope), - }) - } - cube::Branch::Loop(mut op) => instructions.push(wgsl::Instruction::Loop { - instructions: self.compile_scope(&mut op.scope), - }), - }; - } - - fn compile_synchronization( - &mut self, - instructions: &mut Vec, - synchronization: cube::Synchronization, - ) { - match synchronization { - cube::Synchronization::SyncUnits => { - instructions.push(wgsl::Instruction::WorkgroupBarrier) - } - }; - } - - fn compile_procedure( - &mut self, - instructions: &mut Vec, - proc: cube::Procedure, - scope: &mut cube::Scope, - ) { - let mut compile = |scope: &mut cube::Scope| { - instructions.extend(self.compile_scope(scope)); - }; - - match proc { - cube::Procedure::ReadGlobalWithLayout(proc) => { - proc.expand(scope); - compile(scope); - } - cube::Procedure::ReadGlobal(proc) => { - proc.expand(scope); - compile(scope); - } - cube::Procedure::WriteGlobal(proc) => { - proc.expand(scope); - compile(scope); - } - cube::Procedure::ConditionalAssign(proc) => { - proc.expand(scope); - compile(scope); - } - cube::Procedure::CheckedIndex(proc) => { - proc.expand(scope); - compile(scope); - } - cube::Procedure::CheckedIndexAssign(proc) => { - proc.expand(scope); - compile(scope); - } - cube::Procedure::IndexOffsetGlobalWithLayout(proc) => { - proc.expand(scope); - compile(scope); - } - } - } - - fn compile_metadata(&mut self, metadata: cube::Metadata) -> wgsl::Instruction { - match metadata { - cube::Metadata::Stride { dim, var, out } => { - self.stride = true; - let position = match var { - cube::Variable::GlobalInputArray { id, .. } => id as usize, - cube::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize, - _ => panic!("Only Input and Output have a stride, got: {:?}", var), - }; - wgsl::Instruction::Stride { - dim: self.compile_variable(dim), - position, - out: self.compile_variable(out), - } - } - cube::Metadata::Shape { dim, var, out } => { - self.shape = true; - let position = match var { - cube::Variable::GlobalInputArray { id, .. } => id as usize, - cube::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize, - _ => panic!("Only Input and Output have a shape, got {:?}", var), - }; - wgsl::Instruction::Shape { - dim: self.compile_variable(dim), - position, - out: self.compile_variable(out), - } - } - cube::Metadata::Length { var, out } => wgsl::Instruction::Length { - out: self.compile_variable(out), - var: self.compile_variable(var), - }, - } - } - - fn compile_instruction(&mut self, value: cube::Operator) -> wgsl::Instruction { - match value { - cube::Operator::Max(op) => wgsl::Instruction::Max { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Min(op) => wgsl::Instruction::Min { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Add(op) => wgsl::Instruction::Add { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Fma(op) => wgsl::Instruction::Fma { - a: self.compile_variable(op.a), - b: self.compile_variable(op.b), - c: self.compile_variable(op.c), - out: self.compile_variable(op.out), - }, - cube::Operator::Index(op) => wgsl::Instruction::Index { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::UncheckedIndex(op) => wgsl::Instruction::Index { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Modulo(op) => wgsl::Instruction::Modulo { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Sub(op) => wgsl::Instruction::Sub { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Mul(op) => wgsl::Instruction::Mul { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Div(op) => wgsl::Instruction::Div { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Abs(op) => wgsl::Instruction::Abs { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::Exp(op) => wgsl::Instruction::Exp { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::Log(op) => wgsl::Instruction::Log { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::Log1p(op) => wgsl::Instruction::Log1p { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::Cos(op) => wgsl::Instruction::Cos { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::Sin(op) => wgsl::Instruction::Sin { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::Tanh(op) => wgsl::Instruction::Tanh { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::Powf(op) => wgsl::Instruction::Powf { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Sqrt(op) => wgsl::Instruction::Sqrt { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::Floor(op) => wgsl::Instruction::Floor { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::Ceil(op) => wgsl::Instruction::Ceil { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::Erf(op) => wgsl::Instruction::Erf { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::Recip(op) => wgsl::Instruction::Recip { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::Equal(op) => wgsl::Instruction::Equal { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Lower(op) => wgsl::Instruction::Lower { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Clamp(op) => wgsl::Instruction::Clamp { - input: self.compile_variable(op.input), - min_value: self.compile_variable(op.min_value), - max_value: self.compile_variable(op.max_value), - out: self.compile_variable(op.out), - }, - cube::Operator::Greater(op) => wgsl::Instruction::Greater { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::LowerEqual(op) => wgsl::Instruction::LowerEqual { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::GreaterEqual(op) => wgsl::Instruction::GreaterEqual { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::NotEqual(op) => wgsl::Instruction::NotEqual { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Assign(op) => wgsl::Instruction::Assign { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::IndexAssign(op) => wgsl::Instruction::IndexAssign { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::UncheckedIndexAssign(op) => wgsl::Instruction::IndexAssign { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::And(op) => wgsl::Instruction::And { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Or(op) => wgsl::Instruction::Or { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Not(op) => wgsl::Instruction::Not { - input: self.compile_variable(op.input), - out: self.compile_variable(op.out), - }, - cube::Operator::BitwiseAnd(op) => wgsl::Instruction::BitwiseAnd { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::BitwiseXor(op) => wgsl::Instruction::BitwiseXor { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::ShiftLeft(op) => wgsl::Instruction::ShiftLeft { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::ShiftRight(op) => wgsl::Instruction::ShiftRight { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Remainder(op) => wgsl::Instruction::Remainder { - lhs: self.compile_variable(op.lhs), - rhs: self.compile_variable(op.rhs), - out: self.compile_variable(op.out), - }, - cube::Operator::Slice(op) => wgsl::Instruction::Slice { - input: self.compile_variable(op.input), - start: self.compile_variable(op.start), - end: self.compile_variable(op.end), - out: self.compile_variable(op.out), - }, - } - } - - fn compile_location(value: cube::Location) -> wgsl::Location { - match value { - cube::Location::Storage => wgsl::Location::Storage, - cube::Location::Cube => wgsl::Location::Workgroup, - } - } - - fn compile_visibility(value: cube::Visibility) -> wgsl::Visibility { - match value { - cube::Visibility::Read => wgsl::Visibility::Read, - cube::Visibility::ReadWrite => wgsl::Visibility::ReadWrite, - } - } - - fn compile_binding(value: cube::Binding) -> wgsl::Binding { - wgsl::Binding { - visibility: Self::compile_visibility(value.visibility), - location: Self::compile_location(value.location), - item: Self::compile_item(value.item), - size: value.size, - } - } -} - -fn register_extensions(instructions: &[wgsl::Instruction]) -> Vec { - let mut extensions = Vec::new(); - - let mut register_extension = |extension: wgsl::Extension| { - if !extensions.contains(&extension) { - extensions.push(extension); - } - }; - - // Since not all instructions are native to WGSL, we need to add the custom ones. - for instruction in instructions { - match instruction { - wgsl::Instruction::Powf { lhs: _, rhs, out } => { - register_extension(wgsl::Extension::PowfPrimitive(out.item())); - - if rhs.is_always_scalar() { - register_extension(wgsl::Extension::PowfScalar(out.item())); - } else { - register_extension(wgsl::Extension::Powf(out.item())); - } - } - wgsl::Instruction::Erf { input, out: _ } => { - register_extension(wgsl::Extension::Erf(input.item())); - } - #[cfg(target_os = "macos")] - wgsl::Instruction::Tanh { input, out: _ } => { - register_extension(wgsl::Extension::SafeTanh(input.item())) - } - wgsl::Instruction::If { - cond: _, - instructions, - } => { - for extension in register_extensions(instructions) { - register_extension(extension); - } - } - _ => {} - } - } - - extensions -} diff --git a/crates/burn-wgpu/src/compiler/wgsl/extension.rs b/crates/burn-wgpu/src/compiler/wgsl/extension.rs deleted file mode 100644 index c03f29d757..0000000000 --- a/crates/burn-wgpu/src/compiler/wgsl/extension.rs +++ /dev/null @@ -1,281 +0,0 @@ -use super::base::Item; -use std::fmt::Display; - -/// Not all functions are native to WGSL, so this struct allows to support more functions. -#[derive(Debug, PartialEq, Eq, Clone)] -pub enum Extension { - PowfScalar(Item), - PowfPrimitive(Item), - Powf(Item), - Erf(Item), - #[cfg(target_os = "macos")] - SafeTanh(Item), -} - -impl Display for Extension { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Extension::PowfScalar(elem) => format_powf_scalar(f, elem), - Extension::PowfPrimitive(elem) => format_powf_primitive(f, elem), - Extension::Powf(elem) => format_powf(f, elem), - Extension::Erf(elem) => format_erf(f, elem), - #[cfg(target_os = "macos")] - Extension::SafeTanh(elem) => format_safe_tanh(f, elem), - } - } -} - -fn format_powf_scalar(f: &mut core::fmt::Formatter<'_>, item: &Item) -> core::fmt::Result { - match item { - Item::Vec4(elem) => f.write_fmt(format_args!( - " -fn powf_scalar(lhs: {item}, rhs: {elem}) -> {item} {{ - return vec4( - powf_primitive(lhs[0], rhs), - powf_primitive(lhs[1], rhs), - powf_primitive(lhs[2], rhs), - powf_primitive(lhs[3], rhs), - ); -}} -" - )), - Item::Vec3(elem) => f.write_fmt(format_args!( - " -fn powf_scalar(lhs: {item}, rhs: {elem}) -> {item} {{ - return vec3( - powf_primitive(lhs[0], rhs), - powf_primitive(lhs[1], rhs), - powf_primitive(lhs[2], rhs), - ); -}} -" - )), - Item::Vec2(elem) => f.write_fmt(format_args!( - " -fn powf_scalar(lhs: {item}, rhs: {elem}) -> {item} {{ - return vec2( - powf_primitive(lhs[0], rhs), - powf_primitive(lhs[1], rhs), - ); -}} -" - )), - Item::Scalar(elem) => f.write_fmt(format_args!( - " -fn powf_scalar(lhs: {elem}, rhs: {elem}) -> {elem} {{ - return powf_primitive(lhs, rhs); -}} -" - )), - } -} - -fn format_powf_primitive( - f: &mut std::fmt::Formatter<'_>, - item: &Item, -) -> Result<(), std::fmt::Error> { - let elem = item.elem(); - f.write_fmt(format_args!( - " -fn powf_primitive(lhs: {elem}, rhs: {elem}) -> {elem} {{ - let modulo = rhs % 2.0; - if rhs == 0.0 {{ - return 1.0; - }} - if (modulo == 0.0) {{ - // Even number - return pow(abs(lhs), rhs); - }} else if (modulo == 1.0 && lhs < 0.0) {{ - // Odd number - return -1.0 * pow(-1.0 * lhs, rhs); - }} else {{ - // Float number - return pow(lhs, rhs); - }} -}} -" - ))?; - Ok(()) -} - -fn format_powf(f: &mut core::fmt::Formatter<'_>, item: &Item) -> core::fmt::Result { - match item { - Item::Vec4(_) => f.write_fmt(format_args!( - " -fn powf(lhs: {item}, rhs: {item}) -> {item} {{ - return vec4( - powf_primitive(lhs[0], rhs[0]), - powf_primitive(lhs[1], rhs[1]), - powf_primitive(lhs[2], rhs[2]), - powf_primitive(lhs[3], rhs[3]), - ); -}} -" - )), - Item::Vec3(_) => f.write_fmt(format_args!( - " -fn powf(lhs: {item}, rhs: {item}) -> {item} {{ - return vec3( - powf_primitive(lhs[0], rhs[0]), - powf_primitive(lhs[1], rhs[1]), - powf_primitive(lhs[2], rhs[2]), - ); -}} -" - )), - Item::Vec2(_) => f.write_fmt(format_args!( - " -fn powf(lhs: {item}, rhs: {item}) -> {item} {{ - return vec2( - powf_primitive(lhs[0], rhs[0]), - powf_primitive(lhs[1], rhs[1]), - ); -}} -" - )), - Item::Scalar(elem) => f.write_fmt(format_args!( - " -fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{ - return powf_primitive(lhs, rhs); -}} -" - )), - } -} - -fn format_erf(f: &mut core::fmt::Formatter<'_>, ty: &Item) -> core::fmt::Result { - let elem = ty.elem(); - f.write_fmt(format_args!( - " -/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations -/// -/// > (maximum error: 1.5×10−7) -/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x). -fn erf_positive_scalar(x: {elem}) -> {elem} {{ - let p = 0.3275911; - let a1 = 0.254829592; - let a2 = -0.284496736; - let a3 = 1.421413741; - let a4 = -1.453152027; - let a5 = 1.061405429; - - let t = 1.0 / (1.0 + p * abs(x)); - let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1; - - return 1.0 - (tmp * t * exp(-x * x)); -}} - -fn erf_scalar(x: {elem}) -> {elem} {{ - if (x < 0.0) {{ - return -1.0 * erf_positive_scalar(-1.0 * x); - }} - - return erf_positive_scalar(x); -}} -" - ))?; - - match ty { - Item::Vec4(_) => f.write_fmt(format_args!( - " -fn erf(x: {ty}) -> {ty} {{ - return vec4( - erf_scalar(x[0]), - erf_scalar(x[1]), - erf_scalar(x[2]), - erf_scalar(x[3]), - ); -}} - " - )), - Item::Vec3(_) => f.write_fmt(format_args!( - " -fn erf(x: {ty}) -> {ty} {{ - return vec3( - erf_scalar(x[0]), - erf_scalar(x[1]), - erf_scalar(x[2]), - ); -}} - " - )), - Item::Vec2(_) => f.write_fmt(format_args!( - " -fn erf(x: {ty}) -> {ty} {{ - return vec2( - erf_scalar(x[0]), - erf_scalar(x[1]), - ); -}} - " - )), - Item::Scalar(_) => f.write_fmt(format_args!( - " -fn erf(x: {ty}) -> {ty} {{ - return erf_scalar(x); -}} - " - )), - } -} - -#[cfg(target_os = "macos")] -fn format_safe_tanh(f: &mut core::fmt::Formatter<'_>, item: &Item) -> core::fmt::Result { - let elem = item.elem(); - - f.write_fmt(format_args!( - " -/// Metal has a weird numerical behaviour with tanh for inputs over 43.0 -fn safe_tanh_scalar(x: {elem}) -> {elem} {{ - if x > 43.0 {{ - return 1.0; - }} else {{ - return tanh(x); - }} -}} -" - ))?; - - match item { - Item::Vec4(_) => f.write_fmt(format_args!( - " -fn safe_tanh(x: {item}) -> {item} {{ - return vec4( - safe_tanh_scalar(x[0]), - safe_tanh_scalar(x[1]), - safe_tanh_scalar(x[2]), - safe_tanh_scalar(x[3]), - ); -}} -" - )), - Item::Vec3(_) => f.write_fmt(format_args!( - " -fn safe_tanh(x: {item}) -> {item} {{ - return vec3( - safe_tanh_scalar(x[0]), - safe_tanh_scalar(x[1]), - safe_tanh_scalar(x[2]), - ); -}} -" - )), - Item::Vec2(_) => f.write_fmt(format_args!( - " -fn safe_tanh(x: {item}) -> {item} {{ - return vec2( - safe_tanh_scalar(x[0]), - safe_tanh_scalar(x[1]), - ); -}} -" - )), - Item::Scalar(_) => f.write_fmt(format_args!( - " -fn safe_tanh(x: {item}) -> {item} {{ - return safe_tanh_scalar(x); -}} -" - )), - } -} diff --git a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs deleted file mode 100644 index fedd8ce3c4..0000000000 --- a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs +++ /dev/null @@ -1,727 +0,0 @@ -use super::{ - base::{Item, Variable}, - Elem, IndexedVariable, Subgroup, -}; -use std::fmt::Display; - -/// All instructions that can be used in a WGSL compute shader. -#[derive(Debug, Clone)] -#[allow(dead_code)] // Some variants might not be used with different flags -pub enum Instruction { - DeclareVariable { - var: Variable, - }, - Max { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Min { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Add { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Fma { - a: Variable, - b: Variable, - c: Variable, - out: Variable, - }, - If { - cond: Variable, - instructions: Vec, - }, - IfElse { - cond: Variable, - instructions_if: Vec, - instructions_else: Vec, - }, - Return, - Break, - WorkgroupBarrier, - // Index handles casting to correct local variable. - Index { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - // Index assign handles casting to correct output variable. - IndexAssign { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - // Assign handle casting to correct output variable. - Assign { - input: Variable, - out: Variable, - }, - Modulo { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Sub { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Mul { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Div { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Abs { - input: Variable, - out: Variable, - }, - Exp { - input: Variable, - out: Variable, - }, - Log { - input: Variable, - out: Variable, - }, - Log1p { - input: Variable, - out: Variable, - }, - Cos { - input: Variable, - out: Variable, - }, - Sin { - input: Variable, - out: Variable, - }, - Tanh { - input: Variable, - out: Variable, - }, - Powf { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Sqrt { - input: Variable, - out: Variable, - }, - Erf { - input: Variable, - out: Variable, - }, - Recip { - input: Variable, - out: Variable, - }, - Equal { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Lower { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Clamp { - input: Variable, - min_value: Variable, - max_value: Variable, - out: Variable, - }, - Greater { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - LowerEqual { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - GreaterEqual { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - NotEqual { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Stride { - dim: Variable, - position: usize, - out: Variable, - }, - Length { - var: Variable, - out: Variable, - }, - Shape { - dim: Variable, - position: usize, - out: Variable, - }, - RangeLoop { - i: Variable, - start: Variable, - end: Variable, - instructions: Vec, - }, - And { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Or { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Not { - input: Variable, - out: Variable, - }, - Loop { - instructions: Vec, - }, - BitwiseAnd { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - BitwiseXor { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - ShiftLeft { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - ShiftRight { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Floor { - input: Variable, - out: Variable, - }, - Ceil { - input: Variable, - out: Variable, - }, - Remainder { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Slice { - input: Variable, - start: Variable, - end: Variable, - out: Variable, - }, - Subgroup(Subgroup), -} - -impl Display for Instruction { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Instruction::DeclareVariable { var } => { - let item = var.item(); - f.write_fmt(format_args!("var {var}: {item};\n")) - } - Instruction::Add { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = {lhs} + {rhs};\n")) - } - Instruction::Slice { - input, - start, - end, - out, - } => { - f.write_fmt(format_args!("let {out}_offset = {start};\n"))?; - f.write_fmt(format_args!("let {out}_length = {end} - {start};\n"))?; - f.write_fmt(format_args!("let {out}_ptr = &{input};\n")) - } - Instruction::Fma { a, b, c, out } => { - f.write_fmt(format_args!("{out} = fma({a}, {b}, {c});\n")) - } - Instruction::Min { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = min({lhs}, {rhs});\n")) - } - Instruction::Max { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = max({lhs}, {rhs});\n")) - } - Instruction::And { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = {lhs} && {rhs};\n")) - } - Instruction::Or { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = {lhs} || {rhs};\n")) - } - Instruction::Not { input, out } => f.write_fmt(format_args!("{out} = !{input};\n")), - Instruction::Index { lhs, rhs, out } => match lhs { - Variable::Slice { item, .. } => { - let offset = Variable::Named { - name: format!("{lhs}_offset"), - item: Item::Scalar(Elem::U32), - is_array: false, - }; - let lhs = Variable::Named { - name: format!("(*{lhs}_ptr)"), - item: *item, - is_array: true, - }; - index(f, &lhs, rhs, out, Some(offset)) - } - _ => index(f, lhs, rhs, out, None), - }, - Instruction::Modulo { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n")) - } - Instruction::Remainder { lhs, rhs, out } => f.write_fmt(format_args!( - "{out} = {lhs} - {rhs} * floor({lhs} / {rhs});\n" - )), - Instruction::Sub { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = {lhs} - {rhs};\n")) - } - Instruction::Mul { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = {lhs} * {rhs};\n")) - } - Instruction::Div { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = {lhs} / {rhs};\n")) - } - Instruction::Abs { input, out } => f.write_fmt(format_args!("{out} = abs({input});\n")), - Instruction::Exp { input, out } => f.write_fmt(format_args!("{out} = exp({input});\n")), - Instruction::Log { input, out } => f.write_fmt(format_args!("{out} = log({input});\n")), - Instruction::Clamp { - input, - min_value, - max_value, - out, - } => unroll( - f, - out.item().vectorization_factor(), - [input, min_value, max_value, out], - |f, [input, min, max, out]| { - f.write_fmt(format_args!("{out} = clamp({input}, {min}, {max});\n")) - }, - ), - Instruction::Powf { lhs, rhs, out } => { - let vectorization_factor = out.item().vectorization_factor(); - - if rhs.is_always_scalar() { - f.write_fmt(format_args!("{out} = powf_scalar({lhs}, {rhs});\n")) - } else { - unroll( - f, - vectorization_factor, - [lhs, rhs, out], - |f, [lhs, rhs, out]| { - f.write_fmt(format_args!("{out} = powf_primitive({lhs}, {rhs});\n")) - }, - ) - } - } - Instruction::Sqrt { input, out } => { - f.write_fmt(format_args!("{out} = sqrt({input});\n")) - } - Instruction::Log1p { input, out } => { - f.write_fmt(format_args!("{out} = log({input} + 1.0);\n")) - } - Instruction::Cos { input, out } => f.write_fmt(format_args!("{out} = cos({input});\n")), - Instruction::Sin { input, out } => f.write_fmt(format_args!("{out} = sin({input});\n")), - Instruction::Tanh { input, out } => { - #[cfg(target_os = "macos")] - let result = f.write_fmt(format_args!("{out} = safe_tanh({input});\n")); - #[cfg(not(target_os = "macos"))] - let result = f.write_fmt(format_args!("{out} = tanh({input});\n")); - - result - } - Instruction::Erf { input, out } => f.write_fmt(format_args!("{out} = erf({input});\n")), - Instruction::Recip { input, out } => { - f.write_fmt(format_args!("{out} = 1.0 / {input};")) - } - Instruction::Equal { lhs, rhs, out } => comparison(lhs, rhs, out, "==", f), - Instruction::Lower { lhs, rhs, out } => comparison(lhs, rhs, out, "<", f), - Instruction::Greater { lhs, rhs, out } => comparison(lhs, rhs, out, ">", f), - Instruction::LowerEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "<=", f), - Instruction::GreaterEqual { lhs, rhs, out } => comparison(lhs, rhs, out, ">=", f), - Instruction::NotEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "!=", f), - Instruction::Assign { input, out } => match out.item() { - Item::Vec4(elem) => { - let input0 = input.index(0); - let input1 = input.index(1); - let input2 = input.index(2); - let input3 = input.index(3); - - f.write_fmt(format_args!( - "{out} = vec4( - {elem}({input0}), - {elem}({input1}), - {elem}({input2}), - {elem}({input3}), -); -" - )) - } - Item::Vec3(elem) => { - let input0 = input.index(0); - let input1 = input.index(1); - let input2 = input.index(2); - - f.write_fmt(format_args!( - "{out} = vec3( - {elem}({input0}), - {elem}({input1}), - {elem}({input2}), -); -" - )) - } - Item::Vec2(elem) => { - let input0 = input.index(0); - let input1 = input.index(1); - - f.write_fmt(format_args!( - "{out} = vec2( - {elem}({input0}), - {elem}({input1}), -); -" - )) - } - Item::Scalar(elem) => f.write_fmt(format_args!("{out} = {elem}({input});\n")), - }, - Instruction::Stride { dim, position, out } => f.write_fmt(format_args!( - "{out} = info[({position}u * rank_2) + {dim} + 1u];\n" - )), - Instruction::Shape { dim, position, out } => f.write_fmt(format_args!( - "{out} = info[({position}u * rank_2) + rank + {dim} + 1u];\n" - )), - Instruction::RangeLoop { - i, - start, - end, - instructions, - } => { - f.write_fmt(format_args!( - " -for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{ -" - ))?; - for instruction in instructions { - f.write_fmt(format_args!("{instruction}"))?; - } - - f.write_str("}\n") - } - Instruction::IndexAssign { lhs, rhs, out } => { - if let Variable::Slice { item, .. } = out { - let offset = Variable::Named { - name: format!("{out}_offset"), - item: Item::Scalar(Elem::U32), - is_array: false, - }; - let out = Variable::Named { - name: format!("(*{out}_ptr)"), - item: *item, - is_array: true, - }; - - index_assign(f, lhs, rhs, &out, Some(offset)) - } else { - index_assign(f, lhs, rhs, out, None) - } - } - Instruction::If { cond, instructions } => { - f.write_fmt(format_args!("if {cond} {{\n"))?; - for i in instructions { - f.write_fmt(format_args!("{i}"))?; - } - f.write_str("}\n") - } - Instruction::IfElse { - cond, - instructions_if, - instructions_else, - } => { - f.write_fmt(format_args!("if {cond} {{\n"))?; - for i in instructions_if { - f.write_fmt(format_args!("{i}"))?; - } - f.write_str("} else {\n")?; - for i in instructions_else { - f.write_fmt(format_args!("{i}"))?; - } - f.write_str("}\n") - } - Instruction::Return => f.write_str("return;\n"), - Instruction::Break => f.write_str("break;\n"), - Instruction::WorkgroupBarrier => f.write_str("workgroupBarrier();\n"), - Instruction::Length { var, out } => match var { - Variable::Slice { .. } => f.write_fmt(format_args!("{out} = {var}_length;\n")), - _ => f.write_fmt(format_args!("{out} = arrayLength(&{var});\n")), - }, - Instruction::Loop { instructions } => { - f.write_fmt(format_args!("loop {{\n"))?; - for i in instructions { - f.write_fmt(format_args!("{i}"))?; - } - f.write_str("}\n") - } - Instruction::BitwiseAnd { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = {lhs} & {rhs};\n")) - } - Instruction::BitwiseXor { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = {lhs} ^ {rhs};\n")) - } - Instruction::ShiftLeft { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = {lhs} << {rhs};\n")) - } - Instruction::ShiftRight { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = {lhs} >> {rhs};\n")) - } - Instruction::Floor { input, out } => { - f.write_fmt(format_args!("{out} = floor({input});\n")) - } - Instruction::Ceil { input, out } => { - f.write_fmt(format_args!("{out} = ceil({input});\n")) - } - Instruction::Subgroup(op) => f.write_fmt(format_args!("{op}")), - } - } -} - -fn comparison( - lhs: &Variable, - rhs: &Variable, - out: &Variable, - op: &str, - f: &mut std::fmt::Formatter<'_>, -) -> std::fmt::Result { - match out.item() { - Item::Vec4(_) => { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - let lhs2 = lhs.index(2); - let lhs3 = lhs.index(3); - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - let rhs2 = rhs.index(2); - let rhs3 = rhs.index(3); - - f.write_fmt(format_args!( - " -{out} = vec4({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2}, {lhs3} {op} {rhs3}); -" - )) - } - Item::Vec3(_) => { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - let lhs2 = lhs.index(2); - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - let rhs2 = rhs.index(2); - - f.write_fmt(format_args!( - " -{out} = vec3({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2}); -" - )) - } - Item::Vec2(_) => { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - - f.write_fmt(format_args!( - " -{out} = vec2({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}); -" - )) - } - Item::Scalar(_) => match rhs.item() { - Item::Scalar(_) => f.write_fmt(format_args!("{out} = {lhs} {op} {rhs};\n")), - _ => panic!("Can only compare a scalar when the output is a scalar"), - }, - } -} - -fn unroll< - const N: usize, - F: Fn(&mut core::fmt::Formatter<'_>, [IndexedVariable; N]) -> core::fmt::Result, ->( - f: &mut core::fmt::Formatter<'_>, - vectorization_factor: usize, - variables: [&Variable; N], - func: F, -) -> core::fmt::Result { - for i in 0..vectorization_factor { - let mut tmp = Vec::with_capacity(N); - for var in variables.iter().take(N) { - tmp.push(var.index(i)); - } - let vars = tmp.try_into().unwrap(); - - func(f, vars)?; - } - Ok(()) -} - -struct IndexOffset { - var: Variable, - offset: Option, - index: usize, -} -impl IndexOffset { - fn new(var: &Variable, offset: &Option, index: usize) -> Self { - Self { - var: var.clone(), - offset: offset.clone(), - index, - } - } -} - -impl Display for IndexOffset { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let var = self.var.index(self.index); - - match &self.offset { - Some(offset) => { - let offset = offset.index(self.index); - f.write_fmt(format_args!("{var} + {offset}")) - } - None => f.write_fmt(format_args!("{var}")), - } - } -} - -fn index( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - offset: Option, -) -> core::fmt::Result { - let item = out.item(); - match offset { - Some(offset) => f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs} + {offset}]);\n")), - None => f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n")), - } -} - -fn index_assign( - f: &mut std::fmt::Formatter<'_>, - lhs: &Variable, - rhs: &Variable, - out: &Variable, - offset: Option, -) -> core::fmt::Result { - match lhs.item() { - Item::Vec4(elem) => { - let lhs0 = IndexOffset::new(lhs, &offset, 0); - let lhs1 = IndexOffset::new(lhs, &offset, 1); - let lhs2 = IndexOffset::new(lhs, &offset, 2); - let lhs3 = IndexOffset::new(lhs, &offset, 3); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - let rhs2 = rhs.index(2); - let rhs3 = rhs.index(3); - - f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; - f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?; - f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?; - f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n")) - } - Item::Vec3(elem) => { - let lhs0 = IndexOffset::new(lhs, &offset, 0); - let lhs1 = IndexOffset::new(lhs, &offset, 1); - let lhs2 = IndexOffset::new(lhs, &offset, 2); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - let rhs2 = rhs.index(2); - - f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; - f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?; - f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n")) - } - Item::Vec2(elem) => { - let lhs0 = IndexOffset::new(lhs, &offset, 0); - let lhs1 = IndexOffset::new(lhs, &offset, 1); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - - f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; - f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n")) - } - Item::Scalar(_elem) => { - let is_array = match out { - Variable::GlobalInputArray(_, _) - | Variable::GlobalOutputArray(_, _) - | Variable::SharedMemory(_, _, _) - | Variable::Slice { .. } - | Variable::LocalArray(_, _, _, _) => true, - Variable::Named { is_array, .. } => *is_array, - _ => false, - }; - - if !is_array { - let elem_out = out.elem(); - let casting_type = match rhs.item() { - Item::Vec4(_) => Item::Vec4(elem_out), - Item::Vec3(_) => Item::Vec3(elem_out), - Item::Vec2(_) => Item::Vec2(elem_out), - Item::Scalar(_) => Item::Scalar(elem_out), - }; - f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n")) - } else { - let item_rhs = rhs.item(); - let item_out = out.item(); - let lhs = IndexOffset::new(lhs, &offset, 0); - - let vectorization_factor = item_out.vectorization_factor(); - if vectorization_factor > item_rhs.vectorization_factor() { - let casting_type = item_out.elem(); - f.write_fmt(format_args!("{out}[{lhs}] = vec{vectorization_factor}("))?; - for i in 0..vectorization_factor { - let value = rhs.index(i); - f.write_fmt(format_args!("{casting_type}({value})"))?; - - if i < vectorization_factor - 1 { - f.write_str(",")?; - } - } - f.write_str(");\n") - } else { - let casting_type = item_out; - f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n")) - } - } - } - } -} diff --git a/crates/burn-wgpu/src/compiler/wgsl/mod.rs b/crates/burn-wgpu/src/compiler/wgsl/mod.rs deleted file mode 100644 index 836d909fd6..0000000000 --- a/crates/burn-wgpu/src/compiler/wgsl/mod.rs +++ /dev/null @@ -1,15 +0,0 @@ -mod base; -mod body; -mod compiler; -mod extension; -mod instructions; -mod shader; -mod subgroup; - -pub(crate) use base::*; -pub(crate) use body::*; -pub use compiler::*; -pub(crate) use extension::*; -pub(crate) use instructions::*; -pub(crate) use shader::*; -pub(crate) use subgroup::*; diff --git a/crates/burn-wgpu/src/compiler/wgsl/shader.rs b/crates/burn-wgpu/src/compiler/wgsl/shader.rs deleted file mode 100644 index 4618485561..0000000000 --- a/crates/burn-wgpu/src/compiler/wgsl/shader.rs +++ /dev/null @@ -1,242 +0,0 @@ -use super::{Body, Extension, Item}; -use burn_cube::{ir::CubeDim, CompilerRepresentation}; -use std::fmt::Display; - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum Location { - Storage, - Workgroup, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum Visibility { - Read, - ReadWrite, -} - -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct Binding { - pub location: Location, - pub visibility: Visibility, - pub item: Item, - pub size: Option, -} - -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct SharedMemory { - location: Location, - pub index: u16, - item: Item, - size: u32, -} - -impl SharedMemory { - pub fn new(index: u16, item: Item, size: u32) -> Self { - Self { - location: Location::Workgroup, - index, - item, - size, - } - } -} - -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct LocalArray { - pub index: u16, - item: Item, - name: u8, - size: u32, -} - -impl LocalArray { - pub fn new(index: u16, item: Item, name: u8, size: u32) -> Self { - Self { - index, - item, - name, - size, - } - } -} - -#[derive(Debug, Clone)] -pub struct ComputeShader { - pub inputs: Vec, - pub outputs: Vec, - pub named: Vec<(String, Binding)>, - pub shared_memories: Vec, - pub local_arrays: Vec, - pub workgroup_size: CubeDim, - pub global_invocation_id: bool, - pub local_invocation_index: bool, - pub local_invocation_id: bool, - pub num_workgroups: bool, - pub workgroup_id: bool, - pub num_workgroups_no_axis: bool, - pub workgroup_id_no_axis: bool, - pub workgroup_size_no_axis: bool, - pub body: Body, - pub extensions: Vec, -} - -impl Display for ComputeShader { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Self::format_bindings(f, "input", &self.inputs, 0)?; - Self::format_bindings(f, "output", &self.outputs, self.inputs.len())?; - - for (i, (name, binding)) in self.named.iter().enumerate() { - Self::format_binding( - f, - name.as_str(), - binding, - self.inputs.len() + self.outputs.len() + i, - )?; - } - - for array in self.shared_memories.iter() { - f.write_fmt(format_args!( - "var<{}> shared_memory_{}: array<{}, {}>;\n\n", - array.location, array.index, array.item, array.size - ))?; - } - - f.write_fmt(format_args!( - "const WORKGROUP_SIZE_X = {}u; -const WORKGROUP_SIZE_Y = {}u; -const WORKGROUP_SIZE_Z = {}u;\n", - self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z - ))?; - - f.write_fmt(format_args!( - " -@compute -@workgroup_size({}, {}, {}) -fn main( -", - self.workgroup_size.x, self.workgroup_size.y, self.workgroup_size.z - ))?; - - if self.global_invocation_id { - f.write_str(" @builtin(global_invocation_id) global_id: vec3,\n")?; - } - - if self.local_invocation_index { - f.write_str(" @builtin(local_invocation_index) local_idx: u32,\n")?; - } - - if self.local_invocation_id { - f.write_str(" @builtin(local_invocation_id) local_invocation_id: vec3,\n")?; - } - - if self.num_workgroups { - f.write_str(" @builtin(num_workgroups) num_workgroups: vec3,\n")?; - } - - if self.workgroup_id { - f.write_str(" @builtin(workgroup_id) workgroup_id: vec3,\n")?; - } - - // Open body - f.write_fmt(format_args!(") {{"))?; - - // Local arrays - for array in self.local_arrays.iter() { - f.write_fmt(format_args!( - "var a_{}_{}: array<{}, {}>;\n\n", - array.name, array.index, array.item, array.size - ))?; - } - - // Body - if self.workgroup_id_no_axis { - f.write_str("let workgroup_id_no_axis = (num_workgroups.y * num_workgroups.x * workgroup_id.z) + (num_workgroups.x * workgroup_id.y) + workgroup_id.x;\n")?; - } - - if self.workgroup_size_no_axis { - f.write_str("let workgroup_size_no_axis = WORKGROUP_SIZE_X * WORKGROUP_SIZE_Y * WORKGROUP_SIZE_Z;\n")?; - } - - if self.num_workgroups_no_axis { - f.write_str("let num_workgroups_no_axis = num_workgroups.x * num_workgroups.y * num_workgroups.z;\n")?; - } - - f.write_fmt(format_args!("{}", self.body))?; - - // Close body - f.write_fmt(format_args!("}}"))?; - - for extension in self.extensions.iter() { - f.write_fmt(format_args!("{extension}\n\n"))?; - } - - Ok(()) - } -} - -impl ComputeShader { - fn format_bindings( - f: &mut core::fmt::Formatter<'_>, - prefix: &str, - bindings: &[Binding], - num_entry: usize, - ) -> core::fmt::Result { - for (i, binding) in bindings.iter().enumerate() { - Self::format_binding( - f, - format!("{prefix}_{i}_global").as_str(), - binding, - num_entry + i, - )?; - } - - Ok(()) - } - - fn format_binding( - f: &mut core::fmt::Formatter<'_>, - name: &str, - binding: &Binding, - num_entry: usize, - ) -> core::fmt::Result { - let ty = match binding.size { - Some(size) => format!("array<{}, {}>", binding.item, size), - None => format!("array<{}>", binding.item), - }; - - f.write_fmt(format_args!( - "@group(0) -@binding({}) -var<{}, {}> {}: {}; -\n", - num_entry, binding.location, binding.visibility, name, ty - ))?; - - Ok(()) - } -} - -impl Display for Location { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Location::Storage => f.write_str("storage"), - Location::Workgroup => f.write_str("workgroup"), - } - } -} - -impl Display for Visibility { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Visibility::Read => f.write_str("read_write"), - Visibility::ReadWrite => f.write_str("read_write"), - } - } -} - -impl CompilerRepresentation for ComputeShader { - fn shared_memory_size(&self) -> usize { - // not used in wgsl compiler - 0 - } -} diff --git a/crates/burn-wgpu/src/compiler/wgsl/subgroup.rs b/crates/burn-wgpu/src/compiler/wgsl/subgroup.rs deleted file mode 100644 index d686b00e88..0000000000 --- a/crates/burn-wgpu/src/compiler/wgsl/subgroup.rs +++ /dev/null @@ -1,89 +0,0 @@ -use super::Variable; -use std::fmt::Display; - -#[derive(Debug, Clone)] -#[allow(dead_code, missing_docs)] // Some variants might not be used with different flags -pub enum Subgroup { - Elect { - out: Variable, - }, - All { - input: Variable, - out: Variable, - }, - Any { - input: Variable, - out: Variable, - }, - Broadcast { - lhs: Variable, - rhs: Variable, - out: Variable, - }, - Sum { - input: Variable, - out: Variable, - }, - Prod { - input: Variable, - out: Variable, - }, - And { - input: Variable, - out: Variable, - }, - Or { - input: Variable, - out: Variable, - }, - Xor { - input: Variable, - out: Variable, - }, - Min { - input: Variable, - out: Variable, - }, - Max { - input: Variable, - out: Variable, - }, -} - -impl Display for Subgroup { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Subgroup::Elect { out } => f.write_fmt(format_args!("{out} = subgroupElect();\n")), - Subgroup::All { input, out } => { - f.write_fmt(format_args!("{out} = subgroupAll({input});\n")) - } - Subgroup::Any { input, out } => { - f.write_fmt(format_args!("{out} = subgroupAny({input});\n")) - } - Subgroup::Broadcast { lhs, rhs, out } => { - f.write_fmt(format_args!("{out} = subgroupBroadcast({lhs}, {rhs});\n")) - } - Subgroup::Sum { input, out } => { - f.write_fmt(format_args!("{out} = subgroupAdd({input});\n")) - } - Subgroup::Prod { input, out } => { - f.write_fmt(format_args!("{out} = subgroupMul({input});\n")) - } - Subgroup::And { input, out } => { - f.write_fmt(format_args!("{out} = subgroupAnd({input});\n")) - } - Subgroup::Or { input, out } => { - f.write_fmt(format_args!("{out} = subgroupOr({input});\n")) - } - Subgroup::Xor { input, out } => { - f.write_fmt(format_args!("{out} = subgroupXor({input});\n")) - } - Subgroup::Min { input, out } => { - f.write_fmt(format_args!("{out} = subgroupMin({input});\n")) - } - Subgroup::Max { input, out } => { - f.write_fmt(format_args!("{out} = subgroupMax({input});\n")) - } - } - } -} diff --git a/crates/burn-wgpu/src/compute/mod.rs b/crates/burn-wgpu/src/compute/mod.rs deleted file mode 100644 index 4139c3868f..0000000000 --- a/crates/burn-wgpu/src/compute/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod server; -mod storage; - -pub use server::*; -pub use storage::*; diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs deleted file mode 100644 index b432d670d3..0000000000 --- a/crates/burn-wgpu/src/compute/server.rs +++ /dev/null @@ -1,353 +0,0 @@ -use std::num::NonZeroU64; - -use super::WgpuStorage; -use alloc::{borrow::Cow, sync::Arc}; -use burn_compute::{ - memory_management::MemoryManagement, - server::{self, ComputeServer}, -}; -use burn_cube::{prelude::*, FeatureSet}; -use burn_jit::JitAutotuneKey; -use burn_tensor::{backend::SyncType, Reader}; -use hashbrown::HashMap; -use wgpu::{ - util::{BufferInitDescriptor, DeviceExt, StagingBelt}, - BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor, -}; - -// Allocations with existing data smaller than this can use a staging belt -// which speeds up the allocation. A higher number here will catch more -// allocations, but can also increase memory usage. -const SMALL_ALLOC_SIZE: usize = 512; - -/// Wgpu compute server. -#[derive(Debug)] -pub struct WgpuServer> { - memory_management: MM, - device: Arc, - queue: Arc, - encoder: CommandEncoder, - staging_belt: StagingBelt, - pipelines: HashMap>, - tasks_max: usize, - tasks_count: usize, -} - -impl WgpuServer -where - MM: MemoryManagement, -{ - /// Create a new server. - pub fn new( - memory_management: MM, - device: Arc, - queue: Arc, - tasks_max: usize, - ) -> Self { - let encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { - label: Some("Command Encoder"), - }); - - Self { - memory_management, - device, - queue, - encoder, - staging_belt: StagingBelt::new(SMALL_ALLOC_SIZE as u64), - pipelines: HashMap::new(), - tasks_max, - tasks_count: 0, - } - } - - fn register_compute( - &mut self, - pipeline: Arc, - bind_group: BindGroup, - count: CubeCount, - ) { - // First resolve the dispatch buffer if needed. The weird ordering is because the lifetime of this - // needs to be longer than the compute pass, so we can't do this just before dispatching. - let dispatch_resource = match count.clone() { - CubeCount::Dynamic(binding) => Some(self.memory_management.get(binding.memory)), - _ => None, - }; - - let mut compute = self - .encoder - .begin_compute_pass(&wgpu::ComputePassDescriptor { - label: None, - timestamp_writes: None, - }); - - compute.set_pipeline(&pipeline); - compute.set_bind_group(0, &bind_group, &[]); - - match count { - CubeCount::Static(x, y, z) => { - compute.dispatch_workgroups(x, y, z); - } - CubeCount::Dynamic(_) => { - let resource = dispatch_resource.as_ref().unwrap(); - compute.dispatch_workgroups_indirect(&resource.buffer, resource.offset()); - } - } - - self.tasks_count += 1; - } - - fn pipeline(&mut self, kernel: ::Kernel) -> Arc { - let kernel_id = kernel.id(); - - if let Some(pipeline) = self.pipelines.get(&kernel_id) { - return pipeline.clone(); - } - - let compile = kernel.compile(); - let pipeline = self.compile_source(&compile.source); - - self.pipelines.insert(kernel_id.clone(), pipeline.clone()); - - pipeline - } - - fn compile_source(&self, source: &str) -> Arc { - let module = self.device.create_shader_module(ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }); - - Arc::new( - self.device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { - label: None, - layout: None, - module: &module, - entry_point: "main", - compilation_options: Default::default(), - }), - ) - } - - fn create_read_buffer(&mut self, handle: server::Binding) -> wgpu::Buffer { - let resource = self.memory_management.get(handle.memory); - - let size = resource.size(); - let buffer_dest = self.device.create_buffer(&wgpu::BufferDescriptor { - label: None, - size, - usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, - mapped_at_creation: false, - }); - - self.encoder.copy_buffer_to_buffer( - &resource.buffer, - resource.offset(), - &buffer_dest, - 0, - size, - ); - self.tasks_count += 1; - - self.sync(SyncType::Flush); - buffer_dest - } -} - -impl ComputeServer for WgpuServer -where - MM: MemoryManagement, -{ - type Kernel = Box; - type DispatchOptions = CubeCount; - type Storage = WgpuStorage; - type MemoryManagement = MM; - type AutotuneKey = JitAutotuneKey; - type FeatureSet = FeatureSet; - - fn read(&mut self, binding: server::Binding) -> Reader { - let device = self.device.clone(); - let buffer = self.create_read_buffer(binding); - - Box::pin(async move { - let buffer_slice = buffer.slice(..); - let (sender, receiver) = async_channel::bounded(1); - - buffer_slice.map_async(wgpu::MapMode::Read, move |v| { - sender - .try_send(v) - .expect("Unable to send buffer slice result to async channel.") - }); - - device.poll(wgpu::Maintain::Wait); - - let result = receiver - .recv() - .await - .expect("Unable to receive buffer slice result."); - - if let Ok(()) = result { - let data = buffer_slice.get_mapped_range(); - let result = bytemuck::cast_slice(&data).to_vec(); - - drop(data); - buffer.unmap(); - result - } else { - panic!("Unable to read buffer {:?}", result) - } - }) - } - - fn get_resource( - &mut self, - binding: server::Binding, - ) -> ::Resource { - self.memory_management.get(binding.memory) - } - - /// When we create a new handle from existing data, we use custom allocations so that we don't - /// have to execute the current pending tasks. - /// - /// This is important, otherwise the compute passes are going to be too small and we won't be able to - /// fully utilize the GPU. - fn create(&mut self, data: &[u8]) -> server::Handle { - let handle = server::Handle::new(self.memory_management.reserve(data.len(), || { - flush_tasks( - &mut self.encoder, - &self.queue, - &self.device, - &mut self.tasks_count, - &mut self.staging_belt, - ); - self.device.poll(wgpu::Maintain::Wait); - })); - - let non_zero_len = NonZeroU64::new(data.len() as u64); - - // If there's nothing to copy, don't need to do any work here. - if let Some(len) = non_zero_len { - let binding = handle.clone().binding(); - let resource = self.memory_management.get(binding.memory); - - if data.len() < SMALL_ALLOC_SIZE { - // Use a staging belt if the allocation is small enough. This is faster than allocating a new buffer. - // Ideally, we could use queue.write_buffer_with(), which seems to be the recommended method for performance, - // but that doesn't seem to work, as we might re-use a buffer multiple times, and need to schedule this - // precisely in the encoder. - let mut write_buf = self.staging_belt.write_buffer( - &mut self.encoder, - &resource.buffer, - resource.offset(), - len, - &self.device, - ); - write_buf.copy_from_slice(data); - } else { - let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor { - label: Some("Buffer Src"), - contents: data, - usage: wgpu::BufferUsages::COPY_SRC, - })); - self.encoder.copy_buffer_to_buffer( - &buffer_src, - 0, - &resource.buffer, - resource.offset(), - buffer_src.size(), - ); - } - self.tasks_count += 1; - } - - handle - } - - fn empty(&mut self, size: usize) -> server::Handle { - server::Handle::new(self.memory_management.reserve(size, || { - flush_tasks( - &mut self.encoder, - &self.queue, - &self.device, - &mut self.tasks_count, - &mut self.staging_belt, - ); - self.device.poll(wgpu::Maintain::Wait); - })) - } - - fn execute( - &mut self, - kernel: Self::Kernel, - count: Self::DispatchOptions, - bindings: Vec>, - ) { - let pipeline = self.pipeline(kernel); - let group_layout = pipeline.get_bind_group_layout(0); - - let memory_handles = bindings - .into_iter() - .map(|binding| self.memory_management.get(binding.memory)) - .collect::>(); - - let entries = memory_handles - .iter() - .enumerate() - .map(|(i, buffer)| wgpu::BindGroupEntry { - binding: i as u32, - resource: buffer.as_binding(), - }) - .collect::>(); - - let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { - label: None, - layout: &group_layout, - entries: &entries, - }); - - self.register_compute(pipeline, bind_group, count); - - if self.tasks_count >= self.tasks_max { - self.sync(SyncType::Flush); - } - } - - fn sync(&mut self, sync_type: SyncType) { - flush_tasks( - &mut self.encoder, - &self.queue, - &self.device, - &mut self.tasks_count, - &mut self.staging_belt, - ); - - // Cleanup allocations and deallocations. - self.memory_management.storage().perform_deallocations(); - - if sync_type == SyncType::Wait { - self.device.poll(wgpu::Maintain::Wait); - } - } -} - -/// Flush tasks using the [command encoder](CommandEncoder). -/// -/// This implementation is decoupled from both the [server](WgpuServer) and [memory management](MemoryManagement). -/// This decoupling allows for safe usage within sync callbacks when allocating memory buffers. -fn flush_tasks( - encoder: &mut CommandEncoder, - queue: &wgpu::Queue, - device: &wgpu::Device, - tasks_count: &mut usize, - staging_belt: &mut StagingBelt, -) { - staging_belt.finish(); - - let mut new_encoder = - device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); - core::mem::swap(&mut new_encoder, encoder); - - queue.submit(Some(new_encoder.finish())); - *tasks_count = 0; - staging_belt.recall(); -} diff --git a/crates/burn-wgpu/src/compute/storage.rs b/crates/burn-wgpu/src/compute/storage.rs deleted file mode 100644 index a2733b806e..0000000000 --- a/crates/burn-wgpu/src/compute/storage.rs +++ /dev/null @@ -1,146 +0,0 @@ -use burn_compute::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; -use hashbrown::HashMap; -use std::{num::NonZeroU64, sync::Arc}; - -/// Buffer storage for wgpu. -pub struct WgpuStorage { - memory: HashMap>, - deallocations: Vec, - device: Arc, - queue: Arc, -} - -impl core::fmt::Debug for WgpuStorage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str()) - } -} - -/// The memory resource that can be allocated for wgpu. -#[derive(new, Debug)] -pub struct WgpuResource { - /// The wgpu buffer. - pub buffer: Arc, - /// How the resource is used. - pub kind: WgpuResourceKind, -} - -impl WgpuResource { - /// Return the binding view of the buffer. - pub fn as_binding(&self) -> wgpu::BindingResource { - let binding = match &self.kind { - WgpuResourceKind::Full => self.buffer.as_entire_buffer_binding(), - WgpuResourceKind::Slice(offs, size) => wgpu::BufferBinding { - buffer: &self.buffer, - offset: *offs, - size: Some(*size), - }, - }; - wgpu::BindingResource::Buffer(binding) - } - - /// Return the buffer size. - pub fn size(&self) -> u64 { - match self.kind { - WgpuResourceKind::Full => self.buffer.size(), - WgpuResourceKind::Slice(_, size) => size.get(), - } - } - - /// Return the buffer offset. - pub fn offset(&self) -> u64 { - match self.kind { - WgpuResourceKind::Full => 0, - WgpuResourceKind::Slice(offset, _) => offset, - } - } -} - -/// How the resource is used, either as a slice or fully. -#[derive(Debug)] -pub enum WgpuResourceKind { - /// Represents an entire buffer. - Full, - /// A slice over a buffer. - Slice(wgpu::BufferAddress, wgpu::BufferSize), -} - -/// Keeps actual wgpu buffer references in a hashmap with ids as key. -impl WgpuStorage { - /// Create a new storage on the given [device](wgpu::Device). - pub fn new(device: Arc, queue: Arc) -> Self { - Self { - memory: HashMap::new(), - deallocations: Vec::new(), - device, - queue, - } - } - - /// Actually deallocates buffers tagged to be deallocated. - pub fn perform_deallocations(&mut self) { - for id in self.deallocations.drain(..) { - if let Some(buffer) = self.memory.remove(&id) { - buffer.destroy() - } - } - } -} - -impl ComputeStorage for WgpuStorage { - type Resource = WgpuResource; - - fn get(&mut self, handle: &StorageHandle) -> Self::Resource { - let buffer = self.memory.get(&handle.id).unwrap(); - - match handle.utilization { - StorageUtilization::Full(_) => { - WgpuResource::new(buffer.clone(), WgpuResourceKind::Full) - } - StorageUtilization::Slice { offset, size } => WgpuResource::new( - buffer.clone(), - WgpuResourceKind::Slice(offset as u64, NonZeroU64::new(size as u64).unwrap()), - ), - } - } - - fn alloc(&mut self, size: usize) -> StorageHandle { - let id = StorageId::new(); - let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor { - label: None, - size: size as u64, - usage: wgpu::BufferUsages::COPY_DST - | wgpu::BufferUsages::STORAGE - | wgpu::BufferUsages::COPY_SRC - | wgpu::BufferUsages::INDIRECT, - mapped_at_creation: false, - })); - - self.memory.insert(id.clone(), buffer); - - StorageHandle::new(id, StorageUtilization::Full(size)) - } - - fn dealloc(&mut self, id: StorageId) { - self.deallocations.push(id); - } - - fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) { - let mut encoder = self - .device - .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); - - let from = self.get(from); - let to = self.get(to); - - encoder.copy_buffer_to_buffer( - &from.buffer, - from.offset(), - &to.buffer, - to.offset(), - to.size(), - ); - - self.queue.submit(Some(encoder.finish())); - } -} diff --git a/crates/burn-wgpu/src/device.rs b/crates/burn-wgpu/src/device.rs deleted file mode 100644 index 50f1fca1a9..0000000000 --- a/crates/burn-wgpu/src/device.rs +++ /dev/null @@ -1,56 +0,0 @@ -/// The device struct when using the `wgpu` backend. -/// -/// Note that you need to provide the device index when using a GPU backend. -/// -/// # Example -/// -/// ```no_run -/// use burn_wgpu::WgpuDevice; -/// -/// let device_gpu_1 = WgpuDevice::DiscreteGpu(0); // First discrete GPU found. -/// let device_gpu_2 = WgpuDevice::DiscreteGpu(1); // Second discrete GPU found. -/// ``` -#[derive(Clone, Debug, Hash, PartialEq, Eq)] -pub enum WgpuDevice { - /// Discrete GPU with the given index. The index is the index of the discrete GPU in the list - /// of all discrete GPUs found on the system. - DiscreteGpu(usize), - - /// Integrated GPU with the given index. The index is the index of the integrated GPU in the - /// list of all integrated GPUs found on the system. - IntegratedGpu(usize), - - /// Virtual GPU with the given index. The index is the index of the virtual GPU in the list of - /// all virtual GPUs found on the system. - VirtualGpu(usize), - - /// CPU. - Cpu, - - /// The best available device found with the current [graphics API](crate::GraphicsApi). - /// - /// Priority - /// - /// 1. DiscreteGpu - /// 2. IntegratedGpu - /// 3. VirtualGpu - /// 4. Cpu - /// - /// # Notes - /// - /// A device might be identified as [Other](wgpu::DeviceType::Other) by [wgpu](wgpu), in this case, we chose this device over - /// `IntegratedGpu` since it's often a discrete GPU. - BestAvailable, - - /// Use an externally created, existing, wgpu setup. This is helpful when using Burn in conjunction - /// with some existing wgpu setup (eg. egui or bevy), as resources can be transferred in & out of Burn. - /// - /// The device is indexed by the global wgpu [adapter ID](wgpu::Device::global_id). - Existing(wgpu::Id), -} - -impl Default for WgpuDevice { - fn default() -> Self { - Self::BestAvailable - } -} diff --git a/crates/burn-wgpu/src/element.rs b/crates/burn-wgpu/src/element.rs deleted file mode 100644 index 9d13b8060b..0000000000 --- a/crates/burn-wgpu/src/element.rs +++ /dev/null @@ -1,35 +0,0 @@ -use burn_jit::JitElement; - -use crate::compiler::wgsl; - -/// The base element trait for the wgpu backend. -pub trait WgpuElement: JitElement { - fn wgpu_elem() -> wgsl::Elem; -} - -/// The float element type for the wgpu backend. -pub trait FloatElement: WgpuElement + burn_jit::FloatElement {} - -/// The int element type for the wgpu backend. -pub trait IntElement: WgpuElement + burn_jit::IntElement {} - -impl WgpuElement for u32 { - fn wgpu_elem() -> wgsl::Elem { - wgsl::Elem::U32 - } -} - -impl WgpuElement for i32 { - fn wgpu_elem() -> wgsl::Elem { - wgsl::Elem::I32 - } -} - -impl WgpuElement for f32 { - fn wgpu_elem() -> wgsl::Elem { - wgsl::Elem::F32 - } -} - -impl FloatElement for f32 {} -impl IntElement for i32 {} diff --git a/crates/burn-wgpu/src/graphics.rs b/crates/burn-wgpu/src/graphics.rs deleted file mode 100644 index 470e39ffc5..0000000000 --- a/crates/burn-wgpu/src/graphics.rs +++ /dev/null @@ -1,96 +0,0 @@ -/// The basic trait to specify which graphics API to use as Backend. -/// -/// Options are: -/// - [Vulkan](Vulkan) -/// - [Metal](Metal) -/// - [OpenGL](OpenGl) -/// - [DirectX 12](Dx12) -/// - [WebGpu](WebGpu) -pub trait GraphicsApi: Send + Sync + core::fmt::Debug + Default + Clone + 'static { - /// The wgpu backend. - fn backend() -> wgpu::Backend; -} - -/// Vulkan graphics API. -#[derive(Default, Debug, Clone)] -pub struct Vulkan; - -/// Metal graphics API. -#[derive(Default, Debug, Clone)] -pub struct Metal; - -/// OpenGL graphics API. -#[derive(Default, Debug, Clone)] -pub struct OpenGl; - -/// DirectX 12 graphics API. -#[derive(Default, Debug, Clone)] -pub struct Dx12; - -/// WebGpu graphics API. -#[derive(Default, Debug, Clone)] -pub struct WebGpu; - -/// Automatic graphics API based on OS. -#[derive(Default, Debug, Clone)] -pub struct AutoGraphicsApi; - -impl GraphicsApi for Vulkan { - fn backend() -> wgpu::Backend { - wgpu::Backend::Vulkan - } -} - -impl GraphicsApi for Metal { - fn backend() -> wgpu::Backend { - wgpu::Backend::Metal - } -} - -impl GraphicsApi for OpenGl { - fn backend() -> wgpu::Backend { - wgpu::Backend::Gl - } -} - -impl GraphicsApi for Dx12 { - fn backend() -> wgpu::Backend { - wgpu::Backend::Dx12 - } -} - -impl GraphicsApi for WebGpu { - fn backend() -> wgpu::Backend { - wgpu::Backend::BrowserWebGpu - } -} - -impl GraphicsApi for AutoGraphicsApi { - fn backend() -> wgpu::Backend { - // Allow overriding AutoGraphicsApi backend with ENV var in std test environments - #[cfg(not(no_std))] - #[cfg(test)] - if let Ok(backend_str) = std::env::var("AUTO_GRAPHICS_BACKEND") { - match backend_str.to_lowercase().as_str() { - "metal" => return wgpu::Backend::Metal, - "vulkan" => return wgpu::Backend::Vulkan, - "dx12" => return wgpu::Backend::Dx12, - "opengl" => return wgpu::Backend::Gl, - "webgpu" => return wgpu::Backend::BrowserWebGpu, - _ => { - eprintln!( - "Invalid graphics backend specified in GRAPHICS_BACKEND environment \ - variable" - ); - std::process::exit(1); - } - } - } - - // In a no_std environment or if the environment variable is not set - #[cfg(target_os = "macos")] - return wgpu::Backend::Metal; - #[cfg(not(target_os = "macos"))] - return wgpu::Backend::Vulkan; - } -} diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index 1b94e7b829..8b49ec7e50 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -1,34 +1,19 @@ -#[macro_use] -extern crate derive_new; - extern crate alloc; -mod compiler; -mod compute; -mod device; -mod element; -mod graphics; -mod runtime; - -#[cfg(feature = "template")] -pub use burn_cube::ir::CubeDim; #[cfg(feature = "template")] pub use burn_jit::{ kernel::{into_contiguous, Kernel}, - kernel_wgsl, + kernel_source, template::{build_info, KernelSource, SourceKernel, SourceTemplate}, }; -pub use device::*; -pub use element::*; -pub use graphics::*; -pub use runtime::*; - -pub use burn_cube::prelude::CubeCount; pub use burn_jit::{tensor::JitTensor, JitBackend}; +pub use burn_jit::{FloatElement, IntElement}; +pub use cubecl::ir::CubeDim; +pub use cubecl::wgpu::*; #[cfg(feature = "fusion")] -/// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders. +/// Tensor backend that uses the wgpu crate for executing GPU compute shaders. /// /// This backend can target multiple graphics APIs, including: /// - [Vulkan] on Linux, Windows, and Android. @@ -54,15 +39,15 @@ pub use burn_jit::{tensor::JitTensor, JitBackend}; /// /// # Notes /// -/// This version of the [wgpu] backend uses [burn_fusion] to compile and optimize streams of tensor +/// This version of the wgpu backend uses [burn_fusion] to compile and optimize streams of tensor /// operations for improved performance. /// /// You can disable the `fusion` feature flag to remove that functionality, which might be /// necessary on `wasm` for now. -pub type Wgpu = burn_fusion::Fusion>; +pub type Wgpu = burn_fusion::Fusion>; #[cfg(not(feature = "fusion"))] -/// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders. +/// Tensor backend that uses the wgpu crate for executing GPU compute shaders. /// /// This backend can target multiple graphics APIs, including: /// - [Vulkan] on Linux, Windows, and Android. @@ -88,7 +73,7 @@ pub type Wgpu = burn_fusion::Fusion = JitBackend; #[cfg(test)] mod tests { - use super::*; - - pub type TestRuntime = crate::WgpuRuntime; + use burn_jit::JitBackend; + pub type TestRuntime = cubecl::wgpu::WgpuRuntime; burn_jit::testgen_all!(); - burn_cube::testgen_all!(); } diff --git a/crates/burn-wgpu/src/runtime.rs b/crates/burn-wgpu/src/runtime.rs deleted file mode 100644 index 161a4802cc..0000000000 --- a/crates/burn-wgpu/src/runtime.rs +++ /dev/null @@ -1,338 +0,0 @@ -use crate::{ - compiler::wgsl, - compute::{WgpuServer, WgpuStorage}, - AutoGraphicsApi, GraphicsApi, WgpuDevice, -}; -use alloc::sync::Arc; -use burn_common::stub::RwLock; -use burn_compute::{ - channel::MutexComputeChannel, - client::ComputeClient, - memory_management::dynamic::{DynamicMemoryManagement, DynamicMemoryManagementOptions}, - tune::Tuner, - ComputeRuntime, -}; -use burn_cube::{Feature, FeatureSet, Runtime}; -use burn_jit::JitRuntime; -use burn_tensor::backend::{DeviceId, DeviceOps}; -use wgpu::{AdapterInfo, DeviceDescriptor}; - -/// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend. -/// For advanced configuration, use [`init_sync`] to pass in runtime options or to select a -/// specific graphics API. -#[derive(Debug)] -pub struct WgpuRuntime {} - -impl JitRuntime for WgpuRuntime { - type JitDevice = WgpuDevice; - type JitServer = WgpuServer>; -} - -/// The compute instance is shared across all [wgpu runtimes](WgpuRuntime). -static RUNTIME: ComputeRuntime> = - ComputeRuntime::new(); - -type Server = WgpuServer>; - -impl Runtime for WgpuRuntime { - type Compiler = wgsl::WgslCompiler; - type Server = WgpuServer>; - - type Channel = MutexComputeChannel>>; - type Device = WgpuDevice; - - fn client(device: &Self::Device) -> ComputeClient { - RUNTIME.client(device, move || { - let (adapter, device_wgpu, queue) = - pollster::block_on(create_wgpu_setup::(device)); - create_client(adapter, device_wgpu, queue, RuntimeOptions::default()) - }) - } - - fn name() -> &'static str { - "wgpu" - } -} - -impl DeviceOps for WgpuDevice { - fn id(&self) -> DeviceId { - match self { - WgpuDevice::DiscreteGpu(index) => DeviceId::new(0, *index as u32), - WgpuDevice::IntegratedGpu(index) => DeviceId::new(1, *index as u32), - WgpuDevice::VirtualGpu(index) => DeviceId::new(2, *index as u32), - WgpuDevice::Cpu => DeviceId::new(3, 0), - WgpuDevice::BestAvailable => DeviceId::new(4, 0), - // For an existing device, use the 64 bit wgpu device ID as the burn DeviceID. - // We're only storing 32 bits, so wrap the the 64 bit value to 32 bits. This - // might collide - but a 1 in 4 billion chance seems ok given there's only a few - // devices in flight at any time. - WgpuDevice::Existing(id) => DeviceId::new(5, (id.inner() % (u32::MAX as u64)) as u32), - } - } -} - -/// The values that control how a WGPU Runtime will perform its calculations. -pub struct RuntimeOptions { - /// Control the amount of compute tasks to be aggregated into a single GPU command. - pub tasks_max: usize, -} - -impl Default for RuntimeOptions { - fn default() -> Self { - #[cfg(test)] - const DEFAULT_MAX_TASKS: usize = 1; - #[cfg(not(test))] - const DEFAULT_MAX_TASKS: usize = 16; - - let tasks_max = match std::env::var("BURN_WGPU_MAX_TASKS") { - Ok(value) => value - .parse::() - .expect("BURN_WGPU_MAX_TASKS should be a positive integer."), - Err(_) => DEFAULT_MAX_TASKS, - }; - - Self { tasks_max } - } -} - -pub fn init_existing_device( - adapter: Arc, - device: Arc, - queue: Arc, - options: RuntimeOptions, -) -> WgpuDevice { - let device_id = WgpuDevice::Existing(device.as_ref().global_id()); - let client = create_client(adapter, device, queue, options); - RUNTIME.register(&device_id, client); - device_id -} - -/// Initialize a client on the given device with the given options. This function is useful to configure the runtime options -/// or to pick a different graphics API. On wasm, it is necessary to use [`init_async`] instead. -pub fn init_sync(device: &WgpuDevice, options: RuntimeOptions) { - pollster::block_on(init_async::(device, options)); -} - -/// Like [`init_sync`], but async, necessary for wasm. -pub async fn init_async(device: &WgpuDevice, options: RuntimeOptions) { - let (adapter, device_wgpu, queue) = create_wgpu_setup::(device).await; - let client = create_client(adapter, device_wgpu, queue, options); - RUNTIME.register(device, client) -} - -async fn create_wgpu_setup( - device: &WgpuDevice, -) -> (Arc, Arc, Arc) { - let (device_wgpu, queue, adapter) = select_device::(device).await; - - log::info!( - "Created wgpu compute server on device {:?} => {:?}", - device, - adapter.get_info() - ); - (Arc::new(adapter), Arc::new(device_wgpu), Arc::new(queue)) -} - -fn create_client( - adapter: Arc, - device_wgpu: Arc, - queue: Arc, - options: RuntimeOptions, -) -> ComputeClient< - WgpuServer>, - MutexComputeChannel>>, -> { - let limits = device_wgpu.limits(); - let storage = WgpuStorage::new(device_wgpu.clone(), queue.clone()); - let memory_management = DynamicMemoryManagement::new( - storage, - DynamicMemoryManagementOptions::preset( - limits.max_storage_buffer_binding_size as usize, - limits.min_storage_buffer_offset_alignment as usize, - ), - ); - let server = WgpuServer::new(memory_management, device_wgpu, queue, options.tasks_max); - let channel = MutexComputeChannel::new(server); - let tuner_device_id = tuner_device_id(adapter.get_info()); - - let features = adapter.features(); - let mut features_cube = FeatureSet::default(); - - if features.contains(wgpu::Features::SUBGROUP) { - features_cube.register(Feature::Subcube); - } - - ComputeClient::new( - channel, - Arc::new(RwLock::new(Tuner::new("wgpu", &tuner_device_id))), - Arc::new(features_cube), - ) -} - -/// Select the wgpu device and queue based on the provided [device](WgpuDevice). -pub async fn select_device( - device: &WgpuDevice, -) -> (wgpu::Device, wgpu::Queue, wgpu::Adapter) { - #[cfg(target_family = "wasm")] - let adapter = select_adapter::(device).await; - - #[cfg(not(target_family = "wasm"))] - let adapter = select_adapter::(device); - let limits = adapter.limits(); - - let (device, queue) = adapter - .request_device( - &DeviceDescriptor { - label: None, - required_features: adapter.features(), - required_limits: limits, - }, - None, - ) - .await - .map_err(|err| { - format!( - "Unable to request the device with the adapter {:?}, err {:?}", - adapter.get_info(), - err - ) - }) - .unwrap(); - - (device, queue, adapter) -} - -fn tuner_device_id(info: AdapterInfo) -> String { - format!("wgpu-{}-{}", info.device, info.backend.to_str()) -} - -#[cfg(target_family = "wasm")] -async fn select_adapter(_device: &WgpuDevice) -> wgpu::Adapter { - let instance = wgpu::Instance::default(); - - instance - .request_adapter(&wgpu::RequestAdapterOptionsBase::default()) - .await - .unwrap() -} - -#[cfg(not(target_family = "wasm"))] -fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { - use wgpu::DeviceType; - - let instance = wgpu::Instance::default(); - let mut adapters_other = Vec::new(); - let mut adapters = Vec::new(); - - instance - .enumerate_adapters(G::backend().into()) - .into_iter() - .for_each(|adapter| { - let device_type = adapter.get_info().device_type; - - if let DeviceType::Other = device_type { - adapters_other.push(adapter); - return; - } - - let is_same_type = match device { - WgpuDevice::DiscreteGpu(_) => device_type == DeviceType::DiscreteGpu, - WgpuDevice::IntegratedGpu(_) => device_type == DeviceType::IntegratedGpu, - WgpuDevice::VirtualGpu(_) => device_type == DeviceType::VirtualGpu, - WgpuDevice::Cpu => device_type == DeviceType::Cpu, - WgpuDevice::BestAvailable => true, - WgpuDevice::Existing(_) => { - unreachable!("Cannot select an adapter for an existing device.") - } - }; - - if is_same_type { - adapters.push(adapter); - } - }); - - fn select( - num: usize, - error: &str, - mut adapters: Vec, - mut adapters_other: Vec, - ) -> wgpu::Adapter { - if adapters.len() <= num { - if adapters_other.len() <= num { - panic!( - "{}, adapters {:?}, other adapters {:?}", - error, - adapters - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - adapters_other - .into_iter() - .map(|adapter| adapter.get_info()) - .collect::>(), - ); - } - - return adapters_other.remove(num); - } - - adapters.remove(num) - } - - let adapter = match device { - WgpuDevice::DiscreteGpu(num) => select( - *num, - "No Discrete GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::IntegratedGpu(num) => select( - *num, - "No Integrated GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::VirtualGpu(num) => select( - *num, - "No Virtual GPU device found", - adapters, - adapters_other, - ), - WgpuDevice::Cpu => select(0, "No CPU device found", adapters, adapters_other), - WgpuDevice::BestAvailable => { - let mut most_performant_adapter = None; - let mut current_score = -1; - - adapters - .into_iter() - .chain(adapters_other) - .for_each(|adapter| { - let info = adapter.get_info(); - let score = match info.device_type { - DeviceType::DiscreteGpu => 5, - DeviceType::Other => 4, // Let's be optimistic with the Other device, it's - // often a Discrete Gpu. - DeviceType::IntegratedGpu => 3, - DeviceType::VirtualGpu => 2, - DeviceType::Cpu => 1, - }; - - if score > current_score { - most_performant_adapter = Some(adapter); - current_score = score; - } - }); - - if let Some(adapter) = most_performant_adapter { - adapter - } else { - panic!("No adapter found for graphics API {:?}", G::default()); - } - } - WgpuDevice::Existing(_) => unreachable!("Cannot select an adapter for an existing device."), - }; - - log::info!("Using adapter {:?}", adapter.get_info()); - - adapter -} diff --git a/examples/custom-wgpu-kernel/Cargo.toml b/examples/custom-wgpu-kernel/Cargo.toml index 5eebcdde49..67e3f438c1 100644 --- a/examples/custom-wgpu-kernel/Cargo.toml +++ b/examples/custom-wgpu-kernel/Cargo.toml @@ -13,6 +13,7 @@ burn = { path = "../../crates/burn", default-features = false, features = [ "autotune", "template", ] } +cubecl = { workspace = true, features = ["wgpu"] } # Serialization log = { workspace = true } diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index 84a9826608..5bff5bd1d2 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -3,16 +3,17 @@ use crate::FloatTensor; use super::Backend; use burn::{ backend::wgpu::{ - build_info, into_contiguous, kernel_wgsl, CubeCount, CubeDim, FloatElement, IntElement, - JitBackend, JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime, + build_info, into_contiguous, kernel_source, FloatElement, IntElement, JitBackend, + JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime, }, tensor::Shape, }; +use cubecl::{CubeCount, CubeDim}; use derive_new::new; use std::marker::PhantomData; // Source the kernel written in WGSL. -kernel_wgsl!(FusedMatmulAddReluRaw, "./kernel.wgsl"); +kernel_source!(FusedMatmulAddReluRaw, "./kernel.wgsl"); // Define our kernel type with cube information. #[derive(new, Debug)] diff --git a/examples/custom-wgpu-kernel/src/kernel.wgsl b/examples/custom-wgpu-kernel/src/kernel.wgsl index 9b41c73821..00871fe3bc 100644 --- a/examples/custom-wgpu-kernel/src/kernel.wgsl +++ b/examples/custom-wgpu-kernel/src/kernel.wgsl @@ -1,14 +1,14 @@ @group(0) @binding(0) -var lhs: array<{{ elem }}>; +var lhs: array<{{ elem }}>; @group(0) @binding(1) -var rhs: array<{{ elem }}>; +var rhs: array<{{ elem }}>; @group(0) @binding(2) -var bias: array<{{ elem }}>; +var bias: array<{{ elem }}>; @group(0) @binding(3) @@ -16,7 +16,7 @@ var output: array<{{ elem }}>; @group(0) @binding(4) -var info: array; +var info: array; const BLOCK_SIZE = {{ workgroup_size_x }}u; diff --git a/examples/gelu/Cargo.toml b/examples/gelu/Cargo.toml deleted file mode 100644 index c906548dd1..0000000000 --- a/examples/gelu/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -authors = [] -name = "gelu" -publish = false -edition.workspace = true -license.workspace = true -version.workspace = true - -[features] -default = ["wgpu"] -cuda = ["burn-cuda"] -wgpu = ["burn-wgpu"] - -[dependencies] -burn-cube = { path = "../../crates/burn-cube", version = "0.14.0" } -burn-cuda = { path = "../../crates/burn-cuda", version = "0.14.0", optional = true } -burn-wgpu = { path = "../../crates/burn-wgpu", version = "0.14.0", optional = true } diff --git a/examples/gelu/examples/gelu.rs b/examples/gelu/examples/gelu.rs deleted file mode 100644 index 374bfa3966..0000000000 --- a/examples/gelu/examples/gelu.rs +++ /dev/null @@ -1,6 +0,0 @@ -fn main() { - #[cfg(feature = "cuda")] - gelu::launch::(&Default::default()); - #[cfg(feature = "wgpu")] - gelu::launch::(&Default::default()); -} diff --git a/examples/gelu/src/lib.rs b/examples/gelu/src/lib.rs deleted file mode 100644 index a4c7d8040a..0000000000 --- a/examples/gelu/src/lib.rs +++ /dev/null @@ -1,36 +0,0 @@ -use burn_cube::prelude::*; - -#[cube(launch)] -fn gelu(input: &Array, output: &mut Array) { - if ABSOLUTE_POS < input.len() { - output[ABSOLUTE_POS] = gelu_scalar::(input[ABSOLUTE_POS]); - } -} - -#[cube] -fn gelu_scalar(x: F) -> F { - x * (F::new(1.0) + F::erf(x / F::sqrt(F::new(2.0)))) / F::new(2.0) -} - -pub fn launch(device: &R::Device) { - let client = R::client(device); - println!("Executing gelu with runtime {:?}", R::name()); - - let input = &[-1., 0., 1., 5.]; - let input_handle = client.create(f32::as_bytes(input)); - let output_handle = client.empty(input.len() * core::mem::size_of::()); - - gelu::launch::( - client.clone(), - CubeCount::Static(1, 1, 1), - CubeDim::default(), - ArrayArg::new(&input_handle, input.len()), - ArrayArg::new(&output_handle, input.len()), - ); - - let output = client.read(output_handle.binding()); - let output = f32::from_bytes(&output); - - // Should be [-0.1587, 0.0000, 0.8413, 5.0000] - println!("{output:?}"); -} diff --git a/examples/image-classification-web/Cargo.toml b/examples/image-classification-web/Cargo.toml index 652141ff22..87d4e06364 100644 --- a/examples/image-classification-web/Cargo.toml +++ b/examples/image-classification-web/Cargo.toml @@ -17,6 +17,7 @@ half_precision = [] burn = { path = "../../crates/burn", version = "0.14.0", default-features = false, features = [ "ndarray", ] } +cubecl = { workspace = true, features = ["wgpu"] } burn-wgpu = { path = "../../crates/burn-wgpu", version = "0.14.0", default-features = false, features = [ "autotune", ] } diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index 8ca6cb660b..84bc75194d 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -11,7 +11,8 @@ use crate::model::{label::LABELS, normalizer::Normalizer, squeezenet::Model as S use burn::{backend::NdArray, prelude::*, tensor::activation::softmax}; use burn_candle::Candle; -use burn_wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuDevice}; +use burn_wgpu::{Wgpu, WgpuDevice}; +use cubecl::wgpu::{init_async, AutoGraphicsApi}; use serde::Serialize; use wasm_bindgen::prelude::*; diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 6fe1ae2953..c7bca549ff 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -84,13 +84,10 @@ mod tch_cpu { #[cfg(feature = "wgpu")] mod wgpu { use crate::{launch, ElemType}; - use burn::backend::{ - wgpu::{Wgpu, WgpuDevice}, - Autodiff, - }; + use burn::backend::{wgpu::Wgpu, Autodiff}; pub fn run() { - launch::>>(vec![WgpuDevice::default()]); + launch::>>(vec![Default::default()]); } } diff --git a/xtask/src/runchecks.rs b/xtask/src/runchecks.rs index 850b07c81e..78f04fdbb8 100644 --- a/xtask/src/runchecks.rs +++ b/xtask/src/runchecks.rs @@ -232,10 +232,6 @@ fn no_std_checks() { // Run checks for the following crates build_and_test_no_std("burn", []); build_and_test_no_std("burn-core", []); - build_and_test_no_std( - "burn-compute", - ["--features", "channel-mutex,storage-bytes"], - ); build_and_test_no_std("burn-common", []); build_and_test_no_std("burn-tensor", []); build_and_test_no_std("burn-ndarray", []);