Skip to content
15 changes: 15 additions & 0 deletions QEfficient/generation/embedding_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,21 @@ def prepare_vlm_inputs(self, image_url: str, query: str, prefill_seq_len: int) -
inputs = self._qeff_model.model.prepare_inputs_for_generation(
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
)
if (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "qwen3_vl"
):
inputs = self._qeff_model.model.prepare_inputs_for_generation(
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
)

if (
hasattr(self._qeff_model.model.config, "model_type")
and self._qeff_model.model.config.model_type == "qwen3_vl_moe"
):
inputs = self._qeff_model.model.prepare_inputs_for_generation(
inputs=inputs, prefill_seq_len=prefill_seq_len, batch_size=inputs["input_ids"].shape[0]
)

# Convert to float32 if needed
if "pixel_values" in inputs:
Expand Down
33 changes: 32 additions & 1 deletion QEfficient/generation/vlm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ def __init__(
self.is_qwen2_5_vl = (
hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl"
)
self.is_qwen3_vl = (
hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen3_vl"
)
self.qeff_model = qeff_model
self.processor = processor
self.tokenizer = tokenizer
Expand Down Expand Up @@ -259,6 +262,8 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):

if self.is_qwen2_5_vl:
_ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id)
if self.is_qwen3_vl:
_ = self.update_decode_inputs_qwen3_vl(outputs, position_ids, generation_len, decode_batch_id)
else:
_ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id)

Expand All @@ -283,6 +288,27 @@ def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len,
self.generation_len[decode_batch_id or slice(None)] = generation_len
return next_token_id

def update_decode_inputs_qwen3_vl(self, outputs, position_ids, generation_len, decode_batch_id=None):
"""
Updates the decode input with the generated values.
Args:
outputs (dict): The outputs of the model.
position_ids (array): The position IDs.
generation_len (int): The generation length.
decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None.

Returns:
next_token_id (array): The next token ID.
"""
next_token_id = self._fetch_next_token_id(outputs)

# Store the generated values.
self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id
self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1)
self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1)
self.generation_len[decode_batch_id or slice(None)] = generation_len
return next_token_id

def _execute_chunked_prefill(
self,
lang_inputs: Dict[str, np.ndarray],
Expand Down Expand Up @@ -583,7 +609,8 @@ def _generate_continuous_batching(self, vision_prompts, generation_len, stream,
self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length)
if self.is_qwen2_5_vl:
self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64)

if self.is_qwen3_vl:
self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64)
# Create prompt queue
prompt_queue = deque(vision_prompts)

Expand Down Expand Up @@ -696,6 +723,10 @@ def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation
self.update_decode_inputs_qwen2_5_vl(
outputs, position_ids_decode, generation_len_final, decode_batch_id
)
elif self.is_qwen3_vl:
self.update_decode_inputs_qwen3_vl(
outputs, position_ids_decode, generation_len_final, decode_batch_id
)
else:
self.update_decode_input(outputs, position_ids_decode, generation_len_final, decode_batch_id)
else:
Expand Down
52 changes: 47 additions & 5 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ def _get_invalid_idx_value(cls):


class QEffDynamicLayer(DynamicLayer):
def lazy_initialization(self, key_states: torch.Tensor):
self.dtype, self.device = key_states.dtype, key_states.device
self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
self.values = torch.tensor([], dtype=self.dtype, device=self.device)
self.is_initialized = True

def read_only(self, cache_kwargs):
"""
Reads the `key_states` and `value_states` for the layer.
Expand Down Expand Up @@ -185,11 +191,14 @@ def update(
Return:
A tuple containing the updated key and value states.
"""

# Update the cache

if self.keys is None:
self.keys = key_states
self.values = value_states
k_out, v_out = self.keys, self.values
self.is_initialized = True
else:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs
Expand Down Expand Up @@ -306,15 +315,47 @@ class QEffDynamicCache(DynamicCache):

"""

def __init__(self, ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None, *args, **kwargs):
def __init__(
self,
ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
config=None,
offloading: bool = False,
offload_only_non_sliding: bool = False,
*args,
**kwargs,
):
# Remove layer_classes if present to avoid duplicate argument
kwargs.pop("layer_classes", None)

kwargs.pop("layers", None)
from transformers.cache_utils import Cache # Import here to avoid circular import

Cache.__init__(self, layer_classes=QEffDynamicLayer, *args, **kwargs)
layers = []
# If a config is passed, use it to infer the layer types and initialize accordingly
if len(layers) == 0:
Cache.__init__(
self,
layer_class_to_replicate=QEffDynamicLayer,
offloading=offloading,
offload_only_non_sliding=offload_only_non_sliding,
# args=args,
# kwargs=kwargs,
)
else:
Cache.__init__(
self,
layers=layers,
offloading=offloading,
offload_only_non_sliding=offload_only_non_sliding,
# args=args,
# kwargs=kwargs,
)

if ddp_cache_data is not None:
for key_states, value_states in ddp_cache_data:
self.layers.append(QEffDynamicLayer.from_tensors(key_states, value_states))
for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
# If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
layers.append(QEffDynamicLayer())
# Update the layer with the data
_, _ = layers[layer_idx].update(key_states, value_states)

def read_only(self, layer_idx, cache_kwargs):
"""
Expand All @@ -329,6 +370,7 @@ def read_only(self, layer_idx, cache_kwargs):
Return:
A tuple containing the updated key and value states.
"""

return self.layers[layer_idx].read_only(cache_kwargs)

def read_only_blockedKV(self, start_index, end_index, layer_idx, cache_kwargs):
Expand Down
19 changes: 18 additions & 1 deletion QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,6 @@ def compile(
use_onnx_subfunctions=use_onnx_subfunctions,
**compiler_options,
)

# Custom NPI file options
if hasattr(self.model, "get_npi_file") and "node_precision_info" not in compiler_options:
compiler_options["node_precision_info"] = self.model.get_npi_file(self.model.name_or_path)
Expand Down Expand Up @@ -1686,6 +1685,14 @@ def kv_offload_generate(

vision_inputs_fp16 = {"pixel_values", "image_masks"}
vision_inputs.update({k: vision_inputs[k].astype("float16") for k in vision_inputs_fp16 if k in vision_inputs})
pixel_values_shape = list(vision_inputs["pixel_values"].shape)
idx = next(i for i, inner in enumerate(vision_session.allowed_shapes) if (2, pixel_values_shape) in inner)

biffer_set = {
"vision_embeds": np.zeros(vision_session.allowed_shapes[idx][2][1], dtype=np.float16),
"image_grid_thw": np.zeros(vision_session.allowed_shapes[idx][0][1], dtype=np.int64),
}
vision_session.set_buffers(biffer_set)

vision_start = perf_counter()

Expand All @@ -1712,6 +1719,16 @@ def kv_offload_generate(
vision_session.deactivate()
lang_session.activate()

vision_outputs["vision_embeds"] = np.pad(
vision_outputs["vision_embeds"],
pad_width=(
(0, 0),
(0, lang_session.allowed_shapes[0][1][1][1] - vision_session.allowed_shapes[idx][2][1][1]),
(0, 0),
), # pad axis=1 only
mode="constant",
constant_values=0,
)
lang_session.set_buffers(vision_outputs)

if self.comp_ctx_lengths_prefill is not None:
Expand Down
28 changes: 28 additions & 0 deletions QEfficient/transformers/models/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,16 @@
Qwen3MoeRotaryEmbedding,
Qwen3MoeSparseMoeBlock,
)
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
Qwen3VLForConditionalGeneration,
Qwen3VLModel,
Qwen3VLTextAttention,
Qwen3VLTextDecoderLayer,
Qwen3VLTextModel,
Qwen3VLTextRMSNorm,
Qwen3VLVisionAttention,
Qwen3VLVisionModel,
)
from transformers.models.starcoder2.modeling_starcoder2 import (
Starcoder2Attention,
Starcoder2DecoderLayer,
Expand Down Expand Up @@ -432,6 +442,15 @@
QEffQwen3MoeRotaryEmbedding,
QEffQwen3MoeSparseMoeBlock,
)
from QEfficient.transformers.models.qwen3_vl.modeling_qwen3_vl import (
QEffQwen3VLForConditionalGeneration,
QEffQwen3VLModel,
QEffQwen3VLTextAttention,
QEffQwen3VLTextDecoderLayer,
QEffQwen3VLTextModel,
QEffQwen3VLVisionAttention,
QEffQwen3VLVisionModel,
)
from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import (
QEffStarcoder2Attention,
QEFFStarcoder2DecoderLayer,
Expand Down Expand Up @@ -478,6 +497,7 @@ class CustomOpsTransform(ModuleMappingTransform):
GraniteMoeRMSNorm: CustomRMSNormAIC,
Qwen3MoeRMSNorm: CustomRMSNormAIC,
Gemma3RMSNorm: QEffGemma3CustomRMSNormAIC,
Qwen3VLTextRMSNorm: CustomRMSNormAIC,
Olmo2RMSNorm: CustomRMSNormAIC,
}

Expand Down Expand Up @@ -628,6 +648,14 @@ class KVCacheTransform(ModuleMappingTransform):
Qwen2_5_VisionTransformerPretrainedModel: QEffQwen2_5_VisionTransformerPretrainedModel,
Qwen2_5_VLVisionAttention: QEffQwen2_5_VLVisionAttention,
Qwen2_5_VLTextModel: QEffQwen2_5_VLTextModel,
# Qwen3 VL
Qwen3VLForConditionalGeneration: QEffQwen3VLForConditionalGeneration,
Qwen3VLModel: QEffQwen3VLModel,
Qwen3VLTextAttention: QEffQwen3VLTextAttention,
Qwen3VLTextDecoderLayer: QEffQwen3VLTextDecoderLayer,
Qwen3VLVisionAttention: QEffQwen3VLVisionAttention,
Qwen3VLVisionModel: QEffQwen3VLVisionModel,
Qwen3VLTextModel: QEffQwen3VLTextModel,
# Starcoder2
Starcoder2Attention: QEffStarcoder2Attention,
Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer,
Expand Down
6 changes: 6 additions & 0 deletions QEfficient/transformers/models/qwen3_vl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
Loading
Loading