From a644430b0a07f3113b401b75c0c379415428631d Mon Sep 17 00:00:00 2001 From: Tiago Sanona <40792244+tsanona@users.noreply.github.com> Date: Tue, 7 Jan 2025 21:08:27 +0100 Subject: [PATCH] add fscore, reorganize mod and change error in doc. (#2648) * add fscore, reorganize mod and change error in doc. * add function args to doc * fscore to fbetascore and formatting metrics doc in book * Minor doc fixes --------- Co-authored-by: Tiago Sanona Co-authored-by: Guillaume Lagrange --- burn-book/src/building-blocks/metric.md | 3 +- crates/burn-train/src/metric/fbetascore.rs | 195 +++++++++++++++++++++ crates/burn-train/src/metric/mod.rs | 41 +++-- crates/burn-train/src/metric/precision.rs | 4 +- crates/burn-train/src/metric/recall.rs | 6 +- 5 files changed, 227 insertions(+), 22 deletions(-) create mode 100644 crates/burn-train/src/metric/fbetascore.rs diff --git a/burn-book/src/building-blocks/metric.md b/burn-book/src/building-blocks/metric.md index e029aca708..e5dd4eaae9 100644 --- a/burn-book/src/building-blocks/metric.md +++ b/burn-book/src/building-blocks/metric.md @@ -4,11 +4,12 @@ When working with the learner, you have the option to record metrics that will b throughout the training process. We currently offer a restricted range of metrics. | Metric | Description | -|------------------|---------------------------------------------------------| +| ---------------- | ------------------------------------------------------- | | Accuracy | Calculate the accuracy in percentage | | TopKAccuracy | Calculate the top-k accuracy in percentage | | Precision | Calculate precision in percentage | | Recall | Calculate recall in percentage | +| FBetaScore | Calculate Fβ score in percentage | | AUROC | Calculate the area under curve of ROC in percentage | | Loss | Output the loss used for the backward pass | | CPU Temperature | Fetch the temperature of CPUs | diff --git a/crates/burn-train/src/metric/fbetascore.rs b/crates/burn-train/src/metric/fbetascore.rs new file mode 100644 index 0000000000..5eeba0aa9c --- /dev/null +++ b/crates/burn-train/src/metric/fbetascore.rs @@ -0,0 +1,195 @@ +use super::{ + classification::{ClassReduction, ClassificationMetricConfig, DecisionRule}, + confusion_stats::{ConfusionStats, ConfusionStatsInput}, + state::{FormatOptions, NumericMetricState}, + Metric, MetricEntry, MetricMetadata, Numeric, +}; +use burn_core::{ + prelude::{Backend, Tensor}, + tensor::cast::ToElement, +}; +use core::marker::PhantomData; +use std::num::NonZeroUsize; + +/// The [F-beta score](https://en.wikipedia.org/wiki/F-score) metric. +#[derive(Default)] +pub struct FBetaScoreMetric { + state: NumericMetricState, + _b: PhantomData, + config: ClassificationMetricConfig, + beta: f64, +} + +impl FBetaScoreMetric { + /// F-beta score metric for binary classification. + /// + /// # Arguments + /// + /// * `beta` - Positive real factor to weight recall's importance. + /// * `threshold` - The threshold to transform a probability into a binary prediction. + #[allow(dead_code)] + pub fn binary(beta: f64, threshold: f64) -> Self { + Self { + config: ClassificationMetricConfig { + decision_rule: DecisionRule::Threshold(threshold), + // binary classification results are the same independently of class_reduction + ..Default::default() + }, + beta, + ..Default::default() + } + } + + /// F-beta score metric for multiclass classification. + /// + /// # Arguments + /// + /// * `beta` - Positive real factor to weight recall's importance. + /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). + /// * `class_reduction` - [Class reduction](ClassReduction) type. + #[allow(dead_code)] + pub fn multiclass(beta: f64, top_k: usize, class_reduction: ClassReduction) -> Self { + Self { + config: ClassificationMetricConfig { + decision_rule: DecisionRule::TopK( + NonZeroUsize::new(top_k).expect("top_k must be non-zero"), + ), + class_reduction, + }, + beta, + ..Default::default() + } + } + + /// F-beta score metric for multi-label classification. + /// + /// # Arguments + /// + /// * `beta` - Positive real factor to weight recall's importance. + /// * `threshold` - The threshold to transform a probability into a binary prediction. + /// * `class_reduction` - [Class reduction](ClassReduction) type. + #[allow(dead_code)] + pub fn multilabel(beta: f64, threshold: f64, class_reduction: ClassReduction) -> Self { + Self { + config: ClassificationMetricConfig { + decision_rule: DecisionRule::Threshold(threshold), + class_reduction, + }, + beta, + ..Default::default() + } + } + + fn class_average(&self, mut aggregated_metric: Tensor) -> f64 { + use ClassReduction::{Macro, Micro}; + let avg_tensor = match self.config.class_reduction { + Micro => aggregated_metric, + Macro => { + if aggregated_metric.contains_nan().any().into_scalar() { + let nan_mask = aggregated_metric.is_nan(); + aggregated_metric = aggregated_metric + .clone() + .select(0, nan_mask.bool_not().argwhere().squeeze(1)) + } + aggregated_metric.mean() + } + }; + avg_tensor.into_scalar().to_f64() + } +} + +impl Metric for FBetaScoreMetric { + const NAME: &'static str = "FBetaScore"; + type Input = ConfusionStatsInput; + + fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry { + let [sample_size, _] = input.predictions.dims(); + + let cf_stats = ConfusionStats::new(input, &self.config); + let scaled_true_positive = cf_stats.clone().true_positive() * (1.0 + self.beta.powi(2)); + let metric = self.class_average( + scaled_true_positive.clone() + / (scaled_true_positive + + cf_stats.clone().false_negative() * self.beta.powi(2) + + cf_stats.false_positive()), + ); + + self.state.update( + 100.0 * metric, + sample_size, + FormatOptions::new(Self::NAME).unit("%").precision(2), + ) + } + + fn clear(&mut self) { + self.state.reset() + } +} + +impl Numeric for FBetaScoreMetric { + fn value(&self) -> f64 { + self.state.value() + } +} + +#[cfg(test)] +mod tests { + use super::{ + ClassReduction::{self, *}, + FBetaScoreMetric, Metric, MetricMetadata, Numeric, + }; + use crate::tests::{dummy_classification_input, ClassificationType, THRESHOLD}; + use burn_core::tensor::TensorData; + use rstest::rstest; + + #[rstest] + #[case::binary_b1(1.0, THRESHOLD, 0.5)] + #[case::binary_b2(2.0, THRESHOLD, 0.5)] + fn test_binary_fscore(#[case] beta: f64, #[case] threshold: f64, #[case] expected: f64) { + let input = dummy_classification_input(&ClassificationType::Binary).into(); + let mut metric = FBetaScoreMetric::binary(beta, threshold); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } + + #[rstest] + #[case::multiclass_b1_micro_k1(1.0, Micro, 1, 3.0/5.0)] + #[case::multiclass_b1_micro_k2(1.0, Micro, 2, 2.0/(5.0/4.0 + 10.0/4.0))] + #[case::multiclass_b1_macro_k1(1.0, Macro, 1, (0.5 + 2.0/(1.0 + 2.0) + 2.0/(2.0 + 1.0))/3.0)] + #[case::multiclass_b1_macro_k2(1.0, Macro, 2, (2.0/(1.0 + 2.0) + 2.0/(1.0 + 4.0) + 0.5)/3.0)] + #[case::multiclass_b2_micro_k1(2.0, Micro, 1, 3.0/5.0)] + #[case::multiclass_b2_micro_k2(2.0, Micro, 2, 5.0*4.0/(4.0*5.0 + 10.0))] + #[case::multiclass_b2_macro_k1(2.0, Macro, 1, (0.5 + 5.0/(4.0 + 2.0) + 5.0/(8.0 + 1.0))/3.0)] + #[case::multiclass_b2_macro_k2(2.0, Macro, 2, (5.0/(4.0 + 2.0) + 5.0/(4.0 + 4.0) + 0.5)/3.0)] + fn test_multiclass_fscore( + #[case] beta: f64, + #[case] class_reduction: ClassReduction, + #[case] top_k: usize, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Multiclass).into(); + let mut metric = FBetaScoreMetric::multiclass(beta, top_k, class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } + + #[rstest] + #[case::multilabel_micro(1.0, Micro, THRESHOLD, 2.0/(9.0/5.0 + 8.0/5.0))] + #[case::multilabel_macro(1.0, Macro, THRESHOLD, (2.0/(2.0 + 3.0/2.0) + 2.0/(1.0 + 3.0/2.0) + 2.0/(3.0+2.0))/3.0)] + #[case::multilabel_micro(2.0, Micro, THRESHOLD, 5.0/(4.0*9.0/5.0 + 8.0/5.0))] + #[case::multilabel_macro(2.0, Macro, THRESHOLD, (5.0/(8.0 + 3.0/2.0) + 5.0/(4.0 + 3.0/2.0) + 5.0/(12.0+2.0))/3.0)] + fn test_multilabel_fscore( + #[case] beta: f64, + #[case] class_reduction: ClassReduction, + #[case] threshold: f64, + #[case] expected: f64, + ) { + let input = dummy_classification_input(&ClassificationType::Multilabel).into(); + let mut metric = FBetaScoreMetric::multilabel(beta, threshold, class_reduction); + let _entry = metric.update(&input, &MetricMetadata::fake()); + TensorData::from([metric.value()]) + .assert_approx_eq(&TensorData::from([expected * 100.0]), 3) + } +} diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index e6358e3023..191099a383 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -1,23 +1,32 @@ /// State module. pub mod state; +/// Module responsible to save and exposes data collected during training. +pub mod store; mod acc; mod auroc; mod base; #[cfg(feature = "metrics")] +mod confusion_stats; +#[cfg(feature = "metrics")] mod cpu_temp; #[cfg(feature = "metrics")] mod cpu_use; #[cfg(feature = "metrics")] mod cuda; +#[cfg(feature = "metrics")] +mod fbetascore; mod hamming; +#[cfg(feature = "metrics")] +mod iteration; mod learning_rate; mod loss; #[cfg(feature = "metrics")] mod memory_use; - #[cfg(feature = "metrics")] -mod iteration; +mod precision; +#[cfg(feature = "metrics")] +mod recall; #[cfg(feature = "metrics")] mod top_k_acc; @@ -25,11 +34,15 @@ pub use acc::*; pub use auroc::*; pub use base::*; #[cfg(feature = "metrics")] +pub use confusion_stats::ConfusionStatsInput; +#[cfg(feature = "metrics")] pub use cpu_temp::*; #[cfg(feature = "metrics")] pub use cpu_use::*; #[cfg(feature = "metrics")] pub use cuda::*; +#[cfg(feature = "metrics")] +pub use fbetascore::*; pub use hamming::*; #[cfg(feature = "metrics")] pub use iteration::*; @@ -38,25 +51,17 @@ pub use loss::*; #[cfg(feature = "metrics")] pub use memory_use::*; #[cfg(feature = "metrics")] +pub use precision::*; +#[cfg(feature = "metrics")] +pub use recall::*; +#[cfg(feature = "metrics")] pub use top_k_acc::*; +#[cfg(feature = "metrics")] +pub(crate) mod classification; pub(crate) mod processor; -// Expose `ItemLazy` so it can be implemented for custom types -pub use processor::ItemLazy; - -/// Module responsible to save and exposes data collected during training. -pub mod store; -pub(crate) mod classification; #[cfg(feature = "metrics")] pub use crate::metric::classification::ClassReduction; -mod confusion_stats; -pub use confusion_stats::ConfusionStatsInput; -#[cfg(feature = "metrics")] -mod precision; -#[cfg(feature = "metrics")] -pub use precision::*; -#[cfg(feature = "metrics")] -mod recall; -#[cfg(feature = "metrics")] -pub use recall::*; +// Expose `ItemLazy` so it can be implemented for custom types +pub use processor::ItemLazy; diff --git a/crates/burn-train/src/metric/precision.rs b/crates/burn-train/src/metric/precision.rs index 067261cbdf..375d368795 100644 --- a/crates/burn-train/src/metric/precision.rs +++ b/crates/burn-train/src/metric/precision.rs @@ -42,6 +42,7 @@ impl PrecisionMetric { /// # Arguments /// /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). + /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { @@ -60,6 +61,7 @@ impl PrecisionMetric { /// # Arguments /// /// * `threshold` - The threshold to transform a probability into a binary value. + /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { @@ -129,7 +131,7 @@ mod tests { use rstest::rstest; #[rstest] - #[case::binary_macro(THRESHOLD, 0.5)] + #[case::binary(THRESHOLD, 0.5)] fn test_binary_precision(#[case] threshold: f64, #[case] expected: f64) { let input = dummy_classification_input(&ClassificationType::Binary).into(); let mut metric = PrecisionMetric::binary(threshold); diff --git a/crates/burn-train/src/metric/recall.rs b/crates/burn-train/src/metric/recall.rs index 8ce4351396..5003ddcd03 100644 --- a/crates/burn-train/src/metric/recall.rs +++ b/crates/burn-train/src/metric/recall.rs @@ -11,7 +11,7 @@ use burn_core::{ use core::marker::PhantomData; use std::num::NonZeroUsize; -///The Precision Metric +///The Recall Metric #[derive(Default)] pub struct RecallMetric { state: NumericMetricState, @@ -42,6 +42,7 @@ impl RecallMetric { /// # Arguments /// /// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`). + /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self { Self { @@ -60,6 +61,7 @@ impl RecallMetric { /// # Arguments /// /// * `threshold` - The threshold to transform a probability into a binary prediction. + /// * `class_reduction` - [Class reduction](ClassReduction) type. #[allow(dead_code)] pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self { Self { @@ -128,7 +130,7 @@ mod tests { use rstest::rstest; #[rstest] - #[case::binary_macro(THRESHOLD, 0.5)] + #[case::binary(THRESHOLD, 0.5)] fn test_binary_recall(#[case] threshold: f64, #[case] expected: f64) { let input = dummy_classification_input(&ClassificationType::Binary).into(); let mut metric = RecallMetric::binary(threshold);