Skip to content

Commit

Permalink
Fix bce loss log
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Jan 24, 2025
1 parent e40c69b commit afd3279
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions crates/burn-core/src/nn/loss/binary_cross_entropy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ impl<B: Backend> BinaryCrossEntropyLoss<B> {
(targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits)
} else {
// - (target * log(input) + (1 - target) * log(1 - input))
(targets_float.clone() * logits.clone().log()
+ (targets_float.neg() + 1.) * (logits.neg() + 1.).log())
.neg()
// https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values
(targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)
- targets_float * logits.log().clamp_min(-100.0)
};

if let Some(weights) = &self.weights {
Expand Down Expand Up @@ -171,6 +171,38 @@ mod tests {
use crate::tensor::{activation::sigmoid, TensorData};
use crate::TestBackend;

#[test]
fn test_binary_cross_entropy_preds_all_correct() {
let device = Default::default();
let preds = Tensor::<TestBackend, 1>::from_floats([1.0, 0.0, 1.0, 0.0], &device);
let targets =
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);

let loss_actual = BinaryCrossEntropyLossConfig::new()
.init(&device)
.forward(preds, targets)
.into_data();

let loss_expected = TensorData::from([0.000]);
loss_actual.assert_approx_eq(&loss_expected, 3);
}

#[test]
fn test_binary_cross_entropy_preds_all_incorrect() {
let device = Default::default();
let preds = Tensor::<TestBackend, 1>::from_floats([0.0, 1.0, 0.0, 1.0], &device);
let targets =
Tensor::<TestBackend, 1, Int>::from_data(TensorData::from([1, 0, 1, 0]), &device);

let loss_actual = BinaryCrossEntropyLossConfig::new()
.init(&device)
.forward(preds, targets)
.into_data();

let loss_expected = TensorData::from([100.000]); // clamped value
loss_actual.assert_approx_eq(&loss_expected, 3);
}

#[test]
fn test_binary_cross_entropy() {
// import torch
Expand Down

0 comments on commit afd3279

Please sign in to comment.