diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index cfae873076..fb788aa2df 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -176,6 +176,8 @@ def gemm( grad_bias = empty_tensor bias = bias if use_bias else empty_tensor + if A.nelement() == 0 or B.nelement() == 0: + return out, grad_bias, gelu_input assert A.dtype == dtype and B.dtype == dtype, \ f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'