Skip to content

Commit

Permalink
Fix unittest.
Browse files Browse the repository at this point in the history
Signed-off-by: Dennis Liu <[email protected]>
  • Loading branch information
Victarry committed Apr 7, 2024
1 parent 28ec889 commit c779207
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
19 changes: 9 additions & 10 deletions transformer_engine/pytorch/cpp_extensions/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ def cast_to_fp8(
otype
)
return None
if inp.nelement() > 0:
return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)
return inp

return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor,
otype,
)

def cast_from_fp8(
inp: torch.Tensor,
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ at::Tensor cast_to_fp8(const at::Tensor &input,

auto output = at::empty_like(input, at::CUDA(GetATenDType(otype)));

if (input.numel() == 0)
return output;

auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), shape, otype,
amax.data_ptr(), scale.data_ptr(),
Expand Down

0 comments on commit c779207

Please sign in to comment.