Skip to content

Commit

Permalink
Handle zeros tokens for cast_to_fp8.
Browse files Browse the repository at this point in the history
  • Loading branch information
Victarry committed Mar 6, 2024
1 parent 71b7472 commit 8df68fc
Showing 1 changed file with 12 additions and 18 deletions.
30 changes: 12 additions & 18 deletions transformer_engine/pytorch/cpp_extensions/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,18 @@ def cast_to_fp8(
"""Cast input to FP8"""

if out is not None:
torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp,
fp8_meta_tensor.scale,
out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype
)
return None
return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)
if inp.nelement() > 0:
torch.ops.tex_ts.cast_to_fp8_noalloc_ts(
inp,
fp8_meta_tensor.scale,
out,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype
)
else:
return inp


def cast_from_fp8(
Expand Down

0 comments on commit 8df68fc

Please sign in to comment.