Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion QEfficient/compile/compile_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,22 @@ def compile_kv_model_on_cloud_ai_100(
raise FileNotFoundError(f"Please use 'QEfficient.compile', as {specializations_json} file was not found")
if not os.path.isfile(custom_io_path):
raise FileNotFoundError(f"{custom_io_path} file was not found!")
aic_version = kwargs.get("aic_hw_version", constants.DEFAULT_AIC_HW_VERSION)
command = [
"/opt/qti-aic/exec/qaic-exec",
f"-m={onnx_path}",
"-aic-hw",
f"-aic-hw-version={kwargs.pop('aic_hw_version', kwargs.pop('aic-hw-version', constants.DEFAULT_AIC_HW_VERSION))}",
f"-network-specialization-config={specializations_json}",
"-convert-to-fp16",
# "-convert-to-fp16",
"-retained-state",
f"-aic-num-cores={num_cores}",
f"-custom-IO-list-file={custom_io_path}",
"-compile-only",
f"-aic-binary-dir={aic_binary_dir}",
]
if aic_version == "ai100":
command.append("-convert-to-fp16")
if mxfp6:
command.append("-mxfp6-matmul")
if mos > 0:
Expand Down
2 changes: 1 addition & 1 deletion QEfficient/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
if attention_mask is not None:
# Apply the attention mask
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def eager_attention_forward(

if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down
8 changes: 4 additions & 4 deletions QEfficient/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down Expand Up @@ -147,7 +147,7 @@ def eager_attention_forward_blockedKV(
past_seen_tokens = cache_kwargs.get("past_seen_tokens")
position_ids = cache_kwargs.get("position_ids")
block_size = -(-past_seen_tokens // num_kv_blocks)
masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32)
masked_tensor = torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype)

for j in range(num_kv_blocks):
start_index = j * block_size
Expand Down Expand Up @@ -439,7 +439,7 @@ def forward(
# Cast to INT32 to avoid issue while running in ONNXRT
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
logits = self.lm_head(hidden_states).float()
logits = self.lm_head(hidden_states).to(self.config.torch_dtype)

return CausalLMOutputWithPast(
loss=None,
Expand Down
18 changes: 14 additions & 4 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@
from QEfficient.utils.logging_utils import logger
from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs

DTYPE_TO_STRING_MAP = {
torch.float16: "float16",
torch.bfloat16: "bfloat16",
}


class QEFFTransformersBase(QEFFBaseModel):
"""
Expand Down Expand Up @@ -2659,7 +2664,9 @@ def export(
)
for i in range(self.num_layers):
for kv in ["key", "value"]:
example_inputs["past_key_values"][i].append(torch.zeros(pkv_cache[0][0].shape, dtype=torch.float32))
example_inputs["past_key_values"][i].append(
torch.zeros(pkv_cache[0][0].shape, dtype=self.model.config.torch_dtype)
)
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
output_names.append(f"past_{kv}.{i}_RetainedState")

Expand All @@ -2682,7 +2689,9 @@ def export(

for i in range(self.num_layers):
for kv in ["key", "value"]:
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
example_inputs["past_key_values"][i].append(
torch.zeros(kv_cache_shape, dtype=self.model.config.torch_dtype)
)
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i]
output_names.append(f"past_{kv}.{i}_RetainedState")

Expand Down Expand Up @@ -3059,7 +3068,8 @@ def compile(
specializations.append(decode_spec)

# --- Compilation ---
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
needed_dtype = self.model.config.torch_dtype
kv_cache_dtype = "mxint8" if mxint8_kv_cache else DTYPE_TO_STRING_MAP[needed_dtype]
custom_io = {}

for suffix in ["", "_RetainedState"]:
Expand Down Expand Up @@ -3667,7 +3677,7 @@ def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
seq_len = constants.WAV2VEC2_MAX_SEQ_LEN

example_inputs = {
"input_values": torch.zeros((bs, seq_len), dtype=torch.float32),
"input_values": torch.zeros((bs, seq_len), dtype=self.model.config.torch_dtype),
}

dynamic_axes = {"input_values": {0: "batch_size", 1: "seq_len"}}
Expand Down
6 changes: 3 additions & 3 deletions QEfficient/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down Expand Up @@ -383,7 +383,7 @@ def forward(
# Cast to INT32 to avoid issue while running in ONNXRT
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
logits = self.lm_head(hidden_states).float()
logits = self.lm_head(hidden_states).to(torch.float32)

return CausalLMOutputWithPast(
loss=None,
Expand Down
6 changes: 3 additions & 3 deletions QEfficient/transformers/models/qwen3/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ def eager_attention_forward(
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
attn_weights = torch.where(
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights
attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=module.config.torch_dtype), attn_weights
)

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=module.config.torch_dtype).to(query.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()

Expand Down Expand Up @@ -386,7 +386,7 @@ def forward(
# Cast to INT32 to avoid issue while running in ONNXRT
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
hidden_states = outputs.last_hidden_state[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
logits = self.lm_head(hidden_states).float()
logits = self.lm_head(hidden_states).to(torch.float32)

return CausalLMOutputWithPast(
loss=None,
Expand Down
44 changes: 32 additions & 12 deletions QEfficient/utils/generate_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
padding_check_and_fix,
)

MODEL_DTYPE_TO_INPUT_DTYPE_MAP = {
torch.float32: np.float32,
torch.float16: np.float16,
torch.bfloat16: np.float16, # bfloat16 not supported by onnxruntime, so we cast it to float16 for now.
}


class InputHandler:
def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size):
Expand Down Expand Up @@ -100,8 +106,8 @@ def prepare_pytorch_inputs(self):
pad_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]]
else:
pad_shape = self.padding_shape
past_key = torch.zeros((pad_shape), dtype=torch.float32)
past_value = torch.zeros((pad_shape), dtype=torch.float32)
past_key = torch.zeros((pad_shape), dtype=self.config.torch_dtype)
past_value = torch.zeros((pad_shape), dtype=self.config.torch_dtype)
pkv = (past_key, past_value)
past_key_values.append(pkv)
inputs["past_key_values"] = tuple(past_key_values)
Expand Down Expand Up @@ -170,8 +176,12 @@ def prepare_ort_inputs(self):
if hasattr(self.config, "model_type") and self.config.model_type in DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH:
for i in range(self.n_layer):
cache_shape = self.global_shape if not self.is_chunked_attention[i] else self.sliding_shape
inputs["past_key." + str(i)] = np.zeros((cache_shape), dtype=np.float32)
inputs["past_value." + str(i)] = np.zeros((cache_shape), dtype=np.float32)
inputs["past_key." + str(i)] = np.zeros(
(cache_shape), dtype=MODEL_DTYPE_TO_INPUT_DTYPE_MAP[self.config.torch_dtype]
)
inputs["past_value." + str(i)] = np.zeros(
(cache_shape), dtype=MODEL_DTYPE_TO_INPUT_DTYPE_MAP[self.config.torch_dtype]
)
else:
for i in range(self.n_layer):
if (
Expand All @@ -181,8 +191,12 @@ def prepare_ort_inputs(self):
pad_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]]
else:
pad_shape = self.padding_shape
inputs["past_key." + str(i)] = np.zeros((pad_shape), dtype=np.float32)
inputs["past_value." + str(i)] = np.zeros((pad_shape), dtype=np.float32)
inputs["past_key." + str(i)] = np.zeros(
(pad_shape), dtype=MODEL_DTYPE_TO_INPUT_DTYPE_MAP[self.config.torch_dtype]
)
inputs["past_value." + str(i)] = np.zeros(
(pad_shape), dtype=MODEL_DTYPE_TO_INPUT_DTYPE_MAP[self.config.torch_dtype]
)
if self.full_batch_size:
inputs["batch_index"] = np.arange(self.full_batch_size).reshape(-1, 1)
return inputs
Expand Down Expand Up @@ -324,17 +338,21 @@ def prepare_vlm_ort_inputs(self):
idx = cross_attention_layers.index(i)
assert idx == ((i - 3) // 5), f"{i}, {(i - 3) // 5}"
inputs["past_key." + str(i)] = np.zeros(
(self.batch_size, num_key_value_heads, image_tokens_len, head_dim), dtype=np.float32
(self.batch_size, num_key_value_heads, image_tokens_len, head_dim),
dtype=MODEL_DTYPE_TO_INPUT_DTYPE_MAP[self.config.torch_dtype],
)
inputs["past_value." + str(i)] = np.zeros(
(self.batch_size, num_key_value_heads, image_tokens_len, head_dim), dtype=np.float32
(self.batch_size, num_key_value_heads, image_tokens_len, head_dim),
dtype=MODEL_DTYPE_TO_INPUT_DTYPE_MAP[self.config.torch_dtype],
)
else:
inputs["past_key." + str(i)] = np.zeros(
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim),
dtype=MODEL_DTYPE_TO_INPUT_DTYPE_MAP[self.config.torch_dtype],
)
inputs["past_value." + str(i)] = np.zeros(
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim),
dtype=MODEL_DTYPE_TO_INPUT_DTYPE_MAP[self.config.torch_dtype],
)
lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
return vision_inputs, lang_inputs
Expand Down Expand Up @@ -474,10 +492,12 @@ def prepare_vlm_ort_inputs(self):

for i in range(num_hidden_layers):
inputs["past_key." + str(i)] = np.zeros(
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim),
dtype=MODEL_DTYPE_TO_INPUT_DTYPE_MAP[self.config.torch_dtype],
)
inputs["past_value." + str(i)] = np.zeros(
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim), dtype=np.float32
(self.batch_size, num_key_value_heads, self.ctx_len, head_dim),
dtype=MODEL_DTYPE_TO_INPUT_DTYPE_MAP[self.config.torch_dtype],
)
lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs}
return vision_inputs, lang_inputs
2 changes: 1 addition & 1 deletion QEfficient/utils/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def run_vlm_hf_model_on_pytorch_CB(self, model, images, queries):
# Process inputs
inputs = self.processor(images=image, text=prompt, return_tensors="pt")
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32)
inputs["pixel_values"] = inputs["pixel_values"].to(self.config.torch_dtype)

# Generate tokens
output = model.generate(**inputs, max_new_tokens=self.gen_len, do_sample=False)
Expand Down
Loading
Loading