diff --git a/transformer_engine/pytorch/cpp_extensions/cast.py b/transformer_engine/pytorch/cpp_extensions/cast.py index 034c71dc2e..a86222d958 100644 --- a/transformer_engine/pytorch/cpp_extensions/cast.py +++ b/transformer_engine/pytorch/cpp_extensions/cast.py @@ -33,16 +33,15 @@ def cast_to_fp8( otype ) return None - if inp.nelement() > 0: - return torch.ops.tex_ts.cast_to_fp8_ts( - inp, - fp8_meta_tensor.scale, - fp8_meta_tensor.amax_history, - fp8_meta_tensor.scale_inv, - fp8_tensor, - otype, - ) - return inp + + return torch.ops.tex_ts.cast_to_fp8_ts( + inp, + fp8_meta_tensor.scale, + fp8_meta_tensor.amax_history, + fp8_meta_tensor.scale_inv, + fp8_tensor, + otype, + ) def cast_from_fp8( inp: torch.Tensor, diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cu b/transformer_engine/pytorch/csrc/extensions/cast.cu index 80975069de..c798a39df5 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cu +++ b/transformer_engine/pytorch/csrc/extensions/cast.cu @@ -19,6 +19,9 @@ at::Tensor cast_to_fp8(const at::Tensor &input, auto output = at::empty_like(input, at::CUDA(GetATenDType(otype))); + if (input.numel() == 0) + return output; + auto input_cu = makeTransformerEngineTensor(input); auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype, amax.data_ptr(), scale.data_ptr(),