Skip to content

Commit

Permalink
Cleaned up the files per pylint
Browse files Browse the repository at this point in the history
Signed-off-by: Oleg Goncharov <[email protected]>
  • Loading branch information
Oleg-Goncharov committed Feb 26, 2024
1 parent 68d8f78 commit 4f8e7a9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2678,11 +2678,11 @@ def forward(
sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]

if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)

key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _get_default_causal_mask(sq: int, sk: int) -> torch.Tensor:
if matrix_shape not in _default_causal_mask:
diagonal_offset = sk - sq + 1
_default_causal_mask[matrix_shape] = torch.triu(
torch.ones(sq, sk, dtype=torch.bool, device="cuda"),
torch.ones(sq, sk, dtype=torch.bool, device="cuda"),
diagonal=diagonal_offset)
return _default_causal_mask[matrix_shape]

Expand Down

0 comments on commit 4f8e7a9

Please sign in to comment.