Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Adjusted the logic of MHA and DPA to enable speculative decoding #668

Merged
merged 14 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
Expand Down Expand Up @@ -1362,3 +1362,71 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
y_bshd = block_bshd(x_bshd)

assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()])



model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(8, 1e-5, 1, 8, 12, 64),
}

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs_inference.keys())
@pytest.mark.parametrize("use_RoPE", all_boolean)
@pytest.mark.parametrize("use_flash_attn", all_boolean)
def test_te_layer_kv_cache_accuracy(dtype, bs, model, use_RoPE, use_flash_attn):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0" # suppress for testing

if use_flash_attn:
os.environ["NVTE_FLASH_ATTN"] = "1"

config = model_configs_inference[model]

S = config.seq_len
B = bs
H = config.num_attention_heads
D = config.hidden_size
head_size = config.embed
layer_number = 1

# Limits the max size of KV-cache
B_max = B
S_max = 2 * S

TE_layer = (
TransformerLayer(
hidden_size=D,
ffn_hidden_size= 4 * D,
num_attention_heads=H,
attn_input_format='sbhd',
layer_number=layer_number,
)
.to(dtype=dtype)
.cuda()
)

inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=dtype, device="cuda")

input = torch.randn((S, B, D), dtype=dtype, device="cuda")
incremental_output = [None] * S

# Generate output for the entire sequence
full_output = TE_layer(input)

# Incrementaly generate outputs using KV-cache
for i in range(S):
incremental_output[i] = TE_layer(
hidden_states=input[i].view(1,B,D),
inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None)

atol = {
torch.float32 : 2e-2,
Oleg-Goncharov marked this conversation as resolved.
Show resolved Hide resolved
torch.half : 2e-2,
torch.bfloat16: 5e-2,
}
# Check if the fully generated output matches the one generated incrementally
assert_allclose(full_output, incremental_output, atol[dtype])
67 changes: 43 additions & 24 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,7 @@ def apply_rotary_pos_emb(
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[s, b, h, d]` or `[t, h, d]`, on which
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
Expand Down Expand Up @@ -2523,6 +2523,7 @@ def forward(
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
Expand Down Expand Up @@ -2616,6 +2617,8 @@ def forward(
to the attention score of query i and key j.
fast_zero_fill: bool, default = `True`
Whether to use the fast path to set output tensors to 0 or not.
inference_params: Optional[InferenceParams], default = `None`
Inference context.
Oleg-Goncharov marked this conversation as resolved.
Show resolved Hide resolved
"""

assert (
Expand Down Expand Up @@ -2643,6 +2646,28 @@ def forward(
if qkv_format is None:
qkv_format = self.qkv_format

if inference_params is not None:
Oleg-Goncharov marked this conversation as resolved.
Show resolved Hide resolved
assert qkv_format == "sbhd"
Oleg-Goncharov marked this conversation as resolved.
Show resolved Hide resolved
assert self.layer_number is not None
Oleg-Goncharov marked this conversation as resolved.
Show resolved Hide resolved
(inference_key_memory, inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]

batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)

sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)

# Copy keys and values into KV-cache
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
Oleg-Goncharov marked this conversation as resolved.
Show resolved Hide resolved

assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
Expand Down Expand Up @@ -2721,12 +2746,15 @@ def forward(
use_flash_attention = False

# Filter: cross attention + causal mask.
if (_flash_attn_2_1_plus
# (in training mode)
if (inference_params is None
and _flash_attn_2_1_plus
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv):
and max_seqlen_q != max_seqlen_kv
):
warnings.warn(
"Disabling the use of FlashAttention since version 2.1+ has changed its behavior "
"for causal mask in cross attention. See "
"In training mode, disable the use of FlashAttention since version 2.1+ has "
"changed its behavior for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
Expand All @@ -2753,7 +2781,11 @@ def forward(
if attn_mask_type == "arbitrary":
use_flash_attention = False
use_fused_attention = False
if "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:

if (inference_params is None
ptrendx marked this conversation as resolved.
Show resolved Hide resolved
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
use_unfused_attention = False

# Filter: bias.
Expand Down Expand Up @@ -3441,7 +3473,7 @@ def forward(

if inference_params and self.layer_number is not None:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
Expand Down Expand Up @@ -3593,29 +3625,15 @@ def forward(
rotary_pos_emb = ((rotary_pos_emb,) * 2)

if inference_params and self.layer_number is not None:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...
]

# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
delta = sequence_end - sequence_start
q_pos_emb, k_pos_emb = rotary_pos_emb
q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
q_pos_emb = q_pos_emb[:delta, :, :, :]
k_pos_emb = k_pos_emb[:delta, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)
Oleg-Goncharov marked this conversation as resolved.
Show resolved Hide resolved

# ==================================
Expand Down Expand Up @@ -3643,6 +3661,7 @@ def forward(
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=fast_zero_fill,
inference_params=inference_params,
)

# =================
Expand Down
75 changes: 40 additions & 35 deletions transformer_engine/pytorch/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,17 @@

_default_causal_mask = {}

def _get_default_causal_mask(sq: int) -> torch.Tensor:
def _get_default_causal_mask(sq: int, sk: int) -> torch.Tensor:
"""Return the causal upper triangular mask for softmax input"""
if sq not in _default_causal_mask:
_default_causal_mask[sq] = torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()
return _default_causal_mask[sq]
if sq == 1:
return torch.zeros((1, sk), dtype=torch.bool, device="cuda")

matrix_shape = (sq, sk)
if matrix_shape not in _default_causal_mask:
diagonal_offset = sk - sq + 1
_default_causal_mask[matrix_shape] = torch.triu(torch.ones(sq, sk, device="cuda"),
diagonal=diagonal_offset).bool()
return _default_causal_mask[matrix_shape]


def _get_onnx_export_causal_mask(
Expand Down Expand Up @@ -334,47 +340,46 @@ def is_kernel_available(self, mask: torch.Tensor, b: int, np: int, sq: int, sk:
attn_batches = b * np

if ( # pylint: disable=too-many-boolean-expressions
self.scaled_masked_softmax_fusion # user wants to fuse
and self.input_in_float16 # input must be fp16
and 16 <= sk <= 16384 # sk must be 16 ~ 16384
and sk % 8 == 0 # sk must be divisor of 8
and sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
not self.scaled_masked_softmax_fusion # user doesn't want to fuse
or not self.input_in_float16 # input must be fp16
or sk < 16
or sk > 16384 # sk must be 16 ~ 16384
or sk % 8 != 0 # sk must be divisor of 8
or self.attn_mask_type == "arbitrary" # Custom masks not supported
):
return False

if self.attn_mask_type == "causal": # unfused causal softmax kernel
return True

if (sq % 4 == 0 # sq must be divisor of 4
and attn_batches % 4 == 0 # np * b must be divisor of 4
and self.attn_mask_type != "arbitrary" # Custom masks not supported
):
if 0 <= sk <= 16384:
batch_per_block = self.get_batch_per_block(int(sk))

if self.attn_mask_type == "causal":
if attn_batches % batch_per_block == 0:
return True
elif self.attn_mask_type == "padding":
if (
mask is not None
and sq % batch_per_block == 0
and mask.shape[-2] == sq
and mask.shape[-1] == sk
):
return True
else:
if sq % batch_per_block == 0:
return True
batch_per_block = self.get_batch_per_block(int(sk))

if self.attn_mask_type == "padding":
if (
mask is not None
and sq % batch_per_block == 0
and mask.shape[-2] == sq
and mask.shape[-1] == sk
):
return True
else:
if sq % batch_per_block == 0:
return True
return False

def forward_fused_softmax(
self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
) -> torch.Tensor:
"""Fused masked softmax kernel"""
b, np, sq, sk = inp.size()
scale = 1.0 if scale is None else scale

if self.attn_mask_type == "causal":
assert sq == sk, "causal mask is only for self attention"
return ScaledAlignedCausalMaskedSoftmax.apply(inp, scale)

# input is 3D tensor (attn_batches, sq, sk)
inp = inp.view(-1, sq, sk)
probs = ScaledUpperTriangMaskedSoftmax.apply(inp, scale)
return probs.view(b, np, sq, sk)
# input is 4D tensor (b, np, sq, sk)
if mask is not None and self.attn_mask_type != "no_mask":
return ScaledMaskedSoftmax.apply(inp, mask, scale)
Expand All @@ -391,12 +396,12 @@ def forward_torch_softmax(
inp = inp * scale

if self.attn_mask_type == "causal":
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
if is_in_onnx_export_mode() and self.kvcache_max_seq > 0:
seq_len_q, seq_len_k = inp.size(2), inp.size(3)
assert self.kvcache_max_seq >= seq_len_k
mask = _get_onnx_export_causal_mask(seq_len_q, seq_len_k, self.onnx_causal_mask)
else:
mask = _get_default_causal_mask(inp.size(2))
mask = _get_default_causal_mask(seq_len_q, seq_len_k)

mask_output = inp
if mask is not None and self.attn_mask_type != "no_mask":
Expand Down
Loading