diff --git a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs index f645c84fd9..54b80f4f60 100644 --- a/crates/burn-core/src/nn/loss/binary_cross_entropy.rs +++ b/crates/burn-core/src/nn/loss/binary_cross_entropy.rs @@ -118,9 +118,9 @@ impl BinaryCrossEntropyLoss { (targets_float.neg() + 1.) * logits.clone() - log_sigmoid(logits) } else { // - (target * log(input) + (1 - target) * log(1 - input)) - (targets_float.clone() * logits.clone().log() - + (targets_float.neg() + 1.) * (logits.neg() + 1.).log()) - .neg() + // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values + (targets_float.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0) + - targets_float * logits.log().clamp_min(-100.0) }; if let Some(weights) = &self.weights { @@ -171,6 +171,38 @@ mod tests { use crate::tensor::{activation::sigmoid, TensorData}; use crate::TestBackend; + #[test] + fn test_binary_cross_entropy_preds_all_correct() { + let device = Default::default(); + let preds = Tensor::::from_floats([1.0, 0.0, 1.0, 0.0], &device); + let targets = + Tensor::::from_data(TensorData::from([1, 0, 1, 0]), &device); + + let loss_actual = BinaryCrossEntropyLossConfig::new() + .init(&device) + .forward(preds, targets) + .into_data(); + + let loss_expected = TensorData::from([0.000]); + loss_actual.assert_approx_eq(&loss_expected, 3); + } + + #[test] + fn test_binary_cross_entropy_preds_all_incorrect() { + let device = Default::default(); + let preds = Tensor::::from_floats([0.0, 1.0, 0.0, 1.0], &device); + let targets = + Tensor::::from_data(TensorData::from([1, 0, 1, 0]), &device); + + let loss_actual = BinaryCrossEntropyLossConfig::new() + .init(&device) + .forward(preds, targets) + .into_data(); + + let loss_expected = TensorData::from([100.000]); // clamped value + loss_actual.assert_approx_eq(&loss_expected, 3); + } + #[test] fn test_binary_cross_entropy() { // import torch