Skip to content

Commit

Permalink
add test documentation and refactor function
Browse files Browse the repository at this point in the history
  • Loading branch information
ToluClassics committed Oct 23, 2023
1 parent 33a54f1 commit e4c07b5
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
3 changes: 1 addition & 2 deletions candle-nn/src/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,9 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
/// The resulting tensor is a scalar containing the average value over the batch.
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
let inp = crate::ops::sigmoid(inp)?;
let one_tensor = Tensor::new(1.0, &inp.device())?;

let left_side = target * inp.log()?;
let right_side = (one_tensor.broadcast_sub(&target)?) * (one_tensor.broadcast_sub(&inp)?.log()?);
let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;

let loss = left_side? + right_side?;
let loss = loss?.neg()?.mean_all()?;
Expand Down
33 changes: 21 additions & 12 deletions candle-nn/tests/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ fn nll_and_cross_entropy() -> Result<()> {
Ok(())
}


/* Equivalent python code:
import torch
import torch.nn.functional as F
inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178],
[ 0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[ 1.3081, 0.6641, 1.1802, -0.2547],
[ 0.5292, 0.7636, 0.3692, -0.8318]])
target = torch.Tensor([[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.]])
print(F.binary_cross_entropy_with_logits(inp, target))
*/
#[test]
fn binary_cross_entropy_with_logit() -> Result<()> {
let cpu = Device::Cpu;
Expand All @@ -48,29 +67,19 @@ fn binary_cross_entropy_with_logit() -> Result<()> {
[ 0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[ 1.3081, 0.6641, 1.1802, -0.2547],
[ 0.5292, 0.7636, 0.3692, -0.8318],
[ 0.5100, 0.9849, -1.2905, 0.2821],
[ 1.4662, 0.4550, 0.9875, 0.3143],
[-1.2121, 0.1262, 0.0598, -1.6363],
[ 0.3214, -0.8689, 0.0689, -2.5094],
[ 1.1320, -0.6824, 0.1657, -0.0687]];
[ 0.5292, 0.7636, 0.3692, -0.8318]];

let target = [[0.0f64, 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 1., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.]];

let inp = Tensor::new(&inp, &cpu)?;
let target = Tensor::new(&target, &cpu)?;

let loss = candle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?;

assert_eq!(to_vec0_round(&loss, 4)?, 0.7739);
assert_eq!(to_vec0_round(&loss, 4)?, 0.8224);
Ok(())
}

0 comments on commit e4c07b5

Please sign in to comment.