From c7792079e23132c357092e5eae56c87213127fd6 Mon Sep 17 00:00:00 2001 From: Dennis Liu Date: Sun, 7 Apr 2024 00:26:09 -0700 Subject: [PATCH] Fix unittest. Signed-off-by: Dennis Liu --- .../pytorch/cpp_extensions/cast.py | 19 +++++++++---------- .../pytorch/csrc/extensions/cast.cu | 3 +++ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py index 034c71dc2e..a86222d958 100644 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ b/transformer_engine/pytorch/cpp_extensions/cast.py @@ -33,16 +33,15 @@ def cast_to_fp8( otype ) return None - if inp.nelement() > 0: - return torch.ops.tex_ts.cast_to_fp8_ts( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, - otype, - ) - return inp + + return torch.ops.tex_ts.cast_to_fp8_ts( + inp, + fp8_meta_tensor.scale, + fp8_meta_tensor.amax_history, + fp8_meta_tensor.scale_inv, + fp8_tensor, + otype, + ) def cast_from_fp8( inp: torch.Tensor, diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cu index 80975069de..c798a39df5 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cu +++ b/transformer_engine/pytorch/csrc/extensions/cast.cu @@ -19,6 +19,9 @@ at::Tensor cast_to_fp8(const at::Tensor &input, auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); + if (input.numel() == 0) + return output; + auto input_cu = makeTransformerEngineTensor(input); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax.data_ptr(), scale.data_ptr(),