Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,10 @@ def write_only(
else:
position_ids = cache_kwargs.get("position_ids")
is_sliding_layer = cache_kwargs.get("is_sliding")
sliding_window = cache_kwargs.get("sliding_window")
_, _, ctx_len, _ = self.key_cache[layer_idx].shape
if is_sliding_layer:
kv_position_ids = torch.arange(ctx_len, dtype=torch.int64).reshape(1, -1)
kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window)
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
Expand All @@ -713,6 +714,45 @@ def write_only(
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
return k_out, v_out

def read_only_blockedKV(
self,
start_idx: torch.Tensor,
end_idx: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
position_ids = cache_kwargs.get("position_ids")
is_sliding_layer = cache_kwargs.get("is_sliding")
sliding_window = cache_kwargs.get("sliding_window")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
if is_sliding_layer:
ctx_len = self.key_cache[layer_idx].shape[2]
else:
ctx_len = cache_kwargs.get("CCL", self.key_cache[layer_idx].shape[2])

ctx_indices = torch.arange(start=start_idx, end=end_idx)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
if torch.onnx.is_in_onnx_export():
invalid_idx_value = torch.iinfo(torch.int32).max
else:
invalid_idx_value = 0
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)

if batch_index is not None:
k_out = CtxGatherFuncBlockedKVCB.apply(k_out, batch_index, ctx_indices)
v_out = CtxGatherFuncBlockedKVCB.apply(v_out, batch_index, ctx_indices)
else:
k_out = CtxGatherFuncBlockedKV.apply(k_out, ctx_indices)
v_out = CtxGatherFuncBlockedKV.apply(v_out, ctx_indices)

v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out

def update(
self,
key_states: torch.Tensor,
Expand Down
179 changes: 163 additions & 16 deletions QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# -----------------------------------------------------------------------------
import math
import os
from typing import Callable, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

import torch
from torch import nn
Expand Down Expand Up @@ -34,6 +34,7 @@
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
from QEfficient.utils.logging_utils import logger
import math


class QEffGptOssExperts(GptOssExperts):
Expand Down Expand Up @@ -671,6 +672,123 @@ def eager_attention_forward_blocked(
output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous()
return output, output

def eager_attention_forward_blockedKV(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
num_kv_blocks: Optional[int] = None,
skip_threshold: Optional[bool] = None,
cache_kwargs: Optional[Dict[str, Any]] = None,
layer_idx: int = None,
past_key_value: Optional[Cache] = None,
dropout: float = 0.0,
**kwargs,
):
# Initialize result tensor
output = torch.zeros_like(query)

# Initialize Running Maximum
batch_size, num_heads, seq_len, head_dim = query.shape
current_max = torch.full((batch_size, num_heads, seq_len), float(MIN_MASKED_ATTENTION_VALUE), dtype=query.dtype)

current_denominator = torch.zeros(batch_size, num_heads, seq_len, dtype=query.dtype)

past_seen_tokens = cache_kwargs.get("past_seen_tokens")
position_ids = cache_kwargs.get("position_ids")
block_size = -(-past_seen_tokens // num_kv_blocks)
masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32)

# sinks = module.sinks.unsqueeze(0).unsqueeze(2).expand(batch_size, -1, seq_len, -1)
sinks = module.sinks.reshape(1, -1, 1, 1).expand(
batch_size, -1, seq_len, -1
)

if skip_threshold:
threshold = math.log(skip_threshold / past_seen_tokens)

for j in range(num_kv_blocks):
start_index = j * block_size
end_index = (j + 1) * block_size

if not torch.onnx.is_in_onnx_export():
if (position_ids < start_index).all():
break # can skip future blocks since masked

# out of onnx export, we create a condition to use torch.where on
skip_condition = (position_ids < start_index).all()

K_block, V_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs)
K_block_states = repeat_kv(K_block, module.num_key_value_groups)
V_block_states = repeat_kv(V_block, module.num_key_value_groups)

past_seen_tokens_start = start_index
past_seen_tokens_end = torch.where(
torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int),
past_seen_tokens,
end_index,
)

causal_mask_block = _create_causal_mask(
position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start
)

attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling

if attention_mask is not None:
attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block)

block_max = attn_weights_block.max(dim=-1, keepdim=True).values

prev_max = current_max
current_max = torch.maximum(prev_max, block_max.squeeze(-1))
delta_max = prev_max - current_max

# skip condition
if skip_threshold:
local_delta = current_max.unsqueeze(2) - block_max > threshold

attn_weights_block_stable = attn_weights_block - current_max.unsqueeze(-1)
kv_exp = torch.exp(attn_weights_block_stable)

prev_denominator = current_denominator
current_denominator = prev_denominator * torch.exp(delta_max) + kv_exp.sum(dim=-1)

block_contribution = torch.matmul(kv_exp, V_block_states)
prev_output = output
output_curr_step = prev_output * torch.exp(delta_max).unsqueeze(-1) + block_contribution

if torch.onnx.is_in_onnx_export():
output_curr_step = torch.where(skip_condition, output, output_curr_step)

if skip_threshold:
output = torch.where(local_delta, output, output_curr_step) # if current_max - local_max > threshold, use previous output effectively skipping this block
else:
output = output_curr_step

# Apply attention sinks
prev_max = current_max
sink_max = sinks.max(dim=-1, keepdim=True).values.squeeze(-1)
current_max = torch.maximum(prev_max, sink_max)
delta_max = prev_max - current_max

sinks_stable = sinks - current_max.unsqueeze(-1)
sink_exp = torch.exp(sinks_stable)

prev_denominator = current_denominator
current_denominator = prev_denominator * torch.exp(delta_max) + sink_exp.sum(dim=-1)

output = output * torch.exp(delta_max).unsqueeze(-1)

output = output / current_denominator.unsqueeze(-1)

attn_output = output.transpose(1, 2).contiguous()
attn_weights = None

return attn_output, attn_weights


def opt_eager_attention_forward_blocked(
module: nn.Module,
Expand Down Expand Up @@ -915,6 +1033,8 @@ def forward(
batch_index: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
sliding_mask=None,
num_kv_blocks: Optional[torch.Tensor] = None,
skip_threshold: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
Expand All @@ -929,27 +1049,50 @@ def forward(
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"batch_index": batch_index,
"position_ids": position_ids,
"config": self.config,
"is_sliding": self.sliding_window is not None,
"sliding_window": past_key_value.sliding_window_len,
}
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
cache_kwargs["CCL"] = attention_mask.shape[-1]
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if num_kv_blocks is not None:
past_seen_tokens = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
cache_kwargs = {
"sin": sin,
"cos": cos,
"batch_index": batch_index,
"position_ids": position_ids,
"past_seen_tokens": past_seen_tokens,
"config": self.config,
"is_sliding": self.sliding_window is not None,
"sliding_window": past_key_value.sliding_window_len,
}
past_key_value.write_only(key_states, value_states, self.layer_idx, cache_kwargs)
else:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"batch_index": batch_index,
"position_ids": position_ids,
"config": self.config,
"is_sliding": self.sliding_window is not None,
"sliding_window": past_key_value.sliding_window_len,
}
if comp_ctx_lengths is not None:
attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]]
cache_kwargs["CCL"] = attention_mask.shape[-1]
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)


if self.sliding_window is not None:
attention_mask = sliding_mask
else:
attention_mask = attention_mask

attention_interface: Callable = eager_attention_forward
if num_kv_blocks is not None:
attention_interface: Callable = eager_attention_forward_blockedKV
if query_states.shape[2] != 1: # don't skip for prefill
skip_threshold = None
else:
attention_interface: Callable = eager_attention_forward

kwargs["cache_kwargs"] = cache_kwargs

attn_output, attn_weights = attention_interface(
self,
query_states,
Expand All @@ -958,8 +1101,12 @@ def forward(
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
num_kv_blocks=num_kv_blocks,
skip_threshold=skip_threshold,
sliding_window=self.sliding_window,
s_aux=self.sinks, # diff with Llama
past_key_value=past_key_value,
layer_idx=self.layer_idx,
**kwargs,
)

Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
RevertPrefillKeepAttentionTransform,
RevertPrefillOnlyTransform,
SamplerTransform,
SkipSoftmaxTransform,
SpDTransform,
VlmKVOffloadTransform,
VlmNoKVOffloadTransform,
Expand Down Expand Up @@ -2405,6 +2406,11 @@ def __init__(

if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None:
BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks"))
if self.model.qaic_config is not None and self.model.qaic_config.get("skip_threshold", None) is not None:
if self.model.qaic_config.get("num_kv_blocks", None) is None:
warnings.warn(f"Skip Softmax require KV blocking be enabled by passing num_kv_blocks")
else:
SkipSoftmaxTransform.apply(self.model, skip_threshold=self.model.qaic_config.get("skip_threshold"))

def __repr__(self) -> str:
return self.__class__.__name__ + "\n" + self.model.__repr__()
Expand Down
19 changes: 19 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ class BlockedKVAttentionTransform:
_module_mapping = {
QEffLlamaAttention,
QEffQwen2_5_VLAttention,
QEffGptOssAttention,
}

@classmethod
Expand All @@ -937,3 +938,21 @@ def apply(cls, model: nn.Module, num_kv_blocks) -> Tuple[nn.Module, bool]:
elif module.__class__.__name__.endswith("Attention") and type(module) not in cls._module_mapping:
warnings.warn(f"KV blocking is not yet supported for {type(module)}.")
return model, transformed

class SkipSoftmaxTransform:
_module_mapping = {
QEffGptOssAttention,
}

@classmethod
def apply(cls, model: nn.Module, skip_threshold) -> Tuple[nn.Module, bool]:
transformed = False
for module in model.modules():
if type(module) in cls._module_mapping:
repl_module = type(module)
module.__class__ = repl_module
module.forward = MethodType(partial(repl_module.forward, skip_threshold=skip_threshold), module)
transformed = True # Set to True if at least one transformation occurs
elif module.__class__.__name__.endswith("Attention") and type(module) not in cls._module_mapping:
warnings.warn(f"Skip Softmax is not yet supported for {type(module)}.")
return model, transformed
1 change: 1 addition & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ class Constants:
SDK_PLATFORM_XML = (
"/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version.
)
SKIP_THRESHOLD = 500


@dataclass
Expand Down
23 changes: 22 additions & 1 deletion tests/transformers/models/test_causal_lm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,19 @@
test_models_blockedKV = [
# "meta-llama/Llama-3.3-70B-Instruct",
"meta-llama/Llama-3.2-1B",
"openai/gpt-oss-20b",
]

test_models_skipSoftmax = [
"openai/gpt-oss-20b",
]




def get_custom_n_layers(model_name):
"""
Function to set number layers of the variuos types of models such as swiftkv models and others
Function to set number layers of the various types of models such as swiftkv models and others
--------

:model_name: str
Expand Down Expand Up @@ -193,6 +200,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), (
"Tokens don't match for HF PyTorch model output and KV PyTorch model output"
)

onnx_model_path = qeff_model.export()
ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm)
gen_len = ort_tokens.shape[-1]
Expand Down Expand Up @@ -513,6 +521,19 @@ def test_causal_blockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
qaic_config = dict(num_kv_blocks=Constants.NUM_KV_BLOCKS)
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config)

@pytest.mark.on_qaic
@pytest.mark.parametrize("model_name", test_models_skipSoftmax)
def test_causal_skipSoftmax_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
"""
Test function to validate the PyTorch model for KV blocking, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
``Mandatory`` Args:
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
"""
n_layer = get_custom_n_layers(model_name)

qaic_config = dict(num_kv_blocks=Constants.NUM_KV_BLOCKS, skip_threshold=Constants.SKIP_THRESHOLD)
check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config)


@pytest.mark.on_qaic
@pytest.mark.parametrize("model_name", test_models_blockedKV)
Expand Down