From 9f003203d05a0b260c3cd5ad44a7460dfcffc67d Mon Sep 17 00:00:00 2001 From: SalvoMcL <64030770+salvomcl@users.noreply.github.com> Date: Mon, 3 Feb 2025 16:05:14 +0100 Subject: [PATCH] Feat: Add PoissonNLL loss (#2765) * added PoissonNLLLossConfig * added PoissonNLLLoss * added tests * update docs * added requested changes --- burn-book/src/building-blocks/module.md | 1 + crates/burn-core/src/nn/loss/mod.rs | 2 + crates/burn-core/src/nn/loss/poisson.rs | 390 ++++++++++++++++++++++++ 3 files changed, 393 insertions(+) create mode 100644 crates/burn-core/src/nn/loss/poisson.rs diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index 0f5aca7f24..9598d6e39e 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -294,3 +294,4 @@ Burn comes with built-in modules that you can use to build your own modules. | `CrossEntropyLoss` | `nn.CrossEntropyLoss` | | `MseLoss` | `nn.MSELoss` | | `HuberLoss` | `nn.HuberLoss` | +| `PoissonNllLoss` | `nn.PoissonNLLLoss` | diff --git a/crates/burn-core/src/nn/loss/mod.rs b/crates/burn-core/src/nn/loss/mod.rs index cca7b4541b..475364e63b 100644 --- a/crates/burn-core/src/nn/loss/mod.rs +++ b/crates/burn-core/src/nn/loss/mod.rs @@ -2,10 +2,12 @@ mod binary_cross_entropy; mod cross_entropy; mod huber; mod mse; +mod poisson; mod reduction; pub use binary_cross_entropy::*; pub use cross_entropy::*; pub use huber::*; pub use mse::*; +pub use poisson::*; pub use reduction::*; diff --git a/crates/burn-core/src/nn/loss/poisson.rs b/crates/burn-core/src/nn/loss/poisson.rs new file mode 100644 index 0000000000..3cc989ad8e --- /dev/null +++ b/crates/burn-core/src/nn/loss/poisson.rs @@ -0,0 +1,390 @@ +use core::f32::consts::PI; + +use crate as burn; +use crate::module::{Content, DisplaySettings, ModuleDisplay}; +use crate::tensor::backend::Backend; +use crate::tensor::Tensor; +use crate::{config::Config, module::Module}; + +use super::Reduction; + +/// Configuration for creating a [PoissonNllLoss](PoissonNllLoss) instance. +/// +/// This configuration allows customization of the Poisson Negative Log Likelihood (NLL) loss +/// behavior, such as whether the input is in log-space, whether to include the Stirling +/// approximation term, and a small epsilon value to avoid numerical instability. +#[derive(Config, Debug)] +pub struct PoissonNllLossConfig { + /// If `true`, the predictions are expected to be in log-space. + /// + /// When `log_input` is `true`, the loss is computed as: + /// ```text + /// L(predictions, target) = exp(predictions) - target * predictions + /// ``` + /// When `log_input` is `false`, the loss is computed as: + /// ```text + /// L(predictions, target) = predictions - target * log(predictions + eps) + /// ``` + #[config(default = true)] + pub log_input: bool, + /// Whether to compute the full loss, including the Stirling approximation term. + /// + /// When `full` is `true`, the Stirling approximation term is added to the loss: + /// ```text + /// target * log(target) - target + 0.5 * log(2 * PI * target) + /// ``` + #[config(default = false)] + pub full: bool, + /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`. + /// + /// This epsilon value is added to the predictions to ensure numerical stability + /// when computing the logarithm. + #[config(default = 1e-8)] + pub eps: f64, +} + +impl PoissonNllLossConfig { + /// Initializes a [PoissonNllLoss](PoissonNllLoss) instance with the current configuration. + /// + /// # Panics + /// - Panics if `eps` is not a positive number. + pub fn init(&self) -> PoissonNllLoss { + self.assertions(); + PoissonNllLoss { + log_input: self.log_input, + full: self.full, + eps: self.eps, + } + } + + /// Validates the configuration parameters. + /// + /// # Panics + /// - Panics if `eps` is not a positive number. + fn assertions(&self) { + assert!( + self.eps > 0., + "eps for PoissonNllLoss must be a positive number." + ); + } +} + +/// Negative Log Likelihood (NLL) loss with a Poisson distribution assumption for the target. +/// +/// This loss function is used when the target values are assumed to follow a Poisson distribution. +/// The loss is defined as: +/// ```text +/// target ~ Poisson(input) +/// L(predictions, target) = predictions - target * log(predictions) + log(target!) +/// ``` +/// The last term (`log(target!)`) can be omitted or approximated using Stirling's formula. +/// The approximation is applied for `target > 1`, while for `target <= 1`, zeros are added to the loss. +/// +/// For more details, see: +/// +#[derive(Module, Debug, Clone)] +#[module(custom_display)] +pub struct PoissonNllLoss { + /// If `true`, the predictions are expected to be in log-space. + pub log_input: bool, + /// Whether to compute the full loss, including the Stirling approximation term. + pub full: bool, + /// A small value to avoid evaluation of `log(0)` when `log_input` is `false`. + pub eps: f64, +} + +impl ModuleDisplay for PoissonNllLoss { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("log_input", &self.log_input) + .add("full", &self.full) + .add("eps", &self.eps) + .optional() + } +} + +impl PoissonNllLoss { + /// Computes the loss element-wise for the given predictions and targets, then reduces + /// the result to a single loss value. + /// + /// # Arguments + /// - `predictions`: The predicted values. + /// - `targets`: The target values. + /// - `reduction`: The reduction method to apply. `Reduction::Auto` behaves as `Reduction::Mean`. + /// + /// # Shapes + /// - `predictions`: `[...dims]` + /// - `targets`: `[...dims]` + /// - `output`: `[1]` + /// + /// # Panics + /// - Panics if the shapes of `predictions` and `targets` do not match. + /// - Panics if any target value is negative. + /// - Panics if `log_input` is `false` and any prediction value is negative. + pub fn forward( + &self, + predictions: Tensor, + targets: Tensor, + reduction: Reduction, + ) -> Tensor { + let loss = self.forward_no_reduction(predictions, targets); + match reduction { + Reduction::Mean | Reduction::Auto => loss.mean(), + Reduction::Sum => loss.sum(), + } + } + + /// Computes the loss element-wise for the given predictions and targets without reduction. + /// + /// # Arguments + /// - `predictions`: The predicted values. + /// - `targets`: The target values. + /// + /// # Shapes + /// - `predictions`: `[...dims]` + /// - `targets`: `[...dims]` + /// - `output`: `[...dims]` + /// + /// # Panics + /// - Panics if the shapes of `predictions` and `targets` do not match. + /// - Panics if any target value is negative. + /// - Panics if `log_input` is `false` and any prediction value is negative. + pub fn forward_no_reduction( + &self, + predictions: Tensor, + targets: Tensor, + ) -> Tensor { + self.assertions(&predictions, &targets); + let mut loss; + if self.log_input { + loss = predictions.clone().exp() - targets.clone() * predictions; + } else { + loss = predictions.clone() - targets.clone() * (predictions + self.eps).log(); + } + if self.full { + let log_stirling_term = targets.clone() * targets.clone().log() - targets.clone() + + (targets.clone() * 2. * PI).log() * 0.5; + loss = loss + + log_stirling_term + .mask_where(targets.clone().lower_equal_elem(1), targets.zeros_like()); + } + loss + } + + /// Validates the input tensors for the loss computation. + /// + /// # Panics + /// - Panics if the shapes of `predictions` and `targets` do not match. + /// - Panics if any target value is negative. + /// - Panics if `log_input` is `false` and any prediction value is negative. + fn assertions( + &self, + predictions: &Tensor, + targets: &Tensor, + ) { + let predictions_dims = predictions.dims(); + let targets_dims = targets.dims(); + assert!( + predictions_dims == targets_dims, + "Shape of targets ({:?}) should correspond to outer shape of predictions ({:?}).", + targets_dims, + predictions_dims + ); + assert!( + targets.clone().greater_equal_elem(0.).all().into_scalar(), + "All the values of `targets` must be non-negative." + ); + if !self.log_input { + assert!( + predictions.clone().greater_equal_elem(0.).all().into_scalar(), + "When `log_input` is `false`, all the values of `predictions` must be non-negative." + ); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor::TensorData; + use crate::TestBackend; + type TestTensor = Tensor; + + #[test] + fn test_poisson_nll_loss() { + let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().init(); + + let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum); + let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); + let loss_no_reduction = poisson.forward_no_reduction(predictions, targets); + + let expected = TensorData::from([1.0000, 1.0000, 100.0000, 2.7183, 7.3891, 14.0855]); + loss_no_reduction.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([21.0321]); + loss.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([126.1929]); + loss_sum.into_data().assert_approx_eq(&expected, 5); + } + + #[test] + fn test_poisson_nll_loss_no_log_input() { + let predictions = TensorData::from([0.0, 0.5, 1.0, 1.0, 2.71828, 7.38905, 20.0855]); + let targets = TensorData::from([2., 3., 1., 4.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_log_input(false).init(); + + let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone()); + + let expected = TensorData::from([36.84136, 2.579441, 1.0, 1.0, 2.71828, 7.38905, 14.0855]); + loss_no_reduction.into_data().assert_approx_eq(&expected, 5); + } + + #[test] + fn test_poisson_nll_loss_full() { + let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_full(true).init(); + + let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum); + let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); + let loss_no_reduction = poisson.forward_no_reduction(predictions, targets); + + let expected = TensorData::from([1.0000, 4.9393, 101.1678, 2.7183, 7.3891, 14.7373]); + loss_no_reduction.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([21.9920]); + loss.into_data().assert_approx_eq(&expected, 5); + + let expected = TensorData::from([131.9518]); + loss_sum.into_data().assert_approx_eq(&expected, 5); + } + + #[cfg(feature = "std")] + #[test] + fn test_poisson_nll_loss_gradients() { + type TestAutodiffTensor = Tensor; + + let predictions = TensorData::from([0., 0., -40., 1., 2., 3.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2.]); + + let device = Default::default(); + + let predictions1 = TestAutodiffTensor::from_data(predictions, &device).require_grad(); + let predictions2 = predictions1.clone(); + let targets = TestAutodiffTensor::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_full(false).init(); + let poisson_full = PoissonNllLossConfig::new().with_full(true).init(); + + let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum); + let loss_full_sum = + poisson_full.forward(predictions2.clone(), targets.clone(), Reduction::Sum); + + let grads = loss_sum.backward(); + let grads_full = loss_full_sum.backward(); + + let grads_predictions1 = predictions1.grad(&grads).unwrap(); + let grads_predictions2 = predictions2.grad(&grads_full).unwrap(); + + let expected = TensorData::from([0.0000, -3.5000, -2.5000, 2.7183, 7.3891, 18.0855]); + + grads_predictions1 + .into_data() + .assert_approx_eq(&expected, 5); + grads_predictions2 + .into_data() + .assert_approx_eq(&expected, 5); + } + + #[test] + #[should_panic = "eps for PoissonNllLoss must be a positive number."] + fn test_negative_eps() { + let _poisson = PoissonNllLossConfig::new().with_eps(0.).init(); + } + + #[test] + #[should_panic = "All the values of `targets` must be non-negative."] + fn test_targets_with_negative_values() { + let predictions = TensorData::from([0., 0., -40., 1., 2., 3., 4.]); + let targets = TensorData::from([1., 4.5, 2.5, 0., 0., 2., -0.42]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().init(); + + let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); + } + + #[test] + #[should_panic = "Shape of targets"] + fn test_shape_tensors() { + let predictions = TensorData::from([0., 1., 2.]); + let targets = TensorData::from([0., 1.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().init(); + + let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone()); + } + + #[test] + #[should_panic = "When `log_input` is `false`, all the values of `predictions` must be non-negative."] + fn test_exp_predictions_non_negative() { + let predictions = TensorData::from([0.3, -0.1, 0.4]); + let targets = TensorData::from([0., 1., 0.]); + + let device = Default::default(); + + let predictions = TestTensor::<1>::from_data(predictions, &device); + let targets = TestTensor::<1>::from_data(targets, &device); + + let poisson = PoissonNllLossConfig::new().with_log_input(false).init(); + + let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone()); + } + + #[test] + fn display() { + let config = PoissonNllLossConfig::new(); + let loss = config.init(); + + assert_eq!( + alloc::format!("{}", loss), + "PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}" + ); + } +}