Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Binary cross entropy with logit numerically stable for high logit values #2562

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions candle-nn/src/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,15 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to raw logits.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
/// * [target]: The ground truth labels as a tensor of dimension `N, C` where `N` is the batch size and `C` the number
/// of categories.
///
/// The resulting tensor is a scalar containing the average value over the batch.
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
let inp = crate::ops::sigmoid(inp)?;
let log_sigmoid_input = crate::ops::sigmoid(inp)?.log()?;

let left_side = target * inp.log()?;
let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;
let loss = (1.0 - target)?.mul(inp)?.sub(&log_sigmoid_input)?.mean_all()?;

let loss = left_side? + right_side?;
let loss = loss?.neg()?.mean_all()?;

Ok(loss)
}
50 changes: 50 additions & 0 deletions candle-nn/tests/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,53 @@ fn binary_cross_entropy_with_logit() -> Result<()> {
assert_eq!(to_vec0_round(&loss, 4)?, 0.8224);
Ok(())
}

/*
Test high logit

Equivalent python code:
import torch
import torch.nn.functional as F

inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178],
[ 0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[ 1.3081, 0.6641, 1.1802, -0.2547],
[ 0.5292, 0.7636, 0.3692, 28.8318]])

target = torch.Tensor([[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.]])

print(F.binary_cross_entropy_with_logits(inp, target))
*/
#[test]
fn binary_cross_entropy_with_high_logit() -> Result<()> {
let cpu = Device::Cpu;

let inp = [
[2.3611f32, -0.8813, -0.5006, -0.2178],
[0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[1.3081, 0.6641, 1.1802, -0.2547],
[0.5292, 0.7636, 0.3692, 28.8318],
];

let target = [
[0.0f32, 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
];

let inp = Tensor::new(&inp, &cpu)?;
let target = Tensor::new(&target, &cpu)?;

let loss = candle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?;

assert_eq!(to_vec0_round(&loss, 4)?, 2.246);
Ok(())
}