Skip to content

Commit

Permalink
Fix bf16 te.Linear with no tokens.
Browse files Browse the repository at this point in the history
  • Loading branch information
Victarry committed Feb 22, 2024
1 parent 142a349 commit 9163446
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def gemm(
grad_bias = empty_tensor

bias = bias if use_bias else empty_tensor
if A.nelement() == 0 or B.nelement() == 0:
return out, gelu_input

assert A.dtype == dtype and B.dtype == dtype, \
f'Expected dtype={dtype}, but found A.dtype={A.dtype} and B.dtype={B.dtype}'
Expand Down

0 comments on commit 9163446

Please sign in to comment.