diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index faadaba6b..2f2639344 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -696,9 +696,10 @@ def write_only( else: position_ids = cache_kwargs.get("position_ids") is_sliding_layer = cache_kwargs.get("is_sliding") + sliding_window = cache_kwargs.get("sliding_window") _, _, ctx_len, _ = self.key_cache[layer_idx].shape if is_sliding_layer: - kv_position_ids = torch.arange(ctx_len, dtype=torch.int64).reshape(1, -1) + kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window) self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) self.value_cache[layer_idx] = CtxScatterFunc.apply( self.value_cache[layer_idx], kv_position_ids, value_states @@ -713,6 +714,45 @@ def write_only( k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] return k_out, v_out + def read_only_blockedKV( + self, + start_idx: torch.Tensor, + end_idx: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = cache_kwargs.get("position_ids") + is_sliding_layer = cache_kwargs.get("is_sliding") + sliding_window = cache_kwargs.get("sliding_window") + batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Original Gather + if is_sliding_layer: + ctx_len = self.key_cache[layer_idx].shape[2] + else: + ctx_len = cache_kwargs.get("CCL", self.key_cache[layer_idx].shape[2]) + + ctx_indices = torch.arange(start=start_idx, end=end_idx)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncBlockedKVCB.apply(k_out, batch_index, ctx_indices) + v_out = CtxGatherFuncBlockedKVCB.apply(v_out, batch_index, ctx_indices) + else: + k_out = CtxGatherFuncBlockedKV.apply(k_out, ctx_indices) + v_out = CtxGatherFuncBlockedKV.apply(v_out, ctx_indices) + + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out + def update( self, key_states: torch.Tensor, diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 3efe890b8..ecce68488 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- import math import os -from typing import Callable, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import torch from torch import nn @@ -34,6 +34,7 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE from QEfficient.utils.logging_utils import logger +import math class QEffGptOssExperts(GptOssExperts): @@ -671,6 +672,123 @@ def eager_attention_forward_blocked( output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous() return output, output +def eager_attention_forward_blockedKV( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + num_kv_blocks: Optional[int] = None, + skip_threshold: Optional[bool] = None, + cache_kwargs: Optional[Dict[str, Any]] = None, + layer_idx: int = None, + past_key_value: Optional[Cache] = None, + dropout: float = 0.0, + **kwargs, +): + # Initialize result tensor + output = torch.zeros_like(query) + + # Initialize Running Maximum + batch_size, num_heads, seq_len, head_dim = query.shape + current_max = torch.full((batch_size, num_heads, seq_len), float(MIN_MASKED_ATTENTION_VALUE), dtype=query.dtype) + + current_denominator = torch.zeros(batch_size, num_heads, seq_len, dtype=query.dtype) + + past_seen_tokens = cache_kwargs.get("past_seen_tokens") + position_ids = cache_kwargs.get("position_ids") + block_size = -(-past_seen_tokens // num_kv_blocks) + masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32) + + # sinks = module.sinks.unsqueeze(0).unsqueeze(2).expand(batch_size, -1, seq_len, -1) + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + batch_size, -1, seq_len, -1 + ) + + if skip_threshold: + threshold = math.log(skip_threshold / past_seen_tokens) + + for j in range(num_kv_blocks): + start_index = j * block_size + end_index = (j + 1) * block_size + + if not torch.onnx.is_in_onnx_export(): + if (position_ids < start_index).all(): + break # can skip future blocks since masked + + # out of onnx export, we create a condition to use torch.where on + skip_condition = (position_ids < start_index).all() + + K_block, V_block = past_key_value.read_only_blockedKV(start_index, end_index, layer_idx, cache_kwargs) + K_block_states = repeat_kv(K_block, module.num_key_value_groups) + V_block_states = repeat_kv(V_block, module.num_key_value_groups) + + past_seen_tokens_start = start_index + past_seen_tokens_end = torch.where( + torch.tensor(past_seen_tokens, dtype=torch.int) < torch.tensor(end_index, dtype=torch.int), + past_seen_tokens, + end_index, + ) + + causal_mask_block = _create_causal_mask( + position_ids=position_ids, target_length=past_seen_tokens_end, start_index=past_seen_tokens_start + ) + + attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling + + if attention_mask is not None: + attn_weights_block = torch.where(causal_mask_block, masked_tensor, attn_weights_block) + + block_max = attn_weights_block.max(dim=-1, keepdim=True).values + + prev_max = current_max + current_max = torch.maximum(prev_max, block_max.squeeze(-1)) + delta_max = prev_max - current_max + + # skip condition + if skip_threshold: + local_delta = current_max.unsqueeze(2) - block_max > threshold + + attn_weights_block_stable = attn_weights_block - current_max.unsqueeze(-1) + kv_exp = torch.exp(attn_weights_block_stable) + + prev_denominator = current_denominator + current_denominator = prev_denominator * torch.exp(delta_max) + kv_exp.sum(dim=-1) + + block_contribution = torch.matmul(kv_exp, V_block_states) + prev_output = output + output_curr_step = prev_output * torch.exp(delta_max).unsqueeze(-1) + block_contribution + + if torch.onnx.is_in_onnx_export(): + output_curr_step = torch.where(skip_condition, output, output_curr_step) + + if skip_threshold: + output = torch.where(local_delta, output, output_curr_step) # if current_max - local_max > threshold, use previous output effectively skipping this block + else: + output = output_curr_step + + # Apply attention sinks + prev_max = current_max + sink_max = sinks.max(dim=-1, keepdim=True).values.squeeze(-1) + current_max = torch.maximum(prev_max, sink_max) + delta_max = prev_max - current_max + + sinks_stable = sinks - current_max.unsqueeze(-1) + sink_exp = torch.exp(sinks_stable) + + prev_denominator = current_denominator + current_denominator = prev_denominator * torch.exp(delta_max) + sink_exp.sum(dim=-1) + + output = output * torch.exp(delta_max).unsqueeze(-1) + + output = output / current_denominator.unsqueeze(-1) + + attn_output = output.transpose(1, 2).contiguous() + attn_weights = None + + return attn_output, attn_weights + def opt_eager_attention_forward_blocked( module: nn.Module, @@ -915,6 +1033,8 @@ def forward( batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, sliding_mask=None, + num_kv_blocks: Optional[torch.Tensor] = None, + skip_threshold: Optional[bool] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: input_shape = hidden_states.shape[:-1] @@ -929,27 +1049,50 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = { - "sin": sin, - "cos": cos, - "batch_index": batch_index, - "position_ids": position_ids, - "config": self.config, - "is_sliding": self.sliding_window is not None, - "sliding_window": past_key_value.sliding_window_len, - } - if comp_ctx_lengths is not None: - attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] - cache_kwargs["CCL"] = attention_mask.shape[-1] - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if num_kv_blocks is not None: + past_seen_tokens = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0 + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "past_seen_tokens": past_seen_tokens, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + } + past_key_value.write_only(key_states, value_states, self.layer_idx, cache_kwargs) + else: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + } + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs["CCL"] = attention_mask.shape[-1] + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + if self.sliding_window is not None: attention_mask = sliding_mask else: attention_mask = attention_mask - attention_interface: Callable = eager_attention_forward + if num_kv_blocks is not None: + attention_interface: Callable = eager_attention_forward_blockedKV + if query_states.shape[2] != 1: # don't skip for prefill + skip_threshold = None + else: + attention_interface: Callable = eager_attention_forward + + kwargs["cache_kwargs"] = cache_kwargs + attn_output, attn_weights = attention_interface( self, query_states, @@ -958,8 +1101,12 @@ def forward( attention_mask, dropout=0.0 if not self.training else self.attention_dropout, scaling=self.scaling, + num_kv_blocks=num_kv_blocks, + skip_threshold=skip_threshold, sliding_window=self.sliding_window, s_aux=self.sinks, # diff with Llama + past_key_value=past_key_value, + layer_idx=self.layer_idx, **kwargs, ) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index d2cc1e681..de95c6037 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -53,6 +53,7 @@ RevertPrefillKeepAttentionTransform, RevertPrefillOnlyTransform, SamplerTransform, + SkipSoftmaxTransform, SpDTransform, VlmKVOffloadTransform, VlmNoKVOffloadTransform, @@ -2405,6 +2406,11 @@ def __init__( if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) + if self.model.qaic_config is not None and self.model.qaic_config.get("skip_threshold", None) is not None: + if self.model.qaic_config.get("num_kv_blocks", None) is None: + warnings.warn(f"Skip Softmax require KV blocking be enabled by passing num_kv_blocks") + else: + SkipSoftmaxTransform.apply(self.model, skip_threshold=self.model.qaic_config.get("skip_threshold")) def __repr__(self) -> str: return self.__class__.__name__ + "\n" + self.model.__repr__() diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index b978b6193..f1336710b 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -923,6 +923,7 @@ class BlockedKVAttentionTransform: _module_mapping = { QEffLlamaAttention, QEffQwen2_5_VLAttention, + QEffGptOssAttention, } @classmethod @@ -937,3 +938,21 @@ def apply(cls, model: nn.Module, num_kv_blocks) -> Tuple[nn.Module, bool]: elif module.__class__.__name__.endswith("Attention") and type(module) not in cls._module_mapping: warnings.warn(f"KV blocking is not yet supported for {type(module)}.") return model, transformed + +class SkipSoftmaxTransform: + _module_mapping = { + QEffGptOssAttention, + } + + @classmethod + def apply(cls, model: nn.Module, skip_threshold) -> Tuple[nn.Module, bool]: + transformed = False + for module in model.modules(): + if type(module) in cls._module_mapping: + repl_module = type(module) + module.__class__ = repl_module + module.forward = MethodType(partial(repl_module.forward, skip_threshold=skip_threshold), module) + transformed = True # Set to True if at least one transformation occurs + elif module.__class__.__name__.endswith("Attention") and type(module) not in cls._module_mapping: + warnings.warn(f"Skip Softmax is not yet supported for {type(module)}.") + return model, transformed diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 854c1134a..ab40de256 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -209,6 +209,7 @@ class Constants: SDK_PLATFORM_XML = ( "/opt/qti-aic/versions/platform.xml" # This xml file is parsed to find out the SDK platform version. ) + SKIP_THRESHOLD = 500 @dataclass diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index ead636759..2e87722e0 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -70,12 +70,19 @@ test_models_blockedKV = [ # "meta-llama/Llama-3.3-70B-Instruct", "meta-llama/Llama-3.2-1B", + "openai/gpt-oss-20b", ] +test_models_skipSoftmax = [ + "openai/gpt-oss-20b", +] + + + def get_custom_n_layers(model_name): """ - Function to set number layers of the variuos types of models such as swiftkv models and others + Function to set number layers of the various types of models such as swiftkv models and others -------- :model_name: str @@ -193,6 +200,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( "Tokens don't match for HF PyTorch model output and KV PyTorch model output" ) + onnx_model_path = qeff_model.export() ort_tokens = api_runner.run_kv_model_on_ort(onnx_model_path, is_tlm=is_tlm) gen_len = ort_tokens.shape[-1] @@ -513,6 +521,19 @@ def test_causal_blockedKV_pytorch_vs_kv_vs_ort_vs_ai100(model_name): qaic_config = dict(num_kv_blocks=Constants.NUM_KV_BLOCKS) check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_name", test_models_skipSoftmax) +def test_causal_skipSoftmax_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model for KV blocking, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + n_layer = get_custom_n_layers(model_name) + + qaic_config = dict(num_kv_blocks=Constants.NUM_KV_BLOCKS, skip_threshold=Constants.SKIP_THRESHOLD) + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name=model_name, n_layer=n_layer, qaic_config=qaic_config) + @pytest.mark.on_qaic @pytest.mark.parametrize("model_name", test_models_blockedKV)