Skip to content

Commit

Permalink
fix test cases and formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
ToluClassics committed Oct 23, 2023
1 parent e4c07b5 commit 36d7866
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 13 deletions.
3 changes: 1 addition & 2 deletions candle-nn/src/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
(inp - target)?.sqr()?.mean_all()
}


/// The binary cross-entropy with logit loss.
///
/// Arguments
Expand All @@ -68,6 +67,6 @@ pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<

let loss = left_side? + right_side?;
let loss = loss?.neg()?.mean_all()?;

Ok(loss)
}
25 changes: 14 additions & 11 deletions candle-nn/tests/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ fn nll_and_cross_entropy() -> Result<()> {
Ok(())
}


/* Equivalent python code:
import torch
import torch.nn.functional as F
Expand All @@ -63,17 +62,21 @@ print(F.binary_cross_entropy_with_logits(inp, target))
fn binary_cross_entropy_with_logit() -> Result<()> {
let cpu = Device::Cpu;

let inp = [[ 2.3611f64, -0.8813, -0.5006, -0.2178],
[ 0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[ 1.3081, 0.6641, 1.1802, -0.2547],
[ 0.5292, 0.7636, 0.3692, -0.8318]];
let inp = [
[2.3611f32, -0.8813, -0.5006, -0.2178],
[0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[1.3081, 0.6641, 1.1802, -0.2547],
[0.5292, 0.7636, 0.3692, -0.8318],
];

let target = [[0.0f64, 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.]];
let target = [
[0.0f32, 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
];

let inp = Tensor::new(&inp, &cpu)?;
let target = Tensor::new(&target, &cpu)?;
Expand Down

0 comments on commit 36d7866

Please sign in to comment.