From 852a33ae04de85fd965f30c561660618c332b5ad Mon Sep 17 00:00:00 2001 From: salvomcl Date: Mon, 3 Feb 2025 15:38:25 +0100 Subject: [PATCH] added requested changes --- burn-book/src/building-blocks/module.md | 2 +- crates/burn-core/src/nn/loss/poisson.rs | 47 +++++++++++++------------ 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index d10243a37f..9598d6e39e 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -294,4 +294,4 @@ Burn comes with built-in modules that you can use to build your own modules. | `CrossEntropyLoss` | `nn.CrossEntropyLoss` | | `MseLoss` | `nn.MSELoss` | | `HuberLoss` | `nn.HuberLoss` | -| `PoissonNLLLoss` | `nn.PoissonNLLLoss` | +| `PoissonNllLoss` | `nn.PoissonNLLLoss` | diff --git a/crates/burn-core/src/nn/loss/poisson.rs b/crates/burn-core/src/nn/loss/poisson.rs index 8b4e56bb38..3cc989ad8e 100644 --- a/crates/burn-core/src/nn/loss/poisson.rs +++ b/crates/burn-core/src/nn/loss/poisson.rs @@ -1,4 +1,4 @@ -use std::f32::consts::PI; +use core::f32::consts::PI; use crate as burn; use crate::module::{Content, DisplaySettings, ModuleDisplay}; @@ -8,13 +8,13 @@ use crate::{config::Config, module::Module}; use super::Reduction; -/// Configuration for creating a [PoissonNLLLoss](PoissonNLLLoss) instance. +/// Configuration for creating a [PoissonNllLoss](PoissonNllLoss) instance. /// /// This configuration allows customization of the Poisson Negative Log Likelihood (NLL) loss /// behavior, such as whether the input is in log-space, whether to include the Stirling /// approximation term, and a small epsilon value to avoid numerical instability. #[derive(Config, Debug)] -pub struct PoissonNLLLossConfig { +pub struct PoissonNllLossConfig { /// If `true`, the predictions are expected to be in log-space. /// /// When `log_input` is `true`, the loss is computed as: @@ -43,14 +43,14 @@ pub struct PoissonNLLLossConfig { pub eps: f64, } -impl PoissonNLLLossConfig { - /// Initializes a [PoissonNLLLoss](PoissonNLLLoss) instance with the current configuration. +impl PoissonNllLossConfig { + /// Initializes a [PoissonNllLoss](PoissonNllLoss) instance with the current configuration. /// /// # Panics /// - Panics if `eps` is not a positive number. - pub fn init(&self) -> PoissonNLLLoss { + pub fn init(&self) -> PoissonNllLoss { self.assertions(); - PoissonNLLLoss { + PoissonNllLoss { log_input: self.log_input, full: self.full, eps: self.eps, @@ -64,7 +64,7 @@ impl PoissonNLLLossConfig { fn assertions(&self) { assert!( self.eps > 0., - "eps for PoissonNLLLoss must be a positive number." + "eps for PoissonNllLoss must be a positive number." ); } } @@ -84,7 +84,7 @@ impl PoissonNLLLossConfig { /// #[derive(Module, Debug, Clone)] #[module(custom_display)] -pub struct PoissonNLLLoss { +pub struct PoissonNllLoss { /// If `true`, the predictions are expected to be in log-space. pub log_input: bool, /// Whether to compute the full loss, including the Stirling approximation term. @@ -93,7 +93,7 @@ pub struct PoissonNLLLoss { pub eps: f64, } -impl ModuleDisplay for PoissonNLLLoss { +impl ModuleDisplay for PoissonNllLoss { fn custom_settings(&self) -> Option { DisplaySettings::new() .with_new_line_after_attribute(false) @@ -109,7 +109,7 @@ impl ModuleDisplay for PoissonNLLLoss { } } -impl PoissonNLLLoss { +impl PoissonNllLoss { /// Computes the loss element-wise for the given predictions and targets, then reduces /// the result to a single loss value. /// @@ -226,7 +226,7 @@ mod tests { let predictions = TestTensor::<1>::from_data(predictions, &device); let targets = TestTensor::<1>::from_data(targets, &device); - let poisson = PoissonNLLLossConfig::new().init(); + let poisson = PoissonNllLossConfig::new().init(); let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum); let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); @@ -252,7 +252,7 @@ mod tests { let predictions = TestTensor::<1>::from_data(predictions, &device); let targets = TestTensor::<1>::from_data(targets, &device); - let poisson = PoissonNLLLossConfig::new().with_log_input(false).init(); + let poisson = PoissonNllLossConfig::new().with_log_input(false).init(); let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone()); @@ -270,7 +270,7 @@ mod tests { let predictions = TestTensor::<1>::from_data(predictions, &device); let targets = TestTensor::<1>::from_data(targets, &device); - let poisson = PoissonNLLLossConfig::new().with_full(true).init(); + let poisson = PoissonNllLossConfig::new().with_full(true).init(); let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum); let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); @@ -286,6 +286,7 @@ mod tests { loss_sum.into_data().assert_approx_eq(&expected, 5); } + #[cfg(feature = "std")] #[test] fn test_poisson_nll_loss_gradients() { type TestAutodiffTensor = Tensor; @@ -299,8 +300,8 @@ mod tests { let predictions2 = predictions1.clone(); let targets = TestAutodiffTensor::from_data(targets, &device); - let poisson = PoissonNLLLossConfig::new().with_full(false).init(); - let poisson_full = PoissonNLLLossConfig::new().with_full(true).init(); + let poisson = PoissonNllLossConfig::new().with_full(false).init(); + let poisson_full = PoissonNllLossConfig::new().with_full(true).init(); let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum); let loss_full_sum = @@ -323,9 +324,9 @@ mod tests { } #[test] - #[should_panic = "eps for PoissonNLLLoss must be a positive number."] + #[should_panic = "eps for PoissonNllLoss must be a positive number."] fn test_negative_eps() { - let _poisson = PoissonNLLLossConfig::new().with_eps(0.).init(); + let _poisson = PoissonNllLossConfig::new().with_eps(0.).init(); } #[test] @@ -339,7 +340,7 @@ mod tests { let predictions = TestTensor::<1>::from_data(predictions, &device); let targets = TestTensor::<1>::from_data(targets, &device); - let poisson = PoissonNLLLossConfig::new().init(); + let poisson = PoissonNllLossConfig::new().init(); let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto); } @@ -355,7 +356,7 @@ mod tests { let predictions = TestTensor::<1>::from_data(predictions, &device); let targets = TestTensor::<1>::from_data(targets, &device); - let poisson = PoissonNLLLossConfig::new().init(); + let poisson = PoissonNllLossConfig::new().init(); let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone()); } @@ -371,19 +372,19 @@ mod tests { let predictions = TestTensor::<1>::from_data(predictions, &device); let targets = TestTensor::<1>::from_data(targets, &device); - let poisson = PoissonNLLLossConfig::new().with_log_input(false).init(); + let poisson = PoissonNllLossConfig::new().with_log_input(false).init(); let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone()); } #[test] fn display() { - let config = PoissonNLLLossConfig::new(); + let config = PoissonNllLossConfig::new(); let loss = config.init(); assert_eq!( alloc::format!("{}", loss), - "PoissonNLLLoss {log_input: true, full: false, eps: 0.00000001}" + "PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}" ); } }