Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Adjusted the logic of MHA and DPA to enable speculative decoding #668

Merged
merged 14 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 116 additions & 1 deletion tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
Expand Down Expand Up @@ -1397,3 +1397,118 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
y_bshd = block_bshd(x_bshd)

assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()])


model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
@pytest.mark.parametrize("use_RoPE", all_boolean)
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"

if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"

config = model_configs_inference[model_key]

S = config.seq_len
B = bs
H = config.num_attention_heads
D = config.hidden_size
head_size = config.embed
layer_number = 1

# Limits the max size of KV-cache
B_max = B
S_max = S + 2

if module == "TransformerLayer":
model = (
TransformerLayer(
hidden_size=D,
ffn_hidden_size= 4 * D,
num_attention_heads=H,
attn_input_format=input_format,
layer_number=layer_number,
attention_dropout = 0.0
)
.to(dtype=dtype)
.cuda()
.eval()
)
else:
model = (
MultiheadAttention(
hidden_size=D,
num_attention_heads=H,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout = 0.0
)
.to(dtype=dtype)
.cuda()
.eval()
)

inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")

input = torch.randn((S, B, D), dtype=dtype, device="cuda")
if input_format == "bshd":
input = input.transpose(0, 1).contiguous()

incremental_output = torch.zeros_like(input)

# Generate output for the entire sequence
full_output = model(
hidden_states=input,
rotary_pos_emb=rotary_freqs if use_RoPE else None)

# Incrementaly generate outputs using KV-cache
for i in range(S):
if input_format == "sbhd":
incremental_input = input[i].view(1,B,D)
else:
incremental_input = input[:, i, :].view(B,1,D)

line_output = model(
hidden_states=incremental_input,
inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None)

inference_params.sequence_len_offset += 1

if input_format == "sbhd":
incremental_output[i] = line_output.view(B,D)
else:
incremental_output[:, i, :] = line_output.view(B,D)

if module == "TransformerLayer":
atol = {
torch.float32 : 5e-3,
torch.half : 5e-3,
torch.bfloat16: 5e-2,
}
else:
atol = {
torch.float32 : 1e-3,
torch.half : 1e-3,
torch.bfloat16: 1e-2,
}

# Check if the fully generated output matches the one generated incrementally
assert_allclose(full_output, incremental_output, atol[dtype])
135 changes: 86 additions & 49 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@

__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]


class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
Expand Down Expand Up @@ -1180,7 +1179,7 @@ def apply_rotary_pos_emb(
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[s, b, h, d]` or `[t, h, d]`, on which
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
Expand Down Expand Up @@ -2523,6 +2522,7 @@ def forward(
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
Expand Down Expand Up @@ -2616,6 +2616,16 @@ def forward(
to the attention score of query i and key j.
fast_zero_fill: bool, default = `True`
Whether to use the fast path to set output tensors to 0 or not.
inference_params: Optional[InferenceParams], default = `None`
Optimizes execution performance during inference by caching Keys and Values of the
current decoding iteration. These cached values are appended to the K and V values
computed in previous iterations, eliminating the need to recalculate them for the
entire sequence.
Initialization of `inference_params` is required prior to use to ensure sufficient
memory allocation.
Adjustments of the sequence_len_offset should be done after a complete forward pass.
If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
"""

assert (
Expand Down Expand Up @@ -2643,6 +2653,39 @@ def forward(
if qkv_format is None:
qkv_format = self.qkv_format

if inference_params is not None:
Oleg-Goncharov marked this conversation as resolved.
Show resolved Hide resolved
assert self.layer_number is not None, "Layer number must be set!"

if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)

(inference_key_memory, inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]

batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)

sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)

# Copy keys and values into KV-cache
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]
Oleg-Goncharov marked this conversation as resolved.
Show resolved Hide resolved

if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)

key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
Expand Down Expand Up @@ -2721,12 +2764,15 @@ def forward(
use_flash_attention = False

# Filter: cross attention + causal mask.
if (_flash_attn_2_1_plus
# (in training mode)
if (inference_params is None
and _flash_attn_2_1_plus
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv):
and max_seqlen_q != max_seqlen_kv
):
warnings.warn(
"Disabling the use of FlashAttention since version 2.1+ has changed its behavior "
"for causal mask in cross attention. See "
"In training mode, disable the use of FlashAttention since version 2.1+ has "
"changed its behavior for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
Expand All @@ -2753,7 +2799,11 @@ def forward(
if attn_mask_type == "arbitrary":
use_flash_attention = False
use_fused_attention = False
if "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:

if (inference_params is None
ptrendx marked this conversation as resolved.
Show resolved Hide resolved
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
use_unfused_attention = False

# Filter: bias.
Expand Down Expand Up @@ -3446,12 +3496,12 @@ def forward(
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"

# =================================================
# Pre-allocate memory for key-values for inference.
# Pre-allocate memory for key-values for inference
# =================================================

if inference_params and self.layer_number is not None:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
Expand All @@ -3469,9 +3519,9 @@ def forward(
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]

# =====================
# ======================
# Query, Key, and Value
# =====================
# ======================

if self.attention_type == "self":
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
Expand Down Expand Up @@ -3593,51 +3643,37 @@ def forward(
)
query_layer = query_layer.view(*new_tensor_shape)

# ==================================
# Adjust key and value for inference
# ==================================
# ======================================================
# Apply relative positional encoding (rotary embedding)
# ======================================================

# duplicate the pos_emb for self attention
if rotary_pos_emb is not None:
# duplicate the pos_emb for self attention
if not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = ((rotary_pos_emb,) * 2)

if inference_params and self.layer_number is not None:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...
]

# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)

# ==================================
# core attention computation
# ==================================

# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb

# adjust key and value for inference
if inference_params is not None:
if self.qkv_format == "sbhd":
sequence_length = key_layer.size(0)
elif self.qkv_format == "bshd":
sequence_length = key_layer.size(1)

sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + sequence_length

q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)

# ===========================
# Core attention computation
# ===========================

context_layer = self.core_attention(
query_layer,
key_layer,
Expand All @@ -3653,11 +3689,12 @@ def forward(
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=fast_zero_fill,
inference_params=inference_params,
)

# =================
# ===================
# Output. [sq, b, h]
# =================
# ===================

projection_output = self.proj(
context_layer, is_first_microbatch=is_first_microbatch
Expand Down
Loading
Loading