From 80e67b7aef6d7f7003992a6c40d1c3ba72d2e2cb Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 6 Jan 2025 08:05:18 -0500 Subject: [PATCH 01/61] Bump libc from 0.2.168 to 0.2.169 (#2656) Bumps [libc](https://github.com/rust-lang/libc) from 0.2.168 to 0.2.169. - [Release notes](https://github.com/rust-lang/libc/releases) - [Changelog](https://github.com/rust-lang/libc/blob/0.2.169/CHANGELOG.md) - [Commits](https://github.com/rust-lang/libc/compare/0.2.168...0.2.169) --- updated-dependencies: - dependency-name: libc dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7b555924fc..3cdd163311 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3825,9 +3825,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8" [[package]] name = "libc" -version = "0.2.168" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aaeb2981e0606ca11d79718f8bb01164f1d6ed75080182d3abf017e6d244b6d" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libfuzzer-sys" diff --git a/Cargo.toml b/Cargo.toml index e72651e70c..7c15091c3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -143,7 +143,7 @@ serde = { version = "1.0.217", default-features = false, features = [ serde_json = { version = "1.0.134", default-features = false } uuid = { version = "1.11.0", default-features = false } -libc = "0.2.168" +libc = "0.2.169" nvml-wrapper = "0.10.0" sysinfo = "0.32.1" systemstat = "0.2.3" From e2fa9c45b3ccc75a5d97db1ca64e86a8e33367b0 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Mon, 6 Jan 2025 12:52:59 -0500 Subject: [PATCH 02/61] Update cubecl (#2662) --- Cargo.lock | 24 ++++++++++++------------ Cargo.toml | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3cdd163311..128196fae0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1637,7 +1637,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1669,7 +1669,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1686,7 +1686,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1705,7 +1705,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1719,7 +1719,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1735,7 +1735,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1761,7 +1761,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" dependencies = [ "bytemuck", "cubecl-core", @@ -1773,7 +1773,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1788,7 +1788,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1825,7 +1825,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" dependencies = [ "async-channel", "async-lock", @@ -1846,7 +1846,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1860,7 +1860,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=34af9342a2b4f8dcf1b0047afbea0f26405b92cf#34af9342a2b4f8dcf1b0047afbea0f26405b92cf" +source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 7c15091c3c..863943f0b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "34af9342a2b4f8dcf1b0047afbea0f26405b92cf" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "34af9342a2b4f8dcf1b0047afbea0f26405b92cf" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } From a644430b0a07f3113b401b75c0c379415428631d Mon Sep 17 00:00:00 2001 From: Tiago Sanona <40792244+tsanona@users.noreply.github.com> Date: Tue, 7 Jan 2025 21:08:27 +0100 Subject: [PATCH 03/61] add fscore, reorganize mod and change error in doc. (#2648) * add fscore, reorganize mod and change error in doc. * add function args to doc * fscore to fbetascore and formatting metrics doc in book * Minor doc fixes --------- Co-authored-by: Tiago Sanona Co-authored-by: Guillaume Lagrange --- burn-book/src/building-blocks/metric.md | 3 +- crates/burn-train/src/metric/fbetascore.rs | 195 +++++++++++++++++++++ crates/burn-train/src/metric/mod.rs | 41 +++-- crates/burn-train/src/metric/precision.rs | 4 +- crates/burn-train/src/metric/recall.rs | 6 +- 5 files changed, 227 insertions(+), 22 deletions(-) create mode 100644 crates/burn-train/src/metric/fbetascore.rs diff --git a/burn-book/src/building-blocks/metric.md b/burn-book/src/building-blocks/metric.md index e029aca708..e5dd4eaae9 100644 --- a/burn-book/src/building-blocks/metric.md +++ b/burn-book/src/building-blocks/metric.md @@ -4,11 +4,12 @@ When working with the learner, you have the option to record metrics that will b throughout the training process. We currently offer a restricted range of metrics. | Metric | Description | -|------------------|---------------------------------------------------------| +| ---------------- | ------------------------------------------------------- | | Accuracy | Calculate the accuracy in percentage | | TopKAccuracy | Calculate the top-k accuracy in percentage | | Precision | Calculate precision in percentage | | Recall | Calculate recall in percentage | +| FBetaScore | Calculate Fβ score in percentage | | AUROC | Calculate the area under curve of ROC in percentage | | Loss | Output the loss used for the backward pass | | CPU Temperature | Fetch the temperature of CPUs | diff --git a/crates/burn-train/src/metric/fbetascore.rs b/crates/burn-train/src/metric/fbetascore.rs new file mode 100644 index 0000000000..5eeba0aa9c --- /dev/null +++ b/crates/burn-train/src/metric/fbetascore.rs @@ -0,0 +1,195 @@ +use super::{ + classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, + confusion_stats::{ConfusionStats, ConfusionStatsInput}, + state::{FormatOptions, NumericMetricState}, + Metric, MetricEntry, MetricMetadata, Numeric, +}; +use burn_core::{ + prelude::{Backend, Tensor}, + tensor::cast::ToElement, +}; +use core::marker::PhantomData; +use std::num::NonZeroUsize; + +/// The [F-beta score](https://en.wikipedia.org/wiki/F-score) metric. +#[derive(Default)] +pub struct FBetaScoreMetric { + state: NumericMetricState, + _b: PhantomData, + config: ClassificationMetricConfig, + beta: f64, +} + +impl FBetaScoreMetric { + /// F-beta score metric for binary classification. + /// + /// # Arguments + /// + /// * `beta` - Positive real factor to weight recall's importance. + /// * `threshold` - The threshold to transform a probability into a binary prediction. + #[allow(dead_code)] + pub fn binary(beta: f64, threshold: f64) -> Self { + Self { + config: ClassificationMetricConfig { + decision_rule: DecisionRule::Threshold(threshold), + // binary classification results are the same independently of class_reduction + ..Default::default() + }, + beta, + ..Default::default() + } + } + + /// F-beta score metric for multiclass classification. + /// + /// # Arguments + /// + /// * `beta` - Positive real factor to weight recall's importance. + /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). + /// * `class_reduction` - [Class reduction](ClassReduction) type. + #[allow(dead_code)] + pub fn multiclass(beta: f64, top_k: usize, class_reduction: ClassReduction) -> Self { + Self { + config: ClassificationMetricConfig { + decision_rule: DecisionRule::TopK( + NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + ), + class_reduction, + }, + beta, + ..Default::default() + } + } + + /// F-beta score metric for multi-label classification. + /// + /// # Arguments + /// + /// * `beta` - Positive real factor to weight recall's importance. + /// * `threshold` - The threshold to transform a probability into a binary prediction. + /// * `class_reduction` - [Class reduction](ClassReduction) type. + #[allow(dead_code)] + pub fn multilabel(beta: f64, threshold: f64, class_reduction: ClassReduction) -> Self { + Self { + config: ClassificationMetricConfig { + decision_rule: DecisionRule::Threshold(threshold), + class_reduction, + }, + beta, + ..Default::default() + } + } + + fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { + use ClassReduction::{Macro, Micro}; + let avg_tensor = match self.config.class_reduction { + Micro => aggregated_metric, + Macro => { + if aggregated_metric.contains_nan().any().into_scalar() { + let nan_mask = aggregated_metric.is_nan(); + aggregated_metric = aggregated_metric + .clone() + .select(0, nan_mask.bool_not().argwhere().squeeze(1)) + } + aggregated_metric.mean() + } + }; + avg_tensor.into_scalar().to_f64() + } +} + +impl Metric for FBetaScoreMetric { + const NAME: &'static str = "FBetaScore"; + type Input = ConfusionStatsInput; + + fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + let [sample_size, _] = input.predictions.dims(); + + let cf_stats = ConfusionStats::new(input, &self.config); + let scaled_true_positive = cf_stats.clone().true_positive() * (1.0 + self.beta.powi(2)); + let metric = self.class_average( + scaled_true_positive.clone() + / (scaled_true_positive + + cf_stats.clone().false_negative() * self.beta.powi(2) + + cf_stats.false_positive()), + ); + + self.state.update( + 100.0 * metric, + sample_size, + FormatOptions::new(Self::NAME).unit("%").precision(2), + ) + } + + fn clear(&mut self) { + self.state.reset() + } +} + +impl Numeric for FBetaScoreMetric { + fn value(&self) -> f64 { + self.state.value() + } +} + +#[cfg(test)] +mod tests { + use super::{ + ClassReduction::{self, *}, + FBetaScoreMetric, Metric, MetricMetadata, Numeric, + }; + use crate::tests::{dummy_classification_input, ClassificationType, THRESHOLD}; + use burn_core::tensor::TensorData; + use rstest::rstest; + + #[rstest] + #[case::binary_b1(1.0, THRESHOLD, 0.5)] + #[case::binary_b2(2.0, THRESHOLD, 0.5)] + fn test_binary_fscore(#[case] beta: f64, #[case] threshold: f64, #[case] expected: f64) { + let input = dummy_classification_input(&ClassificationType::Binary).into(); + let mut metric = FBetaScoreMetric::binary(beta, threshold); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } + + #[rstest] + #[case::multiclass_b1_micro_k1(1.0, Micro, 1, 3.0/5.0)] + #[case::multiclass_b1_micro_k2(1.0, Micro, 2, 2.0/(5.0/4.0 + 10.0/4.0))] + #[case::multiclass_b1_macro_k1(1.0, Macro, 1, (0.5 + 2.0/(1.0 + 2.0) + 2.0/(2.0 + 1.0))/3.0)] + #[case::multiclass_b1_macro_k2(1.0, Macro, 2, (2.0/(1.0 + 2.0) + 2.0/(1.0 + 4.0) + 0.5)/3.0)] + #[case::multiclass_b2_micro_k1(2.0, Micro, 1, 3.0/5.0)] + #[case::multiclass_b2_micro_k2(2.0, Micro, 2, 5.0*4.0/(4.0*5.0 + 10.0))] + #[case::multiclass_b2_macro_k1(2.0, Macro, 1, (0.5 + 5.0/(4.0 + 2.0) + 5.0/(8.0 + 1.0))/3.0)] + #[case::multiclass_b2_macro_k2(2.0, Macro, 2, (5.0/(4.0 + 2.0) + 5.0/(4.0 + 4.0) + 0.5)/3.0)] + fn test_multiclass_fscore( + #[case] beta: f64, + #[case] class_reduction: ClassReduction, + #[case] top_k: usize, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Multiclass).into(); + let mut metric = FBetaScoreMetric::multiclass(beta, top_k, class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } + + #[rstest] + #[case::multilabel_micro(1.0, Micro, THRESHOLD, 2.0/(9.0/5.0 + 8.0/5.0))] + #[case::multilabel_macro(1.0, Macro, THRESHOLD, (2.0/(2.0 + 3.0/2.0) + 2.0/(1.0 + 3.0/2.0) + 2.0/(3.0+2.0))/3.0)] + #[case::multilabel_micro(2.0, Micro, THRESHOLD, 5.0/(4.0*9.0/5.0 + 8.0/5.0))] + #[case::multilabel_macro(2.0, Macro, THRESHOLD, (5.0/(8.0 + 3.0/2.0) + 5.0/(4.0 + 3.0/2.0) + 5.0/(12.0+2.0))/3.0)] + fn test_multilabel_fscore( + #[case] beta: f64, + #[case] class_reduction: ClassReduction, + #[case] threshold: f64, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Multilabel).into(); + let mut metric = FBetaScoreMetric::multilabel(beta, threshold, class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } +} diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index e6358e3023..191099a383 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -1,23 +1,32 @@ /// State module. pub mod state; +/// Module responsible to save and exposes data collected during training. +pub mod store; mod acc; mod auroc; mod base; #[cfg(feature = "metrics")] +mod confusion_stats; +#[cfg(feature = "metrics")] mod cpu_temp; #[cfg(feature = "metrics")] mod cpu_use; #[cfg(feature = "metrics")] mod cuda; +#[cfg(feature = "metrics")] +mod fbetascore; mod hamming; +#[cfg(feature = "metrics")] +mod iteration; mod learning_rate; mod loss; #[cfg(feature = "metrics")] mod memory_use; - #[cfg(feature = "metrics")] -mod iteration; +mod precision; +#[cfg(feature = "metrics")] +mod recall; #[cfg(feature = "metrics")] mod top_k_acc; @@ -25,11 +34,15 @@ pub use acc::*; pub use auroc::*; pub use base::*; #[cfg(feature = "metrics")] +pub use confusion_stats::ConfusionStatsInput; +#[cfg(feature = "metrics")] pub use cpu_temp::*; #[cfg(feature = "metrics")] pub use cpu_use::*; #[cfg(feature = "metrics")] pub use cuda::*; +#[cfg(feature = "metrics")] +pub use fbetascore::*; pub use hamming::*; #[cfg(feature = "metrics")] pub use iteration::*; @@ -38,25 +51,17 @@ pub use loss::*; #[cfg(feature = "metrics")] pub use memory_use::*; #[cfg(feature = "metrics")] +pub use precision::*; +#[cfg(feature = "metrics")] +pub use recall::*; +#[cfg(feature = "metrics")] pub use top_k_acc::*; +#[cfg(feature = "metrics")] +pub(crate) mod classification; pub(crate) mod processor; -// Expose `ItemLazy` so it can be implemented for custom types -pub use processor::ItemLazy; - -/// Module responsible to save and exposes data collected during training. -pub mod store; -pub(crate) mod classification; #[cfg(feature = "metrics")] pub use crate::metric::classification::ClassReduction; -mod confusion_stats; -pub use confusion_stats::ConfusionStatsInput; -#[cfg(feature = "metrics")] -mod precision; -#[cfg(feature = "metrics")] -pub use precision::*; -#[cfg(feature = "metrics")] -mod recall; -#[cfg(feature = "metrics")] -pub use recall::*; +// Expose `ItemLazy` so it can be implemented for custom types +pub use processor::ItemLazy; diff --git a/crates/burn-train/src/metric/precision.rs b/crates/burn-train/src/metric/precision.rs index 067261cbdf..375d368795 100644 --- a/crates/burn-train/src/metric/precision.rs +++ b/crates/burn-train/src/metric/precision.rs @@ -42,6 +42,7 @@ impl PrecisionMetric { /// # Arguments /// /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). + /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { @@ -60,6 +61,7 @@ impl PrecisionMetric { /// # Arguments /// /// * `threshold` - The threshold to transform a probability into a binary value. + /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { @@ -129,7 +131,7 @@ mod tests { use rstest::rstest; #[rstest] - #[case::binary_macro(THRESHOLD, 0.5)] + #[case::binary(THRESHOLD, 0.5)] fn test_binary_precision(#[case] threshold: f64, #[case] expected: f64) { let input = dummy_classification_input(&ClassificationType::Binary).into(); let mut metric = PrecisionMetric::binary(threshold); diff --git a/crates/burn-train/src/metric/recall.rs b/crates/burn-train/src/metric/recall.rs index 8ce4351396..5003ddcd03 100644 --- a/crates/burn-train/src/metric/recall.rs +++ b/crates/burn-train/src/metric/recall.rs @@ -11,7 +11,7 @@ use burn_core::{ use core::marker::PhantomData; use std::num::NonZeroUsize; -///The Precision Metric +///The Recall Metric #[derive(Default)] pub struct RecallMetric { state: NumericMetricState, @@ -42,6 +42,7 @@ impl RecallMetric { /// # Arguments /// /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). + /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { @@ -60,6 +61,7 @@ impl RecallMetric { /// # Arguments /// /// * `threshold` - The threshold to transform a probability into a binary prediction. + /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { @@ -128,7 +130,7 @@ mod tests { use rstest::rstest; #[rstest] - #[case::binary_macro(THRESHOLD, 0.5)] + #[case::binary(THRESHOLD, 0.5)] fn test_binary_recall(#[case] threshold: f64, #[case] expected: f64) { let input = dummy_classification_input(&ClassificationType::Binary).into(); let mut metric = RecallMetric::binary(threshold); From 2b4be6cae3d8123bb5f09918bc95d3572a376c1c Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Tue, 7 Jan 2025 16:05:15 -0500 Subject: [PATCH 04/61] Refactor unary + binary kernels (#2665) * Refactor unary + binary kernels * Improve float unary * Cleanup binary --- crates/burn-jit/src/kernel/binary.rs | 281 +++++++++++--------- crates/burn-jit/src/kernel/clamp.rs | 41 +-- crates/burn-jit/src/kernel/mod.rs | 6 +- crates/burn-jit/src/kernel/unary.rs | 158 ----------- crates/burn-jit/src/kernel/unary_float.rs | 181 +++++++++++++ crates/burn-jit/src/kernel/unary_numeric.rs | 106 ++++++++ crates/burn-jit/src/ops/float_ops.rs | 185 +++---------- crates/burn-jit/src/ops/int_ops.rs | 22 +- crates/burn-jit/src/ops/numeric.rs | 2 +- crates/burn-jit/src/tensor/base.rs | 34 ++- 10 files changed, 549 insertions(+), 467 deletions(-) delete mode 100644 crates/burn-jit/src/kernel/unary.rs create mode 100644 crates/burn-jit/src/kernel/unary_float.rs create mode 100644 crates/burn-jit/src/kernel/unary_numeric.rs diff --git a/crates/burn-jit/src/kernel/binary.rs b/crates/burn-jit/src/kernel/binary.rs index d799d1caea..d7c4d789ab 100644 --- a/crates/burn-jit/src/kernel/binary.rs +++ b/crates/burn-jit/src/kernel/binary.rs @@ -4,34 +4,60 @@ use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, use burn_tensor::Shape; use cubecl::{ calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, - tensor_vectorization_factor, + tensor_line_size_parallel, }; use super::into_contiguous; +pub(crate) trait BinaryOpFamily: Send + Sync + 'static { + type BinaryOp: BinaryOp; +} + #[cube] pub(crate) trait BinaryOp: 'static + Send + Sync { /// Execute a binary operation. fn execute(lhs: Line, rhs: Line) -> Line; } -pub(crate) trait BinaryOpSpec: Send + Sync + 'static { - type C: Numeric; -} -pub(crate) struct Spec { - _c: PhantomData, -} - -impl BinaryOpSpec for Spec { - type C = C; -} - pub(crate) struct AddOp; pub(crate) struct SubOp; pub(crate) struct MulOp; pub(crate) struct DivOp; pub(crate) struct RemainderOp; -pub(crate) struct PowOp; + +/// Since Powf only works on float, but we still want to implement the numeric binary op family, we +/// set another precision in the family type to cast, when necessary, the input value to a valid +/// float. +/// +/// Because of this we won't benefit from the cubecl rust compilation speed improvement from using +/// the family pattern for [PowOp], but at least we don't duplicate code. +pub(crate) struct PowOp { + _f: PhantomData, +} + +impl BinaryOpFamily for AddOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for SubOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for MulOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for DivOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for RemainderOp { + type BinaryOp = Self; +} + +impl BinaryOpFamily for PowOp { + type BinaryOp = Self; +} #[cube] impl BinaryOp for AddOp { @@ -69,30 +95,34 @@ impl BinaryOp for RemainderOp { } #[cube] -impl BinaryOp for PowOp { +impl BinaryOp for PowOp { fn execute(lhs: Line, rhs: Line) -> Line { - Line::powf(lhs, rhs) + let lhs = Line::::cast_from(lhs); + let rhs = Line::::cast_from(rhs); + let out = Line::powf(lhs, rhs); + + Line::cast_from(out) } } -#[cube(launch)] -pub(crate) fn kernel_scalar_binop>( - input: &Tensor>, - scalar: BS::C, - output: &mut Tensor>, +#[cube(launch_unchecked)] +pub(crate) fn kernel_scalar_binop( + input: &Tensor>, + scalar: C, + output: &mut Tensor>, ) { if ABSOLUTE_POS >= output.len() { return; } - output[ABSOLUTE_POS] = O::execute(input[ABSOLUTE_POS], Line::new(scalar)); + output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); } -#[cube(launch)] -pub(crate) fn kernel_binop>( - lhs: &Tensor>, - rhs: &Tensor>, - out: &mut Tensor>, +#[cube(launch_unchecked)] +pub(crate) fn kernel_binop( + lhs: &Tensor>, + rhs: &Tensor>, + out: &mut Tensor>, #[comptime] rank: Option, #[comptime] to_contiguous_lhs: bool, #[comptime] to_contiguous_rhs: bool, @@ -106,7 +136,7 @@ pub(crate) fn kernel_binop>( } if to_contiguous_lhs { - offset_lhs = index_offset_with_layout::( + offset_lhs = index_offset_with_layout::( lhs, out, offset_out, @@ -117,7 +147,7 @@ pub(crate) fn kernel_binop>( } if to_contiguous_rhs { - offset_rhs = index_offset_with_layout::( + offset_rhs = index_offset_with_layout::( rhs, out, offset_out, @@ -127,20 +157,27 @@ pub(crate) fn kernel_binop>( ); } - out[offset_out] = O::execute(lhs[offset_lhs], rhs[offset_rhs]); + out[offset_out] = O::BinaryOp::::execute(lhs[offset_lhs], rhs[offset_rhs]); } -pub(crate) fn launch_binop>( +pub(crate) fn launch_binop( lhs: JitTensor, rhs: JitTensor, ) -> JitTensor { let ndims = lhs.shape.num_dims(); - let vectorization_factor_lhs = - tensor_vectorization_factor(&[4, 2], &lhs.shape.dims, &lhs.strides, ndims - 1); - let vectorization_factor_rhs = - tensor_vectorization_factor(&[4, 2], &rhs.shape.dims, &rhs.strides, ndims - 1); - - let vectorization_factor = Ord::min(vectorization_factor_lhs, vectorization_factor_rhs); + let line_size_lhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &lhs.shape.dims, + &lhs.strides, + ndims - 1, + ); + let line_size_rhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &rhs.shape.dims, + &rhs.strides, + ndims - 1, + ); + let line_size = Ord::min(line_size_lhs, line_size_rhs); let mut shape_out = vec![0; ndims]; lhs.shape @@ -157,59 +194,60 @@ pub(crate) fn launch_binop>( let num_elems = shape_out.num_elements(); let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - - if lhs.can_mut_broadcast(&rhs) { - kernel_binop::launch::, O, R>( - &client, - cube_count, - cube_dim, - lhs.as_tensor_arg::(vectorization_factor), - rhs.as_tensor_arg::(vectorization_factor), - TensorArg::alias(0), - None, - false, - rhs.strides != lhs.strides || rhs.shape != lhs.shape, - ); - - lhs - } else if rhs.can_mut_broadcast(&lhs) { - kernel_binop::launch::, O, R>( - &client, - cube_count, - cube_dim, - lhs.as_tensor_arg::(vectorization_factor), - rhs.as_tensor_arg::(vectorization_factor), - TensorArg::alias(1), - None, - rhs.strides != lhs.strides || rhs.shape != lhs.shape, - false, - ); - - rhs - } else { - let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); - let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; - let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; - - kernel_binop::launch::, O, R>( - &client, - cube_count, - cube_dim, - lhs.as_tensor_arg::(vectorization_factor), - rhs.as_tensor_arg::(vectorization_factor), - output.as_tensor_arg::(vectorization_factor), - None, - to_contiguous_lhs, - to_contiguous_rhs, - ); - - output + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if lhs.can_mut_broadcast(&rhs) { + kernel_binop::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(0), + None, + false, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + ); + + lhs + } else if rhs.can_mut_broadcast(&lhs) { + kernel_binop::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(1), + None, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + false, + ); + + rhs + } else { + let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); + let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; + let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; + + kernel_binop::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + None, + to_contiguous_lhs, + to_contiguous_rhs, + ); + + output + } } } -pub(crate) fn launch_scalar_binop>( +pub(crate) fn launch_scalar_binop( mut tensor: JitTensor, scalar: E, ) -> JitTensor { @@ -219,42 +257,47 @@ pub(crate) fn launch_scalar_binop>( // Vectorization is only enabled when the last dimension is contiguous. let ndims = tensor.shape.num_dims(); - let vectorization_factor = - tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, ndims - 1); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); let client = tensor.client.clone(); let num_elems = tensor.shape.num_elements(); let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - - if tensor.can_mut() { - kernel_scalar_binop::launch::, O, R>( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(vectorization_factor), - ScalarArg::new(scalar), - TensorArg::alias(0), - ); - - tensor - } else { - let output = empty_device::( - tensor.client.clone(), - tensor.device.clone(), - tensor.shape.clone(), - ); - - kernel_scalar_binop::launch::, O, R>( - &client, - cube_count, - CubeDim::default(), - tensor.as_tensor_arg::(vectorization_factor), - ScalarArg::new(scalar), - output.as_tensor_arg::(vectorization_factor), - ); - - output + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if tensor.can_mut() { + kernel_scalar_binop::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + TensorArg::alias(0), + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + kernel_scalar_binop::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + output.as_tensor_arg::(line_size), + ); + + output + } } } diff --git a/crates/burn-jit/src/kernel/clamp.rs b/crates/burn-jit/src/kernel/clamp.rs index 683e8aff8f..ec2bc93d1f 100644 --- a/crates/burn-jit/src/kernel/clamp.rs +++ b/crates/burn-jit/src/kernel/clamp.rs @@ -1,7 +1,11 @@ use cubecl::prelude::*; -use crate::kernel::{launch_unary, UnaryOp}; -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; +use crate::{ + element::JitElement, + kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily}, + tensor::JitTensor, + JitRuntime, +}; #[derive(CubeLaunch)] struct Options { @@ -16,28 +20,25 @@ pub(crate) fn clamp( ) -> JitTensor { struct ClampOp; - impl UnaryOp for ClampOp { - type Options = Options; + #[cube] + impl NumericUnaryOp for ClampOp { + type Options = Options; - fn __expand_execute( - context: &mut CubeContext, - input: as CubeType>::ExpandType, - options: OptionsExpand, - ) -> as CubeType>::ExpandType { - #[cube] - fn execute(input: Line, options: &Options) -> Line { - Line::clamp( - input, - Line::new(options.min_value), - Line::new(options.max_value), - ) - } - - execute::expand(context, input, options) + fn execute(input: Line, options: &Self::Options) -> Line { + Line::clamp( + input, + Line::new(options.min_value), + Line::new(options.max_value), + ) } } - launch_unary::(input, |_| { + impl NumericUnaryOpFamily for ClampOp { + type Options = Options; + type Unary = Self; + } + + launch_unary_numeric::(input, |_| { OptionsLaunch::new(ScalarArg::new(min_value), ScalarArg::new(max_value)) }) } diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index afa8ecd6fa..660ae2f6fd 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -5,13 +5,15 @@ mod comparison; mod contiguous; mod index; mod mask; -mod unary; +mod unary_float; +mod unary_numeric; pub(crate) use binary::*; pub use cast::*; pub use contiguous::*; pub use mask::*; -pub(crate) use unary::*; +pub(crate) use unary_float::*; +pub(crate) use unary_numeric::*; pub use cubecl::{Kernel, PLANE_DIM_APPROX}; diff --git a/crates/burn-jit/src/kernel/unary.rs b/crates/burn-jit/src/kernel/unary.rs deleted file mode 100644 index 09f9c77689..0000000000 --- a/crates/burn-jit/src/kernel/unary.rs +++ /dev/null @@ -1,158 +0,0 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; -use cubecl::{ - calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, - tensor_vectorization_factor, unexpanded, -}; - -#[cube] -pub(crate) trait UnaryOp: 'static + Send + Sync { - type Options: LaunchArg; - - /// Execute a unary operation. - fn execute(_input: Line, _options: &Self::Options) -> Line { - unexpanded!(); - } -} - -#[cube(launch)] -pub(crate) fn unary_kernel>( - input: &Tensor>, - output: &mut Tensor>, - options: &O::Options, - #[comptime] rank: Option, - #[comptime] to_contiguous: bool, -) { - let offset_output = ABSOLUTE_POS; - - if offset_output >= output.len() { - return; - } - - if to_contiguous { - let offset_input = index_offset_with_layout::( - input, - output, - offset_output, - 0, - rank.unwrap_or_else(|| output.rank()), - rank.is_some(), - ); - - output[offset_output] = O::execute(input[offset_input], options); - } else { - output[offset_output] = O::execute(input[offset_output], options); - } -} - -pub(crate) fn launch_unary, F>( - tensor: JitTensor, - options: F, -) -> JitTensor -where - // Magic fix for lifetime, the closure is supposed to capture everything required to create the - // argument. - for<'a> F: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, -{ - let ndims = tensor.shape.num_dims(); - // Vectorization is only enabled when the last dimension is contiguous. - let vectorization_factor = - tensor_vectorization_factor(&[4, 2], &tensor.shape.dims, &tensor.strides, ndims - 1); - - let client = tensor.client.clone(); - let num_elems = tensor.shape.num_elements(); - - let cube_dim = CubeDim::default(); - let cube_count = - calculate_cube_count_elemwise(num_elems / vectorization_factor as usize, cube_dim); - let is_contiguous = tensor.is_contiguous(); - - if tensor.can_mut() && tensor.is_contiguous_buffer() { - unary_kernel::launch::( - &client, - cube_count, - cube_dim, - tensor.as_tensor_arg::(vectorization_factor), - TensorArg::alias(0), - options(&()), - None, - false, - ); - - tensor - } else { - let output = empty_device::( - tensor.client.clone(), - tensor.device.clone(), - tensor.shape.clone(), - ); - - unary_kernel::launch::( - &client, - cube_count, - CubeDim::default(), - tensor.as_tensor_arg::(vectorization_factor), - output.as_tensor_arg::(vectorization_factor), - options(&()), - Some(ndims as u32), - !is_contiguous, - ); - output - } -} - -macro_rules! unary_op { - ($name:ident, $elem:ident, $expand:expr) => { - struct $name; - - impl UnaryOp for $name { - type Options = (); - - #[allow(clippy::redundant_closure_call)] - fn __expand_execute( - context: &mut CubeContext, - input: as CubeType>::ExpandType, - _options: ::ExpandType, - ) -> as CubeType>::ExpandType { - $expand(context, input) - } - } - }; - (scalar $name:ident, $elem:ident, $expand:expr) => { - struct $name; - - impl UnaryOp for $name { - type Options = C; - - #[allow(clippy::redundant_closure_call)] - fn __expand_execute( - context: &mut CubeContext, - input: as CubeType>::ExpandType, - scalar: C::ExpandType, - ) -> as CubeType>::ExpandType { - $expand(context, input, scalar) - } - } - }; - (float($tensor:expr) => $exp:expr) => {{ - unary_op!(Op, Float, $exp); - launch_unary::($tensor, |_| ()) - }}; - (int($tensor:expr) => $exp:expr) => {{ - unary_op!(Op, Numeric, $exp); - launch_unary::($tensor, |_| ()) - }}; - (numeric($tensor:expr) => $exp:expr) => {{ - unary_op!(Op, Numeric, $exp); - launch_unary::($tensor, |_| ()) - }}; - (numeric($tensor:expr, $scalar:expr) => $exp:expr) => {{ - unary_op!(scalar Op, Numeric, $exp); - launch_unary::($tensor, |_| ScalarArg::new($scalar)) - }}; - (float($tensor:expr, $scalar:expr) => $exp:expr) => {{ - unary_op!(scalar Op, Float, $exp); - launch_unary::($tensor, |_| ScalarArg::new($scalar)) - }}; -} - -pub(crate) use unary_op; diff --git a/crates/burn-jit/src/kernel/unary_float.rs b/crates/burn-jit/src/kernel/unary_float.rs new file mode 100644 index 0000000000..33a311ecbc --- /dev/null +++ b/crates/burn-jit/src/kernel/unary_float.rs @@ -0,0 +1,181 @@ +use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +pub(crate) trait FloatUnaryOpFamily: 'static + Send + Sync { + type Options: LaunchArg; + type Unary: FloatUnaryOp>; +} + +#[cube] +pub(crate) trait FloatUnaryOp: 'static + Send + Sync { + type Options: LaunchArg; + + fn execute(input: Line, options: &Self::Options) -> Line; +} + +#[cube(launch_unchecked)] +pub(crate) fn unary_float( + input: &Tensor>, + output: &mut Tensor>, + options: &O::Options, + #[comptime] rank: Option, + #[comptime] to_contiguous: bool, +) { + let offset_output = ABSOLUTE_POS; + + if offset_output >= output.len() { + return; + } + + if comptime![to_contiguous] { + let offset_input = index_offset_with_layout::( + input, + output, + offset_output, + 0, + rank.unwrap_or_else(|| output.rank()), + rank.is_some(), + ); + + output[offset_output] = O::Unary::::execute(input[offset_input], options); + } else { + output[offset_output] = O::Unary::::execute(input[offset_output], options); + } +} + +pub(crate) fn launch_unary_float(tensor: JitTensor, args: Args) -> JitTensor +where + // Magic fix for lifetime, the closure is supposed to capture everything required to create the + // argument. + for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, + R: JitRuntime, + E: JitElement + Float, + O: FloatUnaryOpFamily, +{ + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + let is_contiguous = tensor.is_contiguous(); + + unsafe { + if tensor.can_mut() && tensor.is_contiguous_buffer() { + unary_float::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + TensorArg::alias(0), + args(&()), + None, + false, + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + unary_float::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + args(&()), + Some(ndims as u32), + !is_contiguous, + ); + output + } + } +} + +/// Use comptime enum to implement all unary operations that don't have any input argument in the +/// kernel definition. +pub(crate) mod unary_basic { + use crate::execute_with_dtype; + + use super::*; + + pub(crate) fn launch(tensor: JitTensor, args: Args) -> JitTensor + where + R: JitRuntime, + for<'a> Args: FnOnce(&'a ()) -> &'a BasicFloatUnaryKind, + { + execute_with_dtype!( + float(tensor.dtype), + F, + launch_unary_float::(tensor, |input| { + BasicFloatUnaryOptionsLaunch::new(args(input)) + }) + ) + } + + #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)] + pub enum BasicFloatUnaryKind { + Exp, + Log, + Log1p, + Sqrt, + Abs, + Cos, + Sin, + Tanh, + Round, + Floor, + Ceil, + Erf, + Recip, + } + + #[derive(CubeLaunch)] + struct BasicFloatUnaryOptions { + #[cube(comptime)] + kind: BasicFloatUnaryKind, + } + struct BasicFloatUnary; + + #[cube] + impl FloatUnaryOp for BasicFloatUnary { + type Options = BasicFloatUnaryOptions; + + fn execute(input: Line, options: &Self::Options) -> Line { + match comptime![options.kind] { + BasicFloatUnaryKind::Exp => Line::exp(input), + BasicFloatUnaryKind::Log => Line::log(input), + BasicFloatUnaryKind::Log1p => Line::log1p(input), + BasicFloatUnaryKind::Sqrt => Line::sqrt(input), + BasicFloatUnaryKind::Abs => Line::abs(input), + BasicFloatUnaryKind::Cos => Line::cos(input), + BasicFloatUnaryKind::Sin => Line::sin(input), + BasicFloatUnaryKind::Tanh => Line::tanh(input), + BasicFloatUnaryKind::Round => Line::round(input), + BasicFloatUnaryKind::Floor => Line::floor(input), + BasicFloatUnaryKind::Ceil => Line::ceil(input), + BasicFloatUnaryKind::Erf => Line::erf(input), + BasicFloatUnaryKind::Recip => Line::recip(input), + } + } + } + + impl FloatUnaryOpFamily for BasicFloatUnary { + type Options = BasicFloatUnaryOptions; + type Unary = Self; + } +} diff --git a/crates/burn-jit/src/kernel/unary_numeric.rs b/crates/burn-jit/src/kernel/unary_numeric.rs new file mode 100644 index 0000000000..0b8dcb2cbc --- /dev/null +++ b/crates/burn-jit/src/kernel/unary_numeric.rs @@ -0,0 +1,106 @@ +use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +pub(crate) trait NumericUnaryOpFamily: 'static + Send + Sync { + type Options: LaunchArg; + type Unary: NumericUnaryOp>; +} + +#[cube] +pub(crate) trait NumericUnaryOp: 'static + Send + Sync { + type Options: LaunchArg; + + fn execute(input: Line, options: &Self::Options) -> Line; +} + +#[cube(launch_unchecked)] +pub(crate) fn unary_numeric( + input: &Tensor>, + output: &mut Tensor>, + options: &O::Options, + #[comptime] rank: Option, + #[comptime] to_contiguous: bool, +) { + let offset_output = ABSOLUTE_POS; + + if offset_output >= output.len() { + return; + } + + if comptime![to_contiguous] { + let offset_input = index_offset_with_layout::( + input, + output, + offset_output, + 0, + rank.unwrap_or_else(|| output.rank()), + rank.is_some(), + ); + + output[offset_output] = O::Unary::::execute(input[offset_input], options); + } else { + output[offset_output] = O::Unary::::execute(input[offset_output], options); + } +} + +pub(crate) fn launch_unary_numeric(tensor: JitTensor, args: Args) -> JitTensor +where + // Magic fix for lifetime, the closure is supposed to capture everything required to create the + // argument. + for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, + R: JitRuntime, + E: JitElement + Numeric, + O: NumericUnaryOpFamily, +{ + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + let is_contiguous = tensor.is_contiguous(); + + unsafe { + if tensor.can_mut() && tensor.is_contiguous_buffer() { + unary_numeric::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + TensorArg::alias(0), + args(&()), + None, + false, + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + unary_numeric::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + args(&()), + Some(ndims as u32), + !is_contiguous, + ); + output + } + } +} diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 2dc8a4a6f2..6090e895ae 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -1,6 +1,9 @@ use super::{expand, numeric, permute}; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; -use crate::kernel::{self, launch_unary, reduce, unary_op, UnaryOp}; +use crate::kernel::unary_basic::BasicFloatUnaryKind; +use crate::kernel::{ + self, launch_unary_float, reduce, unary_basic, FloatUnaryOp, FloatUnaryOpFamily, +}; use crate::{ element::BoolElement, kernel::matmul::{matmul, MatmulStrategy}, @@ -389,185 +392,75 @@ where } fn float_exp(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::exp(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Exp) } fn float_log(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::log(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Log) } fn float_log1p(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::log1p(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Log1p) } fn float_powf_scalar(lhs: FloatTensor, rhs: f32) -> FloatTensor { + struct Powf; + + #[cube] + impl FloatUnaryOp for Powf { + type Options = F; + + fn execute(input: Line, options: &Self::Options) -> Line { + Line::powf(input, Line::new(*options)) + } + } + + impl FloatUnaryOpFamily for Powf { + type Options = F; + type Unary = Self; + } + execute_with_dtype!( float(lhs.dtype), F, - unary_op!(float(lhs, rhs.elem::()) => |context, tensor, scalar| { - #[cube] - fn execute(input: Line, scalar: C) -> Line { - Line::powf(input, Line::new(scalar)) - } - execute::expand::(context, tensor, scalar) - }) + launch_unary_float::(lhs, |_| ScalarArg::new(rhs.elem::())) ) } fn float_sqrt(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::sqrt(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Sqrt) } fn float_abs(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::abs(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Abs) } fn float_cos(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::cos(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Cos) } fn float_sin(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::sin(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Sin) } fn float_tanh(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::tanh(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Tanh) } fn float_round(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::round(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Round) } fn float_floor(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::floor(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Floor) } fn float_ceil(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::ceil(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Ceil) } fn float_erf(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::erf(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Erf) } fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { @@ -603,17 +496,7 @@ where } fn float_recip(tensor: FloatTensor) -> FloatTensor { - execute_with_dtype!( - float(tensor.dtype), - F, - unary_op!(float(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - Line::recip(input) - } - execute::expand::(context, tensor) - }) - ) + unary_basic::launch::(tensor, |_| &BasicFloatUnaryKind::Recip) } fn float_repeat_dim(tensor: FloatTensor, dim: usize, times: usize) -> FloatTensor { diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 25bb92521f..a0e181a9c7 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -1,5 +1,5 @@ use super::{expand, numeric, permute}; -use crate::kernel::{launch_unary, unary_op, UnaryOp}; +use crate::kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily}; use crate::{ element::BoolElement, kernel::prng::{random_bernoulli, random_normal, random_uniform}, @@ -229,13 +229,23 @@ where } fn int_abs(tensor: IntTensor) -> IntTensor { - unary_op!(int(tensor) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { + struct Abs; + + #[cube] + impl NumericUnaryOp for Abs { + type Options = (); + + fn execute(input: Line, _options: &Self::Options) -> Line { Line::abs(input) } - execute::expand::(context, tensor) - }) + } + + impl NumericUnaryOpFamily for Abs { + type Options = (); + type Unary = Self; + } + + launch_unary_numeric::(tensor, |_| ()) } fn int_into_float(tensor: IntTensor) -> FloatTensor { diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index 5632425198..d0d5be8468 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -137,5 +137,5 @@ pub fn remainder_scalar(lhs: JitTensor, rhs: E) } pub fn pow(lhs: JitTensor, rhs: JitTensor) -> JitTensor { - launch_binop::(lhs, rhs) + launch_binop::>(lhs, rhs) } diff --git a/crates/burn-jit/src/tensor/base.rs b/crates/burn-jit/src/tensor/base.rs index e114b2f8e6..b586c4a6b7 100644 --- a/crates/burn-jit/src/tensor/base.rs +++ b/crates/burn-jit/src/tensor/base.rs @@ -1,5 +1,5 @@ use crate::element::JitElement; -use crate::kernel::{launch_unary, unary_op, UnaryOp}; +use crate::kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily}; use crate::JitRuntime; use burn_tensor::quantization::QTensorPrimitive; use burn_tensor::{DType, Shape, TensorMetadata}; @@ -314,15 +314,29 @@ where /// Copy the current tensor. pub fn copy(&self) -> Self { - execute_with_dtype!(self.dtype, E, { - unary_op!(numeric(self.clone()) => |context, tensor| { - #[cube] - fn execute(input: Line) -> Line { - input - } - execute::expand::(context, tensor) - }) - }) + struct Copy; + + #[cube] + impl NumericUnaryOp for Copy { + type Options = (); + + fn execute(input: Line, _options: &Self::Options) -> Line { + input + } + } + + impl NumericUnaryOpFamily for Copy { + type Options = (); + type Unary = Self; + } + + let tensor = self.clone(); + + execute_with_dtype!( + tensor.dtype, + E, + launch_unary_numeric::(tensor, |_| ()) + ) } /// Check if the tensor is safe to mutate. From 7c2d590bfeae42bfd9e615f098bd2c68702724c5 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 8 Jan 2025 11:46:13 -0500 Subject: [PATCH 05/61] Fix load_file usage to keep using model (#2672) --- burn-book/src/saving-and-loading.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/burn-book/src/saving-and-loading.md b/burn-book/src/saving-and-loading.md index 13a96cc94d..77f7c863d6 100644 --- a/burn-book/src/saving-and-loading.md +++ b/burn-book/src/saving-and-loading.md @@ -22,7 +22,7 @@ Now that you have a trained model saved to your disk, you can easily load it in ```rust, ignore // Load model in full precision from MessagePack file let recorder = NamedMpkFileRecorder::::new(); -model +model = model .load_file(model_path, &recorder, device) .expect("Should be able to load the model weights from the provided file"); ``` From 9367b1667a2c3a1fa3215cfbd247a5842fd9c83d Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 8 Jan 2025 11:53:35 -0500 Subject: [PATCH 06/61] Fix output float dtype in fusion (#2671) --- crates/burn-fusion/src/ops/float.rs | 67 ++++++++++------------------- 1 file changed, 22 insertions(+), 45 deletions(-) diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 57cfbd4132..b3e2a80432 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -278,9 +278,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -323,7 +321,7 @@ impl FloatTensorOps for Fusion { let dtype = tensor.dtype; let out = tensor .client - .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + .tensor_uninitialized(tensor.shape.clone(), dtype); let desc = ClampOperationDescription { tensor: tensor.into_description(), @@ -375,9 +373,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), rhs: rhs.elem(), @@ -428,9 +424,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -481,9 +475,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -534,9 +526,7 @@ impl FloatTensorOps for Fusion { let stream = lhs.stream; let dtype = lhs.dtype; - let out = lhs - .client - .tensor_uninitialized(lhs.shape.clone(), B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(lhs.shape.clone(), dtype); let desc = ScalarOperationDescription { lhs: lhs.into_description(), @@ -567,9 +557,7 @@ impl FloatTensorOps for Fusion { shape[ndims - 2] = lhs.shape[ndims - 2]; shape[ndims - 1] = rhs.shape[ndims - 1]; - let out = lhs - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = lhs.client.tensor_uninitialized(shape, dtype); let desc = BinaryOperationDescription { lhs: lhs.into_description(), rhs: rhs.into_description(), @@ -601,13 +589,12 @@ impl FloatTensorOps for Fusion { } let stream = tensor.stream; + let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim1] = tensor.shape[dim2]; shape[dim2] = tensor.shape[dim1]; - let mut out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let mut out = tensor.client.tensor_uninitialized(shape, dtype); let desc = SwapDimsDescription { input: tensor.into_description(), @@ -641,9 +628,8 @@ impl FloatTensorOps for Fusion { } let stream = tensor.stream; - let out = tensor - .client - .tensor_uninitialized(shape.dims, B::FloatElem::dtype()); + let dtype = tensor.dtype; + let out = tensor.client.tensor_uninitialized(shape.dims, dtype); let desc = ReshapeDescription { input: tensor.into_description(), @@ -1300,9 +1286,7 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let dtype = tensor.dtype; - let out = tensor - .client - .tensor_uninitialized(vec![1], B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(vec![1], dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1327,9 +1311,7 @@ impl FloatTensorOps for Fusion { let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1352,9 +1334,8 @@ impl FloatTensorOps for Fusion { unary_float_ops!(ProdOps, B::float_prod, reduce); let stream = tensor.stream; - let out = tensor - .client - .tensor_uninitialized(vec![1], B::FloatElem::dtype()); + let dtype = tensor.dtype; + let out = tensor.client.tensor_uninitialized(vec![1], dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1363,7 +1344,7 @@ impl FloatTensorOps for Fusion { out.client.register( vec![stream], OperationDescription::NumericFloat( - FloatElem::::dtype(), + dtype, NumericOperationDescription::Prod(desc.clone()), ), ProdOps::::new(desc), @@ -1376,11 +1357,10 @@ impl FloatTensorOps for Fusion { scalar_float_ops!(ProdDimOps, B::float_prod_dim, usize, noconvert); let stream = tensor.stream; + let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1404,9 +1384,7 @@ impl FloatTensorOps for Fusion { let stream = tensor.stream; let dtype = tensor.dtype; - let out = tensor - .client - .tensor_uninitialized(vec![1], B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(vec![1], dtype); let desc = UnaryOperationDescription { input: tensor.into_description(), @@ -1431,9 +1409,7 @@ impl FloatTensorOps for Fusion { let dtype = tensor.dtype; let mut shape = tensor.shape.clone(); shape[dim] = 1; - let out = tensor - .client - .tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = tensor.client.tensor_uninitialized(shape, dtype); let desc = ScalarOperationDescription { lhs: tensor.into_description(), @@ -1716,6 +1692,7 @@ impl FloatTensorOps for Fusion { } let tensor_first = tensors.first().unwrap(); + let dtype = tensor_first.dtype; let client = tensor_first.client.clone(); // Calculate the output shape @@ -1726,7 +1703,7 @@ impl FloatTensorOps for Fusion { shape[dim] += tensor.shape[dim]; } - let out = client.tensor_uninitialized(shape, B::FloatElem::dtype()); + let out = client.tensor_uninitialized(shape, dtype); let desc = CatOperationDescription { tensors: tensors.into_iter().map(|t| t.into_description()).collect(), From e588632679eb83e689195eea9438796b197a955f Mon Sep 17 00:00:00 2001 From: Sylvain Benner Date: Wed, 8 Jan 2025 12:25:03 -0500 Subject: [PATCH 07/61] [burnbench] Import code from github-device-flow crate (#2667) * [burnbench] Import code from github-device-flow crate We only use some of the code of the crate and some of its code is redundant with our own implementation. Moreover we don't need some of its dependencies like clap which is on a previous major version. Fixes #2548 * Fix format * Fix lint --- Cargo.lock | 435 +++--------------- Cargo.toml | 4 +- NOTICES.md | 32 +- backend-comparison/Cargo.toml | 2 +- .../burnbenchapp/{auth.rs => auth/base.rs} | 12 +- .../burnbenchapp/auth/github_device_flow.rs | 232 ++++++++++ .../src/burnbenchapp/auth/mod.rs | 4 + 7 files changed, 341 insertions(+), 380 deletions(-) rename backend-comparison/src/burnbenchapp/{auth.rs => auth/base.rs} (97%) create mode 100644 backend-comparison/src/burnbenchapp/auth/github_device_flow.rs create mode 100644 backend-comparison/src/burnbenchapp/auth/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 128196fae0..20d136df9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -310,17 +310,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a" -[[package]] -name = "atty" -version = "0.2.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" -dependencies = [ - "hermit-abi 0.1.19", - "libc", - "winapi", -] - [[package]] name = "autocfg" version = "1.4.0" @@ -361,10 +350,10 @@ dependencies = [ "base64 0.22.1", "bytes", "futures-util", - "http 1.2.0", - "http-body 1.0.1", + "http", + "http-body", "http-body-util", - "hyper 1.5.2", + "hyper", "hyper-util", "itoa", "matchit", @@ -378,7 +367,7 @@ dependencies = [ "serde_path_to_error", "serde_urlencoded", "sha1", - "sync_wrapper 1.0.2", + "sync_wrapper", "tokio", "tokio-tungstenite 0.24.0", "tower", @@ -396,13 +385,13 @@ dependencies = [ "async-trait", "bytes", "futures-util", - "http 1.2.0", - "http-body 1.0.1", + "http", + "http-body", "http-body-util", "mime", "pin-project-lite", "rustversion", - "sync_wrapper 1.0.2", + "sync_wrapper", "tower-layer", "tower-service", "tracing", @@ -415,19 +404,19 @@ dependencies = [ "arboard", "burn", "burn-common", - "clap 4.5.23", + "chrono", + "clap", "colored", "cubecl", "derive-new 0.7.0", "dirs", - "github-device-flow", "half", "indicatif", "log", "os_info", "percent-encoding", "rand", - "reqwest 0.12.12", + "reqwest", "rstest", "serde", "serde_json", @@ -461,12 +450,6 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" -[[package]] -name = "base64" -version = "0.21.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" - [[package]] name = "base64" version = "0.22.1" @@ -673,7 +656,7 @@ dependencies = [ "getrandom", "indicatif", "rayon", - "reqwest 0.12.12", + "reqwest", "serde", "tokio", "web-time", @@ -1248,61 +1231,31 @@ dependencies = [ [[package]] name = "clap" -version = "3.2.25" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123" -dependencies = [ - "atty", - "bitflags 1.3.2", - "clap_derive 3.2.25", - "clap_lex 0.2.4", - "indexmap 1.9.3", - "once_cell", - "strsim 0.10.0", - "termcolor", - "textwrap", -] - -[[package]] -name = "clap" -version = "4.5.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +checksum = "9560b07a799281c7e0958b9296854d6fafd4c5f31444a7e5bb1ad6dde5ccf1bd" dependencies = [ "clap_builder", - "clap_derive 4.5.18", + "clap_derive", ] [[package]] name = "clap_builder" -version = "4.5.23" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +checksum = "874e0dd3eb68bf99058751ac9712f622e61e6f393a94f7128fa26e3f02f5c7cd" dependencies = [ "anstream", "anstyle", - "clap_lex 0.7.4", - "strsim 0.11.1", -] - -[[package]] -name = "clap_derive" -version = "3.2.25" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae6371b8bdc8b7d3959e9cf7b22d4435ef3e79e138688421ec654acf8c81b008" -dependencies = [ - "heck 0.4.1", - "proc-macro-error", - "proc-macro2", - "quote", - "syn 1.0.109", + "clap_lex", + "strsim", ] [[package]] name = "clap_derive" -version = "4.5.18" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -1310,15 +1263,6 @@ dependencies = [ "syn 2.0.95", ] -[[package]] -name = "clap_lex" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5" -dependencies = [ - "os_str_bytes", -] - [[package]] name = "clap_lex" version = "0.7.4" @@ -1894,7 +1838,7 @@ version = "0.16.0" dependencies = [ "burn", "csv", - "reqwest 0.12.12", + "reqwest", "serde", ] @@ -1976,7 +1920,7 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "strsim 0.11.1", + "strsim", "syn 2.0.95", ] @@ -2808,20 +2752,6 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" -[[package]] -name = "github-device-flow" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98852ab71f5613dac02a0d1b41f3ffaf993b69449904dd13a10575612a56074d" -dependencies = [ - "chrono", - "clap 3.2.25", - "reqwest 0.11.27", - "serde", - "serde_derive", - "serde_json", -] - [[package]] name = "gix-features" version = "0.39.1" @@ -3009,25 +2939,6 @@ dependencies = [ "serde", ] -[[package]] -name = "h2" -version = "0.3.26" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81fe527a889e1532da5c525686d96d4c2e74cdd345badf8dfef9f6b39dd5f5e8" -dependencies = [ - "bytes", - "fnv", - "futures-core", - "futures-sink", - "futures-util", - "http 0.2.12", - "indexmap 2.7.0", - "slab", - "tokio", - "tokio-util", - "tracing", -] - [[package]] name = "h2" version = "0.4.7" @@ -3039,8 +2950,8 @@ dependencies = [ "fnv", "futures-core", "futures-sink", - "http 1.2.0", - "indexmap 2.7.0", + "http", + "indexmap", "slab", "tokio", "tokio-util", @@ -3072,12 +2983,6 @@ dependencies = [ "serde", ] -[[package]] -name = "hashbrown" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" - [[package]] name = "hashbrown" version = "0.13.2" @@ -3133,15 +3038,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.1.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33" -dependencies = [ - "libc", -] - [[package]] name = "hermit-abi" version = "0.3.9" @@ -3201,17 +3097,6 @@ version = "3.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" -[[package]] -name = "http" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" -dependencies = [ - "bytes", - "fnv", - "itoa", -] - [[package]] name = "http" version = "1.2.0" @@ -3223,17 +3108,6 @@ dependencies = [ "itoa", ] -[[package]] -name = "http-body" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" -dependencies = [ - "bytes", - "http 0.2.12", - "pin-project-lite", -] - [[package]] name = "http-body" version = "1.0.1" @@ -3241,7 +3115,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.2.0", + "http", ] [[package]] @@ -3252,8 +3126,8 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http 1.2.0", - "http-body 1.0.1", + "http", + "http-body", "pin-project-lite", ] @@ -3275,30 +3149,6 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" -[[package]] -name = "hyper" -version = "0.14.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41dfc780fdec9373c01bae43289ea34c972e40ee3c9f6b3c8801a35f35586ce7" -dependencies = [ - "bytes", - "futures-channel", - "futures-core", - "futures-util", - "h2 0.3.26", - "http 0.2.12", - "http-body 0.4.6", - "httparse", - "httpdate", - "itoa", - "pin-project-lite", - "socket2", - "tokio", - "tower-service", - "tracing", - "want", -] - [[package]] name = "hyper" version = "1.5.2" @@ -3308,9 +3158,9 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.7", - "http 1.2.0", - "http-body 1.0.1", + "h2", + "http", + "http-body", "httparse", "httpdate", "itoa", @@ -3327,8 +3177,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" dependencies = [ "futures-util", - "http 1.2.0", - "hyper 1.5.2", + "http", + "hyper", "hyper-util", "rustls", "rustls-native-certs 0.8.1", @@ -3338,19 +3188,6 @@ dependencies = [ "tower-service", ] -[[package]] -name = "hyper-tls" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" -dependencies = [ - "bytes", - "hyper 0.14.32", - "native-tls", - "tokio", - "tokio-native-tls", -] - [[package]] name = "hyper-tls" version = "0.6.0" @@ -3359,7 +3196,7 @@ checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" dependencies = [ "bytes", "http-body-util", - "hyper 1.5.2", + "hyper", "hyper-util", "native-tls", "tokio", @@ -3376,9 +3213,9 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.2.0", - "http-body 1.0.1", - "hyper 1.5.2", + "http", + "http-body", + "hyper", "pin-project-lite", "socket2", "tokio", @@ -3629,16 +3466,6 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" -[[package]] -name = "indexmap" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown 0.12.3", -] - [[package]] name = "indexmap" version = "2.7.0" @@ -3846,7 +3673,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -4226,7 +4053,7 @@ dependencies = [ "cfg_aliases 0.1.1", "codespan-reporting", "hexf-parse", - "indexmap 2.7.0", + "indexmap", "log", "rustc-hash 1.1.0", "spirv", @@ -4463,7 +4290,7 @@ version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" dependencies = [ - "hermit-abi 0.3.9", + "hermit-abi", "libc", ] @@ -4665,14 +4492,14 @@ dependencies = [ "chrono", "futures", "humantime", - "hyper 1.5.2", + "hyper", "itertools 0.13.0", "md-5", "parking_lot 0.12.3", "percent-encoding", "quick-xml", "rand", - "reqwest 0.12.12", + "reqwest", "ring", "serde", "serde_json", @@ -4836,12 +4663,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "os_str_bytes" -version = "6.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" - [[package]] name = "overload" version = "0.1.1" @@ -4963,7 +4784,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.7.0", + "indexmap", ] [[package]] @@ -5143,7 +4964,7 @@ dependencies = [ "either", "hashbrown 0.14.5", "hashbrown 0.15.2", - "indexmap 2.7.0", + "indexmap", "num-traits", "once_cell", "polars-arrow", @@ -5237,7 +5058,7 @@ dependencies = [ "pyo3", "rayon", "regex", - "reqwest 0.12.12", + "reqwest", "ryu", "serde", "serde_json", @@ -5258,7 +5079,7 @@ dependencies = [ "chrono", "fallible-streaming-iterator", "hashbrown 0.15.2", - "indexmap 2.7.0", + "indexmap", "itoa", "num-traits", "polars-arrow", @@ -5332,7 +5153,7 @@ dependencies = [ "either", "hashbrown 0.15.2", "hex", - "indexmap 2.7.0", + "indexmap", "memchr", "num-traits", "polars-arrow", @@ -5469,7 +5290,7 @@ version = "0.44.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d88667f770291cefa2e8cd366a54f29dc6fe362e9a263914c903db411a58ac1d" dependencies = [ - "indexmap 2.7.0", + "indexmap", "polars-error", "polars-utils", "serde", @@ -5560,7 +5381,7 @@ dependencies = [ "bytes", "compact_str", "hashbrown 0.15.2", - "indexmap 2.7.0", + "indexmap", "libc", "memmap2 0.7.1", "num-traits", @@ -5640,30 +5461,6 @@ dependencies = [ "toml_edit", ] -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - [[package]] name = "proc-macro2" version = "1.0.92" @@ -5726,7 +5523,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "322330e133eab455718444b4e033ebfac7c6528972c784fcde28d2cc783c6257" dependencies = [ "anyhow", - "indexmap 2.7.0", + "indexmap", "log", "protobuf", "protobuf-support", @@ -6267,46 +6064,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b30a45b0cd0bcca8037f3d0dc3421eaf95327a17cad11964fb8179b4fc4832" -[[package]] -name = "reqwest" -version = "0.11.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62" -dependencies = [ - "base64 0.21.7", - "bytes", - "encoding_rs", - "futures-core", - "futures-util", - "h2 0.3.26", - "http 0.2.12", - "http-body 0.4.6", - "hyper 0.14.32", - "hyper-tls 0.5.0", - "ipnet", - "js-sys", - "log", - "mime", - "native-tls", - "once_cell", - "percent-encoding", - "pin-project-lite", - "rustls-pemfile 1.0.4", - "serde", - "serde_json", - "serde_urlencoded", - "sync_wrapper 0.1.2", - "system-configuration 0.5.1", - "tokio", - "tokio-native-tls", - "tower-service", - "url", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", - "winreg", -] - [[package]] name = "reqwest" version = "0.12.12" @@ -6319,13 +6076,13 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2 0.4.7", - "http 1.2.0", - "http-body 1.0.1", + "h2", + "http", + "http-body", "http-body-util", - "hyper 1.5.2", + "hyper", "hyper-rustls", - "hyper-tls 0.6.0", + "hyper-tls", "hyper-util", "ipnet", "js-sys", @@ -6338,13 +6095,13 @@ dependencies = [ "quinn", "rustls", "rustls-native-certs 0.8.1", - "rustls-pemfile 2.2.0", + "rustls-pemfile", "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", - "sync_wrapper 1.0.2", - "system-configuration 0.6.1", + "sync_wrapper", + "system-configuration", "tokio", "tokio-native-tls", "tokio-rustls", @@ -6531,7 +6288,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" dependencies = [ "openssl-probe", - "rustls-pemfile 2.2.0", + "rustls-pemfile", "rustls-pki-types", "schannel", "security-framework 2.11.1", @@ -6549,15 +6306,6 @@ dependencies = [ "security-framework 3.1.0", ] -[[package]] -name = "rustls-pemfile" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" -dependencies = [ - "base64 0.21.7", -] - [[package]] name = "rustls-pemfile" version = "2.2.0" @@ -7139,12 +6887,6 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe895eb47f22e2ddd4dabc02bce419d2e643c8e3b585c78158b349195bc24d82" -[[package]] -name = "strsim" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" - [[package]] name = "strsim" version = "0.11.1" @@ -7201,12 +6943,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "sync_wrapper" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" - [[package]] name = "sync_wrapper" version = "1.0.2" @@ -7269,17 +7005,6 @@ dependencies = [ "windows 0.57.0", ] -[[package]] -name = "system-configuration" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" -dependencies = [ - "bitflags 1.3.2", - "core-foundation 0.9.4", - "system-configuration-sys 0.5.0", -] - [[package]] name = "system-configuration" version = "0.6.1" @@ -7288,17 +7013,7 @@ checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags 2.6.0", "core-foundation 0.9.4", - "system-configuration-sys 0.6.0", -] - -[[package]] -name = "system-configuration-sys" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" -dependencies = [ - "core-foundation-sys", - "libc", + "system-configuration-sys", ] [[package]] @@ -7442,12 +7157,6 @@ dependencies = [ "rgb", ] -[[package]] -name = "textwrap" -version = "0.16.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" - [[package]] name = "thiserror" version = "1.0.69" @@ -7719,7 +7428,7 @@ version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ - "indexmap 2.7.0", + "indexmap", "serde", "serde_spanned", "toml_datetime", @@ -7750,7 +7459,7 @@ dependencies = [ "futures-core", "futures-util", "pin-project-lite", - "sync_wrapper 1.0.2", + "sync_wrapper", "tokio", "tower-layer", "tower-service", @@ -7776,7 +7485,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58fccce80a2ef6bc32a512514a53cf853d438a44abaea286a4acb0c9f8566860" dependencies = [ "anyhow", - "clap 4.5.23", + "clap", "derive_more 0.99.18", "env_logger", "log", @@ -7888,7 +7597,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.2.0", + "http", "httparse", "log", "rand", @@ -7906,7 +7615,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.2.0", + "http", "httparse", "log", "rand", @@ -8363,7 +8072,7 @@ dependencies = [ "bitflags 2.6.0", "cfg_aliases 0.1.1", "document-features", - "indexmap 2.7.0", + "indexmap", "log", "naga", "once_cell", @@ -8770,16 +8479,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "winreg" -version = "0.50.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" -dependencies = [ - "cfg-if", - "windows-sys 0.48.0", -] - [[package]] name = "wrapcenum-derive" version = "0.4.1" @@ -9004,7 +8703,7 @@ dependencies = [ "crc32fast", "crossbeam-utils", "displaydoc", - "indexmap 2.7.0", + "indexmap", "num_enum", "thiserror 1.0.69", ] @@ -9025,7 +8724,7 @@ dependencies = [ "displaydoc", "flate2", "hmac", - "indexmap 2.7.0", + "indexmap", "lzma-rs", "memchr", "pbkdf2 0.12.2", diff --git a/Cargo.toml b/Cargo.toml index 863943f0b6..f51aa31ce3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ version = "0.16.0" atomic_float = "1" bytemuck = "1.21.0" candle-core = { version = "0.8" } -clap = { version = "4.5.23", features = ["derive"] } +clap = { version = "4.5.24", features = ["derive"] } colored = "2.1.0" console_error_panic_hook = "0.1.7" csv = "1.3.1" @@ -104,8 +104,8 @@ text_placeholder = "0.5.1" wgpu = "23.0.0" # Benchmarks and Burnbench +chrono = "0.4.39" arboard = "3.4.1" -github-device-flow = "0.2.0" os_info = "3.9.0" wsl = "0.1.0" diff --git a/NOTICES.md b/NOTICES.md index a11156a772..c41a90d952 100644 --- a/NOTICES.md +++ b/NOTICES.md @@ -9,7 +9,7 @@ repository copied or derived from. License: BSD 3-Clause License -Copyright (c) 2017, +Copyright (c) 2017, All rights reserved. Redistribution and use in source and binary forms, with or without @@ -572,4 +572,32 @@ SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. \ No newline at end of file +DEALINGS IN THE SOFTWARE. + +## github-device-flow + +**Source**: +- Part of: https://github.com/jakewilkins/gh-device-flow/blob/main/src/lib.rs +- https://github.com/jakewilkins/gh-device-flow/blob/main/src/util.rs + +MIT License + +Copyright (c) 2022 Jake Wilkins + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index ee5f0bd8a2..ab726fa7be 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -37,10 +37,10 @@ burn-common = { path = "../crates/burn-common", version = "0.16.0" } clap = { workspace = true } colored = { workspace = true } +chrono = { workspace = true } cubecl = { workspace = true, features = ["wgpu"], default-features = true } derive-new = { workspace = true } dirs = { workspace = true } -github-device-flow = { workspace = true } half = { workspace = true } indicatif = { workspace = true } os_info = { workspace = true } diff --git a/backend-comparison/src/burnbenchapp/auth.rs b/backend-comparison/src/burnbenchapp/auth/base.rs similarity index 97% rename from backend-comparison/src/burnbenchapp/auth.rs rename to backend-comparison/src/burnbenchapp/auth/base.rs index 3e9470b2bd..1cff65f690 100644 --- a/backend-comparison/src/burnbenchapp/auth.rs +++ b/backend-comparison/src/burnbenchapp/auth/base.rs @@ -1,6 +1,5 @@ use arboard::Clipboard; use burn::serde::{Deserialize, Serialize}; -use github_device_flow::{self, DeviceFlow}; use reqwest; #[cfg(unix)] use std::os::unix::fs::PermissionsExt; @@ -64,7 +63,7 @@ pub(crate) fn get_tokens() -> Option { pub(crate) fn get_username(access_token: &str) -> Option { let client = reqwest::blocking::Client::new(); let response = client - .get(format!("{}users/me", super::USER_BENCHMARK_SERVER_URL)) + .get(format!("{}users/me", USER_BENCHMARK_SERVER_URL)) .header(reqwest::header::USER_AGENT, "burnbench") .header(reqwest::header::CONTENT_TYPE, "application/json") .header( @@ -77,7 +76,7 @@ pub(crate) fn get_username(access_token: &str) -> Option { } fn auth() -> Option { - let mut flow = match DeviceFlow::start(CLIENT_ID, None) { + let mut flow = match DeviceFlow::start(CLIENT_ID, None, None) { Ok(flow) => flow, Err(e) => { eprintln!("Error authenticating: {}", e); @@ -142,10 +141,7 @@ fn refresh_tokens(tokens: &Tokens) -> Option { println!("Refreshing token..."); let client = reqwest::blocking::Client::new(); let response = client - .post(format!( - "{}auth/refresh-token", - super::USER_BENCHMARK_SERVER_URL - )) + .post(format!("{}auth/refresh-token", USER_BENCHMARK_SERVER_URL)) .header(reqwest::header::USER_AGENT, "burnbench") .header(reqwest::header::CONTENT_TYPE, "application/json") .header( @@ -189,6 +185,8 @@ fn save_tokens(tokens: &Tokens) { #[cfg(test)] use serial_test::serial; +use crate::burnbenchapp::{auth::github_device_flow::DeviceFlow, USER_BENCHMARK_SERVER_URL}; + #[cfg(test)] mod tests { use super::*; diff --git a/backend-comparison/src/burnbenchapp/auth/github_device_flow.rs b/backend-comparison/src/burnbenchapp/auth/github_device_flow.rs new file mode 100644 index 0000000000..55aa00f73e --- /dev/null +++ b/backend-comparison/src/burnbenchapp/auth/github_device_flow.rs @@ -0,0 +1,232 @@ +// Initially from: https://github.com/jakewilkins/gh-device-flow +use std::collections::HashMap; +use std::{fmt, result::Result, thread, time}; + +use chrono::offset::Utc; +use chrono::{DateTime, Duration}; +use serde::{Deserialize, Serialize}; + +pub fn credential_error(msg: String) -> DeviceFlowError { + DeviceFlowError::GitHubError(msg) +} + +pub fn send_request( + device_flow: &mut DeviceFlow, + url: String, + body: String, +) -> Option> { + let client = reqwest::blocking::Client::new(); + let response_struct = client + .post(&url) + .header("Accept", "application/json") + .body(body) + .send(); + + match response_struct { + Ok(resp) => match resp.json::>() { + Ok(hm) => Some(hm), + Err(err) => { + device_flow.state = DeviceFlowState::Failure(err.into()); + None + } + }, + Err(err) => { + device_flow.state = DeviceFlowState::Failure(err.into()); + None + } + } +} + +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct Credential { + pub token: String, + pub expiry: String, + pub refresh_token: String, +} + +impl Credential { + fn empty() -> Credential { + Credential { + token: String::new(), + expiry: String::new(), + refresh_token: String::new(), + } + } +} + +#[derive(Debug, Clone)] +pub enum DeviceFlowError { + HttpError(String), + GitHubError(String), +} + +impl fmt::Display for DeviceFlowError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + DeviceFlowError::HttpError(string) => write!(f, "DeviceFlowError: {}", string), + DeviceFlowError::GitHubError(string) => write!(f, "DeviceFlowError: {}", string), + } + } +} + +impl std::error::Error for DeviceFlowError {} + +impl From for DeviceFlowError { + fn from(e: reqwest::Error) -> Self { + DeviceFlowError::HttpError(format!("{:?}", e)) + } +} + +#[derive(Debug, Clone)] +pub enum DeviceFlowState { + Pending, + Processing(time::Duration), + Success(Credential), + Failure(DeviceFlowError), +} + +#[derive(Clone)] +pub struct DeviceFlow { + pub host: String, + pub client_id: String, + pub scope: String, + pub user_code: Option, + pub device_code: Option, + pub verification_uri: Option, + pub state: DeviceFlowState, +} + +const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0); + +impl DeviceFlow { + pub fn new(client_id: &str, maybe_host: Option<&str>, scope: Option<&str>) -> Self { + Self { + client_id: String::from(client_id), + scope: match scope { + Some(string) => String::from(string), + None => String::new(), + }, + host: match maybe_host { + Some(string) => String::from(string), + None => String::from("github.com"), + }, + user_code: None, + device_code: None, + verification_uri: None, + state: DeviceFlowState::Pending, + } + } + + pub fn start( + client_id: &str, + maybe_host: Option<&str>, + scope: Option<&str>, + ) -> Result { + let mut flow = DeviceFlow::new(client_id, maybe_host, scope); + + flow.setup(); + + match flow.state { + DeviceFlowState::Processing(_) => Ok(flow.to_owned()), + DeviceFlowState::Failure(err) => Err(err), + _ => Err(credential_error( + "Something truly unexpected happened".into(), + )), + } + } + + pub fn setup(&mut self) { + let body = format!("client_id={}&scope={}", &self.client_id, &self.scope); + let entry_url = format!("https://{}/login/device/code", &self.host); + + if let Some(res) = send_request(self, entry_url, body) { + if res.contains_key("error") && res.contains_key("error_description") { + self.state = DeviceFlowState::Failure(credential_error( + res["error_description"].as_str().unwrap().into(), + )) + } else if res.contains_key("error") { + self.state = DeviceFlowState::Failure(credential_error(format!( + "Error response: {:?}", + res["error"].as_str().unwrap() + ))) + } else { + self.user_code = Some(String::from(res["user_code"].as_str().unwrap())); + self.device_code = Some(String::from(res["device_code"].as_str().unwrap())); + self.verification_uri = + Some(String::from(res["verification_uri"].as_str().unwrap())); + self.state = DeviceFlowState::Processing(FIVE_SECONDS); + } + }; + } + + pub fn poll(&mut self, iterations: u32) -> Result { + for count in 0..iterations { + self.update(); + + if let DeviceFlowState::Processing(interval) = self.state { + if count == iterations { + return Err(credential_error("Max poll iterations reached".into())); + } + + thread::sleep(interval); + } else { + break; + } + } + + match &self.state { + DeviceFlowState::Success(cred) => Ok(cred.to_owned()), + DeviceFlowState::Failure(err) => Err(err.to_owned()), + _ => Err(credential_error( + "Unable to fetch credential, sorry :/".into(), + )), + } + } + + pub fn update(&mut self) { + let poll_url = format!("https://{}/login/oauth/access_token", self.host); + let poll_payload = format!( + "client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code", + self.client_id, + &self.device_code.clone().unwrap() + ); + + if let Some(res) = send_request(self, poll_url, poll_payload) { + if res.contains_key("error") { + match res["error"].as_str().unwrap() { + "authorization_pending" => {} + "slow_down" => { + if let DeviceFlowState::Processing(current_interval) = self.state { + self.state = + DeviceFlowState::Processing(current_interval + FIVE_SECONDS); + }; + } + other_reason => { + self.state = DeviceFlowState::Failure(credential_error(format!( + "Error checking for token: {}", + other_reason + ))); + } + } + } else { + let mut this_credential = Credential::empty(); + this_credential.token = res["access_token"].as_str().unwrap().to_string(); + + if let Some(expires_in) = res.get("expires_in") { + this_credential.expiry = calculate_expiry(expires_in.as_i64().unwrap()); + this_credential.refresh_token = + res["refresh_token"].as_str().unwrap().to_string(); + } + + self.state = DeviceFlowState::Success(this_credential); + } + } + } +} + +fn calculate_expiry(expires_in: i64) -> String { + let expires_in = Duration::seconds(expires_in); + let mut expiry: DateTime = Utc::now(); + expiry += expires_in; + expiry.to_rfc3339() +} diff --git a/backend-comparison/src/burnbenchapp/auth/mod.rs b/backend-comparison/src/burnbenchapp/auth/mod.rs new file mode 100644 index 0000000000..7e1e4539d7 --- /dev/null +++ b/backend-comparison/src/burnbenchapp/auth/mod.rs @@ -0,0 +1,4 @@ +mod base; +pub(crate) mod github_device_flow; + +pub(crate) use base::*; From da8de562b0f67869c8a8c629b8535f938fd317f9 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Wed, 8 Jan 2025 15:11:59 -0500 Subject: [PATCH 08/61] Fix/autotune error handling (#2670) --- Cargo.lock | 24 +- Cargo.toml | 4 +- .../src/fusion/matmul/optimization.rs | 3 +- .../burn-jit/src/kernel/conv/conv2d/base.rs | 16 +- .../burn-jit/src/kernel/conv/conv2d/col2im.rs | 20 +- .../burn-jit/src/kernel/conv/conv2d/direct.rs | 6 +- .../src/kernel/conv/conv2d/gemm/launch.rs | 12 +- .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 23 +- .../src/kernel/conv/conv2d/implicit_gemm.rs | 6 +- .../kernel/conv/conv2d/transpose_direct.rs | 6 +- .../src/kernel/conv/conv2d/tune/conv2d.rs | 40 +- .../burn-jit/src/kernel/conv/deform_conv2d.rs | 10 +- .../kernel/conv/deform_conv_transpose2d.rs | 32 +- crates/burn-jit/src/kernel/conv/error.rs | 20 + crates/burn-jit/src/kernel/conv/mod.rs | 2 + crates/burn-jit/src/kernel/matmul/base.rs | 12 +- .../burn-jit/src/kernel/matmul/tune/base.rs | 12 +- crates/burn-jit/src/kernel/reduce/base.rs | 4 +- .../src/kernel/reduce/naive/kernel.rs | 4 +- crates/burn-jit/src/kernel/reduce/prod.rs | 2 +- .../src/kernel/reduce/shared/kernel.rs | 4 +- .../src/kernel/reduce/subcube/kernel.rs | 4 +- crates/burn-jit/src/kernel/reduce/sum.rs | 2 +- crates/burn-jit/src/ops/float_ops.rs | 12 +- crates/burn-jit/src/ops/int_ops.rs | 10 +- crates/burn-jit/src/ops/module_ops.rs | 6 +- crates/burn-jit/src/tests/mod.rs | 2 - crates/burn-jit/src/tests/reduce.rs | 566 ------------------ 28 files changed, 158 insertions(+), 706 deletions(-) create mode 100644 crates/burn-jit/src/kernel/conv/error.rs delete mode 100644 crates/burn-jit/src/tests/reduce.rs diff --git a/Cargo.lock b/Cargo.lock index 20d136df9c..11ae2b3de3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1581,7 +1581,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1613,7 +1613,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1630,7 +1630,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1649,7 +1649,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1663,7 +1663,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1679,7 +1679,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1705,7 +1705,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "bytemuck", "cubecl-core", @@ -1717,7 +1717,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1732,7 +1732,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1769,7 +1769,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "async-channel", "async-lock", @@ -1790,7 +1790,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1804,7 +1804,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4d6f50f3af4c8dd664619b61e6adf437e4b09e2e#4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" +source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index f51aa31ce3..3a03582fea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/fusion/matmul/optimization.rs b/crates/burn-jit/src/fusion/matmul/optimization.rs index b1d8431c67..d0cd8749ad 100644 --- a/crates/burn-jit/src/fusion/matmul/optimization.rs +++ b/crates/burn-jit/src/fusion/matmul/optimization.rs @@ -122,7 +122,8 @@ impl MatmulOptimization { rhs_tensor, None, matmul::MatmulStrategy::default(), - ); + ) + .unwrap(); (out_tensor, out) }; context diff --git a/crates/burn-jit/src/kernel/conv/conv2d/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/base.rs index 0b3a35dc45..f015677a2b 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/base.rs @@ -1,6 +1,8 @@ use burn_tensor::ops::{ConvOptions, ConvTransposeOptions}; -use crate::{tensor::JitTensor, FloatElement, IntElement, JitRuntime}; +use crate::{ + kernel::conv::ConvLaunchError, tensor::JitTensor, FloatElement, IntElement, JitRuntime, +}; #[cfg(feature = "autotune")] use super::{conv2d_autotune, conv_transpose2d_autotune}; @@ -75,11 +77,11 @@ pub fn conv2d( bias: Option>, options: ConvOptions<2>, strategy: Conv2dStrategy, -) -> JitTensor { +) -> Result, ConvLaunchError> { match strategy { Conv2dStrategy::Direct => conv2d_direct::(input, weight, bias, options), #[cfg(feature = "autotune")] - Conv2dStrategy::Autotune => conv2d_autotune::(input, weight, bias, options), + Conv2dStrategy::Autotune => Ok(conv2d_autotune::(input, weight, bias, options)), Conv2dStrategy::Gemm => conv2d_im2col::(input, weight, bias, options), Conv2dStrategy::ImplicitGemm => conv2d_implicit_gemm::(input, weight, bias, options), Conv2dStrategy::ImplicitGemmComplex => { @@ -102,15 +104,15 @@ pub fn conv_transpose2d( bias: Option>, options: ConvTransposeOptions<2>, strategy: ConvTranspose2dStrategy, -) -> JitTensor { +) -> Result, ConvLaunchError> { match strategy { ConvTranspose2dStrategy::Direct => { conv_transpose2d_direct::(input, weight, bias, options) } #[cfg(feature = "autotune")] - ConvTranspose2dStrategy::Autotune => { - conv_transpose2d_autotune::(input, weight, bias, options) - } + ConvTranspose2dStrategy::Autotune => Ok(conv_transpose2d_autotune::( + input, weight, bias, options, + )), ConvTranspose2dStrategy::Gemm => { conv_transpose2d_col2im::(input, weight, bias, options) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs index 0d9c48dc30..11fb3b4aee 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -6,6 +6,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ kernel::{ + conv::ConvLaunchError, into_contiguous, matmul::{matmul, MatmulStrategy}, slice, @@ -29,7 +30,7 @@ pub fn conv_transpose2d_col2im( weight: JitTensor, bias: Option>, options: ConvTransposeOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.shape.dims(); let [batch_size, _, input_h, input_w] = input.shape.dims(); let groups = options.groups; @@ -94,9 +95,12 @@ pub fn conv_transpose2d_col2im( options.clone(), kernel_h, kernel_w, - ); + )?; } - reshape(image, Shape::new([batch_size, im_channels, im_h, im_w])) + Ok(reshape( + image, + Shape::new([batch_size, im_channels, im_h, im_w]), + )) } else { let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]); let image = empty_device::(input.client.clone(), input.device.clone(), im_shape); @@ -108,8 +112,8 @@ pub fn conv_transpose2d_col2im( options, kernel_h, kernel_w, - ); - image + )?; + Ok(image) } } @@ -135,7 +139,7 @@ fn execute( options: ConvTransposeOptions<2>, kernel_h: usize, kernel_w: usize, -) { +) -> Result<(), ConvLaunchError> { let [batch_size, _, input_h, input_w] = input.shape.dims(); let [groups, col_shape_0, input_ch_per_group] = weight.shape.dims(); @@ -145,12 +149,14 @@ fn execute( let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]); let input = reshape(input, input_shape); - let columns = matmul::(weight, input, None, MatmulStrategy::default()); + let columns = matmul::(weight, input, None, MatmulStrategy::default())?; let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); col2im::( columns, bias, image, kernel_h, kernel_w, input_h, input_w, options, ); + + Ok(()) } #[allow(clippy::too_many_arguments)] diff --git a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs index d5154ecc4b..c724cfc3a3 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs @@ -5,7 +5,7 @@ use burn_tensor::{ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ - kernel::into_contiguous, + kernel::{conv::ConvLaunchError, into_contiguous}, ops::{ numeric::{empty_device, zeros_device}, reshape, @@ -125,7 +125,7 @@ pub fn conv2d_direct( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let [batch_size, _, in_height, in_width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let channels_per_group = out_channels / options.groups; @@ -193,5 +193,5 @@ pub fn conv2d_direct( kernel_w_unroll, ); - output + Ok(output) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs index c99861c82d..abc94d1a9a 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs @@ -23,7 +23,7 @@ use crate::{ algorithm::{Algorithm, ImplicitCmmaConv}, base::{ConvolutionLaunch, ConvolutionProblem}, }, - nchw_to_nhwc, Conv2dAutotuneKey, + nchw_to_nhwc, Conv2dAutotuneKey, ConvLaunchError, }, into_contiguous, }, @@ -44,7 +44,7 @@ pub fn conv2d_gemm_cmma_large_m( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { conv2d_gemm_cmma_strategy::(input, weight, bias, options) } @@ -60,7 +60,7 @@ pub fn conv2d_gemm_cmma_balanced( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { conv2d_gemm_cmma_strategy::(input, weight, bias, options) } @@ -74,7 +74,7 @@ fn conv2d_gemm_cmma_strategy< weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { if TypeId::of::() == TypeId::of::() { conv2d_gemm_with_algo::(input, weight, bias, options) } else if TypeId::of::() == TypeId::of::() || TypeId::of::() == TypeId::of::() @@ -102,7 +102,7 @@ pub fn conv2d_gemm_with_algo< weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor +) -> Result, ConvLaunchError> where SP::EG: JitElement, { @@ -221,7 +221,7 @@ where // Reset to NCHW let out = reshape(out, Shape::new([batch_size, out_h, out_w, out_channels])); - permute(out, &[0, 3, 1, 2]) + Ok(permute(out, &[0, 3, 1, 2])) } pub fn problem_from_key( diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index a65c29466c..7f9914989a 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -6,7 +6,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ kernel::{ - conv::index, + conv::{index, ConvLaunchError}, into_contiguous, launch_binop, matmul::{matmul, MatmulStrategy}, AddOp, @@ -188,7 +188,7 @@ pub fn conv2d_im2col( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let [batch_size, in_channels, in_height, in_width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let groups = options.groups; @@ -237,13 +237,13 @@ pub fn conv2d_im2col( options.clone(), out_h, out_w, - ); + )?; } let out = swap_dims(out, 1, 2); reshape(out, Shape::new([batch_size, out_channels, out_h, out_w])) } else { let out = empty_device::(input.client.clone(), input.device.clone(), matmul_shape); - execute::(input, weight, out.clone(), options, out_h, out_w); + execute::(input, weight, out.clone(), options, out_h, out_w)?; let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); swap_dims(out, 0, 1) }; @@ -252,7 +252,8 @@ pub fn conv2d_im2col( let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); out = launch_binop::(out, bias) } - out + + Ok(out) } fn execute_1x1_kernel( @@ -260,7 +261,7 @@ fn execute_1x1_kernel( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let [batch_size, _, height, width] = input.shape.dims(); let [out_channels, in_c_per_grp, _, _] = weight.shape.dims(); let groups = options.groups; @@ -271,7 +272,7 @@ fn execute_1x1_kernel( let weight = reshape(weight, Shape::new([groups, out_c_per_grp, in_c_per_grp])); let in_shape = Shape::new([groups, in_c_per_grp, batch_size * height * width]); let input = reshape(input, in_shape); - let out = matmul::(weight, input, None, MatmulStrategy::default()); + let out = matmul::(weight, input, None, MatmulStrategy::default())?; let mut out = reshape(out, Shape::new([out_channels, batch_size, height, width])); if let Some(bias) = bias { @@ -279,7 +280,7 @@ fn execute_1x1_kernel( out = launch_binop::(out, bias) } - swap_dims(out, 0, 1) + Ok(swap_dims(out, 0, 1)) } fn execute( @@ -289,7 +290,7 @@ fn execute( options: ConvOptions<2>, out_h: usize, out_w: usize, -) { +) -> Result<(), ConvLaunchError> { let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); let groups = options.groups; @@ -301,5 +302,7 @@ fn execute( let columns = reshape(columns, Shape::new([groups, col_shape_0, col_shape_1])); let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0])); - matmul::(weight, columns, Some(out), Default::default()); + matmul::(weight, columns, Some(out), Default::default())?; + + Ok(()) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index 2e4f469068..2e8e147170 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -12,7 +12,7 @@ use cubecl::{ use half::f16; use crate::{ - kernel::{into_contiguous, slice, slice_assign}, + kernel::{conv::ConvLaunchError, into_contiguous, slice, slice_assign}, ops::{ numeric::{empty_device, zeros_device}, permute, @@ -35,7 +35,7 @@ pub fn conv2d_implicit_gemm( weight: JitTensor, bias: Option>, options: ConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let is_tf32 = F::as_elem_native_unchecked() == Elem::Float(FloatKind::F32) && input .client @@ -210,7 +210,7 @@ pub fn conv2d_implicit_gemm( let out = slice::(out, &[0..batch_size, 0..out_h, 0..out_w, 0..out_channels]); // Reset to NCHW - permute(out, &[0, 3, 1, 2]) + Ok(permute(out, &[0, 3, 1, 2])) } fn find_common_vec(channels: usize, elems_per_thread: u32, supported_vecs: &[u8]) -> u8 { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index 6a97ab8759..d3e91d5947 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -2,7 +2,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ element::JitElement, - kernel::into_contiguous, + kernel::{conv::ConvLaunchError, into_contiguous}, ops::{ numeric::{empty_device, zeros_device}, reshape, @@ -126,7 +126,7 @@ pub fn conv_transpose2d_direct( weight: JitTensor, bias: Option>, options: ConvTransposeOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let input = into_contiguous(input); let weight = into_contiguous(weight); let [batch_size, _, in_height, in_width] = input.shape.dims(); @@ -184,5 +184,5 @@ pub fn conv_transpose2d_direct( ), ); - output + Ok(output) } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs index 157d4d443d..c6eb31ea9c 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs @@ -97,34 +97,6 @@ macro_rules! check_algo { _ => can_launch::<$algo, R, ($float, $float, f32)>($input, $problem), } }; - - ($algo:tt, $input:expr, $problem:expr) => { - let plane_dim = 32; - let conv_problem = $problem; - - let (selection, config_input) = $algo::select_kernel::(plane_dim); - let cube_dim = ImplicitCmmaConv::cube_dim(&selection); - let cube_count = ImplicitCmmaConv::cube_count(&selection, &conv_problem); - - let advanced_config = Default::default(); - let config = ImplicitCmmaConv::make_config( - config_input, - &conv_problem, - &cube_dim, - &cube_count, - &advanced_config, - ); - - match config { - Ok(config) => ImplicitCmmaConv::can_launch::( - &op.input.client, - &conv_problem, - &config, - &selection, - ), - Err(_) => false, - } - }; } fn should_run( @@ -180,13 +152,21 @@ fn can_launch, R: JitRuntime, CS: ConvPrecisio input: &JitTensor, conv_problem: &ConvolutionProblem, ) -> bool { - let plane_dim = 32; + let plane_dim = match input + .client + .properties() + .hardware_properties() + .defined_plane_size() + { + Some(val) => val, + None => return false, + }; let (selection, config_input) = S::select_kernel::(plane_dim); let cube_dim = ImplicitCmmaConv::cube_dim(&selection); let cube_count = ImplicitCmmaConv::cube_count(&selection, conv_problem); - let advanced_config = Default::default(); + let config = ImplicitCmmaConv::make_config( config_input, conv_problem, diff --git a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs index b22821aef1..300d714335 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs @@ -19,6 +19,8 @@ use crate::{ FloatElement, JitRuntime, }; +use super::ConvLaunchError; + #[derive(CubeLaunch)] struct DeformConv2dArgs { conv_stride_h: u32, @@ -262,7 +264,7 @@ pub(crate) fn deform_conv2d( mask: Option>, bias: Option>, options: DeformConvOptions<2>, -) -> JitTensor { +) -> Result, ConvLaunchError> { let input = into_contiguous(input); let offset = into_contiguous(offset); let weight = into_contiguous(weight); @@ -298,15 +300,15 @@ pub(crate) fn deform_conv2d( let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0])); let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); - let out = matmul::(weight, columns, None, MatmulStrategy::default()); + let out = matmul::(weight, columns, None, MatmulStrategy::default())?; let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); let out = swap_dims(out, 0, 1); if let Some(bias) = bias { let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); - launch_binop::(out, bias) + Ok(launch_binop::(out, bias)) } else { - out + Ok(out) } } diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index b75ac43182..ad9e11c6c5 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -19,7 +19,7 @@ use crate::{ FloatElement, IntElement, JitBackend, JitRuntime, }; -use super::{bilinear_interpolate, deform_im2col, index}; +use super::{bilinear_interpolate, deform_im2col, index, ConvLaunchError}; /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. #[allow(clippy::single_range_in_vec_init)] @@ -36,7 +36,7 @@ pub(crate) fn deform_conv2d_backward< bias: Option>, out_grad: JitTensor, options: DeformConvOptions<2>, -) -> DeformConv2dBackward> { +) -> Result>, ConvLaunchError> { let [_, _, out_h, out_w] = out_grad.shape.dims(); let [_, _, kernel_h, kernel_w] = weight.shape.dims(); @@ -60,7 +60,7 @@ pub(crate) fn deform_conv2d_backward< out_grad.clone(), &options, (kernel_h, kernel_w), - ); + )?; let weight_grad = compute_weight_grad::( input, @@ -70,15 +70,15 @@ pub(crate) fn deform_conv2d_backward< options, (kernel_h, kernel_w), (out_h, out_w), - ); + )?; - DeformConv2dBackward::new( + Ok(DeformConv2dBackward::new( input_gradient, offset_gradient, weight_grad, mask_gradient, gradient_bias, - ) + )) } fn compute_weight_grad( @@ -89,7 +89,7 @@ fn compute_weight_grad( options: DeformConvOptions<2>, kernel_dims: (usize, usize), out_dims: (usize, usize), -) -> JitTensor { +) -> Result, ConvLaunchError> { let [_, in_channels, _, _] = input.shape.dims(); let [_, out_channels, _, _] = out_grad.shape.dims(); let (kernel_h, kernel_w) = kernel_dims; @@ -108,12 +108,12 @@ fn compute_weight_grad( let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); let columns = swap_dims(columns, 1, 2); - let grad_weight = matmul::(out_grad, columns, None, MatmulStrategy::default()); + let grad_weight = matmul::(out_grad, columns, None, MatmulStrategy::default())?; - reshape( + Ok(reshape( grad_weight, Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]), - ) + )) } type InputGradients = (JitTensor, JitTensor, Option>); @@ -126,7 +126,7 @@ fn backward_gradient_inputs( out_grad: JitTensor, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), -) -> InputGradients { +) -> Result, ConvLaunchError> { let client = out_grad.client.clone(); let device = out_grad.device.clone(); @@ -150,7 +150,7 @@ fn backward_gradient_inputs( for group in 0..groups { let weight = swap_dims(index::(weight.clone(), group), 0, 1); let out_grad = index::(out_grad.clone(), group); - let values = matmul::(weight, out_grad, None, MatmulStrategy::default()); + let values = matmul::(weight, out_grad, None, MatmulStrategy::default())?; let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1])); columns = slice_assign::( columns, @@ -169,12 +169,12 @@ fn backward_gradient_inputs( mask.clone(), options, kernel_dims, - ); + )?; let input_gradient = compute_input_grad::(columns, offset, mask, options, kernel_dims, input_shape); - (input_gradient, offset_gradient, mask_gradient) + Ok((input_gradient, offset_gradient, mask_gradient)) } fn compute_offset_and_mask_gradient( @@ -184,7 +184,7 @@ fn compute_offset_and_mask_gradient( mask: Option>, options: &DeformConvOptions<2>, kernel_dims: (usize, usize), -) -> (JitTensor, Option>) { +) -> Result<(JitTensor, Option>), ConvLaunchError> { let client = offset.client.clone(); let device = offset.device.clone(); let (kernel_height, kernel_width) = kernel_dims; @@ -238,7 +238,7 @@ fn compute_offset_and_mask_gradient( }; let mask_gradient = if use_mask { Some(grad_mask) } else { None }; - (grad_offset, mask_gradient) + Ok((grad_offset, mask_gradient)) } #[derive(CubeLaunch)] diff --git a/crates/burn-jit/src/kernel/conv/error.rs b/crates/burn-jit/src/kernel/conv/error.rs new file mode 100644 index 0000000000..2f15bc9886 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/error.rs @@ -0,0 +1,20 @@ +use cubecl::{linalg::matmul::kernels::MatmulLaunchError, tune::AutotuneError}; + +#[derive(Debug)] +pub enum ConvLaunchError { + Matmul(MatmulLaunchError), + Unknown, +} + +impl From for ConvLaunchError { + fn from(value: MatmulLaunchError) -> Self { + Self::Matmul(value) + } +} + +#[allow(clippy::from_over_into)] +impl Into for ConvLaunchError { + fn into(self) -> AutotuneError { + AutotuneError::Unknown(format!("{self:?}")) + } +} diff --git a/crates/burn-jit/src/kernel/conv/mod.rs b/crates/burn-jit/src/kernel/conv/mod.rs index 5d6794495f..04794e9b42 100644 --- a/crates/burn-jit/src/kernel/conv/mod.rs +++ b/crates/burn-jit/src/kernel/conv/mod.rs @@ -3,11 +3,13 @@ mod conv3d; mod conv_transpose3d; mod deform_conv2d; mod deform_conv_transpose2d; +mod error; pub(crate) use conv2d::*; pub(crate) use conv3d::*; pub(crate) use conv_transpose3d::*; pub(crate) use deform_conv2d::*; pub(crate) use deform_conv_transpose2d::*; +pub(crate) use error::*; pub use conv2d::{conv2d, conv_transpose2d, nchw_to_nhwc, Conv2dStrategy, ConvTranspose2dStrategy}; diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 7fa141cf67..611f1e32d4 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -1,3 +1,5 @@ +use cubecl::linalg::matmul::kernels::MatmulLaunchError; + use super::init_matmul_output; use crate::{tensor::JitTensor, FloatElement, JitRuntime}; @@ -30,7 +32,7 @@ pub fn matmul( rhs: JitTensor, out: Option>, strategy: MatmulStrategy, -) -> JitTensor { +) -> Result, MatmulLaunchError> { match strategy { MatmulStrategy::Cube => { let out = out.unwrap_or_else(|| init_matmul_output::(&lhs, &rhs)); @@ -43,11 +45,11 @@ pub fn matmul( &lhs.as_handle_ref(), &rhs.as_handle_ref(), &out.as_handle_ref(), - ) - .unwrap(); - out + )?; + + Ok(out) } #[cfg(feature = "autotune")] - MatmulStrategy::Autotune => matmul_autotune::(lhs, rhs, out), + MatmulStrategy::Autotune => Ok(matmul_autotune::(lhs, rhs, out)), } } diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 3f3232db10..46b1dfacc6 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -83,7 +83,7 @@ fn matmul_accelerated( lhs: JitTensor, rhs: JitTensor, out: JitTensor, -) { +) -> Result<(), String> { cubecl::linalg::matmul::launch_ref::( &Strategy::Standard, &lhs.client, @@ -91,14 +91,14 @@ fn matmul_accelerated( &rhs.as_handle_ref(), &out.as_handle_ref(), ) - .unwrap(); + .map_err(|err| format!("{err:?}")) } fn matmul_tiling2d( lhs: JitTensor, rhs: JitTensor, out: JitTensor, -) { +) -> Result<(), String> { cubecl::linalg::matmul::launch_ref::( &Strategy::Tiling2D(Tiling2dConfig::default()), &lhs.client, @@ -106,14 +106,14 @@ fn matmul_tiling2d( &rhs.as_handle_ref(), &out.as_handle_ref(), ) - .unwrap(); + .map_err(|err| format!("{err:?}")) } fn matmul_simple( lhs: JitTensor, rhs: JitTensor, out: JitTensor, -) { +) -> Result<(), String> { cubecl::linalg::matmul::launch_ref::( &Strategy::Simple, &lhs.client, @@ -121,5 +121,5 @@ fn matmul_simple( &rhs.as_handle_ref(), &out.as_handle_ref(), ) - .unwrap(); + .map_err(|err| format!("{err:?}")) } diff --git a/crates/burn-jit/src/kernel/reduce/base.rs b/crates/burn-jit/src/kernel/reduce/base.rs index 730cc83f37..57cdf13b1e 100644 --- a/crates/burn-jit/src/kernel/reduce/base.rs +++ b/crates/burn-jit/src/kernel/reduce/base.rs @@ -63,13 +63,13 @@ macro_rules! reduce_operation { tensor: JitTensor, dim: usize, strategy: ReduceStrategy, - ) -> JitTensor { + ) -> Result, String> { match strategy { ReduceStrategy::Naive => reduce_dim_naive::<$ops, R, EI, EO>(tensor, dim), ReduceStrategy::SharedMemory => reduce_dim_shared::<$ops, R, EI, EO>(tensor, dim), ReduceStrategy::Subcube => reduce_dim_subcube::<$ops, R, EI, EO>(tensor, dim), #[cfg(feature = "autotune")] - ReduceStrategy::Autotune => reduce_dim_autotune::<$ops, R, EI, EO>(tensor, dim), + ReduceStrategy::Autotune => Ok(reduce_dim_autotune::<$ops, R, EI, EO>(tensor, dim)), } } }; diff --git a/crates/burn-jit/src/kernel/reduce/naive/kernel.rs b/crates/burn-jit/src/kernel/reduce/naive/kernel.rs index a3a1a5441b..c862e7070d 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/kernel.rs @@ -50,7 +50,7 @@ fn naive_reduce, EI: Numeric, EO: Numeric>( pub fn reduce_dim_naive( input: JitTensor, dim: usize, -) -> JitTensor { +) -> Result, String> { let output = init_reduce_output::(&input, dim); let cube_dim = CubeDim::default(); @@ -67,5 +67,5 @@ pub fn reduce_dim_naive( let shape = Shape::new([input.shape.num_elements()]); let input: JitTensor = JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); - prod_dim::(input, 0, strategy) + prod_dim::(input, 0, strategy).unwrap() } diff --git a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs b/crates/burn-jit/src/kernel/reduce/shared/kernel.rs index 1b2dcb356e..1c15e4523f 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/kernel.rs @@ -85,7 +85,7 @@ pub fn reduce_dim_shared< >( input: JitTensor, dim: usize, -) -> JitTensor { +) -> Result, String> { let output = init_reduce_output::(&input, dim); let num_elems_output = output.shape.num_elements(); @@ -113,5 +113,5 @@ pub fn reduce_dim_shared< divisible_shape, ); - output + Ok(output) } diff --git a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs index 4a32b5d641..26f65f5d68 100644 --- a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs +++ b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs @@ -88,7 +88,7 @@ pub fn reduce_dim_subcube< >( input: JitTensor, dim: usize, -) -> JitTensor { +) -> Result, String> { let topology = input.client.properties().hardware_properties(); if !input.client.properties().feature_enabled(Feature::Plane) @@ -130,5 +130,5 @@ pub fn reduce_dim_subcube< divisible_shape, ); - output + Ok(output) } diff --git a/crates/burn-jit/src/kernel/reduce/sum.rs b/crates/burn-jit/src/kernel/reduce/sum.rs index fea80bccf0..d3c9416dc1 100644 --- a/crates/burn-jit/src/kernel/reduce/sum.rs +++ b/crates/burn-jit/src/kernel/reduce/sum.rs @@ -11,5 +11,5 @@ pub fn sum( let shape = Shape::new([input.shape.num_elements()]); let input: JitTensor = JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); - sum_dim::(input, 0, strategy) + sum_dim::(input, 0, strategy).unwrap() } diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 6090e895ae..c59b9df83c 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -165,7 +165,7 @@ where execute_with_dtype!( float(lhs.dtype, rhs.dtype), E, - matmul::(lhs, rhs, None, MatmulStrategy::default()) + matmul::(lhs, rhs, None, MatmulStrategy::default()).unwrap() ) } @@ -363,7 +363,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::sum_dim::(tensor, dim, Default::default()) + reduce::sum_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -371,7 +371,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::mean_dim::(tensor, dim, Default::default()) + reduce::mean_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -387,7 +387,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::prod_dim::(tensor, dim, Default::default()) + reduce::prod_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -467,7 +467,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::argmax::(tensor, dim, Default::default()) + reduce::argmax::(tensor, dim, Default::default()).unwrap() ) } @@ -475,7 +475,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::argmin::(tensor, dim, Default::default()) + reduce::argmin::(tensor, dim, Default::default()).unwrap() ) } diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index a0e181a9c7..ed99258826 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -197,7 +197,7 @@ where } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::sum_dim::(tensor, dim, Default::default()) + kernel::reduce::sum_dim::(tensor, dim, Default::default()).unwrap() } fn int_prod(tensor: IntTensor) -> IntTensor { @@ -205,19 +205,19 @@ where } fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::prod_dim::(tensor, dim, Default::default()) + kernel::reduce::prod_dim::(tensor, dim, Default::default()).unwrap() } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::mean_dim::(tensor, dim, Default::default()) + kernel::reduce::mean_dim::(tensor, dim, Default::default()).unwrap() } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmax::(tensor, dim, Default::default()) + kernel::reduce::argmax::(tensor, dim, Default::default()).unwrap() } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmin::(tensor, dim, Default::default()) + kernel::reduce::argmin::(tensor, dim, Default::default()).unwrap() } fn int_clamp( diff --git a/crates/burn-jit/src/ops/module_ops.rs b/crates/burn-jit/src/ops/module_ops.rs index b5c96058f9..c7f7b18b32 100644 --- a/crates/burn-jit/src/ops/module_ops.rs +++ b/crates/burn-jit/src/ops/module_ops.rs @@ -25,7 +25,7 @@ where bias: Option>, options: ConvOptions<2>, ) -> FloatTensor { - kernel::conv::conv2d::(x, weight, bias, options, Conv2dStrategy::default()) + kernel::conv::conv2d::(x, weight, bias, options, Conv2dStrategy::default()).unwrap() } fn deform_conv2d( @@ -36,7 +36,7 @@ where bias: Option>, options: DeformConvOptions<2>, ) -> FloatTensor { - kernel::conv::deform_conv2d::(x, offset, weight, mask, bias, options) + kernel::conv::deform_conv2d::(x, offset, weight, mask, bias, options).unwrap() } fn deform_conv2d_backward( @@ -57,6 +57,7 @@ where output_grad, options, ) + .unwrap() } fn conv3d( @@ -81,6 +82,7 @@ where options, ConvTranspose2dStrategy::default(), ) + .unwrap() } fn conv_transpose3d( diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index f60edc2a1b..378eb035ed 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -17,7 +17,6 @@ mod max_pool2d; mod max_pool2d_backward; mod normal; mod quantization; -mod reduce; mod repeat_dim; mod scatter; mod select; @@ -48,7 +47,6 @@ macro_rules! testgen_all { mod kernel { use super::*; - burn_jit::testgen_reduction!(); burn_jit::testgen_conv2d!(); burn_jit::testgen_conv3d!(); burn_jit::testgen_conv_transpose2d!(); diff --git a/crates/burn-jit/src/tests/reduce.rs b/crates/burn-jit/src/tests/reduce.rs deleted file mode 100644 index 3e8f81fa8c..0000000000 --- a/crates/burn-jit/src/tests/reduce.rs +++ /dev/null @@ -1,566 +0,0 @@ -#[burn_tensor_testgen::testgen(reduction)] -mod reduction { - use super::*; - use burn_jit::kernel::reduce::{ - argmax, argmin, mean_dim, prod, prod_dim, sum, sum_dim, ReduceStrategy, - }; - use burn_tensor::{ - backend::Backend, ops::IntTensorOps, Distribution, Int, Shape, Tensor, TensorData, - TensorPrimitive, - }; - - #[test] - fn reduction_sum_dim_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Naive, - ))); - let val_ref = tensor_ref.sum_dim(1); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_prod_dim_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(prod_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Naive, - ))); - let val_ref = tensor_ref.prod_dim(1); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_argmin_dim_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(argmin::( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Naive, - )); - let val_ref = tensor_ref.argmin(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_argmax_dim_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(argmax::( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Naive, - )); - let val_ref = tensor_ref.argmax(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn sum_dim_should_work_with_int() { - let summed_shape = Shape::new([1]); - let data = TensorData::from([1, 2, 3, 4]); - let tensor = TestBackend::int_from_data(data, &Default::default()); - - let val = Tensor::::from_primitive(sum_dim::( - tensor, - 0, - ReduceStrategy::Naive, - )); - - let sum_as_data = TensorData::from([10]); - val.into_data().assert_approx_eq(&sum_as_data, 1); - } - - #[test] - fn mean_dim_should_work_with_int() { - let mean_shape = Shape::new([1]); - let data = TensorData::from([1, 2, 3, 4]); - let tensor = TestBackend::int_from_data(data, &Default::default()); - - let val = Tensor::::from_primitive(mean_dim::( - tensor, - 0, - ReduceStrategy::Naive, - )); - - // Mean calculation truncates to an integer - let mean_as_data = TensorData::from([2]); - val.into_data().assert_approx_eq(&mean_as_data, 1); - } - - #[test] - fn reduction_sum_dim_shared_memory_small() { - let tensor = - Tensor::::random([700], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 0; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_subcube_small() { - let tensor = - Tensor::::random([700], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 0; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_shared_memory_medium_divisible() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_subcube_medium_divisible() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_shared_memory_medium_not_divisible() { - let tensor = - Tensor::::random([6, 1025], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_subcube_medium_not_divisible() { - let tensor = - Tensor::::random([6, 1025], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_shared_memory_large() { - let tensor = Tensor::::random( - [4, 1024, 50], - Distribution::Default, - &Default::default(), - ); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_sum_dim_subcube_large() { - let tensor = Tensor::::random( - [4, 1024, 50], - Distribution::Default, - &Default::default(), - ); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.sum_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_mean_dim_shared_memory_medium() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 0; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(mean_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.mean_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_mean_dim_subcube_medium() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 0; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(mean_dim::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.mean_dim(reduce_dim); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_argmin_shared_memory_medium() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(argmin::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.argmin(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_argmin_subcube_medium() { - let tensor = - Tensor::::random([6, 1024], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(argmin::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.argmin(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_argmax_shared_memory_medium() { - let tensor = Tensor::::random( - [10, 3000], - Distribution::Default, - &Default::default(), - ); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(argmax::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::SharedMemory, - ))); - let val_ref = tensor_ref.argmax(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_argmax_subcube_medium() { - let tensor = Tensor::::random( - [10, 3000], - Distribution::Default, - &Default::default(), - ); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - let reduce_dim = 1; - - let val = Tensor::::from_primitive(TensorPrimitive::Float(argmax::< - TestRuntime, - f32, - f32, - >( - tensor.into_primitive().tensor(), - reduce_dim, - ReduceStrategy::Subcube, - ))); - let val_ref = tensor_ref.argmax(reduce_dim); - - val_ref.into_data().assert_eq(&val.into_data(), false); - } - - #[test] - fn reduction_sum_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 256], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - - let val = Tensor::::from_primitive(TensorPrimitive::Float(sum::< - _, - ::FloatElem, - >( - tensor.into_primitive().tensor(), - ReduceStrategy::default(), - ))); - let val_ref = tensor_ref.sum(); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_prod_should_match_reference_backend() { - let tensor = - Tensor::::random([6, 256], Distribution::Default, &Default::default()); - let tensor_ref = - Tensor::::from_data(tensor.to_data(), &Default::default()); - - let val = Tensor::::from_primitive(TensorPrimitive::Float(prod::< - _, - ::FloatElem, - >( - tensor.into_primitive().tensor(), - ReduceStrategy::default(), - ))); - let val_ref = tensor_ref.prod(); - - val_ref.into_data().assert_approx_eq(&val.into_data(), 2); - } - - #[test] - fn reduction_argmax_shared_memory_extreme_values_float() { - let data = TensorData::from([-999999., -999997., -999998.]); - let tensor = Tensor::::from_data(data, &Default::default()); - - let val_shared = - Tensor::::from_primitive(argmax::( - tensor.into_primitive().tensor(), - 0, - ReduceStrategy::SharedMemory, - )); - - assert_eq!( - 1, - val_shared - .into_data() - .as_slice::<::IntElem>() - .unwrap()[0] - ); - } - - #[test] - fn reduction_argmin_shared_memory_extreme_values_float() { - let data = TensorData::from([999999., 999998., 999997.]); - let tensor = Tensor::::from_data(data, &Default::default()); - - let val_shared = - Tensor::::from_primitive(argmin::( - tensor.into_primitive().tensor(), - 0, - ReduceStrategy::SharedMemory, - )); - - assert_eq!( - 2, - val_shared - .into_data() - .as_slice::<::IntElem>() - .unwrap()[0] - ); - } - - #[test] - fn reduction_argmin_shared_memory_extreme_values_i32() { - let data = TensorData::from([999999, 999998, 999997]); - let tensor = Tensor::::from_data(data, &Default::default()); - - let val_shared = - Tensor::::from_primitive(argmin::( - tensor.into_primitive(), - 0, - ReduceStrategy::SharedMemory, - )); - - assert_eq!( - 2, - val_shared - .into_data() - .as_slice::<::IntElem>() - .unwrap()[0] - ); - } - - #[test] - fn reduction_argmax_shared_memory_extreme_values_i32() { - let data = TensorData::from([-999999, -999997, -999998]); - let tensor = Tensor::::from_data(data, &Default::default()); - - let val_shared = - Tensor::::from_primitive(argmax::( - tensor.into_primitive(), - 0, - ReduceStrategy::SharedMemory, - )); - - assert_eq!( - 1, - val_shared - .into_data() - .as_slice::<::IntElem>() - .unwrap()[0] - ); - } -} From 95593fc92ce45b25b25b4bd29ca0459e8edb9c17 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 10 Jan 2025 11:21:09 -0500 Subject: [PATCH 09/61] Fix parallel macro + CI (#2678) * Fix rayon issues * Fix typos * Fix for_each no std * Fix clippy --- .../src/burnbenchapp/auth/base.rs | 2 +- crates/burn-common/src/lib.rs | 3 + crates/burn-common/src/parallel.rs | 75 ++++++++++++++----- crates/burn-jit/src/fusion/on_write/ir.rs | 2 +- crates/burn-ndarray/src/ops/deform_conv.rs | 52 +++++++------ .../src/tensor/quantization/bytes.rs | 2 +- .../src/tensor/quantization/strategy.rs | 43 +++++++---- 7 files changed, 119 insertions(+), 60 deletions(-) diff --git a/backend-comparison/src/burnbenchapp/auth/base.rs b/backend-comparison/src/burnbenchapp/auth/base.rs index 1cff65f690..f956b0a204 100644 --- a/backend-comparison/src/burnbenchapp/auth/base.rs +++ b/backend-comparison/src/burnbenchapp/auth/base.rs @@ -133,7 +133,7 @@ fn verify_tokens(tokens: &Tokens) -> bool { ) .header(GITHUB_API_VERSION_HEADER, GITHUB_API_VERSION) .send(); - response.map_or(false, |resp| resp.status().is_success()) + response.is_ok_and(|resp| resp.status().is_success()) } fn refresh_tokens(tokens: &Tokens) -> Option { diff --git a/crates/burn-common/src/lib.rs b/crates/burn-common/src/lib.rs index 77faa0195d..efe3b6d7d2 100644 --- a/crates/burn-common/src/lib.rs +++ b/crates/burn-common/src/lib.rs @@ -11,6 +11,9 @@ pub mod id; pub use cubecl_common::*; +#[cfg(feature = "rayon")] +pub use rayon; + extern crate alloc; /// Network utilities. diff --git a/crates/burn-common/src/parallel.rs b/crates/burn-common/src/parallel.rs index 93c0da7d2d..969683f9e7 100644 --- a/crates/burn-common/src/parallel.rs +++ b/crates/burn-common/src/parallel.rs @@ -1,51 +1,90 @@ /// Macro for running a function in parallel. +#[cfg(feature = "rayon")] #[macro_export(local_inner_macros)] macro_rules! run_par { ( $func:expr ) => {{ - #[cfg(feature = "rayon")] - use rayon::prelude::*; + use $crate::rayon::prelude::*; - #[cfg(feature = "rayon")] #[allow(clippy::redundant_closure_call)] - let output = rayon::scope(|_| $func()); + $crate::rayon::scope(|_| $func()) + }}; +} - #[cfg(not(feature = "rayon"))] - let output = $func(); +/// Macro for running a function in parallel. +#[cfg(not(feature = "rayon"))] +#[macro_export(local_inner_macros)] +macro_rules! run_par { + ( + $func:expr + ) => {{ + $func() + }}; +} - output +/// Macro for iterating in parallel. +#[cfg(not(feature = "rayon"))] +#[macro_export(local_inner_macros)] +macro_rules! iter_par { + ( + $iter:expr + ) => {{ + $iter }}; } /// Macro for iterating in parallel. +#[cfg(feature = "rayon")] #[macro_export(local_inner_macros)] macro_rules! iter_par { ( $iter:expr ) => {{ - #[cfg(feature = "rayon")] - let output = $iter.into_par_iter(); + $iter.into_par_iter() + }}; +} - #[cfg(not(feature = "rayon"))] - let output = $iter; +/// Macro for iterating in parallel. +#[cfg(feature = "rayon")] +#[macro_export(local_inner_macros)] +macro_rules! iter_slice_par { + ( + $slice:expr + ) => {{ + $slice.into_par_iter() + }}; +} - output +/// Macro for iterating in parallel. +#[cfg(not(feature = "rayon"))] +#[macro_export(local_inner_macros)] +macro_rules! iter_slice_par { + ( + $slice:expr + ) => {{ + $slice.iter() }}; } /// Macro for iterating over a range in parallel. +#[cfg(feature = "rayon")] #[macro_export(local_inner_macros)] macro_rules! iter_range_par { ( $start:expr, $end:expr ) => {{ - #[cfg(feature = "rayon")] - let output = ($start..$end).into_par_iter(); - - #[cfg(not(feature = "rayon"))] - let output = ($start..$end); + ($start..$end).into_par_iter() + }}; +} - output +/// Macro for iterating over a range in parallel. +#[cfg(not(feature = "rayon"))] +#[macro_export(local_inner_macros)] +macro_rules! iter_range_par { + ( + $start:expr, $end:expr + ) => {{ + ($start..$end) }}; } diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index 0b6272b4f1..0cec2d29c7 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -154,7 +154,7 @@ impl GlobalArgsLaunch<'_, R> { } } - /// Resolve the [argument](Arg) to a [tensor arguemnt](TensorArg). + /// Resolve the [argument](Arg) to a [tensor argument](TensorArg). /// /// # Panics /// diff --git a/crates/burn-ndarray/src/ops/deform_conv.rs b/crates/burn-ndarray/src/ops/deform_conv.rs index ac2e25ae86..504c1a8c59 100644 --- a/crates/burn-ndarray/src/ops/deform_conv.rs +++ b/crates/burn-ndarray/src/ops/deform_conv.rs @@ -6,7 +6,7 @@ use burn_tensor::{ use core::ops::AddAssign; use ndarray::{ s, Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim, - Ix4, + Ix4, Zip, }; #[cfg(not(feature = "std"))] use num_traits::Float; @@ -593,31 +593,37 @@ pub mod backward { AtomicF32::new(0.0) }); + let compute_for_each = |(in_channel, kernel_y, kernel_x, batch, out_y, out_x), col: &F| { + let group = in_channel / channels_per_offset_group; + let offset = offset.slice(s![batch, .., out_y, out_x]); + let offset = offset + .to_shape((offs_groups, kernel_h, kernel_w, 2)) + .unwrap(); + let offset = offset.slice(s![group, kernel_y, kernel_x, ..]); + let offset = [offset[0], offset[1]]; + let mask = mask + .as_ref() + .map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32()); + let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) + - F::from_elem(args.padding[0]) + + offset[0]; + let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) + - F::from_elem(args.padding[1]) + + offset[1]; + let grad_in = grad_in.slice(s![batch, in_channel, .., ..]); + deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in); + }; + + // `for_each` expects a 2-tuple argument with `.into_par_iter()`, but 2 separate arguments otherwise + #[cfg(feature = "std")] run_par!(|| { - iter_par!(columns.indexed_iter()).for_each( - |((in_channel, kernel_y, kernel_x, batch, out_y, out_x), col)| { - let group = in_channel / channels_per_offset_group; - let offset = offset.slice(s![batch, .., out_y, out_x]); - let offset = offset - .to_shape((offs_groups, kernel_h, kernel_w, 2)) - .unwrap(); - let offset = offset.slice(s![group, kernel_y, kernel_x, ..]); - let offset = [offset[0], offset[1]]; - let mask = mask - .as_ref() - .map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32()); - let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) - - F::from_elem(args.padding[0]) - + offset[0]; - let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) - - F::from_elem(args.padding[1]) - + offset[1]; - let grad_in = grad_in.slice(s![batch, in_channel, .., ..]); - deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in); - }, - ) + iter_par!(Zip::indexed(columns)) + .for_each(|(args0, args1)| compute_for_each(args0, args1)) }); + #[cfg(not(feature = "std"))] + run_par!(|| { iter_par!(Zip::indexed(columns).for_each(compute_for_each)) }); + let grad_in: Array1 = grad_in .into_iter() .map(|it| F::from_elem(it.into_inner())) diff --git a/crates/burn-tensor/src/tensor/quantization/bytes.rs b/crates/burn-tensor/src/tensor/quantization/bytes.rs index 9091c37960..6d880cc923 100644 --- a/crates/burn-tensor/src/tensor/quantization/bytes.rs +++ b/crates/burn-tensor/src/tensor/quantization/bytes.rs @@ -100,7 +100,7 @@ impl QuantizedBytes { /// Splits the quantized values of the tensor from the quantization parameters. /// - /// Returns the packed values and a newly allocated vector containining the quantization parameters. + /// Returns the packed values and a newly allocated vector containing the quantization parameters. fn split_values_off(self) -> (Vec, Vec) { // The bytes can be created either from packed u32 or existing bytes with the same representation. let mut values = match self.bytes.align() { diff --git a/crates/burn-tensor/src/tensor/quantization/strategy.rs b/crates/burn-tensor/src/tensor/quantization/strategy.rs index bb8b4c6bfb..73f1b1c0b0 100644 --- a/crates/burn-tensor/src/tensor/quantization/strategy.rs +++ b/crates/burn-tensor/src/tensor/quantization/strategy.rs @@ -4,7 +4,7 @@ use core::{ }; use alloc::vec::Vec; -use burn_common::{iter_par, run_par}; +use burn_common::{iter_slice_par, run_par}; use num_traits::{Float, PrimInt}; use serde::{Deserialize, Serialize}; @@ -35,7 +35,7 @@ impl QuantizationStrategy { /// Quantization scheme to convert elements of a higher precision data type `E` to a lower precision /// data type `Q` and vice-versa. -pub trait Quantization { +pub trait Quantization { /// Create a new quantization scheme for an input range `[alpha, beta]`. fn new(alpha: E, beta: E) -> Self; /// Convert the values to a lower precision data type. @@ -48,7 +48,7 @@ pub trait Quantization { /// /// Note that the accumulation type `A` should have a bigger range than quantized type `Q`. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub struct AffineQuantization { +pub struct AffineQuantization { /// The scaling factor. pub scale: E, /// The zero-point offset. @@ -66,7 +66,7 @@ fn valid_scale(mut scale: E) -> E { scale } -impl AffineQuantization { +impl AffineQuantization { /// Initialize an affine quantization scheme with the given parameters. pub fn init(scale: E, offset: Q) -> Self { Self { @@ -77,7 +77,9 @@ impl AffineQuantization { } } -impl Quantization for AffineQuantization { +impl Quantization + for AffineQuantization +{ fn new(alpha: E, beta: E) -> Self { // Q range `[a, b]` let a = E::from(Q::min_value()).unwrap(); @@ -107,7 +109,7 @@ impl Quantization for AffineQuantization // x_q = clamp(round(x / scale + offset), a, b) let z = E::from(self.offset).unwrap(); run_par!(|| { - iter_par!(values.iter()) + iter_slice_par!(values) .map(|x| Q::from(x.div(self.scale).add(z).round().clamp(a, b)).unwrap()) .collect() }) @@ -116,7 +118,7 @@ impl Quantization for AffineQuantization fn dequantize(&self, values: &[Q]) -> Vec { // x = scale * (x_q - offset) run_par!(|| { - iter_par!(values.iter()) + iter_slice_par!(values) .map(|x_q| { self.scale * (E::from( @@ -133,14 +135,14 @@ impl Quantization for AffineQuantization /// Symmetric quantization scheme. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] -pub struct SymmetricQuantization { +pub struct SymmetricQuantization { /// The scaling factor. pub scale: E, /// The quantized type. _q: PhantomData, } -impl SymmetricQuantization { +impl SymmetricQuantization { /// Initialize a symmetric quantization scheme with the given parameters. pub fn init(scale: E) -> Self { Self { @@ -150,7 +152,9 @@ impl SymmetricQuantization { } } -impl Quantization for SymmetricQuantization { +impl Quantization + for SymmetricQuantization +{ fn new(alpha: E, beta: E) -> Self { assert!( !Q::min_value().is_zero(), @@ -214,7 +218,9 @@ fn canonicalize_signed_zero(x: T) -> T { x + T::zero() } -impl Hash for AffineQuantization { +impl Hash + for AffineQuantization +{ fn hash(&self, state: &mut H) { // Hash raw bits. let bits = raw_double_bits(&canonicalize_signed_zero(self.scale)); @@ -223,15 +229,20 @@ impl Hash for AffineQuantization PartialEq for AffineQuantization { +impl PartialEq + for AffineQuantization +{ fn eq(&self, other: &Self) -> bool { self.scale == other.scale && self.offset == other.offset } } -impl Eq for AffineQuantization {} +impl Eq + for AffineQuantization +{ +} -impl Hash for SymmetricQuantization { +impl Hash for SymmetricQuantization { fn hash(&self, state: &mut H) { // Hash raw bits. let bits = raw_double_bits(&canonicalize_signed_zero(self.scale)); @@ -239,13 +250,13 @@ impl Hash for SymmetricQuantization { } } -impl PartialEq for SymmetricQuantization { +impl PartialEq for SymmetricQuantization { fn eq(&self, other: &Self) -> bool { self.scale == other.scale } } -impl Eq for SymmetricQuantization {} +impl Eq for SymmetricQuantization {} #[cfg(test)] mod tests { From 5b3079ae092e597c754a221636774a63beee5fc7 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Sat, 11 Jan 2025 12:37:19 -0500 Subject: [PATCH 10/61] Add checks for even padding when kernel size is even (#2677) * Add checks for even padding when lkernel size is even * Move the check to config init --- crates/burn-core/src/nn/conv/checks.rs | 11 +++++++++++ crates/burn-core/src/nn/conv/conv1d.rs | 15 +++++++++++++++ crates/burn-core/src/nn/conv/conv2d.rs | 15 +++++++++++++++ crates/burn-core/src/nn/conv/conv3d.rs | 11 +++++++++++ crates/burn-core/src/nn/conv/deform_conv2d.rs | 15 +++++++++++++++ crates/burn-core/src/nn/pool/avg_pool1d.rs | 19 +++++++++++++++---- crates/burn-core/src/nn/pool/avg_pool2d.rs | 18 +++++++++++++++--- crates/burn-core/src/nn/pool/max_pool1d.rs | 15 +++++++++++++++ crates/burn-core/src/nn/pool/max_pool2d.rs | 15 +++++++++++++++ 9 files changed, 127 insertions(+), 7 deletions(-) diff --git a/crates/burn-core/src/nn/conv/checks.rs b/crates/burn-core/src/nn/conv/checks.rs index cd346163ad..36932621f1 100644 --- a/crates/burn-core/src/nn/conv/checks.rs +++ b/crates/burn-core/src/nn/conv/checks.rs @@ -9,3 +9,14 @@ pub(crate) fn checks_channels_div_groups(channels_in: usize, channels_out: usize ); } } + +// https://github.com/tracel-ai/burn/issues/2676 +/// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel +/// size is not supported as it will not produce the same output size. +pub(crate) fn check_same_padding_support(kernel_size: &[usize]) { + for k in kernel_size.iter() { + if k % 2 == 0 { + unimplemented!("Same padding with an even kernel size is not supported"); + } + } +} diff --git a/crates/burn-core/src/nn/conv/conv1d.rs b/crates/burn-core/src/nn/conv/conv1d.rs index 0b64eab324..c3f61a6b07 100644 --- a/crates/burn-core/src/nn/conv/conv1d.rs +++ b/crates/burn-core/src/nn/conv/conv1d.rs @@ -28,6 +28,10 @@ pub struct Conv1dConfig { #[config(default = "1")] pub groups: usize, /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig1d::Valid")] pub padding: PaddingConfig1d, /// If bias should be added to the output. @@ -87,6 +91,9 @@ impl Conv1dConfig { /// Initialize a new [conv1d](Conv1d) module. pub fn init(&self, device: &B::Device) -> Conv1d { checks::checks_channels_div_groups(self.channels_in, self.channels_out, self.groups); + if self.padding == PaddingConfig1d::Same { + checks::check_same_padding_support(&[self.kernel_size]); + } let shape = [ self.channels_out, @@ -175,6 +182,14 @@ mod tests { .assert_approx_eq(&TensorData::zeros::(conv.weight.shape()), 3); } + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let device = Default::default(); + let config = Conv1dConfig::new(5, 5, 4).with_padding(PaddingConfig1d::Same); + let _ = config.init::(&device); + } + #[test] fn display() { let config = Conv1dConfig::new(5, 5, 5); diff --git a/crates/burn-core/src/nn/conv/conv2d.rs b/crates/burn-core/src/nn/conv/conv2d.rs index 73be36d357..72c00187be 100644 --- a/crates/burn-core/src/nn/conv/conv2d.rs +++ b/crates/burn-core/src/nn/conv/conv2d.rs @@ -30,6 +30,10 @@ pub struct Conv2dConfig { #[config(default = "1")] pub groups: usize, /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig2d::Valid")] pub padding: PaddingConfig2d, /// If bias should be added to the output. @@ -68,6 +72,9 @@ impl Conv2dConfig { /// Initialize a new [conv2d](Conv2d) module. pub fn init(&self, device: &B::Device) -> Conv2d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); + if self.padding == PaddingConfig2d::Same { + checks::check_same_padding_support(&self.kernel_size); + } let shape = [ self.channels[1], @@ -228,6 +235,14 @@ mod tests { let _ = config.init::(&device); } + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let device = Default::default(); + let config = Conv2dConfig::new([4, 4], [2, 2]).with_padding(PaddingConfig2d::Same); + let _ = config.init::(&device); + } + #[test] fn display() { let config = Conv2dConfig::new([5, 1], [5, 5]); diff --git a/crates/burn-core/src/nn/conv/conv3d.rs b/crates/burn-core/src/nn/conv/conv3d.rs index de7fb1ce2b..0b5d530c5a 100644 --- a/crates/burn-core/src/nn/conv/conv3d.rs +++ b/crates/burn-core/src/nn/conv/conv3d.rs @@ -68,6 +68,9 @@ impl Conv3dConfig { /// Initialize a new [conv3d](Conv3d) module. pub fn init(&self, device: &B::Device) -> Conv3d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.groups); + if self.padding == PaddingConfig3d::Same { + checks::check_same_padding_support(&self.kernel_size); + } let shape = [ self.channels[1], @@ -228,6 +231,14 @@ mod tests { assert_eq!(config.initializer, init); } + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let device = Default::default(); + let config = Conv3dConfig::new([4, 4], [2, 2, 2]).with_padding(PaddingConfig3d::Same); + let _ = config.init::(&device); + } + #[test] fn display() { let config = Conv3dConfig::new([5, 1], [5, 5, 5]); diff --git a/crates/burn-core/src/nn/conv/deform_conv2d.rs b/crates/burn-core/src/nn/conv/deform_conv2d.rs index 03becd9d4e..2baff11d07 100644 --- a/crates/burn-core/src/nn/conv/deform_conv2d.rs +++ b/crates/burn-core/src/nn/conv/deform_conv2d.rs @@ -33,6 +33,10 @@ pub struct DeformConv2dConfig { #[config(default = "1")] pub offset_groups: usize, /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig2d::Valid")] pub padding: PaddingConfig2d, /// If bias should be added to the output. @@ -73,6 +77,9 @@ impl DeformConv2dConfig { /// Initialize a new [DeformConv2d](DeformConv2d) module. pub fn init(&self, device: &B::Device) -> DeformConv2d { checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.weight_groups); + if self.padding == PaddingConfig2d::Same { + checks::check_same_padding_support(&self.kernel_size); + } let shape = [ self.channels[1], @@ -250,6 +257,14 @@ mod tests { let _ = config.init::(&device); } + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let device = Default::default(); + let config = DeformConv2dConfig::new([4, 4], [2, 2]).with_padding(PaddingConfig2d::Same); + let _ = config.init::(&device); + } + #[test] fn display() { let config = DeformConv2dConfig::new([5, 1], [5, 5]); diff --git a/crates/burn-core/src/nn/pool/avg_pool1d.rs b/crates/burn-core/src/nn/pool/avg_pool1d.rs index 949160fd5b..24ec8ff972 100644 --- a/crates/burn-core/src/nn/pool/avg_pool1d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool1d.rs @@ -1,4 +1,5 @@ use crate as burn; +use crate::nn::conv::checks::check_same_padding_support; use crate::config::Config; use crate::module::{Content, DisplaySettings, ModuleDisplay}; @@ -18,6 +19,10 @@ pub struct AvgPool1dConfig { #[config(default = "1")] pub stride: usize, /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig1d::Valid")] pub padding: PaddingConfig1d, /// If the padding is counted in the denominator when computing the average. @@ -36,10 +41,6 @@ pub struct AvgPool1dConfig { /// legitimate values, and they contribute to the denominator /// when calculating the average. This is equivalent to /// `torch.nn.AvgPool2d` with `count_include_pad=True`. -/// -/// TODO: Add support for `count_include_pad=False`, see -/// [Issue 636](https://github.com/tracel-ai/burn/issues/636) - #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct AvgPool1d { @@ -73,6 +74,9 @@ impl ModuleDisplay for AvgPool1d { impl AvgPool1dConfig { /// Initialize a new [avg pool 1d](AvgPool1d) module. pub fn init(&self) -> AvgPool1d { + if self.padding == PaddingConfig1d::Same { + check_same_padding_support(&[self.kernel_size]); + } AvgPool1d { stride: self.stride, kernel_size: self.kernel_size, @@ -111,6 +115,13 @@ impl AvgPool1d { mod tests { use super::*; + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let config = AvgPool1dConfig::new(2).with_padding(PaddingConfig1d::Same); + let _ = config.init(); + } + #[test] fn display() { let config = AvgPool1dConfig::new(3); diff --git a/crates/burn-core/src/nn/pool/avg_pool2d.rs b/crates/burn-core/src/nn/pool/avg_pool2d.rs index 6c6ffc87ed..343d59922b 100644 --- a/crates/burn-core/src/nn/pool/avg_pool2d.rs +++ b/crates/burn-core/src/nn/pool/avg_pool2d.rs @@ -1,4 +1,5 @@ use crate as burn; +use crate::nn::conv::checks::check_same_padding_support; use crate::config::Config; use crate::module::{Content, DisplaySettings, ModuleDisplay}; @@ -18,6 +19,10 @@ pub struct AvgPool2dConfig { #[config(default = "[1, 1]")] pub strides: [usize; 2], /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig2d::Valid")] pub padding: PaddingConfig2d, /// If the padding is counted in the denominator when computing the average. @@ -36,9 +41,6 @@ pub struct AvgPool2dConfig { /// legitimate values, and they contribute to the denominator /// when calculating the average. This is equivalent to /// `torch.nn.AvgPool2d` with `count_include_pad=True`. -/// -/// TODO: Add support for `count_include_pad=False`, see -/// [Issue 636](https://github.com/tracel-ai/burn/issues/636) #[derive(Module, Clone, Debug)] #[module(custom_display)] pub struct AvgPool2d { @@ -72,6 +74,9 @@ impl ModuleDisplay for AvgPool2d { impl AvgPool2dConfig { /// Initialize a new [avg pool 2d](AvgPool2d) module. pub fn init(&self) -> AvgPool2d { + if self.padding == PaddingConfig2d::Same { + check_same_padding_support(&self.kernel_size); + } AvgPool2d { stride: self.strides, kernel_size: self.kernel_size, @@ -110,6 +115,13 @@ impl AvgPool2d { mod tests { use super::*; + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let config = AvgPool2dConfig::new([2, 2]).with_padding(PaddingConfig2d::Same); + let _ = config.init(); + } + #[test] fn display() { let config = AvgPool2dConfig::new([3, 3]); diff --git a/crates/burn-core/src/nn/pool/max_pool1d.rs b/crates/burn-core/src/nn/pool/max_pool1d.rs index 5be363e908..71041e6155 100644 --- a/crates/burn-core/src/nn/pool/max_pool1d.rs +++ b/crates/burn-core/src/nn/pool/max_pool1d.rs @@ -1,4 +1,5 @@ use crate as burn; +use crate::nn::conv::checks::check_same_padding_support; use crate::config::Config; use crate::module::{Content, DisplaySettings, ModuleDisplay}; @@ -18,6 +19,10 @@ pub struct MaxPool1dConfig { #[config(default = "1")] pub stride: usize, /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig1d::Valid")] pub padding: PaddingConfig1d, /// The dilation. @@ -61,6 +66,9 @@ impl ModuleDisplay for MaxPool1d { impl MaxPool1dConfig { /// Initialize a new [max pool 1d](MaxPool1d) module. pub fn init(&self) -> MaxPool1d { + if self.padding == PaddingConfig1d::Same { + check_same_padding_support(&[self.kernel_size]); + } MaxPool1d { stride: self.stride, kernel_size: self.kernel_size, @@ -93,6 +101,13 @@ impl MaxPool1d { mod tests { use super::*; + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let config = MaxPool1dConfig::new(2).with_padding(PaddingConfig1d::Same); + let _ = config.init(); + } + #[test] fn display() { let config = MaxPool1dConfig::new(3); diff --git a/crates/burn-core/src/nn/pool/max_pool2d.rs b/crates/burn-core/src/nn/pool/max_pool2d.rs index ab9c60d276..3eb94f5db5 100644 --- a/crates/burn-core/src/nn/pool/max_pool2d.rs +++ b/crates/burn-core/src/nn/pool/max_pool2d.rs @@ -1,4 +1,5 @@ use crate as burn; +use crate::nn::conv::checks::check_same_padding_support; use crate::config::Config; use crate::module::{Content, DisplaySettings, ModuleDisplay}; @@ -18,6 +19,10 @@ pub struct MaxPool2dConfig { #[config(default = "[1, 1]")] pub strides: [usize; 2], /// The padding configuration. + /// + /// ### Warning + /// Only symmetric padding is currently supported. As such, using `Same` padding with an even kernel + /// size is not supported as it will not produce the same output size. #[config(default = "PaddingConfig2d::Valid")] pub padding: PaddingConfig2d, /// The dilation. @@ -61,6 +66,9 @@ impl ModuleDisplay for MaxPool2d { impl MaxPool2dConfig { /// Initialize a new [max pool 2d](MaxPool2d) module. pub fn init(&self) -> MaxPool2d { + if self.padding == PaddingConfig2d::Same { + check_same_padding_support(&self.kernel_size); + } MaxPool2d { stride: self.strides, kernel_size: self.kernel_size, @@ -93,6 +101,13 @@ impl MaxPool2d { mod tests { use super::*; + #[test] + #[should_panic = "Same padding with an even kernel size is not supported"] + fn same_with_even_kernel_is_invalid() { + let config = MaxPool2dConfig::new([2, 2]).with_padding(PaddingConfig2d::Same); + let _ = config.init(); + } + #[test] fn display() { let config = MaxPool2dConfig::new([3, 3]); From 51b742f6a617835cf7bfbc994e455201d816a081 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Sun, 12 Jan 2025 11:45:12 -0500 Subject: [PATCH 11/61] Update cubecl (#2680) --- Cargo.lock | 24 +++++++++---------- Cargo.toml | 4 ++-- .../kernel/conv/deform_conv_transpose2d.rs | 8 +++---- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 11ae2b3de3..3d24a99fa0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1581,7 +1581,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" +source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1613,7 +1613,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" +source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1630,7 +1630,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" +source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1649,7 +1649,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" +source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1663,7 +1663,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" +source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1679,7 +1679,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" +source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1705,7 +1705,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" +source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" dependencies = [ "bytemuck", "cubecl-core", @@ -1717,7 +1717,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" +source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1732,7 +1732,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" +source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1769,7 +1769,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" +source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" dependencies = [ "async-channel", "async-lock", @@ -1790,7 +1790,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" +source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1804,7 +1804,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d#5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" +source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 3a03582fea..1f58e6e751 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8244dbb4660e373ff1ffb780feb73a5b899e5977" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8244dbb4660e373ff1ffb780feb73a5b899e5977" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index ad9e11c6c5..5e51623505 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -509,7 +509,7 @@ fn deform_col2img_kernel( offset: &Tensor, mask: &Tensor, columns: &Tensor, - grad_input: &mut Tensor, + grad_input: &mut Tensor>, args: &DeformConv2dCol2ImgArgs, #[comptime] use_mask: bool, ) { @@ -589,14 +589,14 @@ fn deform_col2img_kernel( } #[cube] -fn float_atomic_add(ptr: &mut AtomicU32, value: f32) { +fn float_atomic_add(ptr: &mut Atomic, value: f32) { if value != 0.0 { - let mut v = AtomicU32::load(ptr); + let mut v = Atomic::::load(ptr); loop { let prev = v; let v_float = f32::bitcast_from(v); let new = u32::bitcast_from(v_float + value); - v = AtomicU32::compare_and_swap(ptr, v, new); + v = Atomic::::compare_and_swap(ptr, v, new); if prev == v { break; } From 8115815129ae98c2f8a1e144c1e0d71b80be17e4 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 13 Jan 2025 08:09:36 -0500 Subject: [PATCH 12/61] Combined PRs (#2687) * Bump clap from 4.5.24 to 4.5.26 Bumps [clap](https://github.com/clap-rs/clap) from 4.5.24 to 4.5.26. - [Release notes](https://github.com/clap-rs/clap/releases) - [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md) - [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.24...clap_complete-v4.5.26) --- updated-dependencies: - dependency-name: clap dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump serde_json from 1.0.134 to 1.0.135 Bumps [serde_json](https://github.com/serde-rs/json) from 1.0.134 to 1.0.135. - [Release notes](https://github.com/serde-rs/json/releases) - [Commits](https://github.com/serde-rs/json/compare/v1.0.134...v1.0.135) --- updated-dependencies: - dependency-name: serde_json dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump thiserror from 2.0.9 to 2.0.11 Bumps [thiserror](https://github.com/dtolnay/thiserror) from 2.0.9 to 2.0.11. - [Release notes](https://github.com/dtolnay/thiserror/releases) - [Commits](https://github.com/dtolnay/thiserror/compare/2.0.9...2.0.11) --- updated-dependencies: - dependency-name: thiserror dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump proc-macro2 from 1.0.92 to 1.0.93 Bumps [proc-macro2](https://github.com/dtolnay/proc-macro2) from 1.0.92 to 1.0.93. - [Release notes](https://github.com/dtolnay/proc-macro2/releases) - [Commits](https://github.com/dtolnay/proc-macro2/compare/1.0.92...1.0.93) --- updated-dependencies: - dependency-name: proc-macro2 dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Cargo.lock | 46 +++++++++++++++++++++++----------------------- Cargo.toml | 8 ++++---- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3d24a99fa0..342747e651 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -696,7 +696,7 @@ dependencies = [ "serde_json", "spin", "tempfile", - "thiserror 2.0.9", + "thiserror 2.0.11", "uuid", ] @@ -744,7 +744,7 @@ dependencies = [ "strum", "strum_macros", "tempfile", - "thiserror 2.0.9", + "thiserror 2.0.11", ] [[package]] @@ -805,7 +805,7 @@ dependencies = [ "serde", "serde_json", "syn 2.0.95", - "thiserror 2.0.9", + "thiserror 2.0.11", "tracing-core", "tracing-subscriber", "zip 2.2.2", @@ -1231,9 +1231,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.24" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9560b07a799281c7e0958b9296854d6fafd4c5f31444a7e5bb1ad6dde5ccf1bd" +checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" dependencies = [ "clap_builder", "clap_derive", @@ -1241,9 +1241,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.24" +version = "4.5.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "874e0dd3eb68bf99058751ac9712f622e61e6f393a94f7128fa26e3f02f5c7cd" +checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" dependencies = [ "anstream", "anstyle", @@ -2782,7 +2782,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b5eccc17194ed0e67d49285e4853307e4147e95407f91c1c3e4a13ba9f4e4ce" dependencies = [ "faster-hex", - "thiserror 2.0.9", + "thiserror 2.0.11", ] [[package]] @@ -3673,7 +3673,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -4586,7 +4586,7 @@ dependencies = [ "flate2", "native-tls", "tar", - "thiserror 2.0.9", + "thiserror 2.0.11", "ureq", ] @@ -5463,9 +5463,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.92" +version = "1.0.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" dependencies = [ "unicode-ident", ] @@ -5684,7 +5684,7 @@ dependencies = [ "rustc-hash 2.1.0", "rustls", "socket2", - "thiserror 2.0.9", + "thiserror 2.0.11", "tokio", "tracing", ] @@ -5703,7 +5703,7 @@ dependencies = [ "rustls", "rustls-pki-types", "slab", - "thiserror 2.0.9", + "thiserror 2.0.11", "tinyvec", "tracing", "web-time", @@ -6524,9 +6524,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.134" +version = "1.0.135" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d" +checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" dependencies = [ "itoa", "memchr", @@ -7168,11 +7168,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.9" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f072643fd0190df67a8bab670c20ef5d8737177d6ac6b2e9a236cb096206b2cc" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" dependencies = [ - "thiserror-impl 2.0.9", + "thiserror-impl 2.0.11", ] [[package]] @@ -7188,9 +7188,9 @@ dependencies = [ [[package]] name = "thiserror-impl" -version = "2.0.9" +version = "2.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b50fa271071aae2e6ee85f842e2e28ba8cd2c5fb67f11fcb1fd70b276f9e7d4" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", @@ -7620,7 +7620,7 @@ dependencies = [ "log", "rand", "sha1", - "thiserror 2.0.9", + "thiserror 2.0.11", "utf-8", ] @@ -8730,7 +8730,7 @@ dependencies = [ "pbkdf2 0.12.2", "rand", "sha1", - "thiserror 2.0.9", + "thiserror 2.0.11", "time", "zeroize", "zopfli", diff --git a/Cargo.toml b/Cargo.toml index 1f58e6e751..ad5e799b8b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ version = "0.16.0" atomic_float = "1" bytemuck = "1.21.0" candle-core = { version = "0.8" } -clap = { version = "4.5.24", features = ["derive"] } +clap = { version = "4.5.26", features = ["derive"] } colored = "2.1.0" console_error_panic_hook = "0.1.7" csv = "1.3.1" @@ -56,7 +56,7 @@ paste = "1" percent-encoding = "2.3.1" polars = { version = "0.44.2", features = ["lazy"] } pretty_assertions = "1.4.1" -proc-macro2 = "1.0.92" +proc-macro2 = "1.0.93" protobuf = "3.7.1" protobuf-codegen = "3.7.1" quote = "1.0.38" @@ -84,7 +84,7 @@ strum = "0.26.3" strum_macros = "0.26.4" syn = { version = "2.0.95", features = ["full", "extra-traits"] } tempfile = "3.14.0" -thiserror = "2.0.9" +thiserror = "2.0.11" tokio = { version = "1.42.0", features = ["rt", "macros"] } tracing-appender = "0.2.3" tracing-core = "0.1.33" @@ -140,7 +140,7 @@ serde = { version = "1.0.217", default-features = false, features = [ "derive", "alloc", ] } # alloc is for no_std, derive is needed -serde_json = { version = "1.0.134", default-features = false } +serde_json = { version = "1.0.135", default-features = false } uuid = { version = "1.11.0", default-features = false } libc = "0.2.169" From 2d71bcee850ad794306b69a32de4a3f39da9182f Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 13 Jan 2025 09:09:51 -0500 Subject: [PATCH 13/61] Bump axum from 0.7.9 to 0.8.1 (#2683) * Bump axum from 0.7.9 to 0.8.1 Bumps [axum](https://github.com/tokio-rs/axum) from 0.7.9 to 0.8.1. - [Release notes](https://github.com/tokio-rs/axum/releases) - [Changelog](https://github.com/tokio-rs/axum/blob/main/CHANGELOG.md) - [Commits](https://github.com/tokio-rs/axum/commits) --- updated-dependencies: - dependency-name: axum dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Fix axum breaking change (Bytes type) --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Guillaume Lagrange --- Cargo.lock | 51 ++++++--------------------- crates/burn-remote/Cargo.toml | 2 +- crates/burn-remote/src/server/base.rs | 5 ++- 3 files changed, 15 insertions(+), 43 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 342747e651..b4c69e8859 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -341,14 +341,14 @@ dependencies = [ [[package]] name = "axum" -version = "0.7.9" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" dependencies = [ - "async-trait", "axum-core", "base64 0.22.1", "bytes", + "form_urlencoded", "futures-util", "http", "http-body", @@ -369,7 +369,7 @@ dependencies = [ "sha1", "sync_wrapper", "tokio", - "tokio-tungstenite 0.24.0", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -378,11 +378,10 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.4.5" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" dependencies = [ - "async-trait", "bytes", "futures-util", "http", @@ -883,7 +882,7 @@ dependencies = [ "serde", "serde_bytes", "tokio", - "tokio-tungstenite 0.26.1", + "tokio-tungstenite", "tracing-core", "tracing-subscriber", ] @@ -3827,9 +3826,9 @@ dependencies = [ [[package]] name = "matchit" -version = "0.7.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" [[package]] name = "matrixmultiply" @@ -7364,18 +7363,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-tungstenite" -version = "0.24.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite 0.24.0", -] - [[package]] name = "tokio-tungstenite" version = "0.26.1" @@ -7385,7 +7372,7 @@ dependencies = [ "futures-util", "log", "tokio", - "tungstenite 0.26.1", + "tungstenite", ] [[package]] @@ -7588,24 +7575,6 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" -[[package]] -name = "tungstenite" -version = "0.24.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" -dependencies = [ - "byteorder", - "bytes", - "data-encoding", - "http", - "httparse", - "log", - "rand", - "sha1", - "thiserror 1.0.69", - "utf-8", -] - [[package]] name = "tungstenite" version = "0.26.1" diff --git a/crates/burn-remote/Cargo.toml b/crates/burn-remote/Cargo.toml index 16a236ac94..fa6034c681 100644 --- a/crates/burn-remote/Cargo.toml +++ b/crates/burn-remote/Cargo.toml @@ -39,7 +39,7 @@ async-channel = { workspace = true, optional = true } tokio-tungstenite = { version = "0.26", optional = true } # Server dependencies -axum = { version = "0.7.9", features = ["ws"], optional = true } +axum = { version = "0.8.1", features = ["ws"], optional = true } tracing-core = { workspace = true, optional = true } tracing-subscriber = { workspace = true, optional = true } diff --git a/crates/burn-remote/src/server/base.rs b/crates/burn-remote/src/server/base.rs index 1a364c87e9..87c0d7f7cf 100644 --- a/crates/burn-remote/src/server/base.rs +++ b/crates/burn-remote/src/server/base.rs @@ -102,7 +102,10 @@ impl WsServer { let response = callback.recv().unwrap(); let bytes = rmp_serde::to_vec(&response).unwrap(); - socket.send(ws::Message::Binary(bytes)).await.unwrap(); + socket + .send(ws::Message::Binary(bytes.into())) + .await + .unwrap(); } } Err(err) => panic!("Can't start the response handler {err:?}"), From ad89dcc0e4976948979054a470e31631df7a9602 Mon Sep 17 00:00:00 2001 From: kitterion <12870762+kitterion@users.noreply.github.com> Date: Mon, 13 Jan 2025 15:10:07 +0100 Subject: [PATCH 14/61] Don't mention a fixed candle bug (#2689) --- burn-book/src/import/pytorch-model.md | 8 ++------ .../pytorch-tests/tests/config/mod.rs | 16 ++++------------ 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/burn-book/src/import/pytorch-model.md b/burn-book/src/import/pytorch-model.md index 5c17eee3e9..1f584cdc9f 100644 --- a/burn-book/src/import/pytorch-model.md +++ b/burn-book/src/import/pytorch-model.md @@ -162,17 +162,13 @@ struct NetConfig { n_head: usize, n_layer: usize, d_model: usize, - // Candle's pickle has a bug with float serialization - // https://github.com/huggingface/candle/issues/1729 - // some_float: f64, + some_float: f64, some_int: i32, some_bool: bool, some_str: String, some_list_int: Vec, some_list_str: Vec, - // Candle's pickle has a bug with float serialization - // https://github.com/huggingface/candle/issues/1729 - // some_list_float: Vec, + some_list_float: Vec, some_dict: HashMap, } diff --git a/crates/burn-import/pytorch-tests/tests/config/mod.rs b/crates/burn-import/pytorch-tests/tests/config/mod.rs index 8e67015592..10d19e4f4d 100644 --- a/crates/burn-import/pytorch-tests/tests/config/mod.rs +++ b/crates/burn-import/pytorch-tests/tests/config/mod.rs @@ -9,17 +9,13 @@ struct NetConfig { n_head: usize, n_layer: usize, d_model: usize, - // Candle's pickle has a bug with float serialization - // https://github.com/huggingface/candle/issues/1729 - // some_float: f64, + some_float: f64, some_int: i32, some_bool: bool, some_str: String, some_list_int: Vec, some_list_str: Vec, - // Candle's pickle has a bug with float serialization - // https://github.com/huggingface/candle/issues/1729 - // some_list_float: Vec, + some_list_float: Vec, some_dict: HashMap, } @@ -35,17 +31,13 @@ mod tests { n_head: 2, n_layer: 3, d_model: 512, - // Candle's pickle has a bug with float serialization - // https://github.com/huggingface/candle/issues/1729 - // some_float: 0.1, + some_float: 0.1, some_int: 1, some_bool: true, some_str: "hello".to_string(), some_list_int: vec![1, 2, 3], some_list_str: vec!["hello".to_string(), "world".to_string()], - // Candle's pickle has a bug with float serialization - // https://github.com/huggingface/candle/issues/1729 - // some_list_float: vec![0.1, 0.2, 0.3], + some_list_float: vec![0.1, 0.2, 0.3], some_dict: { let mut map = HashMap::new(); map.insert("some_key".to_string(), "some_value".to_string()); From 28f2b67711cde19fa0edbd1df7b2a16cb56c9b4a Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Mon, 13 Jan 2025 12:37:17 -0500 Subject: [PATCH 15/61] Fusion experimental feature flag (#2690) * Fusion experimental feature flag * Minor change to vec init --------- Co-authored-by: Guillaume Lagrange --- crates/burn-jit/Cargo.toml | 2 ++ crates/burn-jit/src/fusion/base.rs | 16 ++++++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index 0811374fd1..5458180203 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -25,6 +25,8 @@ export_tests = [ "paste", ] fusion = ["burn-fusion"] +fusion-experimental = ["fusion"] + std = ["cubecl/std"] template = [] diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 48587a1bf9..e8e4d82659 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -125,16 +125,20 @@ impl FusionRuntime for FusionJitRuntime { fn optimizations( device: R::Device, ) -> Vec>> { - vec![ - Box::new(ElementWiseBuilder::::new( + let mut optimizations: Vec>> = + vec![Box::new(ElementWiseBuilder::::new( device.clone(), BT::as_elem_native_unchecked().into(), - )), - Box::new(MatmulBuilder::::new( + ))]; + + if cfg!(feature = "fusion-experimental") { + optimizations.push(Box::new(MatmulBuilder::::new( device.clone(), BT::as_elem_native_unchecked().into(), - )), - ] + ))); + } + + optimizations } } From c5362a49a3b1e1b363143d9634f3890f2c89ef30 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Mon, 13 Jan 2025 19:00:27 +0100 Subject: [PATCH 16/61] Use float intrinsics for deform_conv2d backward, fix `into_data` for padded tensors (#2681) * Clamp into_data byte arrays to expected size to deal with padding * Check tf32 is supported for CMMA convolution * Use atomic float add when available, update cubecl * Update Cargo.lock * Disable f32 atomics to work around CUDA bug * Use correct elem size and run unchecked to fix CUDA * Remove unused import * Use generic float for offset calculation * remove leftover comment --- Cargo.lock | 34 ++- Cargo.toml | 6 +- .../src/kernel/conv/conv2d/gemm/launch.rs | 4 +- .../kernel/conv/deform_conv_transpose2d.rs | 205 ++++++++++++------ crates/burn-jit/src/ops/base.rs | 9 +- 5 files changed, 171 insertions(+), 87 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b4c69e8859..81a24f75c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1580,7 +1580,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" +source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1612,7 +1612,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" +source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1629,7 +1629,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" +source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1648,7 +1648,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" +source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1662,7 +1662,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" +source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1678,7 +1678,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" +source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" dependencies = [ "bytemuck", "cubecl-common 0.4.0", @@ -1704,7 +1704,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" +source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" dependencies = [ "bytemuck", "cubecl-core", @@ -1716,7 +1716,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" +source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" dependencies = [ "cubecl-common 0.4.0", "darling", @@ -1731,7 +1731,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" +source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1741,6 +1741,7 @@ dependencies = [ "petgraph", "smallvec", "stable-vec", + "type-map", ] [[package]] @@ -1768,7 +1769,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" +source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" dependencies = [ "async-channel", "async-lock", @@ -1789,7 +1790,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" +source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" dependencies = [ "cubecl-common 0.4.0", "cubecl-core", @@ -1803,7 +1804,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=8244dbb4660e373ff1ffb780feb73a5b899e5977#8244dbb4660e373ff1ffb780feb73a5b899e5977" +source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" dependencies = [ "ash", "async-channel", @@ -7593,6 +7594,15 @@ dependencies = [ "utf-8", ] +[[package]] +name = "type-map" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deb68604048ff8fa93347f02441e4487594adc20bb8a084f9e564d2b827a0a9f" +dependencies = [ + "rustc-hash 1.1.0", +] + [[package]] name = "typenum" version = "1.17.0" diff --git a/Cargo.toml b/Cargo.toml index ad5e799b8b..614ab7a76b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -104,8 +104,8 @@ text_placeholder = "0.5.1" wgpu = "23.0.0" # Benchmarks and Burnbench -chrono = "0.4.39" arboard = "3.4.1" +chrono = "0.4.39" os_info = "3.9.0" wsl = "0.1.0" @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8244dbb4660e373ff1ffb780feb73a5b899e5977" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "8244dbb4660e373ff1ffb780feb73a5b899e5977" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b179e368c5404e871176155262cfd5221ae0ed60" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b179e368c5404e871176155262cfd5221ae0ed60" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs index abc94d1a9a..032368b08a 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs @@ -80,8 +80,10 @@ fn conv2d_gemm_cmma_strategy< } else if TypeId::of::() == TypeId::of::() || TypeId::of::() == TypeId::of::() { conv2d_gemm_with_algo::(input, weight, bias, options) - } else { + } else if has_tf32(&input) { conv2d_gemm_with_algo::(input, weight, bias, options) + } else { + conv2d_gemm_with_algo::(input, weight, bias, options) } } diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index 5e51623505..ddee1360e4 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -1,8 +1,13 @@ +use std::marker::PhantomData; + use burn_tensor::{ ops::{DeformConv2dBackward, DeformConvOptions, FloatTensorOps as _}, Shape, }; -use cubecl::{calculate_cube_count_elemwise, cube, prelude::*, CubeDim, CubeLaunch}; +use cubecl::{ + calculate_cube_count_elemwise, cube, ir::Elem, prelude::*, AtomicFeature, CubeDim, CubeLaunch, + Feature, +}; use crate::{ element::BoolElement, @@ -439,15 +444,29 @@ fn compute_input_grad( let client = offset.client.clone(); let device = offset.device.clone(); + let kind = match E::as_elem_native_unchecked() { + Elem::Float(kind) => kind, + _ => unreachable!("Should be float"), + }; + let props = client.properties(); + + let supports_fadd = props.feature_enabled(Feature::AtomicFloat(AtomicFeature::Add)); + let supports_same_type = props.feature_enabled(Feature::Type(Elem::AtomicFloat(kind))); + let [batch_size, in_channels, height, width] = input_shape.dims(); let (kernel_height, kernel_width) = kernel_dims; - // Force `f32` to enable bitcasting as `u32` - let grad_in = zeros_device::( - client.clone(), - device.clone(), - Shape::new([batch_size, in_channels, height, width]), - ); + let shape = Shape::new([batch_size, in_channels, height, width]); + let grad_in = match supports_fadd && supports_same_type { + // Use type as is to save a cast + true => zeros_device::(client.clone(), device.clone(), shape), + // Force `f32` to enable bitcasting as `u32`, or use intrinsic when supported + false => zeros_device::(client.clone(), device.clone(), shape), + }; + let grad_arg = match supports_fadd && supports_same_type { + true => grad_in.as_tensor_arg::(1), + false => grad_in.as_tensor_arg::(1), + }; let use_mask = mask.is_some(); let mask = mask.unwrap_or_else(|| { @@ -458,43 +477,60 @@ fn compute_input_grad( let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(num_elements, cube_dim); - deform_col2img_kernel::launch::( - &offset.client, - cube_count, - cube_dim, - offset.as_tensor_arg::(1), - mask.as_tensor_arg::(1), - columns.as_tensor_arg::(1), - grad_in.as_tensor_arg::(1), - DeformConv2dCol2ImgArgsLaunch::new( - ScalarArg::new(options.stride[0] as u32), - ScalarArg::new(options.stride[1] as u32), - ScalarArg::new(options.dilation[0] as u32), - ScalarArg::new(options.dilation[1] as u32), - ScalarArg::new(options.padding[0] as f32), - ScalarArg::new(options.padding[1] as f32), - ScalarArg::new(options.offset_groups as u32), - ScalarArg::new(batch_size as u32), - ScalarArg::new(in_channels as u32), - ScalarArg::new(height as u32), - ScalarArg::new(width as u32), - ScalarArg::new(kernel_height as u32), - ScalarArg::new(kernel_width as u32), - ), - use_mask, - ); - - cast::(grad_in) + let launch = match (supports_fadd, supports_same_type) { + // use same type intrinsic if supported + (true, true) => deform_col2img_kernel::launch_unchecked::, R>, + // use f32 intrinsic if float add is supported at all + (true, false) => { + deform_col2img_kernel::launch_unchecked::, R> + } + // fall back to compare and swap + _ => deform_col2img_kernel::launch_unchecked::, + }; + + unsafe { + launch( + &offset.client, + cube_count, + cube_dim, + offset.as_tensor_arg::(1), + mask.as_tensor_arg::(1), + columns.as_tensor_arg::(1), + grad_arg, + DeformConv2dCol2ImgArgsLaunch::new( + ScalarArg::new(options.stride[0] as u32), + ScalarArg::new(options.stride[1] as u32), + ScalarArg::new(options.dilation[0] as u32), + ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(E::new(options.padding[0] as f32)), + ScalarArg::new(E::new(options.padding[1] as f32)), + ScalarArg::new(options.offset_groups as u32), + ScalarArg::new(batch_size as u32), + ScalarArg::new(in_channels as u32), + ScalarArg::new(height as u32), + ScalarArg::new(width as u32), + ScalarArg::new(kernel_height as u32), + ScalarArg::new(kernel_width as u32), + ), + use_mask, + ) + }; + + if !supports_same_type || !supports_fadd { + cast::(grad_in) + } else { + grad_in + } } #[derive(CubeLaunch)] -struct DeformConv2dCol2ImgArgs { +struct DeformConv2dCol2ImgArgs { stride_h: u32, stride_w: u32, dilation_h: u32, dilation_w: u32, - pad_h: f32, - pad_w: f32, + pad_h: F, + pad_w: F, offset_groups: u32, batch_size: u32, in_channels: u32, @@ -504,17 +540,19 @@ struct DeformConv2dCol2ImgArgs { kernel_width: u32, } -#[cube(launch)] -fn deform_col2img_kernel( +#[cube(launch_unchecked)] +fn deform_col2img_kernel( offset: &Tensor, mask: &Tensor, columns: &Tensor, - grad_input: &mut Tensor>, - args: &DeformConv2dCol2ImgArgs, + grad_input: &mut Tensor>, + args: &DeformConv2dCol2ImgArgs, #[comptime] use_mask: bool, ) { // Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] - let _ = mask[0]; // Keep mask in bind group + if ABSOLUTE_POS >= columns.len() { + return; + } let n_in_channels = args.in_channels; let height = args.height; @@ -545,8 +583,8 @@ fn deform_col2img_kernel( let offset_y_idx = (offset_idx * out_h + out_y) * out_w + out_x; let offset_x_idx = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; - let offset_y = f32::cast_from(offset[offset_base_idx + offset_y_idx]); - let offset_x = f32::cast_from(offset[offset_base_idx + offset_x_idx]); + let offset_y = offset[offset_base_idx + offset_y_idx]; + let offset_x = offset[offset_base_idx + offset_x_idx]; let mask_value = if use_mask { let mask_base_idx = @@ -558,47 +596,78 @@ fn deform_col2img_kernel( }; let y = - f32::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h) - args.pad_h + offset_y; + F::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h) - args.pad_h + offset_y; let x = - f32::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w) - args.pad_w + offset_x; + F::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w) - args.pad_w + offset_x; for dy in -1..=1 { #[unroll] for dx in -1..=1 { - let yp = f32::floor(y) + dy as f32; - let xp = f32::floor(x) + dx as f32; - - if yp >= 0.0 - && yp < height as f32 - && xp >= 0.0 - && xp < width as f32 - && f32::abs(y - yp) < 1.0 - && f32::abs(x - xp) < 1.0 + let yp = F::floor(y) + F::cast_from(dy); + let xp = F::floor(x) + F::cast_from(dx); + + if yp >= F::new(0.0) + && yp < F::cast_from(height) + && xp >= F::new(0.0) + && xp < F::cast_from(width) + && F::abs(y - yp) < F::new(1.0) + && F::abs(x - xp) < F::new(1.0) { let gradient_pos = - ((batch * n_in_channels + in_channel) * height + yp as u32) * width + xp as u32; + ((batch * n_in_channels + in_channel) * height + u32::cast_from(yp)) * width + + u32::cast_from(xp); - let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp)); + let weight = (F::new(1.0) - F::abs(y - yp)) * (F::new(1.0) - F::abs(x - xp)); let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS]; - float_atomic_add(&mut grad_input[gradient_pos], f32::cast_from(value)); + FAdd::float_atomic_add::(&mut grad_input[gradient_pos], value); } } } } #[cube] -fn float_atomic_add(ptr: &mut Atomic, value: f32) { - if value != 0.0 { - let mut v = Atomic::::load(ptr); - loop { - let prev = v; - let v_float = f32::bitcast_from(v); - let new = u32::bitcast_from(v_float + value); - v = Atomic::::compare_and_swap(ptr, v, new); - if prev == v { - break; +trait FloatAtomicAdd: Send + Sync + 'static { + type ProxyType: Numeric; + + fn float_atomic_add(ptr: &mut Atomic, value: F); +} + +#[derive(CubeType)] +struct IntrinsicFloatAtomicAdd { + _ty: PhantomData, +} + +#[derive(CubeType)] +struct CASFloatAtomicAdd; + +#[cube] +impl FloatAtomicAdd for IntrinsicFloatAtomicAdd { + type ProxyType = FAdd; + + fn float_atomic_add(ptr: &mut Atomic, value: F) { + let value = FAdd::cast_from(value); + Atomic::add(ptr, value); + } +} + +#[cube] +impl FloatAtomicAdd for CASFloatAtomicAdd { + type ProxyType = u32; + + fn float_atomic_add(ptr: &mut Atomic, value: F) { + let value = f32::cast_from(value); + if value != 0.0 { + let mut v = Atomic::load(ptr); + loop { + let prev = v; + let v_float = f32::bitcast_from(v); + let new = u32::bitcast_from(v_float + value); + v = Atomic::compare_and_swap(ptr, v, new); + if prev == v { + break; + } } } } diff --git a/crates/burn-jit/src/ops/base.rs b/crates/burn-jit/src/ops/base.rs index bce600604e..645aaf1535 100644 --- a/crates/burn-jit/src/ops/base.rs +++ b/crates/burn-jit/src/ops/base.rs @@ -17,7 +17,8 @@ pub(crate) async fn into_data(tensor: JitTensor let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; - TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) + let actual_len = tensor.shape.num_elements() * size_of::(); + TensorData::new(E::from_bytes(&bytes[..actual_len]).to_vec(), tensor.shape) } /// Read data from a `JitTensor` synchronously @@ -26,7 +27,8 @@ pub fn into_data_sync(tensor: JitTensor) -> Ten let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one(tensor.handle.binding()); - TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) + let actual_len = tensor.shape.num_elements() * size_of::(); + TensorData::new(E::from_bytes(&bytes[..actual_len]).to_vec(), tensor.shape) } pub(crate) async fn bool_into_data( @@ -34,8 +36,9 @@ pub(crate) async fn bool_into_data( ) -> TensorData { let tensor = kernel::into_contiguous(tensor); let bytes = tensor.client.read_one_async(tensor.handle.binding()).await; + let actual_len = tensor.shape.num_elements() * size_of::(); TensorData::new( - BT::from_bytes(&bytes) + BT::from_bytes(&bytes[..actual_len]) .iter() .map(|i| *i != BT::false_val()) .collect(), From 3e90b6eac7287cf1ce2e95ccbb8f936fd41aa60e Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Mon, 13 Jan 2025 13:25:27 -0500 Subject: [PATCH 17/61] Fix/web examples (#2691) * Fix web examples * Fix conflicting feature flags w/ `default-features = false` burn-tensor - Move rayon to default flags (but not enabled by std for now) burn-jit - Does not require burn-tensor default features - Enable burn-tensor/std w/ std flag burn-import - Split burn-tensor (only std flag) w/ and burn-ndarray - Use of burn-tensor w/ std and ndarray would enable rayon (for build dependency) burn-wgpu - Does not require burn-tensor default features * Update cubecl --- Cargo.lock | 103 ++++++------------- Cargo.toml | 4 +- crates/burn-import/Cargo.toml | 3 +- crates/burn-import/src/burn/graph.rs | 2 +- crates/burn-import/src/burn/node/base.rs | 3 +- crates/burn-jit/Cargo.toml | 4 +- crates/burn-tensor/Cargo.toml | 3 +- crates/burn-wgpu/Cargo.toml | 2 +- examples/image-classification-web/Cargo.toml | 3 +- examples/mnist-inference-web/Cargo.toml | 3 +- examples/mnist-inference-web/src/state.rs | 4 +- 11 files changed, 45 insertions(+), 89 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 81a24f75c3..b3c03aa56d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -650,7 +650,7 @@ dependencies = [ name = "burn-common" version = "0.16.0" dependencies = [ - "cubecl-common 0.4.0", + "cubecl-common", "dashmap", "getrandom", "indicatif", @@ -790,6 +790,7 @@ name = "burn-import" version = "0.16.0" dependencies = [ "burn", + "burn-ndarray", "candle-core", "derive-new 0.7.0", "half", @@ -1580,39 +1581,21 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" +source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" dependencies = [ "cubecl-core", "cubecl-cuda", "cubecl-hip", "cubecl-linalg", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "cubecl-wgpu", "half", ] -[[package]] -name = "cubecl-common" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "51d402af454241d28d303a4cf4d2a861fae18404d65964c31934f746a40a6cf4" -dependencies = [ - "derive-new 0.6.0", - "embassy-futures", - "futures-lite", - "getrandom", - "log", - "portable-atomic", - "rand", - "serde", - "spin", - "web-time", -] - [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" +source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1629,12 +1612,12 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" +source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" dependencies = [ "bytemuck", - "cubecl-common 0.4.0", + "cubecl-common", "cubecl-macros", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "derive-new 0.6.0", "derive_more 1.0.0", "half", @@ -1648,12 +1631,12 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" +source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" dependencies = [ "bytemuck", - "cubecl-common 0.4.0", + "cubecl-common", "cubecl-core", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "derive-new 0.6.0", "half", "log", @@ -1662,13 +1645,13 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" +source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" dependencies = [ "bytemuck", - "cubecl-common 0.4.0", + "cubecl-common", "cubecl-core", "cubecl-cpp", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "cudarc", "derive-new 0.6.0", "half", @@ -1678,14 +1661,14 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" +source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" dependencies = [ "bytemuck", - "cubecl-common 0.4.0", + "cubecl-common", "cubecl-core", "cubecl-cpp", "cubecl-hip-sys", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "derive-new 0.6.0", "half", "log", @@ -1704,11 +1687,11 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" +source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" dependencies = [ "bytemuck", "cubecl-core", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "half", "serde", ] @@ -1716,9 +1699,9 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" +source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" dependencies = [ - "cubecl-common 0.4.0", + "cubecl-common", "darling", "derive-new 0.6.0", "ident_case", @@ -1731,9 +1714,9 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" +source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" dependencies = [ - "cubecl-common 0.4.0", + "cubecl-common", "cubecl-core", "float-ord", "log", @@ -1744,37 +1727,15 @@ dependencies = [ "type-map", ] -[[package]] -name = "cubecl-runtime" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3468467f412dff4bbf97fb5061a3557445f017299e2fb73ef7b96c6cdb799bc3" -dependencies = [ - "async-channel", - "async-lock", - "cfg_aliases 0.2.1", - "cubecl-common 0.3.0", - "derive-new 0.6.0", - "dirs", - "hashbrown 0.14.5", - "log", - "md5", - "sanitize-filename 0.5.0", - "serde", - "serde_json", - "spin", - "wasm-bindgen-futures", -] - [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" +source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" dependencies = [ "async-channel", "async-lock", "cfg_aliases 0.2.1", - "cubecl-common 0.4.0", + "cubecl-common", "derive-new 0.6.0", "dirs", "hashbrown 0.14.5", @@ -1790,12 +1751,12 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" +source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" dependencies = [ - "cubecl-common 0.4.0", + "cubecl-common", "cubecl-core", "cubecl-opt", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "half", "hashbrown 0.14.5", "rspirv", @@ -1804,16 +1765,16 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=b179e368c5404e871176155262cfd5221ae0ed60#b179e368c5404e871176155262cfd5221ae0ed60" +source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" dependencies = [ "ash", "async-channel", "bytemuck", "cfg-if", "cfg_aliases 0.2.1", - "cubecl-common 0.4.0", + "cubecl-common", "cubecl-core", - "cubecl-runtime 0.4.0", + "cubecl-runtime", "cubecl-spirv", "derive-new 0.6.0", "hashbrown 0.14.5", @@ -3438,7 +3399,6 @@ dependencies = [ "burn-candle", "burn-import", "console_error_panic_hook", - "cubecl-runtime 0.3.0", "js-sys", "log", "serde", @@ -3983,7 +3943,6 @@ version = "0.16.0" dependencies = [ "burn", "console_error_panic_hook", - "cubecl-runtime 0.3.0", "js-sys", "serde", "wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index 614ab7a76b..a741216cfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b179e368c5404e871176155262cfd5221ae0ed60" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "b179e368c5404e871176155262cfd5221ae0ed60" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-import/Cargo.toml b/crates/burn-import/Cargo.toml index a75985e71e..14a8bffa32 100644 --- a/crates/burn-import/Cargo.toml +++ b/crates/burn-import/Cargo.toml @@ -20,7 +20,8 @@ onnx = [] pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"] [dependencies] -burn = { path = "../burn", version = "0.16.0", features = ["ndarray"] } +burn = { path = "../burn", version = "0.16.0", default-features = false, features = ["std"]} +burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", default-features = false } onnx-ir = { path = "../onnx-ir", version = "0.16.0" } candle-core = { workspace = true } derive-new = { workspace = true } diff --git a/crates/burn-import/src/burn/graph.rs b/crates/burn-import/src/burn/graph.rs index f6dee5479d..2411a60e27 100644 --- a/crates/burn-import/src/burn/graph.rs +++ b/crates/burn-import/src/burn/graph.rs @@ -50,7 +50,7 @@ pub struct BurnGraph { } // The backend used for recording. -type Backend = burn::backend::ndarray::NdArray; +type Backend = burn_ndarray::NdArray; impl BurnGraph { /// Register a new operation node into the graph. diff --git a/crates/burn-import/src/burn/node/base.rs b/crates/burn-import/src/burn/node/base.rs index f945cb0dce..480d3a1f1c 100644 --- a/crates/burn-import/src/burn/node/base.rs +++ b/crates/burn-import/src/burn/node/base.rs @@ -17,13 +17,12 @@ use super::{ unsqueeze::UnsqueezeNode, }; use crate::burn::{BurnImports, Scope, Type}; -use burn::backend::NdArray; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use serde::Serialize; /// Backend used for serialization. -pub type SerializationBackend = NdArray; +pub type SerializationBackend = burn_ndarray::NdArray; /// Codegen trait that should be implemented by all [node](Node) entries. pub trait NodeCodegen: std::fmt::Debug { diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index 5458180203..b6836a1b82 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -26,14 +26,14 @@ export_tests = [ ] fusion = ["burn-fusion"] fusion-experimental = ["fusion"] +std = ["cubecl/std", "burn-tensor/std"] -std = ["cubecl/std"] template = [] [dependencies] burn-common = { path = "../burn-common", version = "0.16.0" } burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ +burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ "cubecl", "repr", ] } diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index 0eb79c2e88..55e14e174c 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -16,7 +16,7 @@ cubecl = ["dep:cubecl"] cubecl-cuda = ["cubecl", "cubecl/cuda"] cubecl-hip = ["cubecl", "cubecl/hip"] cubecl-wgpu = ["cubecl", "cubecl/wgpu"] -default = ["std", "repr"] +default = ["std", "repr", "burn-common/rayon"] doc = ["default"] experimental-named-tensor = [] export_tests = ["burn-tensor-testgen", "cubecl"] @@ -26,7 +26,6 @@ std = [ "half/std", "num-traits/std", "burn-common/std", - "burn-common/rayon", "colored", ] diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index d3975faad3..055b53ae2f 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -26,7 +26,7 @@ cubecl = { workspace = true, features = ["wgpu"] } burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ +burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ "cubecl-wgpu", ] } diff --git a/examples/image-classification-web/Cargo.toml b/examples/image-classification-web/Cargo.toml index 5f036532ae..44591bdad1 100644 --- a/examples/image-classification-web/Cargo.toml +++ b/examples/image-classification-web/Cargo.toml @@ -17,7 +17,6 @@ half_precision = [] burn = { path = "../../crates/burn", version = "0.16.0", default-features = false, features = [ "ndarray", "wgpu", ] } -cubecl-runtime = { version = "0.3.0", features = ["channel-mpsc"] } # missing feature flags burn-candle = { path = "../../crates/burn-candle", version = "0.16.0", default-features = false } log = { workspace = true } @@ -35,4 +34,4 @@ js-sys = "0.3" [build-dependencies] # Used to generate code from ONNX model -burn-import = { path = "../../crates/burn-import" } +burn-import = { path = "../../crates/burn-import", default-features = false, features = ["onnx"]} diff --git a/examples/mnist-inference-web/Cargo.toml b/examples/mnist-inference-web/Cargo.toml index c8f7803607..a72b3d3c1b 100644 --- a/examples/mnist-inference-web/Cargo.toml +++ b/examples/mnist-inference-web/Cargo.toml @@ -13,11 +13,10 @@ crate-type = ["cdylib"] default = ["ndarray"] ndarray = ["burn/ndarray"] -wgpu = ["burn/wgpu", "cubecl-runtime"] +wgpu = ["burn/wgpu"] [dependencies] burn = { path = "../../crates/burn", default-features = false } -cubecl-runtime = { version = "0.3.0", optional = true, features = ["channel-mpsc"] } # missing feature flag serde = { workspace = true } console_error_panic_hook = { workspace = true } diff --git a/examples/mnist-inference-web/src/state.rs b/examples/mnist-inference-web/src/state.rs index 5516fe90fd..c54290bb7b 100644 --- a/examples/mnist-inference-web/src/state.rs +++ b/examples/mnist-inference-web/src/state.rs @@ -5,7 +5,7 @@ use burn::{ }; #[cfg(feature = "wgpu")] -use burn::backend::wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuDevice}; +use burn::backend::wgpu::{init_setup_async, AutoGraphicsApi, Wgpu, WgpuDevice}; #[cfg(feature = "wgpu")] pub type Backend = Wgpu; @@ -18,7 +18,7 @@ static STATE_ENCODED: &[u8] = include_bytes!("../model.bin"); /// Builds and loads trained parameters into the model. pub async fn build_and_load_model() -> Model { #[cfg(feature = "wgpu")] - init_async::(&WgpuDevice::default(), Default::default()).await; + init_setup_async::(&WgpuDevice::default(), Default::default()).await; let model: Model = Model::new(&Default::default()); let record = BinBytesRecorder::::default() From 3990a8ab2b56383c90989638a91f7e290f95a460 Mon Sep 17 00:00:00 2001 From: Maxime Tremblay Date: Mon, 13 Jan 2025 14:48:10 -0500 Subject: [PATCH 18/61] implement benchmark for reduce kernel (#2692) --- backend-comparison/Cargo.toml | 4 + backend-comparison/README.md | 1 + backend-comparison/benches/reduce.rs | 102 ++++++++++++++++++++ backend-comparison/src/burnbenchapp/base.rs | 2 + backend-comparison/src/lib.rs | 23 +++-- 5 files changed, 122 insertions(+), 10 deletions(-) create mode 100644 backend-comparison/benches/reduce.rs diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index ab726fa7be..9e5054c6af 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -124,6 +124,10 @@ path = "benches/resnet.rs" harness = false name = "autodiff" +[[bench]] +harness = false +name = "reduce" + [[bin]] name = "burnbench" path = "src/bin/burnbench.rs" diff --git a/backend-comparison/README.md b/backend-comparison/README.md index ba7042bbc1..6f4b547a22 100644 --- a/backend-comparison/README.md +++ b/backend-comparison/README.md @@ -57,6 +57,7 @@ Available Benchmarks: - conv-transpose3d - conv2d - conv3d +- reduce ``` #### Run benchmarks diff --git a/backend-comparison/benches/reduce.rs b/backend-comparison/benches/reduce.rs new file mode 100644 index 0000000000..df365f2306 --- /dev/null +++ b/backend-comparison/benches/reduce.rs @@ -0,0 +1,102 @@ +use backend_comparison::persistence::save; +use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; +use burn_common::benchmark::{run_benchmark, Benchmark}; + +enum Instruction { + ArgMin(usize), + SumDim(usize), + Sum, +} + +struct ReduceBenchmark { + instruction: Instruction, + shape: Shape, + device: B::Device, + tensor: Tensor, +} + +impl ReduceBenchmark { + pub fn new(instruction: Instruction, device: B::Device) -> Self { + let shape = Shape::new([4096, 512, 64]); + let tensor = Tensor::random(shape.clone(), Distribution::Default, &device); + Self { + instruction, + shape, + device, + tensor, + } + } +} + +impl Benchmark for ReduceBenchmark { + type Args = (); + + fn prepare(&self) -> Self::Args {} + + fn execute(&self, _: Self::Args) { + match self.instruction { + Instruction::ArgMin(axis) => { + self.tensor.clone().argmin(axis); + } + Instruction::SumDim(axis) => { + self.tensor.clone().sum_dim(axis); + } + Instruction::Sum => { + self.tensor.clone().sum(); + } + } + } + + fn name(&self) -> String { + match self.instruction { + Instruction::ArgMin(axis) => format!("reduce-argmin-{axis}"), + Instruction::SumDim(axis) => format!("reduce-sum-{axis}"), + Instruction::Sum => String::from("reduce-sum-full"), + } + } + + fn sync(&self) { + B::sync(&self.device) + } + + fn shapes(&self) -> Vec> { + vec![self.shape.dims.clone()] + } +} + +#[allow(dead_code)] +fn bench( + device: &B::Device, + feature_name: &str, + url: Option<&str>, + token: Option<&str>, +) { + let mut benchmarks = Vec::new(); + + for axis in 0..3 { + benchmarks.push(ReduceBenchmark::::new( + Instruction::ArgMin(axis), + device.clone(), + )); + + benchmarks.push(ReduceBenchmark::::new( + Instruction::SumDim(axis), + device.clone(), + )); + } + + benchmarks.push(ReduceBenchmark::::new(Instruction::Sum, device.clone())); + + save::( + benchmarks.into_iter().map(run_benchmark).collect(), + device, + feature_name, + url, + token, + ) + .unwrap(); +} + +fn main() { + backend_comparison::bench_on_backend!(); +} diff --git a/backend-comparison/src/burnbenchapp/base.rs b/backend-comparison/src/burnbenchapp/base.rs index 83c5060a6b..9eba1485b3 100644 --- a/backend-comparison/src/burnbenchapp/base.rs +++ b/backend-comparison/src/burnbenchapp/base.rs @@ -123,6 +123,8 @@ enum BenchmarkValues { Conv2d, #[strum(to_string = "conv3d")] Conv3d, + #[strum(to_string = "reduce")] + Reduce, } pub fn execute() { diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index 03e2d70444..26b08bc3b8 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -54,6 +54,9 @@ fn update_panic_hook() { #[macro_export] macro_rules! bench_on_backend { () => { + $crate::bench_on_backend!(bench) + }; + ($fn_name:ident) => { use std::env; backend_comparison::init_log().unwrap(); @@ -99,14 +102,14 @@ macro_rules! bench_on_backend { { use burn::backend::wgpu::{Wgpu, WgpuDevice}; - bench::>(&WgpuDevice::default(), feature_name, url, token); + $fn_name::>(&WgpuDevice::default(), feature_name, url, token); } #[cfg(any(feature = "wgpu-spirv"))] { use burn::backend::wgpu::{Wgpu, WgpuDevice}; - bench::>(&WgpuDevice::default(), feature_name, url, token); + $fn_name::>(&WgpuDevice::default(), feature_name, url, token); } #[cfg(feature = "tch-gpu")] @@ -117,7 +120,7 @@ macro_rules! bench_on_backend { let device = LibTorchDevice::Cuda(0); #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; - bench::>(&device, feature_name, url, token); + $fn_name::>(&device, feature_name, url, token); } #[cfg(feature = "tch-cpu")] @@ -125,7 +128,7 @@ macro_rules! bench_on_backend { use burn::backend::{libtorch::LibTorchDevice, LibTorch}; let device = LibTorchDevice::Cpu; - bench::(&device, feature_name, url, token); + $fn_name::(&device, feature_name, url, token); } #[cfg(any( @@ -139,7 +142,7 @@ macro_rules! bench_on_backend { use burn::backend::NdArray; let device = NdArrayDevice::Cpu; - bench::(&device, feature_name, url, token); + $fn_name::(&device, feature_name, url, token); } #[cfg(feature = "candle-cpu")] @@ -148,7 +151,7 @@ macro_rules! bench_on_backend { use burn::backend::Candle; let device = CandleDevice::Cpu; - bench::(&device, feature_name, url, token); + $fn_name::(&device, feature_name, url, token); } #[cfg(feature = "candle-cuda")] @@ -157,7 +160,7 @@ macro_rules! bench_on_backend { use burn::backend::Candle; let device = CandleDevice::cuda(0); - bench::(&device, feature_name, url, token); + $fn_name::(&device, feature_name, url, token); } #[cfg(feature = "candle-metal")] @@ -166,21 +169,21 @@ macro_rules! bench_on_backend { use burn::backend::Candle; let device = CandleDevice::metal(0); - bench::(&device, feature_name, url, token); + $fn_name::(&device, feature_name, url, token); } #[cfg(feature = "cuda-jit")] { use burn::backend::cuda_jit::{Cuda, CudaDevice}; - bench::>(&CudaDevice::default(), feature_name, url, token); + $fn_name::>(&CudaDevice::default(), feature_name, url, token); } #[cfg(feature = "hip-jit")] { use burn::backend::hip_jit::{Hip, HipDevice}; - bench::>(&HipDevice::default(), feature_name, url, token); + $fn_name::>(&HipDevice::default(), feature_name, url, token); } }; } From 9228d9156de05dc32c08fdf7a6fa689c1357647d Mon Sep 17 00:00:00 2001 From: Maxime Tremblay Date: Mon, 13 Jan 2025 17:04:53 -0500 Subject: [PATCH 19/61] Merge reduce (#2673) --- Cargo.lock | 480 +++++++++--------- Cargo.toml | 4 +- crates/burn-jit/Cargo.toml | 2 +- crates/burn-jit/src/kernel/reduce/base.rs | 142 +++--- crates/burn-jit/src/kernel/reduce/mod.rs | 7 - .../src/kernel/reduce/naive/argmax.rs | 36 -- .../src/kernel/reduce/naive/argmin.rs | 36 -- .../burn-jit/src/kernel/reduce/naive/base.rs | 25 - .../src/kernel/reduce/naive/kernel.rs | 71 --- .../src/kernel/reduce/naive/mean_dim.rs | 27 - .../burn-jit/src/kernel/reduce/naive/mod.rs | 7 - .../src/kernel/reduce/naive/prod_dim.rs | 26 - .../src/kernel/reduce/naive/sum_dim.rs | 26 - crates/burn-jit/src/kernel/reduce/prod.rs | 15 - .../src/kernel/reduce/shared/argmax.rs | 63 --- .../src/kernel/reduce/shared/argmin.rs | 64 --- .../burn-jit/src/kernel/reduce/shared/base.rs | 33 -- .../src/kernel/reduce/shared/kernel.rs | 117 ----- .../src/kernel/reduce/shared/mean_dim.rs | 44 -- .../burn-jit/src/kernel/reduce/shared/mod.rs | 7 - .../src/kernel/reduce/shared/prod_dim.rs | 43 -- .../src/kernel/reduce/shared/sum_dim.rs | 43 -- .../src/kernel/reduce/subcube/argmax.rs | 54 -- .../src/kernel/reduce/subcube/argmin.rs | 54 -- .../src/kernel/reduce/subcube/base.rs | 15 - .../src/kernel/reduce/subcube/kernel.rs | 134 ----- .../src/kernel/reduce/subcube/mean_dim.rs | 45 -- .../burn-jit/src/kernel/reduce/subcube/mod.rs | 7 - .../src/kernel/reduce/subcube/prod_dim.rs | 44 -- .../src/kernel/reduce/subcube/sum_dim.rs | 44 -- crates/burn-jit/src/kernel/reduce/sum.rs | 15 - crates/burn-jit/src/kernel/reduce/tune.rs | 222 ++++++++ .../burn-jit/src/kernel/reduce/tune/base.rs | 94 ---- crates/burn-jit/src/kernel/reduce/tune/key.rs | 39 -- crates/burn-jit/src/kernel/reduce/tune/mod.rs | 7 - crates/burn-jit/src/ops/float_ops.rs | 14 +- crates/burn-jit/src/ops/int_ops.rs | 16 +- crates/burn-jit/src/tests/mod.rs | 3 + crates/burn-jit/src/tests/reduce.rs | 128 +++++ crates/burn-jit/src/tune_key.rs | 4 +- crates/burn-tensor/src/tensor/shape.rs | 7 + 41 files changed, 708 insertions(+), 1556 deletions(-) delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/argmax.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/argmin.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/base.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/kernel.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/mod.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/prod.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/argmax.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/argmin.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/base.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/kernel.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/mod.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/argmax.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/argmin.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/base.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/kernel.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/mean_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/mod.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/prod_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/subcube/sum_dim.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/sum.rs create mode 100644 crates/burn-jit/src/kernel/reduce/tune.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/tune/base.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/tune/key.rs delete mode 100644 crates/burn-jit/src/kernel/reduce/tune/mod.rs create mode 100644 crates/burn-jit/src/tests/reduce.rs diff --git a/Cargo.lock b/Cargo.lock index b3c03aa56d..feff4ed96a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -139,19 +139,20 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "3.0.6" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", + "once_cell", "windows-sys 0.59.0", ] [[package]] name = "anyhow" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1fd03a028ef38ba2276dce7e33fcd6369c158a1bca17946c4b1b701891c1ff7" +checksum = "34ac096ce696dc2fcabef30516bb13c0a68a11d30131d3df6f04711467681b04" [[package]] name = "arbitrary" @@ -188,7 +189,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -269,18 +270,18 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] name = "async-trait" -version = "0.1.83" +version = "0.1.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -510,9 +511,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "1be3f42a67d6d345ecd59f675f3f012d6974981560836e938c22b424b85ce1be" dependencies = [ "serde", ] @@ -594,9 +595,9 @@ dependencies = [ [[package]] name = "bstr" -version = "1.11.1" +version = "1.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "786a307d683a5bf92e6fd5fd69a7eb613751668d1d8d67d802846dfe367c62c8" +checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" dependencies = [ "memchr", "serde", @@ -753,7 +754,7 @@ dependencies = [ "derive-new 0.7.0", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -804,7 +805,7 @@ dependencies = [ "rust-format", "serde", "serde_json", - "syn 2.0.95", + "syn 2.0.96", "thiserror 2.0.11", "tracing-core", "tracing-subscriber", @@ -987,13 +988,13 @@ dependencies = [ [[package]] name = "bytemuck_derive" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcfcc3cd946cb52f0bbfdbbcfa2f4e24f75ebb6c0e1002f7c25904fada18b9ec" +checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1043,9 +1044,9 @@ dependencies = [ [[package]] name = "candle-core" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1e306c8a4276ba57ce9fac76d823cc8c8a7fca14bf222ac20ad8b12c4273152" +checksum = "855dfedff437d2681d68e1f34ae559d88b0dd84aa5a6b63f2c8e75ebdd875bbf" dependencies = [ "accelerate-src", "byteorder", @@ -1073,18 +1074,18 @@ dependencies = [ [[package]] name = "candle-kernels" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbd8ea6588f3c6286ea89a52dad3365f0536fd0b71e729fa998cc2347f1df3b6" +checksum = "53343628fa470b7075c28c589b98735b4220b464e37ddbb8e117040e199f4787" dependencies = [ "bindgen_cuda", ] [[package]] name = "candle-metal-kernels" -version = "0.8.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbc6621c7e2202f4f129bcc3185c2c6d4fa2fc6b8f3f2b07eaf7c06042910c83" +checksum = "50fa64274a009a5d95c542b10bf3a4ea809bd394654c6ae99233bcc35b3a33ef" dependencies = [ "metal 0.27.0", "once_cell", @@ -1118,9 +1119,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.4" +version = "1.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9157bbaa6b165880c27a4293a474c91cdcf265cc68cc829bf10be0964a391caf" +checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" dependencies = [ "jobserver", "libc", @@ -1260,7 +1261,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1333,9 +1334,9 @@ dependencies = [ [[package]] name = "compact_str" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6050c3a16ddab2e412160b31f2c871015704239bca62f72f6e5f0be631d3f644" +checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32" dependencies = [ "castaway", "cfg-if", @@ -1522,7 +1523,7 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "crossterm_winapi", "mio", "parking_lot 0.12.3", @@ -1581,12 +1582,13 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "cubecl-core", "cubecl-cuda", "cubecl-hip", "cubecl-linalg", + "cubecl-reduce", "cubecl-runtime", "cubecl-wgpu", "half", @@ -1595,7 +1597,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1612,7 +1614,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "bytemuck", "cubecl-common", @@ -1631,7 +1633,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "bytemuck", "cubecl-common", @@ -1645,7 +1647,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "bytemuck", "cubecl-common", @@ -1661,7 +1663,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "bytemuck", "cubecl-common", @@ -1677,9 +1679,9 @@ dependencies = [ [[package]] name = "cubecl-hip-sys" -version = "6.3.0" +version = "6.3.1000" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9974218b3ff1f1e7b2f11ce254fd90b3ebcc2af6b4d084f7f6a0c351fb16112c" +checksum = "d4d987c1720eab39c72c515377a8001f683a4c4d99232a29fc0de389d9a8ce4f" dependencies = [ "libc", ] @@ -1687,7 +1689,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "bytemuck", "cubecl-core", @@ -1699,7 +1701,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "cubecl-common", "darling", @@ -1708,13 +1710,13 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "cubecl-common", "cubecl-core", @@ -1727,10 +1729,20 @@ dependencies = [ "type-map", ] +[[package]] +name = "cubecl-reduce" +version = "0.4.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +dependencies = [ + "cubecl-core", + "cubecl-runtime", + "num-traits", +] + [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "async-channel", "async-lock", @@ -1751,7 +1763,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "cubecl-common", "cubecl-core", @@ -1765,7 +1777,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=4c42d0b54ac9069ff520c7719e7ef77833248e34#4c42d0b54ac9069ff520c7719e7ef77833248e34" +source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" dependencies = [ "ash", "async-channel", @@ -1882,7 +1894,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1893,7 +1905,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1939,7 +1951,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1950,7 +1962,7 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1961,7 +1973,7 @@ checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1982,7 +1994,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -1992,7 +2004,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -2003,7 +2015,7 @@ checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -2023,7 +2035,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", "unicode-xid", ] @@ -2079,7 +2091,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -2162,7 +2174,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -2174,14 +2186,14 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] name = "env_filter" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +checksum = "186e05a59d4c50738528153b83b0b0194d3a29507dfec16eccd4b342903397d0" dependencies = [ "log", "regex", @@ -2189,9 +2201,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.5" +version = "0.11.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" +checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" dependencies = [ "anstream", "anstyle", @@ -2236,9 +2248,9 @@ checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" [[package]] name = "event-listener" -version = "5.3.1" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" dependencies = [ "concurrent-queue", "parking", @@ -2272,9 +2284,9 @@ dependencies = [ [[package]] name = "fake" -version = "3.0.1" +version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "661cb0601b5f4050d1e65452c5b0ea555c0b3e88fb5ed7855906adc6c42523ef" +checksum = "aef603df4ba9adbca6a332db7da6f614f21eafefbaf8e087844e452fdec152d0" dependencies = [ "deunicode", "rand", @@ -2373,9 +2385,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foldhash" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" [[package]] name = "foreign-types" @@ -2404,7 +2416,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -2488,9 +2500,9 @@ checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cef40d21ae2c515b51041df9ed313ed21e572df340ea58a922a0aefe7e8891a1" +checksum = "f5edaec856126859abb19ed65f39e90fea3a9574b9707f13539acf4abf7eb532" dependencies = [ "fastrand", "futures-core", @@ -2507,7 +2519,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -2727,9 +2739,9 @@ dependencies = [ [[package]] name = "gix-fs" -version = "0.12.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34740384d8d763975858fa2c176b68652a6fcc09f616e24e3ce967b0d370e4d8" +checksum = "3b3d4fac505a621f97e5ce2c69fdc425742af00c0920363ca4074f0eb48b1db9" dependencies = [ "fastrand", "gix-features", @@ -2791,9 +2803,9 @@ dependencies = [ [[package]] name = "glob" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "globset" @@ -2814,7 +2826,7 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf760ebf69878d9fd8f110c89703d90ce35095324d1f1edcb595c63945ee757" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "ignore", "walkdir", ] @@ -2833,9 +2845,9 @@ dependencies = [ [[package]] name = "glutin_wgl_sys" -version = "0.6.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a4e1951bbd9434a81aa496fe59ccc2235af3820d27b85f9314e279609211e2c" +checksum = "2c4ee00b289aba7a9e5306d57c2d05499b2e5dc427f84ac708bd2c090212cf3e" dependencies = [ "gl_generator", ] @@ -2846,7 +2858,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbcd2dba93594b227a1f57ee09b8b9da8892c34d55aa332e034a228d0fe6a171" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "gpu-alloc-types", ] @@ -2856,7 +2868,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98ff03b468aa837d70984d55f5d3f846f6ec31fe34bbb97c4f85219caeee1ca4" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", ] [[package]] @@ -2873,13 +2885,13 @@ dependencies = [ [[package]] name = "gpu-descriptor" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c08c1f623a8d0b722b8b99f821eb0ba672a1618f0d3b16ddbee1cedd2dd8557" +checksum = "dcf29e94d6d243368b7a56caa16bc213e4f9f8ed38c4d9557069527b5d5281ca" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "gpu-descriptor-types", - "hashbrown 0.14.5", + "hashbrown 0.15.2", ] [[package]] @@ -2888,7 +2900,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdf242682df893b86f33a73828fb09ca4b2d3bb6cc95249707fc684d27484b91" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", ] [[package]] @@ -3133,9 +3145,9 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.3" +version = "0.27.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" dependencies = [ "futures-util", "http", @@ -3322,7 +3334,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -3412,9 +3424,9 @@ dependencies = [ [[package]] name = "image-webp" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e031e8e3d94711a9ccb5d6ea357439ef3dcbed361798bd4071dc4d9793fbe22f" +checksum = "b77d01e822461baa8409e156015a1d91735549f0f2c17691bd2d996bef238f7f" dependencies = [ "byteorder-lite", "quick-error", @@ -3467,16 +3479,15 @@ dependencies = [ [[package]] name = "instability" -version = "0.3.3" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b829f37dead9dc39df40c2d3376c179fdfd2ac771f53f55d3c30dc096a3c0c6e" +checksum = "0bf9fed6d91cfb734e7476a06bde8300a1b94e217e1b523b6f0cd1a01998c71d" dependencies = [ "darling", "indoc", - "pretty_assertions", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -3496,7 +3507,7 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -3573,9 +3584,9 @@ checksum = "f5d4a7da358eff58addd2877a45865158f0d78c911d43a5784ceb7bbf52833b0" [[package]] name = "js-sys" -version = "0.3.76" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" +checksum = "1cfaf33c695fc6e08064efbc1f72ec937429614f25eef83af942d0e227c3a28f" dependencies = [ "once_cell", "wasm-bindgen", @@ -3648,7 +3659,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "libc", "redox_syscall 0.5.8", ] @@ -3666,9 +3677,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "litemap" @@ -3724,9 +3735,9 @@ dependencies = [ [[package]] name = "lz4" -version = "1.28.0" +version = "1.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d1febb2b4a79ddd1980eede06a8f7902197960aa0383ffcfdd62fe723036725" +checksum = "a20b523e860d03443e98350ceaac5e71c6ba89aea7d960769ec3ce37f4de5af4" dependencies = [ "lz4-sys", ] @@ -3870,7 +3881,7 @@ version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block", "core-graphics-types", "foreign-types 0.5.0", @@ -3885,7 +3896,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block", "core-graphics-types", "foreign-types 0.5.0", @@ -3975,7 +3986,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -4008,7 +4019,7 @@ checksum = "364f94bc34f61332abebe8cad6f6cd82a5b65cff22c828d05d0968911462ca4f" dependencies = [ "arrayvec", "bit-set", - "bitflags 2.6.0", + "bitflags 2.7.0", "cfg_aliases 0.1.1", "codespan-reporting", "hexf-parse", @@ -4199,7 +4210,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -4271,7 +4282,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -4295,7 +4306,7 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c9bff0aa1d48904a1385ea2a8b97576fbdcbc9a3cfccd0d31fe978e1c4038c5" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "libloading", "nvml-wrapper-sys", "static_assertions", @@ -4344,7 +4355,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4e89ad9e3d7d297152b17d39ed92cd50ca8063a89a9fa569046d41568891eff" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block2", "libc", "objc2", @@ -4360,7 +4371,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block2", "objc2", "objc2-foundation", @@ -4390,7 +4401,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ee638a5da3799329310ad4cfa62fbf045d5f56e3ef5ba4149e7452dcf89d5a8" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block2", "libc", "objc2", @@ -4402,7 +4413,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block2", "objc2", "objc2-foundation", @@ -4414,7 +4425,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "block2", "objc2", "objc2-foundation", @@ -4432,9 +4443,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.5" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" dependencies = [ "memchr", ] @@ -4536,9 +4547,9 @@ dependencies = [ [[package]] name = "openblas-build" -version = "0.10.10" +version = "0.10.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ca8f8c64eb5b43f5538059ccbc71391420bba14d987d7e8ab99ed62ed33e26b" +checksum = "b8140c0c1afaf88d2d30c48abad86b3bdd2334d691e08f7325a960d784240647" dependencies = [ "anyhow", "cc", @@ -4567,7 +4578,7 @@ version = "0.10.68" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "cfg-if", "foreign-types 0.3.2", "libc", @@ -4584,7 +4595,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -4613,9 +4624,9 @@ checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" [[package]] name = "os_info" -version = "3.9.0" +version = "3.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ca711d8b83edbb00b44d504503cd247c9c0bd8b0fa2694f2a1a3d8165379ce" +checksum = "6e6520c8cc998c5741ee68ec1dc369fc47e5f0ea5320018ecf2a1ccd6328f48b" dependencies = [ "log", "serde", @@ -4748,18 +4759,18 @@ dependencies = [ [[package]] name = "phf" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ade2d8b8f33c7333b51bcf0428d37e217e9f32192ae4772156f65063b8ce03dc" +checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078" dependencies = [ "phf_shared", ] [[package]] name = "phf_codegen" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8d39688d359e6b34654d328e262234662d16cc0f60ec8dcbe5e718709342a5a" +checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" dependencies = [ "phf_generator", "phf_shared", @@ -4767,9 +4778,9 @@ dependencies = [ [[package]] name = "phf_generator" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", "rand", @@ -4777,18 +4788,18 @@ dependencies = [ [[package]] name = "phf_shared" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fcb95eef784c2ac79119d1dd819e162b5da872ce6f3c3abe1e8ca1c082f72b" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" dependencies = [ "siphasher", ] [[package]] name = "pin-project-lite" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" @@ -4813,9 +4824,9 @@ dependencies = [ [[package]] name = "png" -version = "0.17.15" +version = "0.17.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67582bd5b65bdff614270e2ea89a1cf15bef71245cc1e5f7ea126977144211d" +checksum = "82151a2fc869e011c153adc57cf2789ccb8d9906ce52c0b39a6b5697749d7526" dependencies = [ "bitflags 1.3.2", "crc32fast", @@ -4915,7 +4926,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd5df9b55e614088a3270b06f8649dce76537c268d6b1ca4d9c37008b2be5949" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.7.0", "bytemuck", "chrono", "chrono-tz", @@ -4964,7 +4975,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea1b431ed816cba1120cff200f06b962748001bbb2e615ce53cfbbdf701cc136" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.7.0", "hashbrown 0.15.2", "num-traits", "once_cell", @@ -5056,7 +5067,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a8ca74f42e7b47cad241b36b98d991cc7fbb51b8d0695a055eb937588d1f310" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.7.0", "memchr", "once_cell", "polars-arrow", @@ -5201,7 +5212,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23de436f33f4d1134c58f24e7059a221b957ec20730807e0ef0c80c8e4b3d06a" dependencies = [ "ahash", - "bitflags 2.6.0", + "bitflags 2.7.0", "bytemuck", "bytes", "chrono", @@ -5403,12 +5414,12 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.25" +version = "0.2.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" +checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" dependencies = [ "proc-macro2", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -5445,7 +5456,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30" dependencies = [ "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -5568,7 +5579,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -5581,7 +5592,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -5670,9 +5681,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.8" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52cd4b1eff68bf27940dd39811292c49e007f4d0b4c357358dc9b0197be6b527" +checksum = "1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed904" dependencies = [ "cfg_aliases 0.2.1", "libc", @@ -5765,7 +5776,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabd94c2f37801c20583fc49dd5cd6b0ba68c716787c2dd6ed18571e1e63117b" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "cassowary", "compact_str", "crossterm", @@ -5846,7 +5857,7 @@ version = "11.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", ] [[package]] @@ -5915,7 +5926,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -5933,7 +5944,7 @@ version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", ] [[package]] @@ -5964,7 +5975,7 @@ checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -6157,7 +6168,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.95", + "syn 2.0.96", "unicode-ident", ] @@ -6167,7 +6178,7 @@ version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "fallible-iterator", "fallible-streaming-iterator", "hashlink", @@ -6214,11 +6225,11 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.42" +version = "0.38.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93dc38ecbab2eb790ff964bb77fa94faf256fd3e73285fd7ba0903b76bedb85" +checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "errno", "libc", "linux-raw-sys", @@ -6227,9 +6238,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.20" +version = "0.23.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5065c3f250cbd332cd894be57c40fa52387247659b14a2d6041d121547903b1b" +checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" dependencies = [ "log", "once_cell", @@ -6262,7 +6273,7 @@ dependencies = [ "openssl-probe", "rustls-pki-types", "schannel", - "security-framework 3.1.0", + "security-framework 3.2.0", ] [[package]] @@ -6296,9 +6307,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" +checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" @@ -6356,9 +6367,9 @@ dependencies = [ [[package]] name = "scc" -version = "2.2.6" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b13f8ea6177672c49d12ed964cca44836f59621981b04a3e26b87e675181de" +checksum = "28e1c91382686d21b5ac7959341fcb9780fa7c03773646995a87c950fa7be640" dependencies = [ "sdd", ] @@ -6399,7 +6410,7 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "core-foundation 0.9.4", "core-foundation-sys", "libc", @@ -6408,11 +6419,11 @@ dependencies = [ [[package]] name = "security-framework" -version = "3.1.0" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81d3f8c9bfcc3cbb6b0179eb57042d75b1582bdc65c3cb95f3fa999509c03cbc" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "core-foundation 0.10.0", "core-foundation-sys", "libc", @@ -6421,9 +6432,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.13.0" +version = "2.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1863fd3768cd83c56a7f60faa4dc0d403f1b6df0a38c3c25f44b7894e45370d5" +checksum = "49db231d56a190491cb4aeda9527f1ad45345af50b0851622a7adb8c03b01c32" dependencies = [ "core-foundation-sys", "libc", @@ -6478,7 +6489,7 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -6556,7 +6567,7 @@ checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -6685,9 +6696,9 @@ dependencies = [ [[package]] name = "siphasher" -version = "0.3.11" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" [[package]] name = "slab" @@ -6767,7 +6778,7 @@ version = "0.3.0+sdk-1.3.268.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eda41003dc44290527a59b13432d4a0379379fa074b70174882adfbdfd917844" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", ] [[package]] @@ -6871,7 +6882,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -6893,9 +6904,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.95" +version = "2.0.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46f71c0377baf4ef1cc3e3402ded576dccc315800fbc62dfc7fe04b009773b4a" +checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" dependencies = [ "proc-macro2", "quote", @@ -6919,7 +6930,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -6928,7 +6939,7 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "byteorder", "enum-as-inner", "libc", @@ -6970,7 +6981,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "core-foundation 0.9.4", "system-configuration-sys", ] @@ -7054,12 +7065,13 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.14.0" +version = "3.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" dependencies = [ "cfg-if", "fastrand", + "getrandom", "once_cell", "rustix", "windows-sys 0.59.0", @@ -7142,7 +7154,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -7153,7 +7165,7 @@ checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -7231,9 +7243,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" dependencies = [ "tinyvec_macros", ] @@ -7278,9 +7290,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.42.0" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" +checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" dependencies = [ "backtrace", "bytes", @@ -7294,13 +7306,13 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -7452,7 +7464,7 @@ checksum = "5a3a646485f7cd8f580749ab94718ad3d344bcc0cc5b0fefe43c15fdd898bb96" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -7487,7 +7499,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -7753,9 +7765,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.11.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +checksum = "b913a3b5fe84142e269d63cc62b64319ccaf89b748fc31fe025177f767a756c4" dependencies = [ "getrandom", "rand", @@ -7835,34 +7847,35 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" +checksum = "1edc8929d7499fc4e8f0be2262a241556cfc54a0bea223790e71446f2aab1ef5" dependencies = [ "cfg-if", "once_cell", + "rustversion", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" +checksum = "2f0a0651a5c2bc21487bde11ee802ccaf4c51935d0d3d42a6101f98161700bc6" dependencies = [ "bumpalo", "log", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-futures" -version = "0.4.49" +version = "0.4.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38176d9b44ea84e9184eff0bc34cc167ed044f816accfe5922e54d84cf48eca2" +checksum = "555d470ec0bc3bb57890405e5d4322cc9ea83cebb085523ced7be4144dac1e61" dependencies = [ "cfg-if", "js-sys", @@ -7873,9 +7886,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" +checksum = "7fe63fc6d09ed3792bd0897b314f53de8e16568c2b3f7982f468c0bf9bd0b407" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -7883,22 +7896,25 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" +checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.99" +version = "0.2.100" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" +checksum = "1a05d73b933a847d6cccdda8f838a22ff101ad9bf93e33684f39c1f5f0eece3d" +dependencies = [ + "unicode-ident", +] [[package]] name = "wasm-logger" @@ -7941,9 +7957,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.76" +version = "0.3.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc" +checksum = "33b6dd2ef9186f1f2072e409e99cd22a975331a6b3591b12c764e0e55c60d5d2" dependencies = [ "js-sys", "wasm-bindgen", @@ -8007,7 +8023,7 @@ checksum = "d63c3c478de8e7e01786479919c8769f62a22eec16788d8c2ac77ce2c132778a" dependencies = [ "arrayvec", "bit-vec", - "bitflags 2.6.0", + "bitflags 2.7.0", "cfg_aliases 0.1.1", "document-features", "indexmap", @@ -8034,7 +8050,7 @@ dependencies = [ "arrayvec", "ash", "bit-set", - "bitflags 2.6.0", + "bitflags 2.7.0", "block", "bytemuck", "cfg_aliases 0.1.1", @@ -8075,7 +8091,7 @@ version = "23.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "610f6ff27778148c31093f3b03abc4840f9636d58d597ca2f5977433acfe0068" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.7.0", "js-sys", "web-sys", ] @@ -8185,7 +8201,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8196,7 +8212,7 @@ checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8207,7 +8223,7 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8218,7 +8234,7 @@ checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8410,9 +8426,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.20" +version = "0.6.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" dependencies = [ "memchr", ] @@ -8426,7 +8442,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8466,9 +8482,9 @@ checksum = "ec107c4503ea0b4a98ef47356329af139c0a4f7750e621cf2973cd3385ebcb3d" [[package]] name = "xattr" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f" +checksum = "e105d177a3871454f754b33bb0ee637ecaaac997446375fd3e5d43a2ed00c909" dependencies = [ "libc", "linux-raw-sys", @@ -8477,9 +8493,9 @@ dependencies = [ [[package]] name = "xml-rs" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea8b391c9a790b496184c29f7f93b9ed5b16abb306c05415b68bcc16e4d06432" +checksum = "c5b940ebc25896e71dd073bad2dbaa2abfe97b0a391415e22ad1326d9c54e3c4" [[package]] name = "xtask" @@ -8493,9 +8509,9 @@ dependencies = [ [[package]] name = "xxhash-rust" -version = "0.8.12" +version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a5cbf750400958819fb6178eaa83bee5cd9c29a26a40cc241df8c70fdd46984" +checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" [[package]] name = "yansi" @@ -8523,7 +8539,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", "synstructure", ] @@ -8545,7 +8561,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8565,7 +8581,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", "synstructure", ] @@ -8586,7 +8602,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] @@ -8608,7 +8624,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.95", + "syn 2.0.96", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index a741216cfe..5dfebaf2b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4c42d0b54ac9069ff520c7719e7ef77833248e34" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "707093234f11b78fb6630b98fea5d13870f94282" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "707093234f11b78fb6630b98fea5d13870f94282" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index b6836a1b82..2bd0ba6f7e 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -37,7 +37,7 @@ burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = "cubecl", "repr", ] } -cubecl = { workspace = true, features = ["linalg"] } +cubecl = { workspace = true, features = ["linalg", "reduce"] } bytemuck = { workspace = true } derive-new = { workspace = true } diff --git a/crates/burn-jit/src/kernel/reduce/base.rs b/crates/burn-jit/src/kernel/reduce/base.rs index 57cdf13b1e..9ab1f5d2b6 100644 --- a/crates/burn-jit/src/kernel/reduce/base.rs +++ b/crates/burn-jit/src/kernel/reduce/base.rs @@ -1,83 +1,101 @@ -use cubecl::prelude::Numeric; - -#[cfg(feature = "autotune")] -use crate::kernel::reduce::reduce_dim_autotune; use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; -use super::{ - naive::{base::ReduceDimNaiveFamily, kernel::reduce_dim_naive}, - shared::{base::ReduceDimShared, kernel::reduce_dim_shared}, - subcube::{base::ReduceDimSubcube, kernel::reduce_dim_subcube}, -}; +use super::autotune_reduce; + +pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum}; -#[allow(dead_code)] -pub(crate) trait ReduceDimAlgorithm: - core::fmt::Debug + ReduceDimNaiveFamily + ReduceDimShared + ReduceDimSubcube -{ +/// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy). +/// +/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`. +/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid. +/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`. +/// +/// If there is no error, the output is a tensor with decreasing strides +/// where the shape of reduced dim is set to 1 but all shape are similar to the input. +pub fn reduce( + mut input: JitTensor, + strategy: ReduceStrategy, +) -> Result, cubecl::reduce::ReduceError> { + input.shape = input.shape.flatten(); + input.strides = vec![1]; + reduce_dim::(input, 0, strategy) } -/// Creates an empty output tensor with reduce output shape -pub fn init_reduce_output( - input: &JitTensor, - reduce_dim: usize, -) -> JitTensor { - let mut shape_out = input.shape.clone(); - shape_out.dims[reduce_dim] = 1; +/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy). +/// +/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`. +/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid. +/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`. +/// +/// If there is no error, the output is a tensor with decreasing strides +/// where the shape of reduced dim is set to 1 but all shape are similar to the input. +pub fn reduce_dim( + input: JitTensor, + dim: usize, + strategy: ReduceStrategy, +) -> Result, cubecl::reduce::ReduceError> { + let client = input.client.clone(); + let output = init_reduce_output::(&input, dim).ok_or( + cubecl::reduce::ReduceError::InvalidAxis { + axis: dim, + rank: input.shape.num_dims(), + }, + )?; + let result = match strategy { + ReduceStrategy::Unspecified => cubecl::reduce::reduce::( + &client, + input.as_handle_ref(), + output.as_handle_ref(), + dim, + None, + ), + ReduceStrategy::Specific(strategy) => cubecl::reduce::reduce::( + &client, + input.as_handle_ref(), + output.as_handle_ref(), + dim, + Some(strategy), + ), + #[cfg(feature = "autotune")] + ReduceStrategy::Autotune => { + autotune_reduce::(&client, input, output.clone(), dim) + } + }; + result.map(|_| output) +} - empty_device::(input.client.clone(), input.device.clone(), shape_out) +/// Creates an empty output tensor with the proper shape and decreasing strides to reduce the given `axis` of `input` +/// or return `None` if `axis` is out-of-bound. +pub fn init_reduce_output( + input: &JitTensor, + dim: usize, +) -> Option> { + (dim < input.shape.num_dims()).then(|| { + let mut shape_out = input.shape.clone(); + shape_out.dims[dim] = 1; + empty_device::(input.client.clone(), input.device.clone(), shape_out) + }) } +/// Select a strategy to perform a reduction. #[derive(Copy, Clone, Debug)] -#[allow(missing_docs)] pub enum ReduceStrategy { - /// Naive - Naive, - /// Use shared memory as an accumulator - SharedMemory, - /// Use subcube functions - Subcube, + /// Use a best-effort strategy based on the hardware capacity. + /// This differs from Autotune as it doesn't try and compare many strategies to select the best. + Unspecified, + /// Fix the exact strategy for the reduction. + Specific(cubecl::reduce::ReduceStrategy), + /// Use autotune to find the best strategy given the hardware and the inputs. #[cfg(feature = "autotune")] Autotune, } impl Default for ReduceStrategy { fn default() -> Self { - // if autotune is enabled, default to autotune #[cfg(feature = "autotune")] - return ReduceStrategy::Autotune; + return Self::Autotune; #[cfg(not(feature = "autotune"))] - ReduceStrategy::Naive + return Self::Unspecified; } } - -macro_rules! reduce_operation { - ($name:ident, $ops:ident) => { - #[derive(Debug)] - pub(crate) struct $ops; - - impl ReduceDimAlgorithm for $ops {} - - /// Executes the reduce operation with the given strategy. - pub fn $name( - tensor: JitTensor, - dim: usize, - strategy: ReduceStrategy, - ) -> Result, String> { - match strategy { - ReduceStrategy::Naive => reduce_dim_naive::<$ops, R, EI, EO>(tensor, dim), - ReduceStrategy::SharedMemory => reduce_dim_shared::<$ops, R, EI, EO>(tensor, dim), - ReduceStrategy::Subcube => reduce_dim_subcube::<$ops, R, EI, EO>(tensor, dim), - #[cfg(feature = "autotune")] - ReduceStrategy::Autotune => Ok(reduce_dim_autotune::<$ops, R, EI, EO>(tensor, dim)), - } - } - }; -} - -// Autotunable reduce operation variants -reduce_operation!(sum_dim, SumDim); -reduce_operation!(mean_dim, MeanDim); -reduce_operation!(prod_dim, ProdDim); -reduce_operation!(argmin, Argmin); -reduce_operation!(argmax, Argmax); diff --git a/crates/burn-jit/src/kernel/reduce/mod.rs b/crates/burn-jit/src/kernel/reduce/mod.rs index 2401f9467e..8ff38a9da7 100644 --- a/crates/burn-jit/src/kernel/reduce/mod.rs +++ b/crates/burn-jit/src/kernel/reduce/mod.rs @@ -1,12 +1,5 @@ mod base; -mod naive; -mod prod; -mod shared; -mod subcube; -mod sum; mod tune; pub use base::*; -pub use prod::*; -pub use sum::*; pub use tune::*; diff --git a/crates/burn-jit/src/kernel/reduce/naive/argmax.rs b/crates/burn-jit/src/kernel/reduce/naive/argmax.rs deleted file mode 100644 index d577d3decf..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/argmax.rs +++ /dev/null @@ -1,36 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::Argmax; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for Argmax { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for Argmax { - type Accumulator = (EI, u32); - - fn initialize_naive() -> Self::Accumulator { - // TODO: switch to using f32::NEG_INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 - (EI::min_value(), 0u32) - } - - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32) { - let (max, index) = accumulator; - if current_value > *max { - *max = current_value; - *index = i; - } - } - - fn assign_naive( - output: &mut Tensor, - accumulator: Self::Accumulator, - _shape_reduce_dim: u32, - ) { - let (_, index) = accumulator; - output[ABSOLUTE_POS] = EO::cast_from(index); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/argmin.rs b/crates/burn-jit/src/kernel/reduce/naive/argmin.rs deleted file mode 100644 index 2302a2b205..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/argmin.rs +++ /dev/null @@ -1,36 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::Argmin; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for Argmin { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for Argmin { - type Accumulator = (EI, u32); - - fn initialize_naive() -> Self::Accumulator { - // TODO: switch to using f32::INFINITY when it's supported: https://github.com/tracel-ai/cubecl/issues/68 - (EI::max_value(), 0u32) - } - - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32) { - let (min, index) = accumulator; - if current_value < *min { - *min = current_value; - *index = i; - } - } - - fn assign_naive( - output: &mut Tensor, - accumulator: Self::Accumulator, - _shape_reduce_dim: u32, - ) { - let (_, index) = accumulator; - output[ABSOLUTE_POS] = EO::cast_from(index); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/base.rs b/crates/burn-jit/src/kernel/reduce/naive/base.rs deleted file mode 100644 index 7512103ebb..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/base.rs +++ /dev/null @@ -1,25 +0,0 @@ -use cubecl::prelude::*; - -pub trait ReduceDimNaiveFamily: Send + Sync + 'static { - type Reduce: ReduceDimNaive; -} - -/// Specifies the reduce dim algorithm in use -#[cube] -pub trait ReduceDimNaive: Send + Sync + 'static { - /// The reduction accumulator - type Accumulator: CubeType; - - /// Initialization for naive algorithm - fn initialize_naive() -> Self::Accumulator; - - /// Inner loop for naive algorithm - fn inner_loop_naive(accumulator: &mut Self::Accumulator, current_value: EI, i: u32); - - /// Assignation for naive algorithm - fn assign_naive( - output: &mut Tensor, - accumulator: Self::Accumulator, - shape_reduce_dim: u32, - ); -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/kernel.rs b/crates/burn-jit/src/kernel/reduce/naive/kernel.rs deleted file mode 100644 index c862e7070d..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/kernel.rs +++ /dev/null @@ -1,71 +0,0 @@ -use crate::{ - element::JitElement, kernel::reduce::init_reduce_output, tensor::JitTensor, JitRuntime, -}; -use cubecl::calculate_cube_count_elemwise; -use cubecl::prelude::*; - -use super::base::ReduceDimNaive; -use super::base::ReduceDimNaiveFamily; - -#[cube(launch_unchecked)] -pub(crate) fn naive_reduce_dim_kernel( - input: &Tensor, - output: &mut Tensor, - dim: u32, -) { - naive_reduce::, EI, EO>(input, output, dim) -} - -#[cube] -fn naive_reduce, EI: Numeric, EO: Numeric>( - input: &Tensor, - output: &mut Tensor, - dim: u32, -) { - if ABSOLUTE_POS >= output.len() { - return; - } - - let mut offset_input = 0; - - for i in 0..input.rank() { - let mut offset_local = ABSOLUTE_POS / output.stride(i); - offset_local %= output.shape(i); - if i != dim { - offset_input += offset_local * input.stride(i); - } - } - - let mut accumulator = RD::initialize_naive(); - - for i in 0..input.shape(dim) { - let index = i * input.stride(dim) + offset_input; - RD::inner_loop_naive(&mut accumulator, input[index], i); - } - - RD::assign_naive::(output, accumulator, input.shape(dim)); -} - -/// Executes the naive kernel for reduce dim -pub fn reduce_dim_naive( - input: JitTensor, - dim: usize, -) -> Result, String> { - let output = init_reduce_output::(&input, dim); - - let cube_dim = CubeDim::default(); - let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim); - - unsafe { - naive_reduce_dim_kernel::launch_unchecked::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg::(1), - output.as_tensor_arg::(1), - ScalarArg::new(dim as u32), - ); - } - - Ok(output) -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs deleted file mode 100644 index 774c9b251c..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/mean_dim.rs +++ /dev/null @@ -1,27 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::MeanDim; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for MeanDim { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for MeanDim { - type Accumulator = EI; - - fn initialize_naive() -> EI { - EI::from_int(0) - } - - fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { - *accumulator += current_value; - } - - fn assign_naive(output: &mut Tensor, accumulator: EI, shape_reduce_dim: u32) { - let mean = accumulator / EI::cast_from(shape_reduce_dim); - output[ABSOLUTE_POS] = EO::cast_from(mean); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/mod.rs b/crates/burn-jit/src/kernel/reduce/naive/mod.rs deleted file mode 100644 index b11ee5e2da..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub(crate) mod argmax; -pub(crate) mod argmin; -pub(crate) mod base; -pub(crate) mod kernel; -pub(crate) mod mean_dim; -pub(crate) mod prod_dim; -pub(crate) mod sum_dim; diff --git a/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs b/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs deleted file mode 100644 index 1ea52a149c..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/prod_dim.rs +++ /dev/null @@ -1,26 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::ProdDim; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for ProdDim { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for ProdDim { - type Accumulator = EI; - - fn initialize_naive() -> EI { - EI::from_int(1) - } - - fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { - *accumulator *= current_value; - } - - fn assign_naive(output: &mut Tensor, accumulator: EI, _shape_reduce_dim: u32) { - output[ABSOLUTE_POS] = EO::cast_from(accumulator); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs b/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs deleted file mode 100644 index 7168e07ff3..0000000000 --- a/crates/burn-jit/src/kernel/reduce/naive/sum_dim.rs +++ /dev/null @@ -1,26 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::SumDim; - -use super::base::{ReduceDimNaive, ReduceDimNaiveFamily}; - -impl ReduceDimNaiveFamily for SumDim { - type Reduce = Self; -} - -#[cube] -impl ReduceDimNaive for SumDim { - type Accumulator = EI; - - fn initialize_naive() -> EI { - EI::from_int(0) - } - - fn inner_loop_naive(accumulator: &mut EI, current_value: EI, _i: u32) { - *accumulator += current_value; - } - - fn assign_naive(output: &mut Tensor, accumulator: EI, _shape_reduce_dim: u32) { - output[ABSOLUTE_POS] = EO::cast_from(accumulator); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/prod.rs b/crates/burn-jit/src/kernel/reduce/prod.rs deleted file mode 100644 index 1b156157fa..0000000000 --- a/crates/burn-jit/src/kernel/reduce/prod.rs +++ /dev/null @@ -1,15 +0,0 @@ -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_tensor::Shape; - -use super::{prod_dim, ReduceStrategy}; - -/// Multiply all elements in the input buffer. -pub fn prod( - input: JitTensor, - strategy: ReduceStrategy, -) -> JitTensor { - let shape = Shape::new([input.shape.num_elements()]); - let input: JitTensor = - JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); - prod_dim::(input, 0, strategy).unwrap() -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs deleted file mode 100644 index 43c03c09ce..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs +++ /dev/null @@ -1,63 +0,0 @@ -use crate::kernel::reduce::Argmax; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for Argmax { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - /// Initialization for shared algorithm - fn initialize_shared( - shared_memory_size: u32, - write_position: u32, - ) -> (SharedMemory, SharedMemory) { - let mut value_shared = SharedMemory::new(shared_memory_size); - let mut index_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::min_value(); - index_shared[write_position] = 0; - (value_shared, index_shared) - } - - /// How to write to shared memory - fn write_to_shared( - shared_memory: &mut (SharedMemory, SharedMemory), - write_position: u32, - value: (EIn, u32), - ) { - let (values, indices) = shared_memory; - let (value, index) = value; - - if value > values[write_position] { - values[write_position] = value; - indices[write_position] = index; - } - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> (EIn, u32) { - (input[read_position], i) - } - - /// How to read from shared memory - fn read_from_shared( - shared_memory: &(SharedMemory, SharedMemory), - read_position: u32, - ) -> (EIn, u32) { - let (values, indices) = shared_memory; - (values[read_position], indices[read_position]) - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &(SharedMemory, SharedMemory), - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - let (_, indices) = shared_memory; - output[write_position] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs deleted file mode 100644 index 0e47693c5a..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs +++ /dev/null @@ -1,64 +0,0 @@ -use cubecl::prelude::*; - -use crate::kernel::reduce::Argmin; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for Argmin { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - /// Initialization for shared algorithm - fn initialize_shared( - shared_memory_size: u32, - write_position: u32, - ) -> (SharedMemory, SharedMemory) { - let mut value_shared = SharedMemory::new(shared_memory_size); - let mut index_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::max_value(); - index_shared[write_position] = 0; - (value_shared, index_shared) - } - - /// How to write to shared memory - fn write_to_shared( - shared_memory: &mut (SharedMemory, SharedMemory), - write_position: u32, - value: (EIn, u32), - ) { - let (values, indices) = shared_memory; - let (value, index) = value; - - if value < values[write_position] { - values[write_position] = value; - indices[write_position] = index; - } - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> (EIn, u32) { - (input[read_position], i) - } - - /// How to read from shared memory - fn read_from_shared( - shared_memory: &(SharedMemory, SharedMemory), - read_position: u32, - ) -> (EIn, u32) { - let (values, indices) = shared_memory; - (values[read_position], indices[read_position]) - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &(SharedMemory, SharedMemory), - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - let (_, indices) = shared_memory; - output[write_position] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/base.rs b/crates/burn-jit/src/kernel/reduce/shared/base.rs deleted file mode 100644 index 256123fe1b..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/base.rs +++ /dev/null @@ -1,33 +0,0 @@ -use cubecl::prelude::*; - -/// Specifies the reduce dim algorithm in use -#[cube] -pub trait ReduceDimShared: Send + Sync + 'static { - /// The reduction accumulator - type Accumulator: CubeType; - type Value: CubeType; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> Self::Accumulator; - - /// How to write to shared memory - fn write_to_shared( - shared_memory: &mut Self::Accumulator, - write_position: u32, - value: Self::Value, - ); - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, i: u32) -> Self::Value; - - /// How to read from shared memory - fn read_from_shared(shared_memory: &Self::Accumulator, read_position: u32) -> Self::Value; - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &Self::Accumulator, - output: &mut Tensor, - write_position: u32, - shape_reduce_dim: u32, - ); -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs b/crates/burn-jit/src/kernel/reduce/shared/kernel.rs deleted file mode 100644 index 1c15e4523f..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/kernel.rs +++ /dev/null @@ -1,117 +0,0 @@ -use cubecl::prelude::*; - -use crate::{kernel::reduce::init_reduce_output, tensor::JitTensor, JitElement, JitRuntime}; - -use super::base::ReduceDimShared; - -#[cube(launch)] -pub fn reduce_dim_shared_kernel< - RD: ReduceDimShared, - EIn: JitElement, - EOut: JitElement, ->( - input: &Tensor, - output: &mut Tensor, - #[comptime] dim: u32, - #[comptime] smem_size: u32, - #[comptime] elems_per_thread: u32, - #[comptime] divisible_shape: bool, -) { - let reduce_group_id = CUBE_POS; - - let stride_reduce_dim_input = input.stride(dim); - let shape_reduce_dim_input = input.shape(dim); - - let mut shared_memory = RD::initialize_shared(smem_size, UNIT_POS); - - let mut index_offset = 0; - - for i in 0..input.rank() { - let num_block = reduce_group_id / output.stride(i) % output.shape(i); - index_offset += num_block * input.stride(i); - } - - for i in 0..elems_per_thread { - let nth = i * CUBE_DIM + UNIT_POS; - - #[allow(clippy::collapsible_else_if)] - if divisible_shape { - let current_pos = nth * stride_reduce_dim_input + index_offset; - - let new_value = RD::read_from_input(input, current_pos, nth); - RD::write_to_shared(&mut shared_memory, UNIT_POS, new_value); - } else { - if nth < shape_reduce_dim_input { - let current_pos = nth * stride_reduce_dim_input + index_offset; - - let new_value = RD::read_from_input(input, current_pos, nth); - RD::write_to_shared(&mut shared_memory, UNIT_POS, new_value); - } - } - } - - sync_units(); - - let mut n_threads = CUBE_DIM; - - while n_threads > 1 { - n_threads /= 2; - - if UNIT_POS < n_threads { - let read_pos = n_threads + UNIT_POS; - let read_value = RD::read_from_shared(&shared_memory, read_pos); - RD::write_to_shared(&mut shared_memory, UNIT_POS, read_value); - } - - sync_units(); - } - - if UNIT_POS == 0 { - RD::assign_shared( - &shared_memory, - output, - reduce_group_id, - shape_reduce_dim_input, - ); - } -} - -/// Executes the shared memory kernel for reduce dim -pub fn reduce_dim_shared< - RD: ReduceDimShared, - R: JitRuntime, - EI: JitElement, - EO: JitElement, ->( - input: JitTensor, - dim: usize, -) -> Result, String> { - let output = init_reduce_output::(&input, dim); - - let num_elems_output = output.shape.num_elements(); - let cube_dim = CubeDim::default(); - let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32)); - let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x); - let cube_count = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1); - - let reduce_group_size = input.shape.dims[dim]; - let n_invocation_per_cube = cube_dim.num_elems(); - let elems_per_thread = - f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; - - let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32; - - reduce_dim_shared_kernel::launch::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg::(1), - output.as_tensor_arg::(1), - dim as u32, - cube_dim.num_elems(), - elems_per_thread, - divisible_shape, - ); - - Ok(output) -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs deleted file mode 100644 index eef8f5f478..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/mean_dim.rs +++ /dev/null @@ -1,44 +0,0 @@ -use crate::kernel::reduce::MeanDim; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for MeanDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { - let mut value_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::from_int(0); - value_shared - } - - /// How to write to shared memory - fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { - shared_memory[write_position] += value; - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { - input[read_position] - } - - /// How to read from shared memory - fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { - shared_memory[read_position] - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &SharedMemory, - output: &mut Tensor, - write_position: u32, - shape_reduce_dim: u32, - ) { - let mean = shared_memory[0] / EIn::cast_from(shape_reduce_dim); - output[write_position] = EOut::cast_from(mean); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/mod.rs b/crates/burn-jit/src/kernel/reduce/shared/mod.rs deleted file mode 100644 index b11ee5e2da..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub(crate) mod argmax; -pub(crate) mod argmin; -pub(crate) mod base; -pub(crate) mod kernel; -pub(crate) mod mean_dim; -pub(crate) mod prod_dim; -pub(crate) mod sum_dim; diff --git a/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs deleted file mode 100644 index 594f2fec11..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/prod_dim.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::kernel::reduce::ProdDim; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for ProdDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { - let mut value_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::from_int(1); - value_shared - } - - /// How to write to shared memory - fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { - shared_memory[write_position] *= value; - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { - input[read_position] - } - - /// How to read from shared memory - fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { - shared_memory[read_position] - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &SharedMemory, - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - output[write_position] = EOut::cast_from(shared_memory[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs b/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs deleted file mode 100644 index 476dd554a4..0000000000 --- a/crates/burn-jit/src/kernel/reduce/shared/sum_dim.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::kernel::reduce::SumDim; -use cubecl::prelude::*; - -use super::base::ReduceDimShared; - -#[cube] -impl ReduceDimShared for SumDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - /// Initialization for shared algorithm - fn initialize_shared(shared_memory_size: u32, write_position: u32) -> SharedMemory { - let mut value_shared = SharedMemory::new(shared_memory_size); - value_shared[write_position] = EIn::from_int(0); - value_shared - } - - /// How to write to shared memory - fn write_to_shared(shared_memory: &mut SharedMemory, write_position: u32, value: EIn) { - shared_memory[write_position] += value; - } - - /// How to read from input in shared algorithm - fn read_from_input(input: &Tensor, read_position: u32, _i: u32) -> EIn { - input[read_position] - } - - /// How to read from shared memory - fn read_from_shared(shared_memory: &SharedMemory, read_position: u32) -> EIn { - shared_memory[read_position] - } - - /// How to assign from shared memory - fn assign_shared( - shared_memory: &SharedMemory, - output: &mut Tensor, - write_position: u32, - _shape_reduce_dim: u32, - ) { - output[write_position] = EOut::cast_from(shared_memory[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/argmax.rs b/crates/burn-jit/src/kernel/reduce/subcube/argmax.rs deleted file mode 100644 index c8e567e816..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/argmax.rs +++ /dev/null @@ -1,54 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::Argmax; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for Argmax { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - let value_shared = SharedMemory::new(size); - let index_shared = SharedMemory::new(size); - (value_shared, index_shared) - } - - fn init_value() -> Self::Value { - (comptime![EIn::min_value()], 0u32) - } - - fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value { - (input[pos], i) - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - let (values, indices) = acc; - (values[pos], indices[pos]) - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - let (current_val, current_idx) = current; - let (new_val, new_idx) = new; - *current_val = Max::max(*current_val, new_val); - *current_idx = select(*current_val == new_val, new_idx, *current_idx); - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let (val, index) = value; - let (val_smem, index_smem) = acc; - let max = plane_max(val); - - if max == val { - val_smem[write_position] = val; - index_smem[write_position] = index; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - let (_, indices) = acc; - out[pos] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/argmin.rs b/crates/burn-jit/src/kernel/reduce/subcube/argmin.rs deleted file mode 100644 index b7950ebfe2..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/argmin.rs +++ /dev/null @@ -1,54 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::Argmin; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for Argmin { - /// The reduction accumulator - type Accumulator = (SharedMemory, SharedMemory); - type Value = (EIn, u32); - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - let value_shared = SharedMemory::new(size); - let index_shared = SharedMemory::new(size); - (value_shared, index_shared) - } - - fn init_value() -> Self::Value { - (comptime![EIn::max_value()], 0u32) - } - - fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value { - (input[pos], i) - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - let (values, indices) = acc; - (values[pos], indices[pos]) - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - let (current_val, current_idx) = current; - let (new_val, new_idx) = new; - *current_val = Min::min(*current_val, new_val); - *current_idx = select(*current_val == new_val, new_idx, *current_idx); - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let (val, index) = value; - let (val_smem, index_smem) = acc; - let min = plane_min(val); - - if min == val { - val_smem[write_position] = val; - index_smem[write_position] = index; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - let (_, indices) = acc; - out[pos] = EOut::cast_from(indices[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/base.rs b/crates/burn-jit/src/kernel/reduce/subcube/base.rs deleted file mode 100644 index f20e538914..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/base.rs +++ /dev/null @@ -1,15 +0,0 @@ -use cubecl::prelude::*; - -#[cube] -pub trait ReduceDimSubcube: Send + Sync + 'static { - type Accumulator: CubeType; - type Value: CubeType; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator; - fn init_value() -> Self::Value; - fn read_value(input: &Tensor, pos: u32, i: u32) -> Self::Value; - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value; - fn update_value(current: &mut Self::Value, new: Self::Value); - fn reduce_subcube(acc: &mut Self::Accumulator, pos: u32, value: Self::Value); - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, dim_len: u32); -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs b/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs deleted file mode 100644 index 26f65f5d68..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/kernel.rs +++ /dev/null @@ -1,134 +0,0 @@ -use cubecl::{prelude::*, CubeCount, CubeDim, Feature}; - -use crate::{ - kernel::reduce::{init_reduce_output, shared::kernel::reduce_dim_shared, ReduceDimAlgorithm}, - tensor::JitTensor, - JitElement, JitRuntime, -}; - -use super::base::ReduceDimSubcube; - -#[cube(launch)] -pub fn reduce_dim_subcube_kernel< - RD: ReduceDimSubcube, - EIn: JitElement, - EOut: JitElement, ->( - input: &Tensor, - output: &mut Tensor, - #[comptime] dim: u32, - #[comptime] subcube_size: u32, - #[comptime] elems_per_thread: u32, - #[comptime] divisible_shape: bool, -) { - let reduce_group_id = CUBE_POS; - - let stride_reduce_dim_input = input.stride(dim); - let shape_reduce_dim_input = input.shape(dim); - - let should_unroll = elems_per_thread <= 8; - - let warp_id = plane_broadcast(UNIT_POS / PLANE_DIM, 0); - - let mut shared_memory = RD::init_shared(subcube_size); - - let mut index_offset = 0; - - for i in 0..input.rank() { - let num_block = reduce_group_id / output.stride(i) % output.shape(i); - index_offset += num_block * input.stride(i); - } - - let mut value = RD::init_value(); - - #[unroll(should_unroll)] - for i in 0..elems_per_thread { - let nth = i * CUBE_DIM + UNIT_POS; - let current_pos = nth * stride_reduce_dim_input + index_offset; - - #[allow(clippy::collapsible_else_if)] - if divisible_shape { - let next = RD::read_value(input, current_pos, nth); - RD::update_value(&mut value, next); - } else { - if nth < shape_reduce_dim_input { - let next = RD::read_value(input, current_pos, nth); - RD::update_value(&mut value, next); - } - } - } - - RD::reduce_subcube(&mut shared_memory, warp_id, value); - - sync_units(); - - if UNIT_POS >= PLANE_DIM { - return; - } - - let value = RD::read_from_shared(&shared_memory, UNIT_POS); - RD::reduce_subcube(&mut shared_memory, 0, value); - - if UNIT_POS == 0 { - RD::store( - &shared_memory, - output, - reduce_group_id, - shape_reduce_dim_input, - ); - } -} - -/// Executes the shared memory kernel for reduce dim -pub fn reduce_dim_subcube< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement, - EO: JitElement, ->( - input: JitTensor, - dim: usize, -) -> Result, String> { - let topology = input.client.properties().hardware_properties(); - - if !input.client.properties().feature_enabled(Feature::Plane) - || topology.plane_size_min != topology.plane_size_max - { - return reduce_dim_shared::(input, dim); - } - - let subcube_size = topology.plane_size_min; - - let output = init_reduce_output::(&input, dim); - - let num_elems_output = output.shape.num_elements(); - let cube_dim = CubeDim { - x: subcube_size, - y: subcube_size, - z: 1, - }; - let cube_count_x = f32::ceil(f32::sqrt(num_elems_output as f32)); - let cube_count_y = f32::ceil(num_elems_output as f32 / cube_count_x); - let cube_count = CubeCount::Static(cube_count_x as u32, cube_count_y as u32, 1); - - let reduce_group_size = input.shape.dims[dim]; - let n_invocation_per_cube = cube_dim.num_elems(); - let elems_per_thread = - f32::ceil(reduce_group_size as f32 / n_invocation_per_cube as f32) as u32; - - let divisible_shape = n_invocation_per_cube * elems_per_thread == reduce_group_size as u32; - - reduce_dim_subcube_kernel::launch::( - &input.client, - cube_count, - cube_dim, - input.as_tensor_arg::(1), - output.as_tensor_arg::(1), - dim as u32, - subcube_size, - elems_per_thread, - divisible_shape, - ); - - Ok(output) -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/mean_dim.rs b/crates/burn-jit/src/kernel/reduce/subcube/mean_dim.rs deleted file mode 100644 index fb8c0b41d6..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/mean_dim.rs +++ /dev/null @@ -1,45 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::MeanDim; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for MeanDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - SharedMemory::new(size) - } - - fn init_value() -> Self::Value { - EIn::cast_from(0u32) - } - - fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { - input[pos] - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - acc[pos] - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - *current += new; - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let sum = plane_sum(value); - - if UNIT_POS % PLANE_DIM == 0 { - acc[write_position] = sum; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, dim_length: u32) { - let denom = EIn::cast_from(dim_length); - out[pos] = EOut::cast_from(acc[0] / denom); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/mod.rs b/crates/burn-jit/src/kernel/reduce/subcube/mod.rs deleted file mode 100644 index 183c1e2daf..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod argmax; -pub mod argmin; -pub mod base; -pub mod kernel; -pub mod mean_dim; -pub mod prod_dim; -pub mod sum_dim; diff --git a/crates/burn-jit/src/kernel/reduce/subcube/prod_dim.rs b/crates/burn-jit/src/kernel/reduce/subcube/prod_dim.rs deleted file mode 100644 index cccec95167..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/prod_dim.rs +++ /dev/null @@ -1,44 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::ProdDim; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for ProdDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - SharedMemory::new(size) - } - - fn init_value() -> Self::Value { - EIn::from_int(1) - } - - fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { - input[pos] - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - acc[pos] - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - *current *= new; - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let prod = plane_prod(value); - - if UNIT_POS % PLANE_DIM == 0 { - acc[write_position] = prod; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - out[pos] = EOut::cast_from(acc[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/subcube/sum_dim.rs b/crates/burn-jit/src/kernel/reduce/subcube/sum_dim.rs deleted file mode 100644 index 1059432eb2..0000000000 --- a/crates/burn-jit/src/kernel/reduce/subcube/sum_dim.rs +++ /dev/null @@ -1,44 +0,0 @@ -use cubecl::{cube, prelude::*}; - -use crate::kernel::reduce::SumDim; - -use super::base::ReduceDimSubcube; - -#[cube] -impl ReduceDimSubcube for SumDim { - /// The reduction accumulator - type Accumulator = SharedMemory; - type Value = EIn; - - fn init_shared(#[comptime] size: u32) -> Self::Accumulator { - SharedMemory::new(size) - } - - fn init_value() -> Self::Value { - EIn::cast_from(0u32) - } - - fn read_value(input: &Tensor, pos: u32, _i: u32) -> Self::Value { - input[pos] - } - - fn read_from_shared(acc: &Self::Accumulator, pos: u32) -> Self::Value { - acc[pos] - } - - fn update_value(current: &mut Self::Value, new: Self::Value) { - *current += new; - } - - fn reduce_subcube(acc: &mut Self::Accumulator, write_position: u32, value: Self::Value) { - let sum = plane_sum(value); - - if UNIT_POS % PLANE_DIM == 0 { - acc[write_position] = sum; - } - } - - fn store(acc: &Self::Accumulator, out: &mut Tensor, pos: u32, _layout: u32) { - out[pos] = EOut::cast_from(acc[0]); - } -} diff --git a/crates/burn-jit/src/kernel/reduce/sum.rs b/crates/burn-jit/src/kernel/reduce/sum.rs deleted file mode 100644 index d3c9416dc1..0000000000 --- a/crates/burn-jit/src/kernel/reduce/sum.rs +++ /dev/null @@ -1,15 +0,0 @@ -use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use burn_tensor::Shape; - -use super::{sum_dim, ReduceStrategy}; - -/// Sum all elements in the input buffer. -pub fn sum( - input: JitTensor, - strategy: ReduceStrategy, -) -> JitTensor { - let shape = Shape::new([input.shape.num_elements()]); - let input: JitTensor = - JitTensor::new_contiguous(input.client, input.device, shape, input.handle, input.dtype); - sum_dim::(input, 0, strategy).unwrap() -} diff --git a/crates/burn-jit/src/kernel/reduce/tune.rs b/crates/burn-jit/src/kernel/reduce/tune.rs new file mode 100644 index 0000000000..6816196a37 --- /dev/null +++ b/crates/burn-jit/src/kernel/reduce/tune.rs @@ -0,0 +1,222 @@ +#![allow(missing_docs)] + +use burn_tensor::ElementConversion; +use cubecl::{ + client::ComputeClient, + tune, + tune::{local_tuner, tune_with, LocalTuner}, + AutotuneKey, +}; +use serde::{Deserialize, Serialize}; + +use crate::{ + kernel::prng::random_like_uniform, ops::numeric::empty_device, tensor::JitTensor, + JitAutotuneKey, JitElement, JitRuntime, JitTuneId, +}; + +/// Executes autotune on reduce operations. +pub fn autotune_reduce< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, +>( + client: &ComputeClient, + input: JitTensor, + output: JitTensor, + dim: usize, +) -> Result<(), cubecl::reduce::ReduceError> { + static TUNER: LocalTuner = local_tuner!(); + + TUNER.execute( + &JitTuneId::new::(&input.device), + client, + Box::new(ReduceOps::::new(input, output, dim)), + ); + + Ok(()) +} + +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +/// Autotune key representative of redue versions +pub struct ReduceAutotuneKey { + dtype: burn_tensor::DType, + #[autotune(anchor)] + reduce_axis_shape: usize, + #[autotune(anchor)] + reduce_axis_stride: usize, + #[autotune(anchor)] + outer_axes_product: usize, // The product of the shapes of all axes with greater strides. +} + +impl ReduceAutotuneKey { + pub(crate) fn generate(input: &JitTensor, axis: usize) -> Self { + let rank = input.shape.num_dims(); + + if axis > rank { + panic!("axis {axis} is out-of-bound for a rank of {rank}"); + } + + let dtype = input.dtype; + let reduce_axis_shape = input.shape.dims[axis]; + let reduce_axis_stride = input.strides[axis]; + + let outer_axes_product = input + .strides + .iter() + .zip(input.shape.dims.iter()) + .filter_map(|(stride, shape)| (*stride > reduce_axis_stride).then_some(shape)) + .product(); + + Self { + dtype, + reduce_axis_shape, + reduce_axis_stride, + outer_axes_product, + } + } +} + +pub(crate) fn create_key( + input: &JitTensor, + _output: &JitTensor, + dim: &usize, +) -> JitAutotuneKey { + JitAutotuneKey::Reduce(ReduceAutotuneKey::generate(input, *dim)) +} + +pub use reduce_ops::*; +mod reduce_ops { + #![allow(missing_docs)] + + use super::*; + + #[tune( + operations(reduce, reduce_shared, reduce_plane, reduce_shared_plane), + create_key = create_key::, + should_run = should_run +)] + fn reduce_ops( + key: JitAutotuneKey, + input: JitTensor, + output: JitTensor, + dim: usize, + ) { + let random_bounds: (In, In) = ((-10.0_f32).elem::(), (10.0_f32).elem::()); + let input = random_like_uniform(input, random_bounds.0, random_bounds.1); + + let output = empty_device::( + output.client.clone(), + output.device.clone(), + output.shape.clone(), + ); + + tune_with!(input, output, dim) + } + + fn should_run( + op: &ReduceOps, + _key: &JitAutotuneKey, + index: usize, + ) -> bool { + match index { + // if strategy uses planes + 2 | 3 => { + let properties = op.input.client.properties(); + properties.feature_enabled(cubecl::Feature::Plane) + && properties + .hardware_properties() + .defined_plane_size() + .is_some() + } + _ => true, + } + } + + fn reduce( + input: JitTensor, + output: JitTensor, + axis: usize, + ) -> Result<(), String> { + cubecl::reduce::reduce::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + axis, + Some(cubecl::reduce::ReduceStrategy { + shared: false, + use_planes: false, + }), + ) + .map_err(|e| format!("{e}")) + } + + fn reduce_shared< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, + >( + input: JitTensor, + output: JitTensor, + axis: usize, + ) -> Result<(), String> { + cubecl::reduce::reduce::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + axis, + Some(cubecl::reduce::ReduceStrategy { + shared: true, + use_planes: false, + }), + ) + .map_err(|e| format!("{e}")) + } + + fn reduce_plane< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, + >( + input: JitTensor, + output: JitTensor, + axis: usize, + ) -> Result<(), String> { + cubecl::reduce::reduce::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + axis, + Some(cubecl::reduce::ReduceStrategy { + shared: false, + use_planes: true, + }), + ) + .map_err(|e| format!("{e}")) + } + + fn reduce_shared_plane< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, + >( + input: JitTensor, + output: JitTensor, + axis: usize, + ) -> Result<(), String> { + cubecl::reduce::reduce::( + &input.client, + input.as_handle_ref(), + output.as_handle_ref(), + axis, + Some(cubecl::reduce::ReduceStrategy { + shared: true, + use_planes: true, + }), + ) + .map_err(|e| format!("{e}")) + } +} diff --git a/crates/burn-jit/src/kernel/reduce/tune/base.rs b/crates/burn-jit/src/kernel/reduce/tune/base.rs deleted file mode 100644 index f52bfd7ca0..0000000000 --- a/crates/burn-jit/src/kernel/reduce/tune/base.rs +++ /dev/null @@ -1,94 +0,0 @@ -use burn_tensor::{Element, ElementConversion}; -use cubecl::tune::{local_tuner, tune_with, LocalTuner}; -use cubecl::{tune, Feature}; - -use crate::{ - element::JitElement, - kernel::{ - prng::random_like_uniform, - reduce::{ - naive::kernel::reduce_dim_naive, shared::kernel::reduce_dim_shared, - subcube::kernel::reduce_dim_subcube, ReduceDimAlgorithm, - }, - }, - tensor::JitTensor, - tune_key::JitAutotuneKey, - JitRuntime, JitTuneId, -}; - -use super::create_key; - -/// Set of reduce_dim implementations available for autotune -/// Autotune key is given by concatenating the closest upper power of 2 of -/// dim to reduce, and product of others -#[tune( - operations(reduce_dim_naive, reduce_dim_shared, reduce_dim_subcube), - create_key = create_key::, - should_run = should_run -)] -pub fn reduce_dim_operations< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement + Element, - EO: JitElement + Element, ->( - key: JitAutotuneKey, - input: JitTensor, - reduce_dim: usize, -) -> JitTensor { - let random_bounds: (EI, EI) = ((-10.0).elem::(), (10.0).elem::()); - let input = random_like_uniform(input, random_bounds.0, random_bounds.1); - - tune_with!(input, reduce_dim) -} - -/// Executes autotune on reduce_dim operation -pub(crate) fn reduce_dim_autotune< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement + Element, - EO: JitElement + Element, ->( - input: JitTensor, - reduce_dim: usize, -) -> JitTensor { - let client = input.client.clone(); - - let id = JitTuneId::new::(&input.device); - - let operation_set = Box::new(ReduceDimOperations::::new(input, reduce_dim)); - - static TUNER: LocalTuner = local_tuner!(); - - TUNER.execute(&id, &client, operation_set) -} - -fn should_run< - RD: ReduceDimAlgorithm, - R: JitRuntime, - EI: JitElement + Element, - EO: JitElement + Element, ->( - op: &ReduceDimOperations, - key: &JitAutotuneKey, - index: usize, -) -> bool { - let JitAutotuneKey::ReduceDim(key) = key else { - unreachable!() - }; - - match index { - // Naive - 0 => key.reduce_dim_length <= 8192, - // Shared - 1 => key.reduce_dim_length >= 16, - // Subcube - 2 => { - let props = op.input.client.properties(); - let hardware = props.hardware_properties(); - props.feature_enabled(Feature::Plane) - && hardware.plane_size_min == hardware.plane_size_max - } - _ => true, - } -} diff --git a/crates/burn-jit/src/kernel/reduce/tune/key.rs b/crates/burn-jit/src/kernel/reduce/tune/key.rs deleted file mode 100644 index 3634022bc7..0000000000 --- a/crates/burn-jit/src/kernel/reduce/tune/key.rs +++ /dev/null @@ -1,39 +0,0 @@ -use cubecl::AutotuneKey; -use serde::{Deserialize, Serialize}; - -use burn_tensor::DType; - -use crate::{tensor::JitTensor, JitAutotuneKey, JitElement, JitRuntime}; - -/// Autotune key representative of reduce versions -#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] -pub struct ReduceAutotuneKey { - #[autotune(anchor)] - pub(crate) reduce_dim_length: usize, - #[autotune(anchor)] - pub(crate) reduce_dim_stride: usize, - #[autotune(anchor)] - pub(crate) others_product: usize, - dtype: DType, -} - -pub(crate) fn create_key( - input: &JitTensor, - reduce_dim: &usize, -) -> JitAutotuneKey { - let dims = &input.shape.dims; - let reduce_dim = *reduce_dim; - - let mut others_product = 1; - for (d, len) in dims.iter().enumerate() { - if d != reduce_dim { - others_product *= len - } - } - JitAutotuneKey::ReduceDim(ReduceAutotuneKey::new( - dims[reduce_dim], - input.strides[reduce_dim], - others_product, - EI::dtype(), - )) -} diff --git a/crates/burn-jit/src/kernel/reduce/tune/mod.rs b/crates/burn-jit/src/kernel/reduce/tune/mod.rs deleted file mode 100644 index aee5569b6b..0000000000 --- a/crates/burn-jit/src/kernel/reduce/tune/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -#[cfg(feature = "autotune")] -mod base; -mod key; - -#[cfg(feature = "autotune")] -pub(crate) use base::*; -pub use key::*; diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index c59b9df83c..d32de97436 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -355,7 +355,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::sum::(tensor, Default::default()) + reduce::reduce::(tensor, Default::default()).unwrap() ) } @@ -363,7 +363,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::sum_dim::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -371,7 +371,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::mean_dim::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -379,7 +379,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::prod::(tensor, Default::default()) + reduce::reduce::(tensor, Default::default()).unwrap() ) } @@ -387,7 +387,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::prod_dim::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -467,7 +467,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::argmax::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } @@ -475,7 +475,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::argmin::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() ) } diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index ed99258826..5702a90849 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -1,5 +1,5 @@ use super::{expand, numeric, permute}; -use crate::kernel::{launch_unary_numeric, NumericUnaryOp, NumericUnaryOpFamily}; +use crate::kernel::{launch_unary_numeric, reduce, NumericUnaryOp, NumericUnaryOpFamily}; use crate::{ element::BoolElement, kernel::prng::{random_bernoulli, random_normal, random_uniform}, @@ -193,31 +193,31 @@ where } fn int_sum(tensor: IntTensor) -> IntTensor { - kernel::reduce::sum::(tensor, Default::default()) + reduce::reduce::(tensor, Default::default()).unwrap() } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::sum_dim::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_prod(tensor: IntTensor) -> IntTensor { - kernel::reduce::prod::(tensor, Default::default()) + reduce::reduce::(tensor, Default::default()).unwrap() } fn int_prod_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::prod_dim::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::mean_dim::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmax::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::reduce::argmin::(tensor, dim, Default::default()).unwrap() + reduce::reduce_dim::(tensor, dim, Default::default()).unwrap() } fn int_clamp( diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index 378eb035ed..a79ac3c437 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -17,6 +17,7 @@ mod max_pool2d; mod max_pool2d_backward; mod normal; mod quantization; +mod reduce; mod repeat_dim; mod scatter; mod select; @@ -78,6 +79,8 @@ macro_rules! testgen_all { burn_jit::testgen_clamp!(); burn_jit::testgen_unary!(); + burn_jit::testgen_reduce!(); + burn_jit::testgen_quantization!(); } } diff --git a/crates/burn-jit/src/tests/reduce.rs b/crates/burn-jit/src/tests/reduce.rs new file mode 100644 index 0000000000..8e533361e9 --- /dev/null +++ b/crates/burn-jit/src/tests/reduce.rs @@ -0,0 +1,128 @@ +#[burn_tensor_testgen::testgen(reduce)] +mod reduce { + use super::*; + use burn_jit::kernel::reduce::{ + reduce, reduce_dim, ArgMax, ArgMin, Mean, Prod, ReduceStrategy, Sum, + }; + use burn_tensor::{ + backend::Backend, ops::IntTensorOps, Distribution, Int, Shape, Tensor, TensorData, + TensorPrimitive, + }; + + const RANK: usize = 4; + const SHAPE: [usize; RANK] = [2, 4, 8, 16]; + + #[test] + fn reduction_argmax_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .argmax(dim) + .into_data() + .assert_eq(&tensor_ref.clone().argmax(dim).into_data(), false); + } + } + + #[test] + fn reduction_argmin_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .argmin(dim) + .into_data() + .assert_eq(&tensor_ref.clone().argmin(dim).into_data(), false); + } + } + + #[test] + fn reduction_mean_dim_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .mean_dim(dim) + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().mean_dim(dim).into_data(), 1e-6); + } + } + + #[test] + fn reduction_mean_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + tensor + .clone() + .mean() + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().mean().into_data(), 1e-6); + } + + #[test] + fn reduction_prod_dim_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .prod_dim(dim) + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().prod_dim(dim).into_data(), 1e-6); + } + } + + #[test] + fn reduction_prod_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + tensor + .clone() + .prod() + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().prod().into_data(), 1e-6); + } + + #[test] + fn reduction_sum_dim_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + for dim in 0..RANK { + tensor + .clone() + .sum_dim(dim) + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().sum_dim(dim).into_data(), 1e-6); + } + } + + #[test] + fn reduction_sum_should_match_reference_backend() { + let tensor = + Tensor::::random(SHAPE, Distribution::Default, &Default::default()); + let tensor_ref = + Tensor::::from_data(tensor.to_data(), &Default::default()); + tensor + .clone() + .sum() + .into_data() + .assert_approx_eq_diff(&tensor_ref.clone().sum().into_data(), 1e-6); + } +} diff --git a/crates/burn-jit/src/tune_key.rs b/crates/burn-jit/src/tune_key.rs index 0a7ae855b9..cb29e2fe0c 100644 --- a/crates/burn-jit/src/tune_key.rs +++ b/crates/burn-jit/src/tune_key.rs @@ -13,7 +13,7 @@ pub enum JitAutotuneKey { /// Key for matmul operation Matmul(MatmulAutotuneKey), /// Key for reduce dim operations - ReduceDim(ReduceAutotuneKey), + Reduce(ReduceAutotuneKey), /// Key for convolution operations Conv2d(Conv2dAutotuneKey), /// Key for transpose convolution operations @@ -24,7 +24,7 @@ impl Display for JitAutotuneKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { JitAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), - JitAutotuneKey::ReduceDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + JitAutotuneKey::Reduce(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), JitAutotuneKey::Conv2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), JitAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), } diff --git a/crates/burn-tensor/src/tensor/shape.rs b/crates/burn-tensor/src/tensor/shape.rs index 8ad54ba4d9..29eebd549e 100644 --- a/crates/burn-tensor/src/tensor/shape.rs +++ b/crates/burn-tensor/src/tensor/shape.rs @@ -33,6 +33,13 @@ impl Shape { dims[..D].copy_from_slice(&self.dims[..D]); dims } + + /// Change the shape to one dimensional with the same number of elements. + pub fn flatten(&self) -> Self { + Self { + dims: [self.dims.iter().product()].into(), + } + } } impl From<[usize; D]> for Shape { From 1902c90d067faaa29b20b97917d05b5f912d4748 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Mon, 13 Jan 2025 18:34:47 -0500 Subject: [PATCH 20/61] Update cubecl (#2693) --- Cargo.lock | 26 +++++++++++++------------- Cargo.toml | 4 ++-- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index feff4ed96a..02c5f0fda5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1582,7 +1582,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1597,7 +1597,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1614,7 +1614,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "bytemuck", "cubecl-common", @@ -1633,7 +1633,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "bytemuck", "cubecl-common", @@ -1647,7 +1647,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "bytemuck", "cubecl-common", @@ -1663,7 +1663,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "bytemuck", "cubecl-common", @@ -1689,7 +1689,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "bytemuck", "cubecl-core", @@ -1701,7 +1701,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "cubecl-common", "darling", @@ -1716,7 +1716,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "cubecl-common", "cubecl-core", @@ -1732,7 +1732,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1742,7 +1742,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "async-channel", "async-lock", @@ -1763,7 +1763,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "cubecl-common", "cubecl-core", @@ -1777,7 +1777,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=707093234f11b78fb6630b98fea5d13870f94282#707093234f11b78fb6630b98fea5d13870f94282" +source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 5dfebaf2b0..508868e381 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "707093234f11b78fb6630b98fea5d13870f94282" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "707093234f11b78fb6630b98fea5d13870f94282" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c63a62be7238bd28b999160aba6a6bbdabdfb7d3" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c63a62be7238bd28b999160aba6a6bbdabdfb7d3" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } From ddc43398037b50c94fc2665c54ccf447c3eb179c Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 09:05:56 -0500 Subject: [PATCH 21/61] Add dropout prob check (#2695) * Add dropout prob check * Add test --- crates/burn-core/src/nn/dropout.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/crates/burn-core/src/nn/dropout.rs b/crates/burn-core/src/nn/dropout.rs index d03e95c1f3..79fc12ecbf 100644 --- a/crates/burn-core/src/nn/dropout.rs +++ b/crates/burn-core/src/nn/dropout.rs @@ -30,6 +30,12 @@ pub struct Dropout { impl DropoutConfig { /// Initialize a new [dropout](Dropout) module. pub fn init(&self) -> Dropout { + if self.prob < 0.0 || self.prob > 1.0 { + panic!( + "Dropout probability should be between 0 and 1, but got {}", + self.prob + ); + } Dropout { prob: self.prob } } } @@ -108,4 +114,11 @@ mod tests { assert_eq!(alloc::format!("{}", layer), "Dropout {prob: 0.5}"); } + + #[test] + #[should_panic = "Dropout probability should be between 0 and 1,"] + fn dropout_prob_invalid() { + let config = DropoutConfig::new(-10.); + let _layer = config.init(); + } } From d30f71c53343879d5d429fccc2fec74c026d298d Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Tue, 14 Jan 2025 13:36:13 -0500 Subject: [PATCH 22/61] Fix reduce autotune key no anchor (#2696) --- Cargo.lock | 26 +++++++++++------------ Cargo.toml | 4 ++-- crates/burn-jit/src/kernel/reduce/tune.rs | 4 ++-- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 02c5f0fda5..3c8a8b414f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1582,7 +1582,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1597,7 +1597,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1614,7 +1614,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "bytemuck", "cubecl-common", @@ -1633,7 +1633,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "bytemuck", "cubecl-common", @@ -1647,7 +1647,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "bytemuck", "cubecl-common", @@ -1663,7 +1663,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "bytemuck", "cubecl-common", @@ -1689,7 +1689,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "bytemuck", "cubecl-core", @@ -1701,7 +1701,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "cubecl-common", "darling", @@ -1716,7 +1716,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "cubecl-common", "cubecl-core", @@ -1732,7 +1732,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1742,7 +1742,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "async-channel", "async-lock", @@ -1763,7 +1763,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "cubecl-common", "cubecl-core", @@ -1777,7 +1777,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=c63a62be7238bd28b999160aba6a6bbdabdfb7d3#c63a62be7238bd28b999160aba6a6bbdabdfb7d3" +source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index 508868e381..bf14a247d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c63a62be7238bd28b999160aba6a6bbdabdfb7d3" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "c63a62be7238bd28b999160aba6a6bbdabdfb7d3" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/kernel/reduce/tune.rs b/crates/burn-jit/src/kernel/reduce/tune.rs index 6816196a37..c5659cc1cc 100644 --- a/crates/burn-jit/src/kernel/reduce/tune.rs +++ b/crates/burn-jit/src/kernel/reduce/tune.rs @@ -68,12 +68,12 @@ impl ReduceAutotuneKey { .filter_map(|(stride, shape)| (*stride > reduce_axis_stride).then_some(shape)) .product(); - Self { + Self::new( dtype, reduce_axis_shape, reduce_axis_stride, outer_axes_product, - } + ) } } From cdcff034f57ce8d17108ffb8bbaa6b42ed840f88 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 14:52:05 -0500 Subject: [PATCH 23/61] Set cubecl version (#2697) --- Cargo.lock | 39 ++++++++++++++++++++++++++------------- Cargo.toml | 8 ++++---- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3c8a8b414f..09f9e5dc14 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1582,7 +1582,8 @@ dependencies = [ [[package]] name = "cubecl" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aecf090429a4172d94c819e2977f440d7f5846c09f31d36937de309f986c878e" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1597,7 +1598,8 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10239ee4800968f367fbc4828250d38acf5d14fa53e8d0370d5f474387591322" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1614,7 +1616,8 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d249976814abe45ee5d04bdfd5e2359558b409affdc03914625bea778dab5ade" dependencies = [ "bytemuck", "cubecl-common", @@ -1633,7 +1636,8 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8463629d0bdf4d09d47150bce35132236c1a597f65eba213b45073406048a596" dependencies = [ "bytemuck", "cubecl-common", @@ -1647,7 +1651,8 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12c0b49113ba986e984538cf54c3d7390c0af934a80f083b6c99cad737d22c59" dependencies = [ "bytemuck", "cubecl-common", @@ -1663,7 +1668,8 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "976e150315f9d7d6bb84c51cb13c19221ea5d185bb6d61347a3c392dd29720de" dependencies = [ "bytemuck", "cubecl-common", @@ -1689,7 +1695,8 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "640c379e225fecb1336f963affd3b8f1ff66b9320a972dfe92d8158dca8b6382" dependencies = [ "bytemuck", "cubecl-core", @@ -1701,7 +1708,8 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f05d95f3be436814f909a3ac97209159f63076d3d2b254914bc02db2ac7faefb" dependencies = [ "cubecl-common", "darling", @@ -1716,7 +1724,8 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42c0593efee028e010a1a7e8646a8a405f6a653fe194bc8c5b46189245ecaeec" dependencies = [ "cubecl-common", "cubecl-core", @@ -1732,7 +1741,8 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0912890b52cc6f9636e0070320ff93dec27af15d57453789081b9a8bdb49786d" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1742,7 +1752,8 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e84f4ae5a096e4d0c410db01d18b673d6efcd6eea1724d1a001ab60484df87" dependencies = [ "async-channel", "async-lock", @@ -1763,7 +1774,8 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5d88e7d35a58a40991e42c4492739d4b89b6046ac75126cb5f10b190032012c" dependencies = [ "cubecl-common", "cubecl-core", @@ -1777,7 +1789,8 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.4.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=3c083cb136214404d8eb594258534d10a118a077#3c083cb136214404d8eb594258534d10a118a077" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cf8105d01ef4cd103d4e31bee9ae583fabc807253234923fb08218b28db7d15" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index bf14a247d7..e79690fd28 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,14 +153,14 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } +# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } +# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } ### For the release. ### -# cubecl = { version = "0.3.0", default-features = false } -# cubecl-common = { version = "0.3.0", default-features = false } +cubecl = { version = "0.4.0", default-features = false } +cubecl-common = { version = "0.4.0", default-features = false } ### For xtask crate ### tracel-xtask = { version = "=1.1.8" } From 59a2e3bc395cfaca694dff170727dc1632e6c078 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 15:14:22 -0500 Subject: [PATCH 24/61] Add burn-remote to publish workflow --- .github/workflows/publish.yml | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 8b8916c631..46446ced01 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -6,6 +6,31 @@ on: - "v*" jobs: + publish-burn-router: + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 + with: + crate: burn-router + needs: + - publish-burn-common + - publish-burn-tensor + # dev dependencies + - publish-burn-autodiff + - publish-burn-ndarray + - publish-burn-wgpu + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + + publish-burn-remote: + uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 + with: + crate: burn-derive + needs: + - publish-burn-common + - publish-burn-tensor + - publish-burn-router + secrets: + CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }} + publish-burn-derive: uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 with: @@ -162,6 +187,7 @@ jobs: - publish-burn-tch - publish-burn-ndarray - publish-burn-candle + - publish-burn-remote with: crate: burn-core secrets: From 1b91e32e86d44c33ec4a87d468954a45dcb3e554 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 15:25:24 -0500 Subject: [PATCH 25/61] Fix licenses --- NOTICES.md | 43 +++++++++++++++++++++++++++++++++++++++++++ deny.toml | 3 +++ 2 files changed, 46 insertions(+) diff --git a/NOTICES.md b/NOTICES.md index c41a90d952..0f559e27a4 100644 --- a/NOTICES.md +++ b/NOTICES.md @@ -601,3 +601,46 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +## ICU + +UNICODE LICENSE V3 + +COPYRIGHT AND PERMISSION NOTICE + +Copyright © 2016-2024 Unicode, Inc. + +NOTICE TO USER: Carefully read the following legal agreement. BY +DOWNLOADING, INSTALLING, COPYING OR OTHERWISE USING DATA FILES, AND/OR +SOFTWARE, YOU UNEQUIVOCALLY ACCEPT, AND AGREE TO BE BOUND BY, ALL OF THE +TERMS AND CONDITIONS OF THIS AGREEMENT. IF YOU DO NOT AGREE, DO NOT +DOWNLOAD, INSTALL, COPY, DISTRIBUTE OR USE THE DATA FILES OR SOFTWARE. + +Permission is hereby granted, free of charge, to any person obtaining a +copy of data files and any associated documentation (the "Data Files") or +software and any associated documentation (the "Software") to deal in the +Data Files or Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, and/or sell +copies of the Data Files or Software, and to permit persons to whom the +Data Files or Software are furnished to do so, provided that either (a) +this copyright and permission notice appear with all copies of the Data +Files or Software, or (b) this copyright and permission notice appear in +associated Documentation. + +THE DATA FILES AND SOFTWARE ARE PROVIDED "AS IS", WITHOUT WARRANTY OF ANY +KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF +THIRD PARTY RIGHTS. + +IN NO EVENT SHALL THE COPYRIGHT HOLDER OR HOLDERS INCLUDED IN THIS NOTICE +BE LIABLE FOR ANY CLAIM, OR ANY SPECIAL INDIRECT OR CONSEQUENTIAL DAMAGES, +OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, +ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE DATA +FILES OR SOFTWARE. + +Except as contained in this notice, the name of a copyright holder shall +not be used in advertising or otherwise to promote the sale, use or other +dealings in these Data Files or Software without prior written +authorization of the copyright holder. diff --git a/deny.toml b/deny.toml index a9a4506064..e8a251eb1c 100644 --- a/deny.toml +++ b/deny.toml @@ -75,12 +75,14 @@ allow = [ "Apache-2.0 WITH LLVM-exception", "Apache-2.0", "BSD-3-Clause", + "BSD-2-Clause", "CC0-1.0", "ISC", "MIT", "MPL-2.0", "OpenSSL", "Unicode-DFS-2016", + "Unicode-3.0", "Unlicense", "Zlib", ] @@ -90,4 +92,5 @@ exceptions = [ # Each entry is the crate and version constraint, and its specific allow # list #{ allow = ["license_name"], name = "crate", version = "*" }, + { allow = ["BSL-1.0"], name = "clipboard-win", version = "*" }, # in NOTICES.md ] From 3a6a456d2bad0501cd4de213208d82f746e57745 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 15:35:53 -0500 Subject: [PATCH 26/61] Fix typo --- .github/workflows/publish.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 46446ced01..de956c243e 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -23,7 +23,7 @@ jobs: publish-burn-remote: uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1 with: - crate: burn-derive + crate: burn-remote needs: - publish-burn-common - publish-burn-tensor From dd628ec91c8dafa7f5767d85f822d46dec8f4707 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 15:49:32 -0500 Subject: [PATCH 27/61] Remove self dep on burn-remote --- Cargo.lock | 1 - crates/burn-remote/Cargo.toml | 4 ---- deny.toml | 2 +- 3 files changed, 1 insertion(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 09f9e5dc14..92be93458c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -874,7 +874,6 @@ dependencies = [ "async-channel", "axum", "burn-common", - "burn-remote", "burn-router", "burn-tensor", "derive-new 0.7.0", diff --git a/crates/burn-remote/Cargo.toml b/crates/burn-remote/Cargo.toml index fa6034c681..de772861c9 100644 --- a/crates/burn-remote/Cargo.toml +++ b/crates/burn-remote/Cargo.toml @@ -43,10 +43,6 @@ axum = { version = "0.8.1", features = ["ws"], optional = true } tracing-core = { workspace = true, optional = true } tracing-subscriber = { workspace = true, optional = true } -[dev-dependencies] -# We activate the features client and server during dev. -burn-remote = { path = ".", version = "0.16.0", features=["client", "server"] } - [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"] diff --git a/deny.toml b/deny.toml index e8a251eb1c..ac64c923fe 100644 --- a/deny.toml +++ b/deny.toml @@ -76,6 +76,7 @@ allow = [ "Apache-2.0", "BSD-3-Clause", "BSD-2-Clause", + "BSL-1.0", # in NOTICES.md "CC0-1.0", "ISC", "MIT", @@ -92,5 +93,4 @@ exceptions = [ # Each entry is the crate and version constraint, and its specific allow # list #{ allow = ["license_name"], name = "crate", version = "*" }, - { allow = ["BSL-1.0"], name = "clipboard-win", version = "*" }, # in NOTICES.md ] From 93cafc41b509e618ca46b680b3015659b97a0de3 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 14 Jan 2025 18:43:58 -0500 Subject: [PATCH 28/61] Bump next version of Burn to 0.17.0 (#2698) --- Cargo.lock | 88 ++++++++++---------- Cargo.toml | 2 +- backend-comparison/Cargo.toml | 2 +- burn-book/src/advanced/no-std.md | 2 +- burn-book/src/basic-workflow/README.md | 2 +- burn-book/src/basic-workflow/model.md | 2 +- burn-book/src/import/onnx-model.md | 2 +- crates/burn-autodiff/Cargo.toml | 8 +- crates/burn-candle/Cargo.toml | 8 +- crates/burn-core/Cargo.toml | 32 +++---- crates/burn-cuda/Cargo.toml | 8 +- crates/burn-dataset/Cargo.toml | 2 +- crates/burn-fusion/Cargo.toml | 4 +- crates/burn-hip/Cargo.toml | 8 +- crates/burn-import/Cargo.toml | 6 +- crates/burn-jit/Cargo.toml | 12 +-- crates/burn-ndarray/Cargo.toml | 10 +-- crates/burn-no-std-tests/Cargo.toml | 4 +- crates/burn-remote/Cargo.toml | 6 +- crates/burn-router/Cargo.toml | 12 +-- crates/burn-tch/Cargo.toml | 6 +- crates/burn-tensor/Cargo.toml | 4 +- crates/burn-train/Cargo.toml | 6 +- crates/burn-wgpu/Cargo.toml | 8 +- crates/burn/Cargo.toml | 4 +- examples/image-classification-web/Cargo.toml | 4 +- examples/pytorch-import/Cargo.toml | 2 +- examples/pytorch-import/model/Cargo.toml | 2 +- examples/raspberry-pi-pico/Cargo.lock | 34 ++++---- examples/server/Cargo.toml | 2 +- xtask/Cargo.toml | 2 +- 31 files changed, 147 insertions(+), 147 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 92be93458c..c34fb9cd03 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -399,7 +399,7 @@ dependencies = [ [[package]] name = "backend-comparison" -version = "0.16.0" +version = "0.17.0" dependencies = [ "arboard", "burn", @@ -617,7 +617,7 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "burn" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-core", "burn-train", @@ -625,7 +625,7 @@ dependencies = [ [[package]] name = "burn-autodiff" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-tensor", @@ -637,7 +637,7 @@ dependencies = [ [[package]] name = "burn-candle" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-tch", @@ -649,7 +649,7 @@ dependencies = [ [[package]] name = "burn-common" -version = "0.16.0" +version = "0.17.0" dependencies = [ "cubecl-common", "dashmap", @@ -664,7 +664,7 @@ dependencies = [ [[package]] name = "burn-core" -version = "0.16.0" +version = "0.17.0" dependencies = [ "ahash", "bincode", @@ -702,7 +702,7 @@ dependencies = [ [[package]] name = "burn-cuda" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -717,7 +717,7 @@ dependencies = [ [[package]] name = "burn-dataset" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "csv", @@ -749,7 +749,7 @@ dependencies = [ [[package]] name = "burn-derive" -version = "0.16.0" +version = "0.17.0" dependencies = [ "derive-new 0.7.0", "proc-macro2", @@ -759,7 +759,7 @@ dependencies = [ [[package]] name = "burn-fusion" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-tensor", @@ -773,7 +773,7 @@ dependencies = [ [[package]] name = "burn-hip" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -788,7 +788,7 @@ dependencies = [ [[package]] name = "burn-import" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-ndarray", @@ -814,7 +814,7 @@ dependencies = [ [[package]] name = "burn-jit" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-common", @@ -840,7 +840,7 @@ dependencies = [ [[package]] name = "burn-ndarray" -version = "0.16.0" +version = "0.17.0" dependencies = [ "atomic_float", "blas-src", @@ -860,7 +860,7 @@ dependencies = [ [[package]] name = "burn-no-std-tests" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-ndarray", @@ -869,7 +869,7 @@ dependencies = [ [[package]] name = "burn-remote" -version = "0.16.0" +version = "0.17.0" dependencies = [ "async-channel", "axum", @@ -890,7 +890,7 @@ dependencies = [ [[package]] name = "burn-router" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-common", @@ -904,7 +904,7 @@ dependencies = [ [[package]] name = "burn-tch" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-tensor", @@ -917,7 +917,7 @@ dependencies = [ [[package]] name = "burn-tensor" -version = "0.16.0" +version = "0.17.0" dependencies = [ "bincode", "burn-common", @@ -938,7 +938,7 @@ dependencies = [ [[package]] name = "burn-tensor-testgen" -version = "0.16.0" +version = "0.17.0" dependencies = [ "proc-macro2", "quote", @@ -946,7 +946,7 @@ dependencies = [ [[package]] name = "burn-train" -version = "0.16.0" +version = "0.17.0" dependencies = [ "async-channel", "burn-core", @@ -966,7 +966,7 @@ dependencies = [ [[package]] name = "burn-wgpu" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -1819,7 +1819,7 @@ dependencies = [ [[package]] name = "custom-csv-dataset" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "csv", @@ -1829,7 +1829,7 @@ dependencies = [ [[package]] name = "custom-cubecl-kernel" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-jit", @@ -1842,7 +1842,7 @@ dependencies = [ [[package]] name = "custom-image-dataset" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "flate2", @@ -1851,7 +1851,7 @@ dependencies = [ [[package]] name = "custom-renderer" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "bytemuck", @@ -1863,7 +1863,7 @@ dependencies = [ [[package]] name = "custom-training-loop" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "bytemuck", @@ -1875,7 +1875,7 @@ dependencies = [ [[package]] name = "custom-wgpu-kernel" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "bytemuck", @@ -2917,7 +2917,7 @@ dependencies = [ [[package]] name = "guide" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "log", @@ -3417,7 +3417,7 @@ dependencies = [ [[package]] name = "image-classification-web" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-candle", @@ -3953,7 +3953,7 @@ dependencies = [ [[package]] name = "mnist" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "log", @@ -3962,7 +3962,7 @@ dependencies = [ [[package]] name = "mnist-inference-web" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "console_error_panic_hook", @@ -3974,7 +3974,7 @@ dependencies = [ [[package]] name = "model" -version = "0.5.0" +version = "0.6.0" dependencies = [ "burn", "burn-import", @@ -4046,7 +4046,7 @@ dependencies = [ [[package]] name = "named-tensor" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "serde", @@ -4522,7 +4522,7 @@ dependencies = [ [[package]] name = "onnx-inference" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-import", @@ -4531,7 +4531,7 @@ dependencies = [ [[package]] name = "onnx-ir" -version = "0.16.0" +version = "0.17.0" dependencies = [ "bytemuck", "half", @@ -4548,7 +4548,7 @@ dependencies = [ [[package]] name = "onnx-tests" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-import", @@ -5609,7 +5609,7 @@ dependencies = [ [[package]] name = "pytorch-import" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-import", @@ -5618,7 +5618,7 @@ dependencies = [ [[package]] name = "pytorch-tests" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "burn-autodiff", @@ -6584,7 +6584,7 @@ dependencies = [ [[package]] name = "server" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "cfg-if", @@ -6697,7 +6697,7 @@ checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" [[package]] name = "simple-regression" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "log", @@ -7100,7 +7100,7 @@ dependencies = [ [[package]] name = "text-classification" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "derive-new 0.7.0", @@ -7110,7 +7110,7 @@ dependencies = [ [[package]] name = "text-generation" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "derive-new 0.7.0", @@ -8511,7 +8511,7 @@ checksum = "c5b940ebc25896e71dd073bad2dbaa2abfe97b0a391415e22ad1326d9c54e3c4" [[package]] name = "xtask" -version = "1.1.0" +version = "1.2.0" dependencies = [ "log", "rstest", diff --git a/Cargo.toml b/Cargo.toml index e79690fd28..870902a9ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ exclude = [ edition = "2021" license = "MIT OR Apache-2.0" readme = "README.md" -version = "0.16.0" +version = "0.17.0" [workspace.dependencies] atomic_float = "1" diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 9e5054c6af..265dbeaaf0 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -33,7 +33,7 @@ wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"] [dependencies] arboard = { workspace = true } burn = { path = "../crates/burn", default-features = false } -burn-common = { path = "../crates/burn-common", version = "0.16.0" } +burn-common = { path = "../crates/burn-common", version = "0.17.0" } clap = { workspace = true } colored = { workspace = true } diff --git a/burn-book/src/advanced/no-std.md b/burn-book/src/advanced/no-std.md index 7689d25354..5f5621cc51 100644 --- a/burn-book/src/advanced/no-std.md +++ b/burn-book/src/advanced/no-std.md @@ -23,7 +23,7 @@ Some other dependencies have to be added ```toml [dependencies] embedded-alloc = "0.5.1" # Only if there is no default allocator for your chip -burn = { version = "0.16", default-features = false, features = ["ndarray"] } # Backend must be ndarray +burn = { version = "0.17", default-features = false, features = ["ndarray"] } # Backend must be ndarray [build-dependencies] burn-import = { version = "0.14" } # Used to auto generate the rust code to import the model diff --git a/burn-book/src/basic-workflow/README.md b/burn-book/src/basic-workflow/README.md index 5b32591a58..8515d73d2c 100644 --- a/burn-book/src/basic-workflow/README.md +++ b/burn-book/src/basic-workflow/README.md @@ -14,7 +14,7 @@ automatically add the missing imports as you add the code snippets to your code. Be sure to checkout the git branch corresponding to the version of Burn you are using to follow this guide. -The current version of Burn is `0.16` and the corresponding branch to checkout is `main`. +The current version of Burn is `0.17` and the corresponding branch to checkout is `main`. The code for this demo can be executed from Burn's base directory using the command: diff --git a/burn-book/src/basic-workflow/model.md b/burn-book/src/basic-workflow/model.md index adce46b297..ac4b16dbce 100644 --- a/burn-book/src/basic-workflow/model.md +++ b/burn-book/src/basic-workflow/model.md @@ -20,7 +20,7 @@ version = "0.1.0" edition = "2021" [dependencies] -burn = { version = "~0.16", features = ["train", "wgpu", "vision"] } +burn = { version = "~0.17", features = ["train", "wgpu", "vision"] } ``` Our goal will be to create a basic convolutional neural network used for image classification. We diff --git a/burn-book/src/import/onnx-model.md b/burn-book/src/import/onnx-model.md index 05b9d5de81..9b3b7917fd 100644 --- a/burn-book/src/import/onnx-model.md +++ b/burn-book/src/import/onnx-model.md @@ -74,7 +74,7 @@ First, add the `burn-import` crate to your `Cargo.toml`: ```toml [build-dependencies] -burn-import = "~0.16" +burn-import = "~0.17" ``` Then, in your `build.rs` file: diff --git a/crates/burn-autodiff/Cargo.toml b/crates/burn-autodiff/Cargo.toml index 2144d46885..5e221f887f 100644 --- a/crates/burn-autodiff/Cargo.toml +++ b/crates/burn-autodiff/Cargo.toml @@ -18,16 +18,16 @@ std = [] async = [] # Require std [dependencies] -burn-common = { path = "../burn-common", version = "0.16.0" } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false } -burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true } +burn-common = { path = "../burn-common", version = "0.17.0" } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false } +burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } derive-new = { workspace = true } spin = { workspace = true } log = { workspace = true } [dev-dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } diff --git a/crates/burn-candle/Cargo.toml b/crates/burn-candle/Cargo.toml index 62af31d5fb..65fbf416ca 100644 --- a/crates/burn-candle/Cargo.toml +++ b/crates/burn-candle/Cargo.toml @@ -21,17 +21,17 @@ accelerate = ["candle-core/accelerate"] [dependencies] derive-new = { workspace = true } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false } half = { workspace = true } candle-core = { workspace = true } [dev-dependencies] -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-tch = { path = "../burn-tch", version = "0.16.0", default-features = false, features = [ +burn-tch = { path = "../burn-tch", version = "0.17.0", default-features = false, features = [ ] } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index e63af0fba5..b968e28a68 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -129,21 +129,21 @@ test-wgpu-spirv = [ # ** Please make sure all dependencies support no_std when std is disabled ** -burn-common = { path = "../burn-common", version = "0.16.0", default-features = false } -burn-dataset = { path = "../burn-dataset", version = "0.16.0", optional = true, default-features = false } -burn-derive = { path = "../burn-derive", version = "0.16.0" } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false } +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } +burn-dataset = { path = "../burn-dataset", version = "0.17.0", optional = true, default-features = false } +burn-derive = { path = "../burn-derive", version = "0.17.0" } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false } # Backends -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", optional = true } -burn-candle = { path = "../burn-candle", version = "0.16.0", optional = true } -burn-cuda = { path = "../burn-cuda", version = "0.16.0", optional = true, default-features = false } -burn-hip = { path = "../burn-hip", version = "0.16.0", optional = true, default-features = false } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true, default-features = false } -burn-remote = { path = "../burn-remote", version = "0.16.0", default-features = false, optional = true } -burn-router = { path = "../burn-router", version = "0.16.0", default-features = false, optional = true } -burn-tch = { path = "../burn-tch", version = "0.16.0", optional = true } -burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", optional = true, default-features = false } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true } +burn-candle = { path = "../burn-candle", version = "0.17.0", optional = true } +burn-cuda = { path = "../burn-cuda", version = "0.17.0", optional = true, default-features = false } +burn-hip = { path = "../burn-hip", version = "0.17.0", optional = true, default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", optional = true, default-features = false } +burn-remote = { path = "../burn-remote", version = "0.17.0", default-features = false, optional = true } +burn-router = { path = "../burn-router", version = "0.17.0", default-features = false, optional = true } +burn-tch = { path = "../burn-tch", version = "0.17.0", optional = true } +burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", optional = true, default-features = false } data-encoding = { workspace = true } uuid = { workspace = true } @@ -173,13 +173,13 @@ thiserror = { workspace = true, optional = true } portable-atomic-util = { workspace = true } [dev-dependencies] -burn-dataset = { path = "../burn-dataset", version = "0.16.0", features = [ +burn-dataset = { path = "../burn-dataset", version = "0.17.0", features = [ "fake", ] } tempfile = { workspace = true } -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0" } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", default-features = false } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0" } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-cuda/Cargo.toml b/crates/burn-cuda/Cargo.toml index c366386b0e..1a92e695b2 100644 --- a/crates/burn-cuda/Cargo.toml +++ b/crates/burn-cuda/Cargo.toml @@ -19,9 +19,9 @@ fusion = ["burn-fusion", "burn-jit/fusion"] std = ["burn-jit/std", "cubecl/std"] [dependencies] -burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = [ +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", features = [ "cubecl-cuda", ] } cubecl = { workspace = true, features = ["cuda"] } @@ -34,7 +34,7 @@ log = { workspace = true } [dev-dependencies] -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } paste = { workspace = true } diff --git a/crates/burn-dataset/Cargo.toml b/crates/burn-dataset/Cargo.toml index 0237765973..c7ddbebc41 100644 --- a/crates/burn-dataset/Cargo.toml +++ b/crates/burn-dataset/Cargo.toml @@ -30,7 +30,7 @@ __sqlite-shared = [ dataframe = ["dep:polars"] [dependencies] -burn-common = { path = "../burn-common", version = "0.16.0", optional = true, features = [ +burn-common = { path = "../burn-common", version = "0.17.0", optional = true, features = [ "network", ] } csv = { workspace = true } diff --git a/crates/burn-fusion/Cargo.toml b/crates/burn-fusion/Cargo.toml index eb4296097b..1f2f785940 100644 --- a/crates/burn-fusion/Cargo.toml +++ b/crates/burn-fusion/Cargo.toml @@ -17,8 +17,8 @@ std = ["serde/std"] doc = ["default"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0" } -burn-common = { path = "../burn-common", version = "0.16.0" } +burn-tensor = { path = "../burn-tensor", version = "0.17.0" } +burn-common = { path = "../burn-common", version = "0.17.0" } hashbrown = { workspace = true } derive-new = {workspace = true } spin = { workspace = true } diff --git a/crates/burn-hip/Cargo.toml b/crates/burn-hip/Cargo.toml index d5f0bb70f5..206f56e8fe 100644 --- a/crates/burn-hip/Cargo.toml +++ b/crates/burn-hip/Cargo.toml @@ -20,9 +20,9 @@ std = ["burn-jit/std", "cubecl/std"] [dependencies] cubecl = { workspace = true, features = ["hip"] } -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", features = ["cubecl-hip"] } -burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", features = ["cubecl-hip"] } +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } half = { workspace = true } bytemuck = { workspace = true } @@ -31,7 +31,7 @@ log = { workspace = true } derive-new = { workspace = true } [dev-dependencies] -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } paste = { workspace = true } diff --git a/crates/burn-import/Cargo.toml b/crates/burn-import/Cargo.toml index 14a8bffa32..ee7c1c559e 100644 --- a/crates/burn-import/Cargo.toml +++ b/crates/burn-import/Cargo.toml @@ -20,9 +20,9 @@ onnx = [] pytorch = ["burn/record-item-custom-serde", "thiserror", "zip"] [dependencies] -burn = { path = "../burn", version = "0.16.0", default-features = false, features = ["std"]} -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", default-features = false } -onnx-ir = { path = "../onnx-ir", version = "0.16.0" } +burn = { path = "../burn", version = "0.17.0", default-features = false, features = ["std"]} +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false } +onnx-ir = { path = "../onnx-ir", version = "0.17.0" } candle-core = { workspace = true } derive-new = { workspace = true } half = { workspace = true } diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index 2bd0ba6f7e..214b21eef3 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -31,9 +31,9 @@ std = ["cubecl/std", "burn-tensor/std"] template = [] [dependencies] -burn-common = { path = "../burn-common", version = "0.16.0" } -burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-common = { path = "../burn-common", version = "0.17.0" } +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "cubecl", "repr", ] } @@ -54,12 +54,12 @@ futures-lite = { workspace = true, features = ["std"] } serde = { workspace = true } text_placeholder = { workspace = true, features = ["struct_context"] } -burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true } +burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } hashbrown = { workspace = true } # When exporting tests -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, optional = true } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", optional = true } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, optional = true } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", optional = true } paste = { workspace = true, optional = true } serial_test = { workspace = true, optional = true } diff --git a/crates/burn-ndarray/Cargo.toml b/crates/burn-ndarray/Cargo.toml index 89253cd7e8..167cf88c1a 100644 --- a/crates/burn-ndarray/Cargo.toml +++ b/crates/burn-ndarray/Cargo.toml @@ -43,9 +43,9 @@ blas-openblas-system = [ # ** Please make sure all dependencies support no_std when std is disabled ** -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", optional = true } -burn-common = { path = "../burn-common", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = ["repr"] } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true } +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = ["repr"] } atomic_float = { workspace = true } blas-src = { workspace = true, default-features = false, optional = true } # no-std compatible @@ -62,10 +62,10 @@ spin = { workspace = true } # usi portable-atomic-util = { workspace = true } [dev-dependencies] -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } diff --git a/crates/burn-no-std-tests/Cargo.toml b/crates/burn-no-std-tests/Cargo.toml index e15ce56d15..77c7524f6f 100644 --- a/crates/burn-no-std-tests/Cargo.toml +++ b/crates/burn-no-std-tests/Cargo.toml @@ -14,7 +14,7 @@ version.workspace = true # ** Please make sure all dependencies support no_std ** -burn = { path = "../burn", version = "0.16.0", default-features = false } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0", default-features = false } +burn = { path = "../burn", version = "0.17.0", default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0", default-features = false } serde = { workspace = true } diff --git a/crates/burn-remote/Cargo.toml b/crates/burn-remote/Cargo.toml index de772861c9..9ebd0c8568 100644 --- a/crates/burn-remote/Cargo.toml +++ b/crates/burn-remote/Cargo.toml @@ -19,9 +19,9 @@ server = ["axum", "tracing-core", "tracing-subscriber"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = true, features = ["repr"]} -burn-common = { path = "../burn-common", version = "0.16.0", default-features = true} -burn-router = { path = "../burn-router", version = "0.16.0", default-features = true} +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = true, features = ["repr"]} +burn-common = { path = "../burn-common", version = "0.17.0", default-features = true} +burn-router = { path = "../burn-router", version = "0.17.0", default-features = true} # Basic dependencies derive-new = {workspace = true } diff --git a/crates/burn-router/Cargo.toml b/crates/burn-router/Cargo.toml index f6df54e59f..6f21d63640 100644 --- a/crates/burn-router/Cargo.toml +++ b/crates/burn-router/Cargo.toml @@ -17,22 +17,22 @@ std = ["burn-tensor/std", "burn-common/std"] doc = ["default"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = ["repr"]} -burn-common = { path = "../burn-common", version = "0.16.0", default-features = false} +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = ["repr"]} +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false} hashbrown = { workspace = true } spin = { workspace = true } log = { workspace = true } [dev-dependencies] -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" } -burn-wgpu = { path = "../burn-wgpu", version = "0.16.0", default-features = false } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } +burn-wgpu = { path = "../burn-wgpu", version = "0.17.0", default-features = false } [package.metadata.docs.rs] diff --git a/crates/burn-tch/Cargo.toml b/crates/burn-tch/Cargo.toml index 44702c21ba..69b0240c34 100644 --- a/crates/burn-tch/Cargo.toml +++ b/crates/burn-tch/Cargo.toml @@ -16,7 +16,7 @@ default = [] doc = ["tch/doc-only"] [dependencies] -burn-tensor = { path = "../burn-tensor", version = "0.16.0" } +burn-tensor = { path = "../burn-tensor", version = "0.17.0" } half = { workspace = true, features = ["std"] } libc = { workspace = true } @@ -25,10 +25,10 @@ tch = { workspace = true, features = ["download-libtorch"] } log = { workspace = true } [dev-dependencies] -burn-autodiff = { path = "../burn-autodiff", version = "0.16.0", default-features = false, features = [ +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, features = [ "export_tests", ] } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "export_tests", ] } diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index 55e14e174c..318912b2f7 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -30,8 +30,8 @@ std = [ ] [dependencies] -burn-common = { path = "../burn-common", version = "0.16.0", default-features = false } -burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.16.0", optional = true } +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } +burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } cubecl = { workspace = true, optional = true, default-features = true } bytemuck = { workspace = true, features = ["extern_crate_alloc"] } diff --git a/crates/burn-train/Cargo.toml b/crates/burn-train/Cargo.toml index 35707f5052..b922a1a59e 100644 --- a/crates/burn-train/Cargo.toml +++ b/crates/burn-train/Cargo.toml @@ -18,7 +18,7 @@ metrics = ["nvml-wrapper", "sysinfo", "systemstat"] tui = ["ratatui"] [dependencies] -burn-core = { path = "../burn-core", version = "0.16.0", features = [ +burn-core = { path = "../burn-core", version = "0.17.0", features = [ "dataset", "std", ], default-features = false } @@ -40,11 +40,11 @@ ratatui = { workspace = true, optional = true, features = ["all-widgets", "cross derive-new = { workspace = true } serde = { workspace = true, features = ["std", "derive"] } async-channel = { workspace = true } -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } rstest.workspace = true [dev-dependencies] -burn-ndarray = { path = "../burn-ndarray", version = "0.16.0" } +burn-ndarray = { path = "../burn-ndarray", version = "0.17.0" } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index 055b53ae2f..c2e034ada5 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -24,15 +24,15 @@ template = ["burn-jit/template", "cubecl/template"] [dependencies] cubecl = { workspace = true, features = ["wgpu"] } -burn-fusion = { path = "../burn-fusion", version = "0.16.0", optional = true } -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.16.0", default-features = false, features = [ +burn-fusion = { path = "../burn-fusion", version = "0.17.0", optional = true } +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = [ "cubecl-wgpu", ] } [dev-dependencies] -burn-jit = { path = "../burn-jit", version = "0.16.0", default-features = false, features = [ +burn-jit = { path = "../burn-jit", version = "0.17.0", default-features = false, features = [ "export_tests", ] } half = { workspace = true } diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index 7f6af14fbb..0e7ff51e88 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -74,5 +74,5 @@ record-item-custom-serde = ["burn-core/record-item-custom-serde"] # ** Please make sure all dependencies support no_std when std is disabled ** -burn-core = { path = "../burn-core", version = "0.16.0", default-features = false } -burn-train = { path = "../burn-train", version = "0.16.0", optional = true, default-features = false } +burn-core = { path = "../burn-core", version = "0.17.0", default-features = false } +burn-train = { path = "../burn-train", version = "0.17.0", optional = true, default-features = false } diff --git a/examples/image-classification-web/Cargo.toml b/examples/image-classification-web/Cargo.toml index 44591bdad1..9429b24d25 100644 --- a/examples/image-classification-web/Cargo.toml +++ b/examples/image-classification-web/Cargo.toml @@ -14,10 +14,10 @@ default = [] half_precision = [] [dependencies] -burn = { path = "../../crates/burn", version = "0.16.0", default-features = false, features = [ +burn = { path = "../../crates/burn", version = "0.17.0", default-features = false, features = [ "ndarray", "wgpu", ] } -burn-candle = { path = "../../crates/burn-candle", version = "0.16.0", default-features = false } +burn-candle = { path = "../../crates/burn-candle", version = "0.17.0", default-features = false } log = { workspace = true } serde = { workspace = true } diff --git a/examples/pytorch-import/Cargo.toml b/examples/pytorch-import/Cargo.toml index a7b3305689..dd2b56e92d 100644 --- a/examples/pytorch-import/Cargo.toml +++ b/examples/pytorch-import/Cargo.toml @@ -4,7 +4,7 @@ edition = "2021" license = "MIT OR Apache-2.0" name = "pytorch-import" publish = false -version = "0.16.0" +version = "0.17.0" [dependencies] burn = { path = "../../crates/burn", features = [ diff --git a/examples/pytorch-import/model/Cargo.toml b/examples/pytorch-import/model/Cargo.toml index 894ac7e48f..f2678bfcbc 100644 --- a/examples/pytorch-import/model/Cargo.toml +++ b/examples/pytorch-import/model/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "model" -version = "0.5.0" +version = "0.6.0" edition = "2021" [dependencies] diff --git a/examples/raspberry-pi-pico/Cargo.lock b/examples/raspberry-pi-pico/Cargo.lock index 2cbc8fb721..a2f5e866d3 100644 --- a/examples/raspberry-pi-pico/Cargo.lock +++ b/examples/raspberry-pi-pico/Cargo.lock @@ -286,7 +286,7 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "burn" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-core", "burn-train", @@ -294,7 +294,7 @@ dependencies = [ [[package]] name = "burn-autodiff" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-tensor", @@ -305,7 +305,7 @@ dependencies = [ [[package]] name = "burn-candle" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-tensor", "candle-core", @@ -315,7 +315,7 @@ dependencies = [ [[package]] name = "burn-common" -version = "0.16.0" +version = "0.17.0" dependencies = [ "cubecl-common", "data-encoding", @@ -326,7 +326,7 @@ dependencies = [ [[package]] name = "burn-core" -version = "0.16.0" +version = "0.17.0" dependencies = [ "bincode", "burn-autodiff", @@ -357,7 +357,7 @@ dependencies = [ [[package]] name = "burn-cuda" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -371,7 +371,7 @@ dependencies = [ [[package]] name = "burn-dataset" -version = "0.16.0" +version = "0.17.0" dependencies = [ "csv", "derive-new", @@ -395,7 +395,7 @@ dependencies = [ [[package]] name = "burn-derive" -version = "0.16.0" +version = "0.17.0" dependencies = [ "derive-new", "proc-macro2", @@ -405,7 +405,7 @@ dependencies = [ [[package]] name = "burn-fusion" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-tensor", @@ -418,7 +418,7 @@ dependencies = [ [[package]] name = "burn-import" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn", "candle-core", @@ -441,7 +441,7 @@ dependencies = [ [[package]] name = "burn-jit" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "burn-fusion", @@ -461,7 +461,7 @@ dependencies = [ [[package]] name = "burn-ndarray" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-autodiff", "burn-common", @@ -478,7 +478,7 @@ dependencies = [ [[package]] name = "burn-tch" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-tensor", "half", @@ -489,7 +489,7 @@ dependencies = [ [[package]] name = "burn-tensor" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-common", "bytemuck", @@ -507,7 +507,7 @@ dependencies = [ [[package]] name = "burn-train" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-core", "crossterm", @@ -525,7 +525,7 @@ dependencies = [ [[package]] name = "burn-wgpu" -version = "0.16.0" +version = "0.17.0" dependencies = [ "burn-fusion", "burn-jit", @@ -2959,7 +2959,7 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "onnx-ir" -version = "0.16.0" +version = "0.17.0" dependencies = [ "bytemuck", "half", diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index 5d06497e08..bb4824fba9 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -15,4 +15,4 @@ ndarray = ["burn/ndarray"] [dependencies] cfg-if = { workspace = true } -burn = { path = "../../crates/burn", version = "0.16.0", features = ["server"] } +burn = { path = "../../crates/burn", version = "0.17.0", features = ["server"] } diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index ce796eb7b1..63ac5e4c70 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xtask" -version = "1.1.0" +version = "1.2.0" edition = "2021" license = "MIT OR Apache-2.0" From ad81344821bbdca2fde1b5debda4d83d96370876 Mon Sep 17 00:00:00 2001 From: tiruka <33803972+tiruka@users.noreply.github.com> Date: Thu, 16 Jan 2025 01:44:50 +0900 Subject: [PATCH 29/61] Feature add new one hot function meeting multi-dimensions (ranks) (#2613) * add one hot with axis and values function * update one hot multidimentional function * implementing on numeric.rs * update one hot method in numeric * update one hot function to deal with additional dims add one hot test * added tests for one hot * modify function name modify format add tests * modify to respond to difference between Tensor type and values type * fix clippy point out and doc test * do refactoring modify comments * update burn book to publish one hot plus method * modify one_hot_plus to one_hot_fill and args names * modify one_hot function in int impl and float impl modify one_hot tests * modify numeric to clear logic * modify miscs due to validation, linnter and formatter * modify documents for tensor api * modify codes to follow review comments * modify codes to follow reviews * modify tests to follow reviews comments * Improve check message --------- Co-authored-by: Guillaume Lagrange --- burn-book/src/building-blocks/tensor.md | 4 +- crates/burn-tensor/src/tensor/api/check.rs | 34 +++--- crates/burn-tensor/src/tensor/api/float.rs | 34 +----- crates/burn-tensor/src/tensor/api/int.rs | 30 ----- crates/burn-tensor/src/tensor/api/numeric.rs | 97 ++++++++++++++++ crates/burn-tensor/src/tests/ops/one_hot.rs | 112 +++++++++++++------ 6 files changed, 193 insertions(+), 118 deletions(-) diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index fb429ffd0f..8a7c01bbc9 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -228,6 +228,8 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.neg()` or `-tensor` | `-tensor` | | `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` | | `tensor.ones_like()` | `torch.ones_like(tensor)` | +| `tensor.one_hot(num_classes)` | `torch.nn.functional.one_hot` | +| `tensor.one_hot_fill(num_classes, on_value, off_value, axis)` | N/A | | `tensor.pad(pads, value)` | `torch.nn.functional.pad(input, pad, value)` | | `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` | | `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` | @@ -258,7 +260,6 @@ Those operations are only available for `Float` tensors. | Burn API | PyTorch Equivalent | | --------------------------------------------- | ---------------------------------- | -| `Tensor::one_hot(index, num_classes, device)` | N/A | | `tensor.cast(dtype)` | `tensor.to(dtype)` | | `tensor.ceil()` | `tensor.ceil()` | | `tensor.cos()` | `tensor.cos()` | @@ -296,7 +297,6 @@ Those operations are only available for `Int` tensors. | `tensor.from_ints(ints)` | N/A | | `tensor.int_random(shape, distribution, device)` | N/A | | `tensor.cartesian_grid(shape, device)` | N/A | -| `tensor.one_hot(num_classes)` | N/A | ### Bool Operations diff --git a/crates/burn-tensor/src/tensor/api/check.rs b/crates/burn-tensor/src/tensor/api/check.rs index d4ab13faf4..8a6fb2ad78 100644 --- a/crates/burn-tensor/src/tensor/api/check.rs +++ b/crates/burn-tensor/src/tensor/api/check.rs @@ -1,4 +1,4 @@ -use crate::{backend::Backend, BasicOps, Int, Shape, Tensor}; +use crate::{backend::Backend, BasicOps, Numeric, Shape, Tensor}; use alloc::format; use alloc::string::{String, ToString}; use alloc::vec; @@ -447,22 +447,8 @@ impl TensorCheck { check } - pub(crate) fn one_hot_index(index: usize, num_classes: usize) -> Self { - let mut check = Self::Ok; - if index >= num_classes { - check = check.register( - "One Hot", - TensorError::new(format!( - "Can't create a one hot tensor with index ({index}) greater or equal to the number of classes ({num_classes})", - )), - ); - } - - check - } - - pub(crate) fn one_hot_tensor( - index_tensor: Tensor, + pub(crate) fn one_hot_tensor>( + index_tensor: Tensor, num_classes: usize, ) -> Self { let mut check = Self::Ok; @@ -487,6 +473,20 @@ impl TensorCheck { check } + pub(crate) fn one_hot_tensor_rank() -> Self { + let mut check = Self::Ok; + if D + 1 != D2 { + check = check.register( + "One Hot", + TensorError::new( + "The one-hot tensor rank must correspond to the rank of the tensor + 1", + ) + .details(format!("Expected D2={}, got {D2}", D + 1)), + ); + } + check + } + pub(crate) fn swap_dims(dim1: usize, dim2: usize) -> Self { let mut check = Self::Ok; diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index a6f59f6e88..b50d0d0596 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -1,11 +1,8 @@ -use alloc::vec::Vec; -use core::convert::TryInto; - use crate::check::TensorCheck; use crate::quantization::{QuantizationParameters, QuantizationScheme}; use crate::tensor::backend::Backend; use crate::tensor::stats; -use crate::tensor::{Distribution, Shape, TensorData}; +use crate::tensor::{Distribution, TensorData}; use crate::Tensor; use crate::{check, FloatDType}; use crate::{Int, TensorPrimitive}; @@ -174,35 +171,6 @@ where ))) } - /// Create a one hot tensor. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::Tensor; - /// - /// fn example() { - /// let device = Default::default(); - /// let one_hot = Tensor::::one_hot(2, 10, &device); - /// println!("{}", one_hot.to_data()); - /// // [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - /// } - /// ``` - pub fn one_hot(index: usize, num_classes: usize, device: &B::Device) -> Self { - check!(TensorCheck::one_hot_index(index, num_classes)); - - let mut dims = [1; D]; - dims[D - 1] = num_classes; - let shape = Shape::new(dims); - let ranges: Vec<_> = shape.dims.iter().map(|dim| 0..*dim).collect(); - let tensor = Tensor::zeros(shape, device); - let mut ranges: [core::ops::Range; D] = ranges.try_into().unwrap(); - ranges[D - 1] = index..index + 1; - - tensor.slice_assign(ranges, Tensor::ones(Shape::new([1; D]), device)) - } - /// Applies the matrix multiplication operation. /// /// `C = AB` diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index 08bdab0fe7..e882a107c7 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -1,5 +1,3 @@ -use crate::check; -use crate::check::TensorCheck; use crate::{ backend::Backend, cartesian_grid, Float, Int, Shape, Tensor, TensorData, TensorPrimitive, }; @@ -29,34 +27,6 @@ where pub fn arange_step(range: Range, step: usize, device: &B::Device) -> Self { Tensor::new(B::int_arange_step(range, step, device)) } - - /// Create a one hot tensor from an index tensor. - /// - /// # Arguments - /// - /// * `num_classes` - The number of classes to use in encoding. - /// - /// # Example - /// - /// ```rust - /// use burn_tensor::backend::Backend; - /// use burn_tensor::{Tensor, Int}; - /// - /// fn example() { - /// let device = B::Device::default(); - /// let indices: Tensor = Tensor::from_ints([0, 1, 2, 3], &device); - /// let one_hot = indices.one_hot(4); - /// println!("{}", one_hot.to_data()); - /// // [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] - /// } - /// ``` - pub fn one_hot(self, num_classes: usize) -> Tensor { - check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); - let [num_samples] = self.dims(); - let indices = self.unsqueeze_dim(1); - let values = indices.ones_like(); - Tensor::zeros([num_samples, num_classes], &indices.device()).scatter(1, indices, values) - } } impl Tensor diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 59dc44b7e6..b82175c3fe 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -2034,6 +2034,103 @@ where // Assign the original tensor data to the appropriate slice of the padded tensor padded_tensor.slice_assign(ranges, self) } + /// Create a one hot tensor. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::Tensor; + /// + /// fn example(){ + /// let device = Default::default(); + /// let indices: Tensor = Tensor::from_floats([0.0, 1.0, 2.0, 3.0], &device); + /// let one_hot: Tensor = indices.one_hot(4); + /// println!("{}", one_hot.to_data()); + /// // [[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]] + /// } + /// ``` + pub fn one_hot(self, num_classes: usize) -> Tensor { + check!(TensorCheck::one_hot_tensor(self.clone(), num_classes)); + self.one_hot_fill(num_classes, 1.0, 0.0, -1) + } + + /// Create a one-hot encoded tensor with configurable `num_classes`, `on_value`, `off_value`, and `axis` including high-ranked tensors. + /// + /// # Arguments + /// + /// * `num_classes`: The number of classes for the one-hot encoding, which defines the size of the one-hot dimension. + /// * `on_value`: The value to assign for active positions (corresponding to indices). + /// * `off_value`: The value to assign for inactive positions. + /// * `axis`: The axis along which the one-hot dimension is added. Supports negative indexing. + /// + /// # Returns + /// + /// A tensor with one additional dimension for the one-hot encoding, where active positions are filled with `on_value` and others with `off_value`. + /// + /// # Example + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Float}; + /// fn example>>() { + /// let device = B::Device::default(); + /// let indices: Tensor = Tensor::from_floats([[0., 2.], [1., -1.]], &device); + /// // One-hot encoding + /// let tensor:Tensor = indices.one_hot_fill(3, 5.0.into(), 0.0.into(), -1); + /// println!("{tensor}"); + /// // [[[5.0, 0.0, 0.0], + /// // [0.0, 0.0, 5.0]], + /// // [[0.0, 5.0, 0.0], + /// // [0.0, 0.0, 5.0]]] + /// } + /// ``` + pub fn one_hot_fill( + self, + num_classes: usize, + on_value: f32, + off_value: f32, + axis: i64, + ) -> Tensor { + check!(TensorCheck::one_hot_tensor_rank::()); + // Initialize shape from the current tensor dimensions and prepare for modification + let mut shape = self.shape().dims::().to_vec(); + let device = self.device(); + let rank = self.dims().len(); + + // Adjust negative axis to a positive index + let axis = if axis < 0 { + axis + rank as i64 + 1 + } else { + axis + }; + + // Ensure axis is within valid range + if axis < 0 || axis > rank as i64 { + panic!("Axis out of range. Accepted range is [-r-1, r] where r = rank(indices)."); + } + // Convert the input tensor to integer indices + let indices: Tensor = + Tensor::from_data(self.to_data().convert::(), &device); + // Insert the new dimension for the one-hot representation + shape.insert(axis as usize, num_classes); + // Adjust indices to valid range and handle invalid indices + let adjusted_indices = indices + .clone() + .mask_fill(self.clone().lower_elem(0), num_classes as i64) // Handle negative indices + .add(indices.clone().mask_fill(self.clone().greater_elem(0), 0)); // Handle positive indices + // Unsqueeze the indices tensor along the specified axis + let indices_unsqueezed: Tensor = adjusted_indices.unsqueeze_dim(axis as usize); + + // Initialize the output tensor with the off_value + let output = Tensor::full(shape.clone(), off_value, &device); + + // Prepare scatter tensor for on_value and off_value adjustments + let scatter_on_values = Tensor::full(indices_unsqueezed.shape(), on_value, &device) + - Tensor::full(indices_unsqueezed.shape(), off_value, &self.device()); + + // Scatter on_value at the appropriate indices to create the one-hot representation + output.scatter(axis as usize, indices_unsqueezed, scatter_on_values) + } /// Returns a new tensor with boolean elements indicating whether each element of the input is NaN. /// diff --git a/crates/burn-tensor/src/tests/ops/one_hot.rs b/crates/burn-tensor/src/tests/ops/one_hot.rs index 310399119f..24e8f24b38 100644 --- a/crates/burn-tensor/src/tests/ops/one_hot.rs +++ b/crates/burn-tensor/src/tests/ops/one_hot.rs @@ -1,74 +1,114 @@ #[burn_tensor_testgen::testgen(one_hot)] mod tests { use super::*; - use burn_tensor::{Int, TensorData}; + use burn_tensor::{ + as_type, + backend::Backend, + tests::{Float as _, Int as _}, + Float, Int, Numeric, Shape, Tensor, TensorData, + }; #[test] fn float_should_support_one_hot() { - let device = Default::default(); - - let tensor = TestTensor::<1>::one_hot(0, 5, &device); - let expected = TensorData::from([1., 0., 0., 0., 0.]); - tensor.into_data().assert_eq(&expected, false); - - let tensor = TestTensor::<1>::one_hot(1, 5, &device); - let expected = TensorData::from([0., 1., 0., 0., 0.]); - tensor.into_data().assert_eq(&expected, false); - - let tensor = TestTensor::<1>::one_hot(4, 5, &device); - let expected = TensorData::from([0., 0., 0., 0., 1.]); - tensor.into_data().assert_eq(&expected, false); + let tensor = TestTensor::<1>::from([0.0, 1.0, 4.0]); + let one_hot_tensor: Tensor = tensor.one_hot(5); + let expected = TensorData::from([ + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ]); + one_hot_tensor.into_data().assert_eq(&expected, false); + } - let tensor = TestTensor::<1>::one_hot(1, 2, &device); - let expected = TensorData::from([0., 1.]); - tensor.into_data().assert_eq(&expected, false); + #[test] + fn float_should_support_one_hot_index() { + let tensor = TestTensor::<1>::from([2.0]); + let one_hot_tensor: Tensor = tensor.one_hot::<2>(10); + let expected = TensorData::from([[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]); + one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn float_one_hot_should_panic_when_index_exceeds_number_of_classes() { - let device = Default::default(); - let tensor = TestTensor::<1>::one_hot(1, 1, &device); + let tensor = TestTensor::<1>::from([5.0]); + let result: Tensor = tensor.one_hot(5); } #[test] #[should_panic] fn float_one_hot_should_panic_when_number_of_classes_is_zero() { - let device = Default::default(); - let tensor = TestTensor::<1>::one_hot(0, 0, &device); + let tensor = TestTensor::<1>::from([0.0]); + let result: Tensor = tensor.one_hot(0); } #[test] fn int_should_support_one_hot() { - let device = Default::default(); - - let index_tensor = TestTensorInt::<1>::arange(0..5, &device); - let one_hot_tensor = index_tensor.one_hot(5); - let expected = TestTensorInt::eye(5, &device).into_data(); + let tensor = TestTensorInt::<1>::from([0, 1, 4]); + let one_hot_tensor: Tensor = tensor.one_hot(5); + let expected = TensorData::from([[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]]); one_hot_tensor.into_data().assert_eq(&expected, false); } #[test] #[should_panic] fn int_one_hot_should_panic_when_index_exceeds_number_of_classes() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..6, &device); - let one_hot_tensor = index_tensor.one_hot(5); + let tensor = TestTensorInt::<1>::from([5]); + let result: Tensor = tensor.one_hot(5); } #[test] #[should_panic] fn int_one_hot_should_panic_when_number_of_classes_is_zero() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..3, &device); - let one_hot_tensor = index_tensor.one_hot(0); + let tensor = TestTensorInt::<1>::from([2]); + let result: Tensor = tensor.one_hot(0); + } + + #[test] + fn one_hot_fill_with_positive_axis_and_indices() { + let tensor = TestTensorInt::<2>::from([[1, 9], [2, 4]]); + let expected = TensorData::from(as_type!(IntType: [ + [[1, 1], [3, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 3]], + [[1, 1], [1, 1], [3, 1], [1, 1], [1, 3], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1]] + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3.0, 1.0, 1); + + one_hot_tensor.into_data().assert_eq(&expected, true); + } + + #[test] + fn one_hot_fill_with_negative_axis_and_indices() { + let tensor = TestTensor::<2>::from([[0, 2], [1, -1]]); + let expected = TensorData::from(as_type!(FloatType: [ + [[5.0, 0.0, 0.0], [0.0, 0.0, 5.0]], + [[0.0, 5.0, 0.0], [0.0, 0.0, 5.0]] + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(3, 5.0, 0.0, -1); + + one_hot_tensor.into_data().assert_eq(&expected, true); } #[test] + fn one_hot_fill_with_negative_indices() { + let tensor = TestTensor::<1>::from([0.0, -7.0, -8.0]); + let expected = TensorData::from(as_type!(FloatType: [ + [3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ])); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(10, 3.0, 1.0, 1); + + one_hot_tensor.into_data().assert_eq(&expected, true); + } + #[should_panic] - fn int_one_hot_should_panic_when_number_of_classes_is_1() { - let device = Default::default(); - let index_tensor = TestTensorInt::<1>::arange(0..3, &device); - let one_hot_tensor = index_tensor.one_hot(1); + #[test] + fn one_hot_fill_should_panic_when_axis_out_range_of_rank() { + let tensor = TestTensor::<2>::from([[0.0, 2.0], [1.0, -1.0]]); + + let one_hot_tensor: Tensor = tensor.one_hot_fill(2, 5.0, 0.0, 3); } } From f630b3bc7d2d7fac0b972ea001c33daa7c32dd22 Mon Sep 17 00:00:00 2001 From: jiawen wang Date: Thu, 16 Jan 2025 00:45:20 +0800 Subject: [PATCH 30/61] Wasserstein Generative Adversarial Network (#2660) * Add files via upload Wasserstein Generative Adversarial Network * Delete examples/wgan/readme * Create README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update cli.rs * Update cli.rs * Update model.rs * Update training.rs * Update main.rs * Update model.rs * Update training.rs * Update training.rs * Update main.rs * Update training.rs * Update model.rs * Update training.rs * Update cli.rs * Update cli.rs * Update generating.rs * Update lib.rs * Update model.rs * Update training.rs * Update main.rs * Update generating.rs * Update model.rs * Update training.rs * Update generating.rs * Update model.rs * Update training.rs * Update training.rs * Update dataset.rs * Update generating.rs * Update model.rs * Update training.rs * Update training.rs * Update training.rs * Restructure as workspace example * Add support for single range slice (fixes clippy) * Update example usage + list --------- Co-authored-by: Guillaume Lagrange --- Cargo.lock | 8 + README.md | 2 + burn-book/src/examples.md | 1 + crates/burn-tensor/src/tensor/api/base.rs | 8 + crates/burn-tensor/src/tests/ops/slice.rs | 11 ++ examples/wgan/Cargo.toml | 18 ++ examples/wgan/README.md | 40 ++++ examples/wgan/examples/wgan-generate.rs | 95 ++++++++++ examples/wgan/examples/wgan-mnist.rs | 107 +++++++++++ examples/wgan/src/dataset.rs | 49 +++++ examples/wgan/src/infer.rs | 41 +++++ examples/wgan/src/lib.rs | 4 + examples/wgan/src/model.rs | 157 ++++++++++++++++ examples/wgan/src/training.rs | 211 ++++++++++++++++++++++ 14 files changed, 752 insertions(+) create mode 100644 examples/wgan/Cargo.toml create mode 100644 examples/wgan/README.md create mode 100644 examples/wgan/examples/wgan-generate.rs create mode 100644 examples/wgan/examples/wgan-mnist.rs create mode 100644 examples/wgan/src/dataset.rs create mode 100644 examples/wgan/src/infer.rs create mode 100644 examples/wgan/src/lib.rs create mode 100644 examples/wgan/src/model.rs create mode 100644 examples/wgan/src/training.rs diff --git a/Cargo.lock b/Cargo.lock index c34fb9cd03..1af3585919 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8002,6 +8002,14 @@ version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" +[[package]] +name = "wgan" +version = "0.1.0" +dependencies = [ + "burn", + "image", +] + [[package]] name = "wgpu" version = "23.0.1" diff --git a/README.md b/README.md index a0780dcc16..d0ccbcf411 100644 --- a/README.md +++ b/README.md @@ -567,6 +567,8 @@ Additional examples: sample. - [Text Generation](./examples/text-generation) : Trains a text generation transformer model on the DbPedia dataset. +- [Wasserstein GAN MNIST](./examples/wgan) : Trains a WGAN model to generate new handwritten digits + based on MNIST. For more practical insights, you can clone the repository and run any of them directly on your computer! diff --git a/burn-book/src/examples.md b/burn-book/src/examples.md index c9703a4389..2b083b6fbe 100644 --- a/burn-book/src/examples.md +++ b/burn-book/src/examples.md @@ -85,6 +85,7 @@ The following additional examples are currently available if you want to check t | [PyTorch Import Inference](https://github.com/tracel-ai/burn/tree/main/examples/pytorch-import) | Imports a PyTorch model pre-trained on MNIST to perform inference on a sample image with Burn. | | [Text Classification](https://github.com/tracel-ai/burn/tree/main/examples/text-classification) | Trains a text classification transformer model on the AG News or DbPedia datasets. The trained model can then be used to classify a text sample. | | [Text Generation](https://github.com/tracel-ai/burn/tree/main/examples/text-generation) | Trains a text generation transformer model on the DbPedia dataset. | +| [Wasserstein GAN MNIST](https://github.com/tracel-ai/burn/tree/main/examples/wgan) | Trains a WGAN model to generate new handwritten digits based on MNIST. | For more information on each example, see their respective `README.md` file. Be sure to check out the [examples](https://github.com/tracel-ai/burn/tree/main/examples) directory for an up-to-date diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index fabf321d96..4bbc522f49 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -805,6 +805,7 @@ where /// # Arguments /// /// * `ranges` - A type implementing the `RangesArg` trait, which can be: + /// - A single `core::ops::Range` (slice the first dimension) /// - An array of `core::ops::Range` /// - An array of `Option<(i64, i64)>` /// - An array of `(i64, i64)` tuples @@ -2988,6 +2989,13 @@ impl RangesArg for [(i64, i64); D2] { } } +impl RangesArg<1> for core::ops::Range { + fn into_ranges(self, shape: Shape) -> [core::ops::Range; 1] { + let (start, end) = Self::clamp_range(self.start, self.end, shape.dims[0]); + [(start..end)] + } +} + /// Trait used for reshape arguments. pub trait ReshapeArgs { /// Converts to a shape. diff --git a/crates/burn-tensor/src/tests/ops/slice.rs b/crates/burn-tensor/src/tests/ops/slice.rs index 61725a506a..1be5b76315 100644 --- a/crates/burn-tensor/src/tests/ops/slice.rs +++ b/crates/burn-tensor/src/tests/ops/slice.rs @@ -47,6 +47,17 @@ mod tests { output.into_data().assert_eq(&expected, false); } + #[test] + fn should_support_slice_range_first_dim() { + let data = TensorData::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let tensor = TestTensor::<2>::from_data(data, &Default::default()); + + let output = tensor.slice(0..1); + let expected = TensorData::from([[0.0, 1.0, 2.0]]); + + output.into_data().assert_eq(&expected, false); + } + #[test] fn should_support_partial_sliceing_3d() { let tensor = TestTensor::<3>::from_floats( diff --git a/examples/wgan/Cargo.toml b/examples/wgan/Cargo.toml new file mode 100644 index 0000000000..48d5680f51 --- /dev/null +++ b/examples/wgan/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "wgan" +version = "0.1.0" +edition = "2021" + +[features] +ndarray = ["burn/ndarray"] +ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] +ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] +ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] +tch-cpu = ["burn/tch"] +tch-gpu = ["burn/tch"] +wgpu = ["burn/wgpu"] +cuda-jit = ["burn/cuda-jit"] + +[dependencies] +burn = { path = "../../crates/burn", features=["train", "vision"] } +image = { workspace = true } diff --git a/examples/wgan/README.md b/examples/wgan/README.md new file mode 100644 index 0000000000..d7252ba520 --- /dev/null +++ b/examples/wgan/README.md @@ -0,0 +1,40 @@ +# Wasserstein Generative Adversarial Network + +A burn implementation of examplar WGAN model to generate MNIST digits inspired by +[the PyTorch implementation](https://bytepawn.com/training-a-pytorch-wasserstain-mnist-gan-on-google-colab.html). +Please note that better performance maybe gained by adopting a convolution layer in +[some other models](https://github.com/Lornatang/WassersteinGAN-PyTorch). + +## Usage + + +## Training + +```sh +# Cuda backend +cargo run --example wgan-mnist --release --features cuda-jit + +# Wgpu backend +cargo run --example wgan-mnist --release --features wgpu + +# Tch GPU backend +export TORCH_CUDA_VERSION=cu121 # Set the cuda version +cargo run --example wgan-mnist --release --features tch-gpu + +# Tch CPU backend +cargo run --example wgan-mnist --release --features tch-cpu + +# NdArray backend (CPU) +cargo run --example wgan-mnist --release --features ndarray # f32 - single thread +cargo run --example wgan-mnist --release --features ndarray-blas-openblas # f32 - blas with openblas +cargo run --example wgan-mnist --release --features ndarray-blas-netlib # f32 - blas with netlib +``` + + +### Generating + +To generate a sample of images, you can use `wgan-generate`. The same feature flags are used to select a backend. + +```sh +cargo run --example wgan-generate --release --features cuda-jit +``` diff --git a/examples/wgan/examples/wgan-generate.rs b/examples/wgan/examples/wgan-generate.rs new file mode 100644 index 0000000000..fa66623ca3 --- /dev/null +++ b/examples/wgan/examples/wgan-generate.rs @@ -0,0 +1,95 @@ +use burn::tensor::backend::Backend; + +pub fn launch(device: B::Device) { + wgan::infer::generate::("/tmp/wgan-mnist", device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::{ + ndarray::{NdArray, NdArrayDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::{wgpu::Wgpu, Autodiff}; + + pub fn run() { + launch::>(Default::default()); + } +} + +#[cfg(feature = "cuda-jit")] +mod cuda_jit { + use crate::launch; + use burn::backend::{Autodiff, CudaJit}; + + pub fn run() { + launch::>(Default::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda-jit")] + cuda_jit::run(); +} diff --git a/examples/wgan/examples/wgan-mnist.rs b/examples/wgan/examples/wgan-mnist.rs new file mode 100644 index 0000000000..d964b07844 --- /dev/null +++ b/examples/wgan/examples/wgan-mnist.rs @@ -0,0 +1,107 @@ +use burn::{optim::RmsPropConfig, tensor::backend::AutodiffBackend}; + +use wgan::{model::ModelConfig, training::TrainingConfig}; + +pub fn launch(device: B::Device) { + let config = TrainingConfig::new( + ModelConfig::new(), + RmsPropConfig::new() + .with_alpha(0.99) + .with_momentum(0.0) + .with_epsilon(0.00000008) + .with_weight_decay(None) + .with_centered(false), + ); + + wgan::training::train::("/tmp/wgan-mnist", config, device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::{ + ndarray::{NdArray, NdArrayDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::{wgpu::Wgpu, Autodiff}; + + pub fn run() { + launch::>(Default::default()); + } +} + +#[cfg(feature = "cuda-jit")] +mod cuda_jit { + use crate::launch; + use burn::backend::{cuda_jit::CudaDevice, Autodiff, CudaJit}; + + pub fn run() { + launch::>(CudaDevice::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda-jit")] + cuda_jit::run(); +} diff --git a/examples/wgan/src/dataset.rs b/examples/wgan/src/dataset.rs new file mode 100644 index 0000000000..46848d4ffb --- /dev/null +++ b/examples/wgan/src/dataset.rs @@ -0,0 +1,49 @@ +use burn::{ + data::{dataloader::batcher::Batcher, dataset::vision::MnistItem}, + prelude::*, +}; + +#[derive(Clone, Debug)] +pub struct MnistBatcher { + device: B::Device, +} + +#[derive(Clone, Debug)] +pub struct MnistBatch { + pub images: Tensor, + pub targets: Tensor, +} + +impl MnistBatcher { + pub fn new(device: B::Device) -> Self { + Self { device } + } +} + +impl Batcher> for MnistBatcher { + fn batch(&self, items: Vec) -> MnistBatch { + let images = items + .iter() + .map(|item| TensorData::from(item.image)) + .map(|data| Tensor::::from_data(data.convert::(), &self.device)) + .map(|tensor| tensor.reshape([1, 28, 28])) + // Set std=0.5 and mean=0.5 to keep consistent with pytorch WGAN example + .map(|tensor| ((tensor / 255) - 0.5) / 0.5) + .collect(); + + let targets = items + .iter() + .map(|item| { + Tensor::::from_data( + TensorData::from([(item.label as i64).elem::()]), + &self.device, + ) + }) + .collect(); + + let images = Tensor::stack(images, 0); + let targets = Tensor::cat(targets, 0); + + MnistBatch { images, targets } + } +} diff --git a/examples/wgan/src/infer.rs b/examples/wgan/src/infer.rs new file mode 100644 index 0000000000..25ca984feb --- /dev/null +++ b/examples/wgan/src/infer.rs @@ -0,0 +1,41 @@ +use crate::training::{save_image, TrainingConfig}; +use burn::{ + prelude::*, + record::{CompactRecorder, Recorder}, + tensor::Distribution, +}; + +pub fn generate(artifact_dir: &str, device: B::Device) { + // Loading model + let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) + .expect("Config should exist for the model; run train first"); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/generator").into(), &device) + .expect("Trained model should exist; run train first"); + let (mut generator, _) = config.model.init::(&device); + generator = generator.load_record(record); + + // Get a batch of noise + let noise = Tensor::::random( + [config.batch_size, config.model.latent_dim], + Distribution::Normal(0.0, 1.0), + &device, + ); + let fake_images = generator.forward(noise); // [batch_size, channesl*height*width] + let fake_images = fake_images.reshape([ + config.batch_size, + config.model.channels, + config.model.image_size, + config.model.image_size, + ]); + // [B, C, H, W] to [B, H, C, W] to [B, H, W, C] + let fake_images = fake_images.swap_dims(2, 1).swap_dims(3, 2).slice(0..25); + // Normalize the images. The Rgb32 images should be in range 0.0-1.0 + let fake_images = (fake_images.clone() - fake_images.clone().min().reshape([1, 1, 1, 1])) + / (fake_images.clone().max().reshape([1, 1, 1, 1]) + - fake_images.clone().min().reshape([1, 1, 1, 1])); + // Add 0.5 after unnormalizing to [0, 255] to round to the nearest integer, refer to pytorch save_image source + let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0); + // Save images in artifact directory + save_image::(fake_images, 5, format!("{artifact_dir}/fake_image.png")).unwrap(); +} diff --git a/examples/wgan/src/lib.rs b/examples/wgan/src/lib.rs new file mode 100644 index 0000000000..021f62278a --- /dev/null +++ b/examples/wgan/src/lib.rs @@ -0,0 +1,4 @@ +pub mod dataset; +pub mod infer; +pub mod model; +pub mod training; diff --git a/examples/wgan/src/model.rs b/examples/wgan/src/model.rs new file mode 100644 index 0000000000..ddb84ff6d3 --- /dev/null +++ b/examples/wgan/src/model.rs @@ -0,0 +1,157 @@ +use burn::{ + module::{Module, ModuleMapper, ParamId}, + nn::BatchNorm, + prelude::*, + tensor::backend::AutodiffBackend, +}; + +/// Layer block of generator model +#[derive(Module, Debug)] +pub struct LayerBlock { + fc: nn::Linear, + bn: nn::BatchNorm, + leakyrelu: nn::LeakyRelu, +} + +impl LayerBlock { + pub fn new(input: usize, output: usize, device: &B::Device) -> Self { + let fc = nn::LinearConfig::new(input, output) + .with_bias(true) + .init(device); + let bn: BatchNorm = nn::BatchNormConfig::new(output) + .with_epsilon(0.8) + .init(device); + let leakyrelu = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); + + Self { fc, bn, leakyrelu } + } + + pub fn forward(&self, input: Tensor) -> Tensor { + let output = self.fc.forward(input); // output: [Batch, x] + let output = self.bn.forward(output); // output: [Batch, x] + + self.leakyrelu.forward(output) // output: [Batch, x] + } +} + +/// Generator model +#[derive(Module, Debug)] +pub struct Generator { + layer1: LayerBlock, + layer2: LayerBlock, + layer3: LayerBlock, + layer4: LayerBlock, + fc: nn::Linear, + tanh: nn::Tanh, +} + +impl Generator { + /// Applies the forward pass on the input tensor by specified order + pub fn forward(&self, noise: Tensor) -> Tensor { + let output = self.layer1.forward(noise); + let output = self.layer2.forward(output); + let output = self.layer3.forward(output); + let output = self.layer4.forward(output); + let output = self.fc.forward(output); + + self.tanh.forward(output) // [batch_size, channels*height*width] + } +} + +/// Discriminator model +#[derive(Module, Debug)] +pub struct Discriminator { + fc1: nn::Linear, + leakyrelu1: nn::LeakyRelu, + fc2: nn::Linear, + leakyrelu2: nn::LeakyRelu, + fc3: nn::Linear, +} + +impl Discriminator { + /// Applies the forward pass on the input tensor by specified order. + /// The input image shape is [batch, channels, height, width] + pub fn forward(&self, images: Tensor) -> Tensor { + // Full connection for each batch + let output = images.flatten(1, 3); // output: [batch, channels*height*width] + let output = self.fc1.forward(output); // output: [batch, 512] + let output = self.leakyrelu1.forward(output); // output: [batch, 512] + let output = self.fc2.forward(output); // output: [batch, 256] + let output = self.leakyrelu2.forward(output); // output: [batch, 256] + + self.fc3.forward(output) // output: [batch, 1] + } +} + +// Use model config to construct a generative and adverserial model +#[derive(Config, Debug)] +pub struct ModelConfig { + /// Dimensionality of the latent space + #[config(default = 100)] + pub latent_dim: usize, + #[config(default = 28)] + pub image_size: usize, + #[config(default = 1)] + pub channels: usize, +} + +impl ModelConfig { + /// "init" is used to create other objects, while "new" is usally used to create itself. + pub fn init(&self, device: &B::Device) -> (Generator, Discriminator) { + // Construct the initialized generator + let layer1 = LayerBlock::new(self.latent_dim, 128, device); + let layer2 = LayerBlock::new(128, 256, device); + let layer3 = LayerBlock::new(256, 512, device); + let layer4 = LayerBlock::new(512, 1024, device); + let fc = nn::LinearConfig::new(1024, self.channels * self.image_size * self.image_size) + .with_bias(true) + .init(device); + + let generator = Generator { + layer1, + layer2, + layer3, + layer4, + fc, + tanh: nn::Tanh::new(), + }; + + // Construct the initialized discriminator + let fc1 = nn::LinearConfig::new(self.channels * self.image_size * self.image_size, 512) + .init(device); + let leakyrelu1 = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); + let fc2 = nn::LinearConfig::new(512, 256).init(device); + let leakyrelu2 = nn::LeakyReluConfig::new().with_negative_slope(0.2).init(); + let fc3 = nn::LinearConfig::new(256, 1).init(device); + + let discriminator = Discriminator { + fc1, + leakyrelu1, + fc2, + leakyrelu2, + fc3, + }; + + (generator, discriminator) + } +} + +/// Clip module mapper to clip all module parameters between a range of values +#[derive(Module, Clone, Debug)] +pub struct Clip { + pub min: f32, + pub max: f32, +} + +impl ModuleMapper for Clip { + fn map_float(&mut self, _id: ParamId, tensor: Tensor) -> Tensor { + let is_require_grad = tensor.is_require_grad(); + + let mut tensor = Tensor::from_inner(tensor.inner().clamp(self.min, self.max)); + + if is_require_grad { + tensor = tensor.require_grad(); + } + tensor + } +} diff --git a/examples/wgan/src/training.rs b/examples/wgan/src/training.rs new file mode 100644 index 0000000000..db1f594b46 --- /dev/null +++ b/examples/wgan/src/training.rs @@ -0,0 +1,211 @@ +use crate::dataset::MnistBatcher; +use crate::model::{Clip, ModelConfig}; +use burn::optim::{GradientsParams, Optimizer, RmsPropConfig}; +use burn::{ + data::{dataloader::DataLoaderBuilder, dataset::vision::MnistDataset}, + prelude::*, + record::CompactRecorder, + tensor::{backend::AutodiffBackend, Distribution}, +}; +use image::{buffer::ConvertBuffer, error::ImageResult, Rgb32FImage, RgbImage}; +use std::path::Path; + +#[derive(Config)] +pub struct TrainingConfig { + pub model: ModelConfig, + pub optimizer: RmsPropConfig, + + #[config(default = 200)] + pub num_epochs: usize, + #[config(default = 512)] + pub batch_size: usize, + #[config(default = 8)] + pub num_workers: usize, + #[config(default = 5)] + pub seed: u64, + #[config(default = 5e-5)] + pub lr: f64, + + /// Number of training steps for discriminator before generator is trained per iteration + #[config(default = 5)] + pub num_critic: usize, + /// Lower and upper clip value for disc. weights + #[config(default = 0.01)] + pub clip_value: f32, + /// Save a sample of images every `sample_interval` epochs + #[config(default = 10)] + pub sample_interval: usize, +} + +// Create the directory to save the model and model config +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + +/// Save the generated images +// The images format is [B, H, W, C] +pub fn save_image>( + images: Tensor, + nrow: u32, + path: Q, +) -> ImageResult<()> { + let ncol = (images.dims()[0] as f32 / nrow as f32).ceil() as u32; + + let width = images.dims()[2] as u32; + let height = images.dims()[1] as u32; + + // Supports both 1 and 3 channels image + let channels = match images.dims()[3] { + 1 => 3, + 3 => 1, + _ => panic!("Wrong channels number"), + }; + + let mut imgbuf = RgbImage::new(nrow * width, ncol * height); + // Write images into a nrow*ncol grid layout + for row in 0..nrow { + for col in 0..ncol { + let image: Tensor = images + .clone() + .slice((row * nrow + col) as usize..(row * nrow + col + 1) as usize) + .squeeze(0); + // The Rgb32 should be in range 0.0-1.0 + let image = image.into_data().iter::().collect::>(); + // Supports both 1 and 3 channels image + let image = image + .into_iter() + .flat_map(|n| std::iter::repeat(n).take(channels)) + .collect(); + + let image = Rgb32FImage::from_vec(width, height, image).unwrap(); + let image: RgbImage = image.convert(); + for (x, y, pixel) in image.enumerate_pixels() { + imgbuf.put_pixel(row * width + x, col * height + y, *pixel); + } + } + } + imgbuf.save(path) +} + +pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { + create_artifact_dir(artifact_dir); + + // Create the Clip module mapper + let mut clip = Clip { + min: -config.clip_value, + max: config.clip_value, + }; + + // Save training config + config + .save(format!("{artifact_dir}/config.json")) + .expect("Config should be saved successfully"); + B::seed(config.seed); + + // Create the model and optimizer + let (mut generator, mut discriminator) = config.model.init::(&device); + let mut optimizer_g = config.optimizer.init(); + let mut optimizer_d = config.optimizer.init(); + + // Create the dataset batcher + let batcher_train = MnistBatcher::::new(device.clone()); + + // Create the dataloaders + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(MnistDataset::train()); + + // Iterate over our training for X epochs + for epoch in 0..config.num_epochs { + // Implement our training loop + for (iteration, batch) in dataloader_train.iter().enumerate() { + // Generate a batch of fake images from noise (standarded normal distribution) + let noise = Tensor::::random( + [config.batch_size, config.model.latent_dim], + Distribution::Normal(0.0, 1.0), + &device, + ); + // datach: do not update gerenator, only discriminator is updated + let fake_images = generator.forward(noise.clone()).detach(); // [batch_size, channels*height*width] + let fake_images = fake_images.reshape([ + config.batch_size, + config.model.channels, + config.model.image_size, + config.model.image_size, + ]); + // Adversarial loss + let loss_d = -discriminator.forward(batch.images).mean() + + discriminator.forward(fake_images.clone()).mean(); + + // Gradients for the current backward pass + let grads = loss_d.backward(); + // Gradients linked to each parameter of the discriminator + let grads = GradientsParams::from_grads(grads, &discriminator); + // Update the discriminator using the optimizer + discriminator = optimizer_d.step(config.lr, discriminator, grads); + // Clip parameters (weights) of discriminator + discriminator = discriminator.map(&mut clip); + + // Train the generator every num_critic iterations + if iteration % config.num_critic == 0 { + // Generate a batch of images again without detaching + let critic_fake_images = generator.forward(noise.clone()); + let critic_fake_images = critic_fake_images.reshape([ + config.batch_size, + config.model.channels, + config.model.image_size, + config.model.image_size, + ]); + // Adversarial loss. Minimize it to make the fake images as truth + let loss_g = -discriminator.forward(critic_fake_images).mean(); + + let grads = loss_g.backward(); + let grads = GradientsParams::from_grads(grads, &generator); + generator = optimizer_g.step(config.lr, generator, grads); + + // Print the progression + let batch_num = (dataloader_train.num_items() as f32 / config.batch_size as f32) + .ceil() as usize; + println!( + "[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}]", + epoch + 1, + config.num_epochs, + iteration, + batch_num, + loss_d.into_scalar(), + loss_g.into_scalar() + ); + } + // If at save interval => save the first 25 generated images + if epoch % config.sample_interval == 0 && iteration == 0 { + // [B, C, H, W] to [B, H, C, W] to [B, H, W, C] + let fake_images = fake_images.swap_dims(2, 1).swap_dims(3, 2).slice(0..25); + // Normalize the images. The Rgb32 images should be in range 0.0-1.0 + let fake_images = (fake_images.clone() + - fake_images.clone().min().reshape([1, 1, 1, 1])) + / (fake_images.clone().max().reshape([1, 1, 1, 1]) + - fake_images.clone().min().reshape([1, 1, 1, 1])); + // Add 0.5/255.0 to the images, refer to pytorch save_image source + let fake_images = (fake_images + 0.5 / 255.0).clamp(0.0, 1.0); + // Save images in artifact directory + let path = format!("{artifact_dir}/image-{}.png", epoch); + save_image::(fake_images, 5, path).unwrap(); + } + } + } + + // Save the trained models + generator + .save_file(format!("{artifact_dir}/generator"), &CompactRecorder::new()) + .expect("Generator should be saved successfully"); + discriminator + .save_file( + format!("{artifact_dir}/discriminator"), + &CompactRecorder::new(), + ) + .expect("Discriminator should be saved successfully"); +} From 93f8bad67198a0d9cf576c412b39c1864ed040a0 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 16 Jan 2025 09:30:32 -0500 Subject: [PATCH 31/61] Remove deprecated Data and DataSerialize (#2703) --- README.md | 35 +- crates/burn-core/Cargo.toml | 3 - crates/burn-core/src/record/primitive.rs | 21 +- crates/burn-core/src/record/tensor.rs | 50 +-- crates/burn-tensor/src/lib.rs | 4 +- crates/burn-tensor/src/tensor/data.rs | 400 +---------------------- crates/burn/Cargo.toml | 1 - 7 files changed, 37 insertions(+), 477 deletions(-) diff --git a/README.md b/README.md index d0ccbcf411..951b2a9f24 100644 --- a/README.md +++ b/README.md @@ -621,19 +621,20 @@ leads to more reliable, bug-free solutions built faster (after some practice
> **Deprecation Note**
Since `0.14.0`, the internal structure for tensor data has changed. The -> previous `Data` struct is being deprecated in favor of the new `TensorData` struct, which allows -> for more flexibility by storing the underlying data as bytes and keeping the data type as a field. -> If you are using `Data` in your code, make sure to switch to `TensorData`. +> previous `Data` struct was deprecated and officially removed since `0.17.0` in favor of the new +> `TensorData` struct, which allows for more flexibility by storing the underlying data as bytes and +> keeping the data type as a field. If you are using `Data` in your code, make sure to switch to +> `TensorData`.
@@ -642,8 +643,9 @@ Loading Model Records From Previous Versions ⚠️
-In the event that you are trying to load a model record saved in a previous version, make sure to -enable the `record-backward-compat` feature flag. +In the event that you are trying to load a model record saved in a version older than `0.14.0`, make +sure to use a compatible version (`0.14`, `0.15` or `0.16`) with the `record-backward-compat` +feature flag. ``` features = [..., "record-backward-compat"] @@ -652,13 +654,14 @@ features = [..., "record-backward-compat"] Otherwise, the record won't be deserialized correctly and you will get an error message. This error will also point you to the backward compatible feature flag. -The backward compatibility is maintained for deserialization when loading records. Therefore, as -soon as you have saved the record again it will be saved according to the new structure and you -won't need the backward compatible feature flag anymore. +The backward compatibility was maintained for deserialization when loading records. Therefore, as +soon as you have saved the record again it will be saved according to the new structure and you can +upgrade back to the current version Please note that binary formats are not backward compatible. Thus, you will need to load your record in a previous version and save it in any of the other self-describing record format (e.g., using the -`NamedMpkFileRecorder`) before using the new version with the `record-backward-compat` feature flag. +`NamedMpkFileRecorder`) before using a compatible version (as described) with the +`record-backward-compat` feature flag.
diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index b968e28a68..e895cc4572 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -113,9 +113,6 @@ record-item-custom-serde = ["thiserror", "regex"] # Serialization formats experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] -# Backwards compatibility with previous serialized data format. -record-backward-compat = [] - test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray. test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray. test-tch = ["tch"] # To use tch during testing, default uses ndarray. diff --git a/crates/burn-core/src/record/primitive.rs b/crates/burn-core/src/record/primitive.rs index 9dd921e824..2f9fa3e83c 100644 --- a/crates/burn-core/src/record/primitive.rs +++ b/crates/burn-core/src/record/primitive.rs @@ -5,9 +5,7 @@ use super::tensor::{BoolTensorSerde, FloatTensorSerde, IntTensorSerde}; use super::{PrecisionSettings, Record}; use crate::module::{Param, ParamId}; -#[allow(deprecated)] -use burn_tensor::DataSerialize; -use burn_tensor::{backend::Backend, Bool, Element, Int, Tensor}; +use burn_tensor::{backend::Backend, Bool, Int, Tensor}; use hashbrown::HashMap; use serde::{ @@ -143,23 +141,6 @@ where } } -#[allow(deprecated)] -impl Record for DataSerialize -where - E: Element, - B: Backend, -{ - type Item = DataSerialize; - - fn into_item(self) -> Self::Item { - self.convert() - } - - fn from_item(item: Self::Item, _device: &B::Device) -> Self { - item.convert() - } -} - /// (De)serialize parameters into a clean format. #[derive(new, Debug, Clone, Serialize, Deserialize)] pub struct ParamSerde { diff --git a/crates/burn-core/src/record/tensor.rs b/crates/burn-core/src/record/tensor.rs index ab6f448b7e..a07453bcba 100644 --- a/crates/burn-core/src/record/tensor.rs +++ b/crates/burn-core/src/record/tensor.rs @@ -4,20 +4,7 @@ use super::{PrecisionSettings, Record}; use burn_tensor::{backend::Backend, Bool, DType, Element, Int, Tensor, TensorData}; use serde::{Deserialize, Serialize}; -#[cfg(not(feature = "record-backward-compat"))] use alloc::format; -#[cfg(feature = "record-backward-compat")] -use burn_tensor::DataSerialize; - -/// Versioned serde data deserialization to maintain backward compatibility between formats. -#[cfg(feature = "record-backward-compat")] -#[allow(deprecated)] -#[derive(Serialize, Deserialize)] -#[serde(untagged)] -enum TensorDataSerde { - V1(DataSerialize), - V2(TensorData), -} /// Deserialize the value into [`TensorData`]. fn deserialize_data<'de, E, De>(deserializer: De) -> Result @@ -25,31 +12,18 @@ where E: Element + Deserialize<'de>, De: serde::Deserializer<'de>, { - #[cfg(feature = "record-backward-compat")] - { - let data = match TensorDataSerde::::deserialize(deserializer)? { - TensorDataSerde::V1(data) => data.into_tensor_data(), - // NOTE: loading f32 weights with f16 precision will deserialize the f32 weights (bytes) first and then convert to f16 - TensorDataSerde::V2(data) => data.convert::(), - }; - Ok(data) - } - - #[cfg(not(feature = "record-backward-compat"))] - { - let data = TensorData::deserialize(deserializer).map_err(|e| { - serde::de::Error::custom(format!( - "{:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag. Once you have saved the record in the new format, you can disable the feature flag.\n", - e - )) - })?; - let data = if let DType::QFloat(_) = data.dtype { - data // do not convert quantized tensors - } else { - data.convert::() - }; - Ok(data) - } + let data = TensorData::deserialize(deserializer).map_err(|e| { + serde::de::Error::custom(format!( + "{:?}\nThe internal data format has changed since version 0.14.0. If you are trying to load a record saved in a previous version, use the `record-backward-compat` feature flag with a previous version (<=0.16.0). Once you have saved the record in the new format, you can upgrade back to the current version.\n", + e + )) + })?; + let data = if let DType::QFloat(_) = data.dtype { + data // do not convert quantized tensors + } else { + data.convert::() + }; + Ok(data) } /// This struct implements serde to lazily serialize and deserialize a float tensor diff --git a/crates/burn-tensor/src/lib.rs b/crates/burn-tensor/src/lib.rs index d3cb280e90..0376da57a2 100644 --- a/crates/burn-tensor/src/lib.rs +++ b/crates/burn-tensor/src/lib.rs @@ -1,8 +1,6 @@ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] -// Allow deprecated `Data` and `DataSerialize` -#![allow(deprecated)] //! This library provides multiple tensor implementations hidden behind an easy to use API //! that supports reverse mode automatic differentiation. @@ -59,6 +57,8 @@ mod cube_wgpu { use crate::backend::{DeviceId, DeviceOps}; use cubecl::wgpu::WgpuDevice; + // Allow deprecated `WgpuDevice::BestAvailable` + #[allow(deprecated)] impl DeviceOps for WgpuDevice { fn id(&self) -> DeviceId { match self { diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index 5fa6f765fc..bd144e397f 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -1,7 +1,4 @@ -use core::{ - any::{Any, TypeId}, - f32, -}; +use core::f32; use alloc::boxed::Box; use alloc::format; @@ -14,7 +11,7 @@ use crate::{ quantization::{ Quantization, QuantizationScheme, QuantizationStrategy, QuantizationType, QuantizedBytes, }, - tensor::{bytes::Bytes, Shape}, + tensor::bytes::Bytes, DType, Distribution, Element, ElementConversion, }; @@ -777,396 +774,6 @@ impl core::fmt::Display for TensorData { } } -/// Data structure for serializing and deserializing tensor data. -#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Clone, new)] -#[deprecated( - since = "0.14.0", - note = "the internal data format has changed, please use `TensorData` instead" -)] -pub struct DataSerialize { - /// The values of the tensor. - pub value: Vec, - /// The shape of the tensor. - pub shape: Vec, -} - -/// Data structure for tensors. -#[derive(new, Debug, Clone, PartialEq, Eq)] -#[deprecated( - since = "0.14.0", - note = "the internal data format has changed, please use `TensorData` instead" -)] -pub struct Data { - /// The values of the tensor. - pub value: Vec, - - /// The shape of the tensor. - pub shape: Shape, -} - -#[allow(deprecated)] -impl Data { - /// Converts the data to a different element type. - pub fn convert(self) -> Data { - let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); - - Data { - value, - shape: self.shape, - } - } - - /// Asserts each value is within a given range. - /// - /// # Arguments - /// - /// * `range` - The range. - /// - /// # Panics - /// - /// If any value is not within the half-open range bounded inclusively below - /// and exclusively above (`start..end`). - pub fn assert_within_range(&self, range: core::ops::Range) { - let start = range.start.elem::(); - let end = range.end.elem::(); - - for elem in self.value.iter() { - let elem = elem.elem::(); - if elem < start || elem >= end { - panic!("Element ({elem:?}) is not within range {range:?}"); - } - } - } -} - -#[allow(deprecated)] -impl DataSerialize { - /// Converts the data to a different element type. - pub fn convert(self) -> DataSerialize { - if TypeId::of::() == TypeId::of::() { - let cast: Box = Box::new(self); - let cast: Box> = cast.downcast().unwrap(); - return *cast; - } - - let value: Vec = self.value.into_iter().map(|a| a.elem()).collect(); - - DataSerialize { - value, - shape: self.shape, - } - } - - /// Converts the data to the new [TensorData] format. - pub fn into_tensor_data(self) -> TensorData { - TensorData::new(self.value, self.shape) - } -} - -#[allow(deprecated)] -impl Data { - /// Populates the data with random values. - pub fn random(shape: Shape, distribution: Distribution, rng: &mut R) -> Self { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(E::random(distribution, rng)); - } - - Data::new(data, shape) - } -} - -#[allow(deprecated)] -impl Data -where - E: Element, -{ - /// Populates the data with zeros. - pub fn zeros>(shape: S) -> Data { - let shape = shape.into(); - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(0.elem()); - } - - Data::new(data, shape) - } -} - -#[allow(deprecated)] -impl Data -where - E: Element, -{ - /// Populates the data with ones. - pub fn ones(shape: Shape) -> Data { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - - for _ in 0..num_elements { - data.push(1.elem()); - } - - Data::new(data, shape) - } -} - -#[allow(deprecated)] -impl Data -where - E: Element, -{ - /// Populates the data with the given value - pub fn full(shape: Shape, fill_value: E) -> Data { - let num_elements = shape.num_elements(); - let mut data = Vec::with_capacity(num_elements); - for _ in 0..num_elements { - data.push(fill_value) - } - - Data::new(data, shape) - } -} - -#[allow(deprecated)] -impl Data { - /// Serializes the data. - /// - /// # Returns - /// - /// The serialized data. - pub fn serialize(&self) -> DataSerialize { - DataSerialize { - value: self.value.clone(), - shape: self.shape.dims.to_vec(), - } - } -} - -#[allow(deprecated)] -impl + Clone + core::fmt::Debug + PartialEq + Element, const D: usize> Data { - /// Asserts the data is approximately equal to another data. - /// - /// # Arguments - /// - /// * `other` - The other data. - /// * `precision` - The precision of the comparison. - /// - /// # Panics - /// - /// Panics if the data is not approximately equal. - #[track_caller] - pub fn assert_approx_eq(&self, other: &Self, precision: usize) { - let tolerance = 0.1.pow(precision as f64); - - self.assert_approx_eq_diff(other, tolerance) - } - - /// Asserts the data is approximately equal to another data. - /// - /// # Arguments - /// - /// * `other` - The other data. - /// * `tolerance` - The tolerance of the comparison. - /// - /// # Panics - /// - /// Panics if the data is not approximately equal. - #[track_caller] - pub fn assert_approx_eq_diff(&self, other: &Self, tolerance: f64) { - let mut message = String::new(); - if self.shape != other.shape { - message += format!( - "\n => Shape is different: {:?} != {:?}", - self.shape.dims, other.shape.dims - ) - .as_str(); - } - - let iter = self.value.clone().into_iter().zip(other.value.clone()); - - let mut num_diff = 0; - let max_num_diff = 5; - - for (i, (a, b)) in iter.enumerate() { - let a: f64 = a.into(); - let b: f64 = b.into(); - - //if they are both nan, then they are equally nan - let both_nan = a.is_nan() && b.is_nan(); - //this works for both infinities - let both_inf = a.is_infinite() && b.is_infinite() && ((a > 0.) == (b > 0.)); - - if both_nan || both_inf { - continue; - } - - let err = (a - b).abs(); - - if E::dtype().is_float() { - if let Some((err, tolerance)) = compare_floats(a, b, E::dtype(), tolerance) { - // Only print the first 5 different values. - if num_diff < max_num_diff { - message += format!( - "\n => Position {i}: {a} != {b} | difference {err} > tolerance \ - {tolerance}" - ) - .as_str(); - } - num_diff += 1; - } - } else if err > tolerance || err.is_nan() { - // Only print the first 5 different values. - if num_diff < max_num_diff { - message += format!( - "\n => Position {i}: {a} != {b} | difference {err} > tolerance \ - {tolerance}" - ) - .as_str(); - } - num_diff += 1; - } - } - - if num_diff >= max_num_diff { - message += format!("\n{} more errors...", num_diff - 5).as_str(); - } - - if !message.is_empty() { - panic!("Tensors are not approx eq:{}", message); - } - } -} - -#[allow(deprecated)] -impl Data { - /// Converts the usize data to a different element type. - pub fn from_usize(self) -> Data { - let value: Vec = self - .value - .into_iter() - .map(|a| num_traits::FromPrimitive::from_usize(a).unwrap()) - .collect(); - - Data { - value, - shape: self.shape, - } - } -} - -#[allow(deprecated)] -impl From<&DataSerialize> for Data { - fn from(data: &DataSerialize) -> Self { - let mut dims = [0; D]; - dims[..D].copy_from_slice(&data.shape[..D]); - Data::new(data.value.clone(), Shape::new(dims)) - } -} - -#[allow(deprecated)] -impl From> for Data { - fn from(data: DataSerialize) -> Self { - let mut dims = [0; D]; - dims[..D].copy_from_slice(&data.shape[..D]); - Data::new(data.value, Shape::new(dims)) - } -} - -#[allow(deprecated)] -impl From<[E; A]> for Data { - fn from(elems: [E; A]) -> Self { - let mut data = Vec::with_capacity(2 * A); - for elem in elems.into_iter() { - data.push(elem); - } - - Data::new(data, Shape::new([A])) - } -} - -#[allow(deprecated)] -impl From<&[E]> for Data { - fn from(elems: &[E]) -> Self { - let mut data = Vec::with_capacity(elems.len()); - for elem in elems.iter() { - data.push(*elem); - } - - Data::new(data, Shape::new([elems.len()])) - } -} - -#[allow(deprecated)] -impl From<[[E; B]; A]> for Data { - fn from(elems: [[E; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B); - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - data.push(elem); - } - } - - Data::new(data, Shape::new([A, B])) - } -} - -#[allow(deprecated)] -impl - From<[[[E; C]; B]; A]> for Data -{ - fn from(elems: [[[E; C]; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B * C); - - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - for elem in elem.into_iter().take(C) { - data.push(elem); - } - } - } - - Data::new(data, Shape::new([A, B, C])) - } -} - -#[allow(deprecated)] -impl< - E: core::fmt::Debug + Copy, - const A: usize, - const B: usize, - const C: usize, - const D: usize, - > From<[[[[E; D]; C]; B]; A]> for Data -{ - fn from(elems: [[[[E; D]; C]; B]; A]) -> Self { - let mut data = Vec::with_capacity(A * B * C * D); - - for elem in elems.into_iter().take(A) { - for elem in elem.into_iter().take(B) { - for elem in elem.into_iter().take(C) { - for elem in elem.into_iter().take(D) { - data.push(elem); - } - } - } - } - - Data::new(data, Shape::new([A, B, C, D])) - } -} - -#[allow(deprecated)] -impl core::fmt::Display for Data { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.write_str(format!("{:?}", &self.value).as_str()) - } -} - fn compare_floats(value: f64, other: f64, ty: DType, tolerance: f64) -> Option<(f64, f64)> { let epsilon_deviations = tolerance / f32::EPSILON as f64; let epsilon = match ty { @@ -1192,9 +799,8 @@ fn compare_floats(value: f64, other: f64, ty: DType, tolerance: f64) -> Option<( } #[cfg(test)] -#[allow(deprecated)] mod tests { - use crate::quantization::AffineQuantization; + use crate::{quantization::AffineQuantization, Shape}; use super::*; use alloc::vec; diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index 0e7ff51e88..d54233f993 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -67,7 +67,6 @@ network = ["burn-core/network"] experimental-named-tensor = ["burn-core/experimental-named-tensor"] # Records -record-backward-compat = ["burn-core/record-backward-compat"] record-item-custom-serde = ["burn-core/record-item-custom-serde"] [dependencies] From 9d9ea8b7013313ceb992d9eb4ef9d3e30c804851 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 16 Jan 2025 10:07:31 -0500 Subject: [PATCH 32/61] Add hardsigmoid formula and fix WGAN doc + default lr (#2706) --- crates/burn-tensor/src/tensor/activation/base.rs | 2 ++ examples/wgan/src/model.rs | 2 +- examples/wgan/src/training.rs | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/crates/burn-tensor/src/tensor/activation/base.rs b/crates/burn-tensor/src/tensor/activation/base.rs index cc5990d375..15fcc7ab50 100644 --- a/crates/burn-tensor/src/tensor/activation/base.rs +++ b/crates/burn-tensor/src/tensor/activation/base.rs @@ -144,6 +144,8 @@ pub fn sigmoid(tensor: Tensor) -> Tensor } /// Applies the hard sigmoid function +/// +/// `hard_sigmoid(x) = max(0, min(1, alpha * x + beta))` pub fn hard_sigmoid( tensor: Tensor, alpha: f64, diff --git a/examples/wgan/src/model.rs b/examples/wgan/src/model.rs index ddb84ff6d3..b9615f5270 100644 --- a/examples/wgan/src/model.rs +++ b/examples/wgan/src/model.rs @@ -96,7 +96,7 @@ pub struct ModelConfig { } impl ModelConfig { - /// "init" is used to create other objects, while "new" is usally used to create itself. + /// Initialize the generator and discriminator models based on the config. pub fn init(&self, device: &B::Device) -> (Generator, Discriminator) { // Construct the initialized generator let layer1 = LayerBlock::new(self.latent_dim, 128, device); diff --git a/examples/wgan/src/training.rs b/examples/wgan/src/training.rs index db1f594b46..25fbef21c1 100644 --- a/examples/wgan/src/training.rs +++ b/examples/wgan/src/training.rs @@ -23,7 +23,7 @@ pub struct TrainingConfig { pub num_workers: usize, #[config(default = 5)] pub seed: u64, - #[config(default = 5e-5)] + #[config(default = 3e-4)] pub lr: f64, /// Number of training steps for discriminator before generator is trained per iteration From 9daf0486ec71c2d11d23bffd7d5d8ebdcfd6de37 Mon Sep 17 00:00:00 2001 From: Nathan Whitehead Date: Thu, 16 Jan 2025 08:08:07 -0700 Subject: [PATCH 33/61] Fix GRU (#2704) * Fix GRU to match pytorch (#2701). Update GRU implementation of new gate to match pytorch implementation. This can change numerical output in some cases. Add GRU unit test with sequence length > 1. Fix GRU input state dimensions and hidden state handling. This is an API change since the dimensions of the optional hidden state input are being corrected to the right sizes. Just updating to the correct dimensions seems like the best thing since the previous implementation was incorrect, not just different than pytorch. * Add GruConfig option reset_after to allow both reset behaviors. * Fix clippy and keep previous test --------- Co-authored-by: Guillaume Lagrange --- crates/burn-core/src/nn/rnn/gru.rs | 198 ++++++++++++++++++++--------- 1 file changed, 141 insertions(+), 57 deletions(-) diff --git a/crates/burn-core/src/nn/rnn/gru.rs b/crates/burn-core/src/nn/rnn/gru.rs index c66ad631b6..e2f8b2425e 100644 --- a/crates/burn-core/src/nn/rnn/gru.rs +++ b/crates/burn-core/src/nn/rnn/gru.rs @@ -20,6 +20,21 @@ pub struct GruConfig { pub d_hidden: usize, /// If a bias should be applied during the Gru transformation. pub bias: bool, + /// If reset gate should be applied after weight multiplication. + /// + /// This configuration option controls how the reset gate is applied to the hidden state. + /// * `true` - (Default) Match the initial arXiv version of the paper [Learning Phrase Representations using RNN Encoder-Decoder for + /// Statistical Machine Translation (v1)](https://arxiv.org/abs/1406.1078v1) and apply the reset gate after multiplication by + /// the weights. This matches the behavior of [PyTorch GRU](https://pytorch.org/docs/stable/generated/torch.nn.GRU.html#torch.nn.GRU). + /// * `false` - Match the most recent revision of [Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine + /// Translation (v3)](https://arxiv.org/abs/1406.1078) and apply the reset gate before the weight multiplication. + /// + /// The differing implementations can give slightly different numerical results and have different efficiencies. For more + /// motivation for why the `true` can be more efficient see [Optimizing RNNs with Differentiable Graphs](https://svail.github.io/diff_graphs). + /// + /// To set this field to `false` use [`with_reset_after`](`GruConfig::with_reset_after`). + #[config(default = "true")] + pub reset_after: bool, /// Gru initializer #[config(default = "Initializer::XavierNormal{gain:1.0}")] pub initializer: Initializer, @@ -41,6 +56,8 @@ pub struct Gru { pub new_gate: GateController, /// The size of the hidden state. pub d_hidden: usize, + /// If reset gate should be applied after weight multiplication. + pub reset_after: bool, } impl ModuleDisplay for Gru { @@ -58,6 +75,7 @@ impl ModuleDisplay for Gru { .add("d_input", &d_input) .add("d_hidden", &self.d_hidden) .add("bias", &bias) + .add("reset_after", &self.reset_after) .optional() } } @@ -94,86 +112,92 @@ impl GruConfig { reset_gate, new_gate, d_hidden: self.d_hidden, + reset_after: self.reset_after, } } } impl Gru { /// Applies the forward pass on the input tensor. This GRU implementation - /// returns a single state tensor with dimensions [batch_size, sequence_length, hidden_size]. + /// returns a state tensor with dimensions `[batch_size, sequence_length, hidden_size]`. /// - /// # Shapes + /// # Parameters /// - batched_input: `[batch_size, sequence_length, input_size]`. - /// - state: An optional tensor representing an initial cell state with the same dimensions - /// as batched_input. If none is provided, one will be generated. - /// - output: `[batch_size, sequence_length, hidden_size]`. + /// - state: An optional tensor representing an initial cell state with dimensions + /// `[batch_size, hidden_size]`. If none is provided, an empty state will be used. + /// + /// # Returns + /// - output: `[batch_size, sequence_length, hidden_size]` pub fn forward( &self, batched_input: Tensor, - state: Option>, + state: Option>, ) -> Tensor { + let device = batched_input.device(); let [batch_size, seq_length, _] = batched_input.shape().dims(); - let mut hidden_state = match state { + let mut batched_hidden_state = + Tensor::empty([batch_size, seq_length, self.d_hidden], &device); + + let mut hidden_t = match state { Some(state) => state, - None => Tensor::zeros( - [batch_size, seq_length, self.d_hidden], - &batched_input.device(), - ), + None => Tensor::zeros([batch_size, self.d_hidden], &device), }; - for (t, (input_t, hidden_t)) in batched_input - .iter_dim(1) - .zip(hidden_state.clone().iter_dim(1)) - .enumerate() - { + for (t, input_t) in batched_input.iter_dim(1).enumerate() { let input_t = input_t.squeeze(1); - let hidden_t = hidden_t.squeeze(1); // u(pdate)g(ate) tensors - let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate); + let biased_ug_input_sum = + self.gate_product(&input_t, &hidden_t, None, &self.update_gate); let update_values = activation::sigmoid(biased_ug_input_sum); // Colloquially referred to as z(t) // r(eset)g(ate) tensors - let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate); + let biased_rg_input_sum = + self.gate_product(&input_t, &hidden_t, None, &self.reset_gate); let reset_values = activation::sigmoid(biased_rg_input_sum); // Colloquially referred to as r(t) - let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate // n(ew)g(ate) tensor - let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate); + let biased_ng_input_sum = if self.reset_after { + self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate) + } else { + let reset_t = hidden_t.clone().mul(reset_values); // Passed as input to new_gate + self.gate_product(&input_t, &reset_t, None, &self.new_gate) + }; let candidate_state = biased_ng_input_sum.tanh(); // Colloquially referred to as g(t) // calculate linear interpolation between previous hidden state and candidate state: // g(t) * (1 - z(t)) + z(t) * hidden_t - let state_vector = candidate_state + hidden_t = candidate_state .clone() .mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1) + update_values.clone().mul(hidden_t); - let current_shape = state_vector.shape().dims; - let unsqueezed_shape = [current_shape[0], 1, current_shape[1]]; - let reshaped_state_vector = state_vector.reshape(unsqueezed_shape); - hidden_state = hidden_state.slice_assign( + let unsqueezed_hidden_state = hidden_t.clone().unsqueeze_dim(1); + + batched_hidden_state = batched_hidden_state.slice_assign( [0..batch_size, t..(t + 1), 0..self.d_hidden], - reshaped_state_vector, + unsqueezed_hidden_state, ); } - hidden_state + batched_hidden_state } /// Helper function for performing weighted matrix product for a gate and adds - /// bias, if any. + /// bias, if any, and optionally applies reset to hidden state. /// - /// Mathematically, performs `Wx*X + Wh*H + b`, where: + /// Mathematically, performs `Wx*X + r .* (Wh*H + b)`, where: /// Wx = weight matrix for the connection to input vector X /// Wh = weight matrix for the connection to hidden state H /// X = input vector /// H = hidden state /// b = bias terms + /// r = reset state fn gate_product( &self, input: &Tensor, hidden: &Tensor, + reset: Option<&Tensor>, gate: &GateController, ) -> Tensor { let input_product = input.clone().matmul(gate.input_transform.weight.val()); @@ -190,13 +214,29 @@ impl Gru { .as_ref() .map(|bias_param| bias_param.val()); - match (input_bias, hidden_bias) { - (Some(input_bias), Some(hidden_bias)) => { + match (input_bias, hidden_bias, reset) { + (Some(input_bias), Some(hidden_bias), Some(r)) => { + input_product + + input_bias.unsqueeze() + + r.clone().mul(hidden_product + hidden_bias.unsqueeze()) + } + (Some(input_bias), Some(hidden_bias), None) => { input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze() } - (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product, - (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(), - (None, None) => input_product + hidden_product, + (Some(input_bias), None, Some(r)) => { + input_product + input_bias.unsqueeze() + r.clone().mul(hidden_product) + } + (Some(input_bias), None, None) => { + input_product + input_bias.unsqueeze() + hidden_product + } + (None, Some(hidden_bias), Some(r)) => { + input_product + r.clone().mul(hidden_product + hidden_bias.unsqueeze()) + } + (None, Some(hidden_bias), None) => { + input_product + hidden_product + hidden_bias.unsqueeze() + } + (None, None, Some(r)) => input_product + r.clone().mul(hidden_product), + (None, None, None) => input_product + hidden_product, } } } @@ -207,29 +247,16 @@ mod tests { use crate::tensor::{Distribution, TensorData}; use crate::{module::Param, nn::LinearRecord, TestBackend}; - /// Test forward pass with simple input vector. - /// - /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 - /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 - /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 - /// - /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 - #[test] - fn tests_forward_single_input_single_feature() { - TestBackend::seed(0); - let config = GruConfig::new(1, 1, false); - let device = Default::default(); - let mut gru = config.init::(&device); - - fn create_gate_controller( + fn init_gru(reset_after: bool, device: &B::Device) -> Gru { + fn create_gate_controller( weights: f32, biases: f32, d_input: usize, d_output: usize, bias: bool, initializer: Initializer, - device: &::Device, - ) -> GateController { + device: &B::Device, + ) -> GateController { let record_1 = LinearRecord { weight: Param::from_data(TensorData::from([[weights]]), device), bias: Some(Param::from_data(TensorData::from([biases]), device)), @@ -248,6 +275,9 @@ mod tests { ) } + let config = GruConfig::new(1, 1, false).with_reset_after(reset_after); + let mut gru = config.init::(device); + gru.update_gate = create_gate_controller( 0.5, 0.0, @@ -255,7 +285,7 @@ mod tests { 1, false, Initializer::XavierNormal { gain: 1.0 }, - &device, + device, ); gru.reset_gate = create_gate_controller( 0.6, @@ -264,7 +294,7 @@ mod tests { 1, false, Initializer::XavierNormal { gain: 1.0 }, - &device, + device, ); gru.new_gate = create_gate_controller( 0.7, @@ -273,18 +303,72 @@ mod tests { 1, false, Initializer::XavierNormal { gain: 1.0 }, - &device, + device, ); + gru + } + + /// Test forward pass with simple input vector. + /// + /// z_t = sigmoid(0.5*0.1 + 0.5*0) = 0.5125 + /// r_t = sigmoid(0.6*0.1 + 0.*0) = 0.5150 + /// g_t = tanh(0.7*0.1 + 0.7*0) = 0.0699 + /// + /// h_t = z_t * h' + (1 - z_t) * g_t = 0.0341 + #[test] + fn tests_forward_single_input_single_feature() { + TestBackend::seed(0); + let device = Default::default(); + let mut gru = init_gru::(false, &device); let input = Tensor::::from_data(TensorData::from([[[0.1]]]), &device); + let expected = TensorData::from([[0.034]]); + // Reset gate applied to hidden state before the matrix multiplication + let state = gru.forward(input.clone(), None); + + let output = state + .select(0, Tensor::arange(0..1, &device)) + .squeeze::<2>(0); + + output.to_data().assert_approx_eq(&expected, 3); + + // Reset gate applied to hidden state after the matrix multiplication + gru.reset_after = true; // override forward behavior + let state = gru.forward(input, None); + + let output = state + .select(0, Tensor::arange(0..1, &device)) + .squeeze::<2>(0); + + output.to_data().assert_approx_eq(&expected, 3); + } + + #[test] + fn tests_forward_seq_len_3() { + TestBackend::seed(0); + let device = Default::default(); + let mut gru = init_gru::(true, &device); + + let input = + Tensor::::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device); + let expected = TensorData::from([[0.0341], [0.0894], [0.1575]]); + + let result = gru.forward(input.clone(), None); + let output = result + .select(0, Tensor::arange(0..1, &device)) + .squeeze::<2>(0); + + output.to_data().assert_approx_eq(&expected, 3); + + // Reset gate applied to hidden state before the matrix multiplication + gru.reset_after = false; // override forward behavior let state = gru.forward(input, None); let output = state .select(0, Tensor::arange(0..1, &device)) .squeeze::<2>(0); - let expected = TensorData::from([[0.034]]); output.to_data().assert_approx_eq(&expected, 3); } @@ -308,7 +392,7 @@ mod tests { assert_eq!( alloc::format!("{}", layer), - "Gru {d_input: 2, d_hidden: 8, bias: true, params: 288}" + "Gru {d_input: 2, d_hidden: 8, bias: true, reset_after: true, params: 288}" ); } } From 05925f187fea94fd7cf6ed6bb087a5e8fbb3eea0 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 16 Jan 2025 12:54:04 -0500 Subject: [PATCH 34/61] Clean up train system metrics (#2707) --- crates/burn-train/Cargo.toml | 6 ++-- crates/burn-train/src/metric/mod.rs | 49 +++++++++++------------------ crates/burn/Cargo.toml | 2 +- 3 files changed, 23 insertions(+), 34 deletions(-) diff --git a/crates/burn-train/Cargo.toml b/crates/burn-train/Cargo.toml index b922a1a59e..8c024c88f8 100644 --- a/crates/burn-train/Cargo.toml +++ b/crates/burn-train/Cargo.toml @@ -12,9 +12,9 @@ documentation = "https://docs.rs/burn-train" version.workspace = true [features] -default = ["metrics", "tui"] +default = ["sys-metrics", "tui"] doc = ["default"] -metrics = ["nvml-wrapper", "sysinfo", "systemstat"] +sys-metrics = ["nvml-wrapper", "sysinfo", "systemstat"] tui = ["ratatui"] [dependencies] @@ -28,7 +28,7 @@ tracing-subscriber = { workspace = true } tracing-appender = { workspace = true } tracing-core = { workspace = true } -# Metrics +# System Metrics nvml-wrapper = { workspace = true, optional = true } sysinfo = { workspace = true, optional = true } systemstat = { workspace = true, optional = true } diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index 191099a383..ac8211e884 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -3,65 +3,54 @@ pub mod state; /// Module responsible to save and exposes data collected during training. pub mod store; +// System metrics +#[cfg(feature = "sys-metrics")] +mod cpu_temp; +#[cfg(feature = "sys-metrics")] +mod cpu_use; +#[cfg(feature = "sys-metrics")] +mod cuda; +#[cfg(feature = "sys-metrics")] +mod memory_use; +#[cfg(feature = "sys-metrics")] +pub use cpu_temp::*; +#[cfg(feature = "sys-metrics")] +pub use cpu_use::*; +#[cfg(feature = "sys-metrics")] +pub use cuda::*; +#[cfg(feature = "sys-metrics")] +pub use memory_use::*; + +// Training metrics mod acc; mod auroc; mod base; -#[cfg(feature = "metrics")] mod confusion_stats; -#[cfg(feature = "metrics")] -mod cpu_temp; -#[cfg(feature = "metrics")] -mod cpu_use; -#[cfg(feature = "metrics")] -mod cuda; -#[cfg(feature = "metrics")] mod fbetascore; mod hamming; -#[cfg(feature = "metrics")] mod iteration; mod learning_rate; mod loss; -#[cfg(feature = "metrics")] -mod memory_use; -#[cfg(feature = "metrics")] mod precision; -#[cfg(feature = "metrics")] mod recall; -#[cfg(feature = "metrics")] mod top_k_acc; pub use acc::*; pub use auroc::*; pub use base::*; -#[cfg(feature = "metrics")] pub use confusion_stats::ConfusionStatsInput; -#[cfg(feature = "metrics")] -pub use cpu_temp::*; -#[cfg(feature = "metrics")] -pub use cpu_use::*; -#[cfg(feature = "metrics")] -pub use cuda::*; -#[cfg(feature = "metrics")] pub use fbetascore::*; pub use hamming::*; -#[cfg(feature = "metrics")] pub use iteration::*; pub use learning_rate::*; pub use loss::*; -#[cfg(feature = "metrics")] -pub use memory_use::*; -#[cfg(feature = "metrics")] pub use precision::*; -#[cfg(feature = "metrics")] pub use recall::*; -#[cfg(feature = "metrics")] pub use top_k_acc::*; -#[cfg(feature = "metrics")] pub(crate) mod classification; pub(crate) mod processor; -#[cfg(feature = "metrics")] pub use crate::metric::classification::ClassReduction; // Expose `ItemLazy` so it can be implemented for custom types pub use processor::ItemLazy; diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index d54233f993..cd13682a4b 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -24,7 +24,7 @@ train = ["burn-train", "autodiff", "dataset"] tui = ["burn-train?/tui"] ## Includes system info metrics (CPU/GPU usage, etc) -metrics = ["burn-train?/metrics"] +metrics = ["burn-train?/sys-metrics"] # Datasets dataset = ["burn-core/dataset"] From 6750fd689059e628ed763e389c6bd5d8350ce77f Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 16 Jan 2025 14:33:28 -0600 Subject: [PATCH 35/61] Code generation bug fix for ONNX import (#2708) --- crates/burn-import/src/burn/codegen.rs | 5 +++-- crates/burn-import/src/burn/node/resize.rs | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/crates/burn-import/src/burn/codegen.rs b/crates/burn-import/src/burn/codegen.rs index 798636c323..7f511dafd4 100644 --- a/crates/burn-import/src/burn/codegen.rs +++ b/crates/burn-import/src/burn/codegen.rs @@ -5,8 +5,9 @@ use burn::nn::PaddingConfig1d; use burn::nn::PaddingConfig2d; use burn::nn::PaddingConfig3d; -fn convert_primitive(primitive: T) -> TokenStream { - let value = primitive.to_string(); +fn convert_primitive(primitive: T) -> TokenStream { + let value = format!("{:?}", primitive); + value.parse().unwrap() } diff --git a/crates/burn-import/src/burn/node/resize.rs b/crates/burn-import/src/burn/node/resize.rs index 59afcfb607..606f3ef38d 100644 --- a/crates/burn-import/src/burn/node/resize.rs +++ b/crates/burn-import/src/burn/node/resize.rs @@ -228,7 +228,7 @@ mod tests { TensorType::new_float("tensor1", 3), TensorType::new_float("tensor2", 3), "cubic".to_string(), - vec![], + vec![2.0], vec![20], )); @@ -253,7 +253,7 @@ mod tests { pub fn new(device: &B::Device) -> Self { let resize = Interpolate1dConfig::new() .with_output_size(Some(20)) - .with_scale_factor(None) + .with_scale_factor(Some(2.0)) .with_mode(InterpolateMode::Cubic) .init(); Self { From b4d9d54990498cf49dab2f7d272d485a50a7cfe2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 20 Jan 2025 08:17:46 -0500 Subject: [PATCH 36/61] Combined PRs (#2716) * Bump log from 0.4.22 to 0.4.25 Bumps [log](https://github.com/rust-lang/log) from 0.4.22 to 0.4.25. - [Release notes](https://github.com/rust-lang/log/releases) - [Changelog](https://github.com/rust-lang/log/blob/master/CHANGELOG.md) - [Commits](https://github.com/rust-lang/log/compare/0.4.22...0.4.25) --- updated-dependencies: - dependency-name: log dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump uuid from 1.11.1 to 1.12.0 Bumps [uuid](https://github.com/uuid-rs/uuid) from 1.11.1 to 1.12.0. - [Release notes](https://github.com/uuid-rs/uuid/releases) - [Commits](https://github.com/uuid-rs/uuid/compare/1.11.1...1.12.0) --- updated-dependencies: - dependency-name: uuid dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Bump data-encoding from 2.6.0 to 2.7.0 Bumps [data-encoding](https://github.com/ia0/data-encoding) from 2.6.0 to 2.7.0. - [Commits](https://github.com/ia0/data-encoding/commits) --- updated-dependencies: - dependency-name: data-encoding dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Cargo.lock | 12 ++++++------ Cargo.toml | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1af3585919..70011e1e80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1936,9 +1936,9 @@ dependencies = [ [[package]] name = "data-encoding" -version = "2.6.0" +version = "2.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" +checksum = "0e60eed09d8c01d3cee5b7d30acb059b76614c918fa0f992e0dd6eeb10daad6f" [[package]] name = "deflate64" @@ -3723,9 +3723,9 @@ checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "loop9" @@ -7777,9 +7777,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.11.1" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b913a3b5fe84142e269d63cc62b64319ccaf89b748fc31fe025177f767a756c4" +checksum = "744018581f9a3454a9e15beb8a33b017183f1e7c0cd170232a2d1453b23a51c4" dependencies = [ "getrandom", "rand", diff --git a/Cargo.toml b/Cargo.toml index 870902a9ac..5a38273c20 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,7 +34,7 @@ colored = "2.1.0" console_error_panic_hook = "0.1.7" csv = "1.3.1" dashmap = "6.1.0" -data-encoding = { version = "2.6.0", default-features = false, features = [ +data-encoding = { version = "2.7.0", default-features = false, features = [ "alloc", ] } dirs = "5.0.1" @@ -50,7 +50,7 @@ image = "0.25.5" indicatif = "0.17.9" js-sys = "0.3.72" libm = "0.2.11" -log = { default-features = false, version = "0.4.22" } +log = { default-features = false, version = "0.4.25" } md5 = "0.7.0" paste = "1" percent-encoding = "2.3.1" @@ -141,7 +141,7 @@ serde = { version = "1.0.217", default-features = false, features = [ "alloc", ] } # alloc is for no_std, derive is needed serde_json = { version = "1.0.135", default-features = false } -uuid = { version = "1.11.0", default-features = false } +uuid = { version = "1.12.0", default-features = false } libc = "0.2.169" nvml-wrapper = "0.10.0" From 949e77fec6b0282dca91d7cb6c25a7c1580b0fd8 Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Mon, 20 Jan 2025 15:14:24 +0100 Subject: [PATCH 37/61] Migrate to type magic autotune (#2710) --- Cargo.lock | 168 +++++++++--------- Cargo.toml | 8 +- .../src/kernel/conv/conv2d/gemm/algorithm.rs | 31 +--- .../src/kernel/conv/conv2d/gemm/base.rs | 8 +- .../conv/conv2d/gemm/homogeneous/base.rs | 9 +- .../src/kernel/conv/conv2d/gemm/launch.rs | 60 +------ .../src/kernel/conv/conv2d/implicit_gemm.rs | 61 ++++--- .../src/kernel/conv/conv2d/tune/conv2d.rs | 160 +++-------------- .../conv/conv2d/tune/conv_transpose2d.rs | 55 +++--- crates/burn-jit/src/kernel/conv/error.rs | 20 ++- .../burn-jit/src/kernel/matmul/tune/base.rs | 52 ++---- crates/burn-jit/src/kernel/reduce/tune.rs | 65 +++---- crates/burn-ndarray/src/ops/deform_conv.rs | 4 +- 13 files changed, 249 insertions(+), 452 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 70011e1e80..eef5ff3962 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -511,9 +511,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.7.0" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1be3f42a67d6d345ecd59f675f3f012d6974981560836e938c22b424b85ce1be" +checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" dependencies = [ "serde", ] @@ -1522,7 +1522,7 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "829d955a0bb380ef178a640b91779e3987da38c9aea133b20614cfed8cdea9c6" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "crossterm_winapi", "mio", "parking_lot 0.12.3", @@ -1580,9 +1580,8 @@ dependencies = [ [[package]] name = "cubecl" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aecf090429a4172d94c819e2977f440d7f5846c09f31d36937de309f986c878e" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1596,9 +1595,8 @@ dependencies = [ [[package]] name = "cubecl-common" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10239ee4800968f367fbc4828250d38acf5d14fa53e8d0370d5f474387591322" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ "derive-new 0.6.0", "embassy-futures", @@ -1614,10 +1612,10 @@ dependencies = [ [[package]] name = "cubecl-core" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d249976814abe45ee5d04bdfd5e2359558b409affdc03914625bea778dab5ade" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ + "bitflags 2.8.0", "bytemuck", "cubecl-common", "cubecl-macros", @@ -1634,9 +1632,8 @@ dependencies = [ [[package]] name = "cubecl-cpp" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8463629d0bdf4d09d47150bce35132236c1a597f65eba213b45073406048a596" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ "bytemuck", "cubecl-common", @@ -1649,9 +1646,8 @@ dependencies = [ [[package]] name = "cubecl-cuda" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12c0b49113ba986e984538cf54c3d7390c0af934a80f083b6c99cad737d22c59" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ "bytemuck", "cubecl-common", @@ -1666,9 +1662,8 @@ dependencies = [ [[package]] name = "cubecl-hip" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "976e150315f9d7d6bb84c51cb13c19221ea5d185bb6d61347a3c392dd29720de" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ "bytemuck", "cubecl-common", @@ -1693,9 +1688,8 @@ dependencies = [ [[package]] name = "cubecl-linalg" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "640c379e225fecb1336f963affd3b8f1ff66b9320a972dfe92d8158dca8b6382" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ "bytemuck", "cubecl-core", @@ -1706,9 +1700,8 @@ dependencies = [ [[package]] name = "cubecl-macros" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f05d95f3be436814f909a3ac97209159f63076d3d2b254914bc02db2ac7faefb" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ "cubecl-common", "darling", @@ -1722,9 +1715,8 @@ dependencies = [ [[package]] name = "cubecl-opt" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42c0593efee028e010a1a7e8646a8a405f6a653fe194bc8c5b46189245ecaeec" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ "cubecl-common", "cubecl-core", @@ -1739,9 +1731,8 @@ dependencies = [ [[package]] name = "cubecl-reduce" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0912890b52cc6f9636e0070320ff93dec27af15d57453789081b9a8bdb49786d" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1750,9 +1741,8 @@ dependencies = [ [[package]] name = "cubecl-runtime" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75e84f4ae5a096e4d0c410db01d18b673d6efcd6eea1724d1a001ab60484df87" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ "async-channel", "async-lock", @@ -1767,15 +1757,16 @@ dependencies = [ "serde", "serde_json", "spin", + "variadics_please", "wasm-bindgen-futures", ] [[package]] name = "cubecl-spirv" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5d88e7d35a58a40991e42c4492739d4b89b6046ac75126cb5f10b190032012c" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ + "bitflags 2.8.0", "cubecl-common", "cubecl-core", "cubecl-opt", @@ -1787,9 +1778,8 @@ dependencies = [ [[package]] name = "cubecl-wgpu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cf8105d01ef4cd103d4e31bee9ae583fabc807253234923fb08218b28db7d15" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" dependencies = [ "ash", "async-channel", @@ -2838,7 +2828,7 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0bf760ebf69878d9fd8f110c89703d90ce35095324d1f1edcb595c63945ee757" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "ignore", "walkdir", ] @@ -2870,7 +2860,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbcd2dba93594b227a1f57ee09b8b9da8892c34d55aa332e034a228d0fe6a171" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "gpu-alloc-types", ] @@ -2880,7 +2870,7 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98ff03b468aa837d70984d55f5d3f846f6ec31fe34bbb97c4f85219caeee1ca4" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", ] [[package]] @@ -2901,7 +2891,7 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dcf29e94d6d243368b7a56caa16bc213e4f9f8ed38c4d9557069527b5d5281ca" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "gpu-descriptor-types", "hashbrown 0.15.2", ] @@ -2912,7 +2902,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdf242682df893b86f33a73828fb09ca4b2d3bb6cc95249707fc684d27484b91" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", ] [[package]] @@ -3671,7 +3661,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "libc", "redox_syscall 0.5.8", ] @@ -3893,7 +3883,7 @@ version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "block", "core-graphics-types", "foreign-types 0.5.0", @@ -3908,7 +3898,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "block", "core-graphics-types", "foreign-types 0.5.0", @@ -4031,14 +4021,14 @@ checksum = "364f94bc34f61332abebe8cad6f6cd82a5b65cff22c828d05d0968911462ca4f" dependencies = [ "arrayvec", "bit-set", - "bitflags 2.7.0", + "bitflags 2.8.0", "cfg_aliases 0.1.1", "codespan-reporting", "hexf-parse", "indexmap", "log", "rustc-hash 1.1.0", - "spirv", + "spirv 0.3.0+sdk-1.3.268.0", "termcolor", "thiserror 1.0.69", "unicode-xid", @@ -4318,7 +4308,7 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c9bff0aa1d48904a1385ea2a8b97576fbdcbc9a3cfccd0d31fe978e1c4038c5" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "libloading", "nvml-wrapper-sys", "static_assertions", @@ -4367,7 +4357,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4e89ad9e3d7d297152b17d39ed92cd50ca8063a89a9fa569046d41568891eff" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "block2", "libc", "objc2", @@ -4383,7 +4373,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "block2", "objc2", "objc2-foundation", @@ -4413,7 +4403,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ee638a5da3799329310ad4cfa62fbf045d5f56e3ef5ba4149e7452dcf89d5a8" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "block2", "libc", "objc2", @@ -4425,7 +4415,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "block2", "objc2", "objc2-foundation", @@ -4437,7 +4427,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "block2", "objc2", "objc2-foundation", @@ -4590,7 +4580,7 @@ version = "0.10.68" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "cfg-if", "foreign-types 0.3.2", "libc", @@ -4938,7 +4928,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd5df9b55e614088a3270b06f8649dce76537c268d6b1ca4d9c37008b2be5949" dependencies = [ "ahash", - "bitflags 2.7.0", + "bitflags 2.8.0", "bytemuck", "chrono", "chrono-tz", @@ -4987,7 +4977,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea1b431ed816cba1120cff200f06b962748001bbb2e615ce53cfbbdf701cc136" dependencies = [ "ahash", - "bitflags 2.7.0", + "bitflags 2.8.0", "hashbrown 0.15.2", "num-traits", "once_cell", @@ -5079,7 +5069,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a8ca74f42e7b47cad241b36b98d991cc7fbb51b8d0695a055eb937588d1f310" dependencies = [ "ahash", - "bitflags 2.7.0", + "bitflags 2.8.0", "memchr", "once_cell", "polars-arrow", @@ -5224,7 +5214,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "23de436f33f4d1134c58f24e7059a221b957ec20730807e0ef0c80c8e4b3d06a" dependencies = [ "ahash", - "bitflags 2.7.0", + "bitflags 2.8.0", "bytemuck", "bytes", "chrono", @@ -5788,7 +5778,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabd94c2f37801c20583fc49dd5cd6b0ba68c716787c2dd6ed18571e1e63117b" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "cassowary", "compact_str", "crossterm", @@ -5869,7 +5859,7 @@ version = "11.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", ] [[package]] @@ -5956,7 +5946,7 @@ version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", ] [[package]] @@ -6146,12 +6136,11 @@ dependencies = [ [[package]] name = "rspirv" -version = "0.12.0+sdk-1.3.268.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d" +version = "0.12.0+sdk-1.3.296.0" +source = "git+https://github.com/gfx-rs/rspirv.git?rev=e19c11fdb30295127cff1d018189bd436892415e#e19c11fdb30295127cff1d018189bd436892415e" dependencies = [ "rustc-hash 1.1.0", - "spirv", + "spirv 0.3.0+sdk-1.3.296.0", ] [[package]] @@ -6190,7 +6179,7 @@ version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "fallible-iterator", "fallible-streaming-iterator", "hashlink", @@ -6241,7 +6230,7 @@ version = "0.38.43" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "errno", "libc", "linux-raw-sys", @@ -6422,7 +6411,7 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "core-foundation 0.9.4", "core-foundation-sys", "libc", @@ -6435,7 +6424,7 @@ version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "core-foundation 0.10.0", "core-foundation-sys", "libc", @@ -6790,7 +6779,15 @@ version = "0.3.0+sdk-1.3.268.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eda41003dc44290527a59b13432d4a0379379fa074b70174882adfbdfd917844" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", +] + +[[package]] +name = "spirv" +version = "0.3.0+sdk-1.3.296.0" +source = "git+https://github.com/gfx-rs/rspirv.git?rev=e19c11fdb30295127cff1d018189bd436892415e#e19c11fdb30295127cff1d018189bd436892415e" +dependencies = [ + "bitflags 2.8.0", ] [[package]] @@ -6951,7 +6948,7 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "byteorder", "enum-as-inner", "libc", @@ -6993,7 +6990,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "core-foundation 0.9.4", "system-configuration-sys", ] @@ -7814,6 +7811,17 @@ dependencies = [ "ryu", ] +[[package]] +name = "variadics_please" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41b6d82be61465f97d42bd1d15bf20f3b0a3a0905018f38f9d6f6962055b0b5c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "vcpkg" version = "0.2.15" @@ -8043,7 +8051,7 @@ checksum = "d63c3c478de8e7e01786479919c8769f62a22eec16788d8c2ac77ce2c132778a" dependencies = [ "arrayvec", "bit-vec", - "bitflags 2.7.0", + "bitflags 2.8.0", "cfg_aliases 0.1.1", "document-features", "indexmap", @@ -8070,7 +8078,7 @@ dependencies = [ "arrayvec", "ash", "bit-set", - "bitflags 2.7.0", + "bitflags 2.8.0", "block", "bytemuck", "cfg_aliases 0.1.1", @@ -8111,7 +8119,7 @@ version = "23.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "610f6ff27778148c31093f3b03abc4840f9636d58d597ca2f5977433acfe0068" dependencies = [ - "bitflags 2.7.0", + "bitflags 2.8.0", "js-sys", "web-sys", ] diff --git a/Cargo.toml b/Cargo.toml index 5a38273c20..846d0f565c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,14 +153,14 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -# cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } -# cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "3c083cb136214404d8eb594258534d10a118a077" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } ### For the release. ### -cubecl = { version = "0.4.0", default-features = false } -cubecl-common = { version = "0.4.0", default-features = false } +# cubecl = { version = "0.4.0", default-features = false } +# cubecl-common = { version = "0.4.0", default-features = false } ### For xtask crate ### tracel-xtask = { version = "=1.1.8" } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs index 374a03be29..ffff3675bd 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs @@ -5,7 +5,7 @@ use cubecl::{ tile::{accelerated::Accelerated, TileMatmulFamily}, InvalidConfigError, }, - kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, + kernels::matmul::AdvancedConfig, }, prelude::*, }; @@ -13,7 +13,6 @@ use cubecl::{ use super::{ base::{ConvolutionConfigFactory, ConvolutionFamily, ConvolutionProblem}, homogeneous::base::ImplicitGemmConvolutionFamily, - precision::ConvPrecision, selection::ConvSelection, }; @@ -47,34 +46,6 @@ pub trait Algorithm { Self::GlobalConvolution::check_config(&config)?; Ok(config) } - - /// Check availability of the matmul algorithm - fn check_availability( - client: &ComputeClient, - config: &::Config, - ) -> Result<(), MatmulAvailabilityError> { - Self::GlobalConvolution::check_availability::(client, config) - } - - /// Determine whether the given convolution problem is valid to launch (within hardware limits) - fn can_launch( - client: &ComputeClient, - problem: &ConvolutionProblem, - config: &::Config, - selection: &Self::Selection, - ) -> bool { - if problem.options.groups > 1 || Self::check_availability::(client, config).is_err() - { - return false; - } - - let cube_count = Self::cube_count(selection, problem); - let (max_x, max_y, max_z) = R::max_cube_count(); - match cube_count { - CubeCount::Static(x, y, z) => x <= max_x && y <= max_y && z <= max_z, - _ => true, - } - } } /// Cmma convolution diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs index e69b33b40f..a78082950a 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs @@ -6,7 +6,7 @@ use cubecl::linalg::{ stage::{StageMatmul, StageMatmulFamily}, InvalidConfigError, MatmulProblem, MatrixLayout, }, - kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, + kernels::matmul::AdvancedConfig, }, tensor::{ReadWrite, VirtualTensor}, }; @@ -91,12 +91,6 @@ pub trait ConvolutionConfigFactory: Send + Sync + 'static { /// Asserts that the configuration for this matmul will lead to a valid computation fn check_config(config: &Self::Config) -> Result<(), InvalidConfigError>; - /// Checks if the client can handle the features used in this computation - fn check_availability( - client: &ComputeClient, - config: &Self::Config, - ) -> Result<(), MatmulAvailabilityError>; - fn make_config( input: Self::Input, problem: &ConvolutionProblem, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs index 2f32c8471e..988cd0ead6 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/homogeneous/base.rs @@ -16,7 +16,7 @@ use cubecl::{ }, Ident, InvalidConfigError, MatrixLayout, StageDim, }, - kernels::{matmul::AdvancedConfig, MatmulAvailabilityError}, + kernels::matmul::AdvancedConfig, }, tensor::{ReadWrite, VirtualTensor}, }, @@ -194,13 +194,6 @@ where SMM::check_config(&config.to_smm_config()) } - fn check_availability( - client: &ComputeClient, - config: &Self::Config, - ) -> Result<(), MatmulAvailabilityError> { - SMM::check_availability::(client, &config.to_smm_config()) - } - fn make_config( input: Self::Input, problem: &ConvolutionProblem, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs index 032368b08a..f36f89bdf5 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs @@ -7,7 +7,7 @@ use burn_tensor::{ use cubecl::{ flex32, ir::{Elem, FloatKind}, - linalg::matmul::{self, components::MatrixLayout}, + linalg::matmul::{self}, tensor_line_size, tf32, Feature, }; use half::{bf16, f16}; @@ -23,7 +23,7 @@ use crate::{ algorithm::{Algorithm, ImplicitCmmaConv}, base::{ConvolutionLaunch, ConvolutionProblem}, }, - nchw_to_nhwc, Conv2dAutotuneKey, ConvLaunchError, + nchw_to_nhwc, ConvLaunchError, }, into_contiguous, }, @@ -108,6 +108,10 @@ pub fn conv2d_gemm_with_algo< where SP::EG: JitElement, { + if options.groups != 1 { + return Err(ConvLaunchError::Groups(options.groups)); + } + let [batch_size, in_channels, height, width] = input.shape.dims(); let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims(); @@ -226,58 +230,6 @@ where Ok(permute(out, &[0, 3, 1, 2])) } -pub fn problem_from_key( - key: &Conv2dAutotuneKey, - out_h: usize, - out_w: usize, -) -> ConvolutionProblem { - let in_stride_2 = key.in_channels; - let in_stride_1 = key.width * in_stride_2; - let in_stride_0 = key.height * in_stride_1; - - let m = key.batch_size * out_h * out_w; - let n = key.out_channels; - let k = key.kernel_size[0] * key.kernel_size[1] * key.in_channels; - - let options = ConvOptions { - stride: key.stride, - padding: key.padding, - dilation: key.dilation, - groups: key.groups, - }; - - // Target 128 bit accesses - let available_vectorizations = R::supported_line_sizes() - .iter() - .copied() - .filter(|it| *it as usize * size_of::() <= 16) - .collect::>(); - let lhs_line_size = tensor_line_size( - &available_vectorizations, - &[key.batch_size, key.height, key.width, key.in_channels], - &[in_stride_0, in_stride_1, in_stride_2, 1], - 3, - ); - let rhs_line_size = tensor_line_size(&available_vectorizations, &[k, n], &[n, 1], 1); - let out_line_size = tensor_line_size(&available_vectorizations, &[m, n], &[n, 1], 1); - - ConvolutionProblem { - m, - n, - k, - lhs_layout: MatrixLayout::RowMajor, - rhs_layout: MatrixLayout::RowMajor, - lhs_line_size, - rhs_line_size, - out_line_size, - kernel_size: (key.kernel_size[0] as u32, key.kernel_size[1] as u32), - options, - out_shape_y: out_h, - out_shape_x: out_w, - has_bias: key.has_bias, - } -} - pub(crate) fn has_tf32(c: &JitTensor) -> bool { c.client .properties() diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs index 2e8e147170..9c8edf0103 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -6,6 +6,7 @@ use cmma::{Matrix, MatrixIdent, MatrixLayout}; use cubecl::{ cube, ir::{Elem, FloatKind}, + linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}, prelude::*, Compiler, CubeCount, CubeDim, Feature, }; @@ -66,7 +67,7 @@ pub fn conv2d_implicit_gemm( let padded_batch_size = padded_batch_size(batch_size, out_h, out_w); - if !can_do_implicit_gemm::( + check_availability::( batch_size, in_channels, out_channels, @@ -75,15 +76,7 @@ pub fn conv2d_implicit_gemm( out_h, out_w, &input.client, - ) { - panic!( - "Requirements for implicit GEMM not met: -- CMMA must be available -- `groups` must be 1 -- subcube size must be non-variable (might not hold on Intel) - " - ); - } + )?; // If input is contiguous NCHW, use custom transpose kernel let input = match input.is_contiguous() { @@ -643,7 +636,7 @@ fn load_weight_tile( } #[allow(clippy::too_many_arguments)] -pub(crate) fn can_do_implicit_gemm( +pub(crate) fn check_availability( batch_size: usize, in_channels: usize, out_channels: usize, @@ -652,7 +645,7 @@ pub(crate) fn can_do_implicit_gemm( out_h: usize, out_w: usize, client: &ComputeClient, -) -> bool { +) -> Result<(), ConvLaunchError> { let cmma_k = match ( E::as_elem_native_unchecked(), client @@ -672,19 +665,43 @@ pub(crate) fn can_do_implicit_gemm( let gemm_n = out_channels; let gemm_k = in_channels * kernel_h * kernel_w; - let size = find_cmma_size::(client, gemm_m as u32, gemm_k as u32, gemm_n as u32); - - if let Some((cmma_m, cmma_k, cmma_n)) = size { - let warps_per_cube = 8; + let (cmma_m, cmma_n, cmma_k) = + find_cmma_size::(client, gemm_m as u32, gemm_k as u32, gemm_n as u32).ok_or_else( + || { + ConvLaunchError::Matmul(MatmulLaunchError::Unavailable( + MatmulAvailabilityError::CmmaInstructionUnavailable { + input: E::as_elem_native_unchecked(), + output: E::as_elem_native_unchecked(), + m: 16, + n: 16, + k: cmma_k as u32, + }, + )) + }, + )?; + + let warps_per_cube = 8; + + let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::(); + if ::max_shared_memory_size() < smem_size { + return Err(ConvLaunchError::Matmul(MatmulLaunchError::InvalidConfig( + Box::new("Not enough shared memory"), + ))); + } - let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::(); - let topology = client.properties().hardware_properties(); - let not_intel = topology.plane_size_min >= 32; + let topology = client.properties().hardware_properties(); + if topology.plane_size_min < 32 { + return Err(ConvLaunchError::Matmul(MatmulLaunchError::Unavailable( + MatmulAvailabilityError::PlaneDimUnsupported { + plane_dim: topology.plane_size_min, + }, + ))); + } - ::max_shared_memory_size() >= smem_size && groups == 1 && not_intel - } else { - false + if groups != 1 { + return Err(ConvLaunchError::Groups(groups)); } + Ok(()) } fn padded_k( diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs index c6eb31ea9c..36d12e2255 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs @@ -1,26 +1,12 @@ -use burn_tensor::{ - ops::{conv::calculate_conv_output_size, ConvOptions}, - ElementConversion, Shape, -}; -use cubecl::{ - ir::{Elem, FloatKind}, - tf32, tune, - tune::{local_tuner, tune_with, LocalTuner}, -}; -use half::f16; +use burn_tensor::{ops::ConvOptions, ElementConversion, Shape}; +use cubecl::tune::{local_tuner, LocalTuner, TunableSet}; use super::Conv2dAutotuneKey; use crate::{ kernel::{ conv::{ - algorithm::{Algorithm, ImplicitCmmaConv}, - batches_per_run, can_do_implicit_gemm, - conv2d::gemm::base::ConvolutionProblem, conv2d_direct, conv2d_gemm_cmma_balanced, conv2d_gemm_cmma_large_m, conv2d_im2col, - conv2d_implicit_gemm, has_tf32, - precision::ConvPrecision, - problem_from_key, - selection::{Balanced, ConvSelector, Large}, + conv2d_implicit_gemm, }, prng::random_uniform, }, @@ -39,31 +25,33 @@ pub fn conv2d_autotune( static TUNER: LocalTuner = local_tuner!(); + let tunables = TunableSet::new(create_key::, create_conv2d_input::) + .with_tunable(conv2d_direct::) + .with_tunable(conv2d_im2col::) + .with_tunable(conv2d_implicit_gemm::) + .with_tunable(conv2d_gemm_cmma_large_m::) + .with_tunable(conv2d_gemm_cmma_balanced::); + TUNER.execute( &JitTuneId::new::(&input.device), &client, - Box::new(Conv2dOperations::::new(input, weights, bias, options)), + &tunables, + (input, weights, bias, options), ) } -#[tune( - operations( - conv2d_direct, - conv2d_im2col, - conv2d_implicit_gemm, - conv2d_gemm_cmma_large_m, - conv2d_gemm_cmma_balanced - ), - create_key = create_key::, - should_run = should_run -)] -pub fn conv2d_operations( - key: JitAutotuneKey, - input: JitTensor, - weights: JitTensor, - bias: Option>, - options: ConvOptions<2>, -) -> JitTensor { +pub fn create_conv2d_input( + key: &JitAutotuneKey, + input: &JitTensor, + _weights: &JitTensor, + _bias: &Option>, + options: &ConvOptions<2>, +) -> ( + JitTensor, + JitTensor, + Option>, + ConvOptions<2>, +) { let device = &input.device; let key = match key { JitAutotuneKey::Conv2d(key) => key, @@ -82,105 +70,7 @@ pub fn conv2d_operations( .has_bias .then(|| random_uniform(bias_shape, device, random_bounds.0, random_bounds.1)); - tune_with!(input, weights, bias, options) -} - -macro_rules! check_algo { - ($algo:tt, $float:ty, $input:expr, $problem:expr) => { - match (<$float>::as_elem_native_unchecked(), has_tf32(&$input)) { - (Elem::Float(FloatKind::F32), true) => { - can_launch::<$algo, R, ($float, tf32, f32)>($input, $problem) - } - (Elem::Float(FloatKind::Flex32), _) => { - can_launch::<$algo, R, ($float, f16, f32)>($input, $problem) - } - _ => can_launch::<$algo, R, ($float, $float, f32)>($input, $problem), - } - }; -} - -fn should_run( - op: &Conv2dOperations, - key: &JitAutotuneKey, - index: usize, -) -> bool { - let key = match key { - JitAutotuneKey::Conv2d(key) => key, - _ => unreachable!(), - }; - - let out_h = calculate_conv_output_size( - key.kernel_size[0], - key.stride[0], - key.padding[0], - key.dilation[0], - key.height, - ); - let out_w = calculate_conv_output_size( - key.kernel_size[1], - key.stride[1], - key.padding[1], - key.dilation[1], - key.width, - ); - - let conv_problem = problem_from_key::(key, out_h, out_w); - - match index { - // im2col - 1 => batches_per_run(key.batch_size, out_h, out_w).is_some(), - // Implicit gemm. - 2 => can_do_implicit_gemm::( - key.batch_size, - key.in_channels, - key.out_channels, - key.kernel_size, - op.options.groups, - out_h, - out_w, - &op.input.client, - ), - // GEMM large m - 3 => check_algo!(Large, F, &op.input, &conv_problem), - // GEMM balanced - 4 => check_algo!(Balanced, F, &op.input, &conv_problem), - _ => true, - } -} - -fn can_launch, R: JitRuntime, CS: ConvPrecision>( - input: &JitTensor, - conv_problem: &ConvolutionProblem, -) -> bool { - let plane_dim = match input - .client - .properties() - .hardware_properties() - .defined_plane_size() - { - Some(val) => val, - None => return false, - }; - - let (selection, config_input) = S::select_kernel::(plane_dim); - let cube_dim = ImplicitCmmaConv::cube_dim(&selection); - let cube_count = ImplicitCmmaConv::cube_count(&selection, conv_problem); - let advanced_config = Default::default(); - - let config = ImplicitCmmaConv::make_config( - config_input, - conv_problem, - &cube_dim, - &cube_count, - &advanced_config, - ); - - match config { - Ok(config) => { - ImplicitCmmaConv::can_launch::(&input.client, conv_problem, &config, &selection) - } - Err(_) => false, - } + (input, weights, bias, options.clone()) } fn create_key( diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs index c2d546151a..df0159b75d 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs @@ -1,12 +1,9 @@ use burn_tensor::{ops::ConvTransposeOptions, ElementConversion, Shape}; -use cubecl::{ - tune, - tune::{local_tuner, tune_with, LocalTuner}, -}; +use cubecl::tune::{local_tuner, LocalTuner, TunableSet}; use crate::{ kernel::{ - conv::{batches_per_run, conv_transpose2d_col2im, conv_transpose2d_direct}, + conv::{conv_transpose2d_col2im, conv_transpose2d_direct}, prng::random_uniform, }, tensor::JitTensor, @@ -26,23 +23,30 @@ pub fn conv_transpose2d_autotune( static TUNER: LocalTuner = local_tuner!(); + let tune_set = TunableSet::new(create_key::, create_transpose2d_input::) + .with_tunable(conv_transpose2d_direct::) + .with_tunable(conv_transpose2d_col2im::); + TUNER.execute( &JitTuneId::new::(&input.device), &client, - Box::new(ConvTranspose2dOperations::::new( - input, weights, bias, options, - )), + &tune_set, + (input, weights, bias, options), ) } -#[tune(operations(conv_transpose2d_direct, conv_transpose2d_col2im), create_key = create_key::, should_run = should_run)] -pub fn conv_transpose2d_operations( - key: JitAutotuneKey, - input: JitTensor, - weights: JitTensor, - bias: Option>, - options: ConvTransposeOptions<2>, -) -> JitTensor { +pub fn create_transpose2d_input( + key: &JitAutotuneKey, + input: &JitTensor, + _weights: &JitTensor, + _bias: &Option>, + options: &ConvTransposeOptions<2>, +) -> ( + JitTensor, + JitTensor, + Option>, + ConvTransposeOptions<2>, +) { let key = match key { JitAutotuneKey::ConvTranspose2d(key) => key, _ => unreachable!(), @@ -60,7 +64,7 @@ pub fn conv_transpose2d_operations( let bias = key .has_bias .then(|| random_uniform(bias_shape, device, random_bounds.0, random_bounds.1)); - tune_with!(input, weights, bias, options) + (input, weights, bias, options.clone()) } fn create_key( @@ -94,20 +98,3 @@ fn create_key( E::dtype(), )) } - -fn should_run( - _op: &ConvTranspose2dOperations, - key: &JitAutotuneKey, - index: usize, -) -> bool { - let key = match key { - JitAutotuneKey::ConvTranspose2d(key) => key, - _ => unreachable!(), - }; - - match index { - // im2col - 1 => batches_per_run(key.batch_size, key.height, key.width).is_some(), - _ => true, - } -} diff --git a/crates/burn-jit/src/kernel/conv/error.rs b/crates/burn-jit/src/kernel/conv/error.rs index 2f15bc9886..99c91fc751 100644 --- a/crates/burn-jit/src/kernel/conv/error.rs +++ b/crates/burn-jit/src/kernel/conv/error.rs @@ -1,11 +1,29 @@ +use core::fmt::Debug; use cubecl::{linalg::matmul::kernels::MatmulLaunchError, tune::AutotuneError}; -#[derive(Debug)] pub enum ConvLaunchError { Matmul(MatmulLaunchError), + Groups(usize), Unknown, } +impl Debug for ConvLaunchError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ConvLaunchError::Matmul(err) => { + write!(f, "{err:?}") + } + ConvLaunchError::Groups(groups) => { + writeln!( + f, + "Unable to launch matmul because groups must be one, is actually {groups}", + ) + } + ConvLaunchError::Unknown => write!(f, "Unknown"), + } + } +} + impl From for ConvLaunchError { fn from(value: MatmulLaunchError) -> Self { Self::Matmul(value) diff --git a/crates/burn-jit/src/kernel/matmul/tune/base.rs b/crates/burn-jit/src/kernel/matmul/tune/base.rs index 46b1dfacc6..dacd2693b9 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/base.rs @@ -1,10 +1,7 @@ use burn_tensor::{Element, ElementConversion}; use cubecl::{ - ir::{Elem, FloatKind}, linalg::matmul::{kernels::tiling2d::Tiling2dConfig, Strategy}, - tune, - tune::{local_tuner, tune_with, LocalTuner}, - Feature, + tune::{local_tuner, LocalTuner, TunableSet}, }; use crate::{ @@ -18,44 +15,19 @@ use crate::{ use super::key::create_key; -#[tune( - operations(matmul_tiling2d, matmul_accelerated, matmul_simple), - create_key = create_key::, - should_run = should_run -)] -fn matmul_ops( - key: JitAutotuneKey, - lhs: JitTensor, - rhs: JitTensor, - out: JitTensor, -) { +fn matmul_input_gen( + _key: &JitAutotuneKey, + lhs: &JitTensor, + rhs: &JitTensor, + out: &JitTensor, +) -> (JitTensor, JitTensor, JitTensor) { let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); let lhs = random_like_uniform(lhs, random_bounds.0, random_bounds.1); let rhs = random_like_uniform(rhs, random_bounds.0, random_bounds.1); let out = empty_device::(out.client.clone(), out.device.clone(), out.shape.clone()); - tune_with!(lhs, rhs, out) -} - -fn should_run( - op: &MatmulOps, - _key: &JitAutotuneKey, - index: usize, -) -> bool { - match index { - // Accelerated - // TODO: Add way to query actual requirements from cubecl - 1 => op.lhs.client.properties().feature_enabled(Feature::Cmma { - a: Elem::Float(FloatKind::F16), - b: Elem::Float(FloatKind::F16), - c: Elem::Float(FloatKind::F32), - m: 16, - k: 16, - n: 16, - }), - _ => true, - } + (lhs, rhs, out) } /// Executes autotune on matmul operations @@ -70,10 +42,16 @@ pub fn matmul_autotune( static TUNER: LocalTuner = local_tuner!(); + let tunables = TunableSet::new(create_key::, matmul_input_gen::) + .with_tunable(matmul_tiling2d::) + .with_tunable(matmul_accelerated::) + .with_tunable(matmul_simple::); + TUNER.execute( &JitTuneId::new::(&lhs.device), &client, - Box::new(MatmulOps::::new(lhs, rhs, output.clone())), + &tunables, + (lhs, rhs, output.clone()), ); output diff --git a/crates/burn-jit/src/kernel/reduce/tune.rs b/crates/burn-jit/src/kernel/reduce/tune.rs index c5659cc1cc..b364907238 100644 --- a/crates/burn-jit/src/kernel/reduce/tune.rs +++ b/crates/burn-jit/src/kernel/reduce/tune.rs @@ -3,8 +3,7 @@ use burn_tensor::ElementConversion; use cubecl::{ client::ComputeClient, - tune, - tune::{local_tuner, tune_with, LocalTuner}, + tune::{local_tuner, LocalTuner, TunableSet}, AutotuneKey, }; use serde::{Deserialize, Serialize}; @@ -13,6 +12,7 @@ use crate::{ kernel::prng::random_like_uniform, ops::numeric::empty_device, tensor::JitTensor, JitAutotuneKey, JitElement, JitRuntime, JitTuneId, }; +use reduce_ops::*; /// Executes autotune on reduce operations. pub fn autotune_reduce< @@ -28,10 +28,17 @@ pub fn autotune_reduce< ) -> Result<(), cubecl::reduce::ReduceError> { static TUNER: LocalTuner = local_tuner!(); + let tunables = TunableSet::new(create_key::, reduce_input_gen::) + .with_tunable(reduce::) + .with_tunable(reduce_shared::) + .with_tunable(reduce_plane::) + .with_tunable(reduce_shared_plane::); + TUNER.execute( &JitTuneId::new::(&input.device), client, - Box::new(ReduceOps::::new(input, output, dim)), + &tunables, + (input, output, dim), ); Ok(()) @@ -85,23 +92,17 @@ pub(crate) fn create_key( JitAutotuneKey::Reduce(ReduceAutotuneKey::generate(input, *dim)) } -pub use reduce_ops::*; mod reduce_ops { #![allow(missing_docs)] use super::*; - #[tune( - operations(reduce, reduce_shared, reduce_plane, reduce_shared_plane), - create_key = create_key::, - should_run = should_run -)] - fn reduce_ops( - key: JitAutotuneKey, - input: JitTensor, - output: JitTensor, - dim: usize, - ) { + pub(crate) fn reduce_input_gen( + _key: &JitAutotuneKey, + input: &JitTensor, + output: &JitTensor, + dim: &usize, + ) -> (JitTensor, JitTensor, usize) { let random_bounds: (In, In) = ((-10.0_f32).elem::(), (10.0_f32).elem::()); let input = random_like_uniform(input, random_bounds.0, random_bounds.1); @@ -111,29 +112,15 @@ mod reduce_ops { output.shape.clone(), ); - tune_with!(input, output, dim) + (input, output, *dim) } - fn should_run( - op: &ReduceOps, - _key: &JitAutotuneKey, - index: usize, - ) -> bool { - match index { - // if strategy uses planes - 2 | 3 => { - let properties = op.input.client.properties(); - properties.feature_enabled(cubecl::Feature::Plane) - && properties - .hardware_properties() - .defined_plane_size() - .is_some() - } - _ => true, - } - } - - fn reduce( + pub(crate) fn reduce< + Run: JitRuntime, + In: JitElement, + Out: JitElement, + Rd: cubecl::reduce::Reduce, + >( input: JitTensor, output: JitTensor, axis: usize, @@ -151,7 +138,7 @@ mod reduce_ops { .map_err(|e| format!("{e}")) } - fn reduce_shared< + pub(crate) fn reduce_shared< Run: JitRuntime, In: JitElement, Out: JitElement, @@ -174,7 +161,7 @@ mod reduce_ops { .map_err(|e| format!("{e}")) } - fn reduce_plane< + pub(crate) fn reduce_plane< Run: JitRuntime, In: JitElement, Out: JitElement, @@ -197,7 +184,7 @@ mod reduce_ops { .map_err(|e| format!("{e}")) } - fn reduce_shared_plane< + pub(crate) fn reduce_shared_plane< Run: JitRuntime, In: JitElement, Out: JitElement, diff --git a/crates/burn-ndarray/src/ops/deform_conv.rs b/crates/burn-ndarray/src/ops/deform_conv.rs index 504c1a8c59..56b969a67c 100644 --- a/crates/burn-ndarray/src/ops/deform_conv.rs +++ b/crates/burn-ndarray/src/ops/deform_conv.rs @@ -622,7 +622,9 @@ pub mod backward { }); #[cfg(not(feature = "std"))] - run_par!(|| { iter_par!(Zip::indexed(columns).for_each(compute_for_each)) }); + run_par!(|| { + iter_par!(Zip::indexed(columns)).for_each(|args0, args1| compute_for_each(args0, args1)) + }); let grad_in: Array1 = grad_in .into_iter() From e54d03dbbf02ff2d8667f142f693f768bc31596d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 20 Jan 2025 09:57:01 -0500 Subject: [PATCH 38/61] Bump serde_json from 1.0.135 to 1.0.137 (#2713) Bumps [serde_json](https://github.com/serde-rs/json) from 1.0.135 to 1.0.137. - [Release notes](https://github.com/serde-rs/json/releases) - [Commits](https://github.com/serde-rs/json/compare/v1.0.135...v1.0.137) --- updated-dependencies: - dependency-name: serde_json dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index eef5ff3962..58191ae8bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6495,9 +6495,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.135" +version = "1.0.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" +checksum = "930cfb6e6abf99298aaad7d29abbef7a9999a9a8806a40088f55f0dcec03146b" dependencies = [ "itoa", "memchr", diff --git a/Cargo.toml b/Cargo.toml index 846d0f565c..22ed0b2644 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -140,7 +140,7 @@ serde = { version = "1.0.217", default-features = false, features = [ "derive", "alloc", ] } # alloc is for no_std, derive is needed -serde_json = { version = "1.0.135", default-features = false } +serde_json = { version = "1.0.137", default-features = false } uuid = { version = "1.12.0", default-features = false } libc = "0.2.169" From 140ea757ef8e5ca2cdf46e84e0bee9166983181d Mon Sep 17 00:00:00 2001 From: sunxunle <163647374+sunxunle@users.noreply.github.com> Date: Mon, 20 Jan 2025 23:06:06 +0800 Subject: [PATCH 39/61] chore: fix some comments (#2717) Signed-off-by: sunxunle --- crates/burn-import/src/pytorch/recorder.rs | 2 +- crates/burn-train/src/checkpoint/strategy/metric.rs | 2 +- examples/custom-image-dataset/src/dataset.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/burn-import/src/pytorch/recorder.rs b/crates/burn-import/src/pytorch/recorder.rs index 170f64a9d3..32dea273c9 100644 --- a/crates/burn-import/src/pytorch/recorder.rs +++ b/crates/burn-import/src/pytorch/recorder.rs @@ -11,7 +11,7 @@ use serde::{de::DeserializeOwned, Serialize}; use super::reader::from_file; -/// A recorder that that loads PyTorch files (`.pt`) into Burn modules. +/// A recorder that loads PyTorch files (`.pt`) into Burn modules. /// /// LoadArgs can be used to remap keys or file path. /// See [LoadArgs](struct.LoadArgs.html) for more information. diff --git a/crates/burn-train/src/checkpoint/strategy/metric.rs b/crates/burn-train/src/checkpoint/strategy/metric.rs index 4efcf14028..68f36bcd7f 100644 --- a/crates/burn-train/src/checkpoint/strategy/metric.rs +++ b/crates/burn-train/src/checkpoint/strategy/metric.rs @@ -114,7 +114,7 @@ mod tests { process_train(&mut processor, 0.3, epoch); end_epoch(&mut processor, epoch); - // Should save the current record and delete the pervious one. + // Should save the current record and delete the previous one. assert_eq!( vec![CheckpointingAction::Delete(1), CheckpointingAction::Save], strategy.checkpointing(epoch, &store) diff --git a/examples/custom-image-dataset/src/dataset.rs b/examples/custom-image-dataset/src/dataset.rs index eee2bdf9fc..b40396d3c6 100644 --- a/examples/custom-image-dataset/src/dataset.rs +++ b/examples/custom-image-dataset/src/dataset.rs @@ -5,7 +5,7 @@ use tar::Archive; use burn::data::{dataset::vision::ImageFolderDataset, network::downloader}; /// CIFAR-10 mirror from [fastai](https://github.com/fastai/fastai/blob/master/fastai/data/external.py#L44). -/// Licensed under the [Appache License](https://github.com/fastai/fastai/blob/master/LICENSE). +/// Licensed under the [Apache License](https://github.com/fastai/fastai/blob/master/LICENSE). const URL: &str = "https://s3.amazonaws.com/fast-ai-sample/cifar10.tgz"; /// The [CIFAR-10](https://www.cs.toronto.edu/%7Ekriz/cifar.html) dataset consists of 60,000 32x32 From b33bd24f88db3bf59519bb17378515ed829da05f Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 21 Jan 2025 14:37:40 -0500 Subject: [PATCH 40/61] Fix no default features flags + update cubecl (#2725) --- Cargo.lock | 157 +++++++++++++----- Cargo.toml | 6 +- crates/burn-autodiff/Cargo.toml | 2 +- crates/burn-core/src/lib.rs | 1 + .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 2 +- crates/burn-jit/src/kernel/mod.rs | 3 +- crates/burn-jit/src/template/base.rs | 3 +- crates/burn-ndarray/Cargo.toml | 2 +- crates/burn-router/src/lib.rs | 1 + crates/burn-tensor/Cargo.toml | 2 +- crates/burn-wgpu/src/lib.rs | 2 +- examples/guide/src/bin/infer.rs | 1 + examples/image-classification-web/src/lib.rs | 1 + examples/server/src/lib.rs | 2 + 14 files changed, 132 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 58191ae8bc..84a407e07b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1143,12 +1143,6 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" -[[package]] -name = "cfg_aliases" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" - [[package]] name = "cfg_aliases" version = "0.2.1" @@ -1581,7 +1575,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1596,13 +1590,17 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ + "bytemuck", "derive-new 0.6.0", + "derive_more 1.0.0", "embassy-futures", "futures-lite", "getrandom", + "half", "log", + "num-traits", "portable-atomic", "rand", "serde", @@ -1613,11 +1611,12 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ "bitflags 2.8.0", "bytemuck", "cubecl-common", + "cubecl-ir", "cubecl-macros", "cubecl-runtime", "derive-new 0.6.0", @@ -1633,7 +1632,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ "bytemuck", "cubecl-common", @@ -1647,7 +1646,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ "bytemuck", "cubecl-common", @@ -1663,7 +1662,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ "bytemuck", "cubecl-common", @@ -1686,10 +1685,23 @@ dependencies = [ "libc", ] +[[package]] +name = "cubecl-ir" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +dependencies = [ + "cubecl-common", + "float-ord", + "half", + "num-traits", + "serde", + "type_hash", +] + [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ "bytemuck", "cubecl-core", @@ -1701,7 +1713,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ "cubecl-common", "darling", @@ -1716,10 +1728,10 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ "cubecl-common", - "cubecl-core", + "cubecl-ir", "float-ord", "log", "num", @@ -1732,7 +1744,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1742,11 +1754,11 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ "async-channel", "async-lock", - "cfg_aliases 0.2.1", + "cfg_aliases", "cubecl-common", "derive-new 0.6.0", "dirs", @@ -1764,7 +1776,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1779,13 +1791,13 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2cc42af02671d90255ab823e29a4a3ad2e564333#2cc42af02671d90255ab823e29a4a3ad2e564333" +source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" dependencies = [ "ash", "async-channel", "bytemuck", "cfg-if", - "cfg_aliases 0.2.1", + "cfg_aliases", "cubecl-common", "cubecl-core", "cubecl-runtime", @@ -2835,9 +2847,9 @@ dependencies = [ [[package]] name = "glow" -version = "0.14.2" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d51fa363f025f5c111e03f13eda21162faeacb6911fe8caa0c0349f9cf0c4483" +checksum = "c5e5ea60d70410161c8bf5da3fdfeaa1c72ed2c15f8bbb9d19fe3a4fad085f08" dependencies = [ "js-sys", "slotmap", @@ -3907,6 +3919,21 @@ dependencies = [ "paste", ] +[[package]] +name = "metal" +version = "0.31.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f569fb946490b5743ad69813cb19629130ce9374034abe31614a36402d18f99e" +dependencies = [ + "bitflags 2.8.0", + "block", + "core-graphics-types", + "foreign-types 0.5.0", + "log", + "objc", + "paste", +] + [[package]] name = "mime" version = "0.3.17" @@ -4015,22 +4042,23 @@ dependencies = [ [[package]] name = "naga" -version = "23.1.0" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "364f94bc34f61332abebe8cad6f6cd82a5b65cff22c828d05d0968911462ca4f" +checksum = "e380993072e52eef724eddfcde0ed013b0c023c3f0417336ed041aa9f076994e" dependencies = [ "arrayvec", "bit-set", "bitflags 2.8.0", - "cfg_aliases 0.1.1", + "cfg_aliases", "codespan-reporting", "hexf-parse", "indexmap", "log", "rustc-hash 1.1.0", "spirv 0.3.0+sdk-1.3.268.0", + "strum", "termcolor", - "thiserror 1.0.69", + "thiserror 2.0.11", "unicode-xid", ] @@ -4624,6 +4652,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ordered-float" +version = "4.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bb71e1b3fa6ca1c61f383464aaf2bb0e2f8e772a1f01d486832464de363b951" +dependencies = [ + "num-traits", +] + [[package]] name = "os_info" version = "3.9.2" @@ -5687,7 +5724,7 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed904" dependencies = [ - "cfg_aliases 0.2.1", + "cfg_aliases", "libc", "once_cell", "socket2", @@ -7583,6 +7620,37 @@ dependencies = [ "rustc-hash 1.1.0", ] +[[package]] +name = "type_hash" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03c86f48f11992d3e379358c63cb25736c0b23944ff000d1583bbccad2b0b7c6" +dependencies = [ + "type_hash_core", + "type_hash_macros", +] + +[[package]] +name = "type_hash_core" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87b1e93e2cd97790892dbe2d2813fbaa6eebaeb960265f59e363e79e51e4997a" +dependencies = [ + "fnv", +] + +[[package]] +name = "type_hash_macros" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "746fc164e076483ef087b3989f7aa80ffd9320fa558f3cb72cecfb9bb1dbc41e" +dependencies = [ + "either", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "typenum" version = "1.17.0" @@ -8020,12 +8088,13 @@ dependencies = [ [[package]] name = "wgpu" -version = "23.0.1" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80f70000db37c469ea9d67defdc13024ddf9a5f1b89cb2941b812ad7cde1735a" +checksum = "e41253fc7b660735e2a2d9a58c563f2a047d3cc3445293d8f4095538c9e8afbe" dependencies = [ "arrayvec", - "cfg_aliases 0.1.1", + "bitflags 2.8.0", + "cfg_aliases", "document-features", "js-sys", "log", @@ -8045,14 +8114,14 @@ dependencies = [ [[package]] name = "wgpu-core" -version = "23.0.1" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d63c3c478de8e7e01786479919c8769f62a22eec16788d8c2ac77ce2c132778a" +checksum = "82a39b8842dc9ffcbe34346e3ab6d496b32a47f6497e119d762c97fcaae3cb37" dependencies = [ "arrayvec", "bit-vec", "bitflags 2.8.0", - "cfg_aliases 0.1.1", + "cfg_aliases", "document-features", "indexmap", "log", @@ -8063,16 +8132,16 @@ dependencies = [ "raw-window-handle", "rustc-hash 1.1.0", "smallvec", - "thiserror 1.0.69", + "thiserror 2.0.11", "wgpu-hal", "wgpu-types", ] [[package]] name = "wgpu-hal" -version = "23.0.1" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89364b8a0b211adc7b16aeaf1bd5ad4a919c1154b44c9ce27838213ba05fd821" +checksum = "5a782e5056b060b0b4010881d1decddd059e44f2ecd01e2db2971b48ad3627e5" dependencies = [ "android_system_properties", "arrayvec", @@ -8081,7 +8150,7 @@ dependencies = [ "bitflags 2.8.0", "block", "bytemuck", - "cfg_aliases 0.1.1", + "cfg_aliases", "core-graphics-types", "glow", "glutin_wgl_sys", @@ -8093,11 +8162,12 @@ dependencies = [ "libc", "libloading", "log", - "metal 0.29.0", + "metal 0.31.0", "naga", "ndk-sys", "objc", "once_cell", + "ordered-float", "parking_lot 0.12.3", "profiling", "range-alloc", @@ -8105,7 +8175,7 @@ dependencies = [ "renderdoc-sys", "rustc-hash 1.1.0", "smallvec", - "thiserror 1.0.69", + "thiserror 2.0.11", "wasm-bindgen", "web-sys", "wgpu-types", @@ -8115,12 +8185,13 @@ dependencies = [ [[package]] name = "wgpu-types" -version = "23.0.0" +version = "24.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "610f6ff27778148c31093f3b03abc4840f9636d58d597ca2f5977433acfe0068" +checksum = "50ac044c0e76c03a0378e7786ac505d010a873665e2d51383dcff8dd227dc69c" dependencies = [ "bitflags 2.8.0", "js-sys", + "log", "web-sys", ] diff --git a/Cargo.toml b/Cargo.toml index 22ed0b2644..c64b00f0dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -101,7 +101,7 @@ ratatui = "0.29.0" # WGPU stuff text_placeholder = "0.5.1" -wgpu = "23.0.0" +wgpu = "24.0.0" # Benchmarks and Burnbench arboard = "3.4.1" @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-autodiff/Cargo.toml b/crates/burn-autodiff/Cargo.toml index 5e221f887f..df7040f835 100644 --- a/crates/burn-autodiff/Cargo.toml +++ b/crates/burn-autodiff/Cargo.toml @@ -18,7 +18,7 @@ std = [] async = [] # Require std [dependencies] -burn-common = { path = "../burn-common", version = "0.17.0" } +burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false } burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } diff --git a/crates/burn-core/src/lib.rs b/crates/burn-core/src/lib.rs index f554518430..ade8d64db7 100644 --- a/crates/burn-core/src/lib.rs +++ b/crates/burn-core/src/lib.rs @@ -1,6 +1,7 @@ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![recursion_limit = "135"] //! The core crate of Burn. diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index 7f9914989a..6b738ab988 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -99,7 +99,7 @@ fn im2col_kernel( #[cfg(not(test))] pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option { - let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::PLANE_DIM_APPROX); + let cube_count_per_batch = (out_h * out_w).div_ceil(burn_common::PLANE_DIM_APPROX); let max_cube_count = u16::MAX as usize; let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size); if max_simultaneous == 0 { diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index 660ae2f6fd..fd23cd2e2d 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -15,7 +15,8 @@ pub use mask::*; pub(crate) use unary_float::*; pub(crate) use unary_numeric::*; -pub use cubecl::{Kernel, PLANE_DIM_APPROX}; +pub use burn_common::PLANE_DIM_APPROX; +pub use cubecl::Kernel; /// Convolution kernels pub mod conv; diff --git a/crates/burn-jit/src/template/base.rs b/crates/burn-jit/src/template/base.rs index 54e50468fb..cfdf3319fe 100644 --- a/crates/burn-jit/src/template/base.rs +++ b/crates/burn-jit/src/template/base.rs @@ -1,5 +1,6 @@ use crate::{element::JitElement, tensor::JitTensor, JitRuntime}; -use cubecl::{prelude::*, Compiler, ExecutionMode, KernelId}; +use burn_common::ExecutionMode; +use cubecl::{prelude::*, Compiler, KernelId}; use super::SourceTemplate; diff --git a/crates/burn-ndarray/Cargo.toml b/crates/burn-ndarray/Cargo.toml index 167cf88c1a..111649ab25 100644 --- a/crates/burn-ndarray/Cargo.toml +++ b/crates/burn-ndarray/Cargo.toml @@ -43,7 +43,7 @@ blas-openblas-system = [ # ** Please make sure all dependencies support no_std when std is disabled ** -burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true } +burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, optional = true } burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = ["repr"] } diff --git a/crates/burn-router/src/lib.rs b/crates/burn-router/src/lib.rs index 644f65ee67..773235f781 100644 --- a/crates/burn-router/src/lib.rs +++ b/crates/burn-router/src/lib.rs @@ -1,6 +1,7 @@ #![cfg_attr(not(feature = "std"), no_std)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![recursion_limit = "138"] //! Burn multi-backend router. diff --git a/crates/burn-tensor/Cargo.toml b/crates/burn-tensor/Cargo.toml index 318912b2f7..7428408292 100644 --- a/crates/burn-tensor/Cargo.toml +++ b/crates/burn-tensor/Cargo.toml @@ -32,7 +32,7 @@ std = [ [dependencies] burn-common = { path = "../burn-common", version = "0.17.0", default-features = false } burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true } -cubecl = { workspace = true, optional = true, default-features = true } +cubecl = { workspace = true, optional = true, default-features = false } bytemuck = { workspace = true, features = ["extern_crate_alloc"] } colored = { workspace = true, optional = true } diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index deb6a8ebd8..7aab106b29 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -12,8 +12,8 @@ pub use burn_jit::{ pub use burn_jit::{tensor::JitTensor, JitBackend}; pub use burn_jit::{BoolElement, FloatElement, IntElement}; pub use cubecl::flex32; -pub use cubecl::ir::CubeDim; pub use cubecl::wgpu::*; +pub use cubecl::CubeDim; pub type Wgsl = cubecl::wgpu::WgslCompiler; #[cfg(feature = "spirv")] diff --git a/examples/guide/src/bin/infer.rs b/examples/guide/src/bin/infer.rs index 3c64879bc5..6a246d85f0 100644 --- a/examples/guide/src/bin/infer.rs +++ b/examples/guide/src/bin/infer.rs @@ -1,3 +1,4 @@ +#![recursion_limit = "131"] use burn::{backend::Wgpu, data::dataset::Dataset}; use guide::inference; diff --git a/examples/image-classification-web/src/lib.rs b/examples/image-classification-web/src/lib.rs index 3881123eaf..3d528f2e9d 100644 --- a/examples/image-classification-web/src/lib.rs +++ b/examples/image-classification-web/src/lib.rs @@ -1,4 +1,5 @@ #![cfg_attr(not(test), no_std)] +#![recursion_limit = "135"] pub mod model; pub mod web; diff --git a/examples/server/src/lib.rs b/examples/server/src/lib.rs index 92cba57a2a..70705a0876 100644 --- a/examples/server/src/lib.rs +++ b/examples/server/src/lib.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "141"] + pub fn start() { let port = std::env::var("REMOTE_BACKEND_PORT") .map(|port| match port.parse::() { From 245fbcdccdbc9da2a0f3c3208aa7556b77f98ce7 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Wed, 22 Jan 2025 11:55:09 -0500 Subject: [PATCH 41/61] Feat/fused matmul tune (#2726) --- Cargo.lock | 28 +-- Cargo.toml | 4 +- backend-comparison/benches/matmul_fused.rs | 10 +- crates/burn-fusion/src/stream/context.rs | 78 ++++++++ crates/burn-jit/src/fusion/base.rs | 16 +- crates/burn-jit/src/fusion/matmul/builder.rs | 8 +- crates/burn-jit/src/fusion/matmul/mod.rs | 1 + .../src/fusion/matmul/optimization.rs | 176 +++++++++++++++--- crates/burn-jit/src/fusion/matmul/tune.rs | 133 +++++++++++++ crates/burn-jit/src/fusion/mod.rs | 1 + crates/burn-jit/src/fusion/tune.rs | 108 +++++++++++ crates/burn-jit/src/kernel/matmul/tune/key.rs | 2 +- crates/burn-tensor/src/repr/handle.rs | 18 ++ .../examples/ag-news-train.rs | 2 + 14 files changed, 525 insertions(+), 60 deletions(-) create mode 100644 crates/burn-jit/src/fusion/matmul/tune.rs create mode 100644 crates/burn-jit/src/fusion/tune.rs diff --git a/Cargo.lock b/Cargo.lock index 84a407e07b..de9444c5ff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1575,7 +1575,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1590,7 +1590,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1611,7 +1611,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1632,7 +1632,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "cubecl-common", @@ -1646,7 +1646,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "cubecl-common", @@ -1662,7 +1662,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "cubecl-common", @@ -1688,7 +1688,7 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-common", "float-ord", @@ -1701,7 +1701,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bytemuck", "cubecl-core", @@ -1713,7 +1713,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-common", "darling", @@ -1728,7 +1728,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1744,7 +1744,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1754,7 +1754,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "async-channel", "async-lock", @@ -1776,7 +1776,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1791,7 +1791,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c#dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" +source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index c64b00f0dd..f731d063a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/backend-comparison/benches/matmul_fused.rs b/backend-comparison/benches/matmul_fused.rs index 375be97b4e..fbec64c648 100644 --- a/backend-comparison/benches/matmul_fused.rs +++ b/backend-comparison/benches/matmul_fused.rs @@ -1,5 +1,9 @@ use backend_comparison::persistence::save; -use burn::tensor::{activation::relu, backend::Backend, Distribution, Shape, Tensor}; +use burn::tensor::{ + activation::{gelu, relu}, + backend::Backend, + Distribution, Shape, Tensor, +}; use burn_common::benchmark::{run_benchmark, Benchmark}; use derive_new::new; @@ -14,7 +18,7 @@ impl Benchmark for MatmulBenchmark { type Args = (Tensor, Tensor, Tensor); fn name(&self) -> String { - "matmul_bias_relu".into() + "matmul_relu_bias_gelu".into() } fn shapes(&self) -> Vec> { @@ -23,7 +27,7 @@ impl Benchmark for MatmulBenchmark { fn execute(&self, (lhs, rhs, bias): Self::Args) { let bias = bias.unsqueeze(); - relu(lhs.matmul(rhs) + bias); + gelu(relu(lhs.matmul(rhs)) + bias); } fn prepare(&self) -> Self::Args { diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index d5e1ee9e38..0f9c75fc94 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -59,6 +59,84 @@ pub(crate) struct OperationConverter { scalar_u8: Vec, } +/// Fork of a [context](Context) which owns its data. +pub struct ContextOwned { + tensors: HashMap, + handles: HandleContainer, + scalar_f32: Vec, + scalar_f16: Vec, + scalar_bf16: Vec, + scalar_i64: Vec, + scalar_i32: Vec, + scalar_i16: Vec, + scalar_i8: Vec, + scalar_u64: Vec, + scalar_u32: Vec, + scalar_u16: Vec, + scalar_u8: Vec, +} + +impl ContextOwned { + /// Convert into [context](Context). + pub fn as_context(&mut self) -> Context<'_, H> { + Context { + tensors: &mut self.tensors, + handles: &mut self.handles, + scalar_f32: &self.scalar_f32, + scalar_f16: &self.scalar_f16, + scalar_bf16: &self.scalar_bf16, + scalar_i64: &self.scalar_i64, + scalar_i32: &self.scalar_i32, + scalar_i16: &self.scalar_i16, + scalar_i8: &self.scalar_i8, + scalar_u64: &self.scalar_u64, + scalar_u32: &self.scalar_u32, + scalar_u16: &self.scalar_u16, + scalar_u8: &self.scalar_u8, + } + } + + /// Fork the context again. + pub fn fork(&self) -> ContextOwned { + ContextOwned { + tensors: self.tensors.clone(), + handles: self.handles.fork(), + scalar_f32: self.scalar_f32.clone(), + scalar_f16: self.scalar_f16.clone(), + scalar_bf16: self.scalar_bf16.clone(), + scalar_i64: self.scalar_i64.clone(), + scalar_i32: self.scalar_i32.clone(), + scalar_i16: self.scalar_i16.clone(), + scalar_i8: self.scalar_i8.clone(), + scalar_u64: self.scalar_u64.clone(), + scalar_u32: self.scalar_u32.clone(), + scalar_u16: self.scalar_u16.clone(), + scalar_u8: self.scalar_u8.clone(), + } + } +} + +impl Context<'_, H> { + /// Fork the context into an [owned context](ContextOwned). + pub fn fork(&self) -> ContextOwned { + ContextOwned { + tensors: self.tensors.clone(), + handles: self.handles.fork(), + scalar_f32: self.scalar_f32.clone(), + scalar_f16: self.scalar_f16.clone(), + scalar_bf16: self.scalar_bf16.clone(), + scalar_i64: self.scalar_i64.clone(), + scalar_i32: self.scalar_i32.clone(), + scalar_i16: self.scalar_i16.clone(), + scalar_i8: self.scalar_i8.clone(), + scalar_u64: self.scalar_u64.clone(), + scalar_u32: self.scalar_u32.clone(), + scalar_u16: self.scalar_u16.clone(), + scalar_u8: self.scalar_u8.clone(), + } + } +} + pub(crate) trait RelativeOps { /// Convert (usually an [`OperationDescription`]) to a relative form. /// diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index e8e4d82659..48587a1bf9 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -125,20 +125,16 @@ impl FusionRuntime for FusionJitRuntime { fn optimizations( device: R::Device, ) -> Vec>> { - let mut optimizations: Vec>> = - vec![Box::new(ElementWiseBuilder::::new( + vec![ + Box::new(ElementWiseBuilder::::new( device.clone(), BT::as_elem_native_unchecked().into(), - ))]; - - if cfg!(feature = "fusion-experimental") { - optimizations.push(Box::new(MatmulBuilder::::new( + )), + Box::new(MatmulBuilder::::new( device.clone(), BT::as_elem_native_unchecked().into(), - ))); - } - - optimizations + )), + ] } } diff --git a/crates/burn-jit/src/fusion/matmul/builder.rs b/crates/burn-jit/src/fusion/matmul/builder.rs index 986332914f..f197237819 100644 --- a/crates/burn-jit/src/fusion/matmul/builder.rs +++ b/crates/burn-jit/src/fusion/matmul/builder.rs @@ -47,7 +47,13 @@ impl OptimizationBuilder> for MatmulBuilder let rhs = self.builder.input_unhandled(&op.rhs); let out = self.builder.output_unhandled(&op.out); - self.matmul = Some(FusedMatmul::new(lhs, rhs, out, op.clone())); + self.matmul = Some(FusedMatmul::new( + lhs, + rhs, + out, + op.clone(), + Default::default(), + )); } else { self.builder.close(); } diff --git a/crates/burn-jit/src/fusion/matmul/mod.rs b/crates/burn-jit/src/fusion/matmul/mod.rs index 1afeef9c88..cddec5983a 100644 --- a/crates/burn-jit/src/fusion/matmul/mod.rs +++ b/crates/burn-jit/src/fusion/matmul/mod.rs @@ -2,3 +2,4 @@ pub(crate) mod args; pub(crate) mod builder; pub(crate) mod optimization; pub(crate) mod spec; +pub(crate) mod tune; diff --git a/crates/burn-jit/src/fusion/matmul/optimization.rs b/crates/burn-jit/src/fusion/matmul/optimization.rs index d0cd8749ad..9a020df62c 100644 --- a/crates/burn-jit/src/fusion/matmul/optimization.rs +++ b/crates/burn-jit/src/fusion/matmul/optimization.rs @@ -12,7 +12,9 @@ use burn_tensor::Shape; use cubecl::linalg::matmul::components; use cubecl::linalg::matmul::components::tile::accelerated::Accelerated; use cubecl::linalg::matmul::components::MatmulProblem; -use cubecl::linalg::matmul::kernels::matmul::{MatmulSelector, StandardSelector}; +use cubecl::linalg::matmul::kernels::matmul::{ + MatmulSelector, PipelinedSelector, SpecializedSelector, StandardSelector, +}; use cubecl::linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}; use cubecl::linalg::tensor::{matrix_layout, MatrixLayout}; use cubecl::{client::ComputeClient, prelude::*}; @@ -26,16 +28,18 @@ use crate::fusion::on_write::{ use super::args::FusedMatmulInputLaunch; use super::spec::FusedMatmulSpec; +use super::tune::fused_matmul_autotune; -#[derive(new)] /// Fuse matmul operation followed by elemwise operations into a single kernel. pub struct MatmulOptimization { trace: FuseOnWriteTrace, trace_fallback: FuseOnWriteTrace, - client: ComputeClient, - device: R::Device, - len: usize, - matmul: FusedMatmul, + pub(crate) client: ComputeClient, + pub(crate) device: R::Device, + pub(crate) len: usize, + pub(crate) matmul_standard: FusedMatmul, + pub(crate) matmul_pipelined: FusedMatmul, + pub(crate) matmul_specialized: FusedMatmul, } #[derive(Serialize, Deserialize, Debug)] @@ -43,13 +47,46 @@ pub struct MatmulOptimization { pub struct MatmulOptimizationState { trace: FuseOnWriteTrace, trace_fallback: FuseOnWriteTrace, - matmul: FusedMatmul, + matmul_standard: FusedMatmul, + matmul_pipelined: FusedMatmul, + matmul_specialized: FusedMatmul, len: usize, } impl MatmulOptimization { + pub fn new( + trace: FuseOnWriteTrace, + trace_fallback: FuseOnWriteTrace, + client: ComputeClient, + device: R::Device, + len: usize, + matmul: FusedMatmul, + ) -> Self { + let mut matmul_standard = matmul.clone(); + let mut matmul_specialized = matmul.clone(); + let mut matmul_pipelined = matmul; + + matmul_standard.selector = FusedMatmulSelector::Standard; + matmul_specialized.selector = FusedMatmulSelector::Specialized; + matmul_pipelined.selector = FusedMatmulSelector::Pipelined; + + Self { + trace, + trace_fallback, + client, + device, + len, + matmul_standard, + matmul_pipelined, + matmul_specialized, + } + } /// Execute the optimization. pub fn execute(&mut self, context: &mut Context<'_, JitFusionHandle>) { + #[cfg(feature = "autotune")] + fused_matmul_autotune::(self, context); + + #[cfg(not(feature = "autotune"))] if self.execute_fused::(context).is_err() { self.execute_fallback::(context); } @@ -68,7 +105,9 @@ impl MatmulOptimization { len: state.len, client: R::client(device), device: device.clone(), - matmul: state.matmul.clone(), + matmul_standard: state.matmul_standard.clone(), + matmul_specialized: state.matmul_specialized.clone(), + matmul_pipelined: state.matmul_pipelined.clone(), } } @@ -77,21 +116,51 @@ impl MatmulOptimization { MatmulOptimizationState { trace: self.trace.clone(), trace_fallback: self.trace_fallback.clone(), - matmul: self.matmul.clone(), + matmul_standard: self.matmul_standard.clone(), + matmul_specialized: self.matmul_specialized.clone(), + matmul_pipelined: self.matmul_pipelined.clone(), len: self.len, } } - fn execute_fused( - &mut self, + pub fn execute_standard_fused( + &self, context: &mut Context<'_, JitFusionHandle>, ) -> Result<(), FusedMatmulError> { - self.trace - .run::(&self.client, &self.device, context, &self.matmul) + self.trace.run::( + &self.client, + &self.device, + context, + &self.matmul_standard, + ) } - fn execute_fallback(&mut self, context: &mut Context<'_, JitFusionHandle>) { - match self.matmul.lhs.precision() { + pub fn execute_specialized_fused( + &self, + context: &mut Context<'_, JitFusionHandle>, + ) -> Result<(), FusedMatmulError> { + self.trace.run::( + &self.client, + &self.device, + context, + &self.matmul_specialized, + ) + } + + pub fn execute_pipelined_fused( + &self, + context: &mut Context<'_, JitFusionHandle>, + ) -> Result<(), FusedMatmulError> { + self.trace.run::( + &self.client, + &self.device, + context, + &self.matmul_pipelined, + ) + } + + pub fn execute_fallback(&self, context: &mut Context<'_, JitFusionHandle>) { + match self.matmul_standard.lhs.precision() { ElemwisePrecision::F32 => self.run_fallback::(context), ElemwisePrecision::F16 => self.run_fallback::(context), ElemwisePrecision::BF16 => self.run_fallback::(context), @@ -100,13 +169,25 @@ impl MatmulOptimization { } fn run_fallback( - &mut self, + &self, context: &mut Context<'_, JitFusionHandle>, ) { let (out_tensor, out_desc) = { - let lhs = context.tensors.get(&self.matmul.op.lhs.id).unwrap().clone(); - let rhs = context.tensors.get(&self.matmul.op.rhs.id).unwrap().clone(); - let out = context.tensors.get(&self.matmul.op.out.id).unwrap().clone(); + let lhs = context + .tensors + .get(&self.matmul_standard.op.lhs.id) + .unwrap() + .clone(); + let rhs = context + .tensors + .get(&self.matmul_standard.op.rhs.id) + .unwrap() + .clone(); + let out = context + .tensors + .get(&self.matmul_standard.op.out.id) + .unwrap() + .clone(); let lhs_handle = context.handles.get_handle(&lhs.id, &TensorStatus::ReadOnly); let rhs_handle = context.handles.get_handle(&rhs.id, &TensorStatus::ReadOnly); @@ -136,12 +217,21 @@ impl MatmulOptimization { } } +#[derive(Default, Clone, Serialize, Deserialize, Debug)] +pub enum FusedMatmulSelector { + #[default] + Standard, + Pipelined, + Specialized, +} + #[derive(new, Clone, Serialize, Deserialize, Debug)] pub struct FusedMatmul { lhs: Arg, rhs: Arg, out: Arg, - op: BinaryOperationDescription, + pub(crate) op: BinaryOperationDescription, + pub(crate) selector: FusedMatmulSelector, } #[derive(Debug)] @@ -261,15 +351,43 @@ impl FusedMatmul { } }; - match matmul_launch_kernel::>( - client, - FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), - outputs, - problem, - plane_size, - ) { - Ok(_) => Ok(()), - Err(err) => Err(FusedMatmulError::LaunchError(err)), + match self.selector { + FusedMatmulSelector::Standard => { + match matmul_launch_kernel::>( + client, + FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), + outputs, + problem, + plane_size, + ) { + Ok(_) => Ok(()), + Err(err) => Err(FusedMatmulError::LaunchError(err)), + } + } + FusedMatmulSelector::Pipelined => { + match matmul_launch_kernel::>( + client, + FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), + outputs, + problem, + plane_size, + ) { + Ok(_) => Ok(()), + Err(err) => Err(FusedMatmulError::LaunchError(err)), + } + } + FusedMatmulSelector::Specialized => { + match matmul_launch_kernel::>( + client, + FusedMatmulInputLaunch::new(inputs, config, &self.lhs, &self.rhs, &self.out), + outputs, + problem, + plane_size, + ) { + Ok(_) => Ok(()), + Err(err) => Err(FusedMatmulError::LaunchError(err)), + } + } } } } diff --git a/crates/burn-jit/src/fusion/matmul/tune.rs b/crates/burn-jit/src/fusion/matmul/tune.rs new file mode 100644 index 0000000000..0f6e42c486 --- /dev/null +++ b/crates/burn-jit/src/fusion/matmul/tune.rs @@ -0,0 +1,133 @@ +use crate::{ + fusion::{ + tune::{TuneContext, TuneInput}, + JitFusionHandle, + }, + kernel::matmul::MatmulAutotuneKey, + BoolElement, JitRuntime, JitTuneId, +}; +use burn_fusion::stream::Context; +use cubecl::{ + tune::{local_tuner, LocalTuner, TunableSet}, + AutotuneKey, +}; +use serde::{Deserialize, Serialize}; + +use super::optimization::MatmulOptimization; + +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +pub struct FusedMatmulAutotuneKey { + matmul_key: MatmulAutotuneKey, + #[autotune(anchor)] + num_ops_fused: usize, +} + +/// Executes autotune on matmul operations +pub fn fused_matmul_autotune( + optimization: &MatmulOptimization, + context: &mut Context>, +) { + static TUNER: LocalTuner = local_tuner!(); + + let tunables = TunableSet::new(create_key::, input_gen::) + .with_tunable(tune_standard_fused::) + .with_tunable(tune_specialized_fused::) + .with_tunable(tune_pipelined_fused::) + .with_tunable(tune_fallback::); + + TUNER.execute( + &JitTuneId::new::(&optimization.device), + &optimization.client, + &tunables, + TuneInput::new(context, optimization), + ); +} + +pub(crate) fn create_key( + input: &TuneInput>, +) -> FusedMatmulAutotuneKey { + let opt = input.optimization(); + let context = match input.context() { + TuneContext::Original(context) => context, + TuneContext::Fork(_) => panic!("Not supported when generating key"), + }; + + let lhs = context.tensors.get(&opt.matmul_standard.op.lhs.id).unwrap(); + let rhs = context.tensors.get(&opt.matmul_standard.op.rhs.id).unwrap(); + let out = context.tensors.get(&opt.matmul_standard.op.out.id).unwrap(); + + let key = MatmulAutotuneKey::from_shape( + &lhs.shape.clone().into(), + &rhs.shape.clone().into(), + out.dtype, + ); + FusedMatmulAutotuneKey::new(key, opt.len) +} + +fn input_gen( + _key: &FusedMatmulAutotuneKey, + input: &TuneInput>, +) -> TuneInput> { + input.clone() +} + +fn tune_standard_fused( + input: TuneInput>, +) -> Result<(), String> { + let optimization = input.optimization(); + let context = input.context(); + + match context { + TuneContext::Original(context) => optimization.execute_standard_fused::(context), + TuneContext::Fork(mut context_owned) => { + optimization.execute_standard_fused::(&mut context_owned.as_context()) + } + } + .map_err(|e| format!("{e:?}")) +} + +fn tune_specialized_fused( + input: TuneInput>, +) -> Result<(), String> { + let optimization = input.optimization(); + let context = input.context(); + + match context { + TuneContext::Original(context) => optimization.execute_specialized_fused::(context), + TuneContext::Fork(mut context_owned) => { + optimization.execute_specialized_fused::(&mut context_owned.as_context()) + } + } + .map_err(|e| format!("{e:?}")) +} + +fn tune_pipelined_fused( + input: TuneInput>, +) -> Result<(), String> { + let optimization = input.optimization(); + let context = input.context(); + + match context { + TuneContext::Original(context) => optimization.execute_pipelined_fused::(context), + TuneContext::Fork(mut context_owned) => { + optimization.execute_pipelined_fused::(&mut context_owned.as_context()) + } + } + .map_err(|e| format!("{e:?}")) +} + +fn tune_fallback( + input: TuneInput>, +) -> Result<(), String> { + let optimization = input.optimization(); + let context = input.context(); + + match context { + TuneContext::Original(context) => optimization.execute_fallback::(context), + TuneContext::Fork(mut context_owned) => { + optimization.execute_fallback::(&mut context_owned.as_context()) + } + }; + + Ok(()) +} diff --git a/crates/burn-jit/src/fusion/mod.rs b/crates/burn-jit/src/fusion/mod.rs index 4c44770b4e..96e1704964 100644 --- a/crates/burn-jit/src/fusion/mod.rs +++ b/crates/burn-jit/src/fusion/mod.rs @@ -3,5 +3,6 @@ mod base; pub(crate) mod elemwise; pub(crate) mod matmul; pub(crate) mod on_write; +pub(crate) mod tune; pub use base::*; diff --git a/crates/burn-jit/src/fusion/tune.rs b/crates/burn-jit/src/fusion/tune.rs new file mode 100644 index 0000000000..8c45f93bb0 --- /dev/null +++ b/crates/burn-jit/src/fusion/tune.rs @@ -0,0 +1,108 @@ +use super::JitFusionHandle; +use crate::JitRuntime; +use burn_fusion::stream::{Context, ContextOwned}; + +/// Fusion context used when tuning kernels. +/// +/// Either the original context is returned or a fork of the original. +/// The fork is only given when performing autotuning, and not when actually performing the +/// operation. +pub enum TuneContext<'a, R: JitRuntime> { + Original(&'a mut Context<'a, JitFusionHandle>), + Fork(Box>>), +} + +/// Fusion input wrapper containing the context and the optimization. +/// +/// # Safety +/// +/// This should only be used with the [tuner](cubecl::tune::LocalTuner), since safety assumptions +/// are made based on its behavior. +pub struct TuneInput { + context: UnsafeTuneContext, + optimization: *const O, +} + +/// Unsafe wrapper around the context. +/// +/// # Safety +/// +/// The wrapper removes the context lifetime. +/// +/// For it to be correct, the context must not be used after the invocation of the +/// [cubecl::tune::LocalTuner::execute] function. This is the case, since autotune functions are +/// tuned using a cloned version of the input; therefore, a fork of the context will be used to find +/// the best kernel to use, which can be async. +enum UnsafeTuneContext { + Original(*mut Context<'static, JitFusionHandle>), + Fork(Box>>), +} + +unsafe impl Send for UnsafeTuneContext {} +unsafe impl Send for TuneInput {} + +impl TuneInput { + /// Create a new autotune input from the [context](Context) and an optimization. + pub fn new(context: &mut Context>, optimization: &O) -> Self { + let context = UnsafeTuneContext::new(context); + // We can erase the lifetime for the same reason we do with the context. + let optimization = core::ptr::from_ref(optimization); + + Self { + context, + optimization, + } + } + + /// Retrieve the [autotune context](TuneContext) for the current input. + pub fn context(&self) -> TuneContext<'static, R> { + self.context.get() + } + + /// Retrieve the optimization for the current input. + pub fn optimization(&self) -> &O { + unsafe { self.optimization.as_ref().unwrap() } + } +} + +impl UnsafeTuneContext { + fn new(context: &mut Context<'_, JitFusionHandle>) -> Self { + let ptr = core::ptr::from_mut(context); + + // It is necessary for the lifetime. + #[allow(clippy::unnecessary_cast)] + Self::Original(ptr as *mut Context<'static, _>) + } + + fn get(&self) -> TuneContext<'static, R> { + match self { + UnsafeTuneContext::Original(ptr) => { + TuneContext::Original(unsafe { ptr.as_mut().unwrap() }) + } + UnsafeTuneContext::Fork(context) => TuneContext::Fork(Box::new(context.fork())), + } + } +} + +impl Clone for TuneInput { + fn clone(&self) -> Self { + Self { + context: self.context.clone(), + optimization: self.optimization, + } + } +} + +impl Clone for UnsafeTuneContext { + fn clone(&self) -> Self { + let context = match self { + UnsafeTuneContext::Original(ptr) => { + let context: &mut Context<'static, JitFusionHandle> = + unsafe { ptr.as_mut().unwrap() }; + context.fork() + } + UnsafeTuneContext::Fork(context) => context.fork(), + }; + UnsafeTuneContext::Fork(Box::new(context)) + } +} diff --git a/crates/burn-jit/src/kernel/matmul/tune/key.rs b/crates/burn-jit/src/kernel/matmul/tune/key.rs index d25cce3023..44cb079399 100644 --- a/crates/burn-jit/src/kernel/matmul/tune/key.rs +++ b/crates/burn-jit/src/kernel/matmul/tune/key.rs @@ -22,7 +22,7 @@ pub struct MatmulAutotuneKey { } impl MatmulAutotuneKey { - fn from_shape(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self { + pub(crate) fn from_shape(lhs_shape: &Shape, rhs_shape: &Shape, dtype: DType) -> Self { let ndims = lhs_shape.num_dims(); let m = lhs_shape.dims[ndims - 2]; let k = lhs_shape.dims[ndims - 1]; diff --git a/crates/burn-tensor/src/repr/handle.rs b/crates/burn-tensor/src/repr/handle.rs index 85e18ec444..dce51f5ee2 100644 --- a/crates/burn-tensor/src/repr/handle.rs +++ b/crates/burn-tensor/src/repr/handle.rs @@ -26,6 +26,23 @@ pub struct HandleContainer { pub handles_orphan: Vec, } +impl HandleContainer { + /// Fork the container, useful for autotune. + pub fn fork(&self) -> Self { + let mut handles = HashMap::with_capacity(self.handles.len()); + + for (id, handle) in self.handles.iter() { + handles.insert(*id, handle.clone()); + } + + Self { + handles, + counter: self.counter, + handles_orphan: self.handles_orphan.clone(), + } + } +} + impl core::fmt::Debug for HandleContainer { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.debug_struct("HandleContainer") @@ -37,6 +54,7 @@ impl core::fmt::Debug for HandleContainer { } /// Backend [tensor handle](ReprBackend::Handle) wrapper tracking their creation state +#[derive(Clone)] pub enum Handle { /// No [tensor handle](ReprBackend::Handle) has been created yet NotInit, diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index bf12a0b6d9..1be9803a15 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -1,3 +1,5 @@ +#![recursion_limit = "256"] + use burn::{ nn::transformer::TransformerEncoderConfig, optim::{decay::WeightDecayConfig, AdamConfig}, From dd0396dd5bb2d464ddf34ae8ba03664bfd7bcf94 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Wed, 22 Jan 2025 15:29:52 -0500 Subject: [PATCH 42/61] Fix db-pedia-infer backend (#2736) --- .../examples/db-pedia-infer.rs | 32 ++++++------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/examples/text-classification/examples/db-pedia-infer.rs b/examples/text-classification/examples/db-pedia-infer.rs index 490ed3b97e..027eb76122 100644 --- a/examples/text-classification/examples/db-pedia-infer.rs +++ b/examples/text-classification/examples/db-pedia-infer.rs @@ -1,6 +1,6 @@ use text_classification::DbPediaDataset; -use burn::tensor::backend::AutodiffBackend; +use burn::tensor::backend::Backend; #[cfg(not(feature = "f16"))] #[allow(dead_code)] @@ -8,7 +8,7 @@ type ElemType = f32; #[cfg(feature = "f16")] type ElemType = burn::tensor::f16; -pub fn launch(device: B::Device) { +pub fn launch(device: B::Device) { text_classification::inference::infer::( device, "/tmp/text-classification-db-pedia", @@ -34,24 +34,18 @@ pub fn launch(device: B::Device) { feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::{ - ndarray::{NdArray, NdArrayDevice}, - Autodiff, - }; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(NdArrayDevice::Cpu); + launch::>(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::{ - libtorch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::{launch, ElemType}; @@ -61,35 +55,29 @@ mod tch_gpu { #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; - launch::>>(device); + launch::>(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::{ - tch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::tch::{LibTorch, LibTorchDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(LibTorchDevice::Cpu); + launch::>(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { - use burn::backend::{ - wgpu::{Wgpu, WgpuDevice}, - Autodiff, - }; + use burn::backend::wgpu::{Wgpu, WgpuDevice}; use crate::{launch, ElemType}; pub fn run() { - launch::>>(WgpuDevice::default()); + launch::>(WgpuDevice::default()); } } From e40c69b89363de48119d95cdb4b36e76bad9c702 Mon Sep 17 00:00:00 2001 From: Kai Shang <44828636+xmy314@users.noreply.github.com> Date: Wed, 22 Jan 2025 13:06:07 -0800 Subject: [PATCH 43/61] Fixed typo in the burn book chapter advanced unit no-std. (#2731) * Fixed typo in the burn book chapter advanced unit no-std. "deice" -> "device" * Fixed typo in the accompanying example as well. --- burn-book/src/advanced/no-std.md | 4 ++-- examples/raspberry-pi-pico/src/bin/main.rs | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/burn-book/src/advanced/no-std.md b/burn-book/src/advanced/no-std.md index 5f5621cc51..e55afc904d 100644 --- a/burn-book/src/advanced/no-std.md +++ b/burn-book/src/advanced/no-std.md @@ -68,7 +68,7 @@ We are using ndarray, so we just need to define the NdArray backend as usual use burn::{backend::NdArray, tensor::Tensor}; type Backend = NdArray; -type BackendDeice = ::Device; +type BackendDevice = ::Device; ``` Then inside the `main` function add @@ -76,7 +76,7 @@ Then inside the `main` function add use your_model::Model; // Get a default device for the backend -let device = BackendDeice::default(); +let device = BackendDevice::default(); // Create a new model and load the state let model: Model = Model::default(); diff --git a/examples/raspberry-pi-pico/src/bin/main.rs b/examples/raspberry-pi-pico/src/bin/main.rs index 1b7f6acdf0..a502a8193e 100644 --- a/examples/raspberry-pi-pico/src/bin/main.rs +++ b/examples/raspberry-pi-pico/src/bin/main.rs @@ -10,7 +10,7 @@ use embassy_rp as _; use embedded_alloc::Heap; type Backend = NdArray; -type BackendDeice = ::Device; +type BackendDevice = ::Device; #[global_allocator] static HEAP: Heap = Heap::empty(); @@ -25,7 +25,7 @@ async fn main(_spawner: Spawner) { } // Get a default device for the backend - let device = BackendDeice::default(); + let device = BackendDevice::default(); // Create a new model and load the state let model: Model = Model::default(); @@ -47,7 +47,7 @@ async fn main(_spawner: Spawner) { } } -fn run_model<'a>(model: &Model, device: &BackendDeice, input: f32) -> Tensor { +fn run_model<'a>(model: &Model, device: &BackendDevice, input: f32) -> Tensor { // Define the tensor let input = Tensor::::from_floats([[input]], &device); From e73c2d967b14bf1432334b6416407b24857991ca Mon Sep 17 00:00:00 2001 From: quinton Date: Fri, 24 Jan 2025 14:07:15 +0000 Subject: [PATCH 44/61] feat: bitwise-ops-for-tensors (#2498) * feat: bitwise-ops-for-tensors * add bitwise ops for jit * patch: address-requested-changes * feat: jit-binary-int-ops * cargo lock * feat: jit-backend bitwise not unary op * feat: bitwise left shift and right shift ops * patch: resolve review request changes * patch: remove-dtype-int-op-desc * refactor requested changes * Add bitwise int ops to book + remove dead code --------- Co-authored-by: Guillaume Lagrange --- burn-book/src/building-blocks/tensor.md | 145 ++++----- crates/burn-autodiff/src/ops/int_tensor.rs | 44 +++ crates/burn-candle/src/ops/int_tensor.rs | 43 +++ crates/burn-fusion/src/ops/int.rs | 263 +++++++++++++++++ crates/burn-fusion/src/stream/context.rs | 76 +++++ crates/burn-jit/src/element.rs | 1 + crates/burn-jit/src/kernel/binary_int.rs | 276 ++++++++++++++++++ crates/burn-jit/src/kernel/mod.rs | 4 + crates/burn-jit/src/kernel/unary_int.rs | 148 ++++++++++ crates/burn-jit/src/ops/int_ops.rs | 59 +++- crates/burn-jit/src/ops/numeric.rs | 38 ++- crates/burn-ndarray/src/ops/int_tensor.rs | 67 +++++ crates/burn-router/src/ops/op_int.rs | 197 +++++++++++++ crates/burn-router/src/runner.rs | 33 +++ crates/burn-tch/src/ops/base.rs | 114 ++++++++ crates/burn-tch/src/ops/int_tensor.rs | 59 ++++ crates/burn-tensor/src/repr/operation.rs | 77 +++++ crates/burn-tensor/src/tensor/api/int.rs | 55 ++++ .../burn-tensor/src/tensor/ops/int_tensor.rs | 33 +++ crates/burn-tensor/src/tests/mod.rs | 1 + crates/burn-tensor/src/tests/ops/bitwise.rs | 172 +++++++++++ crates/burn-tensor/src/tests/ops/mod.rs | 1 + 22 files changed, 1836 insertions(+), 70 deletions(-) create mode 100644 crates/burn-jit/src/kernel/binary_int.rs create mode 100644 crates/burn-jit/src/kernel/unary_int.rs create mode 100644 crates/burn-tensor/src/tests/ops/bitwise.rs diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index 8a7c01bbc9..410d531d74 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -131,47 +131,47 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. -| Burn | PyTorch Equivalent | -| ------------------------------------- | ------------------------------------------------------------------------- | -| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` | -| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` | -| `Tensor::from_primitive(primitive)` | N/A | -| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` | -| `tensor.all()` | `tensor.all()` | -| `tensor.all_dim(dim)` | `tensor.all(dim)` | -| `tensor.any()` | `tensor.any()` | -| `tensor.any_dim(dim)` | `tensor.any(dim)` | -| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` | -| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` | -| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` | -| `tensor.device()` | `tensor.device` | -| `tensor.dtype()` | `tensor.dtype` | -| `tensor.dims()` | `tensor.size()` | -| `tensor.equal(other)` | `x == y` | -| `tensor.expand(shape)` | `tensor.expand(shape)` | -| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | -| `tensor.flip(axes)` | `tensor.flip(axes)` | -| `tensor.into_data()` | N/A | -| `tensor.into_primitive()` | N/A | -| `tensor.into_scalar()` | `tensor.item()` | -| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | -| `tensor.not_equal(other)` | `x != y` | -| `tensor.permute(axes)` | `tensor.permute(axes)` | -| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | -| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` | -| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` | -| `tensor.reshape(shape)` | `tensor.view(shape)` | -| `tensor.shape()` | `tensor.shape` | -| `tensor.slice(ranges)` | `tensor[(*ranges,)]` | -| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` | -| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` | -| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | -| `tensor.to_data()` | N/A | -| `tensor.to_device(device)` | `tensor.to(device)` | -| `tensor.transpose()` | `tensor.T` | -| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` | -| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` | -| `tensor.unsqueeze_dims(dims)` | N/A | +| Burn | PyTorch Equivalent | +| ------------------------------------------- | ------------------------------------------------------------------------- | +| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` | +| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` | +| `Tensor::from_primitive(primitive)` | N/A | +| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` | +| `tensor.all()` | `tensor.all()` | +| `tensor.all_dim(dim)` | `tensor.all(dim)` | +| `tensor.any()` | `tensor.any()` | +| `tensor.any_dim(dim)` | `tensor.any(dim)` | +| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` | +| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` | +| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` | +| `tensor.device()` | `tensor.device` | +| `tensor.dtype()` | `tensor.dtype` | +| `tensor.dims()` | `tensor.size()` | +| `tensor.equal(other)` | `x == y` | +| `tensor.expand(shape)` | `tensor.expand(shape)` | +| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | +| `tensor.flip(axes)` | `tensor.flip(axes)` | +| `tensor.into_data()` | N/A | +| `tensor.into_primitive()` | N/A | +| `tensor.into_scalar()` | `tensor.item()` | +| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | +| `tensor.not_equal(other)` | `x != y` | +| `tensor.permute(axes)` | `tensor.permute(axes)` | +| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | +| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` | +| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` | +| `tensor.reshape(shape)` | `tensor.view(shape)` | +| `tensor.shape()` | `tensor.shape` | +| `tensor.slice(ranges)` | `tensor[(*ranges,)]` | +| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` | +| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` | +| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | +| `tensor.to_data()` | N/A | +| `tensor.to_device(device)` | `tensor.to(device)` | +| `tensor.transpose()` | `tensor.T` | +| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` | +| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` | +| `tensor.unsqueeze_dims(dims)` | N/A | ### Numeric Operations @@ -258,32 +258,32 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. Those operations are only available for `Float` tensors. -| Burn API | PyTorch Equivalent | -| --------------------------------------------- | ---------------------------------- | -| `tensor.cast(dtype)` | `tensor.to(dtype)` | -| `tensor.ceil()` | `tensor.ceil()` | -| `tensor.cos()` | `tensor.cos()` | -| `tensor.erf()` | `tensor.erf()` | -| `tensor.exp()` | `tensor.exp()` | -| `tensor.floor()` | `tensor.floor()` | -| `tensor.from_floats(floats, device)` | N/A | -| `tensor.from_full_precision(tensor)` | N/A | -| `tensor.int()` | Similar to `tensor.to(torch.long)` | -| `tensor.log()` | `tensor.log()` | -| `tensor.log1p()` | `tensor.log1p()` | -| `tensor.matmul(other)` | `tensor.matmul(other)` | -| `tensor.random(shape, distribution, device)` | N/A | -| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform | -| `tensor.recip()` | `tensor.reciprocal()` | -| `tensor.round()` | `tensor.round()` | -| `tensor.sin()` | `tensor.sin()` | -| `tensor.sqrt()` | `tensor.sqrt()` | -| `tensor.tanh()` | `tensor.tanh()` | -| `tensor.to_full_precision()` | `tensor.to(torch.float)` | -| `tensor.var(dim)` | `tensor.var(dim)` | -| `tensor.var_bias(dim)` | N/A | -| `tensor.var_mean(dim)` | N/A | -| `tensor.var_mean_bias(dim)` | N/A | +| Burn API | PyTorch Equivalent | +| -------------------------------------------- | ---------------------------------- | +| `tensor.cast(dtype)` | `tensor.to(dtype)` | +| `tensor.ceil()` | `tensor.ceil()` | +| `tensor.cos()` | `tensor.cos()` | +| `tensor.erf()` | `tensor.erf()` | +| `tensor.exp()` | `tensor.exp()` | +| `tensor.floor()` | `tensor.floor()` | +| `tensor.from_floats(floats, device)` | N/A | +| `tensor.from_full_precision(tensor)` | N/A | +| `tensor.int()` | Similar to `tensor.to(torch.long)` | +| `tensor.log()` | `tensor.log()` | +| `tensor.log1p()` | `tensor.log1p()` | +| `tensor.matmul(other)` | `tensor.matmul(other)` | +| `tensor.random(shape, distribution, device)` | N/A | +| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform | +| `tensor.recip()` | `tensor.reciprocal()` | +| `tensor.round()` | `tensor.round()` | +| `tensor.sin()` | `tensor.sin()` | +| `tensor.sqrt()` | `tensor.sqrt()` | +| `tensor.tanh()` | `tensor.tanh()` | +| `tensor.to_full_precision()` | `tensor.to(torch.float)` | +| `tensor.var(dim)` | `tensor.var(dim)` | +| `tensor.var_bias(dim)` | N/A | +| `tensor.var_mean(dim)` | N/A | +| `tensor.var_mean_bias(dim)` | N/A | ### Int Operations @@ -293,6 +293,17 @@ Those operations are only available for `Int` tensors. | ------------------------------------------------ | ------------------------------------------------------- | | `Tensor::arange(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` | | `Tensor::arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` | +| `tensor.bitwise_and(other)` | `torch.bitwise_and(tensor, other)` | +| `tensor.bitwise_and_scalar(scalar)` | `torch.bitwise_and(tensor, scalar)` | +| `tensor.bitwise_not()` | `torch.bitwise_not(tensor)` | +| `tensor.bitwise_left_shift(other)` | `torch.bitwise_left_shift(tensor, other)` | +| `tensor.bitwise_left_shift_scalar(scalar)` | `torch.bitwise_left_shift(tensor, scalar)` | +| `tensor.bitwise_right_shift(other)` | `torch.bitwise_right_shift(tensor, other)` | +| `tensor.bitwise_right_shift_scalar(scalar)` | `torch.bitwise_right_shift(tensor, scalar)` | +| `tensor.bitwise_or(other)` | `torch.bitwise_or(tensor, other)` | +| `tensor.bitwise_or_scalar(scalar)` | `torch.bitwise_or(tensor, scalar)` | +| `tensor.bitwise_xor(other)` | `torch.bitwise_xor(tensor, other)` | +| `tensor.bitwise_xor_scalar(scalar)` | `torch.bitwise_xor(tensor, scalar)` | | `tensor.float()` | `tensor.to(torch.float)` | | `tensor.from_ints(ints)` | N/A | | `tensor.int_random(shape, distribution, device)` | N/A | diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index 4aad98bb46..f3439d1cad 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -348,4 +348,48 @@ impl IntTensorOps for Autodiff { fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { B::int_argsort(tensor, dim, descending) } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_and(lhs, rhs) + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_and_scalar(lhs, rhs) + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_or(lhs, rhs) + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_or_scalar(lhs, rhs) + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_xor(lhs, rhs) + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_xor_scalar(lhs, rhs) + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + B::bitwise_not(tensor) + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_left_shift(lhs, rhs) + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_left_shift_scalar(lhs, rhs) + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_right_shift(lhs, rhs) + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_right_shift_scalar(lhs, rhs) + } } diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 4ae0c53de7..08b84251fa 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -372,4 +372,47 @@ impl IntTensorOps for Candle) -> IntTensor { sign(tensor) } + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_and is not implemented for Candle IntTensor"); + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_and_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_or is not implemented for Candle IntTensor"); + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_or_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_xor is not implemented for Candle IntTensor"); + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_xor_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + unimplemented!("bitwise_not is not implemented for Candle IntTensor"); + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_left_shift is not implemented for Candle IntTensor"); + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_right_shift is not implemented for Candle IntTensor"); + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_left_shift_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_right_shift_scalar is not implemented for Candle IntTensor"); + } } diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index bdb47df02c..e2115cbf6a 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -1819,4 +1819,267 @@ impl IntTensorOps for Fusion { out } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseAndOps, B::bitwise_and); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseAnd(desc.clone())), + BitwiseAndOps::::new(desc), + ); + + out + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseAndOps, B::bitwise_and_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseAndScalar( + desc.clone(), + )), + BitwiseAndOps::::new(desc), + ); + + out + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseOrOps, B::bitwise_or); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseOr(desc.clone())), + BitwiseOrOps::::new(desc), + ); + + out + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseOrOps, B::bitwise_or_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseOrScalar(desc.clone())), + BitwiseOrOps::::new(desc), + ); + + out + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseXorOps, B::bitwise_xor); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseXor(desc.clone())), + BitwiseXorOps::::new(desc), + ); + + out + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseXorOps, B::bitwise_xor_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseXorScalar( + desc.clone(), + )), + BitwiseXorOps::::new(desc), + ); + + out + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + unary_int_ops!(BitwiseNotOps, B::bitwise_not); + + let stream = tensor.stream; + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseNot(desc.clone())), + BitwiseNotOps::::new(desc), + ); + + out + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseLeftShift( + desc.clone(), + )), + BitwiseLeftShiftOps::::new(desc), + ); + + out + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseLeftShiftScalar( + desc.clone(), + )), + BitwiseLeftShiftOps::::new(desc), + ); + + out + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseRightShift( + desc.clone(), + )), + BitwiseRightShiftOps::::new(desc), + ); + + out + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseRightShiftScalar( + desc.clone(), + )), + BitwiseRightShiftOps::::new(desc), + ); + + out + } } diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 0f9c75fc94..d85e06cc09 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -694,6 +694,82 @@ impl RelativeOps for IntOperationDescription { out: desc.out.to_relative(converter), }) } + IntOperationDescription::BitwiseAnd(desc) => { + IntOperationDescription::BitwiseAnd(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseAndScalar(desc) => { + IntOperationDescription::BitwiseAndScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseOr(desc) => { + IntOperationDescription::BitwiseOr(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseOrScalar(desc) => { + IntOperationDescription::BitwiseOrScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseXor(desc) => { + IntOperationDescription::BitwiseXor(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseXorScalar(desc) => { + IntOperationDescription::BitwiseXorScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseNot(desc) => { + IntOperationDescription::BitwiseNot(UnaryOperationDescription { + input: desc.input.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseLeftShift(desc) => { + IntOperationDescription::BitwiseLeftShift(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseLeftShiftScalar(desc) => { + IntOperationDescription::BitwiseLeftShiftScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseRightShift(desc) => { + IntOperationDescription::BitwiseRightShift(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseRightShiftScalar(desc) => { + IntOperationDescription::BitwiseRightShiftScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } } } } diff --git a/crates/burn-jit/src/element.rs b/crates/burn-jit/src/element.rs index f0e15352cf..a1bbab7f5f 100644 --- a/crates/burn-jit/src/element.rs +++ b/crates/burn-jit/src/element.rs @@ -57,6 +57,7 @@ impl IntElement for i64 {} impl IntElement for i32 {} impl IntElement for i16 {} impl IntElement for i8 {} +impl IntElement for u32 {} impl BoolElement for u8 {} impl BoolElement for u32 {} diff --git a/crates/burn-jit/src/kernel/binary_int.rs b/crates/burn-jit/src/kernel/binary_int.rs new file mode 100644 index 0000000000..06706a7d28 --- /dev/null +++ b/crates/burn-jit/src/kernel/binary_int.rs @@ -0,0 +1,276 @@ +use crate::{ops::numeric::empty_device, tensor::JitTensor, IntElement, JitRuntime}; +use burn_tensor::Shape; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +use super::into_contiguous; + +pub(crate) trait BinaryOpIntFamily: Send + Sync + 'static { + type BinaryOp: BinaryOpInt; +} + +#[cube] +pub(crate) trait BinaryOpInt: 'static + Send + Sync { + /// Execute a binary operation. + fn execute(lhs: Line, rhs: Line) -> Line; +} + +pub(crate) struct BitwiseAndOp; +pub(crate) struct BitwiseOrOp; +pub(crate) struct BitwiseXorOp; +pub(crate) struct BitwiseShrOp; +pub(crate) struct BitwiseShlOp; + +impl BinaryOpIntFamily for BitwiseAndOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseOrOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseXorOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseShrOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseShlOp { + type BinaryOp = Self; +} + +#[cube] +impl BinaryOpInt for BitwiseAndOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs & rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseOrOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs | rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseXorOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs ^ rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseShrOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs >> rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseShlOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs << rhs + } +} + +#[cube(launch_unchecked)] +pub(crate) fn kernel_scalar_binop_int( + input: &Tensor>, + scalar: C, + output: &mut Tensor>, +) { + if ABSOLUTE_POS >= output.len() { + return; + } + + output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); +} + +#[cube(launch_unchecked)] +pub(crate) fn kernel_binop_int( + lhs: &Tensor>, + rhs: &Tensor>, + out: &mut Tensor>, + #[comptime] rank: Option, + #[comptime] to_contiguous_lhs: bool, + #[comptime] to_contiguous_rhs: bool, +) { + let offset_out = ABSOLUTE_POS; + let mut offset_lhs = ABSOLUTE_POS; + let mut offset_rhs = ABSOLUTE_POS; + + if offset_out >= out.len() { + return; + } + + if to_contiguous_lhs { + offset_lhs = index_offset_with_layout::( + lhs, + out, + offset_out, + 0, + rank.unwrap_or_else(|| out.rank()), + rank.is_some(), + ); + } + + if to_contiguous_rhs { + offset_rhs = index_offset_with_layout::( + rhs, + out, + offset_out, + 0, + rank.unwrap_or_else(|| out.rank()), + rank.is_some(), + ); + } + + out[offset_out] = O::BinaryOp::::execute(lhs[offset_lhs], rhs[offset_rhs]); +} + +pub(crate) fn launch_binop_int( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + let ndims = lhs.shape.num_dims(); + let line_size_lhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &lhs.shape.dims, + &lhs.strides, + ndims - 1, + ); + let line_size_rhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &rhs.shape.dims, + &rhs.strides, + ndims - 1, + ); + let line_size = Ord::min(line_size_lhs, line_size_rhs); + + let mut shape_out = vec![0; ndims]; + lhs.shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + + let shape_out = Shape::from(shape_out); + let client = lhs.client.clone(); + let num_elems = shape_out.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if lhs.can_mut_broadcast(&rhs) { + kernel_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(0), + None, + false, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + ); + + lhs + } else if rhs.can_mut_broadcast(&lhs) { + kernel_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(1), + None, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + false, + ); + + rhs + } else { + let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); + let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; + let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; + + kernel_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + None, + to_contiguous_lhs, + to_contiguous_rhs, + ); + + output + } + } +} + +pub(crate) fn launch_scalar_binop_int( + mut tensor: JitTensor, + scalar: E, +) -> JitTensor { + if !tensor.is_contiguous_buffer() { + tensor = into_contiguous(tensor); + } + + // Vectorization is only enabled when the last dimension is contiguous. + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if tensor.can_mut() { + kernel_scalar_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + TensorArg::alias(0), + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + kernel_scalar_binop_int::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + output.as_tensor_arg::(line_size), + ); + + output + } + } +} diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index fd23cd2e2d..93d2833976 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -1,4 +1,5 @@ mod binary; +mod binary_int; mod cast; mod clamp; mod comparison; @@ -6,13 +7,16 @@ mod contiguous; mod index; mod mask; mod unary_float; +mod unary_int; mod unary_numeric; pub(crate) use binary::*; +pub(crate) use binary_int::*; pub use cast::*; pub use contiguous::*; pub use mask::*; pub(crate) use unary_float::*; +pub(crate) use unary_int::*; pub(crate) use unary_numeric::*; pub use burn_common::PLANE_DIM_APPROX; diff --git a/crates/burn-jit/src/kernel/unary_int.rs b/crates/burn-jit/src/kernel/unary_int.rs new file mode 100644 index 0000000000..5e60898699 --- /dev/null +++ b/crates/burn-jit/src/kernel/unary_int.rs @@ -0,0 +1,148 @@ +use crate::{ops::numeric::empty_device, tensor::JitTensor, IntElement, JitRuntime}; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +pub(crate) trait IntUnaryOpFamily: 'static + Send + Sync { + type Options: LaunchArg; + type Unary: IntUnaryOp>; +} + +#[cube] +pub(crate) trait IntUnaryOp: 'static + Send + Sync { + type Options: LaunchArg; + + fn execute(input: Line, options: &Self::Options) -> Line; +} + +#[cube(launch_unchecked)] +pub(crate) fn unary_int( + input: &Tensor>, + output: &mut Tensor>, + options: &O::Options, + #[comptime] rank: Option, + #[comptime] to_contiguous: bool, +) { + let offset_output = ABSOLUTE_POS; + + if offset_output >= output.len() { + return; + } + + if comptime![to_contiguous] { + let offset_input = index_offset_with_layout::( + input, + output, + offset_output, + 0, + rank.unwrap_or_else(|| output.rank()), + rank.is_some(), + ); + + output[offset_output] = O::Unary::::execute(input[offset_input], options); + } else { + output[offset_output] = O::Unary::::execute(input[offset_output], options); + } +} + +pub(crate) fn launch_unary_int(tensor: JitTensor, args: Args) -> JitTensor +where + for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, + R: JitRuntime, + E: IntElement + Int, + O: IntUnaryOpFamily, +{ + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + let is_contiguous = tensor.is_contiguous(); + + unsafe { + if tensor.can_mut() && tensor.is_contiguous_buffer() { + unary_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + TensorArg::alias(0), + args(&()), + None, + false, + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + unary_int::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + args(&()), + Some(ndims as u32), + !is_contiguous, + ); + output + } + } +} + +pub(crate) mod unary_basic_int { + + use super::*; + + pub(crate) fn launch(tensor: JitTensor, args: Args) -> JitTensor + where + R: JitRuntime, + for<'a> Args: FnOnce(&'a ()) -> &'a BasicIntUnaryKind, + I: IntElement, + { + launch_unary_int::(tensor, |input| { + BasicIntUnaryOptionsLaunch::new(args(input)) + }) + } + + #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)] + pub enum BasicIntUnaryKind { + BitwiseNot, + } + + #[derive(CubeLaunch)] + struct BasicIntUnaryOptions { + #[cube(comptime)] + kind: BasicIntUnaryKind, + } + struct BasicIntUnary; + + #[cube] + impl IntUnaryOp for BasicIntUnary { + type Options = BasicIntUnaryOptions; + + fn execute(input: Line, options: &Self::Options) -> Line { + match comptime![options.kind] { + BasicIntUnaryKind::BitwiseNot => Line::bitwise_not(input), + } + } + } + + impl IntUnaryOpFamily for BasicIntUnary { + type Options = BasicIntUnaryOptions; + type Unary = Self; + } +} diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 5702a90849..d7b84e5e64 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -1,5 +1,10 @@ +use self::unary_basic_int::BasicIntUnaryKind; + use super::{expand, numeric, permute}; -use crate::kernel::{launch_unary_numeric, reduce, NumericUnaryOp, NumericUnaryOpFamily}; +use crate::kernel::{ + launch_binop_int, launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int, + BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, +}; use crate::{ element::BoolElement, kernel::prng::{random_bernoulli, random_normal, random_uniform}, @@ -293,4 +298,56 @@ where fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { kernel::flip::(tensor, axes) } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + numeric::bitwise_and::(lhs, rhs) + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + numeric::bitwise_and_scalar::(lhs, rhs) + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + numeric::bitwise_or::(lhs, rhs) + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + numeric::bitwise_or_scalar(lhs, rhs) + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + numeric::bitwise_xor::(lhs, rhs) + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + numeric::bitwise_xor_scalar(lhs, rhs) + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + unary_basic_int::launch::(tensor, |_| &BasicIntUnaryKind::BitwiseNot) + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let lhs_cast = kernel::cast::(lhs); + let rhs_cast = kernel::cast::(rhs); + launch_binop_int::(lhs_cast, rhs_cast) + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let lhs_cast = kernel::cast::(lhs); + let rhs_cast = rhs.elem::(); + launch_scalar_binop_int::(lhs_cast, rhs_cast) + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let lhs_cast = kernel::cast::(lhs); + let rhs_cast = kernel::cast::(rhs); + launch_binop_int::(lhs_cast, rhs_cast) + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let lhs_cast = kernel::cast::(lhs); + let rhs_cast = rhs.elem::(); + launch_scalar_binop_int::(lhs_cast, rhs_cast) + } } diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index d0d5be8468..2c2c7987ab 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -1,8 +1,9 @@ use crate::kernel::{ - launch_binop, launch_scalar_binop, AddOp, DivOp, MulOp, PowOp, RemainderOp, SubOp, + launch_binop, launch_binop_int, launch_scalar_binop, launch_scalar_binop_int, AddOp, + BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, MulOp, PowOp, RemainderOp, SubOp, }; use crate::{element::JitElement, tensor::JitTensor}; -use crate::{FloatElement, JitRuntime}; +use crate::{FloatElement, IntElement, JitRuntime}; use burn_tensor::{ElementConversion, Shape}; use cubecl::client::ComputeClient; use cubecl::tensor_vectorization_factor; @@ -139,3 +140,36 @@ pub fn remainder_scalar(lhs: JitTensor, rhs: E) pub fn pow(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::>(lhs, rhs) } + +pub fn bitwise_and( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_binop_int::(lhs, rhs) +} + +pub fn bitwise_and_scalar(lhs: JitTensor, rhs: E) -> JitTensor { + launch_scalar_binop_int::(lhs, rhs) +} + +pub fn bitwise_or( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_binop_int::(lhs, rhs) +} + +pub fn bitwise_or_scalar(lhs: JitTensor, rhs: E) -> JitTensor { + launch_scalar_binop_int::(lhs, rhs) +} + +pub fn bitwise_xor( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_binop_int::(lhs, rhs) +} + +pub fn bitwise_xor_scalar(lhs: JitTensor, rhs: E) -> JitTensor { + launch_scalar_binop_int::(lhs, rhs) +} diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index 9009b5c4a8..43c7cdb100 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -351,4 +351,71 @@ impl IntTensorOps fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { NdArrayOps::expand(tensor, shape) } + + fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() & (b.elem::())).elem() + }) + } + + fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() & rhs.elem::()).elem() + }) + } + + fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() | (b.elem::())).elem() + }) + } + + fn bitwise_or_scalar( + lhs: burn_tensor::ops::IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> burn_tensor::ops::IntTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() | rhs.elem::()).elem() + }) + } + + fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() ^ (b.elem::())).elem() + }) + } + + fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() ^ rhs.elem::()).elem() + }) + } + + fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(tensor, |a: I| (!a.elem::()).elem()) + } + + fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() << (b.elem::())).elem() + }) + } + + fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() << rhs.elem::()).elem() + }) + } + + fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() >> (b.elem::())).elem() + }) + } + + fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() >> rhs.elem::()).elem() + }) + } } diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index db81602d4f..5d84131e32 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -1173,4 +1173,201 @@ impl IntTensorOps for BackendRouter { out } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseAnd(desc), + )); + + out + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseOr(desc), + )); + + out + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseXor(desc), + )); + + out + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseNot(desc), + )); + + out + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseAndScalar(desc), + )); + + out + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseOrScalar(desc), + )); + + out + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseXorScalar(desc), + )); + + out + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseLeftShift(desc), + )); + + out + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseLeftShiftScalar(desc), + )); + + out + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseRightShift(desc), + )); + + out + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseRightShiftScalar(desc), + )); + + out + } } diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index 04f93a4769..9521cf66ae 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -792,6 +792,39 @@ impl RunnerClient for Runner { let output = B::int_into_float(tensor); handles.register_float_tensor::(&desc.out.id, output); } + IntOperationDescription::BitwiseAnd(desc) => { + binary_int_ops!(handles, desc, B::bitwise_and) + } + IntOperationDescription::BitwiseAndScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_and_scalar) + } + IntOperationDescription::BitwiseOr(desc) => { + binary_int_ops!(handles, desc, B::bitwise_or) + } + IntOperationDescription::BitwiseOrScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_or_scalar) + } + IntOperationDescription::BitwiseXor(desc) => { + binary_int_ops!(handles, desc, B::bitwise_xor) + } + IntOperationDescription::BitwiseXorScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_xor_scalar) + } + IntOperationDescription::BitwiseNot(desc) => { + unary_int_ops!(handles, desc, B::bitwise_not) + } + IntOperationDescription::BitwiseLeftShift(desc) => { + binary_int_ops!(handles, desc, B::bitwise_left_shift) + } + IntOperationDescription::BitwiseRightShift(desc) => { + binary_int_ops!(handles, desc, B::bitwise_right_shift) + } + IntOperationDescription::BitwiseLeftShiftScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_left_shift_scalar) + } + IntOperationDescription::BitwiseRightShiftScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_right_shift_scalar) + } }, OperationDescription::Float(_dtype, op) => match op { FloatOperationDescription::Exp(desc) => { diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index 7b04207871..704c6176cc 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -477,4 +477,118 @@ impl TchOps { pub fn argsort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor { TchTensor::new(tensor.tensor.argsort(dim as i64, descending)) } + + pub fn bitwise_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_and_tensor_(rhs).unwrap(), + |lhs, rhs| rhs.f_bitwise_and_tensor_(lhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_and_tensor(rhs).unwrap(), + ) + } + + pub fn bitwise_and_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_and_(scalar.clone().into()).unwrap(), + |tensor| tensor.f_bitwise_and(scalar.clone().into()).unwrap(), + ) + } + + pub fn bitwise_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_or_tensor_(rhs).unwrap(), + |lhs, rhs| rhs.f_bitwise_or_tensor_(lhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_or_tensor(rhs).unwrap(), + ) + } + + pub fn bitwise_or_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_or_(scalar.clone().into()).unwrap(), + |tensor| tensor.f_bitwise_or(scalar.clone().into()).unwrap(), + ) + } + + pub fn bitwise_xor(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_xor_tensor_(rhs).unwrap(), + |lhs, rhs| rhs.f_bitwise_xor_tensor_(lhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_xor_tensor(rhs).unwrap(), + ) + } + + pub fn bitwise_xor_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_xor_(scalar.clone().into()).unwrap(), + |tensor| tensor.f_bitwise_xor(scalar.clone().into()).unwrap(), + ) + } + + pub fn bitwise_not(tensor: TchTensor) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_not_().unwrap(), + |tensor| tensor.f_bitwise_not().unwrap(), + ) + } + + pub fn bitwise_left_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_left_shift_(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(), + ) + } + + pub fn bitwise_left_shift_scalar + Clone>( + tensor: TchTensor, + scalar: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| { + tensor + .f_bitwise_left_shift_tensor_scalar_(scalar.clone().into()) + .unwrap() + }, + |tensor| { + tensor + .f_bitwise_left_shift_tensor_scalar(scalar.clone().into()) + .unwrap() + }, + ) + } + + pub fn bitwise_right_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_right_shift_(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(), + ) + } + + pub fn bitwise_right_shift_scalar + Clone>( + tensor: TchTensor, + scalar: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| { + tensor + .f_bitwise_right_shift_tensor_scalar_(scalar.clone().into()) + .unwrap() + }, + |tensor| { + tensor + .f_bitwise_right_shift_tensor_scalar(scalar.clone().into()) + .unwrap() + }, + ) + } } diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index 0da31fe430..0ac829abaf 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -416,4 +416,63 @@ impl IntTensorOps for LibTorch { fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { TchOps::argsort(tensor, dim, descending) } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_and(lhs, rhs) + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_or(lhs, rhs) + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_xor(lhs, rhs) + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + TchOps::bitwise_not(tensor) + } + + fn bitwise_and_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_and_scalar(lhs, rhs) + } + + fn bitwise_or_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_or_scalar(lhs, rhs) + } + + fn bitwise_xor_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_xor_scalar(lhs, rhs) + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_left_shift(lhs, rhs) + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_right_shift(lhs, rhs) + } + + fn bitwise_left_shift_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_left_shift_scalar(lhs, rhs) + } + + fn bitwise_right_shift_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_right_shift_scalar(lhs, rhs) + } } diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 001b9d6e83..0d7fe2493b 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -520,6 +520,50 @@ pub enum NumericOperationDescription { pub enum IntOperationDescription { /// Operation corresponding to [into float](crate::ops::IntTensorOps::int_into_float). IntoFloat(UnaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise and](crate::ops::IntTensorOps::bitwise_and). + BitwiseAnd(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise and scalar](crate::ops::IntTensorOps::bitwise_and_scalar). + BitwiseAndScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise or](crate::ops::IntTensorOps::bitwise_or). + BitwiseOr(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise or scalar](crate::ops::IntTensorOps::bitwise_or_scalar). + BitwiseOrScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise xor](crate::ops::IntTensorOps::bitwise_xor). + BitwiseXor(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise xor scalar](crate::ops::IntTensorOps::bitwise_xor_scalar). + BitwiseXorScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise not](crate::ops::IntTensorOps::bitwise_not). + BitwiseNot(UnaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise left shift](crate::ops::IntTensorOps::bitwise_left_shift). + BitwiseLeftShift(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise left shift scalar](crate::ops::IntTensorOps::bitwise_left_shift_scalar). + BitwiseLeftShiftScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise right shift](crate::ops::IntTensorOps::bitwise_right_shift). + BitwiseRightShift(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise right shift scalar](crate::ops::IntTensorOps::bitwise_right_shift_scalar). + BitwiseRightShiftScalar(ScalarOperationDescription), } /// Operation description specific to a bool tensor. @@ -1544,6 +1588,39 @@ impl IntOperationDescription { fn nodes(&self) -> Vec<&TensorDescription> { match self { IntOperationDescription::IntoFloat(desc) => vec![&desc.input, &desc.out], + IntOperationDescription::BitwiseAnd(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseAndScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseOr(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseOrScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseXor(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseXorScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseNot(desc) => { + vec![&desc.input, &desc.out] + } + IntOperationDescription::BitwiseLeftShift(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseLeftShiftScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseRightShift(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseRightShiftScalar(desc) => { + vec![&desc.lhs, &desc.out] + } } } } diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index e882a107c7..5d65b68ceb 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -99,4 +99,59 @@ where ) -> Tensor { cartesian_grid::(shape, device) } + + /// Applies the bitwise logical and operation with each bit representing the integer. + pub fn bitwise_and(self, other: Self) -> Self { + Self::new(B::bitwise_and(self.primitive, other.primitive)) + } + + /// Applies the bitwise logical or operation with another tensor. + pub fn bitwise_or(self, other: Self) -> Self { + Self::new(B::bitwise_or(self.primitive, other.primitive)) + } + + /// Applies the bitwise logical xor operation with another tensor. + pub fn bitwise_xor(self, other: Self) -> Self { + Self::new(B::bitwise_xor(self.primitive, other.primitive)) + } + + /// Applies the bitwise logical not operation. + pub fn bitwise_not(self) -> Self { + Self::new(B::bitwise_not(self.primitive)) + } + + /// Applies the bitwise logical and operation with each bit in the scalar and the integers in the tensor. + pub fn bitwise_and_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_and_scalar(self.primitive, other)) + } + + /// Applies the bitwise logical or operation with each bit in the scalar and the integers in the tensor. + pub fn bitwise_or_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_or_scalar(self.primitive, other)) + } + + /// Applies bitwise logical xor operation with each bit in the scalar and the integers in the tensor. + pub fn bitwise_xor_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_xor_scalar(self.primitive, other)) + } + + /// Applies the bitwise left shift operation with the integers in the tensor. + pub fn bitwise_left_shift(self, other: Self) -> Self { + Self::new(B::bitwise_left_shift(self.primitive, other.primitive)) + } + + /// Applies the bitwise right shift operation with the integers in the tensor. + pub fn bitwise_right_shift(self, other: Self) -> Self { + Self::new(B::bitwise_right_shift(self.primitive, other.primitive)) + } + + /// Applies the bitwise left shift operation with the integers in the tensor. + pub fn bitwise_left_shift_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_left_shift_scalar(self.primitive, other)) + } + + /// Applies the bitwise right shift operation with the integers in the tensor. + pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_right_shift_scalar(self.primitive, other)) + } } diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index abdd2e54ba..81b73eb2dd 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -1185,4 +1185,37 @@ pub trait IntTensorOps { fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { argsort::(tensor, dim, descending) } + + /// Bitwise AND operation for Int Tensors + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise AND operation for Int Tensors with a scalar + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise OR operation for Int Tensors + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise OR operation for Int Tensors with a scalar + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise XOR operation for Int Tensors + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise XOR operation for Int Tensors with a scalar + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise NOT operation for Int Tensors + fn bitwise_not(tensor: IntTensor) -> IntTensor; + + /// Bitwise left shift operation for Int Tensors + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise left shift operation for Int Tensors with a scalar + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise right shift operation for Int Tensors + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise right shift operation for Int Tensors with a scalar + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; } diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 8aa41ee24d..ee9aec9fe8 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -310,6 +310,7 @@ macro_rules! testgen_with_int_param { burn_tensor::testgen_sub!(); burn_tensor::testgen_transpose!(); burn_tensor::testgen_gather_scatter!(); + burn_tensor::testgen_bitwise!(); // test stats burn_tensor::testgen_eye!(); diff --git a/crates/burn-tensor/src/tests/ops/bitwise.rs b/crates/burn-tensor/src/tests/ops/bitwise.rs new file mode 100644 index 0000000000..73702a716e --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/bitwise.rs @@ -0,0 +1,172 @@ +#[burn_tensor_testgen::testgen(bitwise)] +mod tests { + use super::*; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_apply_bitwise_and_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); + + let output = tensor_1.bitwise_and(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[2, 4, 0], [9, 2, 8]]), false); + } + + #[test] + fn should_apply_bitwise_and_1d() { + let tensor_1 = TestTensorInt::<1>::from([13, 7]); + let tensor_2 = TestTensorInt::from([11, 3]); + + let output = tensor_1.bitwise_and(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([9, 3]), false); + } + + #[test] + fn should_apply_bitwise_and_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 5; + + let output = tensor_1.bitwise_and_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[1, 4, 5], [1, 1, 0]]), false); + } + + #[test] + fn should_apply_bitwise_not_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + + let output = tensor_1.bitwise_not(); + + output + .into_data() + .assert_eq(&TensorData::from([[-4, -5, -6], [-10, -4, -9]]), false); + } + + #[test] + fn should_apply_bitwise_or_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 5; + + let output = tensor_1.bitwise_or_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[7, 5, 5], [13, 7, 13]]), false); + } + + #[test] + fn should_apply_bitwise_or_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); + + let output = tensor_1.bitwise_or(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[7, 7, 13], [9, 11, 15]]), false); + } + + #[test] + fn should_apply_bitwise_or_1d() { + let tensor_1 = TestTensorInt::<1>::from([13, 7]); + let tensor_2 = TestTensorInt::from([11, 3]); + + let output = tensor_1.bitwise_or(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([15, 7]), false); + } + + #[test] + fn should_apply_bitwise_xor_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 5; + + let output = tensor_1.bitwise_xor_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[6, 1, 0], [12, 6, 13]]), false); + } + + #[test] + fn should_apply_bitwise_xor_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); + + let output = tensor_1.bitwise_xor(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[5, 3, 13], [0, 9, 7]]), false); + } + + #[test] + fn should_apply_bitwise_xor_1d() { + let tensor_1 = TestTensorInt::<1>::from([13, 7]); + let tensor_2 = TestTensorInt::from([11, 3]); + + let output = tensor_1.bitwise_xor(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([6, 4]), false); + } + + #[test] + fn should_apply_bitwise_left_shift_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); + + let output = tensor_1.bitwise_left_shift(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[6, 16, 40], [144, 96, 512]]), false); + } + + #[test] + fn should_apply_bitwise_left_shift_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 2; + + let output = tensor_1.bitwise_left_shift_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[12, 16, 20], [36, 12, 32]]), false); + } + + #[test] + fn should_apply_bitwise_right_shift_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); + + let output = tensor_1.bitwise_right_shift(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[1, 1, 0], [0, 0, 0]]), false); + } + + #[test] + fn should_apply_bitwise_right_shift_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 2; + + let output = tensor_1.bitwise_right_shift_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[0, 1, 1], [2, 0, 2]]), false); + } +} diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index b1096e0216..32bdd0f4ba 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -7,6 +7,7 @@ mod arange; mod arange_step; mod arg; mod argwhere_nonzero; +mod bitwise; mod bool; mod cartesian_grid; mod cast; From e586b1739510ec0a2869e58667cf3c8685d17cdf Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 24 Jan 2025 10:34:26 -0500 Subject: [PATCH 45/61] Fix bce loss log (#2741) --- .../src/nn/loss/binary_cross_entropy.rs | 38 +++++++++++++++++-- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs index f645c84fd9..54b80f4f60 100644 --- a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs +++ b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs @@ -118,9 +118,9 @@ impl BinaryCrossEntropyLoss { (targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits) } else { // - (target * log(input) + (1 - target) * log(1 - input)) - (targets_float.clone() * logits.clone().log() - + (targets_float.neg() + 1.) * (logits.neg() + 1.).log()) - .neg() + // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values + (targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0) + - targets_float * logits.log().clamp_min(-100.0) }; if let Some(weights) = &self.weights { @@ -171,6 +171,38 @@ mod tests { use crate::tensor::{activation::sigmoid, TensorData}; use crate::TestBackend; + #[test] + fn test_binary_cross_entropy_preds_all_correct() { + let device = Default::default(); + let preds = Tensor::::from_floats([1.0, 0.0, 1.0, 0.0], &device); + let targets = + Tensor::::from_data(TensorData::from([1, 0, 1, 0]), &device); + + let loss_actual = BinaryCrossEntropyLossConfig::new() + .init(&device) + .forward(preds, targets) + .into_data(); + + let loss_expected = TensorData::from([0.000]); + loss_actual.assert_approx_eq(&loss_expected, 3); + } + + #[test] + fn test_binary_cross_entropy_preds_all_incorrect() { + let device = Default::default(); + let preds = Tensor::::from_floats([0.0, 1.0, 0.0, 1.0], &device); + let targets = + Tensor::::from_data(TensorData::from([1, 0, 1, 0]), &device); + + let loss_actual = BinaryCrossEntropyLossConfig::new() + .init(&device) + .forward(preds, targets) + .into_data(); + + let loss_expected = TensorData::from([100.000]); // clamped value + loss_actual.assert_approx_eq(&loss_expected, 3); + } + #[test] fn test_binary_cross_entropy() { // import torch From 7ddb5afe49a0c7474ff1bfd74758846f4272c42a Mon Sep 17 00:00:00 2001 From: Maxime Tremblay Date: Fri, 24 Jan 2025 12:17:07 -0500 Subject: [PATCH 46/61] Feat/shared sum (#2737) * bump cubecl version * bump cubecl version * import new specialized sum reduction from cubecl * commit missing autotune key * improve chained reduction * fix reduce shape issue * fix typos and dead code --- Cargo.lock | 28 +++---- Cargo.toml | 4 +- crates/burn-jit/src/kernel/reduce/base.rs | 86 ++++++++++++++++++--- crates/burn-jit/src/kernel/reduce/tune.rs | 93 +++++++++++++++++++++-- crates/burn-jit/src/ops/float_ops.rs | 2 +- crates/burn-jit/src/ops/int_ops.rs | 2 +- crates/burn-jit/src/tune_key.rs | 5 +- 7 files changed, 185 insertions(+), 35 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index de9444c5ff..b5a1835a3f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1575,7 +1575,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1590,7 +1590,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1611,7 +1611,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1632,7 +1632,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "bytemuck", "cubecl-common", @@ -1646,7 +1646,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "bytemuck", "cubecl-common", @@ -1662,7 +1662,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "bytemuck", "cubecl-common", @@ -1688,7 +1688,7 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "cubecl-common", "float-ord", @@ -1701,7 +1701,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "bytemuck", "cubecl-core", @@ -1713,7 +1713,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "cubecl-common", "darling", @@ -1728,7 +1728,7 @@ dependencies = [ [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1744,7 +1744,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1754,7 +1754,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "async-channel", "async-lock", @@ -1776,7 +1776,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1791,7 +1791,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=2a6dd3e60b686230a8f686aafd246342259f7003#2a6dd3e60b686230a8f686aafd246342259f7003" +source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index f731d063a9..7cf3ddd008 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2a6dd3e60b686230a8f686aafd246342259f7003" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/kernel/reduce/base.rs b/crates/burn-jit/src/kernel/reduce/base.rs index 9ab1f5d2b6..9ec677ee93 100644 --- a/crates/burn-jit/src/kernel/reduce/base.rs +++ b/crates/burn-jit/src/kernel/reduce/base.rs @@ -1,31 +1,94 @@ -use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime}; +use super::{autotune_reduce, autotune_sum}; +use crate::{ + element::JitElement, + ops::{from_data, numeric::empty_device}, + tensor::JitTensor, + JitRuntime, +}; +use burn_tensor::{Shape, TensorData}; +pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum}; +use cubecl::reduce::shared_sum; -use super::autotune_reduce; +/// Specialize reduce function to compute the sum of all elements of the `input` tensor and return +/// the value into a single-element tensor of shape `1 x 1 x 1 x ...` with the same rank as `input`. +/// +/// This is expected to be faster for larger tensors than calling [reduce] with the `Sum` instruction. +/// +/// Return an error if the `client` doesn't support atomic add for the type `E`. +pub fn sum( + tensor: JitTensor, + cube_count: SumStrategy, +) -> Result, cubecl::reduce::ReduceError> { + let client = tensor.client.clone(); + let device = tensor.device.clone(); -pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum}; + match cube_count { + SumStrategy::OneShot(cube_count) => { + let output = shared_sum::(&client, tensor.as_handle_ref(), cube_count)?; + Ok(from_data::( + TensorData::new(vec![output], vec![1]), + &device, + )) + } + SumStrategy::Chained(strategy) => reduce::(tensor, strategy), + SumStrategy::Autotune => Ok(autotune_sum::(&client, tensor)), + } +} + +/// Select a strategy to perform a sum. +pub enum SumStrategy { + /// Run a single kernel with many cubes working in parallel to sum all elements. + /// The provided value is the number of elements summed per unit (up-to-rounding ) + OneShot(u32), + /// Use multiple kernels + Chained(ReduceStrategy), + /// Use autotune to find the best cube count given the hardware and the input. + #[cfg(feature = "autotune")] + Autotune, +} + +impl Default for SumStrategy { + fn default() -> Self { + #[cfg(feature = "autotune")] + return Self::Autotune; + + #[cfg(not(feature = "autotune"))] + return Self::Static(4); + } +} /// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy). /// /// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`. -/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid. -/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`. /// /// If there is no error, the output is a tensor with decreasing strides /// where the shape of reduced dim is set to 1 but all shape are similar to the input. pub fn reduce( - mut input: JitTensor, + mut tensor: JitTensor, strategy: ReduceStrategy, ) -> Result, cubecl::reduce::ReduceError> { - input.shape = input.shape.flatten(); - input.strides = vec![1]; - reduce_dim::(input, 0, strategy) + // In practice, it looks like starting by the axis with the smallest shape + // and going in increasing order lead to the fastest calculation. + let sorted_axis = argsort(&tensor.shape.dims); + for axis in sorted_axis { + tensor = reduce_dim::(tensor, axis, strategy)?; + } + // reshape to scalar tensor + tensor.shape = Shape::new([1]); + tensor.strides = vec![1]; + Ok(tensor) +} + +fn argsort(shape: &[usize]) -> Vec { + let mut indices = (0..shape.len()).collect::>(); + indices.sort_by_key(|&i| &shape[i]); + indices } /// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy). /// /// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`. /// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid. -/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`. /// /// If there is no error, the output is a tensor with decreasing strides /// where the shape of reduced dim is set to 1 but all shape are similar to the input. @@ -58,7 +121,8 @@ pub fn reduce_dim { - autotune_reduce::(&client, input, output.clone(), dim) + autotune_reduce::(&client, input, output.clone(), dim); + Ok(()) } }; result.map(|_| output) diff --git a/crates/burn-jit/src/kernel/reduce/tune.rs b/crates/burn-jit/src/kernel/reduce/tune.rs index b364907238..c397af1a04 100644 --- a/crates/burn-jit/src/kernel/reduce/tune.rs +++ b/crates/burn-jit/src/kernel/reduce/tune.rs @@ -12,7 +12,6 @@ use crate::{ kernel::prng::random_like_uniform, ops::numeric::empty_device, tensor::JitTensor, JitAutotuneKey, JitElement, JitRuntime, JitTuneId, }; -use reduce_ops::*; /// Executes autotune on reduce operations. pub fn autotune_reduce< @@ -25,7 +24,9 @@ pub fn autotune_reduce< input: JitTensor, output: JitTensor, dim: usize, -) -> Result<(), cubecl::reduce::ReduceError> { +) { + use reduce_ops::*; + static TUNER: LocalTuner = local_tuner!(); let tunables = TunableSet::new(create_key::, reduce_input_gen::) @@ -40,12 +41,10 @@ pub fn autotune_reduce< &tunables, (input, output, dim), ); - - Ok(()) } #[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] -/// Autotune key representative of redue versions +/// Autotune key representative of reduce versions pub struct ReduceAutotuneKey { dtype: burn_tensor::DType, #[autotune(anchor)] @@ -207,3 +206,87 @@ mod reduce_ops { .map_err(|e| format!("{e}")) } } + +/// Executes autotune on reduce operations. +pub fn autotune_sum( + client: &ComputeClient, + input: JitTensor, +) -> JitTensor { + use sum_ops::*; + + static TUNER: LocalTuner = local_tuner!(); + + let tunables = TunableSet::new(create_key_sum::, sum_input_gen::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_one_shot::) + .with_tunable(sum_chained::); + + TUNER.execute( + &JitTuneId::new::(&input.device), + client, + &tunables, + input, + ) +} + +pub(crate) fn create_key_sum(input: &JitTensor) -> JitAutotuneKey { + JitAutotuneKey::Sum(SumAutotuneKey::generate(input)) +} + +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +/// Autotune key representative of sum versions +pub struct SumAutotuneKey { + dtype: burn_tensor::DType, + #[autotune(anchor)] + length: usize, +} + +impl SumAutotuneKey { + pub(crate) fn generate(input: &JitTensor) -> Self { + let dtype = input.dtype; + let length = input.shape.num_elements(); + Self { dtype, length } + } +} +mod sum_ops { + #![allow(missing_docs)] + + use burn_tensor::TensorData; + use cubecl::reduce::instructions::Sum; + + use crate::ops::from_data; + + use super::*; + + pub(crate) fn sum_input_gen( + _key: &JitAutotuneKey, + input: &JitTensor, + ) -> JitTensor { + let random_bounds: (E, E) = ((-10.0_f32).elem::(), (10.0_f32).elem::()); + random_like_uniform(input, random_bounds.0, random_bounds.1) + } + + pub(crate) fn sum_one_shot( + input: JitTensor, + ) -> Result, String> { + let device = input.device.clone(); + cubecl::reduce::shared_sum::(&input.client, input.as_handle_ref(), C) + .map(|output| from_data::(TensorData::new(vec![output], vec![1]), &device)) + .map_err(|e| e.to_string()) + } + + pub(crate) fn sum_chained( + input: JitTensor, + ) -> Result, String> { + crate::kernel::reduce::reduce::( + input, + crate::kernel::reduce::ReduceStrategy::Autotune, + ) + .map_err(|e| e.to_string()) + } +} diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index d32de97436..17b775361e 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -355,7 +355,7 @@ where execute_with_dtype!( float(tensor.dtype), E, - reduce::reduce::(tensor, Default::default()).unwrap() + reduce::sum::(tensor, Default::default()).unwrap() ) } diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index d7b84e5e64..068c1269d9 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -198,7 +198,7 @@ where } fn int_sum(tensor: IntTensor) -> IntTensor { - reduce::reduce::(tensor, Default::default()).unwrap() + reduce::sum::(tensor, Default::default()).unwrap() } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { diff --git a/crates/burn-jit/src/tune_key.rs b/crates/burn-jit/src/tune_key.rs index cb29e2fe0c..9a86a85483 100644 --- a/crates/burn-jit/src/tune_key.rs +++ b/crates/burn-jit/src/tune_key.rs @@ -1,7 +1,7 @@ use crate::kernel::{ conv::{Conv2dAutotuneKey, ConvTranspose2dAutotuneKey}, matmul::MatmulAutotuneKey, - reduce::ReduceAutotuneKey, + reduce::{ReduceAutotuneKey, SumAutotuneKey}, }; use cubecl::tune::AutotuneKey; use serde::{Deserialize, Serialize}; @@ -14,6 +14,8 @@ pub enum JitAutotuneKey { Matmul(MatmulAutotuneKey), /// Key for reduce dim operations Reduce(ReduceAutotuneKey), + /// Key for sum operations + Sum(SumAutotuneKey), /// Key for convolution operations Conv2d(Conv2dAutotuneKey), /// Key for transpose convolution operations @@ -25,6 +27,7 @@ impl Display for JitAutotuneKey { match self { JitAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), JitAutotuneKey::Reduce(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + JitAutotuneKey::Sum(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), JitAutotuneKey::Conv2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), JitAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), } From f978e7ba47660c88d1fad5742fc8d360f568b79d Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 24 Jan 2025 15:51:00 -0500 Subject: [PATCH 47/61] Add new FromData operation description (#2735) * Add new FromData operation description * Only hash tensor desc --- crates/burn-fusion/src/ops/boolean.rs | 69 ++++++++++++++++++------ crates/burn-fusion/src/ops/float.rs | 69 ++++++++++++++++++------ crates/burn-fusion/src/ops/int.rs | 68 +++++++++++++++++------ crates/burn-fusion/src/ops/qtensor.rs | 47 +++++++++++----- crates/burn-fusion/src/stream/context.rs | 6 +++ crates/burn-router/src/ops/op_bool.rs | 19 +++++-- crates/burn-router/src/ops/op_float.rs | 27 +++++++--- crates/burn-router/src/ops/op_int.rs | 27 +++++++--- crates/burn-router/src/runner.rs | 12 +++++ crates/burn-tensor/src/repr/operation.rs | 27 ++++++++-- 10 files changed, 286 insertions(+), 85 deletions(-) diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index baa5169db3..658907bf3e 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -1,5 +1,6 @@ use burn_tensor::{ ops::{binary_ops_shape, FloatTensor, IntTensor}, + repr::{FromDataOperationDescription, TensorDescription}, DType, Element, TensorData, }; use std::marker::PhantomData; @@ -24,15 +25,32 @@ use burn_tensor::{ impl BoolTensorOps for Fusion { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + #[derive(new)] + struct EmptyOps { + desc: TensorDescription, + device: Device, + } + + impl Operation for EmptyOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::bool_empty(Shape::from(&self.desc.shape), &self.device); + handles.register_bool_tensor::(&self.desc.id, output); + } + } + + let stream = StreamId::current(); let client = get_client::(&device.clone()); - let tensor = B::bool_empty(shape.clone(), device); + let out = client.tensor_uninitialized(shape.dims.clone(), DType::Bool); - client.register_tensor( - B::bool_tensor_handle(tensor), - shape.dims, - StreamId::current(), - DType::Bool, - ) + let desc = out.to_description_out(); + + client.register( + vec![stream], + OperationDescription::BaseBool(BaseOperationDescription::Empty(desc.clone())), + EmptyOps::::new(desc, device.clone()), + ); + + out } async fn bool_into_data(tensor: BoolTensor) -> TensorData { @@ -40,16 +58,35 @@ impl BoolTensorOps for Fusion { } fn bool_from_data(data: burn_tensor::TensorData, device: &Device) -> BoolTensor { + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::bool_from_data(self.desc.data, &self.device); + handles.register_bool_tensor::(&self.desc.out.id, output); + } + } + + let stream = StreamId::current(); let client = get_client::(&device.clone()); - let tensor = B::bool_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); - - client.register_tensor( - B::bool_tensor_handle(tensor), - shape.dims, - StreamId::current(), - DType::Bool, - ) + let out = client.tensor_uninitialized(data.shape.clone(), DType::Bool); + + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseBool(BaseOperationDescription::FromData(desc.clone())), + FromDataOps::::new(desc, device.clone()), + ); + + out } fn bool_into_int(tensor: BoolTensor) -> IntTensor { diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index b3e2a80432..1ba2717bfb 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -16,16 +16,35 @@ use std::{marker::PhantomData, ops::Range}; impl FloatTensorOps for Fusion { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::float_from_data(self.desc.data, &self.device); + handles.register_float_tensor::(&self.desc.out.id, output); + } + } + + let stream = StreamId::current(); let client = get_client::(&device.clone()); - let tensor = B::float_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); - - client.register_tensor( - B::float_tensor_handle(tensor), - shape.dims, - StreamId::current(), - B::FloatElem::dtype(), - ) + let out = client.tensor_uninitialized(data.shape.clone(), B::FloatElem::dtype()); + + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::FromData(desc.clone())), + FromDataOps::::new(desc, device.clone()), + ); + + out } fn float_random( @@ -233,16 +252,32 @@ impl FloatTensorOps for Fusion { } fn float_empty(shape: Shape, device: &Device) -> FloatTensor { - let client = get_client::(&device.clone()); + #[derive(new)] + struct EmptyOps { + desc: TensorDescription, + device: Device, + } + + impl Operation for EmptyOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::float_empty(Shape::from(&self.desc.shape), &self.device); + handles.register_float_tensor::(&self.desc.id, output); + } + } + let stream = StreamId::current(); - let tensor = B::float_empty(shape.clone(), device); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(shape.dims.clone(), B::FloatElem::dtype()); - client.register_tensor( - B::float_tensor_handle(tensor), - shape.dims, - stream, - B::FloatElem::dtype(), - ) + let desc = out.to_description_out(); + + client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::Empty(desc.clone())), + EmptyOps::::new(desc, device.clone()), + ); + + out } fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index e2115cbf6a..bf88bbd25b 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -15,16 +15,32 @@ use std::marker::PhantomData; impl IntTensorOps for Fusion { fn int_empty(shape: Shape, device: &Device) -> IntTensor { - let client = get_client::(&device.clone()); - let tensor = B::int_empty(shape.clone(), device); + #[derive(new)] + struct EmptyOps { + desc: TensorDescription, + device: Device, + } + + impl Operation for EmptyOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::int_empty(Shape::from(&self.desc.shape), &self.device); + handles.register_int_tensor::(&self.desc.id, output); + } + } + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(shape.dims.clone(), B::IntElem::dtype()); - client.register_tensor( - B::int_tensor_handle(tensor), - shape.dims, - stream, - B::IntElem::dtype(), - ) + let desc = out.to_description_out(); + + client.register( + vec![stream], + OperationDescription::BaseInt(BaseOperationDescription::Empty(desc.clone())), + EmptyOps::::new(desc, device.clone()), + ); + + out } async fn int_into_data(tensor: IntTensor) -> TensorData { @@ -32,17 +48,35 @@ impl IntTensorOps for Fusion { } fn int_from_data(data: TensorData, device: &Device) -> IntTensor { - let client = get_client::(&device.clone()); - let tensor = B::int_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::int_from_data(self.desc.data, &self.device); + handles.register_int_tensor::(&self.desc.out.id, output); + } + } + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(data.shape.clone(), B::IntElem::dtype()); - client.register_tensor( - B::int_tensor_handle(tensor), - shape.dims, - stream, - B::IntElem::dtype(), - ) + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseInt(BaseOperationDescription::FromData(desc.clone())), + FromDataOps::::new(desc, device.clone()), + ); + + out } fn int_device(tensor: &IntTensor) -> Device { diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index 41bc7ccde6..1449a485af 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -4,8 +4,9 @@ use burn_tensor::{ ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{QuantizationParametersPrimitive, QuantizationScheme}, repr::{ - DequantizeOperationDescription, FloatOperationDescription, HandleContainer, - OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription, + BaseOperationDescription, DequantizeOperationDescription, FloatOperationDescription, + FromDataOperationDescription, HandleContainer, OperationDescription, + QuantizationParametersDescription, QuantizeOperationDescription, }, DType, Device, Element, Shape, TensorData, }; @@ -19,19 +20,41 @@ use crate::{ impl QTensorOps for Fusion { fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::q_from_data(self.desc.data, &self.device); + handles.register_quantized_tensor::(&self.desc.out.id, output); + } + } + match data.dtype { DType::QFloat(_scheme) => { let dtype = data.dtype; - let client = get_client::(device); - let tensor = B::q_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); - - client.register_tensor( - B::quantized_tensor_handle(tensor), - shape.dims, - StreamId::current(), - dtype, - ) + + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(data.shape.clone(), dtype); + + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::FromData( + desc.clone(), + )), + FromDataOps::::new(desc, device.clone()), + ); + + out } _ => panic!( "Invalid dtype (expected DType::QFloat, got {:?})", diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index d85e06cc09..ed1a1902f8 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -1210,6 +1210,12 @@ impl RelativeOps for BaseOperationDescription { BaseOperationDescription::Empty(desc) => { BaseOperationDescription::Empty(desc.to_relative(converter)) } + BaseOperationDescription::FromData(desc) => { + BaseOperationDescription::FromData(FromDataOperationDescription { + data: desc.data.clone(), + out: desc.out.to_relative(converter), + }) + } } } } diff --git a/crates/burn-router/src/ops/op_bool.rs b/crates/burn-router/src/ops/op_bool.rs index 25c46ae854..5d01ddd7d3 100644 --- a/crates/burn-router/src/ops/op_bool.rs +++ b/crates/burn-router/src/ops/op_bool.rs @@ -4,9 +4,9 @@ use burn_tensor::ops::{BoolTensor, BoolTensorOps, FloatElem, FloatTensor, IntEle use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, - OperationDescription, PermuteOperationDescription, RepeatDimOperationDescription, - ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, - SwapDimsDescription, UnaryOperationDescription, + FromDataOperationDescription, OperationDescription, PermuteOperationDescription, + RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription, + SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{DType, Device, Element, Shape, TensorData, TensorMetadata}; @@ -31,7 +31,18 @@ impl BoolTensorOps for BackendRouter { fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { let client = get_client::(device); - client.register_tensor_data(data.convert::()) + let out = client.register_empty_tensor(data.shape.clone(), DType::Bool); + + let desc = FromDataOperationDescription { + data, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::FromData(desc), + )); + + out } fn bool_into_int(tensor: BoolTensor) -> IntTensor { diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs index dda01990e0..1cf211701c 100644 --- a/crates/burn-router/src/ops/op_float.rs +++ b/crates/burn-router/src/ops/op_float.rs @@ -8,13 +8,13 @@ use burn_tensor::ops::{ use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription, - MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, - PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - RepeatDimOperationDescription, ReshapeDescription, ScalarOperationDescription, - ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, - UnaryOperationDescription, + FloatOperationDescription, FromDataOperationDescription, GatherOperationDescription, + MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, + OperationDescription, PermuteOperationDescription, RandomOperationDescription, + ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ReshapeDescription, + ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, + SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, + SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{ DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, @@ -25,7 +25,18 @@ use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient}; impl FloatTensorOps for BackendRouter { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { let client = get_client::(device); - client.register_tensor_data(data.convert::<::FloatElem>()) + let out = client.register_empty_tensor(data.shape.clone(), FloatElem::::dtype()); + + let desc = FromDataOperationDescription { + data, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::FromData(desc), + )); + + out } fn float_random( diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index 5d84131e32..997bf5b9e6 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -8,13 +8,13 @@ use burn_tensor::ops::{ use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - GatherOperationDescription, IntOperationDescription, MaskFillOperationDescription, - MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, - PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - RepeatDimOperationDescription, ReshapeDescription, ScalarOperationDescription, - ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, - UnaryOperationDescription, + FromDataOperationDescription, GatherOperationDescription, IntOperationDescription, + MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, + OperationDescription, PermuteOperationDescription, RandomOperationDescription, + ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ReshapeDescription, + ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, + SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, + SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{ DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, @@ -45,7 +45,18 @@ impl IntTensorOps for BackendRouter { fn int_from_data(data: TensorData, device: &Device) -> IntTensor { let client = get_client::(device); - client.register_tensor_data(data.convert::<::IntElem>()) + let out = client.register_empty_tensor(data.shape.clone(), IntElem::::dtype()); + + let desc = FromDataOperationDescription { + data, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::FromData(desc), + )); + + out } fn int_device(tensor: &IntTensor) -> Device { diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index 9521cf66ae..7443be94f9 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -245,6 +245,10 @@ impl RunnerClient for Runner { let output = B::float_empty(shape, &self.device); handles.register_float_tensor::(&desc.id, output); } + BaseOperationDescription::FromData(desc) => { + let output = B::float_from_data(desc.data.clone(), &self.device); + handles.register_float_tensor::(&desc.out.id, output); + } }, OperationDescription::BaseInt(op) => match op { BaseOperationDescription::ToDevice(_) => unreachable!(), @@ -316,6 +320,10 @@ impl RunnerClient for Runner { let output = B::int_empty(shape, &self.device); handles.register_int_tensor::(&desc.id, output); } + BaseOperationDescription::FromData(desc) => { + let output = B::int_from_data(desc.data.clone(), &self.device); + handles.register_int_tensor::(&desc.out.id, output); + } }, OperationDescription::BaseBool(op) => match op { BaseOperationDescription::ToDevice(_) => unreachable!(), @@ -391,6 +399,10 @@ impl RunnerClient for Runner { let output = B::bool_empty(shape, &self.device); handles.register_bool_tensor::(&desc.id, output); } + BaseOperationDescription::FromData(desc) => { + let output = B::bool_from_data(desc.data.clone(), &self.device); + handles.register_bool_tensor::(&desc.out.id, output); + } }, OperationDescription::NumericFloat(_dtype, op) => match op { NumericOperationDescription::Add(desc) => { diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 0d7fe2493b..e4b0f3ccaf 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -6,6 +6,7 @@ use alloc::borrow::ToOwned; use alloc::boxed::Box; use alloc::{string::String, vec, vec::Vec}; +use crate::TensorData; use crate::{ ops::{ ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions, @@ -197,6 +198,12 @@ pub enum ModuleOperationDescription { /// Basic operations that can be done on any tensor type. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum BaseOperationDescription { + /// Operation corresponding to: + /// + /// Float => [from_data](crate::ops::FloatTensorOps::float_from_data). + /// Int => [from_data](crate::ops::IntTensorOps::int_from_data). + /// Bool => [from_data](crate::ops::BoolTensorOps::bool_from_data). + FromData(FromDataOperationDescription), /// Operation corresponding to: /// /// Float => [to device](crate::ops::FloatTensorOps::float_to_device). @@ -272,9 +279,9 @@ pub enum BaseOperationDescription { /// Operation corresponding to: /// - /// Float => [equal](crate::ops::FloatTensorOps::float_empty). - /// Int => [equal](crate::ops::IntTensorOps::int_empty). - /// Bool => [equal](crate::ops::BoolTensorOps::bool_empty). + /// Float => [empty](crate::ops::FloatTensorOps::float_empty). + /// Int => [empty](crate::ops::IntTensorOps::int_empty). + /// Bool => [empty](crate::ops::BoolTensorOps::bool_empty). Empty(TensorDescription), } @@ -630,6 +637,13 @@ pub struct RandomOperationDescription { pub distribution: Distribution, } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[allow(missing_docs)] +pub struct FromDataOperationDescription { + pub out: TensorDescription, + pub data: TensorData, +} + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ReshapeDescription { @@ -1408,6 +1422,7 @@ impl BaseOperationDescription { BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(), BaseOperationDescription::Cast(desc) => vec![&desc.input, &desc.out], BaseOperationDescription::Empty(desc) => vec![desc], + BaseOperationDescription::FromData(desc) => vec![&desc.out], } } } @@ -1754,6 +1769,12 @@ impl ModuleOperationDescription { } } +impl core::hash::Hash for FromDataOperationDescription { + fn hash(&self, state: &mut H) { + self.out.hash(state); + } +} + impl core::hash::Hash for RandomOperationDescription { fn hash(&self, state: &mut H) { self.out.hash(state); From fae641fabf3e3a412a96f9e111c506ffa168d187 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 27 Jan 2025 08:15:13 -0500 Subject: [PATCH 48/61] Combined PRs (#2749) * Bump wgpu from 24.0.0 to 24.0.1 Bumps [wgpu](https://github.com/gfx-rs/wgpu) from 24.0.0 to 24.0.1. - [Release notes](https://github.com/gfx-rs/wgpu/releases) - [Changelog](https://github.com/gfx-rs/wgpu/blob/wgpu-v24.0.1/CHANGELOG.md) - [Commits](https://github.com/gfx-rs/wgpu/compare/v24.0.0...wgpu-v24.0.1) --- updated-dependencies: - dependency-name: wgpu dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump uuid from 1.12.0 to 1.12.1 Bumps [uuid](https://github.com/uuid-rs/uuid) from 1.12.0 to 1.12.1. - [Release notes](https://github.com/uuid-rs/uuid/releases) - [Commits](https://github.com/uuid-rs/uuid/compare/1.12.0...1.12.1) --- updated-dependencies: - dependency-name: uuid dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] * Bump clap from 4.5.26 to 4.5.27 Bumps [clap](https://github.com/clap-rs/clap) from 4.5.26 to 4.5.27. - [Release notes](https://github.com/clap-rs/clap/releases) - [Changelog](https://github.com/clap-rs/clap/blob/master/CHANGELOG.md) - [Commits](https://github.com/clap-rs/clap/compare/clap_complete-v4.5.26...clap_complete-v4.5.27) --- updated-dependencies: - dependency-name: clap dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Cargo.lock | 16 ++++++++-------- Cargo.toml | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b5a1835a3f..f739c341ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1225,9 +1225,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8eb5e908ef3a6efbe1ed62520fb7287959888c88485abe072543190ecc66783" +checksum = "769b0145982b4b48713e01ec42d61614425f27b7058bda7180a3a41f30104796" dependencies = [ "clap_builder", "clap_derive", @@ -1235,9 +1235,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.26" +version = "4.5.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96b01801b5fc6a0a232407abc821660c9c6d25a1cafc0d4f85f29fb8d9afc121" +checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" dependencies = [ "anstream", "anstyle", @@ -7842,9 +7842,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.12.0" +version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744018581f9a3454a9e15beb8a33b017183f1e7c0cd170232a2d1453b23a51c4" +checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" dependencies = [ "getrandom", "rand", @@ -8088,9 +8088,9 @@ dependencies = [ [[package]] name = "wgpu" -version = "24.0.0" +version = "24.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e41253fc7b660735e2a2d9a58c563f2a047d3cc3445293d8f4095538c9e8afbe" +checksum = "47f55718f85c2fa756edffa0e7f0e0a60aba463d1362b57e23123c58f035e4b6" dependencies = [ "arrayvec", "bitflags 2.8.0", diff --git a/Cargo.toml b/Cargo.toml index 7cf3ddd008..af983f62cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ version = "0.17.0" atomic_float = "1" bytemuck = "1.21.0" candle-core = { version = "0.8" } -clap = { version = "4.5.26", features = ["derive"] } +clap = { version = "4.5.27", features = ["derive"] } colored = "2.1.0" console_error_panic_hook = "0.1.7" csv = "1.3.1" @@ -101,7 +101,7 @@ ratatui = "0.29.0" # WGPU stuff text_placeholder = "0.5.1" -wgpu = "24.0.0" +wgpu = "24.0.1" # Benchmarks and Burnbench arboard = "3.4.1" @@ -141,7 +141,7 @@ serde = { version = "1.0.217", default-features = false, features = [ "alloc", ] } # alloc is for no_std, derive is needed serde_json = { version = "1.0.137", default-features = false } -uuid = { version = "1.12.0", default-features = false } +uuid = { version = "1.12.1", default-features = false } libc = "0.2.169" nvml-wrapper = "0.10.0" From 51a29730a1fa23e66bf615effbed26a4ece39ce5 Mon Sep 17 00:00:00 2001 From: Cameron Braid Date: Tue, 28 Jan 2025 00:20:34 +1100 Subject: [PATCH 49/61] typo - correct `smp_serde` to `rmp_serde` as per crate's name in url (#2744) --- burn-book/src/saving-and-loading.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/burn-book/src/saving-and-loading.md b/burn-book/src/saving-and-loading.md index 77f7c863d6..24b52dd22a 100644 --- a/burn-book/src/saving-and-loading.md +++ b/burn-book/src/saving-and-loading.md @@ -4,7 +4,7 @@ Saving your trained machine learning model is quite easy, no matter the output f mentioned in the [Record](./building-blocks/record.md) section, different formats are supported to serialize/deserialize models. By default, we use the `NamedMpkFileRecorder` which uses the [MessagePack](https://msgpack.org/) binary serialization format with the help of -[smp_serde](https://docs.rs/rmp-serde/). +[rmp_serde](https://docs.rs/rmp-serde/). ```rust, ignore // Save model in MessagePack format with full precision From 894fdbc54ab37066af4366fa876036c818a47225 Mon Sep 17 00:00:00 2001 From: Cameron Braid Date: Tue, 28 Jan 2025 00:21:15 +1100 Subject: [PATCH 50/61] typo - missing `tick` which was breaking formatting (#2745) --- burn-book/src/building-blocks/tensor.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index 410d531d74..3713ee2571 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -339,7 +339,7 @@ strategies. | Burn API | PyTorch Equivalent | | ------------------------------------------------ | -------------------------------------------------- | | `activation::gelu(tensor)` | `nn.functional.gelu(tensor)` | -| `activation::hard_sigmoid(tensor, alpha, beta) | `nn.functional.hardsigmoid(tensor)` | +| `activation::hard_sigmoid(tensor, alpha, beta)` | `nn.functional.hardsigmoid(tensor)` | | `activation::leaky_relu(tensor, negative_slope)` | `nn.functional.leaky_relu(tensor, negative_slope)` | | `activation::log_sigmoid(tensor)` | `nn.functional.log_sigmoid(tensor)` | | `activation::log_softmax(tensor, dim)` | `nn.functional.log_softmax(tensor, dim)` | From 29c383b87d58190499c2671eb08f47cd48cc4e43 Mon Sep 17 00:00:00 2001 From: Maxime Tremblay Date: Mon, 27 Jan 2025 09:57:28 -0500 Subject: [PATCH 51/61] Replace return with terminate (#2742) * replace return with terminate * bump cubecl * cargo fmt --- Cargo.lock | 41 ++++++++++++------- Cargo.toml | 4 +- crates/burn-jit/src/kernel/binary.rs | 4 +- crates/burn-jit/src/kernel/binary_int.rs | 4 +- crates/burn-jit/src/kernel/cast/base.rs | 2 +- crates/burn-jit/src/kernel/comparison.rs | 4 +- .../burn-jit/src/kernel/conv/conv2d/col2im.rs | 2 +- .../burn-jit/src/kernel/conv/conv2d/direct.rs | 2 +- .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 2 +- .../src/kernel/conv/conv2d/layout_swap.rs | 4 +- .../kernel/conv/conv2d/transpose_direct.rs | 2 +- crates/burn-jit/src/kernel/conv/conv3d.rs | 2 +- .../kernel/conv/deform_conv_transpose2d.rs | 4 +- crates/burn-jit/src/kernel/index/flip.rs | 2 +- crates/burn-jit/src/kernel/index/gather.rs | 2 +- .../burn-jit/src/kernel/index/repeat_dim.rs | 2 +- crates/burn-jit/src/kernel/index/scatter.rs | 2 +- crates/burn-jit/src/kernel/index/select.rs | 2 +- .../src/kernel/index/select_assign.rs | 2 +- crates/burn-jit/src/kernel/index/slice.rs | 2 +- .../src/kernel/interpolate/bicubic.rs | 2 +- .../src/kernel/interpolate/bilinear.rs | 2 +- .../src/kernel/interpolate/nearest.rs | 2 +- .../kernel/interpolate/nearest_backward.rs | 2 +- crates/burn-jit/src/kernel/mask/mask_fill.rs | 4 +- crates/burn-jit/src/kernel/mask/mask_where.rs | 4 +- .../src/kernel/pool/avg_pool2d_backward.rs | 2 +- .../src/kernel/pool/max_pool2d_backward.rs | 2 +- .../src/kernel/quantization/dequantize.rs | 4 +- .../src/kernel/quantization/quantize.rs | 10 ++--- crates/burn-jit/src/kernel/unary_float.rs | 2 +- crates/burn-jit/src/kernel/unary_int.rs | 2 +- crates/burn-jit/src/kernel/unary_numeric.rs | 2 +- crates/burn-jit/src/ops/numeric.rs | 2 +- examples/custom-cubecl-kernel/src/kernel.rs | 2 +- 35 files changed, 74 insertions(+), 61 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f739c341ad..5f5b948a86 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1575,7 +1575,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1590,7 +1590,7 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "derive-new 0.6.0", @@ -1611,7 +1611,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1632,7 +1632,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-common", @@ -1646,7 +1646,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-common", @@ -1662,7 +1662,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-common", @@ -1688,9 +1688,11 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-common", + "cubecl-macros-internal", + "derive_more 1.0.0", "float-ord", "half", "num-traits", @@ -1701,7 +1703,7 @@ dependencies = [ [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bytemuck", "cubecl-core", @@ -1713,7 +1715,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-common", "darling", @@ -1725,10 +1727,21 @@ dependencies = [ "syn 2.0.96", ] +[[package]] +name = "cubecl-macros-internal" +version = "0.5.0" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 2.0.96", +] + [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1744,7 +1757,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1754,7 +1767,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "async-channel", "async-lock", @@ -1776,7 +1789,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1791,7 +1804,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=a43015e2069e2728274a46242e928db189e56982#a43015e2069e2728274a46242e928db189e56982" +source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" dependencies = [ "ash", "async-channel", diff --git a/Cargo.toml b/Cargo.toml index af983f62cb..db1073ad67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a43015e2069e2728274a46242e928db189e56982" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/kernel/binary.rs b/crates/burn-jit/src/kernel/binary.rs index d7c4d789ab..f0da764a7a 100644 --- a/crates/burn-jit/src/kernel/binary.rs +++ b/crates/burn-jit/src/kernel/binary.rs @@ -112,7 +112,7 @@ pub(crate) fn kernel_scalar_binop( output: &mut Tensor>, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); @@ -132,7 +132,7 @@ pub(crate) fn kernel_binop( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { diff --git a/crates/burn-jit/src/kernel/binary_int.rs b/crates/burn-jit/src/kernel/binary_int.rs index 06706a7d28..390bfc479e 100644 --- a/crates/burn-jit/src/kernel/binary_int.rs +++ b/crates/burn-jit/src/kernel/binary_int.rs @@ -85,7 +85,7 @@ pub(crate) fn kernel_scalar_binop_int( output: &mut Tensor>, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); @@ -105,7 +105,7 @@ pub(crate) fn kernel_binop_int( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { diff --git a/crates/burn-jit/src/kernel/cast/base.rs b/crates/burn-jit/src/kernel/cast/base.rs index 798b79a0f0..43b24f071a 100644 --- a/crates/burn-jit/src/kernel/cast/base.rs +++ b/crates/burn-jit/src/kernel/cast/base.rs @@ -12,7 +12,7 @@ pub(crate) fn cast_element( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } let offset_input = index_offset_with_layout::( diff --git a/crates/burn-jit/src/kernel/comparison.rs b/crates/burn-jit/src/kernel/comparison.rs index e33687fb5a..a6de9025bb 100644 --- a/crates/burn-jit/src/kernel/comparison.rs +++ b/crates/burn-jit/src/kernel/comparison.rs @@ -82,7 +82,7 @@ pub(crate) fn kernel_scalar_cmp>( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } output[ABSOLUTE_POS] = Line::cast_from(O::execute(input[ABSOLUTE_POS], Line::new(scalar))); @@ -102,7 +102,7 @@ pub(crate) fn kernel_cmp>( let mut offset_rhs = ABSOLUTE_POS; if offset_out >= out.len() { - return; + terminate!(); } if to_contiguous_lhs { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs index 11fb3b4aee..4f6931f86d 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -241,7 +241,7 @@ fn col2im_kernel( #[comptime] has_bias: bool, ) { if ABSOLUTE_POS >= image.len() { - return; + terminate!(); } let im_x = ABSOLUTE_POS % image.shape(3) + args.pad_w; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs index c724cfc3a3..1cd24f7c0c 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs @@ -35,7 +35,7 @@ fn direct_conv2d_kernel( #[comptime] kernel_size_1_unroll: Option, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_channels = weight.shape(1); diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index 6b738ab988..f74cdaf8bc 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -53,7 +53,7 @@ fn im2col_kernel( let out_w = args.out_w; if ABSOLUTE_POS > args.num_elements { - return; + terminate!(); } let out_x = ABSOLUTE_POS % out_w; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs index 62f0e56d8f..7cbe09dbc0 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/layout_swap.rs @@ -107,7 +107,7 @@ fn nchw_to_nhwc_kernel( let batch = CUBE_POS_Z; if batch >= input.shape(0) { - return; + terminate!(); } let batch_offset = batch * input.stride(0); @@ -163,7 +163,7 @@ fn nchw_to_nhwc_kernel( let hw = base_hw + mat_hw; if hw >= shape_hw { - return; + terminate!(); } let mat_c_start = mat_hw_start; diff --git a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index d3e91d5947..a8cd1ceb7f 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -32,7 +32,7 @@ fn conv_transpose2d_direct_kernel( args: ConvArgs, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_c_per_group = weight.shape(0) / args.groups; diff --git a/crates/burn-jit/src/kernel/conv/conv3d.rs b/crates/burn-jit/src/kernel/conv/conv3d.rs index 157610794b..a616c432b9 100644 --- a/crates/burn-jit/src/kernel/conv/conv3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv3d.rs @@ -41,7 +41,7 @@ fn conv3d_kernel( #[comptime] kernel_size_2_unroll: Option, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let in_channels = weight.shape(1); diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs index ddee1360e4..5840f4dc9f 100644 --- a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -275,7 +275,7 @@ fn deform_col2img_coord_kernel( // Alternatively : [batch, offset_channels, out_h, out_w] if ABSOLUTE_POS >= grad_offset.len() { - return; + terminate!(); } let offset_channels = offset.shape(1); @@ -551,7 +551,7 @@ fn deform_col2img_kernel( ) { // Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] if ABSOLUTE_POS >= columns.len() { - return; + terminate!(); } let n_in_channels = args.in_channels; diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index 583e0346d3..a682a76eac 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -11,7 +11,7 @@ fn flip_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/gather.rs b/crates/burn-jit/src/kernel/index/gather.rs index 9e9b5685bb..c1aa56072e 100644 --- a/crates/burn-jit/src/kernel/index/gather.rs +++ b/crates/burn-jit/src/kernel/index/gather.rs @@ -12,7 +12,7 @@ fn gather_kernel( dim: &u32, ) { if ABSOLUTE_POS >= indices.len() { - return; + terminate!(); } let index = indices[ABSOLUTE_POS]; diff --git a/crates/burn-jit/src/kernel/index/repeat_dim.rs b/crates/burn-jit/src/kernel/index/repeat_dim.rs index 3887bfbd8b..b19f9e2b21 100644 --- a/crates/burn-jit/src/kernel/index/repeat_dim.rs +++ b/crates/burn-jit/src/kernel/index/repeat_dim.rs @@ -4,7 +4,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*}; #[cube(launch_unchecked)] fn repeat_dim_kernel(input: &Tensor, output: &mut Tensor, dim: u32) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/scatter.rs b/crates/burn-jit/src/kernel/index/scatter.rs index 4ddd9c00fb..4cca94f824 100644 --- a/crates/burn-jit/src/kernel/index/scatter.rs +++ b/crates/burn-jit/src/kernel/index/scatter.rs @@ -46,7 +46,7 @@ fn scatter_kernel( let should_stop = ABSOLUTE_POS >= num_elems; if should_stop { - return; + terminate!(); } for i in 0..shape_value { diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index b104bf504f..fe664ab420 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -10,7 +10,7 @@ fn select_kernel( dim: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index a0fed49dbd..cd4c013f63 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -29,7 +29,7 @@ fn select_assign_kernel( } if ABSOLUTE_POS >= num_elems { - return; + terminate!(); } let strides_tensor_dim = tensor.stride(dim); diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index 7f20f033b8..b6daba8da5 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -52,7 +52,7 @@ fn slice_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let mut offset_input = 0; diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index 1d545d79c7..3f77ef1302 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch)] fn interpolate_bicubic_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/bilinear.rs b/crates/burn-jit/src/kernel/interpolate/bilinear.rs index 3557fcdbb8..f0cb95b536 100644 --- a/crates/burn-jit/src/kernel/interpolate/bilinear.rs +++ b/crates/burn-jit/src/kernel/interpolate/bilinear.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch)] fn interpolate_bilinear_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest.rs b/crates/burn-jit/src/kernel/interpolate/nearest.rs index 0743a13567..0e6ba32552 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch_unchecked)] fn interpolate_nearest_kernel(input: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs index 5ea860a7ae..f0442ec92e 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs @@ -5,7 +5,7 @@ use crate::{tensor::JitTensor, FloatElement, JitRuntime}; #[cube(launch_unchecked)] fn interpolate_nearest_backward_kernel(grad: &Tensor, output: &mut Tensor) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let out_h = output.shape(2); diff --git a/crates/burn-jit/src/kernel/mask/mask_fill.rs b/crates/burn-jit/src/kernel/mask/mask_fill.rs index 386e7a5039..95096c7994 100644 --- a/crates/burn-jit/src/kernel/mask/mask_fill.rs +++ b/crates/burn-jit/src/kernel/mask/mask_fill.rs @@ -16,7 +16,7 @@ fn mask_fill_readonly_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); @@ -35,7 +35,7 @@ fn mask_fill_inplace_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= input.len() { - return; + terminate!(); } let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); diff --git a/crates/burn-jit/src/kernel/mask/mask_where.rs b/crates/burn-jit/src/kernel/mask/mask_where.rs index 5518e9648b..99384fde98 100644 --- a/crates/burn-jit/src/kernel/mask/mask_where.rs +++ b/crates/burn-jit/src/kernel/mask/mask_where.rs @@ -16,7 +16,7 @@ fn mask_where_readonly_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let index_input = index_offset_with_layout(input, output, ABSOLUTE_POS, 0, rank, true); @@ -36,7 +36,7 @@ fn mask_where_inplace_kernel( #[comptime] rank: u32, ) { if ABSOLUTE_POS >= input.len() { - return; + terminate!(); } let index_mask = index_offset_with_layout(mask, input, ABSOLUTE_POS, 0, rank, true); diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs index bba68c7166..d2a5a21d0a 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs @@ -24,7 +24,7 @@ fn avg_pool2d_backward_kernel( #[comptime] count_include_pad: bool, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index 6da6e2b37c..40259c4573 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -16,7 +16,7 @@ fn max_pool2d_with_indices_backward_kernel( #[comptime] kernel_size_1: i32, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let batch = ABSOLUTE_POS / output.stride(0) % output.shape(0); diff --git a/crates/burn-jit/src/kernel/quantization/dequantize.rs b/crates/burn-jit/src/kernel/quantization/dequantize.rs index 72040d8839..270e32f854 100644 --- a/crates/burn-jit/src/kernel/quantization/dequantize.rs +++ b/crates/burn-jit/src/kernel/quantization/dequantize.rs @@ -48,7 +48,7 @@ pub(crate) fn dequantize_per_tensor_affine_int8_kernel( ) { // Last two positions contain the qparams if ABSOLUTE_POS >= input.len() - 2 { - return; + terminate!(); } let qparams = QParams::new(scheme); @@ -85,7 +85,7 @@ pub(crate) fn dequantize_per_tensor_symmetric_int8_kernel( ) { // Last position contains the qparam if ABSOLUTE_POS >= input.len() - 1 { - return; + terminate!(); } let qparams = QParams::new(scheme); diff --git a/crates/burn-jit/src/kernel/quantization/quantize.rs b/crates/burn-jit/src/kernel/quantization/quantize.rs index e9494aa987..0a7b0ea553 100644 --- a/crates/burn-jit/src/kernel/quantization/quantize.rs +++ b/crates/burn-jit/src/kernel/quantization/quantize.rs @@ -34,7 +34,7 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let scale = scale[0]; @@ -43,13 +43,13 @@ pub(crate) fn quantize_per_tensor_affine_int8_kernel( // Cast the scale to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 1 { output[ABSOLUTE_POS] = u32::bitcast_from(scale); - return; + terminate!(); } // Cast the offset to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 2 { output[ABSOLUTE_POS] = u32::bitcast_from(offset); - return; + terminate!(); } let line_size = comptime!(input.line_size()); @@ -120,7 +120,7 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( output: &mut Array, ) { if ABSOLUTE_POS >= output.len() { - return; + terminate!(); } let scale = scale[0]; @@ -128,7 +128,7 @@ pub(crate) fn quantize_per_tensor_symmetric_int8_kernel( // Cast the scale to u32 and write the value in the output if ABSOLUTE_POS == output.len() - 1 { output[ABSOLUTE_POS] = u32::bitcast_from(scale); - return; + terminate!(); } let line_size = comptime!(input.line_size()); diff --git a/crates/burn-jit/src/kernel/unary_float.rs b/crates/burn-jit/src/kernel/unary_float.rs index 33a311ecbc..4664d3c0b3 100644 --- a/crates/burn-jit/src/kernel/unary_float.rs +++ b/crates/burn-jit/src/kernel/unary_float.rs @@ -27,7 +27,7 @@ pub(crate) fn unary_float( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } if comptime![to_contiguous] { diff --git a/crates/burn-jit/src/kernel/unary_int.rs b/crates/burn-jit/src/kernel/unary_int.rs index 5e60898699..17bced52d1 100644 --- a/crates/burn-jit/src/kernel/unary_int.rs +++ b/crates/burn-jit/src/kernel/unary_int.rs @@ -27,7 +27,7 @@ pub(crate) fn unary_int( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } if comptime![to_contiguous] { diff --git a/crates/burn-jit/src/kernel/unary_numeric.rs b/crates/burn-jit/src/kernel/unary_numeric.rs index 0b8dcb2cbc..aaeadbb685 100644 --- a/crates/burn-jit/src/kernel/unary_numeric.rs +++ b/crates/burn-jit/src/kernel/unary_numeric.rs @@ -27,7 +27,7 @@ pub(crate) fn unary_numeric( let offset_output = ABSOLUTE_POS; if offset_output >= output.len() { - return; + terminate!(); } if comptime![to_contiguous] { diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index 2c2c7987ab..cf15916aab 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -31,7 +31,7 @@ pub fn full_device( #[cube(launch)] pub fn full_kernel(tensor: &mut Tensor, value: C) { if ABSOLUTE_POS >= tensor.len() { - return; + terminate!(); } tensor[ABSOLUTE_POS] = value; diff --git a/examples/custom-cubecl-kernel/src/kernel.rs b/examples/custom-cubecl-kernel/src/kernel.rs index 0809971327..08d4ded4d7 100644 --- a/examples/custom-cubecl-kernel/src/kernel.rs +++ b/examples/custom-cubecl-kernel/src/kernel.rs @@ -17,7 +17,7 @@ pub fn fused_matmul_add_relu_kernel( let dim_k = rhs.shape(rhs.rank() - 1); if row >= n_rows || col >= n_cols { - return; + terminate!(); } let offset_output = batch * n_rows * n_cols; From 2d9e9b9a1900729fdf9c643f49d3a18d1d5b937d Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 28 Jan 2025 09:05:48 -0500 Subject: [PATCH 52/61] Clean up -jit suffix in feature flags and modules (#2705) --- backend-comparison/Cargo.toml | 8 +-- backend-comparison/src/lib.rs | 20 +++---- crates/burn-core/Cargo.toml | 17 +++--- crates/burn-core/src/backend.rs | 22 ++++--- crates/burn-hip/src/lib.rs | 5 +- crates/burn-wgpu/Cargo.toml | 9 ++- crates/burn-wgpu/src/lib.rs | 57 ++++++++++++------- crates/burn/Cargo.toml | 7 ++- crates/burn/src/lib.rs | 6 +- .../examples/custom-renderer.rs | 4 +- examples/custom-training-loop/Cargo.toml | 2 +- .../examples/custom-training-loop.rs | 4 +- examples/guide/Cargo.toml | 2 +- examples/guide/src/bin/infer.rs | 4 +- examples/guide/src/bin/print.rs | 4 +- examples/guide/src/bin/train.rs | 4 +- examples/image-classification-web/src/web.rs | 6 +- examples/server/Cargo.toml | 8 +-- examples/server/src/lib.rs | 10 ++-- examples/text-classification/Cargo.toml | 6 +- examples/text-classification/README.md | 4 +- .../examples/ag-news-infer.rs | 12 ++-- .../examples/ag-news-train.rs | 20 +++---- examples/wgan/Cargo.toml | 2 +- examples/wgan/README.md | 4 +- examples/wgan/examples/wgan-generate.rs | 12 ++-- examples/wgan/examples/wgan-mnist.rs | 12 ++-- xtask/src/commands/test.rs | 2 +- 28 files changed, 153 insertions(+), 120 deletions(-) diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index 265dbeaaf0..821d189fe0 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -15,10 +15,10 @@ candle-accelerate = ["burn/candle", "burn/accelerate"] candle-cpu = ["burn/candle"] candle-cuda = ["burn/candle-cuda"] candle-metal = ["burn/candle", "burn/metal"] -cuda-jit = ["burn/cuda-jit"] -cuda-jit-fusion = ["cuda-jit", "burn/fusion"] +cuda = ["burn/cuda"] +cuda-fusion = ["cuda", "burn/fusion"] default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"] -hip-jit = ["burn/hip-jit"] +hip = ["burn/hip"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] @@ -27,7 +27,7 @@ tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu", "burn/autotune"] wgpu-fusion = ["wgpu", "burn/fusion"] -wgpu-spirv = ["burn/wgpu-spirv", "burn/autotune"] +wgpu-spirv = ["burn/vulkan", "burn/autotune"] wgpu-spirv-fusion = ["wgpu-spirv", "burn/fusion"] [dependencies] diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index 26b08bc3b8..b3351e9dd5 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -91,12 +91,12 @@ macro_rules! bench_on_backend { let feature_name = "wgpu-spirv"; #[cfg(feature = "wgpu-spirv-fusion")] let feature_name = "wgpu-spirv-fusion"; - #[cfg(feature = "cuda-jit")] - let feature_name = "cuda-jit"; - #[cfg(feature = "cuda-jit-fusion")] - let feature_name = "cuda-jit-fusion"; - #[cfg(feature = "hip-jit")] - let feature_name = "hip-jit"; + #[cfg(feature = "cuda")] + let feature_name = "cuda"; + #[cfg(feature = "cuda-fusion")] + let feature_name = "cuda-fusion"; + #[cfg(feature = "hip")] + let feature_name = "hip"; #[cfg(any(feature = "wgpu"))] { @@ -172,16 +172,16 @@ macro_rules! bench_on_backend { $fn_name::(&device, feature_name, url, token); } - #[cfg(feature = "cuda-jit")] + #[cfg(feature = "cuda")] { - use burn::backend::cuda_jit::{Cuda, CudaDevice}; + use burn::backend::cuda::{Cuda, CudaDevice}; $fn_name::>(&CudaDevice::default(), feature_name, url, token); } - #[cfg(feature = "hip-jit")] + #[cfg(feature = "hip")] { - use burn::backend::hip_jit::{Hip, HipDevice}; + use burn::backend::hip::{Hip, HipDevice}; $fn_name::>(&HipDevice::default(), feature_name, url, token); } diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index e895cc4572..5d926cd0b3 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -36,8 +36,8 @@ doc = [ "ndarray", "tch", "wgpu", - "cuda-jit", - "hip-jit", + "cuda", + "hip", "audio", "vision", "autodiff", @@ -100,12 +100,13 @@ template = ["burn-wgpu?/template"] candle = ["burn-candle"] candle-cuda = ["candle", "burn-candle/cuda"] -cuda-jit = ["burn-cuda"] -hip-jit = ["burn-hip"] +cuda = ["burn-cuda"] +hip = ["burn-hip"] ndarray = ["burn-ndarray"] tch = ["burn-tch"] wgpu = ["burn-wgpu"] -wgpu-spirv = ["wgpu", "burn-wgpu/spirv"] +vulkan = ["wgpu", "burn-wgpu/vulkan"] +webgpu = ["wgpu", "burn-wgpu/webgpu"] # Custom deserializer for Record that is helpful for importing data, such as PyTorch pt files. record-item-custom-serde = ["thiserror", "regex"] @@ -113,13 +114,13 @@ record-item-custom-serde = ["thiserror", "regex"] # Serialization formats experimental-named-tensor = ["burn-tensor/experimental-named-tensor"] -test-cuda = ["cuda-jit"] # To use cuda during testing, default uses ndarray. -test-hip = ["hip-jit"] # To use hip during testing, default uses ndarray. +test-cuda = ["cuda"] # To use cuda during testing, default uses ndarray. +test-hip = ["hip"] # To use hip during testing, default uses ndarray. test-tch = ["tch"] # To use tch during testing, default uses ndarray. test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray. test-wgpu-spirv = [ "test-wgpu", - "wgpu-spirv", + "vulkan", ] # To use wgpu-spirv during testing, default uses ndarray. [dependencies] diff --git a/crates/burn-core/src/backend.rs b/crates/burn-core/src/backend.rs index bd4c959302..31ac3a8c41 100644 --- a/crates/burn-core/src/backend.rs +++ b/crates/burn-core/src/backend.rs @@ -21,11 +21,17 @@ pub use burn_wgpu as wgpu; #[cfg(feature = "wgpu")] pub use burn_wgpu::Wgpu; -#[cfg(feature = "cuda-jit")] -pub use burn_cuda as cuda_jit; +#[cfg(feature = "webgpu")] +pub use burn_wgpu::WebGpu; -#[cfg(feature = "cuda-jit")] -pub use burn_cuda::Cuda as CudaJit; +#[cfg(feature = "vulkan")] +pub use burn_wgpu::Vulkan; + +#[cfg(feature = "cuda")] +pub use burn_cuda as cuda; + +#[cfg(feature = "cuda")] +pub use burn_cuda::Cuda; #[cfg(feature = "candle")] pub use burn_candle as candle; @@ -33,11 +39,11 @@ pub use burn_candle as candle; #[cfg(feature = "candle")] pub use burn_candle::Candle; -#[cfg(feature = "hip-jit")] -pub use burn_hip as hip_jit; +#[cfg(feature = "hip")] +pub use burn_hip as hip; -#[cfg(feature = "hip-jit")] -pub use burn_hip::Hip as HipJit; +#[cfg(feature = "hip")] +pub use burn_hip::Hip; #[cfg(feature = "tch")] pub use burn_tch as libtorch; diff --git a/crates/burn-hip/src/lib.rs b/crates/burn-hip/src/lib.rs index fc8f704e74..13f5239637 100644 --- a/crates/burn-hip/src/lib.rs +++ b/crates/burn-hip/src/lib.rs @@ -26,7 +26,8 @@ pub type Hip = burn_fusion::Fusion( +/// burn::backend::wgpu::init_setup::( /// &device, /// Default::default(), /// ); /// } /// ``` /// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API. -/// It's also possible to use an existing wgpu device, by using `init_existing_device`. +/// It's also possible to use an existing wgpu device, by using `init_device`. /// /// # Notes /// @@ -60,7 +60,7 @@ type Bool = u32; /// /// You can disable the `fusion` feature flag to remove that functionality, which might be /// necessary on `wasm` for now. -pub type Wgpu = +pub type Wgpu = burn_fusion::Fusion, F, I, B>>; #[cfg(not(feature = "fusion"))] @@ -79,14 +79,14 @@ pub type Wgpu = /// ```rust, ignore /// fn custom_init() { /// let device = Default::default(); -/// burn::backend::wgpu::init_sync::( +/// burn::backend::wgpu::init_setup::( /// &device, /// Default::default(), /// ); /// } /// ``` /// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API. -/// It's also possible to use an existing wgpu device, by using `init_existing_device`. +/// It's also possible to use an existing wgpu device, by using `init_device`. /// /// # Notes /// @@ -95,20 +95,33 @@ pub type Wgpu = /// /// You can enable the `fusion` feature flag to add that functionality, which might improve /// performance. -pub type Wgpu = +pub type Wgpu = JitBackend, F, I, B>; +#[cfg(feature = "vulkan")] +/// Tensor backend that leverages the Vulkan graphics API to execute GPU compute shaders compiled to SPIR-V. +pub type Vulkan = Wgpu; + +#[cfg(feature = "webgpu")] +/// Tensor backend that uses the wgpu crate to execute GPU compute shaders written in WGSL. +pub type WebGpu = Wgpu; + #[cfg(test)] mod tests { use burn_jit::JitBackend; - #[cfg(feature = "spirv")] + #[cfg(feature = "vulkan")] pub use half::f16; - pub type TestRuntime = cubecl::wgpu::WgpuRuntime; + + #[cfg(feature = "cubecl-spirv")] + type Compiler = cubecl::wgpu::spirv::VkSpirvCompiler; + #[cfg(not(feature = "cubecl-spirv"))] + type Compiler = cubecl::wgpu::WgslCompiler; + pub type TestRuntime = cubecl::wgpu::WgpuRuntime; // Don't test `flex32` for now, burn sees it as `f32` but is actually `f16` precision, so it // breaks a lot of tests from precision issues - #[cfg(feature = "spirv")] + #[cfg(feature = "vulkan")] burn_jit::testgen_all!([f16, f32], [i8, i16, i32, i64], [u8, u32]); - #[cfg(not(feature = "spirv"))] + #[cfg(not(feature = "vulkan"))] burn_jit::testgen_all!([f32], [i32], [u32]); } diff --git a/crates/burn/Cargo.toml b/crates/burn/Cargo.toml index cd13682a4b..b0abf7d178 100644 --- a/crates/burn/Cargo.toml +++ b/crates/burn/Cargo.toml @@ -50,15 +50,16 @@ openblas-system = ["burn-core/openblas-system"] template = ["burn-core/template"] candle = ["burn-core/candle"] -cuda-jit = ["burn-core/cuda-jit"] -hip-jit = ["burn-core/hip-jit"] +cuda = ["burn-core/cuda"] +hip = ["burn-core/hip"] ndarray = ["burn-core/ndarray"] remote = ["burn-core/remote"] router = ["burn-core/router"] server = ["burn-core/server"] tch = ["burn-core/tch"] wgpu = ["burn-core/wgpu"] -wgpu-spirv = ["burn-core/wgpu-spirv"] +vulkan = ["burn-core/vulkan"] +webgpu = ["burn-core/webgpu"] # Network utils network = ["burn-core/network"] diff --git a/crates/burn/src/lib.rs b/crates/burn/src/lib.rs index b0ecf06a71..203d1a802d 100644 --- a/crates/burn/src/lib.rs +++ b/crates/burn/src/lib.rs @@ -76,12 +76,14 @@ //! - `vision`: Enables vision datasets (MnistDataset) //! - Backends //! - `wgpu`: Makes available the WGPU backend -//! - `wgpu-spirv`: Makes available the `wgpu` backend with the alternative SPIR-V compiler +//! - `webgpu`: Makes available the `wgpu` backend with the WebGPU Shading Language (WGSL) compiler +//! - `vulkan`: Makes available the `wgpu` backend with the alternative SPIR-V compiler +//! - `cuda`: Makes available the CUDA backend +//! - `hip`: Makes available the HIP backend //! - `candle`: Makes available the Candle backend //! - `tch`: Makes available the LibTorch backend //! - `ndarray`: Makes available the NdArray backend //! - Backend specifications -//! - `cuda`: If supported, CUDA will be used //! - `accelerate`: If supported, Accelerate will be used //! - `blas-netlib`: If supported, Blas Netlib will be use //! - `openblas`: If supported, Openblas will be use diff --git a/examples/custom-renderer/examples/custom-renderer.rs b/examples/custom-renderer/examples/custom-renderer.rs index ea580833df..aa344b1d2b 100644 --- a/examples/custom-renderer/examples/custom-renderer.rs +++ b/examples/custom-renderer/examples/custom-renderer.rs @@ -1,5 +1,5 @@ -use burn::backend::{wgpu::WgpuDevice, Autodiff, Wgpu}; +use burn::backend::{wgpu::WgpuDevice, Autodiff, WebGpu}; fn main() { - custom_renderer::run::>(WgpuDevice::default()); + custom_renderer::run::>(WgpuDevice::default()); } diff --git a/examples/custom-training-loop/Cargo.toml b/examples/custom-training-loop/Cargo.toml index 536307fdba..6e1fca1e92 100644 --- a/examples/custom-training-loop/Cargo.toml +++ b/examples/custom-training-loop/Cargo.toml @@ -7,7 +7,7 @@ publish = false version.workspace = true [dependencies] -burn = {path = "../../crates/burn", features=["autodiff", "wgpu", "vision"]} +burn = {path = "../../crates/burn", features=["autodiff", "webgpu", "vision"]} guide = {path = "../guide"} # Serialization diff --git a/examples/custom-training-loop/examples/custom-training-loop.rs b/examples/custom-training-loop/examples/custom-training-loop.rs index a418ede196..ec9d55f42a 100644 --- a/examples/custom-training-loop/examples/custom-training-loop.rs +++ b/examples/custom-training-loop/examples/custom-training-loop.rs @@ -1,5 +1,5 @@ -use burn::backend::{Autodiff, Wgpu}; +use burn::backend::{Autodiff, WebGpu}; fn main() { - custom_training_loop::run::>(Default::default()); + custom_training_loop::run::>(Default::default()); } diff --git a/examples/guide/Cargo.toml b/examples/guide/Cargo.toml index e60b8d45e5..aea61f5e25 100644 --- a/examples/guide/Cargo.toml +++ b/examples/guide/Cargo.toml @@ -10,7 +10,7 @@ version.workspace = true default = ["burn/default"] [dependencies] -burn = {path = "../../crates/burn", features = ["wgpu", "train", "vision"]} +burn = {path = "../../crates/burn", features = ["webgpu", "train", "vision"]} # Serialization log = {workspace = true} diff --git a/examples/guide/src/bin/infer.rs b/examples/guide/src/bin/infer.rs index 6a246d85f0..44c5b1dabc 100644 --- a/examples/guide/src/bin/infer.rs +++ b/examples/guide/src/bin/infer.rs @@ -1,9 +1,9 @@ #![recursion_limit = "131"] -use burn::{backend::Wgpu, data::dataset::Dataset}; +use burn::{backend::WebGpu, data::dataset::Dataset}; use guide::inference; fn main() { - type MyBackend = Wgpu; + type MyBackend = WebGpu; let device = burn::backend::wgpu::WgpuDevice::default(); diff --git a/examples/guide/src/bin/print.rs b/examples/guide/src/bin/print.rs index 9432aa93a4..6f3b710c25 100644 --- a/examples/guide/src/bin/print.rs +++ b/examples/guide/src/bin/print.rs @@ -1,8 +1,8 @@ -use burn::backend::Wgpu; +use burn::backend::WebGpu; use guide::model::ModelConfig; fn main() { - type MyBackend = Wgpu; + type MyBackend = WebGpu; let device = Default::default(); let model = ModelConfig::new(10, 512).init::(&device); diff --git a/examples/guide/src/bin/train.rs b/examples/guide/src/bin/train.rs index 04f1f44146..a4acf02b69 100644 --- a/examples/guide/src/bin/train.rs +++ b/examples/guide/src/bin/train.rs @@ -1,5 +1,5 @@ use burn::{ - backend::{Autodiff, Wgpu}, + backend::{Autodiff, WebGpu}, data::dataset::Dataset, optim::AdamConfig, }; @@ -10,7 +10,7 @@ use guide::{ }; fn main() { - type MyBackend = Wgpu; + type MyBackend = WebGpu; type MyAutodiffBackend = Autodiff; // Create a default Wgpu device diff --git a/examples/image-classification-web/src/web.rs b/examples/image-classification-web/src/web.rs index 4b20507abc..a9868099f6 100644 --- a/examples/image-classification-web/src/web.rs +++ b/examples/image-classification-web/src/web.rs @@ -14,7 +14,7 @@ use burn::{ tensor::activation::softmax, }; -use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice}; +use burn::backend::wgpu::{graphics::AutoGraphicsApi, WebGpu, WgpuDevice}; use burn_candle::Candle; use serde::Serialize; @@ -37,8 +37,8 @@ pub enum ModelType { /// The model is loaded to the NdArray backend WithNdArrayBackend(Model>), - /// The model is loaded to the Wgpu backend - WithWgpuBackend(Model>), + /// The model is loaded to the WebGpu backend + WithWgpuBackend(Model>), } /// The image is 224x224 pixels with 3 channels (RGB) diff --git a/examples/server/Cargo.toml b/examples/server/Cargo.toml index bb4824fba9..f9f80bdb8d 100644 --- a/examples/server/Cargo.toml +++ b/examples/server/Cargo.toml @@ -7,10 +7,10 @@ publish = false version.workspace = true [features] -default = ["wgpu"] -cuda-jit = ["burn/cuda-jit"] -wgpu = ["burn/wgpu"] -wgpu-spirv = ["wgpu", "burn/wgpu-spirv"] +default = ["webgpu"] +cuda = ["burn/cuda"] +webgpu = ["burn/webgpu"] +vulkan = ["burn/vulkan"] ndarray = ["burn/ndarray"] [dependencies] diff --git a/examples/server/src/lib.rs b/examples/server/src/lib.rs index 70705a0876..014a5e2cf5 100644 --- a/examples/server/src/lib.rs +++ b/examples/server/src/lib.rs @@ -11,10 +11,12 @@ pub fn start() { cfg_if::cfg_if! { if #[cfg(feature = "ndarray")]{ burn::server::start::(Default::default(), port); - } else if #[cfg(feature = "cuda-jit")]{ - burn::server::start::(Default::default(), port); - } else if #[cfg(feature = "wgpu")] { - burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "cuda")]{ + burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "webgpu")] { + burn::server::start::(Default::default(), port); + } else if #[cfg(feature = "vulkan")] { + burn::server::start::(Default::default(), port); } else { panic!("No backend selected, can't start server on port {port}"); } diff --git a/examples/text-classification/Cargo.toml b/examples/text-classification/Cargo.toml index 4ec5d7c89a..043c61672d 100644 --- a/examples/text-classification/Cargo.toml +++ b/examples/text-classification/Cargo.toml @@ -16,10 +16,10 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] -wgpu-spirv = ["wgpu", "burn/wgpu-spirv"] +vulkan = ["wgpu", "burn/vulkan"] remote = ["burn/remote"] -cuda-jit = ["burn/cuda-jit"] -hip-jit = ["burn/hip-jit"] +cuda = ["burn/cuda"] +hip = ["burn/hip"] [dependencies] # Burn diff --git a/examples/text-classification/README.md b/examples/text-classification/README.md index 8bc611361f..9d62606706 100644 --- a/examples/text-classification/README.md +++ b/examples/text-classification/README.md @@ -102,6 +102,6 @@ cd burn # Use the --release flag to really speed up training. # AG News -cargo run --example ag-news-train --release --features cuda-jit # Train on the ag news dataset -cargo run --example ag-news-infer --release --features cuda-jit # Run inference on the ag news dataset +cargo run --example ag-news-train --release --features cuda # Train on the ag news dataset +cargo run --example ag-news-infer --release --features cuda # Run inference on the ag news dataset ``` diff --git a/examples/text-classification/examples/ag-news-infer.rs b/examples/text-classification/examples/ag-news-infer.rs index 9af5c6c6eb..77626e0b60 100644 --- a/examples/text-classification/examples/ag-news-infer.rs +++ b/examples/text-classification/examples/ag-news-infer.rs @@ -81,13 +81,13 @@ mod wgpu { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::{launch, ElemType}; - use burn::backend::{cuda_jit::CudaDevice, CudaJit}; + use burn::backend::{cuda::CudaDevice, Cuda}; pub fn run() { - launch::>(CudaDevice::default()); + launch::>(CudaDevice::default()); } } @@ -105,6 +105,6 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); } diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 1be9803a15..9a9cab44bd 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -103,18 +103,18 @@ mod remote { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::{launch, ElemType}; - use burn::backend::{Autodiff, CudaJit}; + use burn::backend::{Autodiff, Cuda}; pub fn run() { - launch::>>(vec![Default::default()]); + launch::>>(vec![Default::default()]); } } -#[cfg(feature = "hip-jit")] -mod hip_jit { +#[cfg(feature = "hip")] +mod hip { use crate::{launch, ElemType}; use burn::backend::{Autodiff, HipJit}; @@ -137,10 +137,10 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); - #[cfg(feature = "hip-jit")] - hip_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); + #[cfg(feature = "hip")] + hip::run(); #[cfg(feature = "remote")] remote::run(); } diff --git a/examples/wgan/Cargo.toml b/examples/wgan/Cargo.toml index 48d5680f51..d6ee6345b1 100644 --- a/examples/wgan/Cargo.toml +++ b/examples/wgan/Cargo.toml @@ -11,7 +11,7 @@ ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu"] -cuda-jit = ["burn/cuda-jit"] +cuda = ["burn/cuda"] [dependencies] burn = { path = "../../crates/burn", features=["train", "vision"] } diff --git a/examples/wgan/README.md b/examples/wgan/README.md index d7252ba520..0828145f61 100644 --- a/examples/wgan/README.md +++ b/examples/wgan/README.md @@ -12,7 +12,7 @@ Please note that better performance maybe gained by adopting a convolution layer ```sh # Cuda backend -cargo run --example wgan-mnist --release --features cuda-jit +cargo run --example wgan-mnist --release --features cuda # Wgpu backend cargo run --example wgan-mnist --release --features wgpu @@ -36,5 +36,5 @@ cargo run --example wgan-mnist --release --features ndarray-blas-netlib # f32 To generate a sample of images, you can use `wgan-generate`. The same feature flags are used to select a backend. ```sh -cargo run --example wgan-generate --release --features cuda-jit +cargo run --example wgan-generate --release --features cuda ``` diff --git a/examples/wgan/examples/wgan-generate.rs b/examples/wgan/examples/wgan-generate.rs index fa66623ca3..1b4a51a535 100644 --- a/examples/wgan/examples/wgan-generate.rs +++ b/examples/wgan/examples/wgan-generate.rs @@ -66,13 +66,13 @@ mod wgpu { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::launch; - use burn::backend::{Autodiff, CudaJit}; + use burn::backend::{Autodiff, Cuda}; pub fn run() { - launch::>(Default::default()); + launch::>(Default::default()); } } @@ -90,6 +90,6 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); } diff --git a/examples/wgan/examples/wgan-mnist.rs b/examples/wgan/examples/wgan-mnist.rs index d964b07844..787acfec94 100644 --- a/examples/wgan/examples/wgan-mnist.rs +++ b/examples/wgan/examples/wgan-mnist.rs @@ -78,13 +78,13 @@ mod wgpu { } } -#[cfg(feature = "cuda-jit")] -mod cuda_jit { +#[cfg(feature = "cuda")] +mod cuda { use crate::launch; - use burn::backend::{cuda_jit::CudaDevice, Autodiff, CudaJit}; + use burn::backend::{cuda::CudaDevice, Autodiff, Cuda}; pub fn run() { - launch::>(CudaDevice::default()); + launch::>(CudaDevice::default()); } } @@ -102,6 +102,6 @@ fn main() { tch_cpu::run(); #[cfg(feature = "wgpu")] wgpu::run(); - #[cfg(feature = "cuda-jit")] - cuda_jit::run(); + #[cfg(feature = "cuda")] + cuda::run(); } diff --git a/xtask/src/commands/test.rs b/xtask/src/commands/test.rs index 47e50f80ed..5b94b2909e 100644 --- a/xtask/src/commands/test.rs +++ b/xtask/src/commands/test.rs @@ -83,7 +83,7 @@ pub(crate) fn handle_command( vec!["--features", "test-wgpu-spirv"], None, None, - "std wgpu-spirv", + "std vulkan", )?; } } From 04c5e74daac2adbfda4f5293746bc37ab6be56ea Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Tue, 28 Jan 2025 09:30:00 -0500 Subject: [PATCH 53/61] Fix types under autotune flag (#2750) --- crates/burn-core/Cargo.toml | 2 +- crates/burn-jit/src/fusion/matmul/optimization.rs | 2 +- crates/burn-jit/src/kernel/reduce/base.rs | 4 +++- crates/burn-jit/src/kernel/reduce/tune.rs | 2 ++ 4 files changed, 7 insertions(+), 3 deletions(-) diff --git a/crates/burn-core/Cargo.toml b/crates/burn-core/Cargo.toml index 5d926cd0b3..423dc784d8 100644 --- a/crates/burn-core/Cargo.toml +++ b/crates/burn-core/Cargo.toml @@ -88,7 +88,7 @@ fusion = ["burn-wgpu?/fusion", "burn-cuda?/fusion"] ## Backend features accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"] -autotune = ["burn-wgpu?/autotune"] +autotune = ["burn-wgpu?/autotune", "burn-cuda?/autotune", "burn-hip?/autotune"] blas-netlib = ["burn-ndarray?/blas-netlib"] metal = ["burn-candle?/metal"] openblas = ["burn-ndarray?/blas-openblas"] diff --git a/crates/burn-jit/src/fusion/matmul/optimization.rs b/crates/burn-jit/src/fusion/matmul/optimization.rs index 9a020df62c..804628613d 100644 --- a/crates/burn-jit/src/fusion/matmul/optimization.rs +++ b/crates/burn-jit/src/fusion/matmul/optimization.rs @@ -87,7 +87,7 @@ impl MatmulOptimization { fused_matmul_autotune::(self, context); #[cfg(not(feature = "autotune"))] - if self.execute_fused::(context).is_err() { + if self.execute_standard_fused::(context).is_err() { self.execute_fallback::(context); } } diff --git a/crates/burn-jit/src/kernel/reduce/base.rs b/crates/burn-jit/src/kernel/reduce/base.rs index 9ec677ee93..ccfcc3ef9e 100644 --- a/crates/burn-jit/src/kernel/reduce/base.rs +++ b/crates/burn-jit/src/kernel/reduce/base.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "autotune")] use super::{autotune_reduce, autotune_sum}; use crate::{ element::JitElement, @@ -31,6 +32,7 @@ pub fn sum( )) } SumStrategy::Chained(strategy) => reduce::(tensor, strategy), + #[cfg(feature = "autotune")] SumStrategy::Autotune => Ok(autotune_sum::(&client, tensor)), } } @@ -53,7 +55,7 @@ impl Default for SumStrategy { return Self::Autotune; #[cfg(not(feature = "autotune"))] - return Self::Static(4); + return Self::OneShot(4); } } diff --git a/crates/burn-jit/src/kernel/reduce/tune.rs b/crates/burn-jit/src/kernel/reduce/tune.rs index c397af1a04..cd5cd61157 100644 --- a/crates/burn-jit/src/kernel/reduce/tune.rs +++ b/crates/burn-jit/src/kernel/reduce/tune.rs @@ -208,6 +208,7 @@ mod reduce_ops { } /// Executes autotune on reduce operations. +#[cfg(feature = "autotune")] pub fn autotune_sum( client: &ComputeClient, input: JitTensor, @@ -280,6 +281,7 @@ mod sum_ops { .map_err(|e| e.to_string()) } + #[cfg(feature = "autotune")] pub(crate) fn sum_chained( input: JitTensor, ) -> Result, String> { From 390317259b12cbe43b5a2e38ded5cecfdb2136ab Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Thu, 30 Jan 2025 12:07:50 -0500 Subject: [PATCH 54/61] Upgrade to polars 0.46.0 (#2757) --- Cargo.lock | 743 +++---------------- Cargo.toml | 2 +- crates/burn-dataset/src/dataset/dataframe.rs | 22 +- 3 files changed, 116 insertions(+), 651 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5f5b948a86..cb75c0ac6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -62,21 +62,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1" -[[package]] -name = "alloc-no-stdlib" -version = "2.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc7bb162ec39d46ab1ca8c77bf72e890535becd1751bb45f64c597edb4c8c6b3" - -[[package]] -name = "alloc-stdlib" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94fb8275041c72129eb51b7d0322c29b8387a0386127718b096429201a5d6ece" -dependencies = [ - "alloc-no-stdlib", -] - [[package]] name = "allocator-api2" version = "0.2.21" @@ -207,12 +192,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" -[[package]] -name = "arrayref" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" - [[package]] name = "arrayvec" version = "0.7.6" @@ -284,20 +263,11 @@ dependencies = [ "syn 2.0.96", ] -[[package]] -name = "atoi" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" -dependencies = [ - "num-traits", -] - [[package]] name = "atoi_simd" -version = "0.15.6" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ae037714f313c1353189ead58ef9eec30a8e8dc101b2622d461418fd59e28a9" +checksum = "4790f9e8961209112beb783d85449b508673cf4a6a419c8449b210743ac4dbe9" [[package]] name = "atomic-waker" @@ -524,19 +494,6 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2" -[[package]] -name = "blake3" -version = "1.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8ee0c1824c4dea5b5f81736aff91bae041d2c07ee1192bec91054e10e3e601e" -dependencies = [ - "arrayref", - "arrayvec", - "cc", - "cfg-if", - "constant_time_eq 0.3.1", -] - [[package]] name = "blas-src" version = "0.10.0" @@ -572,27 +529,6 @@ dependencies = [ "objc2", ] -[[package]] -name = "brotli" -version = "6.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", - "brotli-decompressor", -] - -[[package]] -name = "brotli-decompressor" -version = "4.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362" -dependencies = [ - "alloc-no-stdlib", - "alloc-stdlib", -] - [[package]] name = "bstr" version = "1.11.3" @@ -1013,6 +949,9 @@ name = "bytes" version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +dependencies = [ + "serde", +] [[package]] name = "bytesize" @@ -1055,7 +994,7 @@ dependencies = [ "gemm", "half", "libc", - "memmap2 0.9.5", + "memmap2", "metal 0.27.0", "num-traits", "num_cpus", @@ -1159,16 +1098,15 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", - "serde", "wasm-bindgen", "windows-targets 0.52.6", ] [[package]] name = "chrono-tz" -version = "0.8.6" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d59ae0466b83e838b81a54256c39d5d7c20b9d7daa10510a242d9b75abd5936e" +checksum = "9c6ac4f2c0bf0f44e9161aec9675e1050aa4a530663c4a9e37e108fa948bca9f" dependencies = [ "chrono", "chrono-tz-build", @@ -1177,42 +1115,14 @@ dependencies = [ [[package]] name = "chrono-tz-build" -version = "0.2.1" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "433e39f13c9a060046954e0592a8d0a4bcb1040125cbf91cb8ee58964cfb350f" +checksum = "e94fea34d77a245229e7746bd2beb786cd2a896f306ff491fb8cecb3074b10a7" dependencies = [ "parse-zoneinfo", - "phf", "phf_codegen", ] -[[package]] -name = "ciborium" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" -dependencies = [ - "ciborium-io", - "ciborium-ll", - "serde", -] - -[[package]] -name = "ciborium-io" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" - -[[package]] -name = "ciborium-ll" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" -dependencies = [ - "ciborium-io", - "half", -] - [[package]] name = "cipher" version = "0.4.4" @@ -1251,7 +1161,7 @@ version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.96", @@ -1394,16 +1304,6 @@ dependencies = [ "libc", ] -[[package]] -name = "core-foundation" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1417,7 +1317,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c07782be35f9e1140080c6b96f0d44b739e2278479f64e02fdab4e32dfd8b081" dependencies = [ "bitflags 1.3.2", - "core-foundation 0.9.4", + "core-foundation", "core-graphics-types", "foreign-types 0.5.0", "libc", @@ -1430,7 +1330,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" dependencies = [ "bitflags 1.3.2", - "core-foundation 0.9.4", + "core-foundation", "libc", ] @@ -2121,12 +2021,6 @@ dependencies = [ "syn 2.0.96", ] -[[package]] -name = "doc-comment" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" - [[package]] name = "document-features" version = "0.2.10" @@ -2167,9 +2061,6 @@ name = "either" version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" -dependencies = [ - "serde", -] [[package]] name = "embassy-futures" @@ -2198,7 +2089,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.96", @@ -2332,10 +2223,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" [[package]] -name = "fast-float" -version = "0.2.0" +name = "fast-float2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95765f67b4b18863968b4a1bd5bb576f732b29a4a28c7cd84c09fa3e2875f33c" +checksum = "f8eb564c5c7423d25c886fb561d1e4ee69f72354d16918afa32c08811f6b6a55" [[package]] name = "faster-hex" @@ -2467,16 +2358,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "fs4" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8c6b3bd49c37d2aa3f3f2220233b29a7cd23f79d1fe70e5337d25fb390793de" -dependencies = [ - "rustix", - "windows-sys 0.52.0", -] - [[package]] name = "futures" version = "0.3.31" @@ -2973,16 +2854,6 @@ dependencies = [ "serde", ] -[[package]] -name = "halfbrown" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8588661a8607108a5ca69cab034063441a0413a0b041c13618a7dd348021ef6f" -dependencies = [ - "hashbrown 0.14.5", - "serde", -] - [[package]] name = "hashbrown" version = "0.13.2" @@ -3026,12 +2897,6 @@ dependencies = [ "hashbrown 0.14.5", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -3181,7 +3046,6 @@ dependencies = [ "hyper", "hyper-util", "rustls", - "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", "tokio-rustls", @@ -3582,12 +3446,6 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" -[[package]] -name = "itoap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9028f49264629065d057f340a86acb84867925865f73bbf8d47b4d149a7e88b8" - [[package]] name = "jni-sys" version = "0.3.0" @@ -3852,16 +3710,6 @@ dependencies = [ "rayon", ] -[[package]] -name = "md-5" -version = "0.10.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" -dependencies = [ - "cfg-if", - "digest", -] - [[package]] name = "md5" version = "0.7.0" @@ -3874,15 +3722,6 @@ version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" -[[package]] -name = "memmap2" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" -dependencies = [ - "libc", -] - [[package]] name = "memmap2" version = "0.9.5" @@ -3893,15 +3732,6 @@ dependencies = [ "stable_deref_trait", ] -[[package]] -name = "memoffset" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" -dependencies = [ - "autocfg", -] - [[package]] name = "metal" version = "0.27.0" @@ -4031,28 +3861,6 @@ dependencies = [ "syn 2.0.96", ] -[[package]] -name = "multiversion" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4851161a11d3ad0bf9402d90ffc3967bf231768bfd7aeb61755ad06dbf1a142" -dependencies = [ - "multiversion-macros", - "target-features", -] - -[[package]] -name = "multiversion-macros" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79a74ddee9e0c27d2578323c13905793e91622148f138ba29738f9dddb835e90" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", - "target-features", -] - [[package]] name = "naga" version = "24.0.0" @@ -4067,7 +3875,7 @@ dependencies = [ "hexf-parse", "indexmap", "log", - "rustc-hash 1.1.0", + "rustc-hash", "spirv 0.3.0+sdk-1.3.268.0", "strum", "termcolor", @@ -4095,7 +3903,7 @@ dependencies = [ "openssl-probe", "openssl-sys", "schannel", - "security-framework 2.11.1", + "security-framework", "security-framework-sys", "tempfile", ] @@ -4493,36 +4301,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "object_store" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6da452820c715ce78221e8202ccc599b4a52f3e1eb3eedb487b680c81a8e3f3" -dependencies = [ - "async-trait", - "base64 0.22.1", - "bytes", - "chrono", - "futures", - "humantime", - "hyper", - "itertools 0.13.0", - "md-5", - "parking_lot 0.12.3", - "percent-encoding", - "quick-xml", - "rand", - "reqwest", - "ring", - "serde", - "serde_json", - "snafu", - "tokio", - "tracing", - "url", - "walkdir", -] - [[package]] name = "once_cell" version = "1.20.2" @@ -4889,9 +4667,9 @@ dependencies = [ [[package]] name = "polars" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65c6aa86d991a64c95416a61202f7952da2f8cccefa448f9a23c1b8f2301ecc" +checksum = "72571dde488ecccbe799798bf99ab7308ebdb7cf5d95bcc498dbd5a132f0da4d" dependencies = [ "getrandom", "polars-arrow", @@ -4909,12 +4687,11 @@ dependencies = [ [[package]] name = "polars-arrow" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87dbb24d29ddea5abb73d7954df8b8d3d4bb7f02a3e5c96d1519cdad9e816a3d" +checksum = "6611c758d52e799761cc25900666b71552e6c929d88052811bc9daad4b3321a8" dependencies = [ "ahash", - "atoi", "atoi_simd", "bytemuck", "chrono", @@ -4922,21 +4699,16 @@ dependencies = [ "dyn-clone", "either", "ethnum", - "fast-float", "getrandom", "hashbrown 0.15.2", "itoa", - "itoap", "lz4", - "multiversion", "num-traits", "parking_lot 0.12.3", "polars-arrow-format", "polars-error", "polars-schema", "polars-utils", - "ryu", - "serde", "simdutf8", "streaming-iterator", "strength_reduce", @@ -4957,25 +4729,30 @@ dependencies = [ [[package]] name = "polars-compute" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbdb1071147452a4c4b25560f23d2fbaffef255b04757291131b22fc2c0d35b2" +checksum = "332f2547dbb27599a8ffe68e56159f5996ba03d1dad0382ccb62c109ceacdeb6" dependencies = [ + "atoi_simd", "bytemuck", + "chrono", "either", + "fast-float2", + "itoa", "num-traits", "polars-arrow", "polars-error", "polars-utils", + "ryu", "strength_reduce", "version_check", ] [[package]] name = "polars-core" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd5df9b55e614088a3270b06f8649dce76537c268d6b1ca4d9c37008b2be5949" +checksum = "796d06eae7e6e74ed28ea54a8fccc584ebac84e6cf0e1e9ba41ffc807b169a01" dependencies = [ "ahash", "bitflags 2.8.0", @@ -4987,6 +4764,7 @@ dependencies = [ "hashbrown 0.14.5", "hashbrown 0.15.2", "indexmap", + "itoa", "num-traits", "once_cell", "polars-arrow", @@ -4999,32 +4777,29 @@ dependencies = [ "rand_distr", "rayon", "regex", - "serde", - "serde_json", "strum_macros", - "thiserror 1.0.69", + "thiserror 2.0.11", "version_check", "xxhash-rust", ] [[package]] name = "polars-error" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4643898a644f30c83737db85f942f8c8956b0c11190b39afec745218eae1746b" +checksum = "19d6529cae0d1db5ed690e47de41fac9b35ae0c26d476830c2079f130887b847" dependencies = [ - "object_store", "polars-arrow-format", "regex", "simdutf8", - "thiserror 1.0.69", + "thiserror 2.0.11", ] [[package]] name = "polars-expr" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea1b431ed816cba1120cff200f06b962748001bbb2e615ce53cfbbdf701cc136" +checksum = "c8e639991a8ad4fb12880ab44bcc3cf44a5703df003142334d9caf86d77d77e7" dependencies = [ "ahash", "bitflags 2.8.0", @@ -5046,80 +4821,50 @@ dependencies = [ [[package]] name = "polars-io" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2fab2c016635cb416b49461fd6419b0208c6c13a4fd065bd65e4a87dbb66314" +checksum = "719a77e94480f6be090512da196e378cbcbeb3584c6fe1134c600aee906e38ab" dependencies = [ "ahash", "async-trait", "atoi_simd", - "blake3", "bytes", "chrono", - "fast-float", - "fs4", + "fast-float2", "futures", "glob", "hashbrown 0.15.2", "home", "itoa", "memchr", - "memmap2 0.7.1", + "memmap2", "num-traits", - "object_store", "once_cell", "percent-encoding", "polars-arrow", "polars-core", "polars-error", - "polars-json", "polars-parquet", "polars-schema", "polars-time", "polars-utils", - "pyo3", "rayon", "regex", - "reqwest", "ryu", - "serde", - "serde_json", - "simd-json", "simdutf8", "tokio", "tokio-util", - "url", -] - -[[package]] -name = "polars-json" -version = "0.44.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5c8c057ef04feaf34b6ce52096bdea3a766fa4725f50442078c8a4ee86397bf" -dependencies = [ - "ahash", - "chrono", - "fallible-streaming-iterator", - "hashbrown 0.15.2", - "indexmap", - "itoa", - "num-traits", - "polars-arrow", - "polars-error", - "polars-utils", - "ryu", - "simd-json", - "streaming-iterator", ] [[package]] name = "polars-lazy" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a8ca74f42e7b47cad241b36b98d991cc7fbb51b8d0695a055eb937588d1f310" +checksum = "a0a731a672dfc8ac38c1f73c9a4b2ae38d2fc8ac363bfb64c5f3a3e072ffc5ad" dependencies = [ "ahash", "bitflags 2.8.0", + "chrono", "memchr", "once_cell", "polars-arrow", @@ -5139,32 +4884,28 @@ dependencies = [ [[package]] name = "polars-mem-engine" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a32614e5b52c9b83856d80c7e2880b79d83055bfd59969bd1d0b148f9cfdc7a" +checksum = "33442189bcbf2e2559aa7914db3835429030a13f4f18e43af5fba9d1b018cf12" dependencies = [ - "futures", - "memmap2 0.7.1", + "memmap2", "polars-arrow", "polars-core", "polars-error", "polars-expr", "polars-io", - "polars-json", "polars-ops", "polars-plan", "polars-time", "polars-utils", - "pyo3", "rayon", - "tokio", ] [[package]] name = "polars-ops" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "035c800fbe5bbd820afeb8313713ed345853bb014e0f821a4025d40cf0d60e1a" +checksum = "cbb83218b0c216104f0076cd1a005128be078f958125f3d59b094ee73d78c18e" dependencies = [ "ahash", "argminmax", @@ -5178,6 +4919,7 @@ dependencies = [ "indexmap", "memchr", "num-traits", + "once_cell", "polars-arrow", "polars-compute", "polars-core", @@ -5187,39 +4929,33 @@ dependencies = [ "rayon", "regex", "regex-syntax 0.8.5", - "serde", "strum_macros", + "unicode-normalization", "unicode-reverse", "version_check", ] [[package]] name = "polars-parquet" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91dcf1d9f048079376949eaf2e24e240b313ff4a102fb83b57c9a5f807cdca52" +checksum = "5c60ee85535590a38db6c703a21be4cb25342e40f573f070d1e16f9d84a53ac7" dependencies = [ "ahash", "async-stream", "base64 0.22.1", - "brotli", "bytemuck", "ethnum", - "flate2", "futures", "hashbrown 0.15.2", - "lz4", "num-traits", "polars-arrow", "polars-compute", "polars-error", "polars-parquet-format", "polars-utils", - "serde", "simdutf8", - "snap", "streaming-decompression", - "zstd 0.13.2", ] [[package]] @@ -5234,15 +4970,16 @@ dependencies = [ [[package]] name = "polars-pipe" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05936f2b3981eecb2fe74d8ef092bb75a93d2a056b3e4f339f4ac20c71c9e331" +checksum = "42d238fb76698f56e51ddfa89b135e4eda56a4767c6e8859eed0ab78386fcd52" dependencies = [ "crossbeam-channel", "crossbeam-queue", "enum_dispatch", "hashbrown 0.15.2", "num-traits", + "once_cell", "polars-arrow", "polars-compute", "polars-core", @@ -5259,9 +4996,9 @@ dependencies = [ [[package]] name = "polars-plan" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23de436f33f4d1134c58f24e7059a221b957ec20730807e0ef0c80c8e4b3d06a" +checksum = "4f03533a93aa66127fcb909a87153a3c7cfee6f0ae59f497e73d7736208da54c" dependencies = [ "ahash", "bitflags 2.8.0", @@ -5269,65 +5006,59 @@ dependencies = [ "bytes", "chrono", "chrono-tz", - "ciborium", "either", - "futures", "hashbrown 0.15.2", - "memmap2 0.7.1", + "memmap2", "num-traits", "once_cell", "percent-encoding", "polars-arrow", + "polars-compute", "polars-core", "polars-io", - "polars-json", "polars-ops", - "polars-parquet", "polars-time", "polars-utils", - "pyo3", "rayon", "recursive", "regex", - "serde", "strum_macros", "version_check", ] [[package]] name = "polars-row" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3823d3de3e614509bba6929798f1f3d5ae05c1cdfc4eb7029d2ec6ad77201da2" +checksum = "6bf47f7409f8e75328d7d034be390842924eb276716d0458607be0bddb8cc839" dependencies = [ + "bitflags 2.8.0", "bytemuck", "polars-arrow", + "polars-compute", "polars-error", "polars-utils", ] [[package]] name = "polars-schema" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d88667f770291cefa2e8cd366a54f29dc6fe362e9a263914c903db411a58ac1d" +checksum = "416621ae82b84466cf4ff36838a9b0aeb4a67e76bd3065edc8c9cb7da19b1bc7" dependencies = [ "indexmap", "polars-error", "polars-utils", - "serde", "version_check", ] [[package]] name = "polars-sql" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69451f08363bb497407f6ebebe00bc01972a51716d20d115b75f9b5326f1f3c8" +checksum = "edaab553b90aa4d6743bb538978e1982368acb58a94408d7dd3299cad49c7083" dependencies = [ "hex", - "once_cell", - "polars-arrow", "polars-core", "polars-error", "polars-lazy", @@ -5336,22 +5067,22 @@ dependencies = [ "polars-time", "polars-utils", "rand", + "regex", "serde", - "serde_json", "sqlparser", ] [[package]] name = "polars-stream" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "188622b0a4bc4530cf91a288134254ffa065d18932e261075377914225e757c2" +checksum = "498997b656c779610c1496b3d96a59fe569ef22a5b81ccfe5325cb3df8dff2fd" dependencies = [ "atomic-waker", "crossbeam-deque", "crossbeam-utils", "futures", - "memmap2 0.7.1", + "memmap2", "parking_lot 0.12.3", "pin-project-lite", "polars-core", @@ -5359,6 +5090,7 @@ dependencies = [ "polars-expr", "polars-io", "polars-mem-engine", + "polars-ops", "polars-parquet", "polars-plan", "polars-utils", @@ -5372,31 +5104,33 @@ dependencies = [ [[package]] name = "polars-time" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90f36e4d6b19f2c406faea585b9a1814f422fc5b310f65ccf8a55216df0754ef" +checksum = "d192efbdab516d28b3fab1709a969e3385bd5cda050b7c9aa9e2502a01fda879" dependencies = [ - "atoi", + "atoi_simd", "bytemuck", "chrono", "chrono-tz", "now", + "num-traits", "once_cell", "polars-arrow", + "polars-compute", "polars-core", "polars-error", "polars-ops", "polars-utils", + "rayon", "regex", - "serde", "strum_macros", ] [[package]] name = "polars-utils" -version = "0.44.2" +version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96186b70bda00c90b5027bf2f69193c5c40571e80d3e8ec505c22cdc8e3e39aa" +checksum = "a8f6c8166a4a7fbc15b87c81645ed9e1f0651ff2e8c96cafc40ac5bf43441a10" dependencies = [ "ahash", "bytemuck", @@ -5405,16 +5139,15 @@ dependencies = [ "hashbrown 0.15.2", "indexmap", "libc", - "memmap2 0.7.1", + "memmap2", "num-traits", "once_cell", "polars-error", - "pyo3", + "rand", "raw-cpuid 11.2.0", "rayon", - "serde", "stacker", - "sysinfo 0.31.4", + "sysinfo 0.33.1", "version_check", ] @@ -5584,69 +5317,6 @@ dependencies = [ "reborrow", ] -[[package]] -name = "pyo3" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e00b96a521718e08e03b1a622f01c8a8deb50719335de3f60b3b3950f069d8" -dependencies = [ - "cfg-if", - "indoc", - "libc", - "memoffset", - "parking_lot 0.12.3", - "portable-atomic", - "pyo3-build-config", - "pyo3-ffi", - "pyo3-macros", - "unindent", -] - -[[package]] -name = "pyo3-build-config" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7883df5835fafdad87c0d888b266c8ec0f4c9ca48a5bed6bbb592e8dedee1b50" -dependencies = [ - "once_cell", - "target-lexicon", -] - -[[package]] -name = "pyo3-ffi" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01be5843dc60b916ab4dad1dca6d20b9b4e6ddc8e15f50c47fe6d85f1fb97403" -dependencies = [ - "libc", - "pyo3-build-config", -] - -[[package]] -name = "pyo3-macros" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77b34069fc0682e11b31dbd10321cbf94808394c56fd996796ce45217dfac53c" -dependencies = [ - "proc-macro2", - "pyo3-macros-backend", - "quote", - "syn 2.0.96", -] - -[[package]] -name = "pyo3-macros-backend" -version = "0.21.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08260721f32db5e1a5beae69a55553f56b99bd0e1c3e6e0a5e8851a9d0f5a85c" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "pyo3-build-config", - "quote", - "syn 2.0.96", -] - [[package]] name = "pytorch-import" version = "0.17.0" @@ -5683,68 +5353,6 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" -[[package]] -name = "quick-xml" -version = "0.36.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7649a7b4df05aed9ea7ec6f628c67c9953a43869b8bc50929569b2999d443fe" -dependencies = [ - "memchr", - "serde", -] - -[[package]] -name = "quinn" -version = "0.11.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" -dependencies = [ - "bytes", - "pin-project-lite", - "quinn-proto", - "quinn-udp", - "rustc-hash 2.1.0", - "rustls", - "socket2", - "thiserror 2.0.11", - "tokio", - "tracing", -] - -[[package]] -name = "quinn-proto" -version = "0.11.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" -dependencies = [ - "bytes", - "getrandom", - "rand", - "ring", - "rustc-hash 2.1.0", - "rustls", - "rustls-pki-types", - "slab", - "thiserror 2.0.11", - "tinyvec", - "tracing", - "web-time", -] - -[[package]] -name = "quinn-udp" -version = "0.5.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed904" -dependencies = [ - "cfg_aliases", - "libc", - "once_cell", - "socket2", - "tracing", - "windows-sys 0.59.0", -] - [[package]] name = "quote" version = "1.0.38" @@ -6010,26 +5618,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "ref-cast" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccf0a6f84d5f1d581da8b41b47ec8600871962f2a528115b542b362d4b744931" -dependencies = [ - "ref-cast-impl", -] - -[[package]] -name = "ref-cast-impl" -version = "1.0.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcc303e793d3734489387d205e9b186fac9c6cfacedd98cbb2e8a5943595f3e6" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.96", -] - [[package]] name = "regex" version = "1.11.1" @@ -6114,11 +5702,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "quinn", - "rustls", - "rustls-native-certs 0.8.1", "rustls-pemfile", - "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", @@ -6126,14 +5710,11 @@ dependencies = [ "system-configuration", "tokio", "tokio-native-tls", - "tokio-rustls", - "tokio-util", "tower", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", - "wasm-streams", "web-sys", "windows-registry", ] @@ -6189,7 +5770,7 @@ name = "rspirv" version = "0.12.0+sdk-1.3.296.0" source = "git+https://github.com/gfx-rs/rspirv.git?rev=e19c11fdb30295127cff1d018189bd436892415e#e19c11fdb30295127cff1d018189bd436892415e" dependencies = [ - "rustc-hash 1.1.0", + "rustc-hash", "spirv 0.3.0+sdk-1.3.296.0", ] @@ -6259,12 +5840,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" -[[package]] -name = "rustc-hash" -version = "2.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" - [[package]] name = "rustc_version" version = "0.4.1" @@ -6312,19 +5887,7 @@ dependencies = [ "rustls-pemfile", "rustls-pki-types", "schannel", - "security-framework 2.11.1", -] - -[[package]] -name = "rustls-native-certs" -version = "0.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" -dependencies = [ - "openssl-probe", - "rustls-pki-types", - "schannel", - "security-framework 3.2.0", + "security-framework", ] [[package]] @@ -6341,9 +5904,6 @@ name = "rustls-pki-types" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" -dependencies = [ - "web-time", -] [[package]] name = "rustls-webpki" @@ -6462,20 +6022,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.8.0", - "core-foundation 0.9.4", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework" -version = "3.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" -dependencies = [ - "bitflags 2.8.0", - "core-foundation 0.10.0", + "core-foundation", "core-foundation-sys", "libc", "security-framework-sys", @@ -6702,23 +6249,6 @@ version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" -[[package]] -name = "simd-json" -version = "0.14.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2bcf6c6e164e81bc7a5d49fc6988b3d515d9e8c07457d7b74ffb9324b9cd40" -dependencies = [ - "ahash", - "getrandom", - "halfbrown", - "once_cell", - "ref-cast", - "serde", - "serde_json", - "simdutf8", - "value-trait", -] - [[package]] name = "simd_helpers" version = "0.1.0" @@ -6775,34 +6305,6 @@ version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" -[[package]] -name = "snafu" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" -dependencies = [ - "doc-comment", - "snafu-derive", -] - -[[package]] -name = "snafu-derive" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "990079665f075b699031e9c08fd3ab99be5029b96f3b78dc0709e8f77e4efebf" -dependencies = [ - "heck 0.4.1", - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "snap" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" - [[package]] name = "socket2" version = "0.5.8" @@ -6854,9 +6356,9 @@ dependencies = [ [[package]] name = "sqlparser" -version = "0.49.0" +version = "0.53.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4a404d0e14905361b918cb8afdb73605e25c1d5029312bd9785142dcb3aa49e" +checksum = "05a528114c392209b3264855ad491fcce534b94a38771b0a0b97a79379275ce8" dependencies = [ "log", ] @@ -6937,7 +6439,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "rustversion", @@ -7008,29 +6510,29 @@ dependencies = [ [[package]] name = "sysinfo" -version = "0.31.4" +version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "355dbe4f8799b304b05e1b0f05fc59b2a18d36645cf169607da45bde2f69a1be" +checksum = "4c33cd241af0f2e9e3b5c32163b873b29956890b5342e6745b917ce9d490f4af" dependencies = [ "core-foundation-sys", "libc", "memchr", "ntapi", + "rayon", + "serde", "windows 0.57.0", ] [[package]] name = "sysinfo" -version = "0.32.1" +version = "0.33.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c33cd241af0f2e9e3b5c32163b873b29956890b5342e6745b917ce9d490f4af" +checksum = "4fc858248ea01b66f19d8e8a6d55f41deaf91e9d495246fd01368d99935c6c01" dependencies = [ "core-foundation-sys", "libc", "memchr", "ntapi", - "rayon", - "serde", "windows 0.57.0", ] @@ -7041,7 +6543,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" dependencies = [ "bitflags 2.8.0", - "core-foundation 0.9.4", + "core-foundation", "system-configuration-sys", ] @@ -7062,7 +6564,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a3e535eb8dded36d55ec13eddacd30dec501792ff23a0b1682c38601b8cf2349" dependencies = [ "cfg-expr", - "heck 0.5.0", + "heck", "pkg-config", "toml", "version-compare", @@ -7093,12 +6595,6 @@ dependencies = [ "xattr", ] -[[package]] -name = "target-features" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" - [[package]] name = "target-lexicon" version = "0.12.16" @@ -7630,7 +7126,7 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "deb68604048ff8fa93347f02441e4487594adc20bb8a084f9e564d2b827a0a9f" dependencies = [ - "rustc-hash 1.1.0", + "rustc-hash", ] [[package]] @@ -7786,12 +7282,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" -[[package]] -name = "unindent" -version = "0.2.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" - [[package]] name = "untrusted" version = "0.9.0" @@ -7810,7 +7300,7 @@ dependencies = [ "native-tls", "once_cell", "rustls", - "rustls-native-certs 0.7.3", + "rustls-native-certs", "rustls-pki-types", "serde", "serde_json", @@ -7880,18 +7370,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" -[[package]] -name = "value-trait" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9170e001f458781e92711d2ad666110f153e4e50bfd5cbd02db6547625714187" -dependencies = [ - "float-cmp", - "halfbrown", - "itoa", - "ryu", -] - [[package]] name = "variadics_please" version = "1.1.0" @@ -8028,19 +7506,6 @@ dependencies = [ "web-sys", ] -[[package]] -name = "wasm-streams" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" -dependencies = [ - "futures-util", - "js-sys", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - [[package]] name = "wasm-timer" version = "0.2.5" @@ -8143,7 +7608,7 @@ dependencies = [ "parking_lot 0.12.3", "profiling", "raw-window-handle", - "rustc-hash 1.1.0", + "rustc-hash", "smallvec", "thiserror 2.0.11", "wgpu-hal", @@ -8186,7 +7651,7 @@ dependencies = [ "range-alloc", "raw-window-handle", "renderdoc-sys", - "rustc-hash 1.1.0", + "rustc-hash", "smallvec", "thiserror 2.0.11", "wasm-bindgen", diff --git a/Cargo.toml b/Cargo.toml index db1073ad67..09b944468c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,7 +54,7 @@ log = { default-features = false, version = "0.4.25" } md5 = "0.7.0" paste = "1" percent-encoding = "2.3.1" -polars = { version = "0.44.2", features = ["lazy"] } +polars = { version = "0.46.0", features = ["lazy"] } pretty_assertions = "1.4.1" proc-macro2 = "1.0.93" protobuf = "3.7.1" diff --git a/crates/burn-dataset/src/dataset/dataframe.rs b/crates/burn-dataset/src/dataset/dataframe.rs index 023b357454..c851e8a3e3 100644 --- a/crates/burn-dataset/src/dataset/dataframe.rs +++ b/crates/burn-dataset/src/dataset/dataframe.rs @@ -269,20 +269,20 @@ mod tests { } fn create_test_dataframe() -> DataFrame { - let s0 = Column::Series(Series::new("int32".into(), &[1i32, 2i32, 3i32])); - let s1 = Column::Series(Series::new("bool".into(), &[true, false, true])); - let s2 = Column::Series(Series::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64])); - let s3 = Column::Series(Series::new("string".into(), &["Boo", "Boo2", "Boo3"])); - let s6 = Column::Series(Series::new("int16".into(), &[1i16, 2i16, 3i16])); - let s8 = Column::Series(Series::new("uint32".into(), &[1u32, 2u32, 3u32])); - let s9 = Column::Series(Series::new("uint64".into(), &[1u64, 2u64, 3u64])); - let s10 = Column::Series(Series::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32])); - let s11 = Column::Series(Series::new("int64".into(), &[1i64, 2i64, 3i64])); - let s12 = Column::Series(Series::new("int8".into(), &[1i8, 2i8, 3i8])); + let s0 = Column::new("int32".into(), &[1i32, 2i32, 3i32]); + let s1 = Column::new("bool".into(), &[true, false, true]); + let s2 = Column::new("float64".into(), &[1.1f64, 2.2f64, 3.3f64]); + let s3 = Column::new("string".into(), &["Boo", "Boo2", "Boo3"]); + let s6 = Column::new("int16".into(), &[1i16, 2i16, 3i16]); + let s8 = Column::new("uint32".into(), &[1u32, 2u32, 3u32]); + let s9 = Column::new("uint64".into(), &[1u64, 2u64, 3u64]); + let s10 = Column::new("float32".into(), &[1.1f32, 2.2f32, 3.3f32]); + let s11 = Column::new("int64".into(), &[1i64, 2i64, 3i64]); + let s12 = Column::new("int8".into(), &[1i8, 2i8, 3i8]); let binary_data: Vec<&[u8]> = vec![&[1, 2, 3], &[4, 5, 6], &[7, 8, 9]]; - let s13 = Column::Series(Series::new("binary".into(), binary_data)); + let s13 = Column::new("binary".into(), binary_data); DataFrame::new(vec![s0, s1, s2, s3, s6, s8, s9, s10, s11, s12, s13]).unwrap() } From 6d8dd6920ebd66bf56d072d35fa6ca74c54e54cd Mon Sep 17 00:00:00 2001 From: Sylvain Benner Date: Thu, 30 Jan 2025 12:27:56 -0500 Subject: [PATCH 55/61] Fix BackendValues in backend-comparison after removal of jit suffix (#2756) --- backend-comparison/src/burnbenchapp/base.rs | 14 +++++++------- crates/burn-remote/src/server/session.rs | 4 ++-- .../burn-tensor/src/tensor/backend/conversion.rs | 4 ++-- examples/wgan/src/model.rs | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/backend-comparison/src/burnbenchapp/base.rs b/backend-comparison/src/burnbenchapp/base.rs index 9eba1485b3..4fb31edab8 100644 --- a/backend-comparison/src/burnbenchapp/base.rs +++ b/backend-comparison/src/burnbenchapp/base.rs @@ -62,6 +62,13 @@ enum BackendValues { CandleCuda, #[strum(to_string = "candle-metal")] CandleMetal, + #[strum(to_string = "cuda")] + Cuda, + #[strum(to_string = "cuda-fusion")] + CudaFusion, + #[cfg(target_os = "linux")] + #[strum(to_string = "hip")] + Hip, #[strum(to_string = "ndarray")] Ndarray, #[strum(to_string = "ndarray-blas-accelerate")] @@ -82,13 +89,6 @@ enum BackendValues { WgpuSpirv, #[strum(to_string = "wgpu-spirv-fusion")] WgpuSpirvFusion, - #[strum(to_string = "cuda-jit")] - CudaJit, - #[strum(to_string = "cuda-jit-fusion")] - CudaJitFusion, - #[cfg(target_os = "linux")] - #[strum(to_string = "hip-jit")] - HipJit, } #[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)] diff --git a/crates/burn-remote/src/server/session.rs b/crates/burn-remote/src/server/session.rs index 7d32d04b74..3da6b2afa1 100644 --- a/crates/burn-remote/src/server/session.rs +++ b/crates/burn-remote/src/server/session.rs @@ -101,12 +101,12 @@ impl SessionManager { impl Session { fn new(runner: Runner) -> Self { - let (sender, reveiver) = std::sync::mpsc::sync_channel(1); + let (sender, receiver) = std::sync::mpsc::sync_channel(1); Self { runner, streams: Default::default(), sender, - receiver: Some(reveiver), + receiver: Some(receiver), } } diff --git a/crates/burn-tensor/src/tensor/backend/conversion.rs b/crates/burn-tensor/src/tensor/backend/conversion.rs index 46b0423b71..6aebe06463 100644 --- a/crates/burn-tensor/src/tensor/backend/conversion.rs +++ b/crates/burn-tensor/src/tensor/backend/conversion.rs @@ -188,7 +188,7 @@ mod tests { } #[test] - fn should_build_indices_2d_complexe() { + fn should_build_indices_2d_complex() { let shape = Shape::new([2, 3]); let indices = build_indices(&shape, Order::Left); @@ -206,7 +206,7 @@ mod tests { } #[test] - fn should_build_indices_3d_complexe() { + fn should_build_indices_3d_complex() { let shape = Shape::new([2, 5, 3]); let indices = build_indices(&shape, Order::Left); diff --git a/examples/wgan/src/model.rs b/examples/wgan/src/model.rs index b9615f5270..755d8e9e1d 100644 --- a/examples/wgan/src/model.rs +++ b/examples/wgan/src/model.rs @@ -83,7 +83,7 @@ impl Discriminator { } } -// Use model config to construct a generative and adverserial model +// Use model config to construct a generative and adversarial model #[derive(Config, Debug)] pub struct ModelConfig { /// Dimensionality of the latent space From cb0854c636fd7e219675db3c0afc5c1db168574b Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 31 Jan 2025 09:18:51 -0500 Subject: [PATCH 56/61] Remove autodiff from generate (#2759) --- examples/wgan/examples/wgan-generate.rs | 29 +++++++++---------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/examples/wgan/examples/wgan-generate.rs b/examples/wgan/examples/wgan-generate.rs index 1b4a51a535..1d0a4fd87d 100644 --- a/examples/wgan/examples/wgan-generate.rs +++ b/examples/wgan/examples/wgan-generate.rs @@ -11,24 +11,18 @@ pub fn launch(device: B::Device) { feature = "ndarray-blas-accelerate", ))] mod ndarray { - use burn::backend::{ - ndarray::{NdArray, NdArrayDevice}, - Autodiff, - }; + use burn::backend::ndarray::{NdArray, NdArrayDevice}; use crate::launch; pub fn run() { - launch::>(NdArrayDevice::Cpu); + launch::(NdArrayDevice::Cpu); } } #[cfg(feature = "tch-gpu")] mod tch_gpu { - use burn::backend::{ - libtorch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::launch; @@ -38,41 +32,38 @@ mod tch_gpu { #[cfg(target_os = "macos")] let device = LibTorchDevice::Mps; - launch::>(device); + launch::(device); } } #[cfg(feature = "tch-cpu")] mod tch_cpu { - use burn::backend::{ - libtorch::{LibTorch, LibTorchDevice}, - Autodiff, - }; + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; use crate::launch; pub fn run() { - launch::>(LibTorchDevice::Cpu); + launch::(LibTorchDevice::Cpu); } } #[cfg(feature = "wgpu")] mod wgpu { use crate::launch; - use burn::backend::{wgpu::Wgpu, Autodiff}; + use burn::backend::wgpu::Wgpu; pub fn run() { - launch::>(Default::default()); + launch::(Default::default()); } } #[cfg(feature = "cuda")] mod cuda { use crate::launch; - use burn::backend::{Autodiff, Cuda}; + use burn::backend::Cuda; pub fn run() { - launch::>(Default::default()); + launch::(Default::default()); } } From c8f385cf8f4c965bf3dbe3afac3f7f14c9b2000c Mon Sep 17 00:00:00 2001 From: Genna Wingert <1058083+wingertge@users.noreply.github.com> Date: Mon, 3 Feb 2025 14:32:08 +0100 Subject: [PATCH 57/61] Update cubecl (#2764) * Update cubecl * Update to scope merge * Fix bitwise shift * Update * Update lock for OpenSSL fix --- Cargo.lock | 396 +++++++++--------- Cargo.toml | 4 +- crates/burn-jit/src/fusion/matmul/args.rs | 2 +- crates/burn-jit/src/fusion/on_write/ir.rs | 4 +- .../src/kernel/conv/conv2d/gemm/launch.rs | 12 +- .../burn-jit/src/kernel/conv/conv2d/im2col.rs | 38 +- crates/burn-jit/src/kernel/conv/error.rs | 11 +- crates/burn-jit/src/ops/int_ops.rs | 16 +- .../src/tensor/quantization/scheme.rs | 2 +- crates/burn-tensor/src/tests/ops/bitwise.rs | 4 + 10 files changed, 253 insertions(+), 236 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cb75c0ac6e..4151733570 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,7 +41,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", "zerocopy", @@ -174,7 +174,7 @@ checksum = "0ae92a5119aa49cdbcf6b9f893fe4e1d98b04ccbf82ee0584ad948a44a734dea" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -249,18 +249,18 @@ checksum = "c7c24de15d275a1ecfd47a380fb4d5ec9bfe0933f309ed5e705b775596a3574d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "async-trait" -version = "0.1.85" +version = "0.1.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" +checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -547,9 +547,9 @@ checksum = "c360505aed52b7ec96a3636c3f039d99103c37d1d9b4f7a8c743d3ea9ffcd03b" [[package]] name = "bumpalo" -version = "3.16.0" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "burn" @@ -589,7 +589,7 @@ version = "0.17.0" dependencies = [ "cubecl-common", "dashmap", - "getrandom", + "getrandom 0.2.15", "indicatif", "rayon", "reqwest", @@ -690,7 +690,7 @@ dependencies = [ "derive-new 0.7.0", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -741,7 +741,7 @@ dependencies = [ "rust-format", "serde", "serde_json", - "syn 2.0.96", + "syn 2.0.98", "thiserror 2.0.11", "tracing-core", "tracing-subscriber", @@ -929,7 +929,7 @@ checksum = "3fa76293b4f7bb636ab88fd78228235b5248b4d05cc589aed610f954af5d7c7a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -946,9 +946,9 @@ checksum = "8f1fe948ff07f4bd06c30984e69f5b4899c516a3ef74f34df92a2df2ab535495" [[package]] name = "bytes" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" +checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" dependencies = [ "serde", ] @@ -1057,9 +1057,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.9" +version = "1.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8293772165d9345bdaaa39b45b2109591e63fe5e6fbc23c6ff930a048aa310b" +checksum = "e4730490333d58093109dc02c23174c3f4d490998c3fed3cc8e82d57afedb9cf" dependencies = [ "jobserver", "libc", @@ -1164,7 +1164,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1184,9 +1184,9 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.52" +version = "0.1.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c682c223677e0e5b6b7f63a64b9351844c3f1b1678a68b7ee617e30fb082620e" +checksum = "e24a03c8b52922d68a1589ad61032f2c1aa5a8158d2aa0d93c6e9534944bbad6" dependencies = [ "cc", ] @@ -1336,9 +1336,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -1437,9 +1437,9 @@ dependencies = [ [[package]] name = "crunchy" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] name = "crypto-common" @@ -1475,7 +1475,7 @@ dependencies = [ [[package]] name = "cubecl" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-core", "cubecl-cuda", @@ -1490,14 +1490,14 @@ dependencies = [ [[package]] name = "cubecl-common" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "derive-new 0.6.0", "derive_more 1.0.0", "embassy-futures", "futures-lite", - "getrandom", + "getrandom 0.2.15", "half", "log", "num-traits", @@ -1511,7 +1511,7 @@ dependencies = [ [[package]] name = "cubecl-core" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bitflags 2.8.0", "bytemuck", @@ -1532,7 +1532,7 @@ dependencies = [ [[package]] name = "cubecl-cpp" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-common", @@ -1546,7 +1546,7 @@ dependencies = [ [[package]] name = "cubecl-cuda" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-common", @@ -1562,7 +1562,7 @@ dependencies = [ [[package]] name = "cubecl-hip" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-common", @@ -1578,9 +1578,9 @@ dependencies = [ [[package]] name = "cubecl-hip-sys" -version = "6.3.1000" +version = "6.3.1001" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4d987c1720eab39c72c515377a8001f683a4c4d99232a29fc0de389d9a8ce4f" +checksum = "c7e92df7f9feff6a469932fc4d4b349d28000af9e6f34e583eb4f8df70038d48" dependencies = [ "libc", ] @@ -1588,22 +1588,25 @@ dependencies = [ [[package]] name = "cubecl-ir" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-common", "cubecl-macros-internal", "derive_more 1.0.0", "float-ord", + "fnv", "half", + "hashbrown 0.14.5", "num-traits", + "portable-atomic", "serde", - "type_hash", + "variadics_please", ] [[package]] name = "cubecl-linalg" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bytemuck", "cubecl-core", @@ -1615,7 +1618,7 @@ dependencies = [ [[package]] name = "cubecl-macros" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-common", "darling", @@ -1624,24 +1627,24 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "cubecl-macros-internal" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "cubecl-opt" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-common", "cubecl-ir", @@ -1657,7 +1660,7 @@ dependencies = [ [[package]] name = "cubecl-reduce" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "cubecl-core", "cubecl-runtime", @@ -1667,7 +1670,7 @@ dependencies = [ [[package]] name = "cubecl-runtime" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "async-channel", "async-lock", @@ -1689,7 +1692,7 @@ dependencies = [ [[package]] name = "cubecl-spirv" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "bitflags 2.8.0", "cubecl-common", @@ -1704,7 +1707,7 @@ dependencies = [ [[package]] name = "cubecl-wgpu" version = "0.5.0" -source = "git+https://github.com/tracel-ai/cubecl?rev=ff34667accfe077d4a1cd48ae419868e142acfd6#ff34667accfe077d4a1cd48ae419868e142acfd6" +source = "git+https://github.com/tracel-ai/cubecl?rev=a172f6760052bef392e6f0e44e912460960f2c1b#a172f6760052bef392e6f0e44e912460960f2c1b" dependencies = [ "ash", "async-channel", @@ -1724,9 +1727,9 @@ dependencies = [ [[package]] name = "cudarc" -version = "0.12.2" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cd76de2aa3a7bdb9a65941ea5a3c688d941688f736a81b2fc5beb88747a7f25" +checksum = "38cd60a9a42ec83a2ed7effb0b1f073270264ea99da7acfc44f7e8d74dee0384" dependencies = [ "half", "libloading", @@ -1821,7 +1824,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1832,7 +1835,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1878,7 +1881,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1889,7 +1892,7 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1900,7 +1903,7 @@ checksum = "30542c1ad912e0e3d22a1935c290e12e8a29d704a420177a31faad4a601a0800" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1921,7 +1924,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1931,7 +1934,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" dependencies = [ "derive_builder_core", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1942,7 +1945,7 @@ checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -1962,7 +1965,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "unicode-xid", ] @@ -2018,7 +2021,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2042,9 +2045,9 @@ dependencies = [ [[package]] name = "dyn-clone" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" +checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" [[package]] name = "dyn-stack" @@ -2092,7 +2095,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2104,7 +2107,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2334,7 +2337,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2427,7 +2430,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -2613,10 +2616,22 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets 0.52.6", +] + [[package]] name = "gif" version = "0.13.1" @@ -2684,15 +2699,15 @@ dependencies = [ [[package]] name = "gix-trace" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04bdde120c29f1fc23a24d3e115aeeea3d60d8e65bab92cc5f9d90d9302eb952" +checksum = "7c396a2036920c69695f760a65e7f2677267ccf483f25046977d87e4cb2665f7" [[package]] name = "gix-utils" -version = "0.1.13" +version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba427e3e9599508ed98a6ddf8ed05493db114564e338e41f6a996d2e4790335f" +checksum = "ff08f24e03ac8916c478c8419d7d3c33393da9bb41fa4c24455d5406aeefd35f" dependencies = [ "fastrand", "unicode-normalization", @@ -2998,9 +3013,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.9.5" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" +checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" [[package]] name = "httpdate" @@ -3016,9 +3031,9 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "hyper" -version = "1.5.2" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" dependencies = [ "bytes", "futures-channel", @@ -3225,7 +3240,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -3331,9 +3346,9 @@ checksum = "d0263a3d970d5c054ed9312c0057b4f3bde9c0b33836d3637361d4a9e6e7a408" [[package]] name = "indexmap" -version = "2.7.0" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" +checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -3342,9 +3357,9 @@ dependencies = [ [[package]] name = "indicatif" -version = "0.17.9" +version = "0.17.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cbf675b85ed934d3c67b5c5469701eec7db22689d0a2139d856e0925fa28b281" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" dependencies = [ "console", "number_prefix", @@ -3378,7 +3393,7 @@ dependencies = [ "indoc", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -3398,14 +3413,14 @@ checksum = "c34819042dc3d3971c46c2190835914dfbe0c3c13f61449b2997f4e9722dfa60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "ipnet" -version = "2.10.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is_terminal_polyfill" @@ -3514,9 +3529,9 @@ checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libfuzzer-sys" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b9569d2f74e257076d8c6bfa73fb505b46b851e51ddaecc825944aa3bed17fa" +checksum = "cf78f52d400cf2d84a3a973a78a592b4adc535739e0a5597a0da6f0c357adc75" dependencies = [ "arbitrary", "cc", @@ -3791,9 +3806,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" +checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924" dependencies = [ "adler2", "simd-adler32", @@ -3807,7 +3822,7 @@ checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", "log", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -3858,7 +3873,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -3893,9 +3908,9 @@ dependencies = [ [[package]] name = "native-tls" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +checksum = "0dab59f8e050d5df8e4dd87d9206fb6f65a483e20ac9fda365ade4fab353196c" dependencies = [ "libc", "log", @@ -4061,7 +4076,7 @@ checksum = "ed3955f1a9c7c0c15e092f9c887db08b1fc683305fdf6eb6684f22555355e202" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -4133,7 +4148,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -4242,9 +4257,9 @@ dependencies = [ [[package]] name = "objc2-encode" -version = "4.0.3" +version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7891e71393cd1f227313c9379a26a584ff3d7e6e7159e988851f0934c993f0f8" +checksum = "ef25abbcd74fb2609453eb695bd2f860d389e457f67dc17cafc8b8cbc89d0c33" [[package]] name = "objc2-foundation" @@ -4395,9 +4410,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.68" +version = "0.10.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6174bc48f102d208783c2c84bf931bb75927a617866870de8a4ea85597f871f5" +checksum = "61cfb4e166a8bb8c9b55c500bc2308550148ece889be90f609377e58140f42c6" dependencies = [ "bitflags 2.8.0", "cfg-if", @@ -4416,20 +4431,20 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "openssl-probe" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" [[package]] name = "openssl-sys" -version = "0.9.104" +version = "0.9.105" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45abf306cbf99debc8195b66b7346498d7b10c210de50418b5ccd7ceba08c741" +checksum = "8b22d5b84be05a8d6947c7cb71f7c849aa0f112acd4bf51c2a7c1c988ac0a9dc" dependencies = [ "cc", "libc", @@ -4671,7 +4686,7 @@ version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72571dde488ecccbe799798bf99ab7308ebdb7cf5d95bcc498dbd5a132f0da4d" dependencies = [ - "getrandom", + "getrandom 0.2.15", "polars-arrow", "polars-core", "polars-error", @@ -4699,7 +4714,7 @@ dependencies = [ "dyn-clone", "either", "ethnum", - "getrandom", + "getrandom 0.2.15", "hashbrown 0.15.2", "itoa", "lz4", @@ -5144,7 +5159,7 @@ dependencies = [ "once_cell", "polars-error", "rand", - "raw-cpuid 11.2.0", + "raw-cpuid 11.3.0", "rayon", "stacker", "sysinfo 0.33.1", @@ -5156,6 +5171,9 @@ name = "portable-atomic" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +dependencies = [ + "serde", +] [[package]] name = "portable-atomic-util" @@ -5204,7 +5222,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" dependencies = [ "proc-macro2", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -5241,7 +5259,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a65f2e60fbf1063868558d69c6beacf412dc755f9fc020f514b7955fc914fe30" dependencies = [ "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -5411,7 +5429,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", ] [[package]] @@ -5426,9 +5444,9 @@ dependencies = [ [[package]] name = "range-alloc" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8a99fddc9f0ba0a85884b8d14e3592853e787d581ca1816c91349b10e4eeab" +checksum = "c3d6831663a5098ea164f89cff59c6284e95f4e3c76ce9848d4529f5ccca9bde" [[package]] name = "ratatui" @@ -5513,9 +5531,9 @@ dependencies = [ [[package]] name = "raw-cpuid" -version = "11.2.0" +version = "11.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" +checksum = "c6928fa44c097620b706542d428957635951bade7143269085389d42c8a4927e" dependencies = [ "bitflags 2.8.0", ] @@ -5586,7 +5604,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -5613,7 +5631,7 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom", + "getrandom 0.2.15", "libredox", "thiserror 1.0.69", ] @@ -5736,7 +5754,7 @@ checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.15", "libc", "spin", "untrusted", @@ -5800,7 +5818,7 @@ dependencies = [ "regex", "relative-path", "rustc_version", - "syn 2.0.96", + "syn 2.0.98", "unicode-ident", ] @@ -5851,9 +5869,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.43" +version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" +checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ "bitflags 2.8.0", "errno", @@ -5864,9 +5882,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.21" +version = "0.23.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8" +checksum = "9fb9263ab4eb695e42321db096e3b8fbd715a59b154d5c88d82db2175b681ba7" dependencies = [ "log", "once_cell", @@ -5901,9 +5919,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" [[package]] name = "rustls-webpki" @@ -5924,9 +5942,9 @@ checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "safetensors" @@ -6040,9 +6058,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.24" +version = "1.0.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" +checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" [[package]] name = "seq-macro" @@ -6087,14 +6105,14 @@ checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] name = "serde_json" -version = "1.0.137" +version = "1.0.138" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "930cfb6e6abf99298aaad7d29abbef7a9999a9a8806a40088f55f0dcec03146b" +checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" dependencies = [ "itoa", "memchr", @@ -6165,7 +6183,7 @@ checksum = "5d69265a08751de7844521fd15003ae0a888e035773ba05695c5c759a6f89eef" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6443,7 +6461,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6459,15 +6477,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", - "quote", "unicode-ident", ] [[package]] name = "syn" -version = "2.0.96" +version = "2.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5d0adab1ae378d7f53bdebc67a39f1f151407ef230f0ce2883572f5d8985c80" +checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" dependencies = [ "proc-macro2", "quote", @@ -6491,7 +6508,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6620,13 +6637,13 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.15.0" +version = "3.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" +checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" dependencies = [ "cfg-if", "fastrand", - "getrandom", + "getrandom 0.3.1", "once_cell", "rustix", "windows-sys 0.59.0", @@ -6709,7 +6726,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6720,7 +6737,7 @@ checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6820,7 +6837,7 @@ dependencies = [ "aho-corasick", "derive_builder", "esaxx-rs", - "getrandom", + "getrandom 0.2.15", "hf-hub", "itertools 0.12.1", "lazy_static", @@ -6867,7 +6884,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -6938,9 +6955,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.22" +version = "0.22.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +checksum = "02a8b472d1a3d7c18e2d61a489aee3453fd9031c33e4f55bd533f4a7adca1bee" dependencies = [ "indexmap", "serde", @@ -7019,7 +7036,7 @@ checksum = "5a3a646485f7cd8f580749ab94718ad3d344bcc0cc5b0fefe43c15fdd898bb96" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7054,7 +7071,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7129,37 +7146,6 @@ dependencies = [ "rustc-hash", ] -[[package]] -name = "type_hash" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03c86f48f11992d3e379358c63cb25736c0b23944ff000d1583bbccad2b0b7c6" -dependencies = [ - "type_hash_core", - "type_hash_macros", -] - -[[package]] -name = "type_hash_core" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87b1e93e2cd97790892dbe2d2813fbaa6eebaeb960265f59e363e79e51e4997a" -dependencies = [ - "fnv", -] - -[[package]] -name = "type_hash_macros" -version = "0.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "746fc164e076483ef087b3989f7aa80ffd9320fa558f3cb72cecfb9bb1dbc41e" -dependencies = [ - "either", - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "typenum" version = "1.17.0" @@ -7210,9 +7196,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" [[package]] name = "unicode-normalization" @@ -7349,7 +7335,7 @@ version = "1.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b" dependencies = [ - "getrandom", + "getrandom 0.2.15", "rand", ] @@ -7366,9 +7352,9 @@ dependencies = [ [[package]] name = "valuable" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" [[package]] name = "variadics_please" @@ -7378,7 +7364,7 @@ checksum = "41b6d82be61465f97d42bd1d15bf20f3b0a3a0905018f38f9d6f6962055b0b5c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7424,6 +7410,15 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + [[package]] name = "wasm-bindgen" version = "0.2.100" @@ -7446,7 +7441,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "wasm-bindgen-shared", ] @@ -7481,7 +7476,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -7543,9 +7538,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.7" +version = "0.26.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e" +checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9" dependencies = [ "rustls-pki-types", ] @@ -7778,7 +7773,7 @@ checksum = "9107ddc059d5b6fbfbffdfa7a7fe3e22a226def0b2608f72e9d552763d3e1ad7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7789,7 +7784,7 @@ checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7800,7 +7795,7 @@ checksum = "29bee4b38ea3cde66011baa44dba677c432a78593e202392d1e9070cf2a7fca7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -7811,7 +7806,7 @@ checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8003,13 +7998,22 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.24" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8d71a593cc5c42ad7876e2c1fda56f314f3754c084128833e64f1345ff8a03a" +checksum = "7e49d2d35d3fad69b39b94139037ecfb4f359f08958b9c11e7315ce770462419" dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags 2.8.0", +] + [[package]] name = "wrapcenum-derive" version = "0.4.1" @@ -8019,7 +8023,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8116,7 +8120,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "synstructure", ] @@ -8138,7 +8142,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8158,7 +8162,7 @@ checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", "synstructure", ] @@ -8179,7 +8183,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] @@ -8201,7 +8205,7 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.96", + "syn 2.0.98", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 09b944468c..7263dec57a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false } portable-atomic-util = { version = "0.2.4", features = ["alloc"] } ### For the main burn branch. ### -cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } -cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "ff34667accfe077d4a1cd48ae419868e142acfd6" } +cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" } +cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "a172f6760052bef392e6f0e44e912460960f2c1b" } ### For local development. ### # cubecl = { path = "../cubecl/crates/cubecl", default-features = false } # cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false } diff --git a/crates/burn-jit/src/fusion/matmul/args.rs b/crates/burn-jit/src/fusion/matmul/args.rs index 1dbbf3baea..bba18e88f9 100644 --- a/crates/burn-jit/src/fusion/matmul/args.rs +++ b/crates/burn-jit/src/fusion/matmul/args.rs @@ -247,7 +247,7 @@ impl CubeType for FusedMatmulState { } impl Init for FusedMatmulStateExpand { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _context: &mut Scope) -> Self { self } } diff --git a/crates/burn-jit/src/fusion/on_write/ir.rs b/crates/burn-jit/src/fusion/on_write/ir.rs index 0cec2d29c7..36c8e402a0 100644 --- a/crates/burn-jit/src/fusion/on_write/ir.rs +++ b/crates/burn-jit/src/fusion/on_write/ir.rs @@ -45,13 +45,13 @@ impl CubeType for Arg { } impl Init for Arg { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _context: &mut Scope) -> Self { self } } impl IntoRuntime for Arg { - fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType { + fn __expand_runtime_method(self, _context: &mut Scope) -> Self::ExpandType { self } } diff --git a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs index f36f89bdf5..ad70a9b825 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs @@ -7,7 +7,7 @@ use burn_tensor::{ use cubecl::{ flex32, ir::{Elem, FloatKind}, - linalg::matmul::{self}, + linalg::matmul::{self, kernels::MatmulLaunchError}, tensor_line_size, tf32, Feature, }; use half::{bf16, f16}; @@ -195,18 +195,14 @@ where let cube_count = Alg::cube_count(&selection, &problem); let advanced_config = Default::default(); - let config = match Alg::make_config( + let config = Alg::make_config( config_input, &problem, &cube_dim, &cube_count, &advanced_config, - ) { - Ok(val) => val, - Err(err) => { - panic!("Can't launch conv kernel because of an invalid config: {err}") - } - }; + ) + .map_err(MatmulLaunchError::InvalidConfig)?; let bias = bias.unwrap_or_else(|| { empty_device::(input.client.clone(), input.device.clone(), Shape::new([1])) diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs index f74cdaf8bc..09ce56898b 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -98,25 +98,38 @@ fn im2col_kernel( } #[cfg(not(test))] -pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option { - let cube_count_per_batch = (out_h * out_w).div_ceil(burn_common::PLANE_DIM_APPROX); +pub(crate) fn batches_per_run( + batch_size: usize, + out_h: usize, + out_w: usize, +) -> Result { + use cubecl::linalg::matmul::kernels::MatmulAvailabilityError; + + let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::PLANE_DIM_APPROX); let max_cube_count = u16::MAX as usize; let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size); if max_simultaneous == 0 { - return None; + return Err(MatmulAvailabilityError::CubeCountTooBig(CubeCount::Static( + cube_count_per_batch as u32, + 1, + 1, + )) + .into()); } - Some( - (0..=max_simultaneous) - .rev() - .find(|per_run| batch_size % per_run == 0) - .expect("Logically not possible"), - ) + Ok((0..=max_simultaneous) + .rev() + .find(|per_run| batch_size % per_run == 0) + .expect("Logically not possible")) } #[cfg(test)] #[allow(unused)] -pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option { - Some(1) +pub(crate) fn batches_per_run( + batch_size: usize, + out_h: usize, + out_w: usize, +) -> Result { + Ok(1) } fn im2col( @@ -214,8 +227,7 @@ pub fn conv2d_im2col( return execute_1x1_kernel::(input, weight, bias, options); } - let batches_per_run = batches_per_run(batch_size, out_h, out_w) - .expect("Image too large to run even one batch at once"); + let batches_per_run = batches_per_run(batch_size, out_h, out_w)?; let matmul_shape = Shape::new([groups, out_c_per_group, batches_per_run * out_h * out_w]); let mut out = if batches_per_run != batch_size { diff --git a/crates/burn-jit/src/kernel/conv/error.rs b/crates/burn-jit/src/kernel/conv/error.rs index 99c91fc751..2654a20e24 100644 --- a/crates/burn-jit/src/kernel/conv/error.rs +++ b/crates/burn-jit/src/kernel/conv/error.rs @@ -1,5 +1,8 @@ use core::fmt::Debug; -use cubecl::{linalg::matmul::kernels::MatmulLaunchError, tune::AutotuneError}; +use cubecl::{ + linalg::matmul::kernels::{MatmulAvailabilityError, MatmulLaunchError}, + tune::AutotuneError, +}; pub enum ConvLaunchError { Matmul(MatmulLaunchError), @@ -30,6 +33,12 @@ impl From for ConvLaunchError { } } +impl From for ConvLaunchError { + fn from(value: MatmulAvailabilityError) -> Self { + Self::Matmul(MatmulLaunchError::Unavailable(value)) + } +} + #[allow(clippy::from_over_into)] impl Into for ConvLaunchError { fn into(self) -> AutotuneError { diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 068c1269d9..8da778d1e8 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -328,26 +328,18 @@ where } fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - let lhs_cast = kernel::cast::(lhs); - let rhs_cast = kernel::cast::(rhs); - launch_binop_int::(lhs_cast, rhs_cast) + launch_binop_int::(lhs, rhs) } fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { - let lhs_cast = kernel::cast::(lhs); - let rhs_cast = rhs.elem::(); - launch_scalar_binop_int::(lhs_cast, rhs_cast) + launch_scalar_binop_int::(lhs, rhs) } fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { - let lhs_cast = kernel::cast::(lhs); - let rhs_cast = kernel::cast::(rhs); - launch_binop_int::(lhs_cast, rhs_cast) + launch_binop_int::(lhs, rhs) } fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { - let lhs_cast = kernel::cast::(lhs); - let rhs_cast = rhs.elem::(); - launch_scalar_binop_int::(lhs_cast, rhs_cast) + launch_scalar_binop_int::(lhs, rhs) } } diff --git a/crates/burn-tensor/src/tensor/quantization/scheme.rs b/crates/burn-tensor/src/tensor/quantization/scheme.rs index fb141ee16d..27fa996ad6 100644 --- a/crates/burn-tensor/src/tensor/quantization/scheme.rs +++ b/crates/burn-tensor/src/tensor/quantization/scheme.rs @@ -37,7 +37,7 @@ impl CubeType for QuantizationScheme { } #[cfg(feature = "cubecl")] impl cubecl::frontend::Init for QuantizationScheme { - fn init(self, _context: &mut CubeContext) -> Self { + fn init(self, _scope: &mut cubecl::ir::Scope) -> Self { self } } diff --git a/crates/burn-tensor/src/tests/ops/bitwise.rs b/crates/burn-tensor/src/tests/ops/bitwise.rs index 73702a716e..c85f5edcc5 100644 --- a/crates/burn-tensor/src/tests/ops/bitwise.rs +++ b/crates/burn-tensor/src/tests/ops/bitwise.rs @@ -124,6 +124,10 @@ mod tests { #[test] fn should_apply_bitwise_left_shift_2d() { + if (IntType::MAX as u32) < 512 { + return; + } + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); From e0c641934fb44b67e6d6d143e5b72d06cd68ba41 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 08:32:58 -0500 Subject: [PATCH 58/61] Bump indicatif from 0.17.9 to 0.17.11 (#2769) Bumps [indicatif](https://github.com/console-rs/indicatif) from 0.17.9 to 0.17.11. - [Release notes](https://github.com/console-rs/indicatif/releases) - [Commits](https://github.com/console-rs/indicatif/compare/0.17.9...0.17.11) --- updated-dependencies: - dependency-name: indicatif dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 7263dec57a..7287eae729 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ globwalk = "0.9.1" hashbrown = "0.15.2" hound = "3.5.1" image = "0.25.5" -indicatif = "0.17.9" +indicatif = "0.17.11" js-sys = "0.3.72" libm = "0.2.11" log = { default-features = false, version = "0.4.25" } From 6b2e66bd36bd4ddc0e2cd1e94690ece6212562fa Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 3 Feb 2025 09:19:27 -0500 Subject: [PATCH 59/61] Bump sysinfo from 0.32.1 to 0.33.1 (#2771) * Bump sysinfo from 0.32.1 to 0.33.1 Bumps [sysinfo](https://github.com/GuillaumeGomez/sysinfo) from 0.32.1 to 0.33.1. - [Changelog](https://github.com/GuillaumeGomez/sysinfo/blob/master/CHANGELOG.md) - [Commits](https://github.com/GuillaumeGomez/sysinfo/compare/v0.32.1...v0.33.1) --- updated-dependencies: - dependency-name: sysinfo dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] * Fix Hip backend name * Fix refresh kind methods --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Guillaume Lagrange --- Cargo.lock | 25 +++++-------------- Cargo.toml | 2 +- .../src/persistence/system_info.rs | 2 +- crates/burn-train/src/metric/cpu_use.rs | 4 ++- .../examples/ag-news-train.rs | 4 +-- 5 files changed, 13 insertions(+), 24 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4151733570..ab9eddb0bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -393,7 +393,7 @@ dependencies = [ "serial_test", "strum", "strum_macros", - "sysinfo 0.32.1", + "sysinfo", "tracing-subscriber", "wgpu", "wsl", @@ -893,7 +893,7 @@ dependencies = [ "ratatui", "rstest", "serde", - "sysinfo 0.32.1", + "sysinfo", "systemstat", "tracing-appender", "tracing-core", @@ -3544,7 +3544,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" dependencies = [ "cfg-if", - "windows-targets 0.52.6", + "windows-targets 0.48.5", ] [[package]] @@ -5162,7 +5162,7 @@ dependencies = [ "raw-cpuid 11.3.0", "rayon", "stacker", - "sysinfo 0.33.1", + "sysinfo", "version_check", ] @@ -6525,21 +6525,6 @@ dependencies = [ "walkdir", ] -[[package]] -name = "sysinfo" -version = "0.32.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c33cd241af0f2e9e3b5c32163b873b29956890b5342e6745b917ce9d490f4af" -dependencies = [ - "core-foundation-sys", - "libc", - "memchr", - "ntapi", - "rayon", - "serde", - "windows 0.57.0", -] - [[package]] name = "sysinfo" version = "0.33.1" @@ -6550,6 +6535,8 @@ dependencies = [ "libc", "memchr", "ntapi", + "rayon", + "serde", "windows 0.57.0", ] diff --git a/Cargo.toml b/Cargo.toml index 7287eae729..169d668aa8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -145,7 +145,7 @@ uuid = { version = "1.12.1", default-features = false } libc = "0.2.169" nvml-wrapper = "0.10.0" -sysinfo = "0.32.1" +sysinfo = "0.33.1" systemstat = "0.2.3" tch = "0.15.0" diff --git a/backend-comparison/src/persistence/system_info.rs b/backend-comparison/src/persistence/system_info.rs index 287b629c21..3fe24bc955 100644 --- a/backend-comparison/src/persistence/system_info.rs +++ b/backend-comparison/src/persistence/system_info.rs @@ -38,7 +38,7 @@ impl BenchmarkSystemInfo { fn enumerate_cpus() -> Vec { let system = sysinfo::System::new_with_specifics( - sysinfo::RefreshKind::new().with_cpu(sysinfo::CpuRefreshKind::everything()), + sysinfo::RefreshKind::nothing().with_cpu(sysinfo::CpuRefreshKind::everything()), ); let cpu_names: HashSet = system .cpus() diff --git a/crates/burn-train/src/metric/cpu_use.rs b/crates/burn-train/src/metric/cpu_use.rs index 2769793088..d06d8429db 100644 --- a/crates/burn-train/src/metric/cpu_use.rs +++ b/crates/burn-train/src/metric/cpu_use.rs @@ -26,7 +26,9 @@ impl CpuUse { } fn refresh(sys: &mut System) -> f64 { - sys.refresh_specifics(RefreshKind::new().with_cpu(CpuRefreshKind::new().with_cpu_usage())); + sys.refresh_specifics( + RefreshKind::nothing().with_cpu(CpuRefreshKind::nothing().with_cpu_usage()), + ); let cpus = sys.cpus(); let num_cpus = cpus.len(); diff --git a/examples/text-classification/examples/ag-news-train.rs b/examples/text-classification/examples/ag-news-train.rs index 9a9cab44bd..927c190b2c 100644 --- a/examples/text-classification/examples/ag-news-train.rs +++ b/examples/text-classification/examples/ag-news-train.rs @@ -116,10 +116,10 @@ mod cuda { #[cfg(feature = "hip")] mod hip { use crate::{launch, ElemType}; - use burn::backend::{Autodiff, HipJit}; + use burn::backend::{Autodiff, Hip}; pub fn run() { - launch::>>(vec![Default::default()]); + launch::>>(vec![Default::default()]); } } From 9f003203d05a0b260c3cd5ad44a7460dfcffc67d Mon Sep 17 00:00:00 2001 From: SalvoMcL <64030770+salvomcl@users.noreply.github.com> Date: Mon, 3 Feb 2025 16:05:14 +0100 Subject: [PATCH 60/61] Feat: Add PoissonNLL loss (#2765) * added PoissonNLLLossConfig * added PoissonNLLLoss * added tests * update docs * added requested changes --- burn-book/src/building-blocks/module.md | 1 + crates/burn-core/src/nn/loss/mod.rs | 2 + crates/burn-core/src/nn/loss/poisson.rs | 390 ++++++++++++++++++++++++ 3 files changed, 393 insertions(+) create mode 100644 crates/burn-core/src/nn/loss/poisson.rs diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index 0f5aca7f24..9598d6e39e 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -294,3 +294,4 @@ Burn comes with built-in modules that you can use to build your own modules. | `CrossEntropyLoss` | `nn.CrossEntropyLoss` | | `MseLoss` | `nn.MSELoss` | | `HuberLoss` | `nn.HuberLoss` | +| `PoissonNllLoss` | `nn.PoissonNLLLoss` | diff --git a/crates/burn-core/src/nn/loss/mod.rs b/crates/burn-core/src/nn/loss/mod.rs index cca7b4541b..475364e63b 100644 --- a/crates/burn-core/src/nn/loss/mod.rs +++ b/crates/burn-core/src/nn/loss/mod.rs @@ -2,10 +2,12 @@ mod binary_cross_entropy; mod cross_entropy; mod huber; mod mse; +mod poisson; mod reduction; pub use binary_cross_entropy::*; pub use cross_entropy::*; pub use huber::*; pub use mse::*; +pub use poisson::*; pub use reduction::*; diff --git a/crates/burn-core/src/nn/loss/poisson.rs b/crates/burn-core/src/nn/loss/poisson.rs new file mode 100644 index 0000000000..3cc989ad8e --- /dev/null +++ b/crates/burn-core/src/nn/loss/poisson.rs @@ -0,0 +1,390 @@ +use core::f32::consts::PI; + +use crate as burn; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; +use crate::tensor::backend::Backend; +use crate::tensor::Tensor; +use crate::{config::Config, module::Module}; + +use super::Reduction; + +/// Configuration for creating a [PoissonNllLoss](PoissonNllLoss) instance. +/// +/// This configuration allows customization of the Poisson Negative Log Likelihood (NLL) loss +/// behavior, such as whether the input is in log-space, whether to include the Stirling +/// approximation term, and a small epsilon value to avoid numerical instability. +#[derive(Config, Debug)] +pub struct PoissonNllLossConfig { + /// If `true`, the predictions are expected to be in log-space. + /// + /// When `log_input` is `true`, the loss is computed as: + /// ```text + /// L(predictions, target) = exp(predictions) - target * predictions + /// ``` + /// When `log_input` is `false`, the loss is computed as: + /// ```text + /// L(predictions, target) = predictions - target * log(predictions + eps) + /// ``` + #[config(default = true)] + pub log_input: bool, + /// Whether to compute the full loss, including the Stirling approximation term. + /// + /// When `full` is `true`, the Stirling approximation term is added to the loss: + /// ```text + /// target * log(target) - target + 0.5 * log(2 * PI * target) + /// ``` + #[config(default = false)] + pub full: bool, + /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`. + /// + /// This epsilon value is added to the predictions to ensure numerical stability + /// when computing the logarithm. + #[config(default = 1e-8)] + pub eps: f64, +} + +impl PoissonNllLossConfig { + /// Initializes a [PoissonNllLoss](PoissonNllLoss) instance with the current configuration. + /// + /// # Panics + /// - Panics if `eps` is not a positive number. + pub fn init(&self) -> PoissonNllLoss { + self.assertions(); + PoissonNllLoss { + log_input: self.log_input, + full: self.full, + eps: self.eps, + } + } + + /// Validates the configuration parameters. + /// + /// # Panics + /// - Panics if `eps` is not a positive number. + fn assertions(&self) { + assert!( + self.eps > 0., + "eps for PoissonNllLoss must be a positive number." + ); + } +} + +/// Negative Log Likelihood (NLL) loss with a Poisson distribution assumption for the target. +/// +/// This loss function is used when the target values are assumed to follow a Poisson distribution. +/// The loss is defined as: +/// ```text +/// target ~ Poisson(input) +/// L(predictions, target) = predictions - target * log(predictions) + log(target!) +/// ``` +/// The last term (`log(target!)`) can be omitted or approximated using Stirling's formula. +/// The approximation is applied for `target > 1`, while for `target <= 1`, zeros are added to the loss. +/// +/// For more details, see: +/// +#[derive(Module, Debug, Clone)] +#[module(custom_display)] +pub struct PoissonNllLoss { + /// If `true`, the predictions are expected to be in log-space. + pub log_input: bool, + /// Whether to compute the full loss, including the Stirling approximation term. + pub full: bool, + /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`. + pub eps: f64, +} + +impl ModuleDisplay for PoissonNllLoss { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("log_input", &self.log_input) + .add("full", &self.full) + .add("eps", &self.eps) + .optional() + } +} + +impl PoissonNllLoss { + /// Computes the loss element-wise for the given predictions and targets, then reduces + /// the result to a single loss value. + /// + /// # Arguments + /// - `predictions`: The predicted values. + /// - `targets`: The target values. + /// - `reduction`: The reduction method to apply. `Reduction::Auto` behaves as `Reduction::Mean`. + /// + /// # Shapes + /// - `predictions`: `[...dims]` + /// - `targets`: `[...dims]` + /// - `output`: `[1]` + /// + /// # Panics + /// - Panics if the shapes of `predictions` and `targets` do not match. + /// - Panics if any target value is negative. + /// - Panics if `log_input` is `false` and any prediction value is negative. + pub fn forward( + &self, + predictions: Tensor, + targets: Tensor, + reduction: Reduction, + ) -> Tensor { + let loss = self.forward_no_reduction(predictions, targets); + match reduction { + Reduction::Mean | Reduction::Auto => loss.mean(), + Reduction::Sum => loss.sum(), + } + } + + /// Computes the loss element-wise for the given predictions and targets without reduction. + /// + /// # Arguments + /// - `predictions`: The predicted values. + /// - `targets`: The target values. + /// + /// # Shapes + /// - `predictions`: `[...dims]` + /// - `targets`: `[...dims]` + /// - `output`: `[...dims]` + /// + /// # Panics + /// - Panics if the shapes of `predictions` and `targets` do not match. + /// - Panics if any target value is negative. + /// - Panics if `log_input` is `false` and any prediction value is negative. + pub fn forward_no_reduction( + &self, + predictions: Tensor, + targets: Tensor, + ) -> Tensor { + self.assertions(&predictions, &targets); + let mut loss; + if self.log_input { + loss = predictions.clone().exp() - targets.clone() * predictions; + } else { + loss = predictions.clone() - targets.clone() * (predictions + self.eps).log(); + } + if self.full { + let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone() + + (targets.clone() * 2. * PI).log() * 0.5; + loss = loss + + log_stirling_term + .mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like()); + } + loss + } + + /// Validates the input tensors for the loss computation. + /// + /// # Panics + /// - Panics if the shapes of `predictions` and `targets` do not match. + /// - Panics if any target value is negative. + /// - Panics if `log_input` is `false` and any prediction value is negative. + fn assertions( + &self, + predictions: &Tensor, + targets: &Tensor, + ) { + let predictions_dims = predictions.dims(); + let targets_dims = targets.dims(); + assert!( + predictions_dims == targets_dims, + "Shape of targets ({:?}) should correspond to outer shape of predictions ({:?}).", + targets_dims, + predictions_dims + ); + assert!( + targets.clone().greater_equal_elem(0.).all().into_scalar(), + "All the values of `targets` must be non-negative." + ); + if !self.log_input { + assert!( + predictions.clone().greater_equal_elem(0.).all().into_scalar(), + "When `log_input` is `false`, all the values of `predictions` must be non-negative." + ); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor::TensorData; + use crate::TestBackend; + type TestTensor = Tensor; + + #[test] + fn test_poisson_nll_loss() { + let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().init(); + + let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum); + let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); + let loss_no_reduction = poisson.forward_no_reduction(predictions, targets); + + let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]); + loss_no_reduction.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([21.0321]); + loss.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([126.1929]); + loss_sum.into_data().assert_approx_eq(&expected, 5); + } + + #[test] + fn test_poisson_nll_loss_no_log_input() { + let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]); + let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_log_input(false).init(); + + let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone()); + + let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]); + loss_no_reduction.into_data().assert_approx_eq(&expected, 5); + } + + #[test] + fn test_poisson_nll_loss_full() { + let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_full(true).init(); + + let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum); + let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); + let loss_no_reduction = poisson.forward_no_reduction(predictions, targets); + + let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]); + loss_no_reduction.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([21.9920]); + loss.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([131.9518]); + loss_sum.into_data().assert_approx_eq(&expected, 5); + } + + #[cfg(feature = "std")] + #[test] + fn test_poisson_nll_loss_gradients() { + type TestAutodiffTensor = Tensor; + + let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad(); + let predictions2 = predictions1.clone(); + let targets = TestAutodiffTensor::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_full(false).init(); + let poisson_full = PoissonNllLossConfig::new().with_full(true).init(); + + let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum); + let loss_full_sum = + poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum); + + let grads = loss_sum.backward(); + let grads_full = loss_full_sum.backward(); + + let grads_predictions1 = predictions1.grad(&grads).unwrap(); + let grads_predictions2 = predictions2.grad(&grads_full).unwrap(); + + let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]); + + grads_predictions1 + .into_data() + .assert_approx_eq(&expected, 5); + grads_predictions2 + .into_data() + .assert_approx_eq(&expected, 5); + } + + #[test] + #[should_panic = "eps for PoissonNllLoss must be a positive number."] + fn test_negative_eps() { + let _poisson = PoissonNllLossConfig::new().with_eps(0.).init(); + } + + #[test] + #[should_panic = "All the values of `targets` must be non-negative."] + fn test_targets_with_negative_values() { + let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().init(); + + let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); + } + + #[test] + #[should_panic = "Shape of targets"] + fn test_shape_tensors() { + let predictions = TensorData::from([0., 1., 2.]); + let targets = TensorData::from([0., 1.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().init(); + + let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone()); + } + + #[test] + #[should_panic = "When `log_input` is `false`, all the values of `predictions` must be non-negative."] + fn test_exp_predictions_non_negative() { + let predictions = TensorData::from([0.3, -0.1, 0.4]); + let targets = TensorData::from([0., 1., 0.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_log_input(false).init(); + + let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone()); + } + + #[test] + fn display() { + let config = PoissonNllLossConfig::new(); + let loss = config.init(); + + assert_eq!( + alloc::format!("{}", loss), + "PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}" + ); + } +} From e2fa935a7522be64fd8d4ba94750653532705f74 Mon Sep 17 00:00:00 2001 From: jiawen wang Date: Tue, 4 Feb 2025 00:24:15 +0800 Subject: [PATCH 61/61] modern lstm (#2752) * modern lstm * format * formatting * formatting * formatting * formatting * fix a typo * Update examples/modern-lstm/Cargo.toml Co-authored-by: Guillaume Lagrange * use generic backend * remove Cargo.lock * use backend for inference * update readme * Update README + fix main changes * Fix clippy --------- Co-authored-by: Guillaume Lagrange --- Cargo.lock | 11 + examples/modern-lstm/Cargo.toml | 27 ++ examples/modern-lstm/README.md | 46 +++ examples/modern-lstm/examples/lstm-infer.rs | 86 +++++ examples/modern-lstm/examples/lstm-train.rs | 104 ++++++ examples/modern-lstm/src/dataset.rs | 110 ++++++ examples/modern-lstm/src/inference.rs | 45 +++ examples/modern-lstm/src/lib.rs | 4 + examples/modern-lstm/src/model.rs | 362 ++++++++++++++++++++ examples/modern-lstm/src/training.rs | 131 +++++++ 10 files changed, 926 insertions(+) create mode 100644 examples/modern-lstm/Cargo.toml create mode 100644 examples/modern-lstm/README.md create mode 100644 examples/modern-lstm/examples/lstm-infer.rs create mode 100644 examples/modern-lstm/examples/lstm-train.rs create mode 100644 examples/modern-lstm/src/dataset.rs create mode 100644 examples/modern-lstm/src/inference.rs create mode 100644 examples/modern-lstm/src/lib.rs create mode 100644 examples/modern-lstm/src/model.rs create mode 100644 examples/modern-lstm/src/training.rs diff --git a/Cargo.lock b/Cargo.lock index ab9eddb0bd..c9ce522699 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3855,6 +3855,17 @@ dependencies = [ "burn-import", ] +[[package]] +name = "modern-lstm" +version = "0.1.0" +dependencies = [ + "burn", + "polars", + "rand", + "rand_distr", + "serde", +] + [[package]] name = "monostate" version = "0.1.13" diff --git a/examples/modern-lstm/Cargo.toml b/examples/modern-lstm/Cargo.toml new file mode 100644 index 0000000000..86855e9ad4 --- /dev/null +++ b/examples/modern-lstm/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "modern-lstm" +version = "0.1.0" +edition = "2021" + +[features] +ndarray = ["burn/ndarray"] +ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] +ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] +ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"] +tch-cpu = ["burn/tch"] +tch-gpu = ["burn/tch"] +wgpu = ["burn/wgpu"] +cuda = ["burn/cuda"] + +[dependencies] +burn = { path = "../../crates/burn", features=["train"] } + +# Random number generator +rand = { workspace = true } +rand_distr = { workspace = true } + +# Serialization +serde = {workspace = true, features = ["std", "derive"]} + +# Organise the results in dataframe +polars = { workspace = true } diff --git a/examples/modern-lstm/README.md b/examples/modern-lstm/README.md new file mode 100644 index 0000000000..832851a1f0 --- /dev/null +++ b/examples/modern-lstm/README.md @@ -0,0 +1,46 @@ +# Advanced LSTM Implementation with Burn + +A more advanced implementation of Long Short-Term Memory (LSTM) networks in Burn with combined +weight matrices for the input and hidden states, based on the +[PyTorch implementation](https://github.com/shiv08/Advanced-LSTM-Implementation-with-PyTorch). + +`LstmNetwork` is the top-level module with bidirectional and regularization support. The LSTM +variants differ by `bidirectional` and `num_layers` settings: + +- LSTM: `num_layers = 1` and `bidirectional = false` +- Stacked LSTM: `num_layers > 1` and `bidirectional = false` +- Bidirectional LSTM: `num_layers = 1` and `bidirectional = true` +- Bidirectional Stacked LSTM: `num_layers > 1` and `bidirectional = true` + +This implementation is complementary to Burn's official LSTM, users can choose either one depends on +the project's specific needs. + +## Usage + +## Training + +```sh +# Cuda backend +cargo run --example lstm-train --release --features cuda-jit + +# Wgpu backend +cargo run --example lstm-train --release --features wgpu + +# Tch GPU backend +export TORCH_CUDA_VERSION=cu121 # Set the cuda version +cargo run --example lstm-train --release --features tch-gpu + +# Tch CPU backend +cargo run --example lstm-train --release --features tch-cpu + +# NdArray backend (CPU) +cargo run --example lstm-train --release --features ndarray +cargo run --example lstm-train --release --features ndarray-blas-openblas +cargo run --example lstm-train --release --features ndarray-blas-netlib +``` + +### Inference + +```sh +cargo run --example lstm-infer --release --features cuda-jit +``` diff --git a/examples/modern-lstm/examples/lstm-infer.rs b/examples/modern-lstm/examples/lstm-infer.rs new file mode 100644 index 0000000000..f601d08c79 --- /dev/null +++ b/examples/modern-lstm/examples/lstm-infer.rs @@ -0,0 +1,86 @@ +use burn::tensor::backend::Backend; + +pub fn launch(device: B::Device) { + modern_lstm::inference::infer::("/tmp/modern-lstm", device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::ndarray::{NdArray, NdArrayDevice}; + + use crate::launch; + + pub fn run() { + launch::(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::libtorch::{LibTorch, LibTorchDevice}; + + use crate::launch; + + pub fn run() { + launch::(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::wgpu::Wgpu; + + pub fn run() { + launch::(Default::default()); + } +} + +#[cfg(feature = "cuda")] +mod cuda { + use crate::launch; + use burn::backend::Cuda; + + pub fn run() { + launch::(Default::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda")] + cuda::run(); +} diff --git a/examples/modern-lstm/examples/lstm-train.rs b/examples/modern-lstm/examples/lstm-train.rs new file mode 100644 index 0000000000..454263d331 --- /dev/null +++ b/examples/modern-lstm/examples/lstm-train.rs @@ -0,0 +1,104 @@ +use burn::{ + grad_clipping::GradientClippingConfig, optim::AdamConfig, tensor::backend::AutodiffBackend, +}; +use modern_lstm::{model::LstmNetworkConfig, training::TrainingConfig}; + +pub fn launch(device: B::Device) { + let config = TrainingConfig::new( + LstmNetworkConfig::new(), + // Gradient clipping via optimizer config + AdamConfig::new().with_grad_clipping(Some(GradientClippingConfig::Norm(1.0))), + ); + + modern_lstm::training::train::("/tmp/modern-lstm", config, device); +} + +#[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", +))] +mod ndarray { + use burn::backend::{ + ndarray::{NdArray, NdArrayDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(NdArrayDevice::Cpu); + } +} + +#[cfg(feature = "tch-gpu")] +mod tch_gpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + #[cfg(not(target_os = "macos"))] + let device = LibTorchDevice::Cuda(0); + #[cfg(target_os = "macos")] + let device = LibTorchDevice::Mps; + + launch::>(device); + } +} + +#[cfg(feature = "tch-cpu")] +mod tch_cpu { + use burn::backend::{ + libtorch::{LibTorch, LibTorchDevice}, + Autodiff, + }; + + use crate::launch; + + pub fn run() { + launch::>(LibTorchDevice::Cpu); + } +} + +#[cfg(feature = "wgpu")] +mod wgpu { + use crate::launch; + use burn::backend::{wgpu::Wgpu, Autodiff}; + + pub fn run() { + launch::>(Default::default()); + } +} + +#[cfg(feature = "cuda")] +mod cuda { + use crate::launch; + use burn::backend::{cuda::CudaDevice, Autodiff, Cuda}; + + pub fn run() { + launch::>(CudaDevice::default()); + } +} + +fn main() { + #[cfg(any( + feature = "ndarray", + feature = "ndarray-blas-netlib", + feature = "ndarray-blas-openblas", + feature = "ndarray-blas-accelerate", + ))] + ndarray::run(); + #[cfg(feature = "tch-gpu")] + tch_gpu::run(); + #[cfg(feature = "tch-cpu")] + tch_cpu::run(); + #[cfg(feature = "wgpu")] + wgpu::run(); + #[cfg(feature = "cuda")] + cuda::run(); +} diff --git a/examples/modern-lstm/src/dataset.rs b/examples/modern-lstm/src/dataset.rs new file mode 100644 index 0000000000..b2d04d525f --- /dev/null +++ b/examples/modern-lstm/src/dataset.rs @@ -0,0 +1,110 @@ +use burn::{ + data::{ + dataloader::batcher::Batcher, + dataset::{Dataset, InMemDataset}, + }, + prelude::*, +}; +use rand::Rng; +use rand_distr::{Distribution, Normal}; +use serde::{Deserialize, Serialize}; + +// Dataset parameters +pub const NUM_SEQUENCES: usize = 1000; +pub const SEQ_LENGTH: usize = 10; +pub const NOISE_LEVEL: f32 = 0.1; +pub const RANDOM_SEED: u64 = 5; + +// Generate a sequence where each number is the sum of previous two numbers plus noise +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SequenceDatasetItem { + pub sequence: Vec, + pub target: f32, +} + +impl SequenceDatasetItem { + pub fn new(seq_length: usize, noise_level: f32) -> Self { + // Start with two random numbers between 0 and 1 + let mut seq = vec![rand::thread_rng().gen(), rand::thread_rng().gen()]; + + // Generate sequence + for _i in 0..seq_length { + // Next number is sum of previous two plus noise + let normal = Normal::new(0.0, noise_level).unwrap(); + let next_val = + seq[seq.len() - 2] + seq[seq.len() - 1] + normal.sample(&mut rand::thread_rng()); + seq.push(next_val); + } + + Self { + // Convert to sequence and target + sequence: seq[0..seq.len() - 1].to_vec(), // All but last + target: seq[seq.len() - 1], // Last value + } + } +} + +// Custom Dataset for Sequence Data +pub struct SequenceDataset { + dataset: InMemDataset, +} + +impl SequenceDataset { + pub fn new(num_sequences: usize, seq_length: usize, noise_level: f32) -> Self { + let mut items = vec![]; + for _i in 0..num_sequences { + items.push(SequenceDatasetItem::new(seq_length, noise_level)); + } + let dataset = InMemDataset::new(items); + + Self { dataset } + } +} + +impl Dataset for SequenceDataset { + fn get(&self, index: usize) -> Option { + self.dataset.get(index) + } + + fn len(&self) -> usize { + self.dataset.len() + } +} + +#[derive(Clone, Debug)] +pub struct SequenceBatcher { + device: B::Device, +} + +#[derive(Clone, Debug)] +pub struct SequenceBatch { + pub sequences: Tensor, // [batch_size, seq_length, input_size] + pub targets: Tensor, // [batch_size, 1] +} + +impl SequenceBatcher { + pub fn new(device: B::Device) -> Self { + Self { device } + } +} + +impl Batcher> for SequenceBatcher { + fn batch(&self, items: Vec) -> SequenceBatch { + let mut sequences: Vec> = Vec::new(); + + for item in items.iter() { + let seq_tensor = Tensor::::from_floats(item.sequence.as_slice(), &self.device); + // Add feature dimension, the input_size is 1 implicitly. We can change the input_size here with some operations + sequences.push(seq_tensor.unsqueeze_dims(&[-1])); + } + let sequences = Tensor::stack(sequences, 0); + + let targets = items + .iter() + .map(|item| Tensor::::from_floats([item.target], &self.device)) + .collect(); + let targets = Tensor::stack(targets, 0); + + SequenceBatch { sequences, targets } + } +} diff --git a/examples/modern-lstm/src/inference.rs b/examples/modern-lstm/src/inference.rs new file mode 100644 index 0000000000..bad0af2996 --- /dev/null +++ b/examples/modern-lstm/src/inference.rs @@ -0,0 +1,45 @@ +use crate::{ + dataset::{ + SequenceBatcher, SequenceDataset, SequenceDatasetItem, NOISE_LEVEL, NUM_SEQUENCES, + SEQ_LENGTH, + }, + model::LstmNetwork, + training::TrainingConfig, +}; +use burn::{ + data::{dataloader::batcher::Batcher, dataset::Dataset}, + prelude::*, + record::{CompactRecorder, Recorder}, +}; +use polars::prelude::*; + +pub fn infer(artifact_dir: &str, device: B::Device) { + // Loading model + let config = TrainingConfig::load(format!("{artifact_dir}/config.json")) + .expect("Config should exist for the model; run train first"); + let record = CompactRecorder::new() + .load(format!("{artifact_dir}/model").into(), &device) + .expect("Trained model should exist; run train first"); + + let model: LstmNetwork = config.model.init(&device).load_record(record); + + let dataset = SequenceDataset::new(NUM_SEQUENCES / 5, SEQ_LENGTH, NOISE_LEVEL); + let items: Vec = dataset.iter().collect(); + + let batcher = SequenceBatcher::new(device); + // Put all items in one batch + let batch = batcher.batch(items); + let predicted = model.forward(batch.sequences, None); + let targets = batch.targets; + + let predicted = predicted.squeeze::<1>(1).into_data(); + let expected = targets.squeeze::<1>(1).into_data(); + + // Display the predicted vs expected values + let results = df![ + "predicted" => &predicted.to_vec::().unwrap(), + "expected" => &expected.to_vec::().unwrap(), + ] + .unwrap(); + println!("{}", &results.head(Some(10))); +} diff --git a/examples/modern-lstm/src/lib.rs b/examples/modern-lstm/src/lib.rs new file mode 100644 index 0000000000..1a167ffd75 --- /dev/null +++ b/examples/modern-lstm/src/lib.rs @@ -0,0 +1,4 @@ +pub mod dataset; +pub mod inference; +pub mod model; +pub mod training; diff --git a/examples/modern-lstm/src/model.rs b/examples/modern-lstm/src/model.rs new file mode 100644 index 0000000000..268de59a0b --- /dev/null +++ b/examples/modern-lstm/src/model.rs @@ -0,0 +1,362 @@ +use burn::{ + nn::{ + Dropout, DropoutConfig, Initializer, LayerNorm, LayerNormConfig, Linear, LinearConfig, + LstmState, Sigmoid, Tanh, + }, + prelude::*, +}; + +/// LSTM Cell implementation with layer normalization. +/// +/// Mathematical formulation of LSTM: +/// f_t = σ(W_f · [h_{t-1}, x_t] + b_f) # Forget gate +/// i_t = σ(W_i · [h_{t-1}, x_t] + b_i] # Input gate +/// g_t = tanh(W_g · [h_{t-1}, x_t] + b_g] # Candidate cell state +/// o_t = σ(W_o · [h_{t-1}, x_t] + b_o) # Output gate +/// +/// c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t # New cell state +/// h_t = o_t ⊙ tanh(c_t) # New hidden state +/// +/// where: +/// - σ is the sigmoid function +/// - ⊙ is the element-wise multiplication +/// - [h_{t-1}, x_t] represents concatenation + +#[derive(Module, Debug)] +pub struct LstmCell { + pub hidden_size: usize, + // Combined weight matrices for efficiency + // weight_ih layer uses combined weights for [i_t, f_t, g_t, o_t] for input x_t + // weight_hh layer uses combined weights for [i_t, f_t, g_t, o_t] for hidden state h_{t-1} + pub weight_ih: Linear, + pub weight_hh: Linear, + // Layer Normalization for better training stability. Don't use BatchNorm because the input distribution is always changing for LSTM. + pub norm_x: LayerNorm, // Normalize gate pre-activations + pub norm_h: LayerNorm, // Normalize hidden state + pub norm_c: LayerNorm, // Normalize cell state + pub dropout: Dropout, +} + +/// Configuration to create a Lstm module using the init function. +#[derive(Config, Debug)] +pub struct LstmCellConfig { + // The size of the input features + pub input_size: usize, + // The size of the hidden state + pub hidden_size: usize, + // The number of hidden layers + pub dropout: f64, +} + +impl LstmCellConfig { + // Initialize parameters using best practices: + // 1. Orthogonal initialization for better gradient flow (here we use Xavier because of the lack of Orthogonal in burn) + // 2. Initialize forget gate bias to 1.0 to prevent forgetting at start of training + #[allow(clippy::single_range_in_vec_init)] + pub fn init(&self, device: &B::Device) -> LstmCell { + let initializer = Initializer::XavierNormal { gain: 1.0 }; + let init_bias = Tensor::::ones([self.hidden_size], device); + + let mut weight_ih = LinearConfig::new(self.input_size, 4 * self.hidden_size) + .with_initializer(initializer.clone()) + .init(device); + // Set forget gate bias to 1.0 (helps with learning long sequences) + let bias = weight_ih + .bias + .clone() + .unwrap() + .val() + .slice_assign([self.hidden_size..2 * self.hidden_size], init_bias.clone()); + weight_ih.bias = weight_ih.bias.map(|p| p.map(|_t| bias)); + + let mut weight_hh = LinearConfig::new(self.hidden_size, 4 * self.hidden_size) + .with_initializer(initializer) + .init(device); + let bias = weight_hh + .bias + .clone() + .unwrap() + .val() + .slice_assign([self.hidden_size..2 * self.hidden_size], init_bias); + weight_hh.bias = weight_hh.bias.map(|p| p.map(|_t| bias)); + + LstmCell { + hidden_size: self.hidden_size, + weight_ih, + weight_hh, + norm_x: LayerNormConfig::new(4 * self.hidden_size).init(device), + norm_h: LayerNormConfig::new(self.hidden_size).init(device), + norm_c: LayerNormConfig::new(self.hidden_size).init(device), + dropout: DropoutConfig::new(self.dropout).init(), + } + } +} + +impl LstmCell { + /// Forward pass of LSTM cell. + /// Args: + /// x: Input tensor of shape (batch_size, input_size) + /// state: Tuple of (h_{t-1}, c_{t-1}) each of shape (batch_size, hidden_size) + /// Returns: + /// Tuple of (h_t, c_t) representing new hidden and cell states + pub fn forward(&self, x: Tensor, state: LstmState) -> LstmState { + let (h_prev, c_prev) = (state.hidden, state.cell); + + // Combined matrix multiplication for all gates + // Shape: (batch_size, 4 * hidden_size) + let gates_x = self.weight_ih.forward(x); // Transform input + let gates_h = self.weight_hh.forward(h_prev); // Transform previous hidden state + + // Apply layer normalization + let gates_x = self.norm_x.forward(gates_x); + // Combined gate pre-activations + let gates = gates_x + gates_h; + + // Split into individual gates + // Each gate shape: (batch_size, hidden_size) + let gates = gates.chunk(4, 1); + let i_gate = gates[0].clone(); + let f_gate = gates[1].clone(); + let g_gate = gates[2].clone(); + let o_gate = gates[3].clone(); + + // Apply gate non-linearities + let i_t = Sigmoid::new().forward(i_gate); + let f_t = Sigmoid::new().forward(f_gate); + let g_t = Tanh::new().forward(g_gate); + let o_t = Sigmoid::new().forward(o_gate); + + // Update cell state: c_t = f_t ⊙ c_{t-1} + i_t ⊙ g_t + let c_t = f_t * c_prev + i_t * g_t; + let c_t = self.norm_c.forward(c_t); + + // Update cell state: h_t = o_t ⊙ tanh(c_t) + let h_t = o_t * Tanh::new().forward(c_t.clone()); + let h_t = self.norm_h.forward(h_t); + + let h_t = self.dropout.forward(h_t); + + LstmState::new(h_t, c_t) + } + + // Initialize cell state and hidden state if provided or with zeros + pub fn init_state(&self, batch_size: usize, device: &B::Device) -> LstmState { + let cell = Tensor::zeros([batch_size, self.hidden_size], device); + let hidden = Tensor::zeros([batch_size, self.hidden_size], device); + + LstmState::new(cell, hidden) + } +} + +/// Stacked LSTM implementation supporting multiple layers +/// Each layer processes the output of the previous layer +#[derive(Module, Debug)] +pub struct StackedLstm { + pub layers: Vec>, +} + +#[derive(Config, Debug)] +pub struct StackedLstmConfig { + pub input_size: usize, + pub hidden_size: usize, + pub num_layers: usize, + pub dropout: f64, +} + +impl StackedLstmConfig { + pub fn init(&self, device: &B::Device) -> StackedLstm { + let mut layers: Vec> = vec![]; + // Create list of LSTM cells, one for each layer + for i in 0..self.num_layers { + if i == 0 { + if i < self.num_layers - 1 { + layers.push( + LstmCellConfig::new(self.input_size, self.hidden_size, self.dropout) + .init(device), + ); + } else { + // No dropout on last layer + layers.push( + LstmCellConfig::new(self.input_size, self.hidden_size, 0.0).init(device), + ); + } + } else if i < self.num_layers - 1 { + layers.push( + LstmCellConfig::new(self.hidden_size, self.hidden_size, self.dropout) + .init(device), + ); + } else { + // No dropout on last layer + layers.push( + LstmCellConfig::new(self.hidden_size, self.hidden_size, 0.0).init(device), + ); + } + } + StackedLstm { layers } + } +} + +impl StackedLstm { + /// Process input sequence through stacked LSTM layers. + /// + /// Args: + /// x: Input tensor of shape (batch_size, seq_length, input_size) + /// states: Optional initial states for each layer + /// + /// Returns: + /// Tuple of (output, states) where output has shape (batch_size, seq_length, hidden_size) + /// and states is a vector of length num_layers, both cell and hidden state in each element have shape (batch_size, hidden_size) + pub fn forward( + &self, + x: Tensor, + states: Option>>, + ) -> (Tensor, Vec>) { + let [batch_size, seq_length, _] = x.dims(); + let device = x.device(); + + let mut states = match states { + None => { + let mut temp: Vec> = vec![]; + for layer in self.layers.iter() { + temp.push(layer.init_state(batch_size, &device)); + } + temp + } + _ => states.unwrap(), + }; + + let mut layer_outputs = vec![]; + for t in 0..seq_length { + let mut input_t = x + .clone() + .slice([None, Some((t as i64, t as i64 + 1)), None]) + .squeeze::<2>(1); + for (i, lstm_cell) in self.layers.iter().enumerate() { + let mut state: LstmState = + LstmState::new(states[i].cell.clone(), states[i].hidden.clone()); + state = lstm_cell.forward(input_t, state); + input_t = state.hidden.clone(); + states[i] = state; + } + layer_outputs.push(input_t); + } + + // Stack output along sequence dimension + let output = Tensor::stack(layer_outputs, 1); + + (output, states) + } +} + +/// Complete LSTM network with bidirectional support. +/// +/// In bidirectional mode: +/// - Forward LSTM processes sequence from left to right +/// - Backward LSTM processes sequence from right to left +/// - Outputs are concatenated for final prediction +#[derive(Module, Debug)] +pub struct LstmNetwork { + // Forward direction LSTM + pub stacked_lstm: StackedLstm, + // Optional backward direction LSTM for bidirectional processing + pub reverse_lstm: Option>, + pub dropout: Dropout, + pub fc: Linear, +} + +#[derive(Config, Debug)] +pub struct LstmNetworkConfig { + #[config(default = 1)] + pub input_size: usize, // Single feature (number sequence) + #[config(default = 32)] + pub hidden_size: usize, // Size of LSTM hidden state + #[config(default = 2)] + pub num_layers: usize, // Number of LSTM layers + #[config(default = 1)] + pub output_size: usize, // Predict one number + #[config(default = 0.1)] + pub dropout: f64, + #[config(default = true)] + pub bidirectional: bool, // Use bidirectional LSTM +} + +impl LstmNetworkConfig { + pub fn init(&self, device: &B::Device) -> LstmNetwork { + // Forward direction LSTM + let stacked_lstm = StackedLstmConfig::new( + self.input_size, + self.hidden_size, + self.num_layers, + self.dropout, + ) + .init(device); + + // Optional backward direction LSTM for bidirectional processing + let (reverse_lstm, hidden_size) = if self.bidirectional { + let lstm = StackedLstmConfig::new( + self.input_size, + self.hidden_size, + self.num_layers, + self.dropout, + ) + .init(device); + (Some(lstm), 2 * self.hidden_size) + } else { + (None, self.hidden_size) + }; + + let fc = LinearConfig::new(hidden_size, self.output_size).init(device); + let dropout = DropoutConfig::new(self.dropout).init(); + + LstmNetwork { + stacked_lstm, + reverse_lstm, + dropout, + fc, + } + } +} + +impl LstmNetwork { + /// Forward pass of the network. + /// + /// For bidirectional processing: + /// 1. Process sequence normally with forward LSTM + /// 2. Process reversed sequence with backward LSTM + /// 3. Concatenate both outputs + /// 4. Apply final linear transformation + /// + /// Args: + /// x: Input tensor of shape (batch_size, seq_length, input_size) + /// states: Optional initial states + /// + /// Returns: + /// Output tensor of shape (batch_size, output_size) + pub fn forward(&self, x: Tensor, states: Option>>) -> Tensor { + let seq_length = x.dims()[1] as i64; + // Forward direction + let (mut output, _states) = self.stacked_lstm.forward(x.clone(), states); + + output = match &self.reverse_lstm { + Some(reverse_lstm) => { + //Process sequence in reverse direction + let (mut reverse_output, _states) = reverse_lstm.forward(x.flip([1]), None); + // Flip back to align with forward sequence + reverse_output = reverse_output.flip([1]); + // Concatenate forward and backward outputs along the feature dimension + output = Tensor::cat(vec![output, reverse_output], 2); + output + } + None => output, + }; + + // Apply dropout before final layer + output = self.dropout.forward(output); + // Use final timestep output for prediction + self.fc.forward( + output + .slice([None, Some((seq_length - 1, seq_length)), None]) + .squeeze::<2>(1), + ) + } +} diff --git a/examples/modern-lstm/src/training.rs b/examples/modern-lstm/src/training.rs new file mode 100644 index 0000000000..9f6af81328 --- /dev/null +++ b/examples/modern-lstm/src/training.rs @@ -0,0 +1,131 @@ +use crate::dataset::{ + SequenceBatcher, SequenceDataset, NOISE_LEVEL, NUM_SEQUENCES, RANDOM_SEED, SEQ_LENGTH, +}; +use crate::model::{LstmNetwork, LstmNetworkConfig}; +use burn::{ + data::dataloader::DataLoaderBuilder, + module::AutodiffModule, + nn::loss::{MseLoss, Reduction::Mean}, + optim::{AdamConfig, GradientsParams, Optimizer}, + prelude::*, + record::CompactRecorder, + tensor::backend::AutodiffBackend, +}; + +#[derive(Config)] +pub struct TrainingConfig { + pub model: LstmNetworkConfig, + pub optimizer: AdamConfig, + + #[config(default = 30)] + pub num_epochs: usize, + #[config(default = 32)] + pub batch_size: usize, + #[config(default = 2)] + pub num_workers: usize, + #[config(default = 1e-3)] + pub lr: f64, +} + +// Create the directory to save the model and model config +fn create_artifact_dir(artifact_dir: &str) { + // Remove existing artifacts + std::fs::remove_dir_all(artifact_dir).ok(); + std::fs::create_dir_all(artifact_dir).ok(); +} + +pub fn train(artifact_dir: &str, config: TrainingConfig, device: B::Device) { + create_artifact_dir(artifact_dir); + + // Save training config + config + .save(format!("{artifact_dir}/config.json")) + .expect("Config should be saved successfully"); + B::seed(RANDOM_SEED); + + // Create the model and optimizer + let mut model = config.model.init::(&device); + let mut optim = config.optimizer.init::>(); + + // Create the batcher + let batcher_train = SequenceBatcher::::new(device.clone()); + let batcher_valid = SequenceBatcher::::new(device.clone()); + + // Create the dataloaders + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(config.batch_size) + .shuffle(RANDOM_SEED) + .num_workers(config.num_workers) + .build(SequenceDataset::new(NUM_SEQUENCES, SEQ_LENGTH, NOISE_LEVEL)); + + let dataloader_valid = DataLoaderBuilder::new(batcher_valid) + .batch_size(config.batch_size) + .shuffle(RANDOM_SEED) + .num_workers(config.num_workers) + // 20% size of training + .build(SequenceDataset::new( + NUM_SEQUENCES / 5, + SEQ_LENGTH, + NOISE_LEVEL, + )); + + let train_num_items = dataloader_train.num_items(); + let valid_num_items = dataloader_valid.num_items(); + + println!("Starting training..."); + // Iterate over our training for X epochs + for epoch in 1..config.num_epochs + 1 { + // Initialize the training and validation metrics at the start of each epoch + let mut train_losses = vec![]; + let mut train_loss = 0.0; + let mut valid_losses = vec![]; + let mut valid_loss = 0.0; + + // Implement our training loop + for batch in dataloader_train.iter() { + let output = model.forward(batch.sequences, None); + let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean); + train_loss += loss.clone().into_scalar().elem::() * batch.targets.dims()[0] as f32; + + // Gradients for the current backward pass + let grads = loss.backward(); + // Gradients linked to each parameter of the model + let grads = GradientsParams::from_grads(grads, &model); + // Update the model using the optimizer + model = optim.step(config.lr, model, grads); + } + + // The averaged train loss per epoch + let avg_train_loss = train_loss / train_num_items as f32; + train_losses.push(avg_train_loss); + + // Get the model without autodiff + let valid_model = model.valid(); + + // Implement our validation loop + for batch in dataloader_valid.iter() { + let output = valid_model.forward(batch.sequences, None); + let loss = MseLoss::new().forward(output, batch.targets.clone(), Mean); + valid_loss += loss.clone().into_scalar().elem::() * batch.targets.dims()[0] as f32; + } + // The averaged train loss per epoch + let avg_valid_loss = valid_loss / valid_num_items as f32; + valid_losses.push(avg_valid_loss); + + // Display the averaged training and validataion metrics every 10 epochs + if (epoch + 1) % 5 == 0 { + println!( + "Epoch {}/{}, Avg Loss {:.4}, Avg Val Loss: {:.4}", + epoch + 1, + config.num_epochs, + avg_train_loss, + avg_valid_loss, + ); + } + } + + // Save the trained model + model + .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()) + .expect("Trained model should be saved successfully"); +}