diff --git a/candle-nn/src/loss.rs b/candle-nn/src/loss.rs index 669a8c684..2e9dbd2fc 100644 --- a/candle-nn/src/loss.rs +++ b/candle-nn/src/loss.rs @@ -62,10 +62,9 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result { /// The resulting tensor is a scalar containing the average value over the batch. pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result { let inp = crate::ops::sigmoid(inp)?; - let one_tensor = Tensor::new(1.0, &inp.device())?; let left_side = target * inp.log()?; - let right_side = (one_tensor.broadcast_sub(&target)?) * (one_tensor.broadcast_sub(&inp)?.log()?); + let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?; let loss = left_side? + right_side?; let loss = loss?.neg()?.mean_all()?; diff --git a/candle-nn/tests/loss.rs b/candle-nn/tests/loss.rs index 718fb0719..9027de7f6 100644 --- a/candle-nn/tests/loss.rs +++ b/candle-nn/tests/loss.rs @@ -40,6 +40,25 @@ fn nll_and_cross_entropy() -> Result<()> { Ok(()) } + +/* Equivalent python code: +import torch +import torch.nn.functional as F + +inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178], + [ 0.0419, 0.0763, -1.0457, -1.6692], + [-1.0494, 0.8111, 1.5723, 1.2315], + [ 1.3081, 0.6641, 1.1802, -0.2547], + [ 0.5292, 0.7636, 0.3692, -0.8318]]) + +target = torch.Tensor([[0., 1., 0., 0.], + [0., 1., 0., 0.], + [0., 0., 0., 1.], + [1., 0., 0., 0.], + [0., 0., 1., 0.]]) + +print(F.binary_cross_entropy_with_logits(inp, target)) +*/ #[test] fn binary_cross_entropy_with_logit() -> Result<()> { let cpu = Device::Cpu; @@ -48,22 +67,12 @@ fn binary_cross_entropy_with_logit() -> Result<()> { [ 0.0419, 0.0763, -1.0457, -1.6692], [-1.0494, 0.8111, 1.5723, 1.2315], [ 1.3081, 0.6641, 1.1802, -0.2547], - [ 0.5292, 0.7636, 0.3692, -0.8318], - [ 0.5100, 0.9849, -1.2905, 0.2821], - [ 1.4662, 0.4550, 0.9875, 0.3143], - [-1.2121, 0.1262, 0.0598, -1.6363], - [ 0.3214, -0.8689, 0.0689, -2.5094], - [ 1.1320, -0.6824, 0.1657, -0.0687]]; + [ 0.5292, 0.7636, 0.3692, -0.8318]]; let target = [[0.0f64, 1., 0., 0.], [0., 1., 0., 0.], [0., 0., 0., 1.], [1., 0., 0., 0.], - [0., 0., 1., 0.], - [1., 0., 0., 0.], - [0., 0., 1., 0.], - [0., 0., 1., 0.], - [0., 1., 0., 0.], [0., 0., 1., 0.]]; let inp = Tensor::new(&inp, &cpu)?; @@ -71,6 +80,6 @@ fn binary_cross_entropy_with_logit() -> Result<()> { let loss = candle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?; - assert_eq!(to_vec0_round(&loss, 4)?, 0.7739); + assert_eq!(to_vec0_round(&loss, 4)?, 0.8224); Ok(()) }