Skip to content

Commit cd848b1

Browse files
authored
Add is_nan and contains_nan tensor ops (#2088)
* Add is_nan and contains_nan tensor ops * Enable nan test for burn-candle * Disabling tests due to #2089
1 parent 27d42cd commit cd848b1

File tree

6 files changed

+68
-5
lines changed

6 files changed

+68
-5
lines changed

burn-book/src/building-blocks/tensor.md

+8-5
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
190190
| `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` |
191191
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
192192
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
193+
| `tensor.contains_nan()` | N/A |
193194
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
194195
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
195196
| `tensor.equal_elem(other)` | `tensor.eq(other)` |
@@ -199,6 +200,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
199200
| `tensor.greater_equal(other)` | `tensor.ge(other)` |
200201
| `tensor.greater_equal_elem(scalar)` | `tensor.ge(scalar)` |
201202
| `tensor.is_close(other, atol, rtol)` | `torch.isclose(tensor, other, atol, rtol)` |
203+
| `tensor.is_nan()` | `torch.isnan(tensor)` |
202204
| `tensor.lower(other)` | `tensor.lt(other)` |
203205
| `tensor.lower_elem(scalar)` | `tensor.lt(scalar)` |
204206
| `tensor.lower_equal(other)` | `tensor.le(other)` |
@@ -304,12 +306,13 @@ Those operations are only available for `Bool` tensors.
304306

305307
### Quantization Operations
306308

307-
Those operations are only available for `Float` tensors on backends that implement quantization strategies.
309+
Those operations are only available for `Float` tensors on backends that implement quantization
310+
strategies.
308311

309-
| Burn API | PyTorch Equivalent |
310-
| ------------------------------------ | ------------------------------- |
311-
| `tensor.quantize(scheme, qparams)` | N/A |
312-
| `tensor.dequantize()` | N/A |
312+
| Burn API | PyTorch Equivalent |
313+
| ---------------------------------- | ------------------ |
314+
| `tensor.quantize(scheme, qparams)` | N/A |
315+
| `tensor.dequantize()` | N/A |
313316

314317
## Activation Functions
315318

crates/burn-candle/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ mod tests {
8787
burn_tensor::testgen_flip!();
8888
burn_tensor::testgen_argwhere_nonzero!();
8989
burn_tensor::testgen_sign!();
90+
burn_tensor::testgen_nan!();
9091

9192
// TODO: https://github.com/tracel-ai/burn/issues/1237
9293
//

crates/burn-tensor/src/tensor/api/numeric.rs

+26
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,32 @@ where
778778
// Assign the original tensor data to the appropriate slice of the padded tensor
779779
padded_tensor.slice_assign(ranges, self)
780780
}
781+
782+
/// Returns a new tensor with boolean elements indicating whether each element of the input is NaN.
783+
///
784+
/// # Returns
785+
///
786+
/// A boolean tensor where `true` indicates NaN and `false` indicates a non-NaN value.
787+
pub fn is_nan(&self) -> Tensor<B, D, Bool> {
788+
// Check if the input tensor is NaN by comparing it to itself
789+
// NaN is the only value that is not equal to itself
790+
K::not_equal(self.primitive.clone(), self.primitive.clone())
791+
}
792+
793+
/// Checks if the tensor contains any NaN values.
794+
///
795+
/// # Returns
796+
///
797+
/// A boolean tensor with a single element indicating whether the tensor contains any NaN values.
798+
pub fn contains_nan(&self) -> Tensor<B, 1, Bool> {
799+
// Summing the tensor will result in NaN if the tensor contains any NaN values
800+
// This is faster than checking each element individually
801+
// because it rolls up the NaN values into a single value
802+
let sum = K::sum(self.primitive.clone());
803+
804+
// Check if the sum is NaN by comparing it to itself
805+
K::not_equal(sum.clone(), sum)
806+
}
781807
}
782808

783809
impl<B, K> Tensor<B, 2, K>

crates/burn-tensor/src/tests/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ macro_rules! testgen_all {
103103
burn_tensor::testgen_topk!();
104104
burn_tensor::testgen_remainder!();
105105
burn_tensor::testgen_cartesian_grid!();
106+
burn_tensor::testgen_nan!();
106107

107108
// test stats
108109
burn_tensor::testgen_var!();

crates/burn-tensor/src/tests/ops/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ mod matmul;
3434
mod maxmin;
3535
mod movedim;
3636
mod mul;
37+
mod nan;
3738
mod narrow;
3839
mod neg;
3940
mod one_hot;
+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#[burn_tensor_testgen::testgen(nan)]
2+
mod tests {
3+
use super::*;
4+
use burn_tensor::{Int, Tensor, TensorData};
5+
6+
#[test]
7+
#[ignore = "https://github.com/tracel-ai/burn/issues/2089"]
8+
fn is_nan() {
9+
let no_nan = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
10+
let no_nan_expected =
11+
TestTensorBool::<2>::from([[false, false, false], [false, false, false]]);
12+
13+
let with_nan = TestTensor::<2>::from([[0.0, f32::NAN, 2.0], [f32::NAN, 4.0, 5.0]]);
14+
let with_nan_expected =
15+
TestTensorBool::<2>::from([[false, true, false], [true, false, false]]);
16+
17+
assert_eq!(no_nan_expected.into_data(), no_nan.is_nan().into_data());
18+
19+
assert_eq!(with_nan_expected.into_data(), with_nan.is_nan().into_data());
20+
}
21+
22+
#[test]
23+
#[ignore = "https://github.com/tracel-ai/burn/issues/2089"]
24+
fn contains_nan() {
25+
let no_nan = TestTensor::<2>::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
26+
assert!(!no_nan.contains_nan().into_scalar());
27+
28+
let with_nan = TestTensor::<2>::from([[0.0, f32::NAN, 2.0], [3.0, 4.0, 5.0]]);
29+
assert!(with_nan.contains_nan().into_scalar());
30+
}
31+
}

0 commit comments

Comments
 (0)