Skip to content

Commit

Permalink
Handle zeros tokens for cast_to_fp8.
Browse files Browse the repository at this point in the history
  • Loading branch information
Victarry committed Mar 6, 2024
1 parent 27d5eda commit 9b520c6
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions transformer_engine/pytorch/cpp_extensions/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,27 @@ def cast_to_fp8(
"""Cast input to FP8"""

if out is not None:
tex.cast_to_fp8_noalloc(
if inp.nelement() > 0:
tex.cast_to_fp8_noalloc(
inp,
fp8_meta_tensor.scale[fp8_tensor],
out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype
)
return None
if inp.nelement() > 0:
return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale[fp8_tensor],
out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)
return None
return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)
else:
return inp


def cast_from_fp8(
Expand Down

0 comments on commit 9b520c6

Please sign in to comment.