Skip to content

Commit

Permalink
added requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
salvomcl committed Feb 3, 2025
1 parent 940a67a commit 852a33a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
2 changes: 1 addition & 1 deletion burn-book/src/building-blocks/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,4 +294,4 @@ Burn comes with built-in modules that you can use to build your own modules.
| `CrossEntropyLoss` | `nn.CrossEntropyLoss` |
| `MseLoss` | `nn.MSELoss` |
| `HuberLoss` | `nn.HuberLoss` |
| `PoissonNLLLoss` | `nn.PoissonNLLLoss` |
| `PoissonNllLoss` | `nn.PoissonNLLLoss` |
47 changes: 24 additions & 23 deletions crates/burn-core/src/nn/loss/poisson.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::f32::consts::PI;
use core::f32::consts::PI;

use crate as burn;
use crate::module::{Content, DisplaySettings, ModuleDisplay};
Expand All @@ -8,13 +8,13 @@ use crate::{config::Config, module::Module};

use super::Reduction;

/// Configuration for creating a [PoissonNLLLoss](PoissonNLLLoss) instance.
/// Configuration for creating a [PoissonNllLoss](PoissonNllLoss) instance.
///
/// This configuration allows customization of the Poisson Negative Log Likelihood (NLL) loss
/// behavior, such as whether the input is in log-space, whether to include the Stirling
/// approximation term, and a small epsilon value to avoid numerical instability.
#[derive(Config, Debug)]
pub struct PoissonNLLLossConfig {
pub struct PoissonNllLossConfig {
/// If `true`, the predictions are expected to be in log-space.
///
/// When `log_input` is `true`, the loss is computed as:
Expand Down Expand Up @@ -43,14 +43,14 @@ pub struct PoissonNLLLossConfig {
pub eps: f64,
}

impl PoissonNLLLossConfig {
/// Initializes a [PoissonNLLLoss](PoissonNLLLoss) instance with the current configuration.
impl PoissonNllLossConfig {
/// Initializes a [PoissonNllLoss](PoissonNllLoss) instance with the current configuration.
///
/// # Panics
/// - Panics if `eps` is not a positive number.
pub fn init(&self) -> PoissonNLLLoss {
pub fn init(&self) -> PoissonNllLoss {
self.assertions();
PoissonNLLLoss {
PoissonNllLoss {
log_input: self.log_input,
full: self.full,
eps: self.eps,
Expand All @@ -64,7 +64,7 @@ impl PoissonNLLLossConfig {
fn assertions(&self) {
assert!(
self.eps > 0.,
"eps for PoissonNLLLoss must be a positive number."
"eps for PoissonNllLoss must be a positive number."
);
}
}
Expand All @@ -84,7 +84,7 @@ impl PoissonNLLLossConfig {
/// <https://en.wikipedia.org/wiki/Poisson_regression#Maximum_likelihood-based_parameter_estimation>
#[derive(Module, Debug, Clone)]
#[module(custom_display)]
pub struct PoissonNLLLoss {
pub struct PoissonNllLoss {
/// If `true`, the predictions are expected to be in log-space.
pub log_input: bool,
/// Whether to compute the full loss, including the Stirling approximation term.
Expand All @@ -93,7 +93,7 @@ pub struct PoissonNLLLoss {
pub eps: f64,
}

impl ModuleDisplay for PoissonNLLLoss {
impl ModuleDisplay for PoissonNllLoss {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
Expand All @@ -109,7 +109,7 @@ impl ModuleDisplay for PoissonNLLLoss {
}
}

impl PoissonNLLLoss {
impl PoissonNllLoss {
/// Computes the loss element-wise for the given predictions and targets, then reduces
/// the result to a single loss value.
///
Expand Down Expand Up @@ -226,7 +226,7 @@ mod tests {
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);

let poisson = PoissonNLLLossConfig::new().init();
let poisson = PoissonNllLossConfig::new().init();

let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
Expand All @@ -252,7 +252,7 @@ mod tests {
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);

let poisson = PoissonNLLLossConfig::new().with_log_input(false).init();
let poisson = PoissonNllLossConfig::new().with_log_input(false).init();

let loss_no_reduction = poisson.forward_no_reduction(predictions.clone(), targets.clone());

Expand All @@ -270,7 +270,7 @@ mod tests {
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);

let poisson = PoissonNLLLossConfig::new().with_full(true).init();
let poisson = PoissonNllLossConfig::new().with_full(true).init();

let loss_sum = poisson.forward(predictions.clone(), targets.clone(), Reduction::Sum);
let loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
Expand All @@ -286,6 +286,7 @@ mod tests {
loss_sum.into_data().assert_approx_eq(&expected, 5);
}

#[cfg(feature = "std")]
#[test]
fn test_poisson_nll_loss_gradients() {
type TestAutodiffTensor = Tensor<crate::TestAutodiffBackend, 1>;
Expand All @@ -299,8 +300,8 @@ mod tests {
let predictions2 = predictions1.clone();
let targets = TestAutodiffTensor::from_data(targets, &device);

let poisson = PoissonNLLLossConfig::new().with_full(false).init();
let poisson_full = PoissonNLLLossConfig::new().with_full(true).init();
let poisson = PoissonNllLossConfig::new().with_full(false).init();
let poisson_full = PoissonNllLossConfig::new().with_full(true).init();

let loss_sum = poisson.forward(predictions1.clone(), targets.clone(), Reduction::Sum);
let loss_full_sum =
Expand All @@ -323,9 +324,9 @@ mod tests {
}

#[test]
#[should_panic = "eps for PoissonNLLLoss must be a positive number."]
#[should_panic = "eps for PoissonNllLoss must be a positive number."]
fn test_negative_eps() {
let _poisson = PoissonNLLLossConfig::new().with_eps(0.).init();
let _poisson = PoissonNllLossConfig::new().with_eps(0.).init();
}

#[test]
Expand All @@ -339,7 +340,7 @@ mod tests {
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);

let poisson = PoissonNLLLossConfig::new().init();
let poisson = PoissonNllLossConfig::new().init();

let _loss = poisson.forward(predictions.clone(), targets.clone(), Reduction::Auto);
}
Expand All @@ -355,7 +356,7 @@ mod tests {
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);

let poisson = PoissonNLLLossConfig::new().init();
let poisson = PoissonNllLossConfig::new().init();

let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
}
Expand All @@ -371,19 +372,19 @@ mod tests {
let predictions = TestTensor::<1>::from_data(predictions, &device);
let targets = TestTensor::<1>::from_data(targets, &device);

let poisson = PoissonNLLLossConfig::new().with_log_input(false).init();
let poisson = PoissonNllLossConfig::new().with_log_input(false).init();

let _loss = poisson.forward_no_reduction(predictions.clone(), targets.clone());
}

#[test]
fn display() {
let config = PoissonNLLLossConfig::new();
let config = PoissonNllLossConfig::new();
let loss = config.init();

assert_eq!(
alloc::format!("{}", loss),
"PoissonNLLLoss {log_input: true, full: false, eps: 0.00000001}"
"PoissonNllLoss {log_input: true, full: false, eps: 0.00000001}"
);
}
}

0 comments on commit 852a33a

Please sign in to comment.