Skip to content

Commit 3e4af41

Browse files
authored
Fix sort descending for 1d case (#1494)
1 parent e8863da commit 3e4af41

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

crates/burn-tensor/src/tensor/api/sort.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ where
8989
let dims = data.shape.dims;
9090
if D == 1 {
9191
// 1D sort
92-
data.value.sort_unstable_by(|&a, &b| a.cmp(&b));
92+
data.value
93+
.sort_unstable_by(|&a, &b| compare(&a, &b, descending));
9394
} else {
9495
sort_slice::<B, D, K>(&mut data.value, &dims, dim, None, false, descending);
9596
}

crates/burn-tensor/src/tests/ops/sort_argsort.rs

+21
Original file line numberDiff line numberDiff line change
@@ -359,4 +359,25 @@ mod tests {
359359
let values_expected = Data::from([[-0.5, 0.94], [-0.3, f32::NAN], [0., f32::NAN]]);
360360
values_expected.assert_approx_eq(&values_actual, 5);
361361
}
362+
363+
#[test]
364+
fn test_sort_descending_1d() {
365+
let tensor = TestTensorInt::from([1, 2, 3, 4, 5]);
366+
367+
// Sort along dim=0
368+
let values = tensor.sort_descending(0);
369+
let values_actual = values.into_data();
370+
371+
let values_expected = Data::from([5, 4, 3, 2, 1]);
372+
assert_eq!(values_expected, values_actual);
373+
374+
let tensor = TestTensor::from([1., 2., 3., 4., 5.]);
375+
376+
// Sort along dim=0
377+
let values = tensor.sort_descending(0);
378+
let values_actual = values.into_data();
379+
380+
let values_expected = Data::from([5., 4., 3., 2., 1.]);
381+
values_expected.assert_approx_eq(&values_actual, 5);
382+
}
362383
}

0 commit comments

Comments
 (0)