Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2741,10 +2741,12 @@ def build_prefill_specialization(
Dict[str, Union[int, str]]
A dictionary defining the prefill specialization.
"""
if prefill_seq_len == 1 and self.continuous_batching:
if not self.continuous_batching:
exec_batch_size = batch_size
elif prefill_seq_len == 1:
exec_batch_size = full_batch_size
else:
exec_batch_size = 1 if self.continuous_batching else batch_size
exec_batch_size = 1

if hasattr(self.model, "get_specializations"):
spec = self.model.get_specializations(
Expand All @@ -2755,7 +2757,7 @@ def build_prefill_specialization(
)[0]
else:
spec = {
"batch_size": 1 if self.continuous_batching else batch_size,
"batch_size": exec_batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
}
Expand All @@ -2766,8 +2768,9 @@ def build_prefill_specialization(
spec["full_batch_size"] = kv_cache_batch_size
else:
spec["batch_size"] = kv_cache_batch_size
# TODO: remove this; not required
if full_batch_size:
spec["full_batch_exec_size"] = full_batch_size
spec["full_batch_exec_size"] = exec_batch_size
return {k: v for k, v in spec.items() if v is not None}

def build_decode_specialization(
Expand Down Expand Up @@ -2805,9 +2808,6 @@ def build_decode_specialization(
A dictionary defining the decode specialization, or None if it would be a duplicate
of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching).
"""
if prefill_seq_len == 1 and not self.continuous_batching:
return None # Avoid duplication with prefill

if hasattr(self.model, "get_specializations"):
spec = self.model.get_specializations(
batch_size=full_batch_size if self.continuous_batching else batch_size,
Expand Down Expand Up @@ -3025,7 +3025,7 @@ def compile(
)
)

if prefill_only is None or not prefill_only:
if (prefill_only is None or not prefill_only) and prefill_seq_len != 1:
if self.comp_ctx_lengths_decode is not None:
# Adding elements from self.comp_ctx_lengths_decode to decode_specialization
for i in range(0, len(self.comp_ctx_lengths_decode)):
Expand Down Expand Up @@ -3054,6 +3054,8 @@ def compile(
if decode_spec:
specializations.append(decode_spec)

if kw_spec := compiler_options.pop("specializations", None):
specializations = kw_spec
# --- Compilation ---
kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
custom_io = {}
Expand Down
41 changes: 41 additions & 0 deletions tests/transformers/models/test_causal_lm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
config: Optional[AutoConfig] = None,
pytorch_hf_tokens: Optional[list] = None,
qaic_config: Optional[dict] = None,
retain_full_kv: Optional[bool] = None,
):
"""
Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
Expand Down Expand Up @@ -211,6 +212,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
prefill_only=prefill_only,
enable_qnn=enable_qnn,
qnn_config=qnn_config,
retain_full_kv=retain_full_kv,
)
exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR)
cloud_ai_100_tokens = exec_info.generated_ids[0][
Expand Down Expand Up @@ -260,17 +262,38 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
if not get_available_device_id():
pytest.skip("No available devices to run model on Cloud AI 100")

compiler_options = {}
if prompt_len == 1:
prefill_spec = {
"batch_size": batch_size,
"seq_len": 1,
"ctx_len": ctx_len,
"full_batch_size": full_batch_size,
"sliding_window": 128,
}
decode_spec = {
"batch_size": full_batch_size,
"seq_len": 1,
"ctx_len": ctx_len,
"full_batch_size": full_batch_size,
"sliding_window": 128,
}
compiler_options = {"specializations": [prefill_spec, decode_spec]}

# TODO: add prefill_only tests
qpc_path = qeff_model.compile(
prefill_seq_len=prompt_len,
ctx_len=ctx_len,
num_cores=14,
mxfp6=False,
aic_enable_depth_first=False,
batch_size=batch_size,
full_batch_size=full_batch_size,
num_speculative_tokens=num_speculative_tokens,
enable_qnn=enable_qnn,
qnn_config=qnn_config,
retain_full_kv=retain_full_kv,
**compiler_options,
)
exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts)

Expand Down Expand Up @@ -370,6 +393,24 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name):
)


@pytest.mark.nightly
@pytest.mark.on_qaic
@pytest.mark.parametrize("retain_full_kv", [True, False])
def test_causal_lm_gpt_oss_pytorch_vs_kv_vs_ort_vs_ai100_pl1(retain_full_kv):
"""
Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching.
``Mandatory`` Args:
:model_name (str): Hugging Face Model Card name, Example: ``gpt2``
"""
model_name = "openai/gpt-oss-20b"
n_layer = get_custom_n_layers(model_name)
prompt_len = 1

check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(
model_name=model_name, n_layer=n_layer, prompt_len=prompt_len, retain_full_kv=retain_full_kv
)


@pytest.mark.on_qaic
@pytest.mark.regular
@pytest.mark.qnn
Expand Down
Loading
Loading