diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index d2cc1e681..c7372c8cc 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2741,10 +2741,12 @@ def build_prefill_specialization( Dict[str, Union[int, str]] A dictionary defining the prefill specialization. """ - if prefill_seq_len == 1 and self.continuous_batching: + if not self.continuous_batching: + exec_batch_size = batch_size + elif prefill_seq_len == 1: exec_batch_size = full_batch_size else: - exec_batch_size = 1 if self.continuous_batching else batch_size + exec_batch_size = 1 if hasattr(self.model, "get_specializations"): spec = self.model.get_specializations( @@ -2755,7 +2757,7 @@ def build_prefill_specialization( )[0] else: spec = { - "batch_size": 1 if self.continuous_batching else batch_size, + "batch_size": exec_batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, } @@ -2766,8 +2768,9 @@ def build_prefill_specialization( spec["full_batch_size"] = kv_cache_batch_size else: spec["batch_size"] = kv_cache_batch_size + # TODO: remove this; not required if full_batch_size: - spec["full_batch_exec_size"] = full_batch_size + spec["full_batch_exec_size"] = exec_batch_size return {k: v for k, v in spec.items() if v is not None} def build_decode_specialization( @@ -2805,9 +2808,6 @@ def build_decode_specialization( A dictionary defining the decode specialization, or None if it would be a duplicate of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching). """ - if prefill_seq_len == 1 and not self.continuous_batching: - return None # Avoid duplication with prefill - if hasattr(self.model, "get_specializations"): spec = self.model.get_specializations( batch_size=full_batch_size if self.continuous_batching else batch_size, @@ -3025,7 +3025,7 @@ def compile( ) ) - if prefill_only is None or not prefill_only: + if (prefill_only is None or not prefill_only) and prefill_seq_len != 1: if self.comp_ctx_lengths_decode is not None: # Adding elements from self.comp_ctx_lengths_decode to decode_specialization for i in range(0, len(self.comp_ctx_lengths_decode)): @@ -3054,6 +3054,8 @@ def compile( if decode_spec: specializations.append(decode_spec) + if kw_spec := compiler_options.pop("specializations", None): + specializations = kw_spec # --- Compilation --- kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" custom_io = {} diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index ead636759..e5cd35986 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -153,6 +153,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( config: Optional[AutoConfig] = None, pytorch_hf_tokens: Optional[list] = None, qaic_config: Optional[dict] = None, + retain_full_kv: Optional[bool] = None, ): """ Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. @@ -211,6 +212,7 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( prefill_only=prefill_only, enable_qnn=enable_qnn, qnn_config=qnn_config, + retain_full_kv=retain_full_kv, ) exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) cloud_ai_100_tokens = exec_info.generated_ids[0][ @@ -260,6 +262,24 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( if not get_available_device_id(): pytest.skip("No available devices to run model on Cloud AI 100") + compiler_options = {} + if prompt_len == 1: + prefill_spec = { + "batch_size": batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "full_batch_size": full_batch_size, + "sliding_window": 128, + } + decode_spec = { + "batch_size": full_batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "full_batch_size": full_batch_size, + "sliding_window": 128, + } + compiler_options = {"specializations": [prefill_spec, decode_spec]} + # TODO: add prefill_only tests qpc_path = qeff_model.compile( prefill_seq_len=prompt_len, @@ -267,10 +287,13 @@ def check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( num_cores=14, mxfp6=False, aic_enable_depth_first=False, + batch_size=batch_size, full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens, enable_qnn=enable_qnn, qnn_config=qnn_config, + retain_full_kv=retain_full_kv, + **compiler_options, ) exec_info_fbs = qeff_model.generate(tokenizer, prompts=fbs_prompts) @@ -370,6 +393,24 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): ) +@pytest.mark.nightly +@pytest.mark.on_qaic +@pytest.mark.parametrize("retain_full_kv", [True, False]) +def test_causal_lm_gpt_oss_pytorch_vs_kv_vs_ort_vs_ai100_pl1(retain_full_kv): + """ + Test function to validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + ``Mandatory`` Args: + :model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + model_name = "openai/gpt-oss-20b" + n_layer = get_custom_n_layers(model_name) + prompt_len = 1 + + check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, n_layer=n_layer, prompt_len=prompt_len, retain_full_kv=retain_full_kv + ) + + @pytest.mark.on_qaic @pytest.mark.regular @pytest.mark.qnn diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py index 6358940df..bd4db9a77 100644 --- a/tests/transformers/models/test_disagg_mode.py +++ b/tests/transformers/models/test_disagg_mode.py @@ -10,13 +10,13 @@ import numpy as np import pytest import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, HybridCache +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, HybridCache from QEfficient import QEFFAutoModelForCausalLM from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers -model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32 +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 prompt2 = """ Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. @@ -190,3 +190,309 @@ def test_disagg_mode_prefill_chunked(model_id, prompt): del prefill_session # Check QAIC output isclose with QEFF pytorch output assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 8e-2 + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_id", [model_id]) +@pytest.mark.parametrize("prompt", [prompt1]) +def test_disagg_mode_prefill_only_and_decode_only(model_id, prompt): + # Run prefill for original pytorch model + tokenizer = AutoTokenizer.from_pretrained(model_id) + PREFILL_SEQ_LEN = 256 + CTX_LEN = 256 + inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + replace_transformers_quantizers() + model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + config = model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} + cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN) + ins = tokenizer(prompt, return_tensors="pt") + orig_out = model(**ins, past_key_values=cache) + + position_ids = inputs["position_ids"] + generated_ids = [] + generation_len = 10 + out = orig_out + for _ in range(1, generation_len): + next_token_id = out["logits"][:, -1, :].argmax(-1).reshape(-1, 1) + generated_ids.append(next_token_id) + position_ids = position_ids.max(1, keepdim=True).values + 1 + decode_inputs = { + "input_ids": next_token_id, + "position_ids": position_ids, + "past_key_values": out["past_key_values"], + } + out = model(**decode_inputs) + + generated_ids.append(out["logits"][:, -1, :].argmax(-1).reshape(-1, 1)) + generated_ids = np.concatenate(generated_ids, axis=1) + predicted_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + print("Original HF Model Outputs (Torch CPU): \n") + print("Prompt:", repr(prompt)) + print("Completion:", repr(predicted_string)) + + undo_transformers_quantizers() + + prefill_qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + prefill_qeff_model.prefill(enable=True) + config = prefill_qeff_model.model.config + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_len = 128 if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs["past_key_values"] = past_key_values + + prefill_qeff_out = prefill_qeff_model.model(**inputs) + + # Check our pytorch implementation + assert (prefill_qeff_out.logits - orig_out.logits[:, -1, :]).abs().max() < 1e-4 + + decode_qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + decode_qeff_model.prefill(enable=False) + qeff_out = prefill_qeff_out + + position_ids = inputs["position_ids"] + qeff_generated_ids = [] + for _ in range(1, generation_len): + next_token_id = qeff_out["logits"][:, -1, :].argmax(-1).reshape(-1, 1) + qeff_generated_ids.append(next_token_id) + position_ids = position_ids.max(1, keepdim=True).values + 1 + decode_inputs = { + "input_ids": next_token_id, + "position_ids": position_ids, + "past_key_values": qeff_out["past_key_values"], + } + qeff_out = decode_qeff_model.model(**decode_inputs) + + qeff_generated_ids.append(out["logits"][:, -1, :].argmax(-1).reshape(-1, 1)) + qeff_generated_ids = np.concatenate(qeff_generated_ids, axis=1) + predicted_string = tokenizer.batch_decode(qeff_generated_ids, skip_special_tokens=True) + print("QEFF Transformed Model Outputs (Torch CPU): \n") + print("Prompt:", repr(prompt)) + print("Completion:", repr(predicted_string)) + + assert (qeff_generated_ids == generated_ids).all() + + prefill_qpc_path = prefill_qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + ) + + prefill_session = QAICInferenceSession(prefill_qpc_path) + logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + inputs.pop("past_key_values") + inputs = {k: v.detach().numpy() for k, v in inputs.items()} + qpc_out = prefill_session.run(inputs) + del prefill_session + # Check QAIC output isclose with QEFF pytorch output + assert (torch.from_numpy(qpc_out["logits"]) - prefill_qeff_out.logits).abs().max() < 5e-2 + + decode_qpc_path = decode_qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + ) + + qpc_outputs = [] + decode_session = QAICInferenceSession(decode_qpc_path) + decode_session.set_buffers({"logits": logits_out_placeholder}) + + decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, + } + + qpc_outputs.append(decode_inputs["input_ids"][0][0]) + for i in range(config.num_hidden_layers): + if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: + k = qpc_out[f"past_key.{i}_RetainedState"] + v = qpc_out[f"past_value.{i}_RetainedState"] + mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window + decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) + decode_inputs[f"past_value.{i}"] = np.concatenate( + (v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2 + ) + else: + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + + decode_out = decode_session.run(decode_inputs) + decode_session.skip_buffers( + [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] + ) + pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 + for i in range(generation_len - 1): + loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + qpc_outputs.append(loop_decode_inputs["input_ids"][0][0]) + decode_out = decode_session.run(loop_decode_inputs) + pos_id += 1 + + print("QPC Outputs (AIC): \n") + print("Prompt:", repr(prompt)) + print("Completion:", repr(tokenizer.decode(qpc_outputs))) + assert (qeff_generated_ids == qpc_outputs).all() + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_id", [model_id]) +@pytest.mark.parametrize("prompt", [prompt1]) +def test_disagg_mode_prefix_caching(model_id, prompt): + PREFILL_SEQ_LEN = 128 + CTX_LEN = 128 * 3 + config = AutoConfig.from_pretrained(model_id, num_hidden_layers=2) + prefill_qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, num_hidden_layers=2, continuous_batching=True + ) + prefill_qeff_model.prefill(enable=True, enable_chunking=True) + prefill_qpc_path = prefill_qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + full_batch_size=1, + kv_cache_batch_size=2, + ) + + decode_qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, num_hidden_layers=2, continuous_batching=True + ) + decode_qeff_model.prefill(enable=False) + decode_qpc_path = decode_qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + full_batch_size=1, + kv_cache_batch_size=2, + retain_full_kv=True, + ) + + out1, ids1 = prefix_caching_inference(model_id, prefill_qpc_path, decode_qpc_path, prompt, decode_batch_id=0) + out2, ids2 = prefix_caching_inference(model_id, prefill_qpc_path, decode_qpc_path, prompt, decode_batch_id=1) + + for i in range(config.num_hidden_layers): + assert ( + np.abs( + out1[f"past_key.{i}_RetainedState"][0, :, :, :] - out2[f"past_key.{i}_RetainedState"][1, :, :, :] + ).max() + < 5e-2 + ) + assert ( + np.abs( + out1[f"past_value.{i}_RetainedState"][0, :, :, :] - out2[f"past_value.{i}_RetainedState"][1, :, :, :] + ).max() + < 5e-2 + ) + + +def prefix_caching_inference(model_id, prefill_qpc_path, decode_qpc_path, prompt, decode_batch_id): + PREFILL_SEQ_LEN = 128 + tokenizer = AutoTokenizer.from_pretrained(model_id) + config = AutoConfig.from_pretrained(model_id, num_hidden_layers=2) + inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs["batch_index"] = np.array([[decode_batch_id]], dtype=np.int64) + + prefill_session = QAICInferenceSession(prefill_qpc_path) + logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + qpc_out = prefill_session.run(chunk_inputs) + del prefill_session + + qpc_outputs = [] + decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, + "batch_index": inputs["batch_index"], + } + qpc_outputs.append(decode_inputs["input_ids"][0][0]) + + decode_session = QAICInferenceSession(decode_qpc_path) + decode_session.set_buffers({"logits": logits_out_placeholder}) + generation_len = 5 + + for i in range(config.num_hidden_layers): + if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: + k = qpc_out[f"past_key.{i}_RetainedState"] + v = qpc_out[f"past_value.{i}_RetainedState"] + mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window + decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) + decode_inputs[f"past_value.{i}"] = np.concatenate( + (v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2 + ) + else: + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + + decode_out = decode_session.run(decode_inputs) + pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 + for i in range(generation_len - 1): + loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + "batch_index": inputs["batch_index"], + } + for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + qpc_outputs.append(loop_decode_inputs["input_ids"][0][0]) + decode_out = decode_session.run(loop_decode_inputs) + pos_id += 1 + + print("QPC Outputs (AIC): \n") + print("Prompt:", repr(prompt)) + print("Completion:", repr(tokenizer.decode(qpc_outputs))) + return qpc_out, qpc_outputs