Skip to content

Commit

Permalink
Fix error in linear backward.
Browse files Browse the repository at this point in the history
Signed-off-by: Dennis Liu <[email protected]>
  • Loading branch information
Victarry committed Apr 16, 2024
1 parent d79d661 commit 883178e
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def forward(
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
if isinstance(grad_output[0], Float8Tensor):
if isinstance(grad_output, Float8Tensor):
ctx.fp8_meta["scaling_bwd"].scale_inv[
tex.FP8BwdTensors.GRAD_OUTPUT1] = grad_output._scale_inv

Expand Down

0 comments on commit 883178e

Please sign in to comment.