From 35fb2c15609fe99cfef0d590c280eefc05400352 Mon Sep 17 00:00:00 2001 From: Laurent Date: Sat, 21 Sep 2024 14:57:15 +0200 Subject: [PATCH] Add a test. --- candle-core/tests/tensor_tests.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 567b49f1db..4a76035c17 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -193,6 +193,19 @@ fn unary_op(device: &Device) -> Result<()> { tensor.sign()?.to_vec1::()?, [-1., -1., -1., 0., 0., 1., 1., 1., 1.] ); + let tensor = Tensor::new(&[-1.0f32, 0., -2., 3.], device)?; + let y = tensor.elu(2.)?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [-1.2642, 0.0000, -1.7293, 3.0000] + ); + // This test failed on metal prior to the following PR: + // https://github.com/huggingface/candle/pull/2490 + let y = tensor.reshape((2, 2))?.t()?.elu(2.)?.flatten_all()?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [-1.2642, -1.7293, 0.0000, 3.0000] + ); Ok(()) }