Skip to content

Commit

Permalink
add fscore, reorganize mod and change error in doc. (#2648)
Browse files Browse the repository at this point in the history
* add fscore, reorganize mod and change error in doc.

* add function args to doc

* fscore to fbetascore and formatting metrics doc in book

* Minor doc fixes

---------

Co-authored-by: Tiago Sanona <[email protected]>
Co-authored-by: Guillaume Lagrange <[email protected]>
  • Loading branch information
3 people authored Jan 7, 2025
1 parent e2fa9c4 commit a644430
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 22 deletions.
3 changes: 2 additions & 1 deletion burn-book/src/building-blocks/metric.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ When working with the learner, you have the option to record metrics that will b
throughout the training process. We currently offer a restricted range of metrics.

| Metric | Description |
|------------------|---------------------------------------------------------|
| ---------------- | ------------------------------------------------------- |
| Accuracy | Calculate the accuracy in percentage |
| TopKAccuracy | Calculate the top-k accuracy in percentage |
| Precision | Calculate precision in percentage |
| Recall | Calculate recall in percentage |
| FBetaScore | Calculate F<sub>β </sub>score in percentage |
| AUROC | Calculate the area under curve of ROC in percentage |
| Loss | Output the loss used for the backward pass |
| CPU Temperature | Fetch the temperature of CPUs |
Expand Down
195 changes: 195 additions & 0 deletions crates/burn-train/src/metric/fbetascore.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
use super::{
classification::{ClassReduction, ClassificationMetricConfig, DecisionRule},
confusion_stats::{ConfusionStats, ConfusionStatsInput},
state::{FormatOptions, NumericMetricState},
Metric, MetricEntry, MetricMetadata, Numeric,
};
use burn_core::{
prelude::{Backend, Tensor},
tensor::cast::ToElement,
};
use core::marker::PhantomData;
use std::num::NonZeroUsize;

/// The [F-beta score](https://en.wikipedia.org/wiki/F-score) metric.
#[derive(Default)]
pub struct FBetaScoreMetric<B: Backend> {
state: NumericMetricState,
_b: PhantomData<B>,
config: ClassificationMetricConfig,
beta: f64,
}

impl<B: Backend> FBetaScoreMetric<B> {
/// F-beta score metric for binary classification.
///
/// # Arguments
///
/// * `beta` - Positive real factor to weight recall's importance.
/// * `threshold` - The threshold to transform a probability into a binary prediction.
#[allow(dead_code)]
pub fn binary(beta: f64, threshold: f64) -> Self {
Self {
config: ClassificationMetricConfig {
decision_rule: DecisionRule::Threshold(threshold),
// binary classification results are the same independently of class_reduction
..Default::default()
},
beta,
..Default::default()
}
}

/// F-beta score metric for multiclass classification.
///
/// # Arguments
///
/// * `beta` - Positive real factor to weight recall's importance.
/// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multiclass(beta: f64, top_k: usize, class_reduction: ClassReduction) -> Self {
Self {
config: ClassificationMetricConfig {
decision_rule: DecisionRule::TopK(
NonZeroUsize::new(top_k).expect("top_k must be non-zero"),
),
class_reduction,
},
beta,
..Default::default()
}
}

/// F-beta score metric for multi-label classification.
///
/// # Arguments
///
/// * `beta` - Positive real factor to weight recall's importance.
/// * `threshold` - The threshold to transform a probability into a binary prediction.
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multilabel(beta: f64, threshold: f64, class_reduction: ClassReduction) -> Self {
Self {
config: ClassificationMetricConfig {
decision_rule: DecisionRule::Threshold(threshold),
class_reduction,
},
beta,
..Default::default()
}
}

fn class_average(&self, mut aggregated_metric: Tensor<B, 1>) -> f64 {
use ClassReduction::{Macro, Micro};
let avg_tensor = match self.config.class_reduction {
Micro => aggregated_metric,
Macro => {
if aggregated_metric.contains_nan().any().into_scalar() {
let nan_mask = aggregated_metric.is_nan();
aggregated_metric = aggregated_metric
.clone()
.select(0, nan_mask.bool_not().argwhere().squeeze(1))
}
aggregated_metric.mean()
}
};
avg_tensor.into_scalar().to_f64()
}
}

impl<B: Backend> Metric for FBetaScoreMetric<B> {
const NAME: &'static str = "FBetaScore";
type Input = ConfusionStatsInput<B>;

fn update(&mut self, input: &Self::Input, _metadata: &MetricMetadata) -> MetricEntry {
let [sample_size, _] = input.predictions.dims();

let cf_stats = ConfusionStats::new(input, &self.config);
let scaled_true_positive = cf_stats.clone().true_positive() * (1.0 + self.beta.powi(2));
let metric = self.class_average(
scaled_true_positive.clone()
/ (scaled_true_positive
+ cf_stats.clone().false_negative() * self.beta.powi(2)
+ cf_stats.false_positive()),
);

self.state.update(
100.0 * metric,
sample_size,
FormatOptions::new(Self::NAME).unit("%").precision(2),
)
}

fn clear(&mut self) {
self.state.reset()
}
}

impl<B: Backend> Numeric for FBetaScoreMetric<B> {
fn value(&self) -> f64 {
self.state.value()
}
}

#[cfg(test)]
mod tests {
use super::{
ClassReduction::{self, *},
FBetaScoreMetric, Metric, MetricMetadata, Numeric,
};
use crate::tests::{dummy_classification_input, ClassificationType, THRESHOLD};
use burn_core::tensor::TensorData;
use rstest::rstest;

#[rstest]
#[case::binary_b1(1.0, THRESHOLD, 0.5)]
#[case::binary_b2(2.0, THRESHOLD, 0.5)]
fn test_binary_fscore(#[case] beta: f64, #[case] threshold: f64, #[case] expected: f64) {
let input = dummy_classification_input(&ClassificationType::Binary).into();
let mut metric = FBetaScoreMetric::binary(beta, threshold);
let _entry = metric.update(&input, &MetricMetadata::fake());
TensorData::from([metric.value()])
.assert_approx_eq(&TensorData::from([expected * 100.0]), 3)
}

#[rstest]
#[case::multiclass_b1_micro_k1(1.0, Micro, 1, 3.0/5.0)]
#[case::multiclass_b1_micro_k2(1.0, Micro, 2, 2.0/(5.0/4.0 + 10.0/4.0))]
#[case::multiclass_b1_macro_k1(1.0, Macro, 1, (0.5 + 2.0/(1.0 + 2.0) + 2.0/(2.0 + 1.0))/3.0)]
#[case::multiclass_b1_macro_k2(1.0, Macro, 2, (2.0/(1.0 + 2.0) + 2.0/(1.0 + 4.0) + 0.5)/3.0)]
#[case::multiclass_b2_micro_k1(2.0, Micro, 1, 3.0/5.0)]
#[case::multiclass_b2_micro_k2(2.0, Micro, 2, 5.0*4.0/(4.0*5.0 + 10.0))]
#[case::multiclass_b2_macro_k1(2.0, Macro, 1, (0.5 + 5.0/(4.0 + 2.0) + 5.0/(8.0 + 1.0))/3.0)]
#[case::multiclass_b2_macro_k2(2.0, Macro, 2, (5.0/(4.0 + 2.0) + 5.0/(4.0 + 4.0) + 0.5)/3.0)]
fn test_multiclass_fscore(
#[case] beta: f64,
#[case] class_reduction: ClassReduction,
#[case] top_k: usize,
#[case] expected: f64,
) {
let input = dummy_classification_input(&ClassificationType::Multiclass).into();
let mut metric = FBetaScoreMetric::multiclass(beta, top_k, class_reduction);
let _entry = metric.update(&input, &MetricMetadata::fake());
TensorData::from([metric.value()])
.assert_approx_eq(&TensorData::from([expected * 100.0]), 3)
}

#[rstest]
#[case::multilabel_micro(1.0, Micro, THRESHOLD, 2.0/(9.0/5.0 + 8.0/5.0))]
#[case::multilabel_macro(1.0, Macro, THRESHOLD, (2.0/(2.0 + 3.0/2.0) + 2.0/(1.0 + 3.0/2.0) + 2.0/(3.0+2.0))/3.0)]
#[case::multilabel_micro(2.0, Micro, THRESHOLD, 5.0/(4.0*9.0/5.0 + 8.0/5.0))]
#[case::multilabel_macro(2.0, Macro, THRESHOLD, (5.0/(8.0 + 3.0/2.0) + 5.0/(4.0 + 3.0/2.0) + 5.0/(12.0+2.0))/3.0)]
fn test_multilabel_fscore(
#[case] beta: f64,
#[case] class_reduction: ClassReduction,
#[case] threshold: f64,
#[case] expected: f64,
) {
let input = dummy_classification_input(&ClassificationType::Multilabel).into();
let mut metric = FBetaScoreMetric::multilabel(beta, threshold, class_reduction);
let _entry = metric.update(&input, &MetricMetadata::fake());
TensorData::from([metric.value()])
.assert_approx_eq(&TensorData::from([expected * 100.0]), 3)
}
}
41 changes: 23 additions & 18 deletions crates/burn-train/src/metric/mod.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,48 @@
/// State module.
pub mod state;
/// Module responsible to save and exposes data collected during training.
pub mod store;

mod acc;
mod auroc;
mod base;
#[cfg(feature = "metrics")]
mod confusion_stats;
#[cfg(feature = "metrics")]
mod cpu_temp;
#[cfg(feature = "metrics")]
mod cpu_use;
#[cfg(feature = "metrics")]
mod cuda;
#[cfg(feature = "metrics")]
mod fbetascore;
mod hamming;
#[cfg(feature = "metrics")]
mod iteration;
mod learning_rate;
mod loss;
#[cfg(feature = "metrics")]
mod memory_use;

#[cfg(feature = "metrics")]
mod iteration;
mod precision;
#[cfg(feature = "metrics")]
mod recall;
#[cfg(feature = "metrics")]
mod top_k_acc;

pub use acc::*;
pub use auroc::*;
pub use base::*;
#[cfg(feature = "metrics")]
pub use confusion_stats::ConfusionStatsInput;
#[cfg(feature = "metrics")]
pub use cpu_temp::*;
#[cfg(feature = "metrics")]
pub use cpu_use::*;
#[cfg(feature = "metrics")]
pub use cuda::*;
#[cfg(feature = "metrics")]
pub use fbetascore::*;
pub use hamming::*;
#[cfg(feature = "metrics")]
pub use iteration::*;
Expand All @@ -38,25 +51,17 @@ pub use loss::*;
#[cfg(feature = "metrics")]
pub use memory_use::*;
#[cfg(feature = "metrics")]
pub use precision::*;
#[cfg(feature = "metrics")]
pub use recall::*;
#[cfg(feature = "metrics")]
pub use top_k_acc::*;

#[cfg(feature = "metrics")]
pub(crate) mod classification;
pub(crate) mod processor;
// Expose `ItemLazy` so it can be implemented for custom types
pub use processor::ItemLazy;

/// Module responsible to save and exposes data collected during training.
pub mod store;

pub(crate) mod classification;
#[cfg(feature = "metrics")]
pub use crate::metric::classification::ClassReduction;
mod confusion_stats;
pub use confusion_stats::ConfusionStatsInput;
#[cfg(feature = "metrics")]
mod precision;
#[cfg(feature = "metrics")]
pub use precision::*;
#[cfg(feature = "metrics")]
mod recall;
#[cfg(feature = "metrics")]
pub use recall::*;
// Expose `ItemLazy` so it can be implemented for custom types
pub use processor::ItemLazy;
4 changes: 3 additions & 1 deletion crates/burn-train/src/metric/precision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ impl<B: Backend> PrecisionMetric<B> {
/// # Arguments
///
/// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self {
Self {
Expand All @@ -60,6 +61,7 @@ impl<B: Backend> PrecisionMetric<B> {
/// # Arguments
///
/// * `threshold` - The threshold to transform a probability into a binary value.
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self {
Self {
Expand Down Expand Up @@ -129,7 +131,7 @@ mod tests {
use rstest::rstest;

#[rstest]
#[case::binary_macro(THRESHOLD, 0.5)]
#[case::binary(THRESHOLD, 0.5)]
fn test_binary_precision(#[case] threshold: f64, #[case] expected: f64) {
let input = dummy_classification_input(&ClassificationType::Binary).into();
let mut metric = PrecisionMetric::binary(threshold);
Expand Down
6 changes: 4 additions & 2 deletions crates/burn-train/src/metric/recall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use burn_core::{
use core::marker::PhantomData;
use std::num::NonZeroUsize;

///The Precision Metric
///The Recall Metric
#[derive(Default)]
pub struct RecallMetric<B: Backend> {
state: NumericMetricState,
Expand Down Expand Up @@ -42,6 +42,7 @@ impl<B: Backend> RecallMetric<B> {
/// # Arguments
///
/// * `top_k` - The number of highest predictions considered to find the correct label (typically `1`).
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multiclass(top_k: usize, class_reduction: ClassReduction) -> Self {
Self {
Expand All @@ -60,6 +61,7 @@ impl<B: Backend> RecallMetric<B> {
/// # Arguments
///
/// * `threshold` - The threshold to transform a probability into a binary prediction.
/// * `class_reduction` - [Class reduction](ClassReduction) type.
#[allow(dead_code)]
pub fn multilabel(threshold: f64, class_reduction: ClassReduction) -> Self {
Self {
Expand Down Expand Up @@ -128,7 +130,7 @@ mod tests {
use rstest::rstest;

#[rstest]
#[case::binary_macro(THRESHOLD, 0.5)]
#[case::binary(THRESHOLD, 0.5)]
fn test_binary_recall(#[case] threshold: f64, #[case] expected: f64) {
let input = dummy_classification_input(&ClassificationType::Binary).into();
let mut metric = RecallMetric::binary(threshold);
Expand Down

0 comments on commit a644430

Please sign in to comment.