Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,6 @@ def _compile(
mdp_dump_json_path = compiler_options.pop("mdp_dump_partition_config", None)
mdp_ts_json_path = compiler_options.pop("mdp_load_partition_config", None)
mdp_ts_json = None
user_provided_load_config = False

if mdp_dump_json_path:
if mdp_ts_json_path:
Expand All @@ -453,12 +452,14 @@ def _compile(
elif mdp_ts_json_path:
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")
mdp_ts_json = load_json(str(mdp_ts_json_path))
user_provided_load_config = True
elif mdp_ts_num_devices > 1:
# Generate mdp config only if neither dump nor load is provided and num_devices > 1
mdp_ts_json = generate_mdp_partition_config(
mdp_ts_num_devices, compiler_options.get("aic_num_cores", constants.DEFAULT_AIC_NUM_CORES)
)
mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
create_json(str(mdp_ts_json_path), mdp_ts_json)
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")

for key, value in compiler_options.items():
option = "-" + key.replace("_", "-")
Expand Down Expand Up @@ -495,10 +496,7 @@ def _compile(
shutil.rmtree(qpc_path)

# Write the generated MDP partition config file (not if user provided it)
if mdp_ts_json is not None and not user_provided_load_config:
mdp_ts_json_path = compile_dir / f"mdp_ts_{mdp_ts_num_devices}.json"
create_json(str(mdp_ts_json_path), mdp_ts_json)
command.append(f"-mdp-load-partition-config={mdp_ts_json_path}")


# Write specializations.json file
if specializations is not None:
Expand Down
18 changes: 10 additions & 8 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2747,10 +2747,12 @@ def build_prefill_specialization(
Dict[str, Union[int, str]]
A dictionary defining the prefill specialization.
"""
if prefill_seq_len == 1 and self.continuous_batching:
if not self.continuous_batching:
exec_batch_size = batch_size
elif prefill_seq_len == 1:
exec_batch_size = full_batch_size
else:
exec_batch_size = 1 if self.continuous_batching else batch_size
exec_batch_size = 1

if hasattr(self.model, "get_specializations"):
spec = self.model.get_specializations(
Expand All @@ -2761,7 +2763,7 @@ def build_prefill_specialization(
)[0]
else:
spec = {
"batch_size": 1 if self.continuous_batching else batch_size,
"batch_size": exec_batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
}
Expand All @@ -2772,8 +2774,9 @@ def build_prefill_specialization(
spec["full_batch_size"] = kv_cache_batch_size
else:
spec["batch_size"] = kv_cache_batch_size
# TODO: remove this; not required
if full_batch_size:
spec["full_batch_exec_size"] = full_batch_size
spec["full_batch_exec_size"] = exec_batch_size
return {k: v for k, v in spec.items() if v is not None}

def build_decode_specialization(
Expand Down Expand Up @@ -2811,9 +2814,6 @@ def build_decode_specialization(
A dictionary defining the decode specialization, or None if it would be a duplicate
of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching).
"""
if prefill_seq_len == 1 and not self.continuous_batching:
return None # Avoid duplication with prefill

if hasattr(self.model, "get_specializations"):
spec = self.model.get_specializations(
batch_size=full_batch_size if self.continuous_batching else batch_size,
Expand Down Expand Up @@ -3031,7 +3031,7 @@ def compile(
)
)

if prefill_only is None or not prefill_only:
if (prefill_only is None or not prefill_only) and prefill_seq_len != 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set the prefill_only to false as default value?

if self.comp_ctx_lengths_decode is not None:
# Adding elements from self.comp_ctx_lengths_decode to decode_specialization
for i in range(0, len(self.comp_ctx_lengths_decode)):
Expand Down Expand Up @@ -3060,6 +3060,8 @@ def compile(
if decode_spec:
specializations.append(decode_spec)

if kw_spec := compiler_options.pop("specializations", None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This can be simplified to
if "specializations" in compiler_options:
specializations = compiler_options.pop("specializations")

specializations = kw_spec
# --- Compilation ---
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
custom_io = {}
Expand Down
41 changes: 41 additions & 0 deletions tests/transformers/models/test_causal_lm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
config: Optional[AutoConfig] = None,
pytorch_hf_tokens: Optional[list] = None,
qaic_config: Optional[dict] = None,
retain_full_kv: Optional[bool] = None,
):
"""
Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
Expand Down Expand Up @@ -211,6 +212,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
prefill_only=prefill_only,
enable_qnn=enable_qnn,
qnn_config=qnn_config,
retain_full_kv=retain_full_kv,
)
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
cloud_ai_100_tokens = exec_info.generated_ids[0][
Expand Down Expand Up @@ -260,17 +262,38 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
if not get_available_device_id():
pytest.skip("No available devices to run model on Cloud AI 100")

compiler_options = {}
if prompt_len == 1:
prefill_spec = {
"batch_size": batch_size,
"seq_len": 1,
"ctx_len": ctx_len,
"full_batch_size": full_batch_size,
"sliding_window": 128,
}
decode_spec = {
"batch_size": full_batch_size,
"seq_len": 1,
"ctx_len": ctx_len,
"full_batch_size": full_batch_size,
"sliding_window": 128,
}
compiler_options = {"specializations": [prefill_spec, decode_spec]}

# TODO: add prefill_only tests
qpc_path = qeff_model.compile(
prefill_seq_len=prompt_len,
ctx_len=ctx_len,
num_cores=14,
mxfp6=False,
aic_enable_depth_first=False,
batch_size=batch_size,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
enable_qnn=enable_qnn,
qnn_config=qnn_config,
retain_full_kv=retain_full_kv,
**compiler_options,
)
exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)

Expand Down Expand Up @@ -370,6 +393,24 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
)


@pytest.mark.nightly
@pytest.mark.on_qaic
@pytest.mark.parametrize("retain_full_kv", [True, False])
def test_causal_lm_gpt_oss_pytorch_vs_kv_vs_ort_vs_ai100_pl1(retain_full_kv):
"""
Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
``Mandatory`` Args:
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
"""
model_name = "openai/gpt-oss-20b"
n_layer = get_custom_n_layers(model_name)
prompt_len = 1

check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
model_name=model_name, n_layer=n_layer, prompt_len=prompt_len, retain_full_kv=retain_full_kv
)


@pytest.mark.on_qaic
@pytest.mark.regular
@pytest.mark.qnn
Expand Down
Loading
Loading