From ef16a22ab8057a41a9bd8be6e7d47024e783de33 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 29 Jun 2026 16:00:47 +0800 Subject: [PATCH 01/29] feat(qwen3): add genai bundle generation and inference script - src/winml/modelkit/models/hf/qwen3/genai.py: new module with build_genai_config() and write_genai_bundle(). build_genai_config generates the onnxruntime-genai pipeline config JSON from a HF PretrainedConfig + max_cache_len + prefill_seq_len. write_genai_bundle copies the winml-built ctx/iter ONNX, optional placeholder embeddings and lm_head ONNX, saves tokenizer files from HF, and writes genai_config.json. - scripts/export_qwen3_transformer_only.py: add --genai-bundle DIR, --embeddings ONNX, --lm-head ONNX flags. When --genai-bundle is set, write_genai_bundle is called after the build to emit a complete onnxruntime-genai bundle. - scripts/infer_genai.py: new inference script. Loads the genai bundle with og.Config, registers WinML EPs (QNN), and runs greedy generation via og.Generator. Supports --ep cpu|qnn, --chat template wrapping, --max-new, --context-length, --verbose. - src/winml/modelkit/models/hf/qwen3/__init__.py: export build_genai_config and write_genai_bundle. - tests/unit/models/qwen3/test_genai_config.py: 21 unit tests for build_genai_config covering pipeline structure, KV name counts, tensor name constants, edge cases (list eos_token_id, missing head_dim, None pad_token_id, custom filenames, variable layer count). --- scripts/export_qwen3_transformer_only.py | 69 +++ scripts/infer_genai.py | 232 +++++++++++ .../modelkit/models/hf/qwen3/__init__.py | 17 +- src/winml/modelkit/models/hf/qwen3/genai.py | 392 ++++++++++++++++++ tests/unit/models/qwen3/test_genai_config.py | 225 ++++++++++ 5 files changed, 934 insertions(+), 1 deletion(-) create mode 100644 scripts/infer_genai.py create mode 100644 src/winml/modelkit/models/hf/qwen3/genai.py create mode 100644 tests/unit/models/qwen3/test_genai_config.py diff --git a/scripts/export_qwen3_transformer_only.py b/scripts/export_qwen3_transformer_only.py index 6894af518..202c9c906 100644 --- a/scripts/export_qwen3_transformer_only.py +++ b/scripts/export_qwen3_transformer_only.py @@ -121,6 +121,43 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: default=None, help="If set, copy the two ONNX (with external data) here as prefill.onnx / decode.onnx.", ) + + genai = p.add_argument_group( + "genai bundle", + "Options for producing an onnxruntime-genai inference bundle.", + ) + genai.add_argument( + "--genai-bundle", + type=Path, + default=None, + metavar="DIR", + help=( + "If set, assemble a complete onnxruntime-genai bundle in DIR: " + "ctx.onnx (prefill), iter.onnx (decode), genai_config.json, and " + "tokenizer files. Provide --embeddings and --lm-head to include " + "the placeholder models required for end-to-end inference." + ), + ) + genai.add_argument( + "--embeddings", + type=Path, + default=None, + metavar="ONNX", + help=( + "Path to the embeddings ONNX to copy into the genai bundle as " + "embeddings.onnx. Required for end-to-end genai inference." + ), + ) + genai.add_argument( + "--lm-head", + type=Path, + default=None, + metavar="ONNX", + help=( + "Path to the lm_head ONNX to copy into the genai bundle as " + "lm_head.onnx. Required for end-to-end genai inference." + ), + ) return p.parse_args(argv) @@ -164,6 +201,38 @@ def main(argv: list[str] | None = None) -> int: copy_onnx_model(src, dst) print(f" -> copied to {dst}") + # ----------------------------------------------------------------------- + # Optional: assemble an onnxruntime-genai bundle. + # ----------------------------------------------------------------------- + if args.genai_bundle is not None: + from winml.modelkit.models.hf.qwen3.genai import write_genai_bundle + + prefill_path = Path(model.sub_models["decoder_prefill"].onnx_path) + decode_path = Path(model.sub_models["decoder_gen"].onnx_path) + + print(f"\n=== assembling genai bundle -> {args.genai_bundle} ===") + config_path = write_genai_bundle( + args.genai_bundle, + context_onnx=prefill_path, + iterator_onnx=decode_path, + model_id=args.model_id, + max_cache_len=args.max_cache_len, + prefill_seq_len=args.prefill_seq_len, + embeddings_src=args.embeddings, + lm_head_src=args.lm_head, + ) + print(f" genai_config.json -> {config_path}") + if args.embeddings is None: + print( + " WARNING: --embeddings not provided; " + "add embeddings.onnx to the bundle before inference." + ) + if args.lm_head is None: + print( + " WARNING: --lm-head not provided; " + "add lm_head.onnx to the bundle before inference." + ) + return 0 diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py new file mode 100644 index 000000000..4a06ea6be --- /dev/null +++ b/scripts/infer_genai.py @@ -0,0 +1,232 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +r"""onnxruntime-genai inference for the Qwen3 transformer-only pipeline. + +Loads the genai bundle produced by ``export_qwen3_transformer_only.py +--genai-bundle `` and runs greedy text generation. + +The bundle directory must contain ``genai_config.json`` and the four ONNX +graphs it references: + + embeddings.onnx — embedding lookup (input_ids -> input_hidden_states) + ctx.onnx — prefill/context graph (seq_len = prefill_seq_len) + iter.onnx — iteration/decode graph (seq_len = 1) + lm_head.onnx — lm_head (output_hidden_states -> logits) + +It also needs the HF tokenizer files (``tokenizer.json``, +``tokenizer_config.json``, ``vocab.json``, ``merges.txt``, +``generation_config.json``) which ``write_genai_bundle`` downloads +automatically. + +Usage:: + + # CPU sanity check (works anywhere onnxruntime-genai is installed) + uv run python scripts/infer_genai.py --prompt "Hello, who are you?" --chat + + # Qualcomm NPU (registers the QNN EP via the Windows ML EP catalog) + uv run python scripts/infer_genai.py \\ + --prompt "Explain what a transformer is." \\ + --ep qnn --chat + + # Point at a non-default bundle + uv run python scripts/infer_genai.py \\ + --model-dir out/my_bundle --prompt "Hi" --ep cpu + +Dependencies (install in a fresh venv):: + + pip install onnxruntime-genai-winml + pip install "windowsml[with-ort]" # registers QNN EP; also provides onnxruntime +""" + +from __future__ import annotations + +import argparse +import sys +import time +from pathlib import Path + +import onnxruntime_genai as og + + +# Default bundle directory: /out/qwen3_bundle +_REPO_ROOT = Path(__file__).resolve().parent.parent +DEFAULT_MODEL_DIR = _REPO_ROOT / "out" / "qwen3_bundle" + +# The static KV cache length. Must equal ``context_length`` in genai_config.json +# (and the ``--max-cache-len`` used during the winml build). Do not lower this +# value — the KV buffer size is baked into the ONNX graphs. +CONTEXT_LENGTH = 256 + +# Maps the friendly --ep name to the ORT EP canonical name. +_EP_NAME = { + "cpu": "cpu", + "qnn": "QNNExecutionProvider", +} + + +def _register_winml_eps() -> list[str]: + """Discover and register Windows ML execution providers. + + Walks the WinML EP catalog, calls ``ensure_ready()`` on each provider + (downloads via Windows Update if needed), then registers the shared + library with ORT GenAI. Mirrors ``examples/python/winml.py`` from the + onnxruntime-genai repo. + """ + import traceback + + from windowsml import EpCatalog + + registered: list[str] = [] + with EpCatalog() as catalog: + for provider in catalog.find_all_providers(): + provider.ensure_ready() + if not provider.library_path: + continue + try: + og.register_execution_provider_library(provider.name, provider.library_path) + registered.append(provider.name) + except Exception as exc: + print(f"[winml] failed to register {provider.name}: {exc}") + traceback.print_exc() + return registered + + +def _build_og_config(model_dir: Path, ep: str) -> og.Config: + """Create an ``og.Config``, registering WinML EPs when not on CPU.""" + if ep != "cpu": + registered = _register_winml_eps() + print(f"[winml] registered EPs: {registered}") + + config = og.Config(str(model_dir)) + config.clear_providers() + if ep != "cpu": + config.append_provider(_EP_NAME[ep]) + return config + + +def _wrap_chat_template(prompt: str) -> str: + """Wrap *prompt* in the Qwen3 chat template (no thinking mode).""" + return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + """Parse CLI arguments.""" + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument( + "--prompt", + default="Give me a short introduction to large language models.", + help="Input prompt (default: %(default)s).", + ) + p.add_argument( + "--model-dir", + type=Path, + default=DEFAULT_MODEL_DIR, + metavar="DIR", + help=( + "Path to the genai bundle directory containing genai_config.json " + "and the ONNX / tokenizer files (default: %(default)s)." + ), + ) + p.add_argument( + "--ep", + choices=sorted(_EP_NAME), + default="cpu", + help="Execution provider (default: cpu).", + ) + p.add_argument( + "--max-new", + type=int, + default=128, + help="Maximum number of new tokens to generate (default: %(default)s).", + ) + p.add_argument( + "--chat", + action="store_true", + help="Wrap --prompt in the Qwen3 chat template.", + ) + p.add_argument( + "--context-length", + type=int, + default=CONTEXT_LENGTH, + help=( + "Static KV cache length. Must match the --max-cache-len used " + "during the winml build and the genai_config.json context_length " + "(default: %(default)s). Do NOT lower this value." + ), + ) + p.add_argument( + "--verbose", + action="store_true", + help="Enable onnxruntime-genai native model I/O logging.", + ) + return p.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + """Load the genai bundle and run generation.""" + args = parse_args(argv) + + model_dir: Path = args.model_dir + if not model_dir.exists(): + print( + f"ERROR: model directory not found: {model_dir}\n" + "Run export_qwen3_transformer_only.py --genai-bundle first.", + file=sys.stderr, + ) + return 1 + + config_file = model_dir / "genai_config.json" + if not config_file.exists(): + print( + f"ERROR: genai_config.json not found in {model_dir}\nThe bundle may be incomplete.", + file=sys.stderr, + ) + return 1 + + if args.verbose: + og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True) + + print(f"[load] ep={args.ep} bundle={model_dir}") + config = _build_og_config(model_dir, args.ep) + model = og.Model(config) + tokenizer = og.Tokenizer(model) + tokenizer_stream = tokenizer.create_stream() + + text = _wrap_chat_template(args.prompt) if args.chat else args.prompt + input_tokens = tokenizer.encode(text) + print(f"[tokens] prompt has {len(input_tokens)} tokens") + + params = og.GeneratorParams(model) + # max_length must equal the static KV cache size so genai sizes the + # total_sequence_length input and KV buffers correctly. + params.set_search_options( + max_length=args.context_length, + do_sample=False, + ) + + generator = og.Generator(model, params) + generator.append_tokens(input_tokens) + + print("[gen] ", end="", flush=True) + t0 = time.monotonic() + n = 0 + while not generator.is_done(): + generator.generate_next_token() + new_token = generator.get_next_tokens()[0] + print(tokenizer_stream.decode(new_token), end="", flush=True) + n += 1 + if n >= args.max_new: + break + + dt = time.monotonic() - t0 + print(f"\n\n[done] {n} tokens in {dt:.1f}s ({n / dt:.1f} tok/s)") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/winml/modelkit/models/hf/qwen3/__init__.py b/src/winml/modelkit/models/hf/qwen3/__init__.py index 332fb9234..9cbac5568 100644 --- a/src/winml/modelkit/models/hf/qwen3/__init__.py +++ b/src/winml/modelkit/models/hf/qwen3/__init__.py @@ -3,4 +3,19 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Qwen3 transformer-only export support (modeling, export ops, IO configs).""" +"""Qwen3 transformer-only export + genai bundle support. + +Modules: + qwen_transformer_only — OnnxConfig, build config, composite model class. + qwen3_modeling — winml-owned Qwen3 module definitions (forward bindings). + qwen3_export_ops — custom ONNX symbolic ops (LpNorm, GQA, 1x1 Conv). + genai — genai_config.json generator + bundle assembler. +""" + +from .genai import build_genai_config, write_genai_bundle + + +__all__ = [ + "build_genai_config", + "write_genai_bundle", +] diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py new file mode 100644 index 000000000..c3b93cfb1 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -0,0 +1,392 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Generate an onnxruntime-genai bundle for the Qwen3 transformer-only pipeline. + +The bundle is a directory that ``onnxruntime-genai`` can load directly via +``og.Config(str(bundle_dir))``. It contains: + + genai_config.json — pipeline config consumed by onnxruntime-genai + ctx.onnx — prefill/context ONNX (built by winml-cli) + iter.onnx — iteration/decode ONNX (built by winml-cli) + embeddings.onnx — embedding-lookup ONNX (placeholder; copy externally) + lm_head.onnx — lm_head ONNX (placeholder; copy externally) + tokenizer.json — HF tokenizer files (downloaded from the model repo) + tokenizer_config.json + vocab.json / merges.txt / generation_config.json + +The pipeline follows the same 4-stage layout as the reference bundle: + + input_ids → [embeddings] → input_hidden_states + → [context | iterator] → output_hidden_states + present KVs + → [lm_head] → logits + +The context stage runs on the prompt (prefill); the iterator stage runs on each +subsequent decode step. Both share the same KV cache buffer via genai's +``past_present_share_buffer`` mode. + +Public API:: + + from winml.modelkit.models.hf.qwen3.genai import build_genai_config, write_genai_bundle + + cfg = build_genai_config(hf_config, max_cache_len=256, prefill_seq_len=64) + write_genai_bundle( + Path("out/bundle"), + context_onnx=ctx_path, + iterator_onnx=iter_path, + model_id="Qwen/Qwen3-0.6B", + max_cache_len=256, + prefill_seq_len=64, + embeddings_src=emb_path, # None = skip (add later) + lm_head_src=lmh_path, # None = skip (add later) + ) +""" + +from __future__ import annotations + +import json +import logging +from pathlib import Path +from typing import Any + + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Fixed tensor name constants — must match qwen_transformer_only.py I/O. +# --------------------------------------------------------------------------- +_INPUT_IDS = "input_ids" +_INPUT_HIDDEN_STATES = "input_hidden_states" +_OUTPUT_HIDDEN_STATES = "output_hidden_states" +_PAST_SEQ_LEN = "past_seq_len" +_TOTAL_SEQ_LEN = "total_seq_len" +_PAST_KEY_FMT = "past_keys_%d" +_PAST_VALUE_FMT = "past_values_%d" +_PRESENT_KEY_FMT = "present_keys_%d" +_PRESENT_VALUE_FMT = "present_values_%d" +_LOGITS = "logits" + +# Default filenames inside the bundle directory. +DEFAULT_EMBEDDINGS_FILENAME = "embeddings.onnx" +DEFAULT_CONTEXT_FILENAME = "ctx.onnx" +DEFAULT_ITERATOR_FILENAME = "iter.onnx" +DEFAULT_LM_HEAD_FILENAME = "lm_head.onnx" + +# Tokenizer files to save from the HF snapshot. +_TOKENIZER_FILES = [ + "tokenizer.json", + "tokenizer_config.json", + "vocab.json", + "merges.txt", + "generation_config.json", + "special_tokens_map.json", +] + + +# --------------------------------------------------------------------------- +# Config builder +# --------------------------------------------------------------------------- + + +def build_genai_config( + hf_config: Any, + *, + max_cache_len: int, + prefill_seq_len: int, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, +) -> dict: + """Build the ``genai_config.json`` dict for the transformer-only pipeline. + + Args: + hf_config: A ``transformers.PretrainedConfig`` (e.g. from + ``AutoConfig.from_pretrained``). Reads: ``num_hidden_layers``, + ``hidden_size``, ``num_attention_heads``, ``num_key_value_heads``, + ``head_dim`` (or derived), ``bos_token_id``, ``eos_token_id``, + ``pad_token_id``, ``vocab_size``. + max_cache_len: Static KV cache length. Becomes ``context_length`` and + ``search.max_length`` in the generated config. + prefill_seq_len: Prefill / context sequence length. Becomes + ``decoder.sliding_window.window_size``. + embeddings_filename: Filename of the embeddings ONNX in the bundle. + context_filename: Filename of the context (prefill) ONNX. + iterator_filename: Filename of the iterator (decode) ONNX. + lm_head_filename: Filename of the lm_head ONNX. + + Returns: + A ``dict`` ready for ``json.dumps`` as ``genai_config.json``. + """ + num_layers: int = hf_config.num_hidden_layers + head_size: int = getattr( + hf_config, + "head_dim", + hf_config.hidden_size // hf_config.num_attention_heads, + ) + + eos_token_id = hf_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + + pad_token_id = getattr(hf_config, "pad_token_id", None) or hf_config.bos_token_id + + # Build per-layer KV name lists (same ordering as the reference config). + past_keys = [f"past_keys_{i}" for i in range(num_layers)] + past_values = [f"past_values_{i}" for i in range(num_layers)] + present_keys = [f"present_keys_{i}" for i in range(num_layers)] + present_values = [f"present_values_{i}" for i in range(num_layers)] + + # Transformer stage I/O: hidden states + seq lens + KV buffers. + transformer_inputs = [ + _INPUT_HIDDEN_STATES, + _PAST_SEQ_LEN, + _TOTAL_SEQ_LEN, + *past_keys, + *past_values, + ] + transformer_outputs = [_OUTPUT_HIDDEN_STATES, *present_keys, *present_values] + + return { + "model": { + "type": "decoder-pipeline", + "bos_token_id": hf_config.bos_token_id, + "eos_token_id": eos_token_id, + "pad_token_id": pad_token_id, + "vocab_size": hf_config.vocab_size, + "context_length": max_cache_len, + "decoder": { + "hidden_size": hf_config.hidden_size, + "num_attention_heads": hf_config.num_attention_heads, + "num_key_value_heads": hf_config.num_key_value_heads, + "num_hidden_layers": num_layers, + "head_size": head_size, + "sliding_window": { + "window_size": prefill_seq_len, + "pad_value": 0, + "alignment": "left", + "slide_inputs": True, + "slide_key_value_cache": False, + }, + "inputs": { + "input_ids": _INPUT_IDS, + "past_sequence_length": _PAST_SEQ_LEN, + "total_sequence_length": _TOTAL_SEQ_LEN, + "past_key_names": _PAST_KEY_FMT, + "past_value_names": _PAST_VALUE_FMT, + }, + "outputs": { + "logits": _LOGITS, + "present_key_names": _PRESENT_KEY_FMT, + "present_value_names": _PRESENT_VALUE_FMT, + }, + "pipeline": [ + { + "embeddings": { + "filename": embeddings_filename, + "inputs": [_INPUT_IDS], + "outputs": [_INPUT_HIDDEN_STATES], + "run_on_prompt": True, + "run_on_token_gen": True, + } + }, + { + "context": { + "filename": context_filename, + "inputs": transformer_inputs, + "outputs": transformer_outputs, + "run_on_prompt": True, + "run_on_token_gen": False, + } + }, + { + "iterator": { + "filename": iterator_filename, + "inputs": transformer_inputs, + "outputs": transformer_outputs, + "run_on_prompt": False, + "run_on_token_gen": True, + } + }, + { + "lm_head": { + "filename": lm_head_filename, + "inputs": [_OUTPUT_HIDDEN_STATES], + "outputs": [_LOGITS], + "is_lm_head": True, + "run_on_prompt": True, + "run_on_token_gen": True, + } + }, + ], + }, + }, + "search": { + "max_length": max_cache_len, + "min_length": 0, + "do_sample": False, + "past_present_share_buffer": True, + }, + } + + +# --------------------------------------------------------------------------- +# Bundle assembler +# --------------------------------------------------------------------------- + + +def write_genai_bundle( + output_dir: str | Path, + *, + context_onnx: str | Path, + iterator_onnx: str | Path, + model_id: str, + max_cache_len: int, + prefill_seq_len: int, + embeddings_src: str | Path | None = None, + lm_head_src: str | Path | None = None, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, +) -> Path: + """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. + + Copies the winml-built transformer ONNX files, placeholder embedding / + lm_head models (when provided), HF tokenizer files, and writes + ``genai_config.json``. + + Args: + output_dir: Destination directory (created if absent). + context_onnx: Path to the built prefill/context ONNX + (``decoder_prefill`` sub-model output). + iterator_onnx: Path to the built iteration/decode ONNX + (``decoder_gen`` sub-model output). + model_id: HuggingFace model ID or local path used to download the HF + config and tokenizer files. + max_cache_len: Static KV cache length (= ``context_length`` in genai). + prefill_seq_len: Prefill sequence length (= ``sliding_window.window_size``). + embeddings_src: Source path of the embeddings ONNX to copy into the + bundle. Pass ``None`` to skip (the bundle will be incomplete until + the embeddings model is added separately). + lm_head_src: Source path of the lm_head ONNX to copy. Pass ``None`` + to skip. + context_filename: Filename used for the context ONNX inside the bundle. + iterator_filename: Filename used for the iterator ONNX. + embeddings_filename: Filename used for the embeddings ONNX. + lm_head_filename: Filename used for the lm_head ONNX. + + Returns: + Path to the written ``genai_config.json``. + """ + from transformers import AutoConfig, AutoTokenizer + + from winml.modelkit.onnx import copy_onnx_model + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + context_onnx = Path(context_onnx) + iterator_onnx = Path(iterator_onnx) + + # ------------------------------------------------------------------ + # 1. Copy winml-built transformer ONNX files. + # ------------------------------------------------------------------ + logger.info("Copying context ONNX: %s -> %s", context_onnx.name, context_filename) + copy_onnx_model(context_onnx, output_dir / context_filename) + + logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename) + copy_onnx_model(iterator_onnx, output_dir / iterator_filename) + + # ------------------------------------------------------------------ + # 2. Copy placeholder models (embeddings + lm_head). + # ------------------------------------------------------------------ + if embeddings_src is not None: + logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename) + copy_onnx_model(Path(embeddings_src), output_dir / embeddings_filename) + else: + logger.warning( + "embeddings_src not provided — '%s' is missing from bundle; " + "add it manually before running inference.", + embeddings_filename, + ) + + if lm_head_src is not None: + logger.info("Copying lm_head: %s -> %s", Path(lm_head_src).name, lm_head_filename) + copy_onnx_model(Path(lm_head_src), output_dir / lm_head_filename) + else: + logger.warning( + "lm_head_src not provided — '%s' is missing from bundle; " + "add it manually before running inference.", + lm_head_filename, + ) + + # ------------------------------------------------------------------ + # 3. Save tokenizer files from the HF snapshot. + # ------------------------------------------------------------------ + logger.info("Saving tokenizer from '%s' to %s", model_id, output_dir) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(str(output_dir)) + # Prune any extra files that save_pretrained creates but genai doesn't need + # (e.g. tokenizer.model for sentencepiece models). Keep only known files. + onnx_filenames = {context_filename, iterator_filename, embeddings_filename, lm_head_filename} + for path in output_dir.iterdir(): + if ( + path.name not in _TOKENIZER_FILES + and path.suffix in (".json", ".txt", ".model") + and path.name not in onnx_filenames + ): + logger.debug("Keeping extra tokenizer file: %s", path.name) + + # ------------------------------------------------------------------ + # 4. Write genai_config.json. + # ------------------------------------------------------------------ + hf_config = AutoConfig.from_pretrained(model_id) + config = build_genai_config( + hf_config, + max_cache_len=max_cache_len, + prefill_seq_len=prefill_seq_len, + embeddings_filename=embeddings_filename, + context_filename=context_filename, + iterator_filename=iterator_filename, + lm_head_filename=lm_head_filename, + ) + config_path = output_dir / "genai_config.json" + config_path.write_text(json.dumps(config, indent=2), encoding="utf-8") + logger.info("Wrote genai_config.json -> %s", config_path) + + _log_bundle_summary(output_dir, config_path) + return config_path + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: + """Print a human-readable summary of the assembled bundle.""" + files = sorted(bundle_dir.iterdir()) + lines = [f"\n=== genai bundle: {bundle_dir} ==="] + for f in files: + size_kb = f.stat().st_size / 1024 + tag = "" + if f.name == "genai_config.json": + tag = " <- pipeline config" + elif f.name.endswith(".onnx"): + tag = " <- ONNX graph" + elif f.name.endswith(".data"): + tag = " <- ONNX external weights" + lines.append(f" {f.name:<45} {size_kb:>8.1f} KB{tag}") + lines.append(f"\nConfig written to: {config_path}") + logger.info("\n".join(lines)) + + +__all__ = [ + "DEFAULT_CONTEXT_FILENAME", + "DEFAULT_EMBEDDINGS_FILENAME", + "DEFAULT_ITERATOR_FILENAME", + "DEFAULT_LM_HEAD_FILENAME", + "build_genai_config", + "write_genai_bundle", +] diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py new file mode 100644 index 000000000..5f6930b65 --- /dev/null +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -0,0 +1,225 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the Qwen3 genai config builder.""" + +from __future__ import annotations + +from types import SimpleNamespace + +from winml.modelkit.models.hf.qwen3.genai import ( + DEFAULT_CONTEXT_FILENAME, + DEFAULT_EMBEDDINGS_FILENAME, + DEFAULT_ITERATOR_FILENAME, + DEFAULT_LM_HEAD_FILENAME, + build_genai_config, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _mock_config( + *, + num_hidden_layers: int = 28, + hidden_size: int = 1024, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 128, + bos_token_id: int = 151643, + eos_token_id: int = 151645, + pad_token_id: int = 151643, + vocab_size: int = 151936, +) -> SimpleNamespace: + """Return a minimal stand-in for a HF PretrainedConfig.""" + return SimpleNamespace( + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + vocab_size=vocab_size, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestBuildGenaiConfig: + def setup_method(self) -> None: + self.cfg = _mock_config() + self.result = build_genai_config(self.cfg, max_cache_len=256, prefill_seq_len=64) + + def test_top_level_model_type(self) -> None: + assert self.result["model"]["type"] == "decoder-pipeline" + + def test_token_ids(self) -> None: + m = self.result["model"] + assert m["bos_token_id"] == 151643 + assert m["eos_token_id"] == 151645 + assert m["pad_token_id"] == 151643 + assert m["vocab_size"] == 151936 + + def test_context_length_equals_max_cache_len(self) -> None: + assert self.result["model"]["context_length"] == 256 + + def test_search_max_length_equals_context_length(self) -> None: + assert self.result["search"]["max_length"] == self.result["model"]["context_length"] + + def test_search_past_present_share_buffer(self) -> None: + assert self.result["search"]["past_present_share_buffer"] is True + + def test_decoder_architecture_params(self) -> None: + dec = self.result["model"]["decoder"] + assert dec["hidden_size"] == 1024 + assert dec["num_attention_heads"] == 16 + assert dec["num_key_value_heads"] == 8 + assert dec["num_hidden_layers"] == 28 + assert dec["head_size"] == 128 + + def test_sliding_window_size_equals_prefill_seq_len(self) -> None: + sw = self.result["model"]["decoder"]["sliding_window"] + assert sw["window_size"] == 64 + assert sw["slide_inputs"] is True + assert sw["slide_key_value_cache"] is False + + def test_decoder_io_tensor_names(self) -> None: + inputs = self.result["model"]["decoder"]["inputs"] + assert inputs["past_sequence_length"] == "past_seq_len" + assert inputs["total_sequence_length"] == "total_seq_len" + assert inputs["past_key_names"] == "past_keys_%d" + assert inputs["past_value_names"] == "past_values_%d" + outputs = self.result["model"]["decoder"]["outputs"] + assert outputs["logits"] == "logits" + assert outputs["present_key_names"] == "present_keys_%d" + assert outputs["present_value_names"] == "present_values_%d" + + def test_pipeline_has_four_stages(self) -> None: + pipeline = self.result["model"]["decoder"]["pipeline"] + assert len(pipeline) == 4 + stage_names = [next(iter(s.keys())) for s in pipeline] + assert stage_names == ["embeddings", "context", "iterator", "lm_head"] + + def test_embeddings_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][0]["embeddings"] + assert stage["filename"] == DEFAULT_EMBEDDINGS_FILENAME + assert stage["inputs"] == ["input_ids"] + assert stage["outputs"] == ["input_hidden_states"] + assert stage["run_on_prompt"] is True + assert stage["run_on_token_gen"] is True + + def test_context_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][1]["context"] + assert stage["filename"] == DEFAULT_CONTEXT_FILENAME + assert "input_hidden_states" in stage["inputs"] + assert "past_seq_len" in stage["inputs"] + assert "total_seq_len" in stage["inputs"] + assert stage["run_on_prompt"] is True + assert stage["run_on_token_gen"] is False + + def test_iterator_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][2]["iterator"] + assert stage["filename"] == DEFAULT_ITERATOR_FILENAME + assert stage["run_on_prompt"] is False + assert stage["run_on_token_gen"] is True + + def test_lm_head_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][3]["lm_head"] + assert stage["filename"] == DEFAULT_LM_HEAD_FILENAME + assert stage["inputs"] == ["output_hidden_states"] + assert stage["outputs"] == ["logits"] + assert stage["is_lm_head"] is True + assert stage["run_on_prompt"] is True + assert stage["run_on_token_gen"] is True + + def test_context_kv_inputs_count(self) -> None: + """context.inputs must include all 28 past_keys + 28 past_values.""" + inputs = self.result["model"]["decoder"]["pipeline"][1]["context"]["inputs"] + past_keys = [x for x in inputs if x.startswith("past_keys_")] + past_values = [x for x in inputs if x.startswith("past_values_")] + assert len(past_keys) == 28 + assert len(past_values) == 28 + # All layer indices present + assert set(past_keys) == {f"past_keys_{i}" for i in range(28)} + assert set(past_values) == {f"past_values_{i}" for i in range(28)} + + def test_context_outputs_kv_count(self) -> None: + outputs = self.result["model"]["decoder"]["pipeline"][1]["context"]["outputs"] + present_keys = [x for x in outputs if x.startswith("present_keys_")] + present_values = [x for x in outputs if x.startswith("present_values_")] + assert len(present_keys) == 28 + assert len(present_values) == 28 + + def test_context_and_iterator_have_same_io(self) -> None: + ctx = self.result["model"]["decoder"]["pipeline"][1]["context"] + itr = self.result["model"]["decoder"]["pipeline"][2]["iterator"] + assert ctx["inputs"] == itr["inputs"] + assert ctx["outputs"] == itr["outputs"] + + def test_custom_filenames(self) -> None: + result = build_genai_config( + self.cfg, + max_cache_len=512, + prefill_seq_len=128, + embeddings_filename="emb.onnx", + context_filename="prefill.onnx", + iterator_filename="decode.onnx", + lm_head_filename="head.onnx", + ) + pipeline = result["model"]["decoder"]["pipeline"] + assert pipeline[0]["embeddings"]["filename"] == "emb.onnx" + assert pipeline[1]["context"]["filename"] == "prefill.onnx" + assert pipeline[2]["iterator"]["filename"] == "decode.onnx" + assert pipeline[3]["lm_head"]["filename"] == "head.onnx" + + def test_eos_token_id_list_unpacked(self) -> None: + cfg = _mock_config(eos_token_id=[151645, 151643]) + result = build_genai_config(cfg, max_cache_len=256, prefill_seq_len=64) + assert result["model"]["eos_token_id"] == 151645 + + def test_head_size_derived_when_head_dim_missing(self) -> None: + cfg = SimpleNamespace( + num_hidden_layers=2, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=4, + # no head_dim attribute + bos_token_id=0, + eos_token_id=1, + pad_token_id=0, + vocab_size=32000, + ) + result = build_genai_config(cfg, max_cache_len=128, prefill_seq_len=32) + # head_size = hidden_size // num_attention_heads = 512 // 8 = 64 + assert result["model"]["decoder"]["head_size"] == 64 + + def test_pad_token_id_falls_back_to_bos(self) -> None: + cfg = SimpleNamespace( + num_hidden_layers=2, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=64, + bos_token_id=0, + eos_token_id=1, + pad_token_id=None, + vocab_size=32000, + ) + result = build_genai_config(cfg, max_cache_len=128, prefill_seq_len=32) + assert result["model"]["pad_token_id"] == 0 # falls back to bos_token_id + + def test_different_layer_count(self) -> None: + cfg = _mock_config(num_hidden_layers=4) + result = build_genai_config(cfg, max_cache_len=128, prefill_seq_len=32) + inputs = result["model"]["decoder"]["pipeline"][1]["context"]["inputs"] + past_keys = [x for x in inputs if x.startswith("past_keys_")] + assert len(past_keys) == 4 + assert {f"past_keys_{i}" for i in range(4)} == set(past_keys) From 6d67e248e53291ffbd5d4750deb92dd8483dcaaf Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 29 Jun 2026 17:46:03 +0800 Subject: [PATCH 02/29] refactor(qwen3/genai): generic build_genai_config with ONNX introspection Replace hardcoded tensor-name constants with a data-driven design: - PipelineStage dataclass: carries name, filename, run_on_prompt/token_gen, inputs, outputs, is_lm_head. Callers construct stages explicitly; no tensor names are baked into build_genai_config itself. - DecoderIOMapping dataclass: holds the %d-style format strings that genai uses to expand per-layer KV tensor names. Defaults match Qwen3 naming but any naming convention is supported. - build_genai_config: now takes pipeline: list[PipelineStage] and decoder_io: DecoderIOMapping. Architecture-agnostic; no Qwen3-specific logic. prefill_seq_len=None omits the sliding_window section. - _introspect_onnx_io: reads graph.input / graph.output from an ONNX model without loading external data weights. - _detect_format_patterns: scans tensor names for indexed groups matching with exactly num_layers consecutive zero-based indices, returns {prefix: 'prefix%d'} patterns. - build_qwen3_transformer_only_stages: Qwen3-specific factory that calls _introspect_onnx_io on the built ctx/iter ONNX, detects KV patterns via _detect_format_patterns, and returns (list[PipelineStage], DecoderIOMapping). Tensor names can never drift from the actual ONNX graph I/O. - write_genai_bundle: delegates to build_qwen3_transformer_only_stages instead of hardcoding names. Tests (35 total, all pass): - TestBuildGenaiConfig: +2 new cases (no sliding_window, custom DecoderIOMapping) - TestDetectFormatPatterns: 6 new unit tests for the pattern detector - TestBuildQwen3TransformerOnlyStages: 6 new tests using patched _introspect_onnx_io (no real ONNX files required) --- .../modelkit/models/hf/qwen3/__init__.py | 11 +- src/winml/modelkit/models/hf/qwen3/genai.py | 530 ++++++++++++------ tests/unit/models/qwen3/test_genai_config.py | 257 ++++++++- 3 files changed, 620 insertions(+), 178 deletions(-) diff --git a/src/winml/modelkit/models/hf/qwen3/__init__.py b/src/winml/modelkit/models/hf/qwen3/__init__.py index 9cbac5568..8d8676398 100644 --- a/src/winml/modelkit/models/hf/qwen3/__init__.py +++ b/src/winml/modelkit/models/hf/qwen3/__init__.py @@ -12,10 +12,19 @@ genai — genai_config.json generator + bundle assembler. """ -from .genai import build_genai_config, write_genai_bundle +from .genai import ( + DecoderIOMapping, + PipelineStage, + build_genai_config, + build_qwen3_transformer_only_stages, + write_genai_bundle, +) __all__ = [ + "DecoderIOMapping", + "PipelineStage", "build_genai_config", + "build_qwen3_transformer_only_stages", "write_genai_bundle", ] diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py index c3b93cfb1..8c7c14503 100644 --- a/src/winml/modelkit/models/hf/qwen3/genai.py +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -2,10 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Generate an onnxruntime-genai bundle for the Qwen3 transformer-only pipeline. +r"""Generate an onnxruntime-genai bundle for a transformer-only decoder pipeline. The bundle is a directory that ``onnxruntime-genai`` can load directly via -``og.Config(str(bundle_dir))``. It contains: +``og.Config(str(bundle_dir))``. It contains: genai_config.json — pipeline config consumed by onnxruntime-genai ctx.onnx — prefill/context ONNX (built by winml-cli) @@ -23,14 +23,29 @@ → [lm_head] → logits The context stage runs on the prompt (prefill); the iterator stage runs on each -subsequent decode step. Both share the same KV cache buffer via genai's +subsequent decode step. Both share the same KV cache buffer via genai's ``past_present_share_buffer`` mode. Public API:: - from winml.modelkit.models.hf.qwen3.genai import build_genai_config, write_genai_bundle + from winml.modelkit.models.hf.qwen3.genai import ( + build_genai_config, + build_qwen3_transformer_only_stages, + write_genai_bundle, + DecoderIOMapping, + PipelineStage, + ) + + # High-level: derive everything from the built ONNX files + stages, decoder_io = build_qwen3_transformer_only_stages( + ctx_path, iter_path, num_layers=hf_config.num_hidden_layers + ) + cfg = build_genai_config( + hf_config, max_cache_len=256, prefill_seq_len=64, + pipeline=stages, decoder_io=decoder_io, + ) - cfg = build_genai_config(hf_config, max_cache_len=256, prefill_seq_len=64) + # Or one-shot bundle assembly write_genai_bundle( Path("out/bundle"), context_onnx=ctx_path, @@ -47,33 +62,21 @@ import json import logging +import re +from dataclasses import dataclass from pathlib import Path from typing import Any logger = logging.getLogger(__name__) -# --------------------------------------------------------------------------- -# Fixed tensor name constants — must match qwen_transformer_only.py I/O. -# --------------------------------------------------------------------------- -_INPUT_IDS = "input_ids" -_INPUT_HIDDEN_STATES = "input_hidden_states" -_OUTPUT_HIDDEN_STATES = "output_hidden_states" -_PAST_SEQ_LEN = "past_seq_len" -_TOTAL_SEQ_LEN = "total_seq_len" -_PAST_KEY_FMT = "past_keys_%d" -_PAST_VALUE_FMT = "past_values_%d" -_PRESENT_KEY_FMT = "present_keys_%d" -_PRESENT_VALUE_FMT = "present_values_%d" -_LOGITS = "logits" - # Default filenames inside the bundle directory. DEFAULT_EMBEDDINGS_FILENAME = "embeddings.onnx" DEFAULT_CONTEXT_FILENAME = "ctx.onnx" DEFAULT_ITERATOR_FILENAME = "iter.onnx" DEFAULT_LM_HEAD_FILENAME = "lm_head.onnx" -# Tokenizer files to save from the HF snapshot. +# Tokenizer files written by AutoTokenizer.save_pretrained. _TOKENIZER_FILES = [ "tokenizer.json", "tokenizer_config.json", @@ -83,9 +86,93 @@ "special_tokens_map.json", ] +# Regex for detecting indexed tensor names such as ``past_keys_3``. +_KV_INDEXED_RE = re.compile(r"^(.+?)(\d+)$") + # --------------------------------------------------------------------------- -# Config builder +# Pipeline data structures +# --------------------------------------------------------------------------- + + +@dataclass +class PipelineStage: + """One stage in an onnxruntime-genai multi-model pipeline. + + Attributes: + name: Stage key used inside the ``pipeline`` list of ``genai_config.json``. + filename: ONNX filename inside the bundle directory. + run_on_prompt: Whether genai runs this stage during the prefill pass. + run_on_token_gen: Whether genai runs this stage during decode steps. + inputs: Actual ONNX input tensor names (not format strings). + outputs: Actual ONNX output tensor names (not format strings). + is_lm_head: Set ``True`` for the final language-model head stage. + """ + + name: str + filename: str + run_on_prompt: bool + run_on_token_gen: bool + inputs: list[str] + outputs: list[str] + is_lm_head: bool = False + + def to_dict(self) -> dict: + """Serialize to the dict format expected by ``genai_config.json``.""" + d: dict = { + "filename": self.filename, + "inputs": list(self.inputs), + "outputs": list(self.outputs), + "run_on_prompt": self.run_on_prompt, + "run_on_token_gen": self.run_on_token_gen, + } + if self.is_lm_head: + d["is_lm_head"] = True + return d + + +@dataclass +class DecoderIOMapping: + """Maps genai's abstract I/O concepts to ONNX tensor name format strings. + + The ``*_names`` fields use ``%d`` as the layer-index placeholder, which is + the convention genai uses to expand per-layer KV cache tensor names + (e.g. ``"past_keys_%d"`` → ``"past_keys_0"``, ``"past_keys_1"``, …). + + All fields default to the names produced by the Qwen3 transformer-only + export. + """ + + input_ids: str = "input_ids" + past_sequence_length: str = "past_seq_len" + total_sequence_length: str = "total_seq_len" + past_key_names: str = "past_keys_%d" + past_value_names: str = "past_values_%d" + logits: str = "logits" + present_key_names: str = "present_keys_%d" + present_value_names: str = "present_values_%d" + + def inputs_dict(self) -> dict: + """Return the ``decoder.inputs`` mapping dict for ``genai_config.json``.""" + return { + "input_ids": self.input_ids, + "past_sequence_length": self.past_sequence_length, + "total_sequence_length": self.total_sequence_length, + "past_key_names": self.past_key_names, + "past_value_names": self.past_value_names, + } + + def outputs_dict(self) -> dict: + """Return the ``decoder.outputs`` mapping dict for ``genai_config.json``.""" + return { + "logits": self.logits, + "present_key_names": self.present_key_names, + "present_value_names": self.present_value_names, + } + + +# --------------------------------------------------------------------------- +# Generic config builder # --------------------------------------------------------------------------- @@ -93,32 +180,37 @@ def build_genai_config( hf_config: Any, *, max_cache_len: int, - prefill_seq_len: int, - embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, - context_filename: str = DEFAULT_CONTEXT_FILENAME, - iterator_filename: str = DEFAULT_ITERATOR_FILENAME, - lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + prefill_seq_len: int | None = None, + pipeline: list[PipelineStage], + decoder_io: DecoderIOMapping | None = None, ) -> dict: - """Build the ``genai_config.json`` dict for the transformer-only pipeline. + """Build a ``genai_config.json`` dict for any decoder-pipeline model. + + This function is architecture-agnostic: the caller supplies the pipeline + stages and the I/O name mapping so no tensor names are hardcoded here. Args: - hf_config: A ``transformers.PretrainedConfig`` (e.g. from - ``AutoConfig.from_pretrained``). Reads: ``num_hidden_layers``, - ``hidden_size``, ``num_attention_heads``, ``num_key_value_heads``, - ``head_dim`` (or derived), ``bos_token_id``, ``eos_token_id``, - ``pad_token_id``, ``vocab_size``. - max_cache_len: Static KV cache length. Becomes ``context_length`` and - ``search.max_length`` in the generated config. - prefill_seq_len: Prefill / context sequence length. Becomes - ``decoder.sliding_window.window_size``. - embeddings_filename: Filename of the embeddings ONNX in the bundle. - context_filename: Filename of the context (prefill) ONNX. - iterator_filename: Filename of the iterator (decode) ONNX. - lm_head_filename: Filename of the lm_head ONNX. + hf_config: A ``transformers.PretrainedConfig``. Reads: + ``num_hidden_layers``, ``hidden_size``, ``num_attention_heads``, + ``num_key_value_heads``, ``head_dim`` (optional, falls back to + ``hidden_size // num_attention_heads``), ``bos_token_id``, + ``eos_token_id``, ``pad_token_id``, ``vocab_size``. + max_cache_len: Static KV cache length → ``context_length`` and + ``search.max_length``. + prefill_seq_len: When given, emits a ``sliding_window`` section with + ``window_size=prefill_seq_len``. Pass ``None`` to omit. + pipeline: Ordered list of :class:`PipelineStage` describing each + model in the genai pipeline. + decoder_io: Format-string mapping from genai's abstract I/O names to + actual ONNX tensor names. Defaults to + :class:`DecoderIOMapping` (the Qwen3 default names). Returns: - A ``dict`` ready for ``json.dumps`` as ``genai_config.json``. + A ``dict`` suitable for ``json.dumps`` as ``genai_config.json``. """ + if decoder_io is None: + decoder_io = DecoderIOMapping() + num_layers: int = hf_config.num_hidden_layers head_size: int = getattr( hf_config, @@ -132,21 +224,26 @@ def build_genai_config( pad_token_id = getattr(hf_config, "pad_token_id", None) or hf_config.bos_token_id - # Build per-layer KV name lists (same ordering as the reference config). - past_keys = [f"past_keys_{i}" for i in range(num_layers)] - past_values = [f"past_values_{i}" for i in range(num_layers)] - present_keys = [f"present_keys_{i}" for i in range(num_layers)] - present_values = [f"present_values_{i}" for i in range(num_layers)] - - # Transformer stage I/O: hidden states + seq lens + KV buffers. - transformer_inputs = [ - _INPUT_HIDDEN_STATES, - _PAST_SEQ_LEN, - _TOTAL_SEQ_LEN, - *past_keys, - *past_values, - ] - transformer_outputs = [_OUTPUT_HIDDEN_STATES, *present_keys, *present_values] + decoder_section: dict = { + "hidden_size": hf_config.hidden_size, + "num_attention_heads": hf_config.num_attention_heads, + "num_key_value_heads": hf_config.num_key_value_heads, + "num_hidden_layers": num_layers, + "head_size": head_size, + } + + if prefill_seq_len is not None: + decoder_section["sliding_window"] = { + "window_size": prefill_seq_len, + "pad_value": 0, + "alignment": "left", + "slide_inputs": True, + "slide_key_value_cache": False, + } + + decoder_section["inputs"] = decoder_io.inputs_dict() + decoder_section["outputs"] = decoder_io.outputs_dict() + decoder_section["pipeline"] = [{s.name: s.to_dict()} for s in pipeline] return { "model": { @@ -156,71 +253,7 @@ def build_genai_config( "pad_token_id": pad_token_id, "vocab_size": hf_config.vocab_size, "context_length": max_cache_len, - "decoder": { - "hidden_size": hf_config.hidden_size, - "num_attention_heads": hf_config.num_attention_heads, - "num_key_value_heads": hf_config.num_key_value_heads, - "num_hidden_layers": num_layers, - "head_size": head_size, - "sliding_window": { - "window_size": prefill_seq_len, - "pad_value": 0, - "alignment": "left", - "slide_inputs": True, - "slide_key_value_cache": False, - }, - "inputs": { - "input_ids": _INPUT_IDS, - "past_sequence_length": _PAST_SEQ_LEN, - "total_sequence_length": _TOTAL_SEQ_LEN, - "past_key_names": _PAST_KEY_FMT, - "past_value_names": _PAST_VALUE_FMT, - }, - "outputs": { - "logits": _LOGITS, - "present_key_names": _PRESENT_KEY_FMT, - "present_value_names": _PRESENT_VALUE_FMT, - }, - "pipeline": [ - { - "embeddings": { - "filename": embeddings_filename, - "inputs": [_INPUT_IDS], - "outputs": [_INPUT_HIDDEN_STATES], - "run_on_prompt": True, - "run_on_token_gen": True, - } - }, - { - "context": { - "filename": context_filename, - "inputs": transformer_inputs, - "outputs": transformer_outputs, - "run_on_prompt": True, - "run_on_token_gen": False, - } - }, - { - "iterator": { - "filename": iterator_filename, - "inputs": transformer_inputs, - "outputs": transformer_outputs, - "run_on_prompt": False, - "run_on_token_gen": True, - } - }, - { - "lm_head": { - "filename": lm_head_filename, - "inputs": [_OUTPUT_HIDDEN_STATES], - "outputs": [_LOGITS], - "is_lm_head": True, - "run_on_prompt": True, - "run_on_token_gen": True, - } - }, - ], - }, + "decoder": decoder_section, }, "search": { "max_length": max_cache_len, @@ -231,6 +264,186 @@ def build_genai_config( } +# --------------------------------------------------------------------------- +# ONNX introspection helpers +# --------------------------------------------------------------------------- + + +def _introspect_onnx_io(onnx_path: Path) -> tuple[list[str], list[str]]: + """Return ``(input_names, output_names)`` from an ONNX model graph header. + + External data is intentionally not loaded — only the graph topology is read, + so this is fast even for large quantized models. + """ + try: + import onnx + except ImportError as exc: + raise ImportError( + "The 'onnx' package is required for ONNX introspection. " + "Install it with: pip install onnx" + ) from exc + model = onnx.load(str(onnx_path), load_external_data=False) + return ( + [inp.name for inp in model.graph.input], + [out.name for out in model.graph.output], + ) + + +def _detect_format_patterns(names: list[str], num_layers: int) -> dict[str, str]: + """Detect ``prefix%d`` patterns from a list of indexed tensor names. + + Scans *names* for entries matching ```` where exactly + *num_layers* consecutive zero-based indices are present. + + Returns: + ``{prefix: "prefix%d"}`` for each qualifying group, in the order the + prefixes first appear in *names*. Only groups covering the full + ``[0, num_layers)`` index range are returned. + + Examples:: + + >>> _detect_format_patterns( + ... ["past_keys_0", "past_keys_1", "past_values_0", "past_values_1"], + ... num_layers=2, + ... ) + {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"} + """ + groups: dict[str, list[int]] = {} + for name in names: + m = _KV_INDEXED_RE.match(name) + if m: + prefix, idx = m.group(1), int(m.group(2)) + groups.setdefault(prefix, []).append(idx) + + return { + prefix: f"{prefix}%d" + for prefix, indices in groups.items() + if len(indices) == num_layers and sorted(indices) == list(range(num_layers)) + } + + +def _sort_patterns_by_first_occurrence(patterns: dict[str, str], names: list[str]) -> list[str]: + """Sort *patterns* keys by when ``0`` first appears in *names*.""" + + def _key(prefix: str) -> int: + try: + return names.index(f"{prefix}0") + except ValueError: + return len(names) + + return sorted(patterns.keys(), key=_key) + + +# --------------------------------------------------------------------------- +# Qwen3 transformer-only pipeline factory +# --------------------------------------------------------------------------- + + +def build_qwen3_transformer_only_stages( + context_onnx: str | Path, + iterator_onnx: str | Path, + num_layers: int, + *, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, +) -> tuple[list[PipelineStage], DecoderIOMapping]: + """Build pipeline stages by introspecting the built ONNX models. + + Reads actual tensor names from *context_onnx* and *iterator_onnx* so the + generated ``genai_config.json`` can never drift out of sync with the real + model I/O — no tensor names are hardcoded. + + Args: + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + num_layers: Number of transformer layers (``hf_config.num_hidden_layers``). + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. + + Returns: + ``(stages, decoder_io)`` — a 4-element :class:`PipelineStage` list and + the :class:`DecoderIOMapping` derived from the introspected tensor names. + """ + ctx_inputs, ctx_outputs = _introspect_onnx_io(Path(context_onnx)) + iter_inputs, iter_outputs = _introspect_onnx_io(Path(iterator_onnx)) + + # Detect per-layer KV format-string patterns in the context model. + input_patterns = _detect_format_patterns(ctx_inputs, num_layers) + output_patterns = _detect_format_patterns(ctx_outputs, num_layers) + + in_sorted = _sort_patterns_by_first_occurrence(input_patterns, ctx_inputs) + out_sorted = _sort_patterns_by_first_occurrence(output_patterns, ctx_outputs) + + past_key_fmt = input_patterns[in_sorted[0]] if len(in_sorted) > 0 else "past_keys_%d" + past_val_fmt = input_patterns[in_sorted[1]] if len(in_sorted) > 1 else "past_values_%d" + pres_key_fmt = output_patterns[out_sorted[0]] if len(out_sorted) > 0 else "present_keys_%d" + pres_val_fmt = output_patterns[out_sorted[1]] if len(out_sorted) > 1 else "present_values_%d" + + # Non-indexed inputs: hidden-state tensor + scalar seq-length scalars. + non_indexed = [n for n in ctx_inputs if not _KV_INDEXED_RE.match(n)] + seq_len_names = [n for n in non_indexed if re.search(r"seq|len", n, re.IGNORECASE)] + hidden_state_in = next( + (n for n in non_indexed if n not in seq_len_names), "input_hidden_states" + ) + past_seq_name = next((n for n in seq_len_names if "past" in n.lower()), "past_seq_len") + total_seq_name = next((n for n in seq_len_names if "total" in n.lower()), "total_seq_len") + + # Non-indexed output: hidden-state output of the transformer stack. + hidden_state_out = next( + (n for n in ctx_outputs if not _KV_INDEXED_RE.match(n)), "output_hidden_states" + ) + + decoder_io = DecoderIOMapping( + past_sequence_length=past_seq_name, + total_sequence_length=total_seq_name, + past_key_names=past_key_fmt, + past_value_names=past_val_fmt, + present_key_names=pres_key_fmt, + present_value_names=pres_val_fmt, + ) + + stages: list[PipelineStage] = [ + PipelineStage( + name="embeddings", + filename=embeddings_filename, + run_on_prompt=True, + run_on_token_gen=True, + inputs=[decoder_io.input_ids], + outputs=[hidden_state_in], + ), + PipelineStage( + name="context", + filename=context_filename, + run_on_prompt=True, + run_on_token_gen=False, + inputs=ctx_inputs, + outputs=ctx_outputs, + ), + PipelineStage( + name="iterator", + filename=iterator_filename, + run_on_prompt=False, + run_on_token_gen=True, + inputs=iter_inputs, + outputs=iter_outputs, + ), + PipelineStage( + name="lm_head", + filename=lm_head_filename, + run_on_prompt=True, + run_on_token_gen=True, + inputs=[hidden_state_out], + outputs=[decoder_io.logits], + is_lm_head=True, + ), + ] + return stages, decoder_io + + # --------------------------------------------------------------------------- # Bundle assembler # --------------------------------------------------------------------------- @@ -255,27 +468,22 @@ def write_genai_bundle( Copies the winml-built transformer ONNX files, placeholder embedding / lm_head models (when provided), HF tokenizer files, and writes - ``genai_config.json``. + ``genai_config.json``. Tensor names in the config are derived by + introspecting the built ONNX files rather than being hardcoded. Args: output_dir: Destination directory (created if absent). - context_onnx: Path to the built prefill/context ONNX - (``decoder_prefill`` sub-model output). - iterator_onnx: Path to the built iteration/decode ONNX - (``decoder_gen`` sub-model output). - model_id: HuggingFace model ID or local path used to download the HF - config and tokenizer files. + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + model_id: HuggingFace model ID or local path for config + tokenizer. max_cache_len: Static KV cache length (= ``context_length`` in genai). prefill_seq_len: Prefill sequence length (= ``sliding_window.window_size``). - embeddings_src: Source path of the embeddings ONNX to copy into the - bundle. Pass ``None`` to skip (the bundle will be incomplete until - the embeddings model is added separately). - lm_head_src: Source path of the lm_head ONNX to copy. Pass ``None`` - to skip. - context_filename: Filename used for the context ONNX inside the bundle. - iterator_filename: Filename used for the iterator ONNX. - embeddings_filename: Filename used for the embeddings ONNX. - lm_head_filename: Filename used for the lm_head ONNX. + embeddings_src: Source path of the embeddings ONNX. ``None`` = skip. + lm_head_src: Source path of the lm_head ONNX. ``None`` = skip. + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. Returns: Path to the written ``genai_config.json``. @@ -289,25 +497,20 @@ def write_genai_bundle( context_onnx = Path(context_onnx) iterator_onnx = Path(iterator_onnx) - # ------------------------------------------------------------------ # 1. Copy winml-built transformer ONNX files. - # ------------------------------------------------------------------ logger.info("Copying context ONNX: %s -> %s", context_onnx.name, context_filename) copy_onnx_model(context_onnx, output_dir / context_filename) logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename) copy_onnx_model(iterator_onnx, output_dir / iterator_filename) - # ------------------------------------------------------------------ # 2. Copy placeholder models (embeddings + lm_head). - # ------------------------------------------------------------------ if embeddings_src is not None: logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename) copy_onnx_model(Path(embeddings_src), output_dir / embeddings_filename) else: logger.warning( - "embeddings_src not provided — '%s' is missing from bundle; " - "add it manually before running inference.", + "embeddings_src not provided — '%s' is missing from bundle.", embeddings_filename, ) @@ -316,40 +519,34 @@ def write_genai_bundle( copy_onnx_model(Path(lm_head_src), output_dir / lm_head_filename) else: logger.warning( - "lm_head_src not provided — '%s' is missing from bundle; " - "add it manually before running inference.", + "lm_head_src not provided — '%s' is missing from bundle.", lm_head_filename, ) - # ------------------------------------------------------------------ # 3. Save tokenizer files from the HF snapshot. - # ------------------------------------------------------------------ logger.info("Saving tokenizer from '%s' to %s", model_id, output_dir) tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer.save_pretrained(str(output_dir)) - # Prune any extra files that save_pretrained creates but genai doesn't need - # (e.g. tokenizer.model for sentencepiece models). Keep only known files. - onnx_filenames = {context_filename, iterator_filename, embeddings_filename, lm_head_filename} - for path in output_dir.iterdir(): - if ( - path.name not in _TOKENIZER_FILES - and path.suffix in (".json", ".txt", ".model") - and path.name not in onnx_filenames - ): - logger.debug("Keeping extra tokenizer file: %s", path.name) - - # ------------------------------------------------------------------ - # 4. Write genai_config.json. - # ------------------------------------------------------------------ + + # 4. Build pipeline stages by introspecting the source ONNX files. hf_config = AutoConfig.from_pretrained(model_id) + stages, decoder_io = build_qwen3_transformer_only_stages( + context_onnx, + iterator_onnx, + num_layers=hf_config.num_hidden_layers, + context_filename=context_filename, + iterator_filename=iterator_filename, + embeddings_filename=embeddings_filename, + lm_head_filename=lm_head_filename, + ) + + # 5. Write genai_config.json. config = build_genai_config( hf_config, max_cache_len=max_cache_len, prefill_seq_len=prefill_seq_len, - embeddings_filename=embeddings_filename, - context_filename=context_filename, - iterator_filename=iterator_filename, - lm_head_filename=lm_head_filename, + pipeline=stages, + decoder_io=decoder_io, ) config_path = output_dir / "genai_config.json" config_path.write_text(json.dumps(config, indent=2), encoding="utf-8") @@ -387,6 +584,9 @@ def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: "DEFAULT_EMBEDDINGS_FILENAME", "DEFAULT_ITERATOR_FILENAME", "DEFAULT_LM_HEAD_FILENAME", + "DecoderIOMapping", + "PipelineStage", "build_genai_config", + "build_qwen3_transformer_only_stages", "write_genai_bundle", ] diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py index 5f6930b65..012c71cba 100644 --- a/tests/unit/models/qwen3/test_genai_config.py +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -7,13 +7,18 @@ from __future__ import annotations from types import SimpleNamespace +from unittest.mock import patch from winml.modelkit.models.hf.qwen3.genai import ( DEFAULT_CONTEXT_FILENAME, DEFAULT_EMBEDDINGS_FILENAME, DEFAULT_ITERATOR_FILENAME, DEFAULT_LM_HEAD_FILENAME, + DecoderIOMapping, + PipelineStage, + _detect_format_patterns, build_genai_config, + build_qwen3_transformer_only_stages, ) @@ -48,15 +53,59 @@ def _mock_config( ) +def _make_pipeline( + num_layers: int = 28, + *, + emb_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + ctx_filename: str = DEFAULT_CONTEXT_FILENAME, + iter_filename: str = DEFAULT_ITERATOR_FILENAME, + lmh_filename: str = DEFAULT_LM_HEAD_FILENAME, +) -> list[PipelineStage]: + """Build a standard 4-stage pipeline for use in unit tests.""" + ctx_inputs = [ + "input_hidden_states", + "past_seq_len", + "total_seq_len", + *[f"past_keys_{i}" for i in range(num_layers)], + *[f"past_values_{i}" for i in range(num_layers)], + ] + ctx_outputs = [ + "output_hidden_states", + *[f"present_keys_{i}" for i in range(num_layers)], + *[f"present_values_{i}" for i in range(num_layers)], + ] + return [ + PipelineStage( + "embeddings", emb_filename, True, True, ["input_ids"], ["input_hidden_states"] + ), + PipelineStage("context", ctx_filename, True, False, ctx_inputs, ctx_outputs), + PipelineStage("iterator", iter_filename, False, True, ctx_inputs, ctx_outputs), + PipelineStage( + "lm_head", + lmh_filename, + True, + True, + ["output_hidden_states"], + ["logits"], + is_lm_head=True, + ), + ] + + # --------------------------------------------------------------------------- -# Tests +# Tests: build_genai_config # --------------------------------------------------------------------------- class TestBuildGenaiConfig: def setup_method(self) -> None: self.cfg = _mock_config() - self.result = build_genai_config(self.cfg, max_cache_len=256, prefill_seq_len=64) + self.result = build_genai_config( + self.cfg, + max_cache_len=256, + prefill_seq_len=64, + pipeline=_make_pipeline(), + ) def test_top_level_model_type(self) -> None: assert self.result["model"]["type"] == "decoder-pipeline" @@ -85,12 +134,21 @@ def test_decoder_architecture_params(self) -> None: assert dec["num_hidden_layers"] == 28 assert dec["head_size"] == 128 - def test_sliding_window_size_equals_prefill_seq_len(self) -> None: + def test_sliding_window_present_when_prefill_seq_len_given(self) -> None: sw = self.result["model"]["decoder"]["sliding_window"] assert sw["window_size"] == 64 assert sw["slide_inputs"] is True assert sw["slide_key_value_cache"] is False + def test_sliding_window_absent_when_prefill_seq_len_none(self) -> None: + result = build_genai_config( + self.cfg, + max_cache_len=256, + prefill_seq_len=None, + pipeline=_make_pipeline(), + ) + assert "sliding_window" not in result["model"]["decoder"] + def test_decoder_io_tensor_names(self) -> None: inputs = self.result["model"]["decoder"]["inputs"] assert inputs["past_sequence_length"] == "past_seq_len" @@ -102,6 +160,27 @@ def test_decoder_io_tensor_names(self) -> None: assert outputs["present_key_names"] == "present_keys_%d" assert outputs["present_value_names"] == "present_values_%d" + def test_custom_decoder_io_mapping(self) -> None: + custom_io = DecoderIOMapping( + past_key_names="k_%d", + past_value_names="v_%d", + present_key_names="pk_%d", + present_value_names="pv_%d", + ) + result = build_genai_config( + self.cfg, + max_cache_len=256, + prefill_seq_len=64, + pipeline=_make_pipeline(), + decoder_io=custom_io, + ) + dec_inputs = result["model"]["decoder"]["inputs"] + assert dec_inputs["past_key_names"] == "k_%d" + assert dec_inputs["past_value_names"] == "v_%d" + dec_outputs = result["model"]["decoder"]["outputs"] + assert dec_outputs["present_key_names"] == "pk_%d" + assert dec_outputs["present_value_names"] == "pv_%d" + def test_pipeline_has_four_stages(self) -> None: pipeline = self.result["model"]["decoder"]["pipeline"] assert len(pipeline) == 4 @@ -147,7 +226,6 @@ def test_context_kv_inputs_count(self) -> None: past_values = [x for x in inputs if x.startswith("past_values_")] assert len(past_keys) == 28 assert len(past_values) == 28 - # All layer indices present assert set(past_keys) == {f"past_keys_{i}" for i in range(28)} assert set(past_values) == {f"past_values_{i}" for i in range(28)} @@ -169,10 +247,12 @@ def test_custom_filenames(self) -> None: self.cfg, max_cache_len=512, prefill_seq_len=128, - embeddings_filename="emb.onnx", - context_filename="prefill.onnx", - iterator_filename="decode.onnx", - lm_head_filename="head.onnx", + pipeline=_make_pipeline( + emb_filename="emb.onnx", + ctx_filename="prefill.onnx", + iter_filename="decode.onnx", + lmh_filename="head.onnx", + ), ) pipeline = result["model"]["decoder"]["pipeline"] assert pipeline[0]["embeddings"]["filename"] == "emb.onnx" @@ -182,7 +262,9 @@ def test_custom_filenames(self) -> None: def test_eos_token_id_list_unpacked(self) -> None: cfg = _mock_config(eos_token_id=[151645, 151643]) - result = build_genai_config(cfg, max_cache_len=256, prefill_seq_len=64) + result = build_genai_config( + cfg, max_cache_len=256, prefill_seq_len=64, pipeline=_make_pipeline() + ) assert result["model"]["eos_token_id"] == 151645 def test_head_size_derived_when_head_dim_missing(self) -> None: @@ -197,7 +279,9 @@ def test_head_size_derived_when_head_dim_missing(self) -> None: pad_token_id=0, vocab_size=32000, ) - result = build_genai_config(cfg, max_cache_len=128, prefill_seq_len=32) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(2) + ) # head_size = hidden_size // num_attention_heads = 512 // 8 = 64 assert result["model"]["decoder"]["head_size"] == 64 @@ -213,13 +297,162 @@ def test_pad_token_id_falls_back_to_bos(self) -> None: pad_token_id=None, vocab_size=32000, ) - result = build_genai_config(cfg, max_cache_len=128, prefill_seq_len=32) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(2) + ) assert result["model"]["pad_token_id"] == 0 # falls back to bos_token_id def test_different_layer_count(self) -> None: cfg = _mock_config(num_hidden_layers=4) - result = build_genai_config(cfg, max_cache_len=128, prefill_seq_len=32) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(4) + ) inputs = result["model"]["decoder"]["pipeline"][1]["context"]["inputs"] past_keys = [x for x in inputs if x.startswith("past_keys_")] assert len(past_keys) == 4 assert {f"past_keys_{i}" for i in range(4)} == set(past_keys) + + +# --------------------------------------------------------------------------- +# Tests: _detect_format_patterns +# --------------------------------------------------------------------------- + + +class TestDetectFormatPatterns: + def test_detects_two_kv_groups(self) -> None: + names = [ + "input_hidden_states", + "past_seq_len", + "past_keys_0", + "past_keys_1", + "past_keys_2", + "past_values_0", + "past_values_1", + "past_values_2", + ] + result = _detect_format_patterns(names, num_layers=3) + assert result == {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"} + + def test_ignores_incomplete_index_range(self) -> None: + # Missing index 1 — should not be detected + names = ["prefix_0", "prefix_2"] + result = _detect_format_patterns(names, num_layers=3) + assert "prefix_" not in result + + def test_ignores_wrong_num_layers(self) -> None: + # 3 entries but num_layers=5 + names = ["kv_0", "kv_1", "kv_2"] + result = _detect_format_patterns(names, num_layers=5) + assert len(result) == 0 + + def test_empty_input(self) -> None: + assert _detect_format_patterns([], num_layers=4) == {} + + def test_non_indexed_names_ignored(self) -> None: + names = ["input_hidden_states", "past_seq_len", "total_seq_len"] + result = _detect_format_patterns(names, num_layers=3) + assert result == {} + + def test_single_layer_model(self) -> None: + names = ["keys_0", "vals_0"] + result = _detect_format_patterns(names, num_layers=1) + assert result == {"keys_": "keys_%d", "vals_": "vals_%d"} + + +# --------------------------------------------------------------------------- +# Tests: build_qwen3_transformer_only_stages +# --------------------------------------------------------------------------- + + +class TestBuildQwen3TransformerOnlyStages: + """Uses mocked onnx.load so no real ONNX files are required.""" + + def _ctx_inputs(self, n: int = 4) -> list[str]: + return [ + "input_hidden_states", + "past_seq_len", + "total_seq_len", + *[f"past_keys_{i}" for i in range(n)], + *[f"past_values_{i}" for i in range(n)], + ] + + def _ctx_outputs(self, n: int = 4) -> list[str]: + return [ + "output_hidden_states", + *[f"present_keys_{i}" for i in range(n)], + *[f"present_values_{i}" for i in range(n)], + ] + + def _patch_onnx(self, n: int = 4): + ctx_io = (self._ctx_inputs(n), self._ctx_outputs(n)) + iter_io = (self._ctx_inputs(n), self._ctx_outputs(n)) + return patch( + "winml.modelkit.models.hf.qwen3.genai._introspect_onnx_io", + side_effect=[ctx_io, iter_io], + ) + + def test_returns_four_stages(self) -> None: + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages("ctx.onnx", "iter.onnx", num_layers=4) + assert len(stages) == 4 + assert [s.name for s in stages] == ["embeddings", "context", "iterator", "lm_head"] + + def test_detected_kv_format_patterns(self) -> None: + with self._patch_onnx(): + _, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4 + ) + assert decoder_io.past_key_names == "past_keys_%d" + assert decoder_io.past_value_names == "past_values_%d" + assert decoder_io.present_key_names == "present_keys_%d" + assert decoder_io.present_value_names == "present_values_%d" + + def test_detected_seq_len_names(self) -> None: + with self._patch_onnx(): + _, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4 + ) + assert decoder_io.past_sequence_length == "past_seq_len" + assert decoder_io.total_sequence_length == "total_seq_len" + + def test_context_stage_inputs_from_onnx(self) -> None: + with self._patch_onnx(n=4): + stages, _ = build_qwen3_transformer_only_stages("ctx.onnx", "iter.onnx", num_layers=4) + ctx_stage = next(s for s in stages if s.name == "context") + assert "input_hidden_states" in ctx_stage.inputs + assert "past_keys_0" in ctx_stage.inputs + assert "past_values_3" in ctx_stage.inputs + + def test_custom_filenames(self) -> None: + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", + "iter.onnx", + num_layers=4, + context_filename="prefill.onnx", + iterator_filename="decode.onnx", + embeddings_filename="emb.onnx", + lm_head_filename="head.onnx", + ) + names = {s.name: s.filename for s in stages} + assert names["context"] == "prefill.onnx" + assert names["iterator"] == "decode.onnx" + assert names["embeddings"] == "emb.onnx" + assert names["lm_head"] == "head.onnx" + + def test_roundtrip_with_build_genai_config(self) -> None: + """build_qwen3_transformer_only_stages output feeds build_genai_config cleanly.""" + with self._patch_onnx(n=4): + stages, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4 + ) + cfg = _mock_config(num_hidden_layers=4) + result = build_genai_config( + cfg, + max_cache_len=128, + prefill_seq_len=32, + pipeline=stages, + decoder_io=decoder_io, + ) + assert result["model"]["type"] == "decoder-pipeline" + assert len(result["model"]["decoder"]["pipeline"]) == 4 From 2f8f884fce73a5e63315381d120f67ecab30689f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 29 Jun 2026 18:13:28 +0800 Subject: [PATCH 03/29] feat(session): add GenaiSession for onnxruntime-genai inference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - GenaiSession drives og.Model + og.Generator lifecycle for autoregressive text generation; peer class to WinMLSession (not a subclass) - GenerationConfig dataclass: temperature, top_p, top_k, max_new_tokens, repetition_penalty, do_sample - Lazy onnxruntime_genai import via _import_og() — class importable without the package installed (raises GenaiNotInstalledError on first use) - Reuses WinMLEPRegistry for EP discovery/registration (idempotent) - EP support: cpu (clear_providers only), qnn, dml - context_length read from genai_config.json; overridable at construction - generate_streaming() yields decoded token strings; generator del'd in finally - generate() returns joined string; auto-load on first call if not loaded - 33 unit tests; all use patch.dict(sys.modules) to avoid real hardware --- scripts/infer_genai.py | 144 +------ src/winml/modelkit/session/__init__.py | 12 + src/winml/modelkit/session/genai_session.py | 421 ++++++++++++++++++++ tests/unit/session/test_genai_session.py | 349 ++++++++++++++++ 4 files changed, 804 insertions(+), 122 deletions(-) create mode 100644 src/winml/modelkit/session/genai_session.py create mode 100644 tests/unit/session/test_genai_session.py diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py index 4a06ea6be..69139ec3e 100644 --- a/scripts/infer_genai.py +++ b/scripts/infer_genai.py @@ -5,20 +5,12 @@ r"""onnxruntime-genai inference for the Qwen3 transformer-only pipeline. Loads the genai bundle produced by ``export_qwen3_transformer_only.py ---genai-bundle `` and runs greedy text generation. +--genai-bundle `` and runs greedy text generation using +:class:`~winml.modelkit.session.GenaiSession`. The bundle directory must contain ``genai_config.json`` and the four ONNX -graphs it references: - - embeddings.onnx — embedding lookup (input_ids -> input_hidden_states) - ctx.onnx — prefill/context graph (seq_len = prefill_seq_len) - iter.onnx — iteration/decode graph (seq_len = 1) - lm_head.onnx — lm_head (output_hidden_states -> logits) - -It also needs the HF tokenizer files (``tokenizer.json``, -``tokenizer_config.json``, ``vocab.json``, ``merges.txt``, -``generation_config.json``) which ``write_genai_bundle`` downloads -automatically. +graphs it references (``embeddings.onnx``, ``ctx.onnx``, ``iter.onnx``, +``lm_head.onnx``) plus HF tokenizer files. Usage:: @@ -47,63 +39,14 @@ import time from pathlib import Path -import onnxruntime_genai as og +from winml.modelkit.session import GenaiSession, GenerationConfig # Default bundle directory: /out/qwen3_bundle _REPO_ROOT = Path(__file__).resolve().parent.parent DEFAULT_MODEL_DIR = _REPO_ROOT / "out" / "qwen3_bundle" -# The static KV cache length. Must equal ``context_length`` in genai_config.json -# (and the ``--max-cache-len`` used during the winml build). Do not lower this -# value — the KV buffer size is baked into the ONNX graphs. -CONTEXT_LENGTH = 256 - -# Maps the friendly --ep name to the ORT EP canonical name. -_EP_NAME = { - "cpu": "cpu", - "qnn": "QNNExecutionProvider", -} - - -def _register_winml_eps() -> list[str]: - """Discover and register Windows ML execution providers. - - Walks the WinML EP catalog, calls ``ensure_ready()`` on each provider - (downloads via Windows Update if needed), then registers the shared - library with ORT GenAI. Mirrors ``examples/python/winml.py`` from the - onnxruntime-genai repo. - """ - import traceback - - from windowsml import EpCatalog - - registered: list[str] = [] - with EpCatalog() as catalog: - for provider in catalog.find_all_providers(): - provider.ensure_ready() - if not provider.library_path: - continue - try: - og.register_execution_provider_library(provider.name, provider.library_path) - registered.append(provider.name) - except Exception as exc: - print(f"[winml] failed to register {provider.name}: {exc}") - traceback.print_exc() - return registered - - -def _build_og_config(model_dir: Path, ep: str) -> og.Config: - """Create an ``og.Config``, registering WinML EPs when not on CPU.""" - if ep != "cpu": - registered = _register_winml_eps() - print(f"[winml] registered EPs: {registered}") - - config = og.Config(str(model_dir)) - config.clear_providers() - if ep != "cpu": - config.append_provider(_EP_NAME[ep]) - return config +_SUPPORTED_EPS = ["cpu", "qnn", "dml"] def _wrap_chat_template(prompt: str) -> str: @@ -134,7 +77,7 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: ) p.add_argument( "--ep", - choices=sorted(_EP_NAME), + choices=_SUPPORTED_EPS, default="cpu", help="Execution provider (default: cpu).", ) @@ -149,16 +92,6 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: action="store_true", help="Wrap --prompt in the Qwen3 chat template.", ) - p.add_argument( - "--context-length", - type=int, - default=CONTEXT_LENGTH, - help=( - "Static KV cache length. Must match the --max-cache-len used " - "during the winml build and the genai_config.json context_length " - "(default: %(default)s). Do NOT lower this value." - ), - ) p.add_argument( "--verbose", action="store_true", @@ -171,57 +104,24 @@ def main(argv: list[str] | None = None) -> int: """Load the genai bundle and run generation.""" args = parse_args(argv) - model_dir: Path = args.model_dir - if not model_dir.exists(): - print( - f"ERROR: model directory not found: {model_dir}\n" - "Run export_qwen3_transformer_only.py --genai-bundle first.", - file=sys.stderr, - ) - return 1 + text = _wrap_chat_template(args.prompt) if args.chat else args.prompt + gen_cfg = GenerationConfig(max_new_tokens=args.max_new, do_sample=False) - config_file = model_dir / "genai_config.json" - if not config_file.exists(): - print( - f"ERROR: genai_config.json not found in {model_dir}\nThe bundle may be incomplete.", - file=sys.stderr, - ) + try: + session = GenaiSession(args.model_dir, ep=args.ep, verbose=args.verbose) + except FileNotFoundError as exc: + print(f"ERROR: {exc}", file=sys.stderr) return 1 - if args.verbose: - og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True) - - print(f"[load] ep={args.ep} bundle={model_dir}") - config = _build_og_config(model_dir, args.ep) - model = og.Model(config) - tokenizer = og.Tokenizer(model) - tokenizer_stream = tokenizer.create_stream() - - text = _wrap_chat_template(args.prompt) if args.chat else args.prompt - input_tokens = tokenizer.encode(text) - print(f"[tokens] prompt has {len(input_tokens)} tokens") - - params = og.GeneratorParams(model) - # max_length must equal the static KV cache size so genai sizes the - # total_sequence_length input and KV buffers correctly. - params.set_search_options( - max_length=args.context_length, - do_sample=False, - ) - - generator = og.Generator(model, params) - generator.append_tokens(input_tokens) - - print("[gen] ", end="", flush=True) - t0 = time.monotonic() - n = 0 - while not generator.is_done(): - generator.generate_next_token() - new_token = generator.get_next_tokens()[0] - print(tokenizer_stream.decode(new_token), end="", flush=True) - n += 1 - if n >= args.max_new: - break + print(f"[load] ep={args.ep} bundle={args.model_dir}") + with session: + print(f"[ctx] context_length={session.context_length}") + print("[gen] ", end="", flush=True) + t0 = time.monotonic() + n = 0 + for token_str in session.generate_streaming(text, gen_cfg): + print(token_str, end="", flush=True) + n += 1 dt = time.monotonic() - t0 print(f"\n\n[done] {n} tokens in {dt:.1f}s ({n / dt:.1f} tok/s)") diff --git a/src/winml/modelkit/session/__init__.py b/src/winml/modelkit/session/__init__.py index 5148da0b3..d11673961 100644 --- a/src/winml/modelkit/session/__init__.py +++ b/src/winml/modelkit/session/__init__.py @@ -5,6 +5,13 @@ """WinMLSession - ONNX Runtime session manager with WinML EP integration.""" from .ep_registry import WinMLEPRegistry +from .genai_session import ( + GenaiLoadError, + GenaiNotInstalledError, + GenaiSession, + GenaiSessionError, + GenerationConfig, +) from .monitor.ep_monitor import EPMonitor, NullEPMonitor from .monitor.hw_monitor import HWMonitor from .monitor.openvino_monitor import OpenVinoMonitor @@ -17,6 +24,11 @@ __all__ = [ "EPMonitor", + "GenaiLoadError", + "GenaiNotInstalledError", + "GenaiSession", + "GenaiSessionError", + "GenerationConfig", "HWMonitor", "InferenceError", "NullEPMonitor", diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py new file mode 100644 index 000000000..21b0c5b31 --- /dev/null +++ b/src/winml/modelkit/session/genai_session.py @@ -0,0 +1,421 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""GenaiSession — onnxruntime-genai session for multi-model decoder pipelines. + +Manages ``og.Model`` + ``og.Generator`` lifecycle for autoregressive text +generation. Reuses :class:`WinMLEPRegistry` for EP discovery and registration +so EPs are downloaded / registered at most once per process. + +Unlike :class:`WinMLSession` (which wraps ``ort.InferenceSession`` for +single-shot inference), ``GenaiSession`` drives a streaming token-by-token +generation loop. The two classes are peers — neither inherits from the other. + +Bundle directory layout expected by ``onnxruntime-genai``:: + + / + genai_config.json ← required; controls pipeline & search + ctx.onnx ← prefill transformer graph + iter.onnx ← decode transformer graph + embeddings.onnx ← embedding lookup + lm_head.onnx ← logit projection + tokenizer.json ← HF tokenizer files + tokenizer_config.json + ... + +Usage:: + + # Context manager (recommended — auto-loads and unloads) + with GenaiSession("out/qwen3_bundle", ep="qnn") as session: + for token_str in session.generate_streaming("Hello, who are you?"): + print(token_str, end="", flush=True) + + # Manual lifecycle + session = GenaiSession("out/qwen3_bundle", ep="cpu") + session.load() + result = session.generate("What is a transformer?") + session.unload() + +Dependencies:: + + pip install onnxruntime-genai-winml + pip install "windowsml[with-ort]" # registers QNN EP; also provides ORT +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING + +from .ep_registry import WinMLEPRegistry + + +if TYPE_CHECKING: + from collections.abc import Iterator + + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# EP name mapping: user-friendly short name → ORT GenAI provider string. +# None means "do not append a provider" (= default CPU execution). +# --------------------------------------------------------------------------- +_EP_PROVIDER_MAP: dict[str, str | None] = { + "cpu": None, + "qnn": "QNNExecutionProvider", + "dml": "DmlExecutionProvider", +} + + +# --------------------------------------------------------------------------- +# Data classes +# --------------------------------------------------------------------------- + + +@dataclass +class GenerationConfig: + """Search / sampling parameters for a single generation call. + + All parameters are forwarded to ``og.GeneratorParams.set_search_options``. + ``max_length`` is **not** configurable here — it is set to the bundle's + ``context_length`` (read from ``genai_config.json``) because the static KV + cache size is baked into the ONNX graphs at export time. + + Attributes: + max_new_tokens: Soft cap on the number of new tokens to generate. + Generation stops when the model signals EOS, when the KV buffer is + exhausted (``context_length``), or when this limit is reached, + whichever comes first. + do_sample: Enable sampling (``True``) vs greedy (``False``). + temperature: Sampling temperature. Ignored when ``do_sample=False``. + top_p: Nucleus sampling probability mass. Ignored when + ``do_sample=False``. + top_k: Top-K sampling. ``0`` disables the filter. Ignored when + ``do_sample=False``. + repetition_penalty: Multiplicative penalty for repeated tokens + (``1.0`` = no penalty). + """ + + max_new_tokens: int = 128 + do_sample: bool = False + temperature: float = 1.0 + top_p: float = 1.0 + top_k: int = 0 + repetition_penalty: float = 1.0 + + +# --------------------------------------------------------------------------- +# Exceptions +# --------------------------------------------------------------------------- + + +class GenaiSessionError(Exception): + """Base exception for GenaiSession.""" + + +class GenaiNotInstalledError(GenaiSessionError): + """``onnxruntime-genai`` (or ``onnxruntime-genai-winml``) is not installed.""" + + +class GenaiLoadError(GenaiSessionError): + """The bundle could not be loaded (bad config, EP unavailable, etc.).""" + + +# --------------------------------------------------------------------------- +# Session +# --------------------------------------------------------------------------- + + +class GenaiSession: + """ORT GenAI session for multi-model decoder-pipeline inference. + + Wraps ``og.Model`` + ``og.Generator`` to provide a clean generation API. + + The session is **stateless across calls**: each :meth:`generate_streaming` + call creates a fresh ``og.Generator`` so KV state does not persist between + prompts. Thread-safety within a single session is not guaranteed. + + Args: + bundle_dir: Path to the genai bundle directory. Must contain + ``genai_config.json`` and the ONNX files it references. + ep: Execution provider short name (``"cpu"``, ``"qnn"``, ``"dml"``). + Non-CPU EPs trigger WinML EP discovery and registration. + context_length: Override for the static KV cache length. When + ``None`` (default), read from ``genai_config.json``. + Must match the ``--max-cache-len`` used during the winml-cli build. + verbose: Enable ``onnxruntime-genai`` native model I/O logging. + """ + + def __init__( + self, + bundle_dir: str | Path, + ep: str = "cpu", + *, + context_length: int | None = None, + verbose: bool = False, + ) -> None: + self._bundle_dir = Path(bundle_dir) + self._ep = ep.lower() + self._context_length_override = context_length + self._verbose = verbose + + # Resolved at load() time. + self._context_length: int | None = None + + # og.* handles — None until load() is called. + self._model: object | None = None + self._tokenizer: object | None = None + + if not self._bundle_dir.exists(): + raise FileNotFoundError(f"Bundle directory not found: {self._bundle_dir}") + config_path = self._bundle_dir / "genai_config.json" + if not config_path.exists(): + raise FileNotFoundError( + f"genai_config.json not found in {self._bundle_dir}. " + "Run export_qwen3_transformer_only.py --genai-bundle first." + ) + if self._ep not in _EP_PROVIDER_MAP: + raise ValueError(f"Unknown EP {ep!r}. Supported: {sorted(_EP_PROVIDER_MAP)}") + + logger.info("GenaiSession initialized: bundle=%s ep=%s", self._bundle_dir, self._ep) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def load(self) -> None: + """Load ``og.Model`` and tokenizer from the bundle directory. + + Idempotent: calling ``load()`` on an already-loaded session is a no-op. + + Raises: + GenaiNotInstalledError: ``onnxruntime_genai`` is not installed. + GenaiLoadError: The model could not be loaded (EP error, bad config, + missing ONNX files, …). + """ + if self._model is not None: + return + + og = self._import_og() + + # Register WinML EPs to ORT GenAI (skipped for CPU; idempotent). + if self._ep != "cpu": + self._register_eps(og) + + if self._verbose: + og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True) + + try: + config = og.Config(str(self._bundle_dir)) + config.clear_providers() + provider = _EP_PROVIDER_MAP[self._ep] + if provider is not None: + config.append_provider(provider) + self._model = og.Model(config) + self._tokenizer = og.Tokenizer(self._model) + except Exception as exc: + self._model = None + self._tokenizer = None + raise GenaiLoadError( + f"Failed to load genai bundle from {self._bundle_dir}: {exc}" + ) from exc + + self._context_length = self._context_length_override or self._read_context_length() + logger.info( + "GenaiSession loaded: ep=%s context_length=%d", + self._ep, + self._context_length, + ) + + def unload(self) -> None: + """Release ``og.Model`` and tokenizer handles. + + Safe to call on an unloaded session. + """ + self._model = None + self._tokenizer = None + self._context_length = None + logger.info("GenaiSession unloaded: bundle=%s", self._bundle_dir) + + def __enter__(self) -> GenaiSession: + self.load() + return self + + def __exit__(self, *_: object) -> None: + self.unload() + + # ------------------------------------------------------------------ + # Generation + # ------------------------------------------------------------------ + + def generate( + self, + prompt: str | list[int], + config: GenerationConfig | None = None, + ) -> str: + """Generate text and return the full response as a single string. + + This is a convenience wrapper around :meth:`generate_streaming`. + + Args: + prompt: Input text (auto-encoded) or a pre-encoded token-ID list. + config: Generation parameters. Uses :class:`GenerationConfig` + defaults when ``None``. + + Returns: + The generated text (not including the prompt). + """ + return "".join(self.generate_streaming(prompt, config)) + + def generate_streaming( + self, + prompt: str | list[int], + config: GenerationConfig | None = None, + ) -> Iterator[str]: + """Generate text token-by-token, yielding decoded token strings. + + The method auto-loads the session on the first call (lazy-load + equivalent of :meth:`load`). + + Each yield is the decoded string for a single new token. Callers + typically ``print(token, end="", flush=True)`` to stream output. + + Args: + prompt: Input text (auto-encoded via the bundle tokenizer) or a + pre-encoded token-ID list. Pass a pre-formatted string when + chat templates or special tokens are needed — the session is + not aware of any particular model's template format. + config: Generation parameters. Uses :class:`GenerationConfig` + defaults when ``None``. + + Yields: + Decoded string for each newly generated token. + """ + self._ensure_loaded() + og = self._import_og() + cfg = config or GenerationConfig() + + tokens = ( + self._tokenizer.encode(prompt) # type: ignore[union-attr] + if isinstance(prompt, str) + else prompt + ) + + params = og.GeneratorParams(self._model) + params.set_search_options( + max_length=self._context_length, + do_sample=cfg.do_sample, + temperature=cfg.temperature, + top_p=cfg.top_p, + top_k=cfg.top_k, + repetition_penalty=cfg.repetition_penalty, + ) + + generator = og.Generator(self._model, params) + generator.append_tokens(tokens) + + stream = self._tokenizer.create_stream() # type: ignore[union-attr] + n = 0 + try: + while not generator.is_done(): + generator.generate_next_token() + new_token = generator.get_next_tokens()[0] + yield stream.decode(new_token) + n += 1 + if n >= cfg.max_new_tokens: + break + finally: + # Explicit deletion releases the KV cache buffer held by the generator. + del generator + + # ------------------------------------------------------------------ + # Tokenizer helpers + # ------------------------------------------------------------------ + + def encode(self, text: str) -> list[int]: + """Encode *text* to a list of token IDs using the bundle tokenizer.""" + self._ensure_loaded() + return self._tokenizer.encode(text).tolist() # type: ignore[union-attr] + + def decode(self, tokens: list[int]) -> str: + """Decode a list of token IDs to a string.""" + self._ensure_loaded() + return self._tokenizer.decode(tokens) # type: ignore[union-attr] + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def is_loaded(self) -> bool: + """``True`` if the model is loaded and ready for generation.""" + return self._model is not None + + @property + def bundle_dir(self) -> Path: + """Path to the genai bundle directory.""" + return self._bundle_dir + + @property + def ep(self) -> str: + """Execution provider short name (as passed to ``__init__``).""" + return self._ep + + @property + def context_length(self) -> int | None: + """Static KV cache length, populated after :meth:`load`.""" + return self._context_length + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _ensure_loaded(self) -> None: + if self._model is None: + self.load() + + @staticmethod + def _import_og() -> object: + """Import and return the ``onnxruntime_genai`` module. + + Raises: + GenaiNotInstalledError: Package not found. + """ + try: + import onnxruntime_genai as og + + return og + except ImportError as exc: + raise GenaiNotInstalledError( + "onnxruntime_genai is not installed. " + "Install it with: pip install onnxruntime-genai-winml" + ) from exc + + def _register_eps(self, og: object) -> None: + """Register WinML EPs with ORT GenAI (idempotent, best-effort).""" + try: + registry = WinMLEPRegistry.get_instance() + if registry.winml_available: + result = registry.register_execution_providers(ort_genai=True) + registered = result.get("onnxruntime_genai", []) + logger.info("WinML EPs registered for ORT GenAI: %s", registered) + except Exception as exc: + logger.warning("WinML EP registration skipped: %s", exc) + + def _read_context_length(self) -> int: + """Read ``model.context_length`` from ``genai_config.json``.""" + cfg = json.loads((self._bundle_dir / "genai_config.json").read_text(encoding="utf-8")) + return int(cfg["model"]["context_length"]) + + +__all__ = [ + "GenaiLoadError", + "GenaiNotInstalledError", + "GenaiSession", + "GenaiSessionError", + "GenerationConfig", +] diff --git a/tests/unit/session/test_genai_session.py b/tests/unit/session/test_genai_session.py new file mode 100644 index 000000000..dbc815f23 --- /dev/null +++ b/tests/unit/session/test_genai_session.py @@ -0,0 +1,349 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for GenaiSession. + +All tests that touch load() / generate*() mock onnxruntime_genai so no +real model files or GPU/NPU hardware is required. +""" + +from __future__ import annotations + +import json +import sys +from typing import TYPE_CHECKING +from unittest.mock import MagicMock, patch + +import pytest + +from winml.modelkit.session.genai_session import ( + GenaiLoadError, + GenaiNotInstalledError, + GenaiSession, + GenaiSessionError, + GenerationConfig, +) + + +if TYPE_CHECKING: + from pathlib import Path + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def bundle_dir(tmp_path: Path) -> Path: + """Create a minimal genai bundle directory with genai_config.json.""" + cfg = { + "model": { + "type": "decoder-pipeline", + "context_length": 256, + "decoder": {}, + }, + "search": {"max_length": 256}, + } + (tmp_path / "genai_config.json").write_text(json.dumps(cfg), encoding="utf-8") + return tmp_path + + +@pytest.fixture +def mock_og() -> MagicMock: + """Return a fully mocked onnxruntime_genai module.""" + og = MagicMock(name="onnxruntime_genai") + og.Config.return_value = MagicMock() + og.Model.return_value = MagicMock() + og.Tokenizer.return_value = MagicMock() + og.GeneratorParams.return_value = MagicMock() + + # Generator that yields two tokens then is_done() + gen = MagicMock() + gen.is_done.side_effect = [False, False, True] + gen.get_next_tokens.side_effect = [ + MagicMock(__getitem__=lambda s, i: 10), + MagicMock(__getitem__=lambda s, i: 20), + ] + og.Generator.return_value = gen + + # TokenizerStream decodes tokens to text + stream = MagicMock() + stream.decode.side_effect = ["Hello", " world"] + og.Tokenizer.return_value.create_stream.return_value = stream + + return og + + +def _patch_og(mock: MagicMock): + """Context manager: inject mock_og as onnxruntime_genai in sys.modules.""" + return patch.dict(sys.modules, {"onnxruntime_genai": mock}) + + +# --------------------------------------------------------------------------- +# Tests: GenaiSession.__init__ +# --------------------------------------------------------------------------- + + +class TestGenaiSessionInit: + def test_missing_bundle_dir_raises(self, tmp_path: Path) -> None: + with pytest.raises(FileNotFoundError, match="Bundle directory not found"): + GenaiSession(tmp_path / "nonexistent") + + def test_missing_config_raises(self, tmp_path: Path) -> None: + # Dir exists but no genai_config.json + with pytest.raises(FileNotFoundError, match=r"genai_config\.json not found"): + GenaiSession(tmp_path) + + def test_unknown_ep_raises(self, bundle_dir: Path) -> None: + with pytest.raises(ValueError, match="Unknown EP"): + GenaiSession(bundle_dir, ep="tensorrt") + + def test_default_ep_is_cpu(self, bundle_dir: Path) -> None: + session = GenaiSession(bundle_dir) + assert session.ep == "cpu" + + def test_not_loaded_after_init(self, bundle_dir: Path) -> None: + session = GenaiSession(bundle_dir) + assert not session.is_loaded + assert session.context_length is None + + def test_bundle_dir_property(self, bundle_dir: Path) -> None: + session = GenaiSession(bundle_dir) + assert session.bundle_dir == bundle_dir + + def test_supported_eps(self, bundle_dir: Path) -> None: + for ep in ("cpu", "qnn", "dml"): + session = GenaiSession(bundle_dir, ep=ep) + assert session.ep == ep + + +# --------------------------------------------------------------------------- +# Tests: load / unload +# --------------------------------------------------------------------------- + + +class TestGenaiSessionLoad: + def test_load_sets_is_loaded(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + session.load() + assert session.is_loaded + + def test_load_reads_context_length_from_config( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + session.load() + assert session.context_length == 256 + + def test_context_length_override(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir, context_length=512) + session.load() + assert session.context_length == 512 + + def test_load_is_idempotent(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + session.load() + session.load() # second call is a no-op + assert mock_og.Model.call_count == 1 + + def test_unload_clears_state(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + session.load() + session.unload() + assert not session.is_loaded + assert session.context_length is None + + def test_unload_on_unloaded_session_is_safe(self, bundle_dir: Path) -> None: + session = GenaiSession(bundle_dir) + session.unload() # should not raise + + def test_context_manager_loads_and_unloads(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + assert session.is_loaded + assert not session.is_loaded + + def test_genai_not_installed_raises(self, bundle_dir: Path) -> None: + with patch.dict(sys.modules, {"onnxruntime_genai": None}): # type: ignore[dict-item] + session = GenaiSession(bundle_dir) + with pytest.raises(GenaiNotInstalledError): + session.load() + + def test_og_load_error_raises_genai_load_error( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + mock_og.Model.side_effect = RuntimeError("driver not found") + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + with pytest.raises(GenaiLoadError, match="driver not found"): + session.load() + + def test_og_load_error_leaves_session_unloaded( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + mock_og.Model.side_effect = RuntimeError("driver not found") + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + with pytest.raises(GenaiLoadError): + session.load() + assert not session.is_loaded + + +# --------------------------------------------------------------------------- +# Tests: EP registration +# --------------------------------------------------------------------------- + + +class TestEPRegistration: + def test_cpu_skips_winml_registration(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with ( + _patch_og(mock_og), + patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, + ): + session = GenaiSession(bundle_dir, ep="cpu") + session.load() + mock_reg_cls.assert_not_called() + + def test_non_cpu_registers_winml_eps(self, bundle_dir: Path, mock_og: MagicMock) -> None: + mock_registry = MagicMock() + mock_registry.winml_available = True + mock_registry.register_execution_providers.return_value = { + "onnxruntime_genai": ["QNNExecutionProvider"] + } + with ( + _patch_og(mock_og), + patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, + ): + mock_reg_cls.get_instance.return_value = mock_registry + session = GenaiSession(bundle_dir, ep="qnn") + session.load() + mock_registry.register_execution_providers.assert_called_once_with(ort_genai=True) + + def test_non_cpu_appends_provider_to_config(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with ( + _patch_og(mock_og), + patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, + ): + mock_reg_cls.get_instance.return_value = MagicMock(winml_available=False) + session = GenaiSession(bundle_dir, ep="qnn") + session.load() + mock_og.Config.return_value.append_provider.assert_called_once_with("QNNExecutionProvider") + + def test_cpu_does_not_append_provider(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir, ep="cpu") + session.load() + mock_og.Config.return_value.append_provider.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests: generate / generate_streaming +# --------------------------------------------------------------------------- + + +class TestGenerate: + def test_generate_streaming_yields_decoded_tokens( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + tokens = list(session.generate_streaming("hi")) + assert tokens == ["Hello", " world"] + + def test_generate_returns_joined_string(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + result = session.generate("hi") + assert result == "Hello world" + + def test_generate_respects_max_new_tokens(self, bundle_dir: Path, mock_og: MagicMock) -> None: + # Generator never signals done; we stop at max_new_tokens=1 + gen = mock_og.Generator.return_value + gen.is_done.side_effect = None + gen.is_done.return_value = False + gen.get_next_tokens.return_value = MagicMock(__getitem__=lambda s, i: 99) + mock_og.Tokenizer.return_value.create_stream.return_value.decode.return_value = "x" + + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + tokens = list(session.generate_streaming("hi", GenerationConfig(max_new_tokens=1))) + assert len(tokens) == 1 + + def test_generate_with_token_list_input(self, bundle_dir: Path, mock_og: MagicMock) -> None: + """Pre-encoded token IDs are forwarded directly to append_tokens.""" + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + list(session.generate_streaming([1, 2, 3])) + gen = mock_og.Generator.return_value + gen.append_tokens.assert_called_once_with([1, 2, 3]) + + def test_generate_deletes_generator_after_iteration( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + """Generator is deleted (not leaked) even on normal completion.""" + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + list(session.generate_streaming("hi")) + # No assertions needed — test passes if no ResourceWarning / hang + + def test_generate_with_custom_config(self, bundle_dir: Path, mock_og: MagicMock) -> None: + cfg = GenerationConfig(max_new_tokens=64, do_sample=True, temperature=0.7) + with _patch_og(mock_og), GenaiSession(bundle_dir) as session: + list(session.generate_streaming("hi", cfg)) + params = mock_og.GeneratorParams.return_value + params.set_search_options.assert_called_once() + call_kwargs = params.set_search_options.call_args.kwargs + assert call_kwargs["do_sample"] is True + assert call_kwargs["temperature"] == 0.7 + + def test_generate_uses_context_length_as_max_length( + self, bundle_dir: Path, mock_og: MagicMock + ) -> None: + with _patch_og(mock_og), GenaiSession(bundle_dir, context_length=128) as session: + list(session.generate_streaming("hi")) + params = mock_og.GeneratorParams.return_value + call_kwargs = params.set_search_options.call_args.kwargs + assert call_kwargs["max_length"] == 128 + + def test_auto_load_on_first_generate(self, bundle_dir: Path, mock_og: MagicMock) -> None: + with _patch_og(mock_og): + session = GenaiSession(bundle_dir) + assert not session.is_loaded + list(session.generate_streaming("hi")) + assert session.is_loaded + + +# --------------------------------------------------------------------------- +# Tests: GenerationConfig defaults +# --------------------------------------------------------------------------- + + +class TestGenerationConfig: + def test_defaults(self) -> None: + cfg = GenerationConfig() + assert cfg.max_new_tokens == 128 + assert cfg.do_sample is False + assert cfg.temperature == 1.0 + assert cfg.top_p == 1.0 + assert cfg.top_k == 0 + assert cfg.repetition_penalty == 1.0 + + def test_custom_values(self) -> None: + cfg = GenerationConfig(max_new_tokens=32, do_sample=True, top_k=50) + assert cfg.max_new_tokens == 32 + assert cfg.do_sample is True + assert cfg.top_k == 50 + + +# --------------------------------------------------------------------------- +# Tests: exception hierarchy +# --------------------------------------------------------------------------- + + +class TestExceptions: + def test_genai_not_installed_is_genai_session_error(self) -> None: + assert issubclass(GenaiNotInstalledError, GenaiSessionError) + + def test_genai_load_error_is_genai_session_error(self) -> None: + assert issubclass(GenaiLoadError, GenaiSessionError) From 0cf6d48d32e858559b51f71e809931198a3527a4 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Mon, 29 Jun 2026 18:27:24 +0800 Subject: [PATCH 04/29] feat(session): add GenaiSession.apply_chatml_template static method - Moves chat template logic from infer_genai.py into GenaiSession - Supports optional system prompt - ChatML is not Qwen3-specific; used by Qwen2/3, Yi, Mistral, etc. - infer_genai.py _wrap_chat_template now delegates to the static method - Updated --chat flag help text and script docstring - 4 new tests covering user-only, with-system, no-system-turn, assistant-priming --- scripts/infer_genai.py | 8 ++--- src/winml/modelkit/session/genai_session.py | 37 +++++++++++++++++++++ tests/unit/session/test_genai_session.py | 24 +++++++++++++ 3 files changed, 65 insertions(+), 4 deletions(-) diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py index 69139ec3e..5144fa7bc 100644 --- a/scripts/infer_genai.py +++ b/scripts/infer_genai.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -r"""onnxruntime-genai inference for the Qwen3 transformer-only pipeline. +r"""onnxruntime-genai inference for a genai bundle (decoder-pipeline). Loads the genai bundle produced by ``export_qwen3_transformer_only.py --genai-bundle `` and runs greedy text generation using @@ -50,8 +50,8 @@ def _wrap_chat_template(prompt: str) -> str: - """Wrap *prompt* in the Qwen3 chat template (no thinking mode).""" - return f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" + """Wrap *prompt* in the ChatML chat template.""" + return GenaiSession.apply_chatml_template(prompt) def parse_args(argv: list[str] | None = None) -> argparse.Namespace: @@ -90,7 +90,7 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: p.add_argument( "--chat", action="store_true", - help="Wrap --prompt in the Qwen3 chat template.", + help="Wrap --prompt in the ChatML template (<|im_start|>user/assistant).", ) p.add_argument( "--verbose", diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 21b0c5b31..229b4e83a 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -332,6 +332,43 @@ def generate_streaming( # Explicit deletion releases the KV cache buffer held by the generator. del generator + # ------------------------------------------------------------------ + # Chat-template helpers + # ------------------------------------------------------------------ + + @staticmethod + def apply_chatml_template( + prompt: str, + system: str | None = None, + ) -> str: + r"""Wrap *prompt* in the ChatML format used by Qwen2/3, Yi, Mistral, etc. + + Produces:: + + <|im_start|>system + {system}<|im_end|> + <|im_start|>user + {prompt}<|im_end|> + <|im_start|>assistant + + The trailing ``<|im_start|>assistant\\n`` primes the model to respond + as the assistant role with no leading newline in its output. + + Args: + prompt: The user turn text. + system: Optional system prompt. When ``None`` no system turn is + prepended. + + Returns: + Formatted string ready to pass to :meth:`generate` / + :meth:`generate_streaming`. + """ + parts: list[str] = [] + if system is not None: + parts.append(f"<|im_start|>system\n{system}<|im_end|>\n") + parts.append(f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n") + return "".join(parts) + # ------------------------------------------------------------------ # Tokenizer helpers # ------------------------------------------------------------------ diff --git a/tests/unit/session/test_genai_session.py b/tests/unit/session/test_genai_session.py index dbc815f23..4859ef11e 100644 --- a/tests/unit/session/test_genai_session.py +++ b/tests/unit/session/test_genai_session.py @@ -314,6 +314,30 @@ def test_auto_load_on_first_generate(self, bundle_dir: Path, mock_og: MagicMock) assert session.is_loaded +# --------------------------------------------------------------------------- +# Tests: apply_chatml_template +# --------------------------------------------------------------------------- + + +class TestApplyChatmlTemplate: + def test_user_only(self) -> None: + result = GenaiSession.apply_chatml_template("Hello") + assert result == "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" + + def test_with_system(self) -> None: + result = GenaiSession.apply_chatml_template("Hello", system="You are helpful.") + assert result.startswith("<|im_start|>system\nYou are helpful.<|im_end|>\n") + assert "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" in result + + def test_no_system_no_system_turn(self) -> None: + result = GenaiSession.apply_chatml_template("Hi") + assert "<|im_start|>system" not in result + + def test_ends_with_assistant_priming(self) -> None: + result = GenaiSession.apply_chatml_template("Hi") + assert result.endswith("<|im_start|>assistant\n") + + # --------------------------------------------------------------------------- # Tests: GenerationConfig defaults # --------------------------------------------------------------------------- From f3a64bc3948186c6decfb62cdb9ad23adad5fc18 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 12:39:25 +0800 Subject: [PATCH 05/29] feat(qwen3/genai): NPU+CPU hybrid EP support in genai_config - PipelineStage gains session_options: dict | None = None field; PipelineStage.to_dict() emits it when set - Add _qnn_stage_session_options(log_id, soc_model) helper that produces QNN HTP provider_options for a pipeline stage - build_qwen3_transformer_only_stages gains ep='cpu' and soc_model='60' params; when ep='qnn' the context and iterator stages receive QNN session_options, embeddings and lm_head stay on CPU (no session_options) - write_genai_bundle threads ep/soc_model through - export_qwen3_transformer_only.py passes ep='qnn' when --device npu - 5 new tests covering cpu/qnn ep routing and soc_model propagation (39 total, all pass) --- scripts/export_qwen3_transformer_only.py | 1 + src/winml/modelkit/models/hf/qwen3/genai.py | 77 ++++++++++++++++++++ tests/unit/models/qwen3/test_genai_config.py | 57 +++++++++++++++ 3 files changed, 135 insertions(+) diff --git a/scripts/export_qwen3_transformer_only.py b/scripts/export_qwen3_transformer_only.py index 202c9c906..856123be2 100644 --- a/scripts/export_qwen3_transformer_only.py +++ b/scripts/export_qwen3_transformer_only.py @@ -220,6 +220,7 @@ def main(argv: list[str] | None = None) -> int: prefill_seq_len=args.prefill_seq_len, embeddings_src=args.embeddings, lm_head_src=args.lm_head, + ep="qnn" if args.device == "npu" else args.device, ) print(f" genai_config.json -> {config_path}") if args.embeddings is None: diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py index 8c7c14503..4a63de45b 100644 --- a/src/winml/modelkit/models/hf/qwen3/genai.py +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -116,6 +116,13 @@ class PipelineStage: inputs: list[str] outputs: list[str] is_lm_head: bool = False + session_options: dict | None = None + """Per-stage ORT session options (e.g. provider_options for QNN). + + When set, emitted verbatim as the ``session_options`` key in the + ``genai_config.json`` pipeline stage. Leave ``None`` (default) for + stages that should run on the default (CPU) provider. + """ def to_dict(self) -> dict: """Serialize to the dict format expected by ``genai_config.json``.""" @@ -126,6 +133,8 @@ def to_dict(self) -> dict: "run_on_prompt": self.run_on_prompt, "run_on_token_gen": self.run_on_token_gen, } + if self.session_options: + d["session_options"] = self.session_options if self.is_lm_head: d["is_lm_head"] = True return d @@ -334,6 +343,42 @@ def _key(prefix: str) -> int: return sorted(patterns.keys(), key=_key) +# --------------------------------------------------------------------------- +# Per-EP stage session_options helpers +# --------------------------------------------------------------------------- + + +def _qnn_stage_session_options(log_id: str, soc_model: str = "60") -> dict: + """Return the ``session_options`` block that routes a stage to QNN HTP. + + Args: + log_id: ORT log identifier (shown in ORT logs), e.g. + ``"onnxruntime-genai.context"``. + soc_model: Snapdragon SoC model number passed to the QNN HTP backend. + ``"60"`` targets Snapdragon 8 Gen 3 (X Elite). Change for other + SoCs (e.g. ``"55"`` for 8 Gen 2, ``"73"`` for 8 Elite). + + Returns: + Dict suitable for the ``session_options`` key of a pipeline stage in + ``genai_config.json``. + """ + return { + "log_id": log_id, + "provider_options": [ + { + "qnn": { + "backend_path": "QnnHtp.dll", + "htp_performance_mode": "burst", + "htp_graph_finalization_optimization_mode": "3", + "soc_model": soc_model, + } + } + ], + "intra_op_num_threads": 2, + "inter_op_num_threads": 1, + } + + # --------------------------------------------------------------------------- # Qwen3 transformer-only pipeline factory # --------------------------------------------------------------------------- @@ -348,6 +393,8 @@ def build_qwen3_transformer_only_stages( iterator_filename: str = DEFAULT_ITERATOR_FILENAME, embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", ) -> tuple[list[PipelineStage], DecoderIOMapping]: """Build pipeline stages by introspecting the built ONNX models. @@ -363,6 +410,13 @@ def build_qwen3_transformer_only_stages( iterator_filename: Bundle filename for the iterator model. embeddings_filename: Bundle filename for the embeddings model. lm_head_filename: Bundle filename for the lm_head model. + ep: Execution provider for the transformer stages. ``"qnn"`` injects + QNN HTP ``session_options`` into the ``context`` and ``iterator`` + stages so they run on the NPU while ``embeddings`` and ``lm_head`` + continue on CPU. ``"cpu"`` (default) omits ``session_options`` + from all stages. + soc_model: Snapdragon SoC model number forwarded to the QNN backend + when ``ep="qnn"``. Default ``"60"`` targets Snapdragon 8 Gen 3. Returns: ``(stages, decoder_io)`` — a 4-element :class:`PipelineStage` list and @@ -406,6 +460,17 @@ def build_qwen3_transformer_only_stages( present_value_names=pres_val_fmt, ) + # Per-stage session_options: NPU stages get QNN config; CPU and others get None. + ctx_session_opts: dict | None = None + iter_session_opts: dict | None = None + if ep == "qnn": + ctx_session_opts = _qnn_stage_session_options( + "onnxruntime-genai.context", soc_model=soc_model + ) + iter_session_opts = _qnn_stage_session_options( + "onnxruntime-genai.iterator", soc_model=soc_model + ) + stages: list[PipelineStage] = [ PipelineStage( name="embeddings", @@ -422,6 +487,7 @@ def build_qwen3_transformer_only_stages( run_on_token_gen=False, inputs=ctx_inputs, outputs=ctx_outputs, + session_options=ctx_session_opts, ), PipelineStage( name="iterator", @@ -430,6 +496,7 @@ def build_qwen3_transformer_only_stages( run_on_token_gen=True, inputs=iter_inputs, outputs=iter_outputs, + session_options=iter_session_opts, ), PipelineStage( name="lm_head", @@ -463,6 +530,8 @@ def write_genai_bundle( iterator_filename: str = DEFAULT_ITERATOR_FILENAME, embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", ) -> Path: """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. @@ -484,6 +553,12 @@ def write_genai_bundle( iterator_filename: Bundle filename for the iterator model. embeddings_filename: Bundle filename for the embeddings model. lm_head_filename: Bundle filename for the lm_head model. + ep: Execution provider for the transformer (context/iterator) stages. + ``"qnn"`` injects QNN HTP ``session_options`` so those stages run + on the NPU while embeddings and lm_head run on CPU. + ``"cpu"`` (default) omits ``session_options`` (all stages on CPU). + soc_model: Snapdragon SoC model passed to the QNN backend when + ``ep="qnn"``. Default ``"60"`` = Snapdragon 8 Gen 3 / X Elite. Returns: Path to the written ``genai_config.json``. @@ -538,6 +613,8 @@ def write_genai_bundle( iterator_filename=iterator_filename, embeddings_filename=embeddings_filename, lm_head_filename=lm_head_filename, + ep=ep, + soc_model=soc_model, ) # 5. Write genai_config.json. diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py index 012c71cba..900f8b664 100644 --- a/tests/unit/models/qwen3/test_genai_config.py +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -456,3 +456,60 @@ def test_roundtrip_with_build_genai_config(self) -> None: ) assert result["model"]["type"] == "decoder-pipeline" assert len(result["model"]["decoder"]["pipeline"]) == 4 + + def test_cpu_ep_no_session_options(self) -> None: + """Default cpu ep: context/iterator stages have no session_options.""" + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="cpu" + ) + ctx = next(s for s in stages if s.name == "context") + itr = next(s for s in stages if s.name == "iterator") + assert ctx.session_options is None + assert itr.session_options is None + + def test_qnn_ep_injects_session_options(self) -> None: + """ep='qnn': context/iterator get QNN session_options; emb/lm_head do not.""" + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn" + ) + stage_map = {s.name: s for s in stages} + assert stage_map["embeddings"].session_options is None + assert stage_map["lm_head"].session_options is None + ctx_opts = stage_map["context"].session_options + itr_opts = stage_map["iterator"].session_options + assert ctx_opts is not None + assert itr_opts is not None + assert ctx_opts["provider_options"][0]["qnn"]["backend_path"] == "QnnHtp.dll" + assert itr_opts["log_id"] == "onnxruntime-genai.iterator" + + def test_qnn_session_options_in_serialized_config(self) -> None: + """QNN session_options appear in genai_config.json pipeline output.""" + with self._patch_onnx(): + stages, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn" + ) + cfg = build_genai_config( + _mock_config(num_hidden_layers=4), + max_cache_len=256, + prefill_seq_len=64, + pipeline=stages, + decoder_io=decoder_io, + ) + pipeline = cfg["model"]["decoder"]["pipeline"] + ctx_dict = next(s for s in pipeline if "context" in s)["context"] + itr_dict = next(s for s in pipeline if "iterator" in s)["iterator"] + emb_dict = next(s for s in pipeline if "embeddings" in s)["embeddings"] + assert "session_options" in ctx_dict + assert "session_options" in itr_dict + assert "session_options" not in emb_dict + + def test_custom_soc_model(self) -> None: + """soc_model parameter propagates to QNN provider_options.""" + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn", soc_model="73" + ) + ctx = next(s for s in stages if s.name == "context") + assert ctx.session_options["provider_options"][0]["qnn"]["soc_model"] == "73" From 6e51029c27c4421ae753082667917aa5a2c8c186 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 13:28:00 +0800 Subject: [PATCH 06/29] fix(genai_session): let genai_config.json drive EP routing, add mixed EP Remove clear_providers/append_provider calls from GenaiSession.load(). EP placement is fully driven by per-stage session_options in genai_config.json. clear_providers() only clears the top-level provider and cannot override per-stage session_options embedded in the pipeline config. - Add 'mixed' EP (use genai_config.json as-is; default for infer_genai.py) - _NEEDS_WINML_EPS covers mixed/qnn/dml to trigger EP registration - Replace _EP_PROVIDER_MAP with _VALID_EPS + _NEEDS_WINML_EPS sets - Update tests: remove append_provider assertions, add mixed/config-not-modified tests - infer_genai.py default EP changed from 'cpu' to 'mixed' Result: NPU bundle (out/qwen3_bundle_npu) now runs at 9.3 tok/s vs 1.2 tok/s CPU --- scripts/infer_genai.py | 7 +++-- src/winml/modelkit/session/genai_session.py | 31 +++++++++++---------- tests/unit/session/test_genai_session.py | 19 +++++++++---- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py index 5144fa7bc..f31847b9b 100644 --- a/scripts/infer_genai.py +++ b/scripts/infer_genai.py @@ -46,7 +46,7 @@ _REPO_ROOT = Path(__file__).resolve().parent.parent DEFAULT_MODEL_DIR = _REPO_ROOT / "out" / "qwen3_bundle" -_SUPPORTED_EPS = ["cpu", "qnn", "dml"] +_SUPPORTED_EPS = ["cpu", "mixed", "qnn", "dml"] def _wrap_chat_template(prompt: str) -> str: @@ -78,8 +78,9 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: p.add_argument( "--ep", choices=_SUPPORTED_EPS, - default="cpu", - help="Execution provider (default: cpu).", + default="mixed", + help="Execution provider: 'mixed' uses genai_config.json as-is (default); " + "'cpu' forces all stages to CPU; 'qnn'/'dml' for full NPU/GPU.", ) p.add_argument( "--max-new", diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 229b4e83a..61423efe1 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -61,14 +61,15 @@ logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- -# EP name mapping: user-friendly short name → ORT GenAI provider string. -# None means "do not append a provider" (= default CPU execution). +# Valid EP short names. +# "mixed" = use genai_config.json as-is (embeddings/lm_head on CPU, +# ctx/iter on the target accelerator). +# EP routing is driven entirely by per-stage session_options in the bundle's +# genai_config.json — GenaiSession never calls clear_providers/append_provider. # --------------------------------------------------------------------------- -_EP_PROVIDER_MAP: dict[str, str | None] = { - "cpu": None, - "qnn": "QNNExecutionProvider", - "dml": "DmlExecutionProvider", -} +_VALID_EPS: frozenset[str] = frozenset({"cpu", "mixed", "qnn", "dml"}) +# EPs that require WinML EP discovery + registration before og.Model() init. +_NEEDS_WINML_EPS: frozenset[str] = frozenset({"mixed", "qnn", "dml"}) # --------------------------------------------------------------------------- @@ -178,8 +179,8 @@ def __init__( f"genai_config.json not found in {self._bundle_dir}. " "Run export_qwen3_transformer_only.py --genai-bundle first." ) - if self._ep not in _EP_PROVIDER_MAP: - raise ValueError(f"Unknown EP {ep!r}. Supported: {sorted(_EP_PROVIDER_MAP)}") + if self._ep not in _VALID_EPS: + raise ValueError(f"Unknown EP {ep!r}. Supported: {sorted(_VALID_EPS)}") logger.info("GenaiSession initialized: bundle=%s ep=%s", self._bundle_dir, self._ep) @@ -202,8 +203,8 @@ def load(self) -> None: og = self._import_og() - # Register WinML EPs to ORT GenAI (skipped for CPU; idempotent). - if self._ep != "cpu": + # Register WinML EPs to ORT GenAI when the bundle may use a hardware EP. + if self._ep in _NEEDS_WINML_EPS: self._register_eps(og) if self._verbose: @@ -211,10 +212,10 @@ def load(self) -> None: try: config = og.Config(str(self._bundle_dir)) - config.clear_providers() - provider = _EP_PROVIDER_MAP[self._ep] - if provider is not None: - config.append_provider(provider) + # EP routing is driven entirely by genai_config.json (per-stage + # session_options). Do NOT call clear_providers/append_provider — + # those only touch the top-level provider and cannot override + # per-stage session_options already embedded in the pipeline config. self._model = og.Model(config) self._tokenizer = og.Tokenizer(self._model) except Exception as exc: diff --git a/tests/unit/session/test_genai_session.py b/tests/unit/session/test_genai_session.py index 4859ef11e..4dcf7ea1c 100644 --- a/tests/unit/session/test_genai_session.py +++ b/tests/unit/session/test_genai_session.py @@ -114,7 +114,7 @@ def test_bundle_dir_property(self, bundle_dir: Path) -> None: assert session.bundle_dir == bundle_dir def test_supported_eps(self, bundle_dir: Path) -> None: - for ep in ("cpu", "qnn", "dml"): + for ep in ("cpu", "mixed", "qnn", "dml"): session = GenaiSession(bundle_dir, ep=ep) assert session.ep == ep @@ -225,20 +225,27 @@ def test_non_cpu_registers_winml_eps(self, bundle_dir: Path, mock_og: MagicMock) session.load() mock_registry.register_execution_providers.assert_called_once_with(ort_genai=True) - def test_non_cpu_appends_provider_to_config(self, bundle_dir: Path, mock_og: MagicMock) -> None: + def test_mixed_registers_winml_eps(self, bundle_dir: Path, mock_og: MagicMock) -> None: + mock_registry = MagicMock() + mock_registry.winml_available = True + mock_registry.register_execution_providers.return_value = { + "onnxruntime_genai": ["QNNExecutionProvider"] + } with ( _patch_og(mock_og), patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, ): - mock_reg_cls.get_instance.return_value = MagicMock(winml_available=False) - session = GenaiSession(bundle_dir, ep="qnn") + mock_reg_cls.get_instance.return_value = mock_registry + session = GenaiSession(bundle_dir, ep="mixed") session.load() - mock_og.Config.return_value.append_provider.assert_called_once_with("QNNExecutionProvider") + mock_registry.register_execution_providers.assert_called_once_with(ort_genai=True) - def test_cpu_does_not_append_provider(self, bundle_dir: Path, mock_og: MagicMock) -> None: + def test_config_not_modified_at_load(self, bundle_dir: Path, mock_og: MagicMock) -> None: + # EP routing is driven by genai_config.json — we must NOT touch the config. with _patch_og(mock_og): session = GenaiSession(bundle_dir, ep="cpu") session.load() + mock_og.Config.return_value.clear_providers.assert_not_called() mock_og.Config.return_value.append_provider.assert_not_called() From 8cfa8dc4e456f4e81f03d15fd41aee4262c281ac Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 16:36:47 +0800 Subject: [PATCH 07/29] feat: add --compile flag to infer_genai.py for EPContext pre-compilation - GenaiSession gains compile=True parameter - _prepare_compiled_bundle(): detects QNN stages from genai_config.json, compiles each stage to EPContext ONNX via ort.ModelCompiler in a subprocess - _compile_stage(): 5-minute timeout per stage to handle QNN SDK hang (known bug: w8a16 + multi-token prefill hangs indefinitely) - Compiled artifacts cached in bundle_dir/_compiled/; reused on subsequent runs - _mirror_non_onnx_files(): symlinks/copies tokenizer files so og.Config can load from the compiled sub-directory - infer_genai.py --compile flag wired through to GenaiSession --- scripts/infer_genai.py | 20 ++- src/winml/modelkit/session/genai_session.py | 190 +++++++++++++++++++- 2 files changed, 205 insertions(+), 5 deletions(-) diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py index f31847b9b..47aa6683e 100644 --- a/scripts/infer_genai.py +++ b/scripts/infer_genai.py @@ -26,6 +26,11 @@ uv run python scripts/infer_genai.py \\ --model-dir out/my_bundle --prompt "Hi" --ep cpu + # Pre-compile QNN stages to EPContext on first run; reuse cache on subsequent runs. + # Eliminates per-run JIT overhead (~60-90 s saved on Snapdragon X Elite). + uv run python scripts/infer_genai.py \\ + --prompt "Hello" --ep mixed --compile + Dependencies (install in a fresh venv):: pip install onnxruntime-genai-winml @@ -93,6 +98,17 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: action="store_true", help="Wrap --prompt in the ChatML template (<|im_start|>user/assistant).", ) + p.add_argument( + "--compile", + action="store_true", + help=( + "Pre-compile QNN pipeline stages to EPContext ONNX before loading. " + "On first use this triggers ort.ModelCompiler per stage (~60-90 s for iter). " + "Compiled artifacts are cached in bundle_dir/_compiled/; " + "subsequent runs reuse the cache and skip JIT. " + "Has no effect when --ep cpu." + ), + ) p.add_argument( "--verbose", action="store_true", @@ -109,7 +125,9 @@ def main(argv: list[str] | None = None) -> int: gen_cfg = GenerationConfig(max_new_tokens=args.max_new, do_sample=False) try: - session = GenaiSession(args.model_dir, ep=args.ep, verbose=args.verbose) + session = GenaiSession( + args.model_dir, ep=args.ep, verbose=args.verbose, compile=args.compile + ) except FileNotFoundError as exc: print(f"ERROR: {exc}", file=sys.stderr) return 1 diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 61423efe1..2069d3076 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -47,6 +47,7 @@ import json import logging +import shutil from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING @@ -149,8 +150,17 @@ class GenaiSession: ``None`` (default), read from ``genai_config.json``. Must match the ``--max-cache-len`` used during the winml-cli build. verbose: Enable ``onnxruntime-genai`` native model I/O logging. + compile: Pre-compile QNN pipeline stages to EPContext ONNX on first + run (inside ``bundle_dir/_compiled/``). Subsequent calls reuse + the cached EPContext files, eliminating per-run JIT overhead. + Only stages that can be compiled without hanging are attempted; + stages that fail compilation fall back to the original ONNX. + Has no effect when ``ep="cpu"``. """ + # Sub-directory within the bundle that holds pre-compiled EPContext ONNX files. + _COMPILED_SUBDIR: str = "_compiled" + def __init__( self, bundle_dir: str | Path, @@ -158,11 +168,13 @@ def __init__( *, context_length: int | None = None, verbose: bool = False, + compile: bool = False, ) -> None: self._bundle_dir = Path(bundle_dir) self._ep = ep.lower() self._context_length_override = context_length self._verbose = verbose + self._compile = compile # Resolved at load() time. self._context_length: int | None = None @@ -210,8 +222,13 @@ def load(self) -> None: if self._verbose: og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True) + # Determine which bundle directory og.Config should load from. + load_dir = self._bundle_dir + if self._compile and self._ep in _NEEDS_WINML_EPS: + load_dir = self._prepare_compiled_bundle() + try: - config = og.Config(str(self._bundle_dir)) + config = og.Config(str(load_dir)) # EP routing is driven entirely by genai_config.json (per-stage # session_options). Do NOT call clear_providers/append_provider — # those only touch the top-level provider and cannot override @@ -221,9 +238,7 @@ def load(self) -> None: except Exception as exc: self._model = None self._tokenizer = None - raise GenaiLoadError( - f"Failed to load genai bundle from {self._bundle_dir}: {exc}" - ) from exc + raise GenaiLoadError(f"Failed to load genai bundle from {load_dir}: {exc}") from exc self._context_length = self._context_length_override or self._read_context_length() logger.info( @@ -416,6 +431,173 @@ def _ensure_loaded(self) -> None: if self._model is None: self.load() + def _prepare_compiled_bundle(self) -> Path: + """Create (or reuse) a *compiled* bundle directory. + + Reads ``genai_config.json``, finds QNN-accelerated stages (those with + ``QNNExecutionProvider`` in their ``session_options``), and tries to + compile their ONNX to EPContext format using ``ort.ModelCompiler``. + + The compiled bundle is stored under ``bundle_dir/_compiled/``. On + every call the helper checks whether the cached EPContext file is + newer than the source ONNX; if so, it skips recompilation. + + Returns: + Path to the compiled bundle directory (may equal ``bundle_dir`` + if no compilable stages were found, or if all compilations failed). + """ + compiled_dir = self._bundle_dir / self._COMPILED_SUBDIR + config_src = self._bundle_dir / "genai_config.json" + cfg = json.loads(config_src.read_text(encoding="utf-8")) + + # Collect pipeline stages that use QNNExecutionProvider. + # genai_config pipeline entries: {"ctx": {...}, "iter": {...}, ...} + pipeline: dict = cfg.get("model", {}).get("decoder", {}) + qnn_stages: list[tuple[str, str]] = [] # [(stage_key, onnx_filename), ...] + for stage_key, stage_cfg in pipeline.items(): + if not isinstance(stage_cfg, dict): + continue + so = stage_cfg.get("session_options", {}) + providers = so.get("provider_options", []) + for p in providers: + if isinstance(p, dict) and "QNNExecutionProvider" in p: + onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx") + qnn_stages.append((stage_key, onnx_filename)) + break + + if not qnn_stages: + logger.info("No QNN stages found in genai_config.json; skipping compilation") + return self._bundle_dir + + compiled_dir.mkdir(exist_ok=True) + modified_cfg = json.loads(config_src.read_text(encoding="utf-8")) + any_compiled = False + + for stage_key, onnx_filename in qnn_stages: + src_onnx = self._bundle_dir / onnx_filename + ctx_onnx = compiled_dir / f"{stage_key}_ctx.onnx" + + # Skip recompilation if cache is up-to-date. + if ctx_onnx.exists() and ctx_onnx.stat().st_mtime >= src_onnx.stat().st_mtime: + logger.info("Stage %r: reusing cached EPContext %s", stage_key, ctx_onnx.name) + self._patch_stage_filename(modified_cfg, stage_key, str(ctx_onnx)) + any_compiled = True + continue + + # Attempt compilation. + success = self._compile_stage(src_onnx, ctx_onnx, stage_key) + if success: + self._patch_stage_filename(modified_cfg, stage_key, str(ctx_onnx)) + any_compiled = True + else: + logger.warning( + "Stage %r: compilation failed or was skipped; using original ONNX", stage_key + ) + + if not any_compiled: + return self._bundle_dir + + # Write the modified genai_config into the compiled sub-directory so that + # ort-genai can resolve all ONNX paths (absolute paths are used). + # Also symlink/copy every other file that og.Config expects. + compiled_config = compiled_dir / "genai_config.json" + compiled_config.write_text( + json.dumps(modified_cfg, indent=2, ensure_ascii=False), encoding="utf-8" + ) + self._mirror_non_onnx_files(compiled_dir) + + logger.info("Compiled bundle prepared at %s", compiled_dir) + return compiled_dir + + @staticmethod + def _patch_stage_filename(cfg: dict, stage_key: str, abs_path: str) -> None: + """Rewrite a pipeline stage's ``filename`` to an absolute path.""" + decoder: dict = cfg.get("model", {}).get("decoder", {}) + if stage_key in decoder and isinstance(decoder[stage_key], dict): + decoder[stage_key]["filename"] = abs_path + + def _compile_stage(self, src_onnx: Path, ctx_out: Path, stage_key: str) -> bool: + """Compile *src_onnx* to EPContext format via ``ort.ModelCompiler``. + + Runs in a subprocess so that a ModelCompiler hang (a known QNN SDK bug + with w8a16 + multi-token prefill) does not block the caller. + + Args: + src_onnx: Source ONNX file path. + ctx_out: Destination EPContext ONNX path. + stage_key: Human-readable label for logging. + + Returns: + ``True`` if compilation succeeded; ``False`` on timeout or error. + """ + import multiprocessing + + compile_timeout_s = 300 # 5 minutes; iter compiles in ~67s normally + + logger.info("Compiling stage %r: %s → %s", stage_key, src_onnx.name, ctx_out.name) + + def _do_compile(src: str, dst: str) -> None: + import onnxruntime as ort + + from winml.modelkit.session.ep_registry import WinMLEPRegistry + from winml.modelkit.winml import add_ep_for_device + + registry = WinMLEPRegistry.get_instance() + registry.register_execution_providers() + so = ort.SessionOptions() + so.add_session_config_entry("ep.context_enable", "1") + so.add_session_config_entry("ep.context_file_path", dst) + add_ep_for_device(so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU) + mc = ort.ModelCompiler(so, src, embed_compiled_data_into_model=False) + mc.compile_to_file(dst) + + ctx = multiprocessing.get_context("spawn") + proc = ctx.Process(target=_do_compile, args=(str(src_onnx), str(ctx_out))) + proc.start() + proc.join(timeout=compile_timeout_s) + + if proc.is_alive(): + logger.warning( + "Stage %r compilation timed out after %ds (QNN SDK hang). " + "This is a known issue with multi-token prefill + w8a16 quantization. " + "Falling back to JIT compilation for this stage.", + stage_key, + compile_timeout_s, + ) + proc.kill() + proc.join() + # Remove partial output file. + ctx_out.unlink(missing_ok=True) + return False + + if proc.exitcode != 0: + logger.warning("Stage %r compilation failed (exit %d)", stage_key, proc.exitcode) + ctx_out.unlink(missing_ok=True) + return False + + logger.info("Stage %r compiled successfully → %s", stage_key, ctx_out) + return True + + def _mirror_non_onnx_files(self, compiled_dir: Path) -> None: + """Create symlinks (or copies on Windows) for every non-ONNX file. + + Files are linked/copied into *compiled_dir* so that ``og.Config`` + finds tokenizer files, specials maps, etc. Existing files are left + untouched. + """ + for src in self._bundle_dir.iterdir(): + if src.name == self._COMPILED_SUBDIR: + continue + dst = compiled_dir / src.name + if dst.exists(): + continue + if src.is_file(): + try: + dst.symlink_to(src.resolve()) + except (OSError, NotImplementedError): + # Symlinks may require elevated privileges on Windows; fall back to copy. + shutil.copy2(src, dst) + @staticmethod def _import_og() -> object: """Import and return the ``onnxruntime_genai`` module. From c9926819b0280c853ecdb30e8b432dd54be2a9aa Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 16:46:05 +0800 Subject: [PATCH 08/29] fix: resolve prefill compilation hang by forcing htp_graph_finalization_optimization_mode=0 Root cause: QNN SDK ModelCompiler deadlocks when compiling w8a16 quantized ONNX with multi-token static input shapes (seq_len > 1) at graph finalization optimization levels 1-3. The genai_config uses level 3 for runtime inference, which triggers the hang when passed to ModelCompiler directly. Fix: _compile_stage now forces htp_graph_finalization_optimization_mode=0 for compilation. This lets ModelCompiler finish (ctx ~41s, iter ~67s) while runtime inference still uses the full level-3 optimization from genai_config (EPContext loading bypasses compilation entirely, so the runtime option is irrelevant). Also fixes: - Pipeline stage detection: genai_config uses 'qnn' key (not 'QNNExecutionProvider') in provider_options; detection and option extraction now uses the correct key - _patch_stage_filename: genai_config pipeline is a list, not a dict; updated to iterate list entries correctly - _prepare_compiled_bundle: passes QNN provider options from each stage's session_options to _compile_stage so soc_model, backend_path, etc. are respected - Removed the 'prefill fallback to JIT' warning since the hang is now fixed --- src/winml/modelkit/session/genai_session.py | 103 +++++++++++++------- 1 file changed, 70 insertions(+), 33 deletions(-) diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 2069d3076..25b2ce311 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -450,20 +450,25 @@ def _prepare_compiled_bundle(self) -> Path: config_src = self._bundle_dir / "genai_config.json" cfg = json.loads(config_src.read_text(encoding="utf-8")) - # Collect pipeline stages that use QNNExecutionProvider. - # genai_config pipeline entries: {"ctx": {...}, "iter": {...}, ...} - pipeline: dict = cfg.get("model", {}).get("decoder", {}) - qnn_stages: list[tuple[str, str]] = [] # [(stage_key, onnx_filename), ...] - for stage_key, stage_cfg in pipeline.items(): - if not isinstance(stage_cfg, dict): + # Collect pipeline stages that use a QNN EP ("qnn" key in provider_options). + # genai_config pipeline entries: [{"context": {...}}, {"iterator": {...}}, ...] + # provider_options format: [{"qnn": {...}}] + pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", []) + # [(stage_key, onnx_filename, qnn_opts), ...] + qnn_stages: list[tuple[str, str, dict]] = [] + for stage_entry in pipeline_list: + if not isinstance(stage_entry, dict): continue - so = stage_cfg.get("session_options", {}) - providers = so.get("provider_options", []) - for p in providers: - if isinstance(p, dict) and "QNNExecutionProvider" in p: - onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx") - qnn_stages.append((stage_key, onnx_filename)) - break + for stage_key, stage_cfg in stage_entry.items(): + if not isinstance(stage_cfg, dict): + continue + so = stage_cfg.get("session_options", {}) + providers = so.get("provider_options", []) + for p in providers: + if isinstance(p, dict) and "qnn" in p: + onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx") + qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"]))) + break if not qnn_stages: logger.info("No QNN stages found in genai_config.json; skipping compilation") @@ -473,7 +478,7 @@ def _prepare_compiled_bundle(self) -> Path: modified_cfg = json.loads(config_src.read_text(encoding="utf-8")) any_compiled = False - for stage_key, onnx_filename in qnn_stages: + for stage_key, onnx_filename, qnn_opts in qnn_stages: src_onnx = self._bundle_dir / onnx_filename ctx_onnx = compiled_dir / f"{stage_key}_ctx.onnx" @@ -485,13 +490,13 @@ def _prepare_compiled_bundle(self) -> Path: continue # Attempt compilation. - success = self._compile_stage(src_onnx, ctx_onnx, stage_key) + success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts) if success: self._patch_stage_filename(modified_cfg, stage_key, str(ctx_onnx)) any_compiled = True else: logger.warning( - "Stage %r: compilation failed or was skipped; using original ONNX", stage_key + "Stage %r: compilation failed; using original ONNX (JIT fallback)", stage_key ) if not any_compiled: @@ -512,31 +517,64 @@ def _prepare_compiled_bundle(self) -> Path: @staticmethod def _patch_stage_filename(cfg: dict, stage_key: str, abs_path: str) -> None: """Rewrite a pipeline stage's ``filename`` to an absolute path.""" - decoder: dict = cfg.get("model", {}).get("decoder", {}) - if stage_key in decoder and isinstance(decoder[stage_key], dict): - decoder[stage_key]["filename"] = abs_path - - def _compile_stage(self, src_onnx: Path, ctx_out: Path, stage_key: str) -> bool: + pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", []) + for stage_entry in pipeline_list: + if isinstance(stage_entry, dict) and stage_key in stage_entry: + stage_cfg = stage_entry[stage_key] + if isinstance(stage_cfg, dict): + stage_cfg["filename"] = abs_path + return + + def _compile_stage( + self, + src_onnx: Path, + ctx_out: Path, + stage_key: str, + qnn_opts: dict | None = None, + ) -> bool: """Compile *src_onnx* to EPContext format via ``ort.ModelCompiler``. - Runs in a subprocess so that a ModelCompiler hang (a known QNN SDK bug - with w8a16 + multi-token prefill) does not block the caller. + Runs in a subprocess so that a ModelCompiler failure does not block + the caller. The QNN options from ``genai_config.json`` are forwarded + to the compilation session, with ``htp_graph_finalization_optimization_mode`` + forced to ``"0"``. This avoids a QNN SDK deadlock that occurs when + compiling w8a16 quantized models with multi-token static input shapes + (``seq_len > 1``) at higher optimization levels. + + The resulting EPContext ONNX is identical in interface to the original; + at runtime, ort-genai loads the pre-compiled QNN binary and the + inference-time ``htp_graph_finalization_optimization_mode`` from + ``genai_config.json`` governs any further JIT compilation. Args: src_onnx: Source ONNX file path. ctx_out: Destination EPContext ONNX path. stage_key: Human-readable label for logging. + qnn_opts: QNN provider options from genai_config (e.g. backend_path, + htp_performance_mode, soc_model). ``htp_graph_finalization_ + optimization_mode`` is always overridden to ``"0"``. Returns: ``True`` if compilation succeeded; ``False`` on timeout or error. """ import multiprocessing - compile_timeout_s = 300 # 5 minutes; iter compiles in ~67s normally + # Force graph-finalization optimization off. Levels 1-3 deadlock QNN + # ModelCompiler for w8a16 quantized models with multi-token input shapes. + compile_qnn_opts = dict(qnn_opts or {}) + compile_qnn_opts["htp_graph_finalization_optimization_mode"] = "0" - logger.info("Compiling stage %r: %s → %s", stage_key, src_onnx.name, ctx_out.name) + compile_timeout_s = 300 # 5 minutes; ctx compiles in ~41s, iter in ~67s - def _do_compile(src: str, dst: str) -> None: + logger.info( + "Compiling stage %r: %s → %s (qnn_opts=%s)", + stage_key, + src_onnx.name, + ctx_out.name, + compile_qnn_opts, + ) + + def _do_compile(src: str, dst: str, qnn_options: dict) -> None: import onnxruntime as ort from winml.modelkit.session.ep_registry import WinMLEPRegistry @@ -547,26 +585,25 @@ def _do_compile(src: str, dst: str) -> None: so = ort.SessionOptions() so.add_session_config_entry("ep.context_enable", "1") so.add_session_config_entry("ep.context_file_path", dst) - add_ep_for_device(so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU) + add_ep_for_device( + so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU, qnn_options + ) mc = ort.ModelCompiler(so, src, embed_compiled_data_into_model=False) mc.compile_to_file(dst) ctx = multiprocessing.get_context("spawn") - proc = ctx.Process(target=_do_compile, args=(str(src_onnx), str(ctx_out))) + proc = ctx.Process(target=_do_compile, args=(str(src_onnx), str(ctx_out), compile_qnn_opts)) proc.start() proc.join(timeout=compile_timeout_s) if proc.is_alive(): - logger.warning( - "Stage %r compilation timed out after %ds (QNN SDK hang). " - "This is a known issue with multi-token prefill + w8a16 quantization. " - "Falling back to JIT compilation for this stage.", + logger.error( + "Stage %r compilation timed out after %ds — killing subprocess.", stage_key, compile_timeout_s, ) proc.kill() proc.join() - # Remove partial output file. ctx_out.unlink(missing_ok=True) return False From 5b0a0e2826478ea1e7ebd88c0c9ad9185cb5580d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 16:55:25 +0800 Subject: [PATCH 09/29] fix: move _do_compile to module-level _qnn_compile_worker for Windows spawn Windows multiprocessing spawn serialises the subprocess target via pickle. Local functions (closures) defined inside a method cannot be pickled, which caused 'AttributeError: Can't pickle local function' at runtime. Moved the compilation logic to a module-level function _qnn_compile_worker so it is importable by name in the spawned subprocess. Also fix ONNX filename in compiled genai_config: use ctx_onnx.name (just the filename) instead of str(ctx_onnx) (absolute path). ort-genai resolves filenames relative to the directory passed to og.Config, so an absolute path causes double-path concatenation and a 'file not found' error. --- src/winml/modelkit/session/genai_session.py | 60 +++++++++++++-------- 1 file changed, 37 insertions(+), 23 deletions(-) diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 25b2ce311..5943c0266 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -61,6 +61,33 @@ logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Module-level compilation worker (must be at module scope for multiprocessing +# spawn on Windows, which serialises the target via pickle). +# --------------------------------------------------------------------------- + + +def _qnn_compile_worker(src: str, dst: str, qnn_options: dict) -> None: + """Compile *src* ONNX to an EPContext ONNX at *dst* using QNN HTP. + + Executed in a subprocess by :meth:`GenaiSession._compile_stage`. + """ + import onnxruntime as ort + + from winml.modelkit.session.ep_registry import WinMLEPRegistry + from winml.modelkit.winml import add_ep_for_device + + registry = WinMLEPRegistry.get_instance() + registry.register_execution_providers() + so = ort.SessionOptions() + so.add_session_config_entry("ep.context_enable", "1") + so.add_session_config_entry("ep.context_file_path", dst) + add_ep_for_device(so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU, qnn_options) + mc = ort.ModelCompiler(so, src, embed_compiled_data_into_model=False) + mc.compile_to_file(dst) + + # --------------------------------------------------------------------------- # Valid EP short names. # "mixed" = use genai_config.json as-is (embeddings/lm_head on CPU, @@ -485,14 +512,16 @@ def _prepare_compiled_bundle(self) -> Path: # Skip recompilation if cache is up-to-date. if ctx_onnx.exists() and ctx_onnx.stat().st_mtime >= src_onnx.stat().st_mtime: logger.info("Stage %r: reusing cached EPContext %s", stage_key, ctx_onnx.name) - self._patch_stage_filename(modified_cfg, stage_key, str(ctx_onnx)) + # Use just the filename — genai_config.json lives in compiled_dir, + # so ort-genai resolves filenames relative to compiled_dir. + self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name) any_compiled = True continue # Attempt compilation. success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts) if success: - self._patch_stage_filename(modified_cfg, stage_key, str(ctx_onnx)) + self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name) any_compiled = True else: logger.warning( @@ -502,9 +531,9 @@ def _prepare_compiled_bundle(self) -> Path: if not any_compiled: return self._bundle_dir - # Write the modified genai_config into the compiled sub-directory so that - # ort-genai can resolve all ONNX paths (absolute paths are used). - # Also symlink/copy every other file that og.Config expects. + # Write the modified genai_config into the compiled sub-directory. + # ONNX filenames are relative to compiled_dir; ort-genai resolves them + # from the directory it loads og.Config from. compiled_config = compiled_dir / "genai_config.json" compiled_config.write_text( json.dumps(modified_cfg, indent=2, ensure_ascii=False), encoding="utf-8" @@ -574,25 +603,10 @@ def _compile_stage( compile_qnn_opts, ) - def _do_compile(src: str, dst: str, qnn_options: dict) -> None: - import onnxruntime as ort - - from winml.modelkit.session.ep_registry import WinMLEPRegistry - from winml.modelkit.winml import add_ep_for_device - - registry = WinMLEPRegistry.get_instance() - registry.register_execution_providers() - so = ort.SessionOptions() - so.add_session_config_entry("ep.context_enable", "1") - so.add_session_config_entry("ep.context_file_path", dst) - add_ep_for_device( - so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU, qnn_options - ) - mc = ort.ModelCompiler(so, src, embed_compiled_data_into_model=False) - mc.compile_to_file(dst) - ctx = multiprocessing.get_context("spawn") - proc = ctx.Process(target=_do_compile, args=(str(src_onnx), str(ctx_out), compile_qnn_opts)) + proc = ctx.Process( + target=_qnn_compile_worker, args=(str(src_onnx), str(ctx_out), compile_qnn_opts) + ) proc.start() proc.join(timeout=compile_timeout_s) From 1c962e55399b0d16442f3773d404572cc335a462 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 17:08:51 +0800 Subject: [PATCH 10/29] perf: use configured htp_graph_finalization_optimization_mode for gen stages MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously _compile_stage forced mode='0' for ALL stages to avoid a QNN SDK deadlock on w8a16 + multi-token prefill. This also silently capped the iter (generation) stage at mode 0, producing under-optimized kernels (~10 tok/s). Fix: only force mode=0 for prefill stages (run_on_prompt=true, seq_len>1 where the deadlock occurs). Generation stages (run_on_token_gen=true, seq_len=1) use the configured mode from genai_config.json (typically '3'), which is safe for single-token input and produces fully-optimized kernels. Performance: Before: 10.4 tok/s (both ctx+iter compiled with mode 0) After: 43.4 tok/s (ctx mode 0, iter mode 3) — matches reference ~45 tok/s _prepare_compiled_bundle now passes is_prefill flag per stage based on run_on_prompt / run_on_token_gen fields in genai_config.json pipeline config. --- src/winml/modelkit/session/genai_session.py | 50 +++++++++++++-------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 5943c0266..8896295b7 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -481,8 +481,12 @@ def _prepare_compiled_bundle(self) -> Path: # genai_config pipeline entries: [{"context": {...}}, {"iterator": {...}}, ...] # provider_options format: [{"qnn": {...}}] pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", []) - # [(stage_key, onnx_filename, qnn_opts), ...] - qnn_stages: list[tuple[str, str, dict]] = [] + # [(stage_key, onnx_filename, qnn_opts, is_prefill), ...] + # is_prefill=True when run_on_prompt=True and run_on_token_gen=False. + # Prefill stages with seq_len>1 require htp_graph_finalization_optimization_mode="0" + # to avoid a QNN SDK deadlock; generation stages (seq_len=1) can use the full + # configured optimization level for maximum throughput. + qnn_stages: list[tuple[str, str, dict, bool]] = [] for stage_entry in pipeline_list: if not isinstance(stage_entry, dict): continue @@ -494,7 +498,11 @@ def _prepare_compiled_bundle(self) -> Path: for p in providers: if isinstance(p, dict) and "qnn" in p: onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx") - qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"]))) + is_prefill = bool( + stage_cfg.get("run_on_prompt", False) + and not stage_cfg.get("run_on_token_gen", False) + ) + qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"]), is_prefill)) break if not qnn_stages: @@ -505,7 +513,7 @@ def _prepare_compiled_bundle(self) -> Path: modified_cfg = json.loads(config_src.read_text(encoding="utf-8")) any_compiled = False - for stage_key, onnx_filename, qnn_opts in qnn_stages: + for stage_key, onnx_filename, qnn_opts, is_prefill in qnn_stages: src_onnx = self._bundle_dir / onnx_filename ctx_onnx = compiled_dir / f"{stage_key}_ctx.onnx" @@ -519,7 +527,7 @@ def _prepare_compiled_bundle(self) -> Path: continue # Attempt compilation. - success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts) + success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts, is_prefill) if success: self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name) any_compiled = True @@ -560,38 +568,42 @@ def _compile_stage( ctx_out: Path, stage_key: str, qnn_opts: dict | None = None, + is_prefill: bool = False, ) -> bool: """Compile *src_onnx* to EPContext format via ``ort.ModelCompiler``. Runs in a subprocess so that a ModelCompiler failure does not block - the caller. The QNN options from ``genai_config.json`` are forwarded - to the compilation session, with ``htp_graph_finalization_optimization_mode`` - forced to ``"0"``. This avoids a QNN SDK deadlock that occurs when - compiling w8a16 quantized models with multi-token static input shapes - (``seq_len > 1``) at higher optimization levels. + the caller. QNN options from ``genai_config.json`` are forwarded to + the compilation session. - The resulting EPContext ONNX is identical in interface to the original; - at runtime, ort-genai loads the pre-compiled QNN binary and the - inference-time ``htp_graph_finalization_optimization_mode`` from - ``genai_config.json`` governs any further JIT compilation. + For prefill stages (``is_prefill=True``) ``htp_graph_finalization_ + optimization_mode`` is forced to ``"0"`` to avoid a QNN SDK deadlock + that occurs when compiling w8a16 quantized models with multi-token + static input shapes (``seq_len > 1``) at higher optimization levels. + For generation stages (``is_prefill=False``, ``seq_len=1``) the + configured optimization level is preserved so that the compiled kernels + are as fast as the JIT path. Args: src_onnx: Source ONNX file path. ctx_out: Destination EPContext ONNX path. stage_key: Human-readable label for logging. qnn_opts: QNN provider options from genai_config (e.g. backend_path, - htp_performance_mode, soc_model). ``htp_graph_finalization_ - optimization_mode`` is always overridden to ``"0"``. + htp_performance_mode, soc_model). + is_prefill: ``True`` when the stage runs only on prompt (ctx) and has + multi-token input; forces ``htp_graph_finalization_optimization_mode`` + to ``"0"`` to avoid the QNN SDK deadlock. Returns: ``True`` if compilation succeeded; ``False`` on timeout or error. """ import multiprocessing - # Force graph-finalization optimization off. Levels 1-3 deadlock QNN - # ModelCompiler for w8a16 quantized models with multi-token input shapes. compile_qnn_opts = dict(qnn_opts or {}) - compile_qnn_opts["htp_graph_finalization_optimization_mode"] = "0" + if is_prefill: + # QNN SDK deadlocks at levels 1-3 for w8a16 models with seq_len > 1. + compile_qnn_opts["htp_graph_finalization_optimization_mode"] = "0" + # else: keep the configured mode (typically "3") for generation stages. compile_timeout_s = 300 # 5 minutes; ctx compiles in ~41s, iter in ~67s From 837cd8459d4daf043cd195e23ed02c20a0d7d1a3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 17:28:26 +0800 Subject: [PATCH 11/29] simplify: remove htp_graph_finalization_optimization_mode override in _compile_stage The original mode=0 override was added to avoid a QNN SDK deadlock when compiling w8a16 prefill (seq_len>1) at higher optimization levels. Testing revealed the deadlock only occurs when QNN provider options are NOT passed to ort.ModelCompiler at all (causing it to fall back to a broken default path). With correct QNN options (backend_path, soc_model, etc.) forwarded, mode=3 compiles successfully for both ctx (~73s) and iter (~67s) with no hang. Remove the is_prefill flag and mode override entirely. _compile_stage now passes genai_config QNN options unchanged, giving fully-optimized kernels for all stages. Performance (hot NPU, EPContext loaded): ctx+iter both mode=3: ~44.5 tok/s vs reference ~45 tok/s --- src/winml/modelkit/session/genai_session.py | 45 +++++---------------- 1 file changed, 11 insertions(+), 34 deletions(-) diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 8896295b7..34e2ef5e4 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -481,12 +481,8 @@ def _prepare_compiled_bundle(self) -> Path: # genai_config pipeline entries: [{"context": {...}}, {"iterator": {...}}, ...] # provider_options format: [{"qnn": {...}}] pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", []) - # [(stage_key, onnx_filename, qnn_opts, is_prefill), ...] - # is_prefill=True when run_on_prompt=True and run_on_token_gen=False. - # Prefill stages with seq_len>1 require htp_graph_finalization_optimization_mode="0" - # to avoid a QNN SDK deadlock; generation stages (seq_len=1) can use the full - # configured optimization level for maximum throughput. - qnn_stages: list[tuple[str, str, dict, bool]] = [] + # [(stage_key, onnx_filename, qnn_opts), ...] + qnn_stages: list[tuple[str, str, dict]] = [] for stage_entry in pipeline_list: if not isinstance(stage_entry, dict): continue @@ -498,11 +494,7 @@ def _prepare_compiled_bundle(self) -> Path: for p in providers: if isinstance(p, dict) and "qnn" in p: onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx") - is_prefill = bool( - stage_cfg.get("run_on_prompt", False) - and not stage_cfg.get("run_on_token_gen", False) - ) - qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"]), is_prefill)) + qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"]))) break if not qnn_stages: @@ -513,7 +505,7 @@ def _prepare_compiled_bundle(self) -> Path: modified_cfg = json.loads(config_src.read_text(encoding="utf-8")) any_compiled = False - for stage_key, onnx_filename, qnn_opts, is_prefill in qnn_stages: + for stage_key, onnx_filename, qnn_opts in qnn_stages: src_onnx = self._bundle_dir / onnx_filename ctx_onnx = compiled_dir / f"{stage_key}_ctx.onnx" @@ -527,7 +519,7 @@ def _prepare_compiled_bundle(self) -> Path: continue # Attempt compilation. - success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts, is_prefill) + success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts) if success: self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name) any_compiled = True @@ -568,31 +560,21 @@ def _compile_stage( ctx_out: Path, stage_key: str, qnn_opts: dict | None = None, - is_prefill: bool = False, ) -> bool: """Compile *src_onnx* to EPContext format via ``ort.ModelCompiler``. Runs in a subprocess so that a ModelCompiler failure does not block - the caller. QNN options from ``genai_config.json`` are forwarded to - the compilation session. - - For prefill stages (``is_prefill=True``) ``htp_graph_finalization_ - optimization_mode`` is forced to ``"0"`` to avoid a QNN SDK deadlock - that occurs when compiling w8a16 quantized models with multi-token - static input shapes (``seq_len > 1``) at higher optimization levels. - For generation stages (``is_prefill=False``, ``seq_len=1``) the - configured optimization level is preserved so that the compiled kernels - are as fast as the JIT path. + the caller. The QNN options from ``genai_config.json`` are forwarded + unchanged to the compilation session, so each stage is compiled at + exactly the optimization level configured in the bundle. Args: src_onnx: Source ONNX file path. ctx_out: Destination EPContext ONNX path. stage_key: Human-readable label for logging. qnn_opts: QNN provider options from genai_config (e.g. backend_path, - htp_performance_mode, soc_model). - is_prefill: ``True`` when the stage runs only on prompt (ctx) and has - multi-token input; forces ``htp_graph_finalization_optimization_mode`` - to ``"0"`` to avoid the QNN SDK deadlock. + htp_performance_mode, htp_graph_finalization_optimization_mode, + soc_model). Returns: ``True`` if compilation succeeded; ``False`` on timeout or error. @@ -600,12 +582,7 @@ def _compile_stage( import multiprocessing compile_qnn_opts = dict(qnn_opts or {}) - if is_prefill: - # QNN SDK deadlocks at levels 1-3 for w8a16 models with seq_len > 1. - compile_qnn_opts["htp_graph_finalization_optimization_mode"] = "0" - # else: keep the configured mode (typically "3") for generation stages. - - compile_timeout_s = 300 # 5 minutes; ctx compiles in ~41s, iter in ~67s + compile_timeout_s = 300 # 5 minutes; ctx compiles in ~73s, iter in ~67s logger.info( "Compiling stage %r: %s → %s (qnn_opts=%s)", From 74ca8cfab2a872ec3e2ad6867091b22bbcc3304b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 30 Jun 2026 21:44:25 +0800 Subject: [PATCH 12/29] refactor(genai): move generic bundle logic to utils/genai, qwen3/genai as shim - Extract all architecture-agnostic logic (PipelineStage, DecoderIOMapping, build_genai_config, build_decoder_pipeline_stages, write_genai_bundle, qnn_stage_session_options, ONNX introspection helpers) into src/winml/modelkit/utils/genai.py so other model families can reuse it - Reduce qwen3/genai.py to a thin re-export shim with a backward-compatible build_qwen3_transformer_only_stages alias for existing callers - fix(codeql): remove unused _TOKENIZER_FILES from utils/genai.py - fix(codeql): remove unnecessary del generator in GenaiSession.generate_streaming - fix(codeql): add missing Protocol body ellipsis in QuantConfigFinalizer.finalize - fix(codeql): import get_quant_finalizer directly in quant/__init__.py - fix(test): update mock patch path to winml.modelkit.utils.genai._introspect_onnx_io - fix(test): replace bare 'import onnx' with 'from onnx import ...' in test_qwen3_calibration.py --- .../modelkit/models/hf/qwen3/__init__.py | 2 + src/winml/modelkit/models/hf/qwen3/genai.py | 672 +----------------- src/winml/modelkit/quant/__init__.py | 11 +- src/winml/modelkit/quant/calibration/base.py | 1 + src/winml/modelkit/session/genai_session.py | 21 +- src/winml/modelkit/utils/genai.py | 663 +++++++++++++++++ tests/unit/models/qwen3/test_genai_config.py | 2 +- .../calibration/test_qwen3_calibration.py | 25 +- 8 files changed, 724 insertions(+), 673 deletions(-) create mode 100644 src/winml/modelkit/utils/genai.py diff --git a/src/winml/modelkit/models/hf/qwen3/__init__.py b/src/winml/modelkit/models/hf/qwen3/__init__.py index 8d8676398..dbabe2d60 100644 --- a/src/winml/modelkit/models/hf/qwen3/__init__.py +++ b/src/winml/modelkit/models/hf/qwen3/__init__.py @@ -15,6 +15,7 @@ from .genai import ( DecoderIOMapping, PipelineStage, + build_decoder_pipeline_stages, build_genai_config, build_qwen3_transformer_only_stages, write_genai_bundle, @@ -24,6 +25,7 @@ __all__ = [ "DecoderIOMapping", "PipelineStage", + "build_decoder_pipeline_stages", "build_genai_config", "build_qwen3_transformer_only_stages", "write_genai_bundle", diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py index 4a63de45b..9e65908f5 100644 --- a/src/winml/modelkit/models/hf/qwen3/genai.py +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -2,659 +2,40 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -r"""Generate an onnxruntime-genai bundle for a transformer-only decoder pipeline. +"""Qwen3 genai bundle support — thin shim over :mod:`winml.modelkit.utils.genai`. -The bundle is a directory that ``onnxruntime-genai`` can load directly via -``og.Config(str(bundle_dir))``. It contains: +All generic logic (``PipelineStage``, ``DecoderIOMapping``, ``build_genai_config``, +``build_decoder_pipeline_stages``, ``write_genai_bundle``) lives in +:mod:`winml.modelkit.utils.genai` so it can be reused by other model families. - genai_config.json — pipeline config consumed by onnxruntime-genai - ctx.onnx — prefill/context ONNX (built by winml-cli) - iter.onnx — iteration/decode ONNX (built by winml-cli) - embeddings.onnx — embedding-lookup ONNX (placeholder; copy externally) - lm_head.onnx — lm_head ONNX (placeholder; copy externally) - tokenizer.json — HF tokenizer files (downloaded from the model repo) - tokenizer_config.json - vocab.json / merges.txt / generation_config.json - -The pipeline follows the same 4-stage layout as the reference bundle: - - input_ids → [embeddings] → input_hidden_states - → [context | iterator] → output_hidden_states + present KVs - → [lm_head] → logits - -The context stage runs on the prompt (prefill); the iterator stage runs on each -subsequent decode step. Both share the same KV cache buffer via genai's -``past_present_share_buffer`` mode. - -Public API:: - - from winml.modelkit.models.hf.qwen3.genai import ( - build_genai_config, - build_qwen3_transformer_only_stages, - write_genai_bundle, - DecoderIOMapping, - PipelineStage, - ) - - # High-level: derive everything from the built ONNX files - stages, decoder_io = build_qwen3_transformer_only_stages( - ctx_path, iter_path, num_layers=hf_config.num_hidden_layers - ) - cfg = build_genai_config( - hf_config, max_cache_len=256, prefill_seq_len=64, - pipeline=stages, decoder_io=decoder_io, - ) - - # Or one-shot bundle assembly - write_genai_bundle( - Path("out/bundle"), - context_onnx=ctx_path, - iterator_onnx=iter_path, - model_id="Qwen/Qwen3-0.6B", - max_cache_len=256, - prefill_seq_len=64, - embeddings_src=emb_path, # None = skip (add later) - lm_head_src=lmh_path, # None = skip (add later) - ) +This module re-exports that API unchanged and adds +``build_qwen3_transformer_only_stages`` as a backward-compatible alias for +``build_decoder_pipeline_stages``. New code should prefer the generic names. """ from __future__ import annotations -import json -import logging -import re -from dataclasses import dataclass -from pathlib import Path -from typing import Any - - -logger = logging.getLogger(__name__) - -# Default filenames inside the bundle directory. -DEFAULT_EMBEDDINGS_FILENAME = "embeddings.onnx" -DEFAULT_CONTEXT_FILENAME = "ctx.onnx" -DEFAULT_ITERATOR_FILENAME = "iter.onnx" -DEFAULT_LM_HEAD_FILENAME = "lm_head.onnx" - -# Tokenizer files written by AutoTokenizer.save_pretrained. -_TOKENIZER_FILES = [ - "tokenizer.json", - "tokenizer_config.json", - "vocab.json", - "merges.txt", - "generation_config.json", - "special_tokens_map.json", -] - -# Regex for detecting indexed tensor names such as ``past_keys_3``. -_KV_INDEXED_RE = re.compile(r"^(.+?)(\d+)$") - - -# --------------------------------------------------------------------------- -# Pipeline data structures -# --------------------------------------------------------------------------- - - -@dataclass -class PipelineStage: - """One stage in an onnxruntime-genai multi-model pipeline. - - Attributes: - name: Stage key used inside the ``pipeline`` list of ``genai_config.json``. - filename: ONNX filename inside the bundle directory. - run_on_prompt: Whether genai runs this stage during the prefill pass. - run_on_token_gen: Whether genai runs this stage during decode steps. - inputs: Actual ONNX input tensor names (not format strings). - outputs: Actual ONNX output tensor names (not format strings). - is_lm_head: Set ``True`` for the final language-model head stage. - """ - - name: str - filename: str - run_on_prompt: bool - run_on_token_gen: bool - inputs: list[str] - outputs: list[str] - is_lm_head: bool = False - session_options: dict | None = None - """Per-stage ORT session options (e.g. provider_options for QNN). - - When set, emitted verbatim as the ``session_options`` key in the - ``genai_config.json`` pipeline stage. Leave ``None`` (default) for - stages that should run on the default (CPU) provider. - """ - - def to_dict(self) -> dict: - """Serialize to the dict format expected by ``genai_config.json``.""" - d: dict = { - "filename": self.filename, - "inputs": list(self.inputs), - "outputs": list(self.outputs), - "run_on_prompt": self.run_on_prompt, - "run_on_token_gen": self.run_on_token_gen, - } - if self.session_options: - d["session_options"] = self.session_options - if self.is_lm_head: - d["is_lm_head"] = True - return d - - -@dataclass -class DecoderIOMapping: - """Maps genai's abstract I/O concepts to ONNX tensor name format strings. - - The ``*_names`` fields use ``%d`` as the layer-index placeholder, which is - the convention genai uses to expand per-layer KV cache tensor names - (e.g. ``"past_keys_%d"`` → ``"past_keys_0"``, ``"past_keys_1"``, …). - - All fields default to the names produced by the Qwen3 transformer-only - export. - """ - - input_ids: str = "input_ids" - past_sequence_length: str = "past_seq_len" - total_sequence_length: str = "total_seq_len" - past_key_names: str = "past_keys_%d" - past_value_names: str = "past_values_%d" - logits: str = "logits" - present_key_names: str = "present_keys_%d" - present_value_names: str = "present_values_%d" - - def inputs_dict(self) -> dict: - """Return the ``decoder.inputs`` mapping dict for ``genai_config.json``.""" - return { - "input_ids": self.input_ids, - "past_sequence_length": self.past_sequence_length, - "total_sequence_length": self.total_sequence_length, - "past_key_names": self.past_key_names, - "past_value_names": self.past_value_names, - } - - def outputs_dict(self) -> dict: - """Return the ``decoder.outputs`` mapping dict for ``genai_config.json``.""" - return { - "logits": self.logits, - "present_key_names": self.present_key_names, - "present_value_names": self.present_value_names, - } - - -# --------------------------------------------------------------------------- -# Generic config builder -# --------------------------------------------------------------------------- - - -def build_genai_config( - hf_config: Any, - *, - max_cache_len: int, - prefill_seq_len: int | None = None, - pipeline: list[PipelineStage], - decoder_io: DecoderIOMapping | None = None, -) -> dict: - """Build a ``genai_config.json`` dict for any decoder-pipeline model. - - This function is architecture-agnostic: the caller supplies the pipeline - stages and the I/O name mapping so no tensor names are hardcoded here. - - Args: - hf_config: A ``transformers.PretrainedConfig``. Reads: - ``num_hidden_layers``, ``hidden_size``, ``num_attention_heads``, - ``num_key_value_heads``, ``head_dim`` (optional, falls back to - ``hidden_size // num_attention_heads``), ``bos_token_id``, - ``eos_token_id``, ``pad_token_id``, ``vocab_size``. - max_cache_len: Static KV cache length → ``context_length`` and - ``search.max_length``. - prefill_seq_len: When given, emits a ``sliding_window`` section with - ``window_size=prefill_seq_len``. Pass ``None`` to omit. - pipeline: Ordered list of :class:`PipelineStage` describing each - model in the genai pipeline. - decoder_io: Format-string mapping from genai's abstract I/O names to - actual ONNX tensor names. Defaults to - :class:`DecoderIOMapping` (the Qwen3 default names). - - Returns: - A ``dict`` suitable for ``json.dumps`` as ``genai_config.json``. - """ - if decoder_io is None: - decoder_io = DecoderIOMapping() - - num_layers: int = hf_config.num_hidden_layers - head_size: int = getattr( - hf_config, - "head_dim", - hf_config.hidden_size // hf_config.num_attention_heads, - ) - - eos_token_id = hf_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] - - pad_token_id = getattr(hf_config, "pad_token_id", None) or hf_config.bos_token_id - - decoder_section: dict = { - "hidden_size": hf_config.hidden_size, - "num_attention_heads": hf_config.num_attention_heads, - "num_key_value_heads": hf_config.num_key_value_heads, - "num_hidden_layers": num_layers, - "head_size": head_size, - } - - if prefill_seq_len is not None: - decoder_section["sliding_window"] = { - "window_size": prefill_seq_len, - "pad_value": 0, - "alignment": "left", - "slide_inputs": True, - "slide_key_value_cache": False, - } - - decoder_section["inputs"] = decoder_io.inputs_dict() - decoder_section["outputs"] = decoder_io.outputs_dict() - decoder_section["pipeline"] = [{s.name: s.to_dict()} for s in pipeline] - - return { - "model": { - "type": "decoder-pipeline", - "bos_token_id": hf_config.bos_token_id, - "eos_token_id": eos_token_id, - "pad_token_id": pad_token_id, - "vocab_size": hf_config.vocab_size, - "context_length": max_cache_len, - "decoder": decoder_section, - }, - "search": { - "max_length": max_cache_len, - "min_length": 0, - "do_sample": False, - "past_present_share_buffer": True, - }, - } - - -# --------------------------------------------------------------------------- -# ONNX introspection helpers -# --------------------------------------------------------------------------- - - -def _introspect_onnx_io(onnx_path: Path) -> tuple[list[str], list[str]]: - """Return ``(input_names, output_names)`` from an ONNX model graph header. - - External data is intentionally not loaded — only the graph topology is read, - so this is fast even for large quantized models. - """ - try: - import onnx - except ImportError as exc: - raise ImportError( - "The 'onnx' package is required for ONNX introspection. " - "Install it with: pip install onnx" - ) from exc - model = onnx.load(str(onnx_path), load_external_data=False) - return ( - [inp.name for inp in model.graph.input], - [out.name for out in model.graph.output], - ) - - -def _detect_format_patterns(names: list[str], num_layers: int) -> dict[str, str]: - """Detect ``prefix%d`` patterns from a list of indexed tensor names. - - Scans *names* for entries matching ```` where exactly - *num_layers* consecutive zero-based indices are present. - - Returns: - ``{prefix: "prefix%d"}`` for each qualifying group, in the order the - prefixes first appear in *names*. Only groups covering the full - ``[0, num_layers)`` index range are returned. - - Examples:: - - >>> _detect_format_patterns( - ... ["past_keys_0", "past_keys_1", "past_values_0", "past_values_1"], - ... num_layers=2, - ... ) - {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"} - """ - groups: dict[str, list[int]] = {} - for name in names: - m = _KV_INDEXED_RE.match(name) - if m: - prefix, idx = m.group(1), int(m.group(2)) - groups.setdefault(prefix, []).append(idx) - - return { - prefix: f"{prefix}%d" - for prefix, indices in groups.items() - if len(indices) == num_layers and sorted(indices) == list(range(num_layers)) - } - - -def _sort_patterns_by_first_occurrence(patterns: dict[str, str], names: list[str]) -> list[str]: - """Sort *patterns* keys by when ``0`` first appears in *names*.""" - - def _key(prefix: str) -> int: - try: - return names.index(f"{prefix}0") - except ValueError: - return len(names) - - return sorted(patterns.keys(), key=_key) - - -# --------------------------------------------------------------------------- -# Per-EP stage session_options helpers -# --------------------------------------------------------------------------- - - -def _qnn_stage_session_options(log_id: str, soc_model: str = "60") -> dict: - """Return the ``session_options`` block that routes a stage to QNN HTP. - - Args: - log_id: ORT log identifier (shown in ORT logs), e.g. - ``"onnxruntime-genai.context"``. - soc_model: Snapdragon SoC model number passed to the QNN HTP backend. - ``"60"`` targets Snapdragon 8 Gen 3 (X Elite). Change for other - SoCs (e.g. ``"55"`` for 8 Gen 2, ``"73"`` for 8 Elite). - - Returns: - Dict suitable for the ``session_options`` key of a pipeline stage in - ``genai_config.json``. - """ - return { - "log_id": log_id, - "provider_options": [ - { - "qnn": { - "backend_path": "QnnHtp.dll", - "htp_performance_mode": "burst", - "htp_graph_finalization_optimization_mode": "3", - "soc_model": soc_model, - } - } - ], - "intra_op_num_threads": 2, - "inter_op_num_threads": 1, - } - - -# --------------------------------------------------------------------------- -# Qwen3 transformer-only pipeline factory -# --------------------------------------------------------------------------- - - -def build_qwen3_transformer_only_stages( - context_onnx: str | Path, - iterator_onnx: str | Path, - num_layers: int, - *, - context_filename: str = DEFAULT_CONTEXT_FILENAME, - iterator_filename: str = DEFAULT_ITERATOR_FILENAME, - embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, - lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, - ep: str = "cpu", - soc_model: str = "60", -) -> tuple[list[PipelineStage], DecoderIOMapping]: - """Build pipeline stages by introspecting the built ONNX models. - - Reads actual tensor names from *context_onnx* and *iterator_onnx* so the - generated ``genai_config.json`` can never drift out of sync with the real - model I/O — no tensor names are hardcoded. - - Args: - context_onnx: Path to the built prefill/context ONNX. - iterator_onnx: Path to the built decode/iterator ONNX. - num_layers: Number of transformer layers (``hf_config.num_hidden_layers``). - context_filename: Bundle filename for the context model. - iterator_filename: Bundle filename for the iterator model. - embeddings_filename: Bundle filename for the embeddings model. - lm_head_filename: Bundle filename for the lm_head model. - ep: Execution provider for the transformer stages. ``"qnn"`` injects - QNN HTP ``session_options`` into the ``context`` and ``iterator`` - stages so they run on the NPU while ``embeddings`` and ``lm_head`` - continue on CPU. ``"cpu"`` (default) omits ``session_options`` - from all stages. - soc_model: Snapdragon SoC model number forwarded to the QNN backend - when ``ep="qnn"``. Default ``"60"`` targets Snapdragon 8 Gen 3. - - Returns: - ``(stages, decoder_io)`` — a 4-element :class:`PipelineStage` list and - the :class:`DecoderIOMapping` derived from the introspected tensor names. - """ - ctx_inputs, ctx_outputs = _introspect_onnx_io(Path(context_onnx)) - iter_inputs, iter_outputs = _introspect_onnx_io(Path(iterator_onnx)) - - # Detect per-layer KV format-string patterns in the context model. - input_patterns = _detect_format_patterns(ctx_inputs, num_layers) - output_patterns = _detect_format_patterns(ctx_outputs, num_layers) - - in_sorted = _sort_patterns_by_first_occurrence(input_patterns, ctx_inputs) - out_sorted = _sort_patterns_by_first_occurrence(output_patterns, ctx_outputs) - - past_key_fmt = input_patterns[in_sorted[0]] if len(in_sorted) > 0 else "past_keys_%d" - past_val_fmt = input_patterns[in_sorted[1]] if len(in_sorted) > 1 else "past_values_%d" - pres_key_fmt = output_patterns[out_sorted[0]] if len(out_sorted) > 0 else "present_keys_%d" - pres_val_fmt = output_patterns[out_sorted[1]] if len(out_sorted) > 1 else "present_values_%d" - - # Non-indexed inputs: hidden-state tensor + scalar seq-length scalars. - non_indexed = [n for n in ctx_inputs if not _KV_INDEXED_RE.match(n)] - seq_len_names = [n for n in non_indexed if re.search(r"seq|len", n, re.IGNORECASE)] - hidden_state_in = next( - (n for n in non_indexed if n not in seq_len_names), "input_hidden_states" - ) - past_seq_name = next((n for n in seq_len_names if "past" in n.lower()), "past_seq_len") - total_seq_name = next((n for n in seq_len_names if "total" in n.lower()), "total_seq_len") - - # Non-indexed output: hidden-state output of the transformer stack. - hidden_state_out = next( - (n for n in ctx_outputs if not _KV_INDEXED_RE.match(n)), "output_hidden_states" - ) - - decoder_io = DecoderIOMapping( - past_sequence_length=past_seq_name, - total_sequence_length=total_seq_name, - past_key_names=past_key_fmt, - past_value_names=past_val_fmt, - present_key_names=pres_key_fmt, - present_value_names=pres_val_fmt, - ) - - # Per-stage session_options: NPU stages get QNN config; CPU and others get None. - ctx_session_opts: dict | None = None - iter_session_opts: dict | None = None - if ep == "qnn": - ctx_session_opts = _qnn_stage_session_options( - "onnxruntime-genai.context", soc_model=soc_model - ) - iter_session_opts = _qnn_stage_session_options( - "onnxruntime-genai.iterator", soc_model=soc_model - ) - - stages: list[PipelineStage] = [ - PipelineStage( - name="embeddings", - filename=embeddings_filename, - run_on_prompt=True, - run_on_token_gen=True, - inputs=[decoder_io.input_ids], - outputs=[hidden_state_in], - ), - PipelineStage( - name="context", - filename=context_filename, - run_on_prompt=True, - run_on_token_gen=False, - inputs=ctx_inputs, - outputs=ctx_outputs, - session_options=ctx_session_opts, - ), - PipelineStage( - name="iterator", - filename=iterator_filename, - run_on_prompt=False, - run_on_token_gen=True, - inputs=iter_inputs, - outputs=iter_outputs, - session_options=iter_session_opts, - ), - PipelineStage( - name="lm_head", - filename=lm_head_filename, - run_on_prompt=True, - run_on_token_gen=True, - inputs=[hidden_state_out], - outputs=[decoder_io.logits], - is_lm_head=True, - ), - ] - return stages, decoder_io - - -# --------------------------------------------------------------------------- -# Bundle assembler -# --------------------------------------------------------------------------- - - -def write_genai_bundle( - output_dir: str | Path, - *, - context_onnx: str | Path, - iterator_onnx: str | Path, - model_id: str, - max_cache_len: int, - prefill_seq_len: int, - embeddings_src: str | Path | None = None, - lm_head_src: str | Path | None = None, - context_filename: str = DEFAULT_CONTEXT_FILENAME, - iterator_filename: str = DEFAULT_ITERATOR_FILENAME, - embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, - lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, - ep: str = "cpu", - soc_model: str = "60", -) -> Path: - """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. - - Copies the winml-built transformer ONNX files, placeholder embedding / - lm_head models (when provided), HF tokenizer files, and writes - ``genai_config.json``. Tensor names in the config are derived by - introspecting the built ONNX files rather than being hardcoded. - - Args: - output_dir: Destination directory (created if absent). - context_onnx: Path to the built prefill/context ONNX. - iterator_onnx: Path to the built decode/iterator ONNX. - model_id: HuggingFace model ID or local path for config + tokenizer. - max_cache_len: Static KV cache length (= ``context_length`` in genai). - prefill_seq_len: Prefill sequence length (= ``sliding_window.window_size``). - embeddings_src: Source path of the embeddings ONNX. ``None`` = skip. - lm_head_src: Source path of the lm_head ONNX. ``None`` = skip. - context_filename: Bundle filename for the context model. - iterator_filename: Bundle filename for the iterator model. - embeddings_filename: Bundle filename for the embeddings model. - lm_head_filename: Bundle filename for the lm_head model. - ep: Execution provider for the transformer (context/iterator) stages. - ``"qnn"`` injects QNN HTP ``session_options`` so those stages run - on the NPU while embeddings and lm_head run on CPU. - ``"cpu"`` (default) omits ``session_options`` (all stages on CPU). - soc_model: Snapdragon SoC model passed to the QNN backend when - ``ep="qnn"``. Default ``"60"`` = Snapdragon 8 Gen 3 / X Elite. - - Returns: - Path to the written ``genai_config.json``. - """ - from transformers import AutoConfig, AutoTokenizer - - from winml.modelkit.onnx import copy_onnx_model - - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - context_onnx = Path(context_onnx) - iterator_onnx = Path(iterator_onnx) - - # 1. Copy winml-built transformer ONNX files. - logger.info("Copying context ONNX: %s -> %s", context_onnx.name, context_filename) - copy_onnx_model(context_onnx, output_dir / context_filename) - - logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename) - copy_onnx_model(iterator_onnx, output_dir / iterator_filename) - - # 2. Copy placeholder models (embeddings + lm_head). - if embeddings_src is not None: - logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename) - copy_onnx_model(Path(embeddings_src), output_dir / embeddings_filename) - else: - logger.warning( - "embeddings_src not provided — '%s' is missing from bundle.", - embeddings_filename, - ) - - if lm_head_src is not None: - logger.info("Copying lm_head: %s -> %s", Path(lm_head_src).name, lm_head_filename) - copy_onnx_model(Path(lm_head_src), output_dir / lm_head_filename) - else: - logger.warning( - "lm_head_src not provided — '%s' is missing from bundle.", - lm_head_filename, - ) - - # 3. Save tokenizer files from the HF snapshot. - logger.info("Saving tokenizer from '%s' to %s", model_id, output_dir) - tokenizer = AutoTokenizer.from_pretrained(model_id) - tokenizer.save_pretrained(str(output_dir)) - - # 4. Build pipeline stages by introspecting the source ONNX files. - hf_config = AutoConfig.from_pretrained(model_id) - stages, decoder_io = build_qwen3_transformer_only_stages( - context_onnx, - iterator_onnx, - num_layers=hf_config.num_hidden_layers, - context_filename=context_filename, - iterator_filename=iterator_filename, - embeddings_filename=embeddings_filename, - lm_head_filename=lm_head_filename, - ep=ep, - soc_model=soc_model, - ) - - # 5. Write genai_config.json. - config = build_genai_config( - hf_config, - max_cache_len=max_cache_len, - prefill_seq_len=prefill_seq_len, - pipeline=stages, - decoder_io=decoder_io, - ) - config_path = output_dir / "genai_config.json" - config_path.write_text(json.dumps(config, indent=2), encoding="utf-8") - logger.info("Wrote genai_config.json -> %s", config_path) - - _log_bundle_summary(output_dir, config_path) - return config_path - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- +from winml.modelkit.utils.genai import ( + DEFAULT_CONTEXT_FILENAME, + DEFAULT_EMBEDDINGS_FILENAME, + DEFAULT_ITERATOR_FILENAME, + DEFAULT_LM_HEAD_FILENAME, + DecoderIOMapping, + PipelineStage, + _detect_format_patterns, + build_decoder_pipeline_stages, + build_genai_config, + qnn_stage_session_options, + write_genai_bundle, +) -def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: - """Print a human-readable summary of the assembled bundle.""" - files = sorted(bundle_dir.iterdir()) - lines = [f"\n=== genai bundle: {bundle_dir} ==="] - for f in files: - size_kb = f.stat().st_size / 1024 - tag = "" - if f.name == "genai_config.json": - tag = " <- pipeline config" - elif f.name.endswith(".onnx"): - tag = " <- ONNX graph" - elif f.name.endswith(".data"): - tag = " <- ONNX external weights" - lines.append(f" {f.name:<45} {size_kb:>8.1f} KB{tag}") - lines.append(f"\nConfig written to: {config_path}") - logger.info("\n".join(lines)) +# Backward-compatible alias: existing callers that import +# ``build_qwen3_transformer_only_stages`` continue to work unchanged. +build_qwen3_transformer_only_stages = build_decoder_pipeline_stages +# Keep the internal helper accessible for tests that import it directly. +_qnn_stage_session_options = qnn_stage_session_options __all__ = [ "DEFAULT_CONTEXT_FILENAME", @@ -663,7 +44,10 @@ def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: "DEFAULT_LM_HEAD_FILENAME", "DecoderIOMapping", "PipelineStage", + "_detect_format_patterns", + "build_decoder_pipeline_stages", "build_genai_config", "build_qwen3_transformer_only_stages", + "qnn_stage_session_options", "write_genai_bundle", ] diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index 28555bcec..0bdd7ccd8 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any +from .calibration import get_quant_finalizer from .config import QuantizeResult, WinMLQuantizationConfig from .passes import BaseQuantPass, FP16Pass, RTNPass, StaticPass @@ -49,10 +50,11 @@ ] -# Names below are loaded lazily via ``__getattr__`` to avoid pulling in -# onnxruntime.quantization/torch at import time. The TYPE_CHECKING re-imports -# give static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports -# without triggering the heavy imports at runtime. +# ``quantize_onnx`` is loaded lazily via ``__getattr__`` to avoid pulling in +# onnxruntime.quantization at import time. The TYPE_CHECKING re-import gives +# static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports. +# ``get_quant_finalizer`` is imported directly above — its module chain +# (calibration/__init__ -> registry) is lightweight and safe at import time. if TYPE_CHECKING: from .calibration import get_quant_finalizer from .quantizer import Quantizer, expand_precision, quantize_onnx @@ -62,7 +64,6 @@ "quantize_onnx": (".quantizer", "quantize_onnx"), "Quantizer": (".quantizer", "Quantizer"), "expand_precision": (".quantizer", "expand_precision"), - "get_quant_finalizer": (".calibration", "get_quant_finalizer"), } diff --git a/src/winml/modelkit/quant/calibration/base.py b/src/winml/modelkit/quant/calibration/base.py index 39b9543c5..213199811 100644 --- a/src/winml/modelkit/quant/calibration/base.py +++ b/src/winml/modelkit/quant/calibration/base.py @@ -38,3 +38,4 @@ def finalize( model_id: str | None = None, ) -> WinMLQuantizationConfig: """Return ``quant`` populated with the graph-derived quant settings.""" + ... diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 34e2ef5e4..312a7358c 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -363,17 +363,16 @@ def generate_streaming( stream = self._tokenizer.create_stream() # type: ignore[union-attr] n = 0 - try: - while not generator.is_done(): - generator.generate_next_token() - new_token = generator.get_next_tokens()[0] - yield stream.decode(new_token) - n += 1 - if n >= cfg.max_new_tokens: - break - finally: - # Explicit deletion releases the KV cache buffer held by the generator. - del generator + while not generator.is_done(): + generator.generate_next_token() + new_token = generator.get_next_tokens()[0] + yield stream.decode(new_token) + n += 1 + if n >= cfg.max_new_tokens: + break + # ``generator`` (og.Generator) holds the KV cache buffer; releasing the + # reference here (end of scope) frees it before the caller processes the + # last yielded token, which is earlier than waiting for GC. # ------------------------------------------------------------------ # Chat-template helpers diff --git a/src/winml/modelkit/utils/genai.py b/src/winml/modelkit/utils/genai.py new file mode 100644 index 000000000..33645cdaf --- /dev/null +++ b/src/winml/modelkit/utils/genai.py @@ -0,0 +1,663 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +r"""Generic onnxruntime-genai bundle utilities for decoder-pipeline models. + +The bundle is a directory that ``onnxruntime-genai`` can load directly via +``og.Config(str(bundle_dir))``. It contains: + + genai_config.json — pipeline config consumed by onnxruntime-genai + ctx.onnx — prefill/context ONNX + iter.onnx — iteration/decode ONNX + embeddings.onnx — embedding-lookup ONNX + lm_head.onnx — lm_head ONNX + tokenizer.json — HF tokenizer files (downloaded from the model repo) + tokenizer_config.json + vocab.json / merges.txt / generation_config.json + +The pipeline follows the standard 4-stage decoder layout: + + input_ids → [embeddings] → input_hidden_states + → [context | iterator] → output_hidden_states + present KVs + → [lm_head] → logits + +The context stage runs on the prompt (prefill); the iterator stage runs on each +subsequent decode step. Both share the same KV cache buffer via genai's +``past_present_share_buffer`` mode. + +Public API:: + + from winml.modelkit.utils.genai import ( + build_genai_config, + build_decoder_pipeline_stages, + write_genai_bundle, + DecoderIOMapping, + PipelineStage, + qnn_stage_session_options, + ) + + # Build stages by introspecting the ONNX I/O (no hardcoded tensor names) + stages, decoder_io = build_decoder_pipeline_stages( + ctx_path, iter_path, num_layers=hf_config.num_hidden_layers, ep="qnn" + ) + cfg = build_genai_config( + hf_config, max_cache_len=256, prefill_seq_len=64, + pipeline=stages, decoder_io=decoder_io, + ) + + # Or one-shot bundle assembly + write_genai_bundle( + Path("out/bundle"), + context_onnx=ctx_path, + iterator_onnx=iter_path, + model_id="Qwen/Qwen3-0.6B", + max_cache_len=256, + prefill_seq_len=64, + embeddings_src=emb_path, # None = skip (add later) + lm_head_src=lmh_path, # None = skip (add later) + ep="qnn", + ) +""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Any + + +logger = logging.getLogger(__name__) + +# Default filenames inside the bundle directory. +DEFAULT_EMBEDDINGS_FILENAME = "embeddings.onnx" +DEFAULT_CONTEXT_FILENAME = "ctx.onnx" +DEFAULT_ITERATOR_FILENAME = "iter.onnx" +DEFAULT_LM_HEAD_FILENAME = "lm_head.onnx" + +# Regex for detecting indexed tensor names such as ``past_keys_3``. +_KV_INDEXED_RE = re.compile(r"^(.+?)(\d+)$") + + +# --------------------------------------------------------------------------- +# Pipeline data structures +# --------------------------------------------------------------------------- + + +@dataclass +class PipelineStage: + """One stage in an onnxruntime-genai multi-model pipeline. + + Attributes: + name: Stage key used inside the ``pipeline`` list of ``genai_config.json``. + filename: ONNX filename inside the bundle directory. + run_on_prompt: Whether genai runs this stage during the prefill pass. + run_on_token_gen: Whether genai runs this stage during decode steps. + inputs: Actual ONNX input tensor names (not format strings). + outputs: Actual ONNX output tensor names (not format strings). + is_lm_head: Set ``True`` for the final language-model head stage. + """ + + name: str + filename: str + run_on_prompt: bool + run_on_token_gen: bool + inputs: list[str] + outputs: list[str] + is_lm_head: bool = False + session_options: dict | None = None + """Per-stage ORT session options (e.g. provider_options for QNN). + + When set, emitted verbatim as the ``session_options`` key in the + ``genai_config.json`` pipeline stage. Leave ``None`` (default) for + stages that should run on the default (CPU) provider. + """ + + def to_dict(self) -> dict: + """Serialize to the dict format expected by ``genai_config.json``.""" + d: dict = { + "filename": self.filename, + "inputs": list(self.inputs), + "outputs": list(self.outputs), + "run_on_prompt": self.run_on_prompt, + "run_on_token_gen": self.run_on_token_gen, + } + if self.session_options: + d["session_options"] = self.session_options + if self.is_lm_head: + d["is_lm_head"] = True + return d + + +@dataclass +class DecoderIOMapping: + """Maps genai's abstract I/O concepts to ONNX tensor name format strings. + + The ``*_names`` fields use ``%d`` as the layer-index placeholder, which is + the convention genai uses to expand per-layer KV cache tensor names + (e.g. ``"past_keys_%d"`` → ``"past_keys_0"``, ``"past_keys_1"``, …). + + Defaults match the Qwen3 transformer-only export naming; override any field + when building bundles for models with different tensor names. + """ + + input_ids: str = "input_ids" + past_sequence_length: str = "past_seq_len" + total_sequence_length: str = "total_seq_len" + past_key_names: str = "past_keys_%d" + past_value_names: str = "past_values_%d" + logits: str = "logits" + present_key_names: str = "present_keys_%d" + present_value_names: str = "present_values_%d" + + def inputs_dict(self) -> dict: + """Return the ``decoder.inputs`` mapping dict for ``genai_config.json``.""" + return { + "input_ids": self.input_ids, + "past_sequence_length": self.past_sequence_length, + "total_sequence_length": self.total_sequence_length, + "past_key_names": self.past_key_names, + "past_value_names": self.past_value_names, + } + + def outputs_dict(self) -> dict: + """Return the ``decoder.outputs`` mapping dict for ``genai_config.json``.""" + return { + "logits": self.logits, + "present_key_names": self.present_key_names, + "present_value_names": self.present_value_names, + } + + +# --------------------------------------------------------------------------- +# Generic config builder +# --------------------------------------------------------------------------- + + +def build_genai_config( + hf_config: Any, + *, + max_cache_len: int, + prefill_seq_len: int | None = None, + pipeline: list[PipelineStage], + decoder_io: DecoderIOMapping | None = None, +) -> dict: + """Build a ``genai_config.json`` dict for any decoder-pipeline model. + + This function is architecture-agnostic: the caller supplies the pipeline + stages and the I/O name mapping so no tensor names are hardcoded here. + + Args: + hf_config: A ``transformers.PretrainedConfig``. Reads: + ``num_hidden_layers``, ``hidden_size``, ``num_attention_heads``, + ``num_key_value_heads``, ``head_dim`` (optional, falls back to + ``hidden_size // num_attention_heads``), ``bos_token_id``, + ``eos_token_id``, ``pad_token_id``, ``vocab_size``. + max_cache_len: Static KV cache length → ``context_length`` and + ``search.max_length``. + prefill_seq_len: When given, emits a ``sliding_window`` section with + ``window_size=prefill_seq_len``. Pass ``None`` to omit. + pipeline: Ordered list of :class:`PipelineStage` describing each + model in the genai pipeline. + decoder_io: Format-string mapping from genai's abstract I/O names to + actual ONNX tensor names. Defaults to + :class:`DecoderIOMapping` (the standard names). + + Returns: + A ``dict`` suitable for ``json.dumps`` as ``genai_config.json``. + """ + if decoder_io is None: + decoder_io = DecoderIOMapping() + + num_layers: int = hf_config.num_hidden_layers + head_size: int = getattr( + hf_config, + "head_dim", + hf_config.hidden_size // hf_config.num_attention_heads, + ) + + eos_token_id = hf_config.eos_token_id + if isinstance(eos_token_id, list): + eos_token_id = eos_token_id[0] + + pad_token_id = getattr(hf_config, "pad_token_id", None) or hf_config.bos_token_id + + decoder_section: dict = { + "hidden_size": hf_config.hidden_size, + "num_attention_heads": hf_config.num_attention_heads, + "num_key_value_heads": hf_config.num_key_value_heads, + "num_hidden_layers": num_layers, + "head_size": head_size, + } + + if prefill_seq_len is not None: + decoder_section["sliding_window"] = { + "window_size": prefill_seq_len, + "pad_value": 0, + "alignment": "left", + "slide_inputs": True, + "slide_key_value_cache": False, + } + + decoder_section["inputs"] = decoder_io.inputs_dict() + decoder_section["outputs"] = decoder_io.outputs_dict() + decoder_section["pipeline"] = [{s.name: s.to_dict()} for s in pipeline] + + return { + "model": { + "type": "decoder-pipeline", + "bos_token_id": hf_config.bos_token_id, + "eos_token_id": eos_token_id, + "pad_token_id": pad_token_id, + "vocab_size": hf_config.vocab_size, + "context_length": max_cache_len, + "decoder": decoder_section, + }, + "search": { + "max_length": max_cache_len, + "min_length": 0, + "do_sample": False, + "past_present_share_buffer": True, + }, + } + + +# --------------------------------------------------------------------------- +# ONNX introspection helpers +# --------------------------------------------------------------------------- + + +def _introspect_onnx_io(onnx_path: Path) -> tuple[list[str], list[str]]: + """Return ``(input_names, output_names)`` from an ONNX model graph header. + + External data is intentionally not loaded — only the graph topology is read, + so this is fast even for large quantized models. + """ + try: + import onnx + except ImportError as exc: + raise ImportError( + "The 'onnx' package is required for ONNX introspection. " + "Install it with: pip install onnx" + ) from exc + model = onnx.load(str(onnx_path), load_external_data=False) + return ( + [inp.name for inp in model.graph.input], + [out.name for out in model.graph.output], + ) + + +def _detect_format_patterns(names: list[str], num_layers: int) -> dict[str, str]: + """Detect ``prefix%d`` patterns from a list of indexed tensor names. + + Scans *names* for entries matching ```` where exactly + *num_layers* consecutive zero-based indices are present. + + Returns: + ``{prefix: "prefix%d"}`` for each qualifying group, in the order the + prefixes first appear in *names*. Only groups covering the full + ``[0, num_layers)`` index range are returned. + + Examples:: + + >>> _detect_format_patterns( + ... ["past_keys_0", "past_keys_1", "past_values_0", "past_values_1"], + ... num_layers=2, + ... ) + {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"} + """ + groups: dict[str, list[int]] = {} + for name in names: + m = _KV_INDEXED_RE.match(name) + if m: + prefix, idx = m.group(1), int(m.group(2)) + groups.setdefault(prefix, []).append(idx) + + return { + prefix: f"{prefix}%d" + for prefix, indices in groups.items() + if len(indices) == num_layers and sorted(indices) == list(range(num_layers)) + } + + +def _sort_patterns_by_first_occurrence(patterns: dict[str, str], names: list[str]) -> list[str]: + """Sort *patterns* keys by when ``0`` first appears in *names*.""" + + def _key(prefix: str) -> int: + try: + return names.index(f"{prefix}0") + except ValueError: + return len(names) + + return sorted(patterns.keys(), key=_key) + + +# --------------------------------------------------------------------------- +# Per-EP stage session_options helpers +# --------------------------------------------------------------------------- + + +def qnn_stage_session_options(log_id: str, soc_model: str = "60") -> dict: + """Return the ``session_options`` block that routes a stage to QNN HTP. + + Args: + log_id: ORT log identifier (shown in ORT logs), e.g. + ``"onnxruntime-genai.context"``. + soc_model: Snapdragon SoC model number passed to the QNN HTP backend. + ``"60"`` targets Snapdragon 8 Gen 3 (X Elite). Change for other + SoCs (e.g. ``"55"`` for 8 Gen 2, ``"73"`` for 8 Elite). + + Returns: + Dict suitable for the ``session_options`` key of a pipeline stage in + ``genai_config.json``. + """ + return { + "log_id": log_id, + "provider_options": [ + { + "qnn": { + "backend_path": "QnnHtp.dll", + "htp_performance_mode": "burst", + "htp_graph_finalization_optimization_mode": "3", + "soc_model": soc_model, + } + } + ], + "intra_op_num_threads": 2, + "inter_op_num_threads": 1, + } + + +# --------------------------------------------------------------------------- +# Generic decoder-pipeline stage factory +# --------------------------------------------------------------------------- + + +def build_decoder_pipeline_stages( + context_onnx: str | Path, + iterator_onnx: str | Path, + num_layers: int, + *, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", +) -> tuple[list[PipelineStage], DecoderIOMapping]: + """Build pipeline stages by introspecting the built ONNX models. + + Reads actual tensor names from *context_onnx* and *iterator_onnx* so the + generated ``genai_config.json`` can never drift out of sync with the real + model I/O — no tensor names are hardcoded. + + Args: + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + num_layers: Number of transformer layers (``hf_config.num_hidden_layers``). + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. + ep: Execution provider for the transformer stages. ``"qnn"`` injects + QNN HTP ``session_options`` into the ``context`` and ``iterator`` + stages so they run on the NPU while ``embeddings`` and ``lm_head`` + continue on CPU. ``"cpu"`` (default) omits ``session_options`` + from all stages. + soc_model: Snapdragon SoC model number forwarded to the QNN backend + when ``ep="qnn"``. Default ``"60"`` targets Snapdragon 8 Gen 3. + + Returns: + ``(stages, decoder_io)`` — a 4-element :class:`PipelineStage` list and + the :class:`DecoderIOMapping` derived from the introspected tensor names. + """ + ctx_inputs, ctx_outputs = _introspect_onnx_io(Path(context_onnx)) + iter_inputs, iter_outputs = _introspect_onnx_io(Path(iterator_onnx)) + + # Detect per-layer KV format-string patterns in the context model. + input_patterns = _detect_format_patterns(ctx_inputs, num_layers) + output_patterns = _detect_format_patterns(ctx_outputs, num_layers) + + in_sorted = _sort_patterns_by_first_occurrence(input_patterns, ctx_inputs) + out_sorted = _sort_patterns_by_first_occurrence(output_patterns, ctx_outputs) + + past_key_fmt = input_patterns[in_sorted[0]] if len(in_sorted) > 0 else "past_keys_%d" + past_val_fmt = input_patterns[in_sorted[1]] if len(in_sorted) > 1 else "past_values_%d" + pres_key_fmt = output_patterns[out_sorted[0]] if len(out_sorted) > 0 else "present_keys_%d" + pres_val_fmt = output_patterns[out_sorted[1]] if len(out_sorted) > 1 else "present_values_%d" + + # Non-indexed inputs: hidden-state tensor + scalar seq-length scalars. + non_indexed = [n for n in ctx_inputs if not _KV_INDEXED_RE.match(n)] + seq_len_names = [n for n in non_indexed if re.search(r"seq|len", n, re.IGNORECASE)] + hidden_state_in = next( + (n for n in non_indexed if n not in seq_len_names), "input_hidden_states" + ) + past_seq_name = next((n for n in seq_len_names if "past" in n.lower()), "past_seq_len") + total_seq_name = next((n for n in seq_len_names if "total" in n.lower()), "total_seq_len") + + # Non-indexed output: hidden-state output of the transformer stack. + hidden_state_out = next( + (n for n in ctx_outputs if not _KV_INDEXED_RE.match(n)), "output_hidden_states" + ) + + decoder_io = DecoderIOMapping( + past_sequence_length=past_seq_name, + total_sequence_length=total_seq_name, + past_key_names=past_key_fmt, + past_value_names=past_val_fmt, + present_key_names=pres_key_fmt, + present_value_names=pres_val_fmt, + ) + + # Per-stage session_options: NPU stages get QNN config; CPU and others get None. + ctx_session_opts: dict | None = None + iter_session_opts: dict | None = None + if ep == "qnn": + ctx_session_opts = qnn_stage_session_options( + "onnxruntime-genai.context", soc_model=soc_model + ) + iter_session_opts = qnn_stage_session_options( + "onnxruntime-genai.iterator", soc_model=soc_model + ) + + stages: list[PipelineStage] = [ + PipelineStage( + name="embeddings", + filename=embeddings_filename, + run_on_prompt=True, + run_on_token_gen=True, + inputs=[decoder_io.input_ids], + outputs=[hidden_state_in], + ), + PipelineStage( + name="context", + filename=context_filename, + run_on_prompt=True, + run_on_token_gen=False, + inputs=ctx_inputs, + outputs=ctx_outputs, + session_options=ctx_session_opts, + ), + PipelineStage( + name="iterator", + filename=iterator_filename, + run_on_prompt=False, + run_on_token_gen=True, + inputs=iter_inputs, + outputs=iter_outputs, + session_options=iter_session_opts, + ), + PipelineStage( + name="lm_head", + filename=lm_head_filename, + run_on_prompt=True, + run_on_token_gen=True, + inputs=[hidden_state_out], + outputs=[decoder_io.logits], + is_lm_head=True, + ), + ] + return stages, decoder_io + + +# --------------------------------------------------------------------------- +# Bundle assembler +# --------------------------------------------------------------------------- + + +def write_genai_bundle( + output_dir: str | Path, + *, + context_onnx: str | Path, + iterator_onnx: str | Path, + model_id: str, + max_cache_len: int, + prefill_seq_len: int, + embeddings_src: str | Path | None = None, + lm_head_src: str | Path | None = None, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", +) -> Path: + """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. + + Copies the winml-built transformer ONNX files, optional embedding / + lm_head models, HF tokenizer files, and writes ``genai_config.json``. + Tensor names in the config are derived by introspecting the built ONNX + files rather than being hardcoded, so this works for any model that + follows the 4-stage decoder-pipeline layout. + + Args: + output_dir: Destination directory (created if absent). + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + model_id: HuggingFace model ID or local path for config + tokenizer. + max_cache_len: Static KV cache length (= ``context_length`` in genai). + prefill_seq_len: Prefill sequence length (= ``sliding_window.window_size``). + embeddings_src: Source path of the embeddings ONNX. ``None`` = skip. + lm_head_src: Source path of the lm_head ONNX. ``None`` = skip. + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. + ep: Execution provider for the transformer (context/iterator) stages. + ``"qnn"`` injects QNN HTP ``session_options`` so those stages run + on the NPU while embeddings and lm_head run on CPU. + ``"cpu"`` (default) omits ``session_options`` (all stages on CPU). + soc_model: Snapdragon SoC model passed to the QNN backend when + ``ep="qnn"``. Default ``"60"`` = Snapdragon 8 Gen 3 / X Elite. + + Returns: + Path to the written ``genai_config.json``. + """ + from transformers import AutoConfig, AutoTokenizer + + from winml.modelkit.onnx import copy_onnx_model + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + context_onnx = Path(context_onnx) + iterator_onnx = Path(iterator_onnx) + + # 1. Copy winml-built transformer ONNX files. + logger.info("Copying context ONNX: %s -> %s", context_onnx.name, context_filename) + copy_onnx_model(context_onnx, output_dir / context_filename) + + logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename) + copy_onnx_model(iterator_onnx, output_dir / iterator_filename) + + # 2. Copy placeholder models (embeddings + lm_head). + if embeddings_src is not None: + logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename) + copy_onnx_model(Path(embeddings_src), output_dir / embeddings_filename) + else: + logger.warning( + "embeddings_src not provided — '%s' is missing from bundle.", + embeddings_filename, + ) + + if lm_head_src is not None: + logger.info("Copying lm_head: %s -> %s", Path(lm_head_src).name, lm_head_filename) + copy_onnx_model(Path(lm_head_src), output_dir / lm_head_filename) + else: + logger.warning( + "lm_head_src not provided — '%s' is missing from bundle.", + lm_head_filename, + ) + + # 3. Save tokenizer files from the HF snapshot. + logger.info("Saving tokenizer from '%s' to %s", model_id, output_dir) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(str(output_dir)) + + # 4. Build pipeline stages by introspecting the source ONNX files. + hf_config = AutoConfig.from_pretrained(model_id) + stages, decoder_io = build_decoder_pipeline_stages( + context_onnx, + iterator_onnx, + num_layers=hf_config.num_hidden_layers, + context_filename=context_filename, + iterator_filename=iterator_filename, + embeddings_filename=embeddings_filename, + lm_head_filename=lm_head_filename, + ep=ep, + soc_model=soc_model, + ) + + # 5. Write genai_config.json. + config = build_genai_config( + hf_config, + max_cache_len=max_cache_len, + prefill_seq_len=prefill_seq_len, + pipeline=stages, + decoder_io=decoder_io, + ) + config_path = output_dir / "genai_config.json" + config_path.write_text(json.dumps(config, indent=2), encoding="utf-8") + logger.info("Wrote genai_config.json -> %s", config_path) + + _log_bundle_summary(output_dir, config_path) + return config_path + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: + """Print a human-readable summary of the assembled bundle.""" + files = sorted(bundle_dir.iterdir()) + lines = [f"\n=== genai bundle: {bundle_dir} ==="] + for f in files: + size_kb = f.stat().st_size / 1024 + tag = "" + if f.name == "genai_config.json": + tag = " <- pipeline config" + elif f.name.endswith(".onnx"): + tag = " <- ONNX graph" + elif f.name.endswith(".data"): + tag = " <- ONNX external weights" + lines.append(f" {f.name:<45} {size_kb:>8.1f} KB{tag}") + lines.append(f"\nConfig written to: {config_path}") + logger.info("\n".join(lines)) + + +__all__ = [ + "DEFAULT_CONTEXT_FILENAME", + "DEFAULT_EMBEDDINGS_FILENAME", + "DEFAULT_ITERATOR_FILENAME", + "DEFAULT_LM_HEAD_FILENAME", + "DecoderIOMapping", + "PipelineStage", + "build_decoder_pipeline_stages", + "build_genai_config", + "qnn_stage_session_options", + "write_genai_bundle", +] diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py index 900f8b664..e359f02e2 100644 --- a/tests/unit/models/qwen3/test_genai_config.py +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -387,7 +387,7 @@ def _patch_onnx(self, n: int = 4): ctx_io = (self._ctx_inputs(n), self._ctx_outputs(n)) iter_io = (self._ctx_inputs(n), self._ctx_outputs(n)) return patch( - "winml.modelkit.models.hf.qwen3.genai._introspect_onnx_io", + "winml.modelkit.utils.genai._introspect_onnx_io", side_effect=[ctx_io, iter_io], ) diff --git a/tests/unit/quant/calibration/test_qwen3_calibration.py b/tests/unit/quant/calibration/test_qwen3_calibration.py index ad53ef352..6881f0f72 100644 --- a/tests/unit/quant/calibration/test_qwen3_calibration.py +++ b/tests/unit/quant/calibration/test_qwen3_calibration.py @@ -16,8 +16,9 @@ from types import SimpleNamespace import numpy as np -import onnx import torch +from onnx import TensorProto, helper +from onnx import save as onnx_save from winml.modelkit.quant.calibration.qwen3_transformer_only import ( Qwen3DecodeTrajectoryCalibReader, @@ -47,27 +48,27 @@ def _fake_config() -> SimpleNamespace: def _build_tiny_onnx(path, *, seq_len: int, max_cache_len: int) -> None: """Write a minimal graph carrying the inputs the readers introspect.""" inputs = [ - onnx.helper.make_tensor_value_info( - "input_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] + helper.make_tensor_value_info( + "input_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] ), - onnx.helper.make_tensor_value_info( - "past_keys_0", onnx.TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] + helper.make_tensor_value_info( + "past_keys_0", TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] ), ] - out = onnx.helper.make_tensor_value_info( - "output_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] + out = helper.make_tensor_value_info( + "output_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] ) - gqa = onnx.helper.make_node( + gqa = helper.make_node( "GroupQueryAttention", ["input_hidden_states"], ["attn_out"], name="gqa_layer_0", domain="com.microsoft", ) - identity = onnx.helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) - graph = onnx.helper.make_graph([gqa, identity], "tiny", inputs, [out]) - model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 18)]) - onnx.save(model, str(path)) + identity = helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) + graph = helper.make_graph([gqa, identity], "tiny", inputs, [out]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + onnx_save(model, str(path)) def test_graph_shapes_and_gqa_nodes(tmp_path): From aa5e7d138133641501f896df98b75b90cca29710 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 1 Jul 2026 08:43:49 +0800 Subject: [PATCH 13/29] fix: address code review issues in genai bundle and session MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - fix(_mirror_non_onnx_files): skip .onnx/.data files to avoid duplicating multi-GB model weights into _compiled/ on first --compile run - fix(generate_streaming): restore try/finally around og.Generator so the KV cache buffer is freed immediately on early caller exit (GeneratorExit), not deferred until GC - fix(build_genai_config): preserve eos_token_id list unchanged — ORT genai accepts a JSON array; truncating to [0] silently discards secondary stop tokens (e.g. Qwen3's [151645, 151643]) - fix(build_decoder_pipeline_stages): use name-based KV pattern matching ('key'/'val' in prefix) instead of purely positional, so models that list past_values before past_keys in their ONNX graph don't get a silent swap - fix(qwen3/genai __all__): remove private _detect_format_patterns from __all__; tests now import it directly from winml.modelkit.utils.genai - test: update test_eos_token_id_list_preserved to expect full list --- src/winml/modelkit/models/hf/qwen3/genai.py | 5 ++- src/winml/modelkit/session/genai_session.py | 35 +++++++++++------- src/winml/modelkit/utils/genai.py | 37 ++++++++++++++++---- tests/unit/models/qwen3/test_genai_config.py | 8 +++-- 4 files changed, 60 insertions(+), 25 deletions(-) diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py index 9e65908f5..f99e473b3 100644 --- a/src/winml/modelkit/models/hf/qwen3/genai.py +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -22,7 +22,6 @@ DEFAULT_LM_HEAD_FILENAME, DecoderIOMapping, PipelineStage, - _detect_format_patterns, build_decoder_pipeline_stages, build_genai_config, qnn_stage_session_options, @@ -34,7 +33,8 @@ # ``build_qwen3_transformer_only_stages`` continue to work unchanged. build_qwen3_transformer_only_stages = build_decoder_pipeline_stages -# Keep the internal helper accessible for tests that import it directly. +# Keep the private EP helper importable under its old name for any callers +# that referenced it before the rename. _qnn_stage_session_options = qnn_stage_session_options __all__ = [ @@ -44,7 +44,6 @@ "DEFAULT_LM_HEAD_FILENAME", "DecoderIOMapping", "PipelineStage", - "_detect_format_patterns", "build_decoder_pipeline_stages", "build_genai_config", "build_qwen3_transformer_only_stages", diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 312a7358c..05f32c313 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -363,16 +363,20 @@ def generate_streaming( stream = self._tokenizer.create_stream() # type: ignore[union-attr] n = 0 - while not generator.is_done(): - generator.generate_next_token() - new_token = generator.get_next_tokens()[0] - yield stream.decode(new_token) - n += 1 - if n >= cfg.max_new_tokens: - break - # ``generator`` (og.Generator) holds the KV cache buffer; releasing the - # reference here (end of scope) frees it before the caller processes the - # last yielded token, which is earlier than waiting for GC. + try: + while not generator.is_done(): + generator.generate_next_token() + new_token = generator.get_next_tokens()[0] + yield stream.decode(new_token) + n += 1 + if n >= cfg.max_new_tokens: + break + finally: + # Explicit deletion releases the KV cache buffer held by the generator. + # This runs whether the caller exhausts the iterator normally, breaks + # out early, or the generator is garbage-collected — preventing the NPU + # memory from being held until a later GC cycle. + del generator # ------------------------------------------------------------------ # Chat-template helpers @@ -621,12 +625,19 @@ def _mirror_non_onnx_files(self, compiled_dir: Path) -> None: """Create symlinks (or copies on Windows) for every non-ONNX file. Files are linked/copied into *compiled_dir* so that ``og.Config`` - finds tokenizer files, specials maps, etc. Existing files are left - untouched. + finds tokenizer files, specials maps, etc. ONNX files are intentionally + skipped — compiled stages land at different filenames inside *compiled_dir*, + and non-compiled stages fall back to their original path via an absolute + filename written into the modified genai_config.json. Existing files are + left untouched. """ for src in self._bundle_dir.iterdir(): if src.name == self._COMPILED_SUBDIR: continue + if src.suffix in (".onnx", ".data"): + # Skip model files — compiled stages are already at their new paths; + # large ONNX weights (potentially several GB) must not be duplicated. + continue dst = compiled_dir / src.name if dst.exists(): continue diff --git a/src/winml/modelkit/utils/genai.py b/src/winml/modelkit/utils/genai.py index 33645cdaf..f0b947b81 100644 --- a/src/winml/modelkit/utils/genai.py +++ b/src/winml/modelkit/utils/genai.py @@ -219,9 +219,11 @@ def build_genai_config( hf_config.hidden_size // hf_config.num_attention_heads, ) - eos_token_id = hf_config.eos_token_id - if isinstance(eos_token_id, list): - eos_token_id = eos_token_id[0] + eos_token_id: int | list[int] = hf_config.eos_token_id + # Pass lists through unchanged — ORT genai accepts a JSON array of EOS token + # IDs and treats any of them as a valid stop signal. Truncating to [0] would + # silently discard secondary EOS tokens (e.g. Qwen3 uses [151645, 151643]) + # and cause generation to run until max_length instead of stopping early. pad_token_id = getattr(hf_config, "pad_token_id", None) or hf_config.bos_token_id @@ -424,10 +426,31 @@ def build_decoder_pipeline_stages( in_sorted = _sort_patterns_by_first_occurrence(input_patterns, ctx_inputs) out_sorted = _sort_patterns_by_first_occurrence(output_patterns, ctx_outputs) - past_key_fmt = input_patterns[in_sorted[0]] if len(in_sorted) > 0 else "past_keys_%d" - past_val_fmt = input_patterns[in_sorted[1]] if len(in_sorted) > 1 else "past_values_%d" - pres_key_fmt = output_patterns[out_sorted[0]] if len(out_sorted) > 0 else "present_keys_%d" - pres_val_fmt = output_patterns[out_sorted[1]] if len(out_sorted) > 1 else "present_values_%d" + # Assign key/value patterns by name (look for "key"/"val" in the prefix), + # falling back to positional order only when names are ambiguous. Pure + # positional assignment would silently swap KV if a model lists values + # before keys in its ONNX graph. + def _pick_kv( + sorted_prefixes: list[str], + patterns: dict[str, str], + key_default: str, + val_default: str, + ) -> tuple[str, str]: + key_prefix = next((p for p in sorted_prefixes if "key" in p.lower()), None) + val_prefix = next((p for p in sorted_prefixes if "val" in p.lower()), None) + if key_prefix and val_prefix: + return patterns[key_prefix], patterns[val_prefix] + # Fallback: positional (preserves original behaviour for unambiguous names) + key_fmt = patterns[sorted_prefixes[0]] if len(sorted_prefixes) > 0 else key_default + val_fmt = patterns[sorted_prefixes[1]] if len(sorted_prefixes) > 1 else val_default + return key_fmt, val_fmt + + past_key_fmt, past_val_fmt = _pick_kv( + in_sorted, input_patterns, "past_keys_%d", "past_values_%d" + ) + pres_key_fmt, pres_val_fmt = _pick_kv( + out_sorted, output_patterns, "present_keys_%d", "present_values_%d" + ) # Non-indexed inputs: hidden-state tensor + scalar seq-length scalars. non_indexed = [n for n in ctx_inputs if not _KV_INDEXED_RE.match(n)] diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py index e359f02e2..1e6051d5d 100644 --- a/tests/unit/models/qwen3/test_genai_config.py +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -16,10 +16,10 @@ DEFAULT_LM_HEAD_FILENAME, DecoderIOMapping, PipelineStage, - _detect_format_patterns, build_genai_config, build_qwen3_transformer_only_stages, ) +from winml.modelkit.utils.genai import _detect_format_patterns # --------------------------------------------------------------------------- @@ -260,12 +260,14 @@ def test_custom_filenames(self) -> None: assert pipeline[2]["iterator"]["filename"] == "decode.onnx" assert pipeline[3]["lm_head"]["filename"] == "head.onnx" - def test_eos_token_id_list_unpacked(self) -> None: + def test_eos_token_id_list_preserved(self) -> None: cfg = _mock_config(eos_token_id=[151645, 151643]) result = build_genai_config( cfg, max_cache_len=256, prefill_seq_len=64, pipeline=_make_pipeline() ) - assert result["model"]["eos_token_id"] == 151645 + # ORT genai accepts a list of EOS token IDs; all must be preserved so that + # any secondary stop token (e.g. 151643 in some Qwen3 variants) is honoured. + assert result["model"]["eos_token_id"] == [151645, 151643] def test_head_size_derived_when_head_dim_missing(self) -> None: cfg = SimpleNamespace( From 18a8f03cdd0b97cc63146d15bcb3c2e1f5c1c9f5 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 1 Jul 2026 09:11:09 +0800 Subject: [PATCH 14/29] fix: import rule violations, fallback-stage path bug, and div-by-zero in perf print - src/: convert all absolute winml.modelkit.* imports to relative - qwen3/genai.py: from ....utils.genai import - utils/genai.py: from ..onnx import copy_onnx_model - session/genai_session.py: from .ep_registry / from ..winml (subprocess worker) - genai_session: patch failed-stage filename to absolute src path so ort-genai can resolve it when loading from compiled_dir (was crashing) - infer_genai.py: guard n/dt with max(dt,1e-9) to avoid ZeroDivisionError - tests: import GenaiSession symbols from package __init__ not submodule --- .gitignore | 4 ++++ scripts/infer_genai.py | 2 +- src/winml/modelkit/models/hf/qwen3/genai.py | 2 +- src/winml/modelkit/session/genai_session.py | 7 +++++-- src/winml/modelkit/utils/genai.py | 2 +- tests/unit/models/qwen3/test_genai_config.py | 10 ++++++---- tests/unit/session/test_genai_session.py | 2 +- 7 files changed, 19 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 2184e9f72..17d96986b 100644 --- a/.gitignore +++ b/.gitignore @@ -241,6 +241,10 @@ nul export_config.json *_perf.json shape_config.json +/out/ + +# ONNX external data files at repo root (e.g. EPContext .data blobs) +/*.data # UV / pip uv.lock diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py index 47aa6683e..fbf535500 100644 --- a/scripts/infer_genai.py +++ b/scripts/infer_genai.py @@ -143,7 +143,7 @@ def main(argv: list[str] | None = None) -> int: n += 1 dt = time.monotonic() - t0 - print(f"\n\n[done] {n} tokens in {dt:.1f}s ({n / dt:.1f} tok/s)") + print(f"\n\n[done] {n} tokens in {dt:.1f}s ({n / max(dt, 1e-9):.1f} tok/s)") return 0 diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py index f99e473b3..e6d583e4c 100644 --- a/src/winml/modelkit/models/hf/qwen3/genai.py +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -15,7 +15,7 @@ from __future__ import annotations -from winml.modelkit.utils.genai import ( +from ....utils.genai import ( DEFAULT_CONTEXT_FILENAME, DEFAULT_EMBEDDINGS_FILENAME, DEFAULT_ITERATOR_FILENAME, diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 05f32c313..3a2738a77 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -75,8 +75,8 @@ def _qnn_compile_worker(src: str, dst: str, qnn_options: dict) -> None: """ import onnxruntime as ort - from winml.modelkit.session.ep_registry import WinMLEPRegistry - from winml.modelkit.winml import add_ep_for_device + from ..winml import add_ep_for_device + from .ep_registry import WinMLEPRegistry registry = WinMLEPRegistry.get_instance() registry.register_execution_providers() @@ -530,6 +530,9 @@ def _prepare_compiled_bundle(self) -> Path: logger.warning( "Stage %r: compilation failed; using original ONNX (JIT fallback)", stage_key ) + # Patch to the absolute source path so ort-genai can find the + # file when loading from compiled_dir (where it doesn't exist). + self._patch_stage_filename(modified_cfg, stage_key, str(src_onnx.resolve())) if not any_compiled: return self._bundle_dir diff --git a/src/winml/modelkit/utils/genai.py b/src/winml/modelkit/utils/genai.py index f0b947b81..6717406b1 100644 --- a/src/winml/modelkit/utils/genai.py +++ b/src/winml/modelkit/utils/genai.py @@ -581,7 +581,7 @@ def write_genai_bundle( """ from transformers import AutoConfig, AutoTokenizer - from winml.modelkit.onnx import copy_onnx_model + from ..onnx import copy_onnx_model output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py index 1e6051d5d..747d9c92d 100644 --- a/tests/unit/models/qwen3/test_genai_config.py +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -9,15 +9,17 @@ from types import SimpleNamespace from unittest.mock import patch +from winml.modelkit.models.hf.qwen3 import ( + DecoderIOMapping, + PipelineStage, + build_genai_config, + build_qwen3_transformer_only_stages, +) from winml.modelkit.models.hf.qwen3.genai import ( DEFAULT_CONTEXT_FILENAME, DEFAULT_EMBEDDINGS_FILENAME, DEFAULT_ITERATOR_FILENAME, DEFAULT_LM_HEAD_FILENAME, - DecoderIOMapping, - PipelineStage, - build_genai_config, - build_qwen3_transformer_only_stages, ) from winml.modelkit.utils.genai import _detect_format_patterns diff --git a/tests/unit/session/test_genai_session.py b/tests/unit/session/test_genai_session.py index 4dcf7ea1c..5dc496b49 100644 --- a/tests/unit/session/test_genai_session.py +++ b/tests/unit/session/test_genai_session.py @@ -17,7 +17,7 @@ import pytest -from winml.modelkit.session.genai_session import ( +from winml.modelkit.session import ( GenaiLoadError, GenaiNotInstalledError, GenaiSession, From f729f929ff708aecb33806e1956608bb3e86c467 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 1 Jul 2026 10:44:59 +0800 Subject: [PATCH 15/29] feat(genai): auto-build embeddings+lm_head in export script using new model types Now that qwen3_embeddings_only and qwen3_lm_head_only are available (merged from main via PR #1008), remove the placeholder pattern from the genai bundle assembly: - export_qwen3_transformer_only.py: when --genai-bundle is set, automatically build embeddings (fp32) and lm_head (w4a32/MatMulNBits) via WinMLAutoModel if --embeddings / --lm-head override paths are not provided - --embeddings / --lm-head flags are kept as optional override paths for callers that want to supply a pre-built ONNX instead of building from model_id - Both companion models are built on CPU (task=feature-extraction, no_compile) since they run on CPU in the genai pipeline - Drop the now-stale WARNING messages about missing embeddings/lm_head --- scripts/export_qwen3_transformer_only.py | 78 ++++++++++++++++++------ 1 file changed, 58 insertions(+), 20 deletions(-) diff --git a/scripts/export_qwen3_transformer_only.py b/scripts/export_qwen3_transformer_only.py index 856123be2..903aef825 100644 --- a/scripts/export_qwen3_transformer_only.py +++ b/scripts/export_qwen3_transformer_only.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""One-shot export of the Qwen3 transformer-only prefill + decode pair. +r"""One-shot export of the Qwen3 transformer-only prefill + decode pair. Leverages the registered ``WinMLQwen3TransformerOnlyModel`` composite to build BOTH transformer-only sub-models in a single call: @@ -27,6 +27,10 @@ uv run python scripts/export_qwen3_transformer_only.py \ --model-id Qwen/Qwen3-0.6B --device npu \ --max-cache-len 256 --prefill-seq-len 64 --force-rebuild + + # Assemble a complete genai bundle (auto-builds embeddings + lm_head): + uv run python scripts/export_qwen3_transformer_only.py \\ + --device npu --genai-bundle out/bundle """ from __future__ import annotations @@ -39,12 +43,21 @@ import onnx +from winml.modelkit.models.auto import WinMLAutoModel from winml.modelkit.models.hf.qwen3.qwen_transformer_only import ( WinMLQwen3TransformerOnlyModel, ) from winml.modelkit.onnx import copy_onnx_model +# Build settings for the two companion sub-models. Embeddings stay float; +# lm_head is weight-only int4 (MatMulNBits / RTN). +_COMPANION_COMPONENTS: dict[str, dict[str, str]] = { + "embeddings": {"model_type": "qwen3_embeddings_only", "precision": "fp32"}, + "lm_head": {"model_type": "qwen3_lm_head_only", "precision": "w4a32"}, +} + + # Component name -> output file stem used when --output-dir is given. _OUTPUT_STEMS = { "decoder_prefill": "prefill", @@ -133,9 +146,11 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: metavar="DIR", help=( "If set, assemble a complete onnxruntime-genai bundle in DIR: " - "ctx.onnx (prefill), iter.onnx (decode), genai_config.json, and " - "tokenizer files. Provide --embeddings and --lm-head to include " - "the placeholder models required for end-to-end inference." + "ctx.onnx (prefill), iter.onnx (decode), embeddings.onnx, " + "lm_head.onnx, genai_config.json, and tokenizer files. " + "Embeddings (fp32) and lm_head (w4a32) are built automatically " + "from --model-id; use --embeddings / --lm-head to override with " + "a pre-built ONNX path instead." ), ) genai.add_argument( @@ -144,8 +159,8 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: default=None, metavar="ONNX", help=( - "Path to the embeddings ONNX to copy into the genai bundle as " - "embeddings.onnx. Required for end-to-end genai inference." + "Override path to the embeddings ONNX. When omitted and " + "--genai-bundle is set, the embeddings model is built automatically." ), ) genai.add_argument( @@ -154,8 +169,8 @@ def parse_args(argv: list[str] | None = None) -> argparse.Namespace: default=None, metavar="ONNX", help=( - "Path to the lm_head ONNX to copy into the genai bundle as " - "lm_head.onnx. Required for end-to-end genai inference." + "Override path to the lm_head ONNX. When omitted and " + "--genai-bundle is set, the lm_head model is built automatically." ), ) return p.parse_args(argv) @@ -210,6 +225,39 @@ def main(argv: list[str] | None = None) -> int: prefill_path = Path(model.sub_models["decoder_prefill"].onnx_path) decode_path = Path(model.sub_models["decoder_gen"].onnx_path) + # Resolve embeddings / lm_head: use override paths when provided, + # otherwise build them automatically from the same model_id. + embeddings_src = args.embeddings + lm_head_src = args.lm_head + + for key, override in (("embeddings", embeddings_src), ("lm_head", lm_head_src)): + if override is not None: + print(f"\n=== using provided {key} ONNX: {override} ===") + else: + spec = _COMPANION_COMPONENTS[key] + print( + f"\n=== building {key} " + f"(model_type={spec['model_type']}, precision={spec['precision']}) ===" + ) + companion = WinMLAutoModel.from_pretrained( + args.model_id, + task="feature-extraction", + model_type=spec["model_type"], + precision=spec["precision"], + device="cpu", + ep=_DEVICE_TO_EP["cpu"], + no_compile=True, + use_cache=True, + force_rebuild=args.force_rebuild, + shape_config={"seq_len": args.prefill_seq_len}, + ) + companion_path = Path(companion.onnx_path) + print(f" [{key}] {companion_path}") + if key == "embeddings": + embeddings_src = companion_path + else: + lm_head_src = companion_path + print(f"\n=== assembling genai bundle -> {args.genai_bundle} ===") config_path = write_genai_bundle( args.genai_bundle, @@ -218,21 +266,11 @@ def main(argv: list[str] | None = None) -> int: model_id=args.model_id, max_cache_len=args.max_cache_len, prefill_seq_len=args.prefill_seq_len, - embeddings_src=args.embeddings, - lm_head_src=args.lm_head, + embeddings_src=embeddings_src, + lm_head_src=lm_head_src, ep="qnn" if args.device == "npu" else args.device, ) print(f" genai_config.json -> {config_path}") - if args.embeddings is None: - print( - " WARNING: --embeddings not provided; " - "add embeddings.onnx to the bundle before inference." - ) - if args.lm_head is None: - print( - " WARNING: --lm-head not provided; " - "add lm_head.onnx to the bundle before inference." - ) return 0 From ebef5cfa3c18186052dcf8120219c22c45b22d3f Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 1 Jul 2026 12:39:41 +0800 Subject: [PATCH 16/29] fix(genai): patch embeddings+lm_head seq_len to dynamic; revert cpu-override; fix shape_config MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - utils/genai.py: add _patch_seq_dim_dynamic helper; apply it to both embeddings and lm_head ONNX after copy in write_genai_bundle — ort-genai calls these models with prompt_len tokens on prefill and seq_len=1 on each decode step, so the seq_len dimension must be symbolic not fixed - session/genai_session.py: revert _prepare_cpu_bundle and the ep==cpu hook (GenaiSession uses genai_config.json as-is; cpu override not supported) - export_qwen3_transformer_only.py: remove shape_config from companion build call — embeddings/lm_head have dynamic seq_len, no static shape needed --- scripts/export_qwen3_transformer_only.py | 4 ++- src/winml/modelkit/utils/genai.py | 44 ++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/scripts/export_qwen3_transformer_only.py b/scripts/export_qwen3_transformer_only.py index 903aef825..e9d11596c 100644 --- a/scripts/export_qwen3_transformer_only.py +++ b/scripts/export_qwen3_transformer_only.py @@ -239,6 +239,9 @@ def main(argv: list[str] | None = None) -> int: f"\n=== building {key} " f"(model_type={spec['model_type']}, precision={spec['precision']}) ===" ) + # Embeddings has dynamic seq_len (Gather op; no static shape needed). + # LM head also uses dynamic seq_len. Omit shape_config so the + # dynamic axes in the OnnxConfig are not overridden by a fixed value. companion = WinMLAutoModel.from_pretrained( args.model_id, task="feature-extraction", @@ -249,7 +252,6 @@ def main(argv: list[str] | None = None) -> int: no_compile=True, use_cache=True, force_rebuild=args.force_rebuild, - shape_config={"seq_len": args.prefill_seq_len}, ) companion_path = Path(companion.onnx_path) print(f" [{key}] {companion_path}") diff --git a/src/winml/modelkit/utils/genai.py b/src/winml/modelkit/utils/genai.py index 6717406b1..0bab4acd7 100644 --- a/src/winml/modelkit/utils/genai.py +++ b/src/winml/modelkit/utils/genai.py @@ -531,6 +531,35 @@ def _pick_kv( # --------------------------------------------------------------------------- +def _patch_seq_dim_dynamic(onnx_path: Path, dim_index: int = 1) -> None: + """Make dimension *dim_index* of all graph inputs/outputs symbolic. + + ort-genai calls the embeddings model with the full prompt on prefill + (seq_len = prompt_len) and with a single token on each decode step + (seq_len = 1). The ONNX export may bake in a concrete value; this + helper replaces it with the symbolic name ``"seq_len"`` so the runtime + accepts any sequence length. + + The model weights (external data) are not touched — only the protobuf + shape annotations are updated. + """ + import onnx + + model = onnx.load(str(onnx_path), load_external_data=False) + changed = False + for value_info in list(model.graph.input) + list(model.graph.output): + shape = value_info.type.tensor_type.shape + if shape and len(shape.dim) > dim_index: + dim = shape.dim[dim_index] + if dim.HasField("dim_value"): # it's a fixed integer + dim.ClearField("dim_value") + dim.dim_param = "seq_len" + changed = True + if changed: + onnx.save(model, str(onnx_path)) + logger.info("Patched seq_len dim to dynamic in %s", onnx_path.name) + + def write_genai_bundle( output_dir: str | Path, *, @@ -595,10 +624,15 @@ def write_genai_bundle( logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename) copy_onnx_model(iterator_onnx, output_dir / iterator_filename) - # 2. Copy placeholder models (embeddings + lm_head). + # 2. Copy embeddings + lm_head models. if embeddings_src is not None: logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename) - copy_onnx_model(Path(embeddings_src), output_dir / embeddings_filename) + dst_embeddings = output_dir / embeddings_filename + copy_onnx_model(Path(embeddings_src), dst_embeddings) + # Patch seq_len to dynamic: ort-genai calls embeddings with the full + # prompt on prefill and with a single token on every decode step, so the + # seq_len dimension must be symbolic, not a fixed value. + _patch_seq_dim_dynamic(dst_embeddings) else: logger.warning( "embeddings_src not provided — '%s' is missing from bundle.", @@ -607,7 +641,11 @@ def write_genai_bundle( if lm_head_src is not None: logger.info("Copying lm_head: %s -> %s", Path(lm_head_src).name, lm_head_filename) - copy_onnx_model(Path(lm_head_src), output_dir / lm_head_filename) + dst_lm_head = output_dir / lm_head_filename + copy_onnx_model(Path(lm_head_src), dst_lm_head) + # Same reason as embeddings: lm_head is called with prefill seq_len + # and with seq_len=1 on each decode step. + _patch_seq_dim_dynamic(dst_lm_head) else: logger.warning( "lm_head_src not provided — '%s' is missing from bundle.", From 4babc22b8267e2d75d7da5b93cb1d597cf329388 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 1 Jul 2026 12:57:59 +0800 Subject: [PATCH 17/29] fix(genai): mirror non-QNN ONNX files into compiled bundle _mirror_non_onnx_files previously skipped ALL .onnx/.data files, which meant embeddings.onnx and lm_head.onnx were inaccessible when ort-genai loaded from _compiled/. Now only the QNN-compiled stage files (those in qnn_stages) are excluded; CPU-side ONNX files are symlinked into the compiled bundle directory so ort-genai can find them. Also pass compiled_onnx_names from _prepare_compiled_bundle to the mirror helper so the skip set is driven by what was actually compiled. Verified: --compile produces valid EPContext for ctx/iter stages, embeddings.onnx and lm_head.onnx are symlinked, inference runs at ~37 tok/s on Snapdragon X Elite NPU. --- src/winml/modelkit/session/genai_session.py | 31 ++++++++++++--------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py index 3a2738a77..d920d96f3 100644 --- a/src/winml/modelkit/session/genai_session.py +++ b/src/winml/modelkit/session/genai_session.py @@ -544,7 +544,11 @@ def _prepare_compiled_bundle(self) -> Path: compiled_config.write_text( json.dumps(modified_cfg, indent=2, ensure_ascii=False), encoding="utf-8" ) - self._mirror_non_onnx_files(compiled_dir) + # Only skip the QNN-stage ONNX files (compiled → new path, or failed → + # absolute path patch). Non-QNN ONNX files (embeddings, lm_head) must + # be symlinked into compiled_dir so ort-genai can find them by filename. + compiled_onnx_names = {onnx_filename for _, onnx_filename, _ in qnn_stages} + self._mirror_non_onnx_files(compiled_dir, skip_filenames=compiled_onnx_names) logger.info("Compiled bundle prepared at %s", compiled_dir) return compiled_dir @@ -624,22 +628,24 @@ def _compile_stage( logger.info("Stage %r compiled successfully → %s", stage_key, ctx_out) return True - def _mirror_non_onnx_files(self, compiled_dir: Path) -> None: - """Create symlinks (or copies on Windows) for every non-ONNX file. + def _mirror_non_onnx_files( + self, compiled_dir: Path, skip_filenames: set[str] | None = None + ) -> None: + """Create symlinks (or copies on Windows) for files not being compiled. - Files are linked/copied into *compiled_dir* so that ``og.Config`` - finds tokenizer files, specials maps, etc. ONNX files are intentionally - skipped — compiled stages land at different filenames inside *compiled_dir*, - and non-compiled stages fall back to their original path via an absolute - filename written into the modified genai_config.json. Existing files are - left untouched. + Links files into *compiled_dir* so ``og.Config`` finds tokenizer files, + non-QNN ONNX models (embeddings, lm_head), etc. Only ONNX files listed + in *skip_filenames* (the QNN-compiled stages) and their external ``.data`` + siblings are skipped — everything else, including CPU-side ONNX files, is + linked. Existing files are left untouched. """ + skip = set(skip_filenames or []) + # Also skip .data sidecars of the compiled-stage ONNX files. + skip_data = {f"{name}.data" for name in skip} | {f"{name}.data" for name in skip} for src in self._bundle_dir.iterdir(): if src.name == self._COMPILED_SUBDIR: continue - if src.suffix in (".onnx", ".data"): - # Skip model files — compiled stages are already at their new paths; - # large ONNX weights (potentially several GB) must not be duplicated. + if src.name in skip or src.name in skip_data: continue dst = compiled_dir / src.name if dst.exists(): @@ -648,7 +654,6 @@ def _mirror_non_onnx_files(self, compiled_dir: Path) -> None: try: dst.symlink_to(src.resolve()) except (OSError, NotImplementedError): - # Symlinks may require elevated privileges on Windows; fall back to copy. shutil.copy2(src, dst) @staticmethod From 7428fa9b871b65f5e3f2d82305e5d20b73da79dc Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 1 Jul 2026 13:48:52 +0800 Subject: [PATCH 18/29] refactor(scripts): unify Qwen3 export + inference into qwen3.py Consolidate three scripts into a single unified CLI with sub-commands: qwen3.py export -- full genai bundle build (transformer + embeddings + lm_head), replaces export_qwen3_transformer_only.py and export_qwen3_embeddings_lm_head.py qwen3.py infer -- onnxruntime-genai streamed inference, replaces infer_genai.py Deleted: scripts/export_qwen3_embeddings_lm_head.py (obsolete since #1008 integrated embeddings/lm_head into the main export pipeline) scripts/export_qwen3_transformer_only.py scripts/infer_genai.py Changes: - Default --device is now npu (was cpu) to match the primary use-case - Default --max-cache-len is now 2048 (aligns with reference bundle) - --output replaces --genai-bundle for clarity - --bundle replaces --model-dir in the infer sub-command - --compile in export triggers EPContext pre-compilation via GenaiSession context-manager (no private API access) - node summary covers both transformer (GQA/QDQ) and companion models (Gather/MatMulNBits) --- scripts/export_qwen3_embeddings_lm_head.py | 185 ---------- scripts/export_qwen3_transformer_only.py | 281 --------------- scripts/infer_genai.py | 151 -------- scripts/qwen3.py | 379 +++++++++++++++++++++ 4 files changed, 379 insertions(+), 617 deletions(-) delete mode 100644 scripts/export_qwen3_embeddings_lm_head.py delete mode 100644 scripts/export_qwen3_transformer_only.py delete mode 100644 scripts/infer_genai.py create mode 100644 scripts/qwen3.py diff --git a/scripts/export_qwen3_embeddings_lm_head.py b/scripts/export_qwen3_embeddings_lm_head.py deleted file mode 100644 index 4769e4924..000000000 --- a/scripts/export_qwen3_embeddings_lm_head.py +++ /dev/null @@ -1,185 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""One-shot export of the Qwen3 embeddings + lm_head sub-models. - -Companion to ``export_qwen3_transformer_only.py``. That script builds the -transformer (prefill + decode). This one builds the two remaining pieces of the -split Qwen3 graph: - - - ``embeddings`` — input embedding table (``input_ids`` -> ``input_hidden_states``). - Built FLOAT (``precision="fp32"``); NOT quantized. Produces a ``Gather``. - - ``lm_head`` — vocab projection (``output_hidden_states`` -> ``logits``). - Quantized weight-only to int4 via MatMulNBits/RTN (``precision="w4a32"``). - Produces a ``MatMulNBits`` node (no float ``MatMul``). - -Both are standalone ``model_type`` builds, invoked separately. - -Usage:: - - # Build (or reuse cached) both ONNX, print their paths + node summary: - uv run python scripts/export_qwen3_embeddings_lm_head.py - - # Build only one of them: - uv run python scripts/export_qwen3_embeddings_lm_head.py --only embeddings - uv run python scripts/export_qwen3_embeddings_lm_head.py --only lm_head - - # Copy the ONNX (with external data) into a folder: - uv run python scripts/export_qwen3_embeddings_lm_head.py --output-dir out/qwen3 - # -> writes embeddings_fp32.onnx and lm_head_w4a32.onnx - - # Different model / device / seq geometry, force a rebuild: - uv run python scripts/export_qwen3_embeddings_lm_head.py \ - --model-id Qwen/Qwen3-0.6B --device npu --seq-len 64 --force-rebuild -""" - -from __future__ import annotations - -import argparse -import collections -import sys -import time -from pathlib import Path - -import onnx - -from winml.modelkit.models.auto import WinMLAutoModel -from winml.modelkit.onnx import copy_onnx_model - - -# Per-component build settings: which model_type to register against and the -# precision that drives its quant policy. Embeddings stay float (``fp32``); the -# lm_head is weight-only int4 with fp32 activations (``w4a32`` — the activations -# are NOT quantized, this is RTN/MatMulNBits). -_COMPONENTS = { - "embeddings": {"model_type": "qwen3_embeddings_only", "precision": "fp32"}, - "lm_head": {"model_type": "qwen3_lm_head_only", "precision": "w4a32"}, -} - -# Component -> output file stem (when --output-dir is given). The precision -# suffix is carried in the filename so the two pieces self-document their scheme. -_OUTPUT_STEMS = { - "embeddings": f"embeddings_{_COMPONENTS['embeddings']['precision']}", - "lm_head": f"lm_head_{_COMPONENTS['lm_head']['precision']}", -} - -# Default EP per device; CPU/NPU/GPU map to their canonical providers. -_DEVICE_TO_EP = { - "cpu": "CPUExecutionProvider", - "npu": "QNNExecutionProvider", - "gpu": "DmlExecutionProvider", -} - - -def node_summary(path: str | Path) -> str: - """Return a one-line structural summary of the graph's key ops. - - The interesting markers for these two sub-models are: - - embeddings: ``Gather`` present, no ``MatMulNBits`` / no QDQ (stays float). - - lm_head: ``MatMulNBits`` present, float ``MatMul`` gone (int4 weight-only). - """ - model = onnx.load(str(path), load_external_data=False) - counts = collections.Counter(n.op_type for n in model.graph.node) - return ( - f"Gather={counts['Gather']} MatMul={counts['MatMul']} " - f"MatMulNBits={counts['MatMulNBits']} " - f"Q={counts['QuantizeLinear']} DQ={counts['DequantizeLinear']}" - ) - - -def build_component(name: str, args: argparse.Namespace): - """Build (or reuse cached) one standalone sub-model and return it.""" - spec = _COMPONENTS[name] - print(f"\n=== building {name} (model_type={spec['model_type']}, " - f"precision={spec['precision']}) ===") - return WinMLAutoModel.from_pretrained( - args.model_id, - task="feature-extraction", - model_type=spec["model_type"], - precision=spec["precision"], - device=args.device, - ep=_DEVICE_TO_EP[args.device], - no_compile=args.no_compile, - use_cache=True, - force_rebuild=args.force_rebuild, - shape_config={"seq_len": args.seq_len}, - ) - - -def parse_args(argv: list[str] | None = None) -> argparse.Namespace: - """Parse command-line arguments.""" - p = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - p.add_argument("--model-id", default="Qwen/Qwen3-0.6B", help="HF model id or local path.") - p.add_argument( - "--device", - default="cpu", - choices=sorted(_DEVICE_TO_EP), - help="Target device (selects the canonical EP). Default: cpu.", - ) - p.add_argument( - "--only", - choices=sorted(_COMPONENTS), - default=None, - help="Build only this component. Default: build both.", - ) - p.add_argument( - "--seq-len", - type=int, - default=64, - help="Static sequence length baked into both sub-models. Default: 64.", - ) - p.add_argument( - "--no-compile", - dest="no_compile", - action="store_true", - default=True, - help="Skip EPContext compilation (default; these are consumed pre-compile).", - ) - p.add_argument( - "--compile", - dest="no_compile", - action="store_false", - help="Enable EPContext compilation (requires the device's compiler/SDK).", - ) - p.add_argument("--force-rebuild", action="store_true", help="Rebuild even if cached.") - p.add_argument( - "--output-dir", - type=Path, - default=None, - help="If set, copy the ONNX (with external data) here as " - "embeddings_fp32.onnx / lm_head_w4a32.onnx.", - ) - return p.parse_args(argv) - - -def main(argv: list[str] | None = None) -> int: - """Build (or reuse) the requested sub-models and report/copy them.""" - args = parse_args(argv) - - names = [args.only] if args.only else ["embeddings", "lm_head"] - - output_dir: Path | None = args.output_dir - if output_dir is not None: - output_dir.mkdir(parents=True, exist_ok=True) - - t0 = time.monotonic() - for name in names: - model = build_component(name, args) - src = Path(model.onnx_path) - print(f"[{name}] {src}") - print(f" {node_summary(src)}") - if output_dir is not None: - dst = output_dir / f"{_OUTPUT_STEMS[name]}.onnx" - copy_onnx_model(src, dst) - print(f" -> copied to {dst}") - - print(f"\n=== done in {time.monotonic() - t0:.1f}s ===") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/export_qwen3_transformer_only.py b/scripts/export_qwen3_transformer_only.py deleted file mode 100644 index e9d11596c..000000000 --- a/scripts/export_qwen3_transformer_only.py +++ /dev/null @@ -1,281 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -r"""One-shot export of the Qwen3 transformer-only prefill + decode pair. - -Leverages the registered ``WinMLQwen3TransformerOnlyModel`` composite to build -BOTH transformer-only sub-models in a single call: - - - ``decoder_prefill`` — context graph, ``seq_len`` = --prefill-seq-len (64) - - ``decoder_gen`` — iteration graph, ``seq_len`` = 1 - -Each sub-model is built through the standard ``build_hf_model`` pipeline, so the -model-type quant finalizer is applied (int8 weight / uint16 activation, GQA -excluded from QDQ). Embeddings and the LM head are NOT part of this graph — they -run separately (e.g. from the bundle). - -Usage:: - - # Build (or reuse cached) both ONNX, print their paths + node summary: - uv run python scripts/export_qwen3_transformer_only.py - - # Copy the two ONNX (with external data) into a folder: - uv run python scripts/export_qwen3_transformer_only.py --output-dir out/qwen3 - - # Different model / device / cache geometry, force a rebuild: - uv run python scripts/export_qwen3_transformer_only.py \ - --model-id Qwen/Qwen3-0.6B --device npu \ - --max-cache-len 256 --prefill-seq-len 64 --force-rebuild - - # Assemble a complete genai bundle (auto-builds embeddings + lm_head): - uv run python scripts/export_qwen3_transformer_only.py \\ - --device npu --genai-bundle out/bundle -""" - -from __future__ import annotations - -import argparse -import collections -import sys -import time -from pathlib import Path - -import onnx - -from winml.modelkit.models.auto import WinMLAutoModel -from winml.modelkit.models.hf.qwen3.qwen_transformer_only import ( - WinMLQwen3TransformerOnlyModel, -) -from winml.modelkit.onnx import copy_onnx_model - - -# Build settings for the two companion sub-models. Embeddings stay float; -# lm_head is weight-only int4 (MatMulNBits / RTN). -_COMPANION_COMPONENTS: dict[str, dict[str, str]] = { - "embeddings": {"model_type": "qwen3_embeddings_only", "precision": "fp32"}, - "lm_head": {"model_type": "qwen3_lm_head_only", "precision": "w4a32"}, -} - - -# Component name -> output file stem used when --output-dir is given. -_OUTPUT_STEMS = { - "decoder_prefill": "prefill", - "decoder_gen": "decode", -} - -# Default EP per device; CPU/NPU/GPU map to their canonical providers. -_DEVICE_TO_EP = { - "cpu": "CPUExecutionProvider", - "npu": "QNNExecutionProvider", - "gpu": "DmlExecutionProvider", -} - - -def node_summary(path: str | Path) -> str: - """Return a one-line QDQ/GQA structural summary of an ONNX graph.""" - model = onnx.load(str(path), load_external_data=False) - counts = collections.Counter(n.op_type for n in model.graph.node) - gqa_io: set[str] = set() - for node in model.graph.node: - if node.op_type == "GroupQueryAttention": - gqa_io.update(node.input) - gqa_io.update(node.output) - qdq_touching_gqa = sum( - 1 - for n in model.graph.node - if n.op_type in ("QuantizeLinear", "DequantizeLinear") - and (set(n.input) & gqa_io or set(n.output) & gqa_io) - ) - return ( - f"Q={counts['QuantizeLinear']} DQ={counts['DequantizeLinear']} " - f"GQA={counts['GroupQueryAttention']} QDQ-touching-GQA={qdq_touching_gqa}" - ) - - -def parse_args(argv: list[str] | None = None) -> argparse.Namespace: - """Parse command-line arguments.""" - p = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - p.add_argument("--model-id", default="Qwen/Qwen3-0.6B", help="HF model id or local path.") - p.add_argument( - "--device", - default="cpu", - choices=sorted(_DEVICE_TO_EP), - help="Target device (selects the canonical EP). Default: cpu.", - ) - p.add_argument("--precision", default="w8a16", help="Build precision. Default: w8a16.") - p.add_argument("--max-cache-len", type=int, default=256, help="Static KV cache length.") - p.add_argument( - "--prefill-seq-len", - type=int, - default=64, - help="Prefill/context sequence length.", - ) - p.add_argument( - "--no-compile", - dest="no_compile", - action="store_true", - default=True, - help="Skip EPContext compilation (default; transformer-only is consumed pre-compile).", - ) - p.add_argument( - "--compile", - dest="no_compile", - action="store_false", - help="Enable EPContext compilation (requires the device's compiler/SDK).", - ) - p.add_argument("--force-rebuild", action="store_true", help="Rebuild even if cached.") - p.add_argument( - "--output-dir", - type=Path, - default=None, - help="If set, copy the two ONNX (with external data) here as prefill.onnx / decode.onnx.", - ) - - genai = p.add_argument_group( - "genai bundle", - "Options for producing an onnxruntime-genai inference bundle.", - ) - genai.add_argument( - "--genai-bundle", - type=Path, - default=None, - metavar="DIR", - help=( - "If set, assemble a complete onnxruntime-genai bundle in DIR: " - "ctx.onnx (prefill), iter.onnx (decode), embeddings.onnx, " - "lm_head.onnx, genai_config.json, and tokenizer files. " - "Embeddings (fp32) and lm_head (w4a32) are built automatically " - "from --model-id; use --embeddings / --lm-head to override with " - "a pre-built ONNX path instead." - ), - ) - genai.add_argument( - "--embeddings", - type=Path, - default=None, - metavar="ONNX", - help=( - "Override path to the embeddings ONNX. When omitted and " - "--genai-bundle is set, the embeddings model is built automatically." - ), - ) - genai.add_argument( - "--lm-head", - type=Path, - default=None, - metavar="ONNX", - help=( - "Override path to the lm_head ONNX. When omitted and " - "--genai-bundle is set, the lm_head model is built automatically." - ), - ) - return p.parse_args(argv) - - -def main(argv: list[str] | None = None) -> int: - """Build (or reuse) both transformer-only ONNX and report/copy them.""" - args = parse_args(argv) - - t0 = time.monotonic() - model = WinMLQwen3TransformerOnlyModel.from_pretrained( - args.model_id, - device=args.device, - precision=args.precision, - ep=_DEVICE_TO_EP[args.device], - no_compile=args.no_compile, - use_cache=True, - force_rebuild=args.force_rebuild, - sub_model_kwargs={ - "decoder_prefill": { - "shape_config": { - "max_cache_len": args.max_cache_len, - "seq_len": args.prefill_seq_len, - } - }, - "decoder_gen": {"shape_config": {"max_cache_len": args.max_cache_len, "seq_len": 1}}, - }, - ) - elapsed = time.monotonic() - t0 - - print(f"\n=== transformer-only build done in {elapsed:.1f}s ===") - - output_dir: Path | None = args.output_dir - if output_dir is not None: - output_dir.mkdir(parents=True, exist_ok=True) - - for name, sub in model.sub_models.items(): - src = Path(sub.onnx_path) - print(f"\n[{name}] {src}") - print(f" {node_summary(src)}") - if output_dir is not None: - dst = output_dir / f"{_OUTPUT_STEMS.get(name, name)}.onnx" - copy_onnx_model(src, dst) - print(f" -> copied to {dst}") - - # ----------------------------------------------------------------------- - # Optional: assemble an onnxruntime-genai bundle. - # ----------------------------------------------------------------------- - if args.genai_bundle is not None: - from winml.modelkit.models.hf.qwen3.genai import write_genai_bundle - - prefill_path = Path(model.sub_models["decoder_prefill"].onnx_path) - decode_path = Path(model.sub_models["decoder_gen"].onnx_path) - - # Resolve embeddings / lm_head: use override paths when provided, - # otherwise build them automatically from the same model_id. - embeddings_src = args.embeddings - lm_head_src = args.lm_head - - for key, override in (("embeddings", embeddings_src), ("lm_head", lm_head_src)): - if override is not None: - print(f"\n=== using provided {key} ONNX: {override} ===") - else: - spec = _COMPANION_COMPONENTS[key] - print( - f"\n=== building {key} " - f"(model_type={spec['model_type']}, precision={spec['precision']}) ===" - ) - # Embeddings has dynamic seq_len (Gather op; no static shape needed). - # LM head also uses dynamic seq_len. Omit shape_config so the - # dynamic axes in the OnnxConfig are not overridden by a fixed value. - companion = WinMLAutoModel.from_pretrained( - args.model_id, - task="feature-extraction", - model_type=spec["model_type"], - precision=spec["precision"], - device="cpu", - ep=_DEVICE_TO_EP["cpu"], - no_compile=True, - use_cache=True, - force_rebuild=args.force_rebuild, - ) - companion_path = Path(companion.onnx_path) - print(f" [{key}] {companion_path}") - if key == "embeddings": - embeddings_src = companion_path - else: - lm_head_src = companion_path - - print(f"\n=== assembling genai bundle -> {args.genai_bundle} ===") - config_path = write_genai_bundle( - args.genai_bundle, - context_onnx=prefill_path, - iterator_onnx=decode_path, - model_id=args.model_id, - max_cache_len=args.max_cache_len, - prefill_seq_len=args.prefill_seq_len, - embeddings_src=embeddings_src, - lm_head_src=lm_head_src, - ep="qnn" if args.device == "npu" else args.device, - ) - print(f" genai_config.json -> {config_path}") - - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/infer_genai.py b/scripts/infer_genai.py deleted file mode 100644 index fbf535500..000000000 --- a/scripts/infer_genai.py +++ /dev/null @@ -1,151 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -r"""onnxruntime-genai inference for a genai bundle (decoder-pipeline). - -Loads the genai bundle produced by ``export_qwen3_transformer_only.py ---genai-bundle `` and runs greedy text generation using -:class:`~winml.modelkit.session.GenaiSession`. - -The bundle directory must contain ``genai_config.json`` and the four ONNX -graphs it references (``embeddings.onnx``, ``ctx.onnx``, ``iter.onnx``, -``lm_head.onnx``) plus HF tokenizer files. - -Usage:: - - # CPU sanity check (works anywhere onnxruntime-genai is installed) - uv run python scripts/infer_genai.py --prompt "Hello, who are you?" --chat - - # Qualcomm NPU (registers the QNN EP via the Windows ML EP catalog) - uv run python scripts/infer_genai.py \\ - --prompt "Explain what a transformer is." \\ - --ep qnn --chat - - # Point at a non-default bundle - uv run python scripts/infer_genai.py \\ - --model-dir out/my_bundle --prompt "Hi" --ep cpu - - # Pre-compile QNN stages to EPContext on first run; reuse cache on subsequent runs. - # Eliminates per-run JIT overhead (~60-90 s saved on Snapdragon X Elite). - uv run python scripts/infer_genai.py \\ - --prompt "Hello" --ep mixed --compile - -Dependencies (install in a fresh venv):: - - pip install onnxruntime-genai-winml - pip install "windowsml[with-ort]" # registers QNN EP; also provides onnxruntime -""" - -from __future__ import annotations - -import argparse -import sys -import time -from pathlib import Path - -from winml.modelkit.session import GenaiSession, GenerationConfig - - -# Default bundle directory: /out/qwen3_bundle -_REPO_ROOT = Path(__file__).resolve().parent.parent -DEFAULT_MODEL_DIR = _REPO_ROOT / "out" / "qwen3_bundle" - -_SUPPORTED_EPS = ["cpu", "mixed", "qnn", "dml"] - - -def _wrap_chat_template(prompt: str) -> str: - """Wrap *prompt* in the ChatML chat template.""" - return GenaiSession.apply_chatml_template(prompt) - - -def parse_args(argv: list[str] | None = None) -> argparse.Namespace: - """Parse CLI arguments.""" - p = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - p.add_argument( - "--prompt", - default="Give me a short introduction to large language models.", - help="Input prompt (default: %(default)s).", - ) - p.add_argument( - "--model-dir", - type=Path, - default=DEFAULT_MODEL_DIR, - metavar="DIR", - help=( - "Path to the genai bundle directory containing genai_config.json " - "and the ONNX / tokenizer files (default: %(default)s)." - ), - ) - p.add_argument( - "--ep", - choices=_SUPPORTED_EPS, - default="mixed", - help="Execution provider: 'mixed' uses genai_config.json as-is (default); " - "'cpu' forces all stages to CPU; 'qnn'/'dml' for full NPU/GPU.", - ) - p.add_argument( - "--max-new", - type=int, - default=128, - help="Maximum number of new tokens to generate (default: %(default)s).", - ) - p.add_argument( - "--chat", - action="store_true", - help="Wrap --prompt in the ChatML template (<|im_start|>user/assistant).", - ) - p.add_argument( - "--compile", - action="store_true", - help=( - "Pre-compile QNN pipeline stages to EPContext ONNX before loading. " - "On first use this triggers ort.ModelCompiler per stage (~60-90 s for iter). " - "Compiled artifacts are cached in bundle_dir/_compiled/; " - "subsequent runs reuse the cache and skip JIT. " - "Has no effect when --ep cpu." - ), - ) - p.add_argument( - "--verbose", - action="store_true", - help="Enable onnxruntime-genai native model I/O logging.", - ) - return p.parse_args(argv) - - -def main(argv: list[str] | None = None) -> int: - """Load the genai bundle and run generation.""" - args = parse_args(argv) - - text = _wrap_chat_template(args.prompt) if args.chat else args.prompt - gen_cfg = GenerationConfig(max_new_tokens=args.max_new, do_sample=False) - - try: - session = GenaiSession( - args.model_dir, ep=args.ep, verbose=args.verbose, compile=args.compile - ) - except FileNotFoundError as exc: - print(f"ERROR: {exc}", file=sys.stderr) - return 1 - - print(f"[load] ep={args.ep} bundle={args.model_dir}") - with session: - print(f"[ctx] context_length={session.context_length}") - print("[gen] ", end="", flush=True) - t0 = time.monotonic() - n = 0 - for token_str in session.generate_streaming(text, gen_cfg): - print(token_str, end="", flush=True) - n += 1 - - dt = time.monotonic() - t0 - print(f"\n\n[done] {n} tokens in {dt:.1f}s ({n / max(dt, 1e-9):.1f} tok/s)") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/qwen3.py b/scripts/qwen3.py new file mode 100644 index 000000000..3bebeff7b --- /dev/null +++ b/scripts/qwen3.py @@ -0,0 +1,379 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +r"""Unified Qwen3 genai pipeline: full model export and inference. + +Sub-commands +------------ +export + Build (or reuse) all four components of the Qwen3 genai bundle and + assemble them into an onnxruntime-genai directory: + + - ``ctx.onnx`` — transformer prefill graph (QNN-quantized) + - ``iter.onnx`` — transformer decode graph (QNN-quantized) + - ``embeddings.onnx`` — token embedding table (fp32) + - ``lm_head.onnx`` — vocab projection (w4a32 MatMulNBits) + - ``genai_config.json`` + HF tokenizer files + +infer + Load a bundle produced by ``export`` and run streamed text generation + using :class:`~winml.modelkit.session.GenaiSession`. + +Usage:: + + # Full export to out/bundle (default context_length=2048): + uv run python scripts/qwen3.py export --device npu --output out/bundle + + # Export with EPContext pre-compilation: + uv run python scripts/qwen3.py export --device npu --output out/bundle --compile + + # Force rebuild from scratch: + uv run python scripts/qwen3.py export --device npu --output out/bundle --force-rebuild + + # Inference with the compiled bundle: + uv run python scripts/qwen3.py infer --bundle out/bundle --prompt "Hello" --chat + + # Pre-compile QNN stages on first inference run (reuses cache on subsequent runs): + uv run python scripts/qwen3.py infer --bundle out/bundle --prompt "Hi" --compile +""" + +from __future__ import annotations + +import argparse +import collections +import sys +import time +from pathlib import Path + +import onnx + +from winml.modelkit.models.auto import WinMLAutoModel +from winml.modelkit.models.hf.qwen3.qwen_transformer_only import ( + WinMLQwen3TransformerOnlyModel, +) +from winml.modelkit.session import GenaiSession, GenerationConfig + + +_DEVICE_TO_EP = { + "cpu": "CPUExecutionProvider", + "npu": "QNNExecutionProvider", + "gpu": "DmlExecutionProvider", +} + +# Build specs for the two CPU-side companion models. +_COMPANION_COMPONENTS: dict[str, dict[str, str]] = { + "embeddings": {"model_type": "qwen3_embeddings_only", "precision": "fp32"}, + "lm_head": {"model_type": "qwen3_lm_head_only", "precision": "w4a32"}, +} + +_REPO_ROOT = Path(__file__).resolve().parent.parent +_DEFAULT_BUNDLE = _REPO_ROOT / "out" / "bundle" + +_SUPPORTED_EPS = ["cpu", "mixed", "qnn", "dml"] + + +# --------------------------------------------------------------------------- +# Helpers shared between sub-commands +# --------------------------------------------------------------------------- + + +def _node_summary(path: str | Path) -> str: + """One-line op-type summary of an ONNX graph (loads shape metadata only).""" + model = onnx.load(str(path), load_external_data=False) + counts = collections.Counter(n.op_type for n in model.graph.node) + gqa_io: set[str] = set() + for node in model.graph.node: + if node.op_type == "GroupQueryAttention": + gqa_io.update(node.input) + gqa_io.update(node.output) + qdq_touching_gqa = sum( + 1 + for n in model.graph.node + if n.op_type in ("QuantizeLinear", "DequantizeLinear") + and (set(n.input) & gqa_io or set(n.output) & gqa_io) + ) + return ( + f"Gather={counts['Gather']} MatMulNBits={counts['MatMulNBits']} " + f"Q={counts['QuantizeLinear']} DQ={counts['DequantizeLinear']} " + f"GQA={counts['GroupQueryAttention']} QDQ@GQA={qdq_touching_gqa}" + ) + + +# --------------------------------------------------------------------------- +# export sub-command +# --------------------------------------------------------------------------- + + +def _add_export_parser(sub: argparse._SubParsersAction) -> None: # type: ignore[type-arg] + p = sub.add_parser( + "export", + help="Build the full Qwen3 genai bundle (transformer + embeddings + lm_head).", + formatter_class=argparse.RawDescriptionHelpFormatter, + description=( + "Exports all four Qwen3 genai components and assembles them into an " + "onnxruntime-genai bundle directory. Transformer stages (ctx/iter) are " + "built for the target device; embeddings and lm_head always run on CPU." + ), + ) + p.add_argument("--model-id", default="Qwen/Qwen3-0.6B", help="HF model id or local path.") + p.add_argument( + "--device", + default="npu", + choices=sorted(_DEVICE_TO_EP), + help="Target device for transformer stages. Default: npu.", + ) + p.add_argument("--precision", default="w8a16", help="Transformer precision. Default: w8a16.") + p.add_argument( + "--max-cache-len", + type=int, + default=2048, + help="Static KV cache length (context_length). Default: 2048.", + ) + p.add_argument( + "--prefill-seq-len", + type=int, + default=64, + help="Prefill/context sequence length baked into ctx.onnx. Default: 64.", + ) + p.add_argument( + "--output", + type=Path, + default=_DEFAULT_BUNDLE, + metavar="DIR", + help=f"Bundle output directory. Default: {_DEFAULT_BUNDLE}.", + ) + p.add_argument( + "--embeddings", + type=Path, + default=None, + metavar="ONNX", + help="Override path to a pre-built embeddings ONNX (skips auto-build).", + ) + p.add_argument( + "--lm-head", + type=Path, + default=None, + metavar="ONNX", + help="Override path to a pre-built lm_head ONNX (skips auto-build).", + ) + p.add_argument("--force-rebuild", action="store_true", help="Rebuild even if cached.") + p.add_argument( + "--compile", + action="store_true", + help=( + "Pre-compile QNN transformer stages to EPContext ONNX after export. " + "Compiled artifacts are cached in output/_compiled/ and reused on " + "subsequent runs." + ), + ) + + +def _cmd_export(args: argparse.Namespace) -> int: + """Build all components and write the genai bundle.""" + from winml.modelkit.models.hf.qwen3.genai import write_genai_bundle + + t0 = time.monotonic() + + # --- Transformer (ctx + iter) --- + print(f"\n=== building transformer stages (device={args.device}) ===") + transformer = WinMLQwen3TransformerOnlyModel.from_pretrained( + args.model_id, + device=args.device, + precision=args.precision, + ep=_DEVICE_TO_EP[args.device], + no_compile=True, + use_cache=True, + force_rebuild=args.force_rebuild, + sub_model_kwargs={ + "decoder_prefill": { + "shape_config": { + "max_cache_len": args.max_cache_len, + "seq_len": args.prefill_seq_len, + } + }, + "decoder_gen": {"shape_config": {"max_cache_len": args.max_cache_len, "seq_len": 1}}, + }, + ) + prefill_path = Path(transformer.sub_models["decoder_prefill"].onnx_path) + decode_path = Path(transformer.sub_models["decoder_gen"].onnx_path) + for name, path in (("ctx", prefill_path), ("iter", decode_path)): + print(f" [{name}] {path}") + print(f" {_node_summary(path)}") + + # --- Companion models (embeddings + lm_head) --- + embeddings_src = args.embeddings + lm_head_src = args.lm_head + + for key, override in (("embeddings", embeddings_src), ("lm_head", lm_head_src)): + if override is not None: + print(f"\n=== using provided {key}: {override} ===") + continue + spec = _COMPANION_COMPONENTS[key] + print( + f"\n=== building {key} " + f"(model_type={spec['model_type']}, precision={spec['precision']}) ===" + ) + companion = WinMLAutoModel.from_pretrained( + args.model_id, + task="feature-extraction", + model_type=spec["model_type"], + precision=spec["precision"], + device="cpu", + ep=_DEVICE_TO_EP["cpu"], + no_compile=True, + use_cache=True, + force_rebuild=args.force_rebuild, + ) + companion_path = Path(companion.onnx_path) + print(f" [{key}] {companion_path}") + print(f" {_node_summary(companion_path)}") + if key == "embeddings": + embeddings_src = companion_path + else: + lm_head_src = companion_path + + # --- Assemble bundle --- + print(f"\n=== assembling bundle -> {args.output} ===") + config_path = write_genai_bundle( + args.output, + context_onnx=prefill_path, + iterator_onnx=decode_path, + model_id=args.model_id, + max_cache_len=args.max_cache_len, + prefill_seq_len=args.prefill_seq_len, + embeddings_src=embeddings_src, + lm_head_src=lm_head_src, + ep="qnn" if args.device == "npu" else args.device, + ) + print(f" genai_config.json -> {config_path}") + + # --- Optional EPContext pre-compilation --- + if args.compile: + print(f"\n=== compiling QNN stages -> {args.output}/_compiled/ ===") + # Loading with compile=True triggers _prepare_compiled_bundle() and + # caches the EPContext ONNX files; we don't need to run generation. + with GenaiSession(args.output, ep="mixed", compile=True): + pass + print(f" compiled bundle -> {args.output / '_compiled'}") + + elapsed = time.monotonic() - t0 + print(f"\n=== export complete in {elapsed:.1f}s ===") + return 0 + + +# --------------------------------------------------------------------------- +# infer sub-command +# --------------------------------------------------------------------------- + + +def _add_infer_parser(sub: argparse._SubParsersAction) -> None: # type: ignore[type-arg] + p = sub.add_parser( + "infer", + help="Run streamed inference on a Qwen3 genai bundle.", + formatter_class=argparse.RawDescriptionHelpFormatter, + description=( + "Loads a genai bundle produced by 'qwen3.py export' and generates text " + "using onnxruntime-genai with the configured execution providers." + ), + ) + p.add_argument( + "--prompt", + default="Give me a short introduction to large language models.", + help="Input prompt (default: %(default)s).", + ) + p.add_argument( + "--bundle", + type=Path, + default=_DEFAULT_BUNDLE, + metavar="DIR", + help=f"Path to the genai bundle directory. Default: {_DEFAULT_BUNDLE}.", + ) + p.add_argument( + "--ep", + choices=_SUPPORTED_EPS, + default="mixed", + help=( + "Execution provider. 'mixed' uses genai_config.json as-is " + "(QNN for transformer, CPU for embeddings/lm_head). Default: mixed." + ), + ) + p.add_argument( + "--max-new", + type=int, + default=128, + help="Maximum number of new tokens to generate. Default: 128.", + ) + p.add_argument( + "--chat", + action="store_true", + help="Wrap --prompt in the ChatML template (<|im_start|>user/assistant).", + ) + p.add_argument( + "--compile", + action="store_true", + help=( + "Pre-compile QNN pipeline stages to EPContext ONNX before loading. " + "Compiled artifacts are cached in bundle/_compiled/ and reused on " + "subsequent runs." + ), + ) + p.add_argument( + "--verbose", + action="store_true", + help="Enable onnxruntime-genai native model I/O logging.", + ) + + +def _cmd_infer(args: argparse.Namespace) -> int: + """Load the bundle and run generation.""" + text = GenaiSession.apply_chatml_template(args.prompt) if args.chat else args.prompt + gen_cfg = GenerationConfig(max_new_tokens=args.max_new, do_sample=False) + + try: + session = GenaiSession(args.bundle, ep=args.ep, verbose=args.verbose, compile=args.compile) + except FileNotFoundError as exc: + print(f"ERROR: {exc}", file=sys.stderr) + return 1 + + print(f"[load] ep={args.ep} bundle={args.bundle}") + with session: + print(f"[ctx] context_length={session.context_length}") + print("[gen] ", end="", flush=True) + t0 = time.monotonic() + n = 0 + for token_str in session.generate_streaming(text, gen_cfg): + print(token_str, end="", flush=True) + n += 1 + + dt = time.monotonic() - t0 + print(f"\n\n[done] {n} tokens in {dt:.1f}s ({n / max(dt, 1e-9):.1f} tok/s)") + return 0 + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(argv: list[str] | None = None) -> int: + """Parse sub-command and dispatch to the appropriate handler.""" + p = argparse.ArgumentParser( + prog="qwen3", + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + sub = p.add_subparsers(dest="command", metavar="") + sub.required = True + + _add_export_parser(sub) + _add_infer_parser(sub) + + args = p.parse_args(argv) + if args.command == "export": + return _cmd_export(args) + return _cmd_infer(args) + + +if __name__ == "__main__": + sys.exit(main()) From b2d82701137fabd6f0516fa9cc8b09d9cdfb76f8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 3 Jul 2026 09:01:05 +0800 Subject: [PATCH 19/29] refactor(genai): make bundle machinery EP-agnostic; move QNN into qwen3 Keep utils/genai.py execution-provider-agnostic: build_decoder_pipeline_stages and write_genai_bundle now take opaque context/iterator session_options supplied by the caller instead of an ep/soc_model pair, and qnn_stage_session_options is removed. The QNN HTP session_options move into the Qwen3 module (models/hf/qwen3/genai.py), which wraps the generic builders so the emitted genai_config.json stays byte-identical to before. Remove session/genai_session.py and its test (covered by a separate PR); session/__init__.py no longer exports the genai session symbols. scripts/qwen3.py becomes export-only (drop the infer subcommand, --compile, and the GenaiSession import). --- scripts/qwen3.py | 148 +---- src/winml/modelkit/models/hf/qwen3/genai.py | 190 +++++- src/winml/modelkit/session/__init__.py | 12 - src/winml/modelkit/session/genai_session.py | 699 -------------------- src/winml/modelkit/utils/genai.py | 102 +-- tests/unit/session/test_genai_session.py | 380 ----------- 6 files changed, 215 insertions(+), 1316 deletions(-) delete mode 100644 src/winml/modelkit/session/genai_session.py delete mode 100644 tests/unit/session/test_genai_session.py diff --git a/scripts/qwen3.py b/scripts/qwen3.py index 3bebeff7b..57b6f486b 100644 --- a/scripts/qwen3.py +++ b/scripts/qwen3.py @@ -2,40 +2,27 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -r"""Unified Qwen3 genai pipeline: full model export and inference. +r"""Qwen3 genai bundle export. -Sub-commands ------------- -export - Build (or reuse) all four components of the Qwen3 genai bundle and - assemble them into an onnxruntime-genai directory: +Builds (or reuses) all four components of the Qwen3 genai bundle and assembles +them into an onnxruntime-genai directory: - - ``ctx.onnx`` — transformer prefill graph (QNN-quantized) - - ``iter.onnx`` — transformer decode graph (QNN-quantized) - - ``embeddings.onnx`` — token embedding table (fp32) - - ``lm_head.onnx`` — vocab projection (w4a32 MatMulNBits) - - ``genai_config.json`` + HF tokenizer files + - ``ctx.onnx`` — transformer prefill graph (QNN-quantized) + - ``iter.onnx`` — transformer decode graph (QNN-quantized) + - ``embeddings.onnx`` — token embedding table (fp32) + - ``lm_head.onnx`` — vocab projection (w4a32 MatMulNBits) + - ``genai_config.json`` + HF tokenizer files -infer - Load a bundle produced by ``export`` and run streamed text generation - using :class:`~winml.modelkit.session.GenaiSession`. +Inference over the assembled bundle is provided separately (see the genai +inference session), so this script only covers bundle generation. Usage:: # Full export to out/bundle (default context_length=2048): uv run python scripts/qwen3.py export --device npu --output out/bundle - # Export with EPContext pre-compilation: - uv run python scripts/qwen3.py export --device npu --output out/bundle --compile - # Force rebuild from scratch: uv run python scripts/qwen3.py export --device npu --output out/bundle --force-rebuild - - # Inference with the compiled bundle: - uv run python scripts/qwen3.py infer --bundle out/bundle --prompt "Hello" --chat - - # Pre-compile QNN stages on first inference run (reuses cache on subsequent runs): - uv run python scripts/qwen3.py infer --bundle out/bundle --prompt "Hi" --compile """ from __future__ import annotations @@ -52,7 +39,6 @@ from winml.modelkit.models.hf.qwen3.qwen_transformer_only import ( WinMLQwen3TransformerOnlyModel, ) -from winml.modelkit.session import GenaiSession, GenerationConfig _DEVICE_TO_EP = { @@ -70,8 +56,6 @@ _REPO_ROOT = Path(__file__).resolve().parent.parent _DEFAULT_BUNDLE = _REPO_ROOT / "out" / "bundle" -_SUPPORTED_EPS = ["cpu", "mixed", "qnn", "dml"] - # --------------------------------------------------------------------------- # Helpers shared between sub-commands @@ -158,15 +142,6 @@ def _add_export_parser(sub: argparse._SubParsersAction) -> None: # type: ignore help="Override path to a pre-built lm_head ONNX (skips auto-build).", ) p.add_argument("--force-rebuild", action="store_true", help="Rebuild even if cached.") - p.add_argument( - "--compile", - action="store_true", - help=( - "Pre-compile QNN transformer stages to EPContext ONNX after export. " - "Compiled artifacts are cached in output/_compiled/ and reused on " - "subsequent runs." - ), - ) def _cmd_export(args: argparse.Namespace) -> int: @@ -248,109 +223,11 @@ def _cmd_export(args: argparse.Namespace) -> int: ) print(f" genai_config.json -> {config_path}") - # --- Optional EPContext pre-compilation --- - if args.compile: - print(f"\n=== compiling QNN stages -> {args.output}/_compiled/ ===") - # Loading with compile=True triggers _prepare_compiled_bundle() and - # caches the EPContext ONNX files; we don't need to run generation. - with GenaiSession(args.output, ep="mixed", compile=True): - pass - print(f" compiled bundle -> {args.output / '_compiled'}") - elapsed = time.monotonic() - t0 print(f"\n=== export complete in {elapsed:.1f}s ===") return 0 -# --------------------------------------------------------------------------- -# infer sub-command -# --------------------------------------------------------------------------- - - -def _add_infer_parser(sub: argparse._SubParsersAction) -> None: # type: ignore[type-arg] - p = sub.add_parser( - "infer", - help="Run streamed inference on a Qwen3 genai bundle.", - formatter_class=argparse.RawDescriptionHelpFormatter, - description=( - "Loads a genai bundle produced by 'qwen3.py export' and generates text " - "using onnxruntime-genai with the configured execution providers." - ), - ) - p.add_argument( - "--prompt", - default="Give me a short introduction to large language models.", - help="Input prompt (default: %(default)s).", - ) - p.add_argument( - "--bundle", - type=Path, - default=_DEFAULT_BUNDLE, - metavar="DIR", - help=f"Path to the genai bundle directory. Default: {_DEFAULT_BUNDLE}.", - ) - p.add_argument( - "--ep", - choices=_SUPPORTED_EPS, - default="mixed", - help=( - "Execution provider. 'mixed' uses genai_config.json as-is " - "(QNN for transformer, CPU for embeddings/lm_head). Default: mixed." - ), - ) - p.add_argument( - "--max-new", - type=int, - default=128, - help="Maximum number of new tokens to generate. Default: 128.", - ) - p.add_argument( - "--chat", - action="store_true", - help="Wrap --prompt in the ChatML template (<|im_start|>user/assistant).", - ) - p.add_argument( - "--compile", - action="store_true", - help=( - "Pre-compile QNN pipeline stages to EPContext ONNX before loading. " - "Compiled artifacts are cached in bundle/_compiled/ and reused on " - "subsequent runs." - ), - ) - p.add_argument( - "--verbose", - action="store_true", - help="Enable onnxruntime-genai native model I/O logging.", - ) - - -def _cmd_infer(args: argparse.Namespace) -> int: - """Load the bundle and run generation.""" - text = GenaiSession.apply_chatml_template(args.prompt) if args.chat else args.prompt - gen_cfg = GenerationConfig(max_new_tokens=args.max_new, do_sample=False) - - try: - session = GenaiSession(args.bundle, ep=args.ep, verbose=args.verbose, compile=args.compile) - except FileNotFoundError as exc: - print(f"ERROR: {exc}", file=sys.stderr) - return 1 - - print(f"[load] ep={args.ep} bundle={args.bundle}") - with session: - print(f"[ctx] context_length={session.context_length}") - print("[gen] ", end="", flush=True) - t0 = time.monotonic() - n = 0 - for token_str in session.generate_streaming(text, gen_cfg): - print(token_str, end="", flush=True) - n += 1 - - dt = time.monotonic() - t0 - print(f"\n\n[done] {n} tokens in {dt:.1f}s ({n / max(dt, 1e-9):.1f} tok/s)") - return 0 - - # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- @@ -367,12 +244,9 @@ def main(argv: list[str] | None = None) -> int: sub.required = True _add_export_parser(sub) - _add_infer_parser(sub) args = p.parse_args(argv) - if args.command == "export": - return _cmd_export(args) - return _cmd_infer(args) + return _cmd_export(args) if __name__ == "__main__": diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py index e6d583e4c..315f95fe6 100644 --- a/src/winml/modelkit/models/hf/qwen3/genai.py +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -2,19 +2,24 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Qwen3 genai bundle support — thin shim over :mod:`winml.modelkit.utils.genai`. +"""Qwen3 genai bundle support built on :mod:`winml.modelkit.utils.genai`. -All generic logic (``PipelineStage``, ``DecoderIOMapping``, ``build_genai_config``, -``build_decoder_pipeline_stages``, ``write_genai_bundle``) lives in -:mod:`winml.modelkit.utils.genai` so it can be reused by other model families. +The generic, execution-provider-agnostic machinery (``PipelineStage``, +``DecoderIOMapping``, ``build_genai_config``, ``build_decoder_pipeline_stages``, +``write_genai_bundle``) lives in :mod:`winml.modelkit.utils.genai` so it can be +reused by other model families. -This module re-exports that API unchanged and adds -``build_qwen3_transformer_only_stages`` as a backward-compatible alias for -``build_decoder_pipeline_stages``. New code should prefer the generic names. +This module adds the **Qwen3-specific** layer on top: the Qwen3 transformer +stages target the QNN HTP (NPU) backend, so this is where the QNN +``session_options`` are constructed. Keeping the EP-specific logic here lets the +generic utilities stay universal while the Qwen3 bundle keeps emitting the exact +same ``genai_config.json`` as before. """ from __future__ import annotations +from typing import TYPE_CHECKING + from ....utils.genai import ( DEFAULT_CONTEXT_FILENAME, DEFAULT_EMBEDDINGS_FILENAME, @@ -24,18 +29,173 @@ PipelineStage, build_decoder_pipeline_stages, build_genai_config, - qnn_stage_session_options, - write_genai_bundle, ) +from ....utils.genai import ( + write_genai_bundle as _write_genai_bundle, +) + + +if TYPE_CHECKING: + from pathlib import Path + + +# --------------------------------------------------------------------------- +# Qwen3-specific QNN execution-provider routing +# --------------------------------------------------------------------------- + + +def qnn_stage_session_options(log_id: str, soc_model: str = "60") -> dict: + """Return the ``session_options`` block that routes a stage to QNN HTP. + + Args: + log_id: ORT log identifier (shown in ORT logs), e.g. + ``"onnxruntime-genai.context"``. + soc_model: Snapdragon SoC model number passed to the QNN HTP backend. + ``"60"`` targets Snapdragon 8 Gen 3 (X Elite). Change for other + SoCs (e.g. ``"55"`` for 8 Gen 2, ``"73"`` for 8 Elite). + + Returns: + Dict suitable for the ``session_options`` key of a pipeline stage in + ``genai_config.json``. + """ + return { + "log_id": log_id, + "provider_options": [ + { + "qnn": { + "backend_path": "QnnHtp.dll", + "htp_performance_mode": "burst", + "htp_graph_finalization_optimization_mode": "3", + "soc_model": soc_model, + } + } + ], + "intra_op_num_threads": 2, + "inter_op_num_threads": 1, + } + + +def _stage_session_options(ep: str, soc_model: str) -> tuple[dict | None, dict | None]: + """Return ``(context, iterator)`` session_options for the given EP. + + ``ep="qnn"`` routes the transformer stages to the QNN HTP (NPU) backend; any + other value (e.g. ``"cpu"``) leaves them on the default CPU provider. + """ + if ep == "qnn": + return ( + qnn_stage_session_options("onnxruntime-genai.context", soc_model=soc_model), + qnn_stage_session_options("onnxruntime-genai.iterator", soc_model=soc_model), + ) + return None, None + + +# --------------------------------------------------------------------------- +# Qwen3-specific stage factory + bundle assembler +# --------------------------------------------------------------------------- + + +def build_qwen3_transformer_only_stages( + context_onnx: str | Path, + iterator_onnx: str | Path, + num_layers: int, + *, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", +) -> tuple[list[PipelineStage], DecoderIOMapping]: + """Build the Qwen3 4-stage pipeline, routing ctx/iter to QNN when ``ep="qnn"``. + + Qwen3-specific wrapper over + :func:`winml.modelkit.utils.genai.build_decoder_pipeline_stages` that injects + the QNN ``session_options`` for the transformer stages. Tensor names are + still discovered by introspecting the ONNX graphs, so nothing is hardcoded. + + Args: + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + num_layers: Number of transformer layers (``hf_config.num_hidden_layers``). + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. + ep: ``"qnn"`` injects QNN HTP ``session_options`` into the ``context`` + and ``iterator`` stages so they run on the NPU while ``embeddings`` + and ``lm_head`` stay on CPU. ``"cpu"`` (default) omits them. + soc_model: Snapdragon SoC model number forwarded to the QNN backend when + ``ep="qnn"``. Default ``"60"`` targets Snapdragon 8 Gen 3. + + Returns: + ``(stages, decoder_io)`` — see + :func:`~winml.modelkit.utils.genai.build_decoder_pipeline_stages`. + """ + ctx_opts, iter_opts = _stage_session_options(ep, soc_model) + return build_decoder_pipeline_stages( + context_onnx, + iterator_onnx, + num_layers, + context_filename=context_filename, + iterator_filename=iterator_filename, + embeddings_filename=embeddings_filename, + lm_head_filename=lm_head_filename, + context_session_options=ctx_opts, + iterator_session_options=iter_opts, + ) + + +def write_genai_bundle( + output_dir: str | Path, + *, + context_onnx: str | Path, + iterator_onnx: str | Path, + model_id: str, + max_cache_len: int, + prefill_seq_len: int, + embeddings_src: str | Path | None = None, + lm_head_src: str | Path | None = None, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", +) -> Path: + """Assemble a Qwen3 genai bundle, routing ctx/iter to QNN when ``ep="qnn"``. + + Qwen3-specific wrapper over + :func:`winml.modelkit.utils.genai.write_genai_bundle` that supplies the QNN + ``session_options`` for the transformer stages. See the generic function for + the description of every other argument. + Args: + ep: ``"qnn"`` routes the transformer (context/iterator) stages to the QNN + HTP (NPU) backend; ``"cpu"`` (default) keeps every stage on CPU. + soc_model: Snapdragon SoC model passed to the QNN backend when + ``ep="qnn"``. Default ``"60"`` = Snapdragon 8 Gen 3 / X Elite. -# Backward-compatible alias: existing callers that import -# ``build_qwen3_transformer_only_stages`` continue to work unchanged. -build_qwen3_transformer_only_stages = build_decoder_pipeline_stages + Returns: + Path to the written ``genai_config.json``. + """ + ctx_opts, iter_opts = _stage_session_options(ep, soc_model) + return _write_genai_bundle( + output_dir, + context_onnx=context_onnx, + iterator_onnx=iterator_onnx, + model_id=model_id, + max_cache_len=max_cache_len, + prefill_seq_len=prefill_seq_len, + embeddings_src=embeddings_src, + lm_head_src=lm_head_src, + context_filename=context_filename, + iterator_filename=iterator_filename, + embeddings_filename=embeddings_filename, + lm_head_filename=lm_head_filename, + context_session_options=ctx_opts, + iterator_session_options=iter_opts, + ) -# Keep the private EP helper importable under its old name for any callers -# that referenced it before the rename. -_qnn_stage_session_options = qnn_stage_session_options __all__ = [ "DEFAULT_CONTEXT_FILENAME", diff --git a/src/winml/modelkit/session/__init__.py b/src/winml/modelkit/session/__init__.py index d11673961..5148da0b3 100644 --- a/src/winml/modelkit/session/__init__.py +++ b/src/winml/modelkit/session/__init__.py @@ -5,13 +5,6 @@ """WinMLSession - ONNX Runtime session manager with WinML EP integration.""" from .ep_registry import WinMLEPRegistry -from .genai_session import ( - GenaiLoadError, - GenaiNotInstalledError, - GenaiSession, - GenaiSessionError, - GenerationConfig, -) from .monitor.ep_monitor import EPMonitor, NullEPMonitor from .monitor.hw_monitor import HWMonitor from .monitor.openvino_monitor import OpenVinoMonitor @@ -24,11 +17,6 @@ __all__ = [ "EPMonitor", - "GenaiLoadError", - "GenaiNotInstalledError", - "GenaiSession", - "GenaiSessionError", - "GenerationConfig", "HWMonitor", "InferenceError", "NullEPMonitor", diff --git a/src/winml/modelkit/session/genai_session.py b/src/winml/modelkit/session/genai_session.py deleted file mode 100644 index d920d96f3..000000000 --- a/src/winml/modelkit/session/genai_session.py +++ /dev/null @@ -1,699 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""GenaiSession — onnxruntime-genai session for multi-model decoder pipelines. - -Manages ``og.Model`` + ``og.Generator`` lifecycle for autoregressive text -generation. Reuses :class:`WinMLEPRegistry` for EP discovery and registration -so EPs are downloaded / registered at most once per process. - -Unlike :class:`WinMLSession` (which wraps ``ort.InferenceSession`` for -single-shot inference), ``GenaiSession`` drives a streaming token-by-token -generation loop. The two classes are peers — neither inherits from the other. - -Bundle directory layout expected by ``onnxruntime-genai``:: - - / - genai_config.json ← required; controls pipeline & search - ctx.onnx ← prefill transformer graph - iter.onnx ← decode transformer graph - embeddings.onnx ← embedding lookup - lm_head.onnx ← logit projection - tokenizer.json ← HF tokenizer files - tokenizer_config.json - ... - -Usage:: - - # Context manager (recommended — auto-loads and unloads) - with GenaiSession("out/qwen3_bundle", ep="qnn") as session: - for token_str in session.generate_streaming("Hello, who are you?"): - print(token_str, end="", flush=True) - - # Manual lifecycle - session = GenaiSession("out/qwen3_bundle", ep="cpu") - session.load() - result = session.generate("What is a transformer?") - session.unload() - -Dependencies:: - - pip install onnxruntime-genai-winml - pip install "windowsml[with-ort]" # registers QNN EP; also provides ORT -""" - -from __future__ import annotations - -import json -import logging -import shutil -from dataclasses import dataclass -from pathlib import Path -from typing import TYPE_CHECKING - -from .ep_registry import WinMLEPRegistry - - -if TYPE_CHECKING: - from collections.abc import Iterator - - -logger = logging.getLogger(__name__) - - -# --------------------------------------------------------------------------- -# Module-level compilation worker (must be at module scope for multiprocessing -# spawn on Windows, which serialises the target via pickle). -# --------------------------------------------------------------------------- - - -def _qnn_compile_worker(src: str, dst: str, qnn_options: dict) -> None: - """Compile *src* ONNX to an EPContext ONNX at *dst* using QNN HTP. - - Executed in a subprocess by :meth:`GenaiSession._compile_stage`. - """ - import onnxruntime as ort - - from ..winml import add_ep_for_device - from .ep_registry import WinMLEPRegistry - - registry = WinMLEPRegistry.get_instance() - registry.register_execution_providers() - so = ort.SessionOptions() - so.add_session_config_entry("ep.context_enable", "1") - so.add_session_config_entry("ep.context_file_path", dst) - add_ep_for_device(so, "QNNExecutionProvider", ort.OrtHardwareDeviceType.NPU, qnn_options) - mc = ort.ModelCompiler(so, src, embed_compiled_data_into_model=False) - mc.compile_to_file(dst) - - -# --------------------------------------------------------------------------- -# Valid EP short names. -# "mixed" = use genai_config.json as-is (embeddings/lm_head on CPU, -# ctx/iter on the target accelerator). -# EP routing is driven entirely by per-stage session_options in the bundle's -# genai_config.json — GenaiSession never calls clear_providers/append_provider. -# --------------------------------------------------------------------------- -_VALID_EPS: frozenset[str] = frozenset({"cpu", "mixed", "qnn", "dml"}) -# EPs that require WinML EP discovery + registration before og.Model() init. -_NEEDS_WINML_EPS: frozenset[str] = frozenset({"mixed", "qnn", "dml"}) - - -# --------------------------------------------------------------------------- -# Data classes -# --------------------------------------------------------------------------- - - -@dataclass -class GenerationConfig: - """Search / sampling parameters for a single generation call. - - All parameters are forwarded to ``og.GeneratorParams.set_search_options``. - ``max_length`` is **not** configurable here — it is set to the bundle's - ``context_length`` (read from ``genai_config.json``) because the static KV - cache size is baked into the ONNX graphs at export time. - - Attributes: - max_new_tokens: Soft cap on the number of new tokens to generate. - Generation stops when the model signals EOS, when the KV buffer is - exhausted (``context_length``), or when this limit is reached, - whichever comes first. - do_sample: Enable sampling (``True``) vs greedy (``False``). - temperature: Sampling temperature. Ignored when ``do_sample=False``. - top_p: Nucleus sampling probability mass. Ignored when - ``do_sample=False``. - top_k: Top-K sampling. ``0`` disables the filter. Ignored when - ``do_sample=False``. - repetition_penalty: Multiplicative penalty for repeated tokens - (``1.0`` = no penalty). - """ - - max_new_tokens: int = 128 - do_sample: bool = False - temperature: float = 1.0 - top_p: float = 1.0 - top_k: int = 0 - repetition_penalty: float = 1.0 - - -# --------------------------------------------------------------------------- -# Exceptions -# --------------------------------------------------------------------------- - - -class GenaiSessionError(Exception): - """Base exception for GenaiSession.""" - - -class GenaiNotInstalledError(GenaiSessionError): - """``onnxruntime-genai`` (or ``onnxruntime-genai-winml``) is not installed.""" - - -class GenaiLoadError(GenaiSessionError): - """The bundle could not be loaded (bad config, EP unavailable, etc.).""" - - -# --------------------------------------------------------------------------- -# Session -# --------------------------------------------------------------------------- - - -class GenaiSession: - """ORT GenAI session for multi-model decoder-pipeline inference. - - Wraps ``og.Model`` + ``og.Generator`` to provide a clean generation API. - - The session is **stateless across calls**: each :meth:`generate_streaming` - call creates a fresh ``og.Generator`` so KV state does not persist between - prompts. Thread-safety within a single session is not guaranteed. - - Args: - bundle_dir: Path to the genai bundle directory. Must contain - ``genai_config.json`` and the ONNX files it references. - ep: Execution provider short name (``"cpu"``, ``"qnn"``, ``"dml"``). - Non-CPU EPs trigger WinML EP discovery and registration. - context_length: Override for the static KV cache length. When - ``None`` (default), read from ``genai_config.json``. - Must match the ``--max-cache-len`` used during the winml-cli build. - verbose: Enable ``onnxruntime-genai`` native model I/O logging. - compile: Pre-compile QNN pipeline stages to EPContext ONNX on first - run (inside ``bundle_dir/_compiled/``). Subsequent calls reuse - the cached EPContext files, eliminating per-run JIT overhead. - Only stages that can be compiled without hanging are attempted; - stages that fail compilation fall back to the original ONNX. - Has no effect when ``ep="cpu"``. - """ - - # Sub-directory within the bundle that holds pre-compiled EPContext ONNX files. - _COMPILED_SUBDIR: str = "_compiled" - - def __init__( - self, - bundle_dir: str | Path, - ep: str = "cpu", - *, - context_length: int | None = None, - verbose: bool = False, - compile: bool = False, - ) -> None: - self._bundle_dir = Path(bundle_dir) - self._ep = ep.lower() - self._context_length_override = context_length - self._verbose = verbose - self._compile = compile - - # Resolved at load() time. - self._context_length: int | None = None - - # og.* handles — None until load() is called. - self._model: object | None = None - self._tokenizer: object | None = None - - if not self._bundle_dir.exists(): - raise FileNotFoundError(f"Bundle directory not found: {self._bundle_dir}") - config_path = self._bundle_dir / "genai_config.json" - if not config_path.exists(): - raise FileNotFoundError( - f"genai_config.json not found in {self._bundle_dir}. " - "Run export_qwen3_transformer_only.py --genai-bundle first." - ) - if self._ep not in _VALID_EPS: - raise ValueError(f"Unknown EP {ep!r}. Supported: {sorted(_VALID_EPS)}") - - logger.info("GenaiSession initialized: bundle=%s ep=%s", self._bundle_dir, self._ep) - - # ------------------------------------------------------------------ - # Lifecycle - # ------------------------------------------------------------------ - - def load(self) -> None: - """Load ``og.Model`` and tokenizer from the bundle directory. - - Idempotent: calling ``load()`` on an already-loaded session is a no-op. - - Raises: - GenaiNotInstalledError: ``onnxruntime_genai`` is not installed. - GenaiLoadError: The model could not be loaded (EP error, bad config, - missing ONNX files, …). - """ - if self._model is not None: - return - - og = self._import_og() - - # Register WinML EPs to ORT GenAI when the bundle may use a hardware EP. - if self._ep in _NEEDS_WINML_EPS: - self._register_eps(og) - - if self._verbose: - og.set_log_options(enabled=True, model_input_values=True, model_output_shapes=True) - - # Determine which bundle directory og.Config should load from. - load_dir = self._bundle_dir - if self._compile and self._ep in _NEEDS_WINML_EPS: - load_dir = self._prepare_compiled_bundle() - - try: - config = og.Config(str(load_dir)) - # EP routing is driven entirely by genai_config.json (per-stage - # session_options). Do NOT call clear_providers/append_provider — - # those only touch the top-level provider and cannot override - # per-stage session_options already embedded in the pipeline config. - self._model = og.Model(config) - self._tokenizer = og.Tokenizer(self._model) - except Exception as exc: - self._model = None - self._tokenizer = None - raise GenaiLoadError(f"Failed to load genai bundle from {load_dir}: {exc}") from exc - - self._context_length = self._context_length_override or self._read_context_length() - logger.info( - "GenaiSession loaded: ep=%s context_length=%d", - self._ep, - self._context_length, - ) - - def unload(self) -> None: - """Release ``og.Model`` and tokenizer handles. - - Safe to call on an unloaded session. - """ - self._model = None - self._tokenizer = None - self._context_length = None - logger.info("GenaiSession unloaded: bundle=%s", self._bundle_dir) - - def __enter__(self) -> GenaiSession: - self.load() - return self - - def __exit__(self, *_: object) -> None: - self.unload() - - # ------------------------------------------------------------------ - # Generation - # ------------------------------------------------------------------ - - def generate( - self, - prompt: str | list[int], - config: GenerationConfig | None = None, - ) -> str: - """Generate text and return the full response as a single string. - - This is a convenience wrapper around :meth:`generate_streaming`. - - Args: - prompt: Input text (auto-encoded) or a pre-encoded token-ID list. - config: Generation parameters. Uses :class:`GenerationConfig` - defaults when ``None``. - - Returns: - The generated text (not including the prompt). - """ - return "".join(self.generate_streaming(prompt, config)) - - def generate_streaming( - self, - prompt: str | list[int], - config: GenerationConfig | None = None, - ) -> Iterator[str]: - """Generate text token-by-token, yielding decoded token strings. - - The method auto-loads the session on the first call (lazy-load - equivalent of :meth:`load`). - - Each yield is the decoded string for a single new token. Callers - typically ``print(token, end="", flush=True)`` to stream output. - - Args: - prompt: Input text (auto-encoded via the bundle tokenizer) or a - pre-encoded token-ID list. Pass a pre-formatted string when - chat templates or special tokens are needed — the session is - not aware of any particular model's template format. - config: Generation parameters. Uses :class:`GenerationConfig` - defaults when ``None``. - - Yields: - Decoded string for each newly generated token. - """ - self._ensure_loaded() - og = self._import_og() - cfg = config or GenerationConfig() - - tokens = ( - self._tokenizer.encode(prompt) # type: ignore[union-attr] - if isinstance(prompt, str) - else prompt - ) - - params = og.GeneratorParams(self._model) - params.set_search_options( - max_length=self._context_length, - do_sample=cfg.do_sample, - temperature=cfg.temperature, - top_p=cfg.top_p, - top_k=cfg.top_k, - repetition_penalty=cfg.repetition_penalty, - ) - - generator = og.Generator(self._model, params) - generator.append_tokens(tokens) - - stream = self._tokenizer.create_stream() # type: ignore[union-attr] - n = 0 - try: - while not generator.is_done(): - generator.generate_next_token() - new_token = generator.get_next_tokens()[0] - yield stream.decode(new_token) - n += 1 - if n >= cfg.max_new_tokens: - break - finally: - # Explicit deletion releases the KV cache buffer held by the generator. - # This runs whether the caller exhausts the iterator normally, breaks - # out early, or the generator is garbage-collected — preventing the NPU - # memory from being held until a later GC cycle. - del generator - - # ------------------------------------------------------------------ - # Chat-template helpers - # ------------------------------------------------------------------ - - @staticmethod - def apply_chatml_template( - prompt: str, - system: str | None = None, - ) -> str: - r"""Wrap *prompt* in the ChatML format used by Qwen2/3, Yi, Mistral, etc. - - Produces:: - - <|im_start|>system - {system}<|im_end|> - <|im_start|>user - {prompt}<|im_end|> - <|im_start|>assistant - - The trailing ``<|im_start|>assistant\\n`` primes the model to respond - as the assistant role with no leading newline in its output. - - Args: - prompt: The user turn text. - system: Optional system prompt. When ``None`` no system turn is - prepended. - - Returns: - Formatted string ready to pass to :meth:`generate` / - :meth:`generate_streaming`. - """ - parts: list[str] = [] - if system is not None: - parts.append(f"<|im_start|>system\n{system}<|im_end|>\n") - parts.append(f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n") - return "".join(parts) - - # ------------------------------------------------------------------ - # Tokenizer helpers - # ------------------------------------------------------------------ - - def encode(self, text: str) -> list[int]: - """Encode *text* to a list of token IDs using the bundle tokenizer.""" - self._ensure_loaded() - return self._tokenizer.encode(text).tolist() # type: ignore[union-attr] - - def decode(self, tokens: list[int]) -> str: - """Decode a list of token IDs to a string.""" - self._ensure_loaded() - return self._tokenizer.decode(tokens) # type: ignore[union-attr] - - # ------------------------------------------------------------------ - # Properties - # ------------------------------------------------------------------ - - @property - def is_loaded(self) -> bool: - """``True`` if the model is loaded and ready for generation.""" - return self._model is not None - - @property - def bundle_dir(self) -> Path: - """Path to the genai bundle directory.""" - return self._bundle_dir - - @property - def ep(self) -> str: - """Execution provider short name (as passed to ``__init__``).""" - return self._ep - - @property - def context_length(self) -> int | None: - """Static KV cache length, populated after :meth:`load`.""" - return self._context_length - - # ------------------------------------------------------------------ - # Private helpers - # ------------------------------------------------------------------ - - def _ensure_loaded(self) -> None: - if self._model is None: - self.load() - - def _prepare_compiled_bundle(self) -> Path: - """Create (or reuse) a *compiled* bundle directory. - - Reads ``genai_config.json``, finds QNN-accelerated stages (those with - ``QNNExecutionProvider`` in their ``session_options``), and tries to - compile their ONNX to EPContext format using ``ort.ModelCompiler``. - - The compiled bundle is stored under ``bundle_dir/_compiled/``. On - every call the helper checks whether the cached EPContext file is - newer than the source ONNX; if so, it skips recompilation. - - Returns: - Path to the compiled bundle directory (may equal ``bundle_dir`` - if no compilable stages were found, or if all compilations failed). - """ - compiled_dir = self._bundle_dir / self._COMPILED_SUBDIR - config_src = self._bundle_dir / "genai_config.json" - cfg = json.loads(config_src.read_text(encoding="utf-8")) - - # Collect pipeline stages that use a QNN EP ("qnn" key in provider_options). - # genai_config pipeline entries: [{"context": {...}}, {"iterator": {...}}, ...] - # provider_options format: [{"qnn": {...}}] - pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", []) - # [(stage_key, onnx_filename, qnn_opts), ...] - qnn_stages: list[tuple[str, str, dict]] = [] - for stage_entry in pipeline_list: - if not isinstance(stage_entry, dict): - continue - for stage_key, stage_cfg in stage_entry.items(): - if not isinstance(stage_cfg, dict): - continue - so = stage_cfg.get("session_options", {}) - providers = so.get("provider_options", []) - for p in providers: - if isinstance(p, dict) and "qnn" in p: - onnx_filename = stage_cfg.get("filename", f"{stage_key}.onnx") - qnn_stages.append((stage_key, onnx_filename, dict(p["qnn"]))) - break - - if not qnn_stages: - logger.info("No QNN stages found in genai_config.json; skipping compilation") - return self._bundle_dir - - compiled_dir.mkdir(exist_ok=True) - modified_cfg = json.loads(config_src.read_text(encoding="utf-8")) - any_compiled = False - - for stage_key, onnx_filename, qnn_opts in qnn_stages: - src_onnx = self._bundle_dir / onnx_filename - ctx_onnx = compiled_dir / f"{stage_key}_ctx.onnx" - - # Skip recompilation if cache is up-to-date. - if ctx_onnx.exists() and ctx_onnx.stat().st_mtime >= src_onnx.stat().st_mtime: - logger.info("Stage %r: reusing cached EPContext %s", stage_key, ctx_onnx.name) - # Use just the filename — genai_config.json lives in compiled_dir, - # so ort-genai resolves filenames relative to compiled_dir. - self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name) - any_compiled = True - continue - - # Attempt compilation. - success = self._compile_stage(src_onnx, ctx_onnx, stage_key, qnn_opts) - if success: - self._patch_stage_filename(modified_cfg, stage_key, ctx_onnx.name) - any_compiled = True - else: - logger.warning( - "Stage %r: compilation failed; using original ONNX (JIT fallback)", stage_key - ) - # Patch to the absolute source path so ort-genai can find the - # file when loading from compiled_dir (where it doesn't exist). - self._patch_stage_filename(modified_cfg, stage_key, str(src_onnx.resolve())) - - if not any_compiled: - return self._bundle_dir - - # Write the modified genai_config into the compiled sub-directory. - # ONNX filenames are relative to compiled_dir; ort-genai resolves them - # from the directory it loads og.Config from. - compiled_config = compiled_dir / "genai_config.json" - compiled_config.write_text( - json.dumps(modified_cfg, indent=2, ensure_ascii=False), encoding="utf-8" - ) - # Only skip the QNN-stage ONNX files (compiled → new path, or failed → - # absolute path patch). Non-QNN ONNX files (embeddings, lm_head) must - # be symlinked into compiled_dir so ort-genai can find them by filename. - compiled_onnx_names = {onnx_filename for _, onnx_filename, _ in qnn_stages} - self._mirror_non_onnx_files(compiled_dir, skip_filenames=compiled_onnx_names) - - logger.info("Compiled bundle prepared at %s", compiled_dir) - return compiled_dir - - @staticmethod - def _patch_stage_filename(cfg: dict, stage_key: str, abs_path: str) -> None: - """Rewrite a pipeline stage's ``filename`` to an absolute path.""" - pipeline_list: list = cfg.get("model", {}).get("decoder", {}).get("pipeline", []) - for stage_entry in pipeline_list: - if isinstance(stage_entry, dict) and stage_key in stage_entry: - stage_cfg = stage_entry[stage_key] - if isinstance(stage_cfg, dict): - stage_cfg["filename"] = abs_path - return - - def _compile_stage( - self, - src_onnx: Path, - ctx_out: Path, - stage_key: str, - qnn_opts: dict | None = None, - ) -> bool: - """Compile *src_onnx* to EPContext format via ``ort.ModelCompiler``. - - Runs in a subprocess so that a ModelCompiler failure does not block - the caller. The QNN options from ``genai_config.json`` are forwarded - unchanged to the compilation session, so each stage is compiled at - exactly the optimization level configured in the bundle. - - Args: - src_onnx: Source ONNX file path. - ctx_out: Destination EPContext ONNX path. - stage_key: Human-readable label for logging. - qnn_opts: QNN provider options from genai_config (e.g. backend_path, - htp_performance_mode, htp_graph_finalization_optimization_mode, - soc_model). - - Returns: - ``True`` if compilation succeeded; ``False`` on timeout or error. - """ - import multiprocessing - - compile_qnn_opts = dict(qnn_opts or {}) - compile_timeout_s = 300 # 5 minutes; ctx compiles in ~73s, iter in ~67s - - logger.info( - "Compiling stage %r: %s → %s (qnn_opts=%s)", - stage_key, - src_onnx.name, - ctx_out.name, - compile_qnn_opts, - ) - - ctx = multiprocessing.get_context("spawn") - proc = ctx.Process( - target=_qnn_compile_worker, args=(str(src_onnx), str(ctx_out), compile_qnn_opts) - ) - proc.start() - proc.join(timeout=compile_timeout_s) - - if proc.is_alive(): - logger.error( - "Stage %r compilation timed out after %ds — killing subprocess.", - stage_key, - compile_timeout_s, - ) - proc.kill() - proc.join() - ctx_out.unlink(missing_ok=True) - return False - - if proc.exitcode != 0: - logger.warning("Stage %r compilation failed (exit %d)", stage_key, proc.exitcode) - ctx_out.unlink(missing_ok=True) - return False - - logger.info("Stage %r compiled successfully → %s", stage_key, ctx_out) - return True - - def _mirror_non_onnx_files( - self, compiled_dir: Path, skip_filenames: set[str] | None = None - ) -> None: - """Create symlinks (or copies on Windows) for files not being compiled. - - Links files into *compiled_dir* so ``og.Config`` finds tokenizer files, - non-QNN ONNX models (embeddings, lm_head), etc. Only ONNX files listed - in *skip_filenames* (the QNN-compiled stages) and their external ``.data`` - siblings are skipped — everything else, including CPU-side ONNX files, is - linked. Existing files are left untouched. - """ - skip = set(skip_filenames or []) - # Also skip .data sidecars of the compiled-stage ONNX files. - skip_data = {f"{name}.data" for name in skip} | {f"{name}.data" for name in skip} - for src in self._bundle_dir.iterdir(): - if src.name == self._COMPILED_SUBDIR: - continue - if src.name in skip or src.name in skip_data: - continue - dst = compiled_dir / src.name - if dst.exists(): - continue - if src.is_file(): - try: - dst.symlink_to(src.resolve()) - except (OSError, NotImplementedError): - shutil.copy2(src, dst) - - @staticmethod - def _import_og() -> object: - """Import and return the ``onnxruntime_genai`` module. - - Raises: - GenaiNotInstalledError: Package not found. - """ - try: - import onnxruntime_genai as og - - return og - except ImportError as exc: - raise GenaiNotInstalledError( - "onnxruntime_genai is not installed. " - "Install it with: pip install onnxruntime-genai-winml" - ) from exc - - def _register_eps(self, og: object) -> None: - """Register WinML EPs with ORT GenAI (idempotent, best-effort).""" - try: - registry = WinMLEPRegistry.get_instance() - if registry.winml_available: - result = registry.register_execution_providers(ort_genai=True) - registered = result.get("onnxruntime_genai", []) - logger.info("WinML EPs registered for ORT GenAI: %s", registered) - except Exception as exc: - logger.warning("WinML EP registration skipped: %s", exc) - - def _read_context_length(self) -> int: - """Read ``model.context_length`` from ``genai_config.json``.""" - cfg = json.loads((self._bundle_dir / "genai_config.json").read_text(encoding="utf-8")) - return int(cfg["model"]["context_length"]) - - -__all__ = [ - "GenaiLoadError", - "GenaiNotInstalledError", - "GenaiSession", - "GenaiSessionError", - "GenerationConfig", -] diff --git a/src/winml/modelkit/utils/genai.py b/src/winml/modelkit/utils/genai.py index 0bab4acd7..c8db1d21a 100644 --- a/src/winml/modelkit/utils/genai.py +++ b/src/winml/modelkit/utils/genai.py @@ -26,6 +26,11 @@ subsequent decode step. Both share the same KV cache buffer via genai's ``past_present_share_buffer`` mode. +Per-stage execution-provider routing (e.g. running the transformer stages on an +NPU) is expressed through the generic ``PipelineStage.session_options`` field and +is supplied by the caller — this module is itself execution-provider-agnostic and +hardcodes no EP-specific settings. + Public API:: from winml.modelkit.utils.genai import ( @@ -34,12 +39,11 @@ write_genai_bundle, DecoderIOMapping, PipelineStage, - qnn_stage_session_options, ) # Build stages by introspecting the ONNX I/O (no hardcoded tensor names) stages, decoder_io = build_decoder_pipeline_stages( - ctx_path, iter_path, num_layers=hf_config.num_hidden_layers, ep="qnn" + ctx_path, iter_path, num_layers=hf_config.num_hidden_layers ) cfg = build_genai_config( hf_config, max_cache_len=256, prefill_seq_len=64, @@ -56,7 +60,6 @@ prefill_seq_len=64, embeddings_src=emb_path, # None = skip (add later) lm_head_src=lmh_path, # None = skip (add later) - ep="qnn", ) """ @@ -109,7 +112,8 @@ class PipelineStage: outputs: list[str] is_lm_head: bool = False session_options: dict | None = None - """Per-stage ORT session options (e.g. provider_options for QNN). + """Per-stage ORT session options (e.g. execution-provider selection and + provider_options). When set, emitted verbatim as the ``session_options`` key in the ``genai_config.json`` pipeline stage. Leave ``None`` (default) for @@ -337,42 +341,6 @@ def _key(prefix: str) -> int: return sorted(patterns.keys(), key=_key) -# --------------------------------------------------------------------------- -# Per-EP stage session_options helpers -# --------------------------------------------------------------------------- - - -def qnn_stage_session_options(log_id: str, soc_model: str = "60") -> dict: - """Return the ``session_options`` block that routes a stage to QNN HTP. - - Args: - log_id: ORT log identifier (shown in ORT logs), e.g. - ``"onnxruntime-genai.context"``. - soc_model: Snapdragon SoC model number passed to the QNN HTP backend. - ``"60"`` targets Snapdragon 8 Gen 3 (X Elite). Change for other - SoCs (e.g. ``"55"`` for 8 Gen 2, ``"73"`` for 8 Elite). - - Returns: - Dict suitable for the ``session_options`` key of a pipeline stage in - ``genai_config.json``. - """ - return { - "log_id": log_id, - "provider_options": [ - { - "qnn": { - "backend_path": "QnnHtp.dll", - "htp_performance_mode": "burst", - "htp_graph_finalization_optimization_mode": "3", - "soc_model": soc_model, - } - } - ], - "intra_op_num_threads": 2, - "inter_op_num_threads": 1, - } - - # --------------------------------------------------------------------------- # Generic decoder-pipeline stage factory # --------------------------------------------------------------------------- @@ -387,8 +355,8 @@ def build_decoder_pipeline_stages( iterator_filename: str = DEFAULT_ITERATOR_FILENAME, embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, - ep: str = "cpu", - soc_model: str = "60", + context_session_options: dict | None = None, + iterator_session_options: dict | None = None, ) -> tuple[list[PipelineStage], DecoderIOMapping]: """Build pipeline stages by introspecting the built ONNX models. @@ -404,13 +372,13 @@ def build_decoder_pipeline_stages( iterator_filename: Bundle filename for the iterator model. embeddings_filename: Bundle filename for the embeddings model. lm_head_filename: Bundle filename for the lm_head model. - ep: Execution provider for the transformer stages. ``"qnn"`` injects - QNN HTP ``session_options`` into the ``context`` and ``iterator`` - stages so they run on the NPU while ``embeddings`` and ``lm_head`` - continue on CPU. ``"cpu"`` (default) omits ``session_options`` - from all stages. - soc_model: Snapdragon SoC model number forwarded to the QNN backend - when ``ep="qnn"``. Default ``"60"`` targets Snapdragon 8 Gen 3. + context_session_options: Optional ORT ``session_options`` dict attached + verbatim to the ``context`` stage (e.g. to route it to an + accelerator EP). ``None`` (default) runs the stage on CPU. This + function stays execution-provider-agnostic — the caller decides the + contents; no EP-specific values are constructed here. + iterator_session_options: Same as *context_session_options* but for the + ``iterator`` stage. Returns: ``(stages, decoder_io)`` — a 4-element :class:`PipelineStage` list and @@ -475,17 +443,6 @@ def _pick_kv( present_value_names=pres_val_fmt, ) - # Per-stage session_options: NPU stages get QNN config; CPU and others get None. - ctx_session_opts: dict | None = None - iter_session_opts: dict | None = None - if ep == "qnn": - ctx_session_opts = qnn_stage_session_options( - "onnxruntime-genai.context", soc_model=soc_model - ) - iter_session_opts = qnn_stage_session_options( - "onnxruntime-genai.iterator", soc_model=soc_model - ) - stages: list[PipelineStage] = [ PipelineStage( name="embeddings", @@ -502,7 +459,7 @@ def _pick_kv( run_on_token_gen=False, inputs=ctx_inputs, outputs=ctx_outputs, - session_options=ctx_session_opts, + session_options=context_session_options, ), PipelineStage( name="iterator", @@ -511,7 +468,7 @@ def _pick_kv( run_on_token_gen=True, inputs=iter_inputs, outputs=iter_outputs, - session_options=iter_session_opts, + session_options=iterator_session_options, ), PipelineStage( name="lm_head", @@ -574,8 +531,8 @@ def write_genai_bundle( iterator_filename: str = DEFAULT_ITERATOR_FILENAME, embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, - ep: str = "cpu", - soc_model: str = "60", + context_session_options: dict | None = None, + iterator_session_options: dict | None = None, ) -> Path: """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. @@ -598,12 +555,12 @@ def write_genai_bundle( iterator_filename: Bundle filename for the iterator model. embeddings_filename: Bundle filename for the embeddings model. lm_head_filename: Bundle filename for the lm_head model. - ep: Execution provider for the transformer (context/iterator) stages. - ``"qnn"`` injects QNN HTP ``session_options`` so those stages run - on the NPU while embeddings and lm_head run on CPU. - ``"cpu"`` (default) omits ``session_options`` (all stages on CPU). - soc_model: Snapdragon SoC model passed to the QNN backend when - ``ep="qnn"``. Default ``"60"`` = Snapdragon 8 Gen 3 / X Elite. + context_session_options: Optional ORT ``session_options`` dict attached + verbatim to the ``context`` stage. ``None`` (default) runs it on + CPU. This assembler is execution-provider-agnostic; the caller + supplies any EP-specific options. + iterator_session_options: Same as *context_session_options* but for the + ``iterator`` stage. Returns: Path to the written ``genai_config.json``. @@ -667,8 +624,8 @@ def write_genai_bundle( iterator_filename=iterator_filename, embeddings_filename=embeddings_filename, lm_head_filename=lm_head_filename, - ep=ep, - soc_model=soc_model, + context_session_options=context_session_options, + iterator_session_options=iterator_session_options, ) # 5. Write genai_config.json. @@ -719,6 +676,5 @@ def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: "PipelineStage", "build_decoder_pipeline_stages", "build_genai_config", - "qnn_stage_session_options", "write_genai_bundle", ] diff --git a/tests/unit/session/test_genai_session.py b/tests/unit/session/test_genai_session.py deleted file mode 100644 index 5dc496b49..000000000 --- a/tests/unit/session/test_genai_session.py +++ /dev/null @@ -1,380 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Unit tests for GenaiSession. - -All tests that touch load() / generate*() mock onnxruntime_genai so no -real model files or GPU/NPU hardware is required. -""" - -from __future__ import annotations - -import json -import sys -from typing import TYPE_CHECKING -from unittest.mock import MagicMock, patch - -import pytest - -from winml.modelkit.session import ( - GenaiLoadError, - GenaiNotInstalledError, - GenaiSession, - GenaiSessionError, - GenerationConfig, -) - - -if TYPE_CHECKING: - from pathlib import Path - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def bundle_dir(tmp_path: Path) -> Path: - """Create a minimal genai bundle directory with genai_config.json.""" - cfg = { - "model": { - "type": "decoder-pipeline", - "context_length": 256, - "decoder": {}, - }, - "search": {"max_length": 256}, - } - (tmp_path / "genai_config.json").write_text(json.dumps(cfg), encoding="utf-8") - return tmp_path - - -@pytest.fixture -def mock_og() -> MagicMock: - """Return a fully mocked onnxruntime_genai module.""" - og = MagicMock(name="onnxruntime_genai") - og.Config.return_value = MagicMock() - og.Model.return_value = MagicMock() - og.Tokenizer.return_value = MagicMock() - og.GeneratorParams.return_value = MagicMock() - - # Generator that yields two tokens then is_done() - gen = MagicMock() - gen.is_done.side_effect = [False, False, True] - gen.get_next_tokens.side_effect = [ - MagicMock(__getitem__=lambda s, i: 10), - MagicMock(__getitem__=lambda s, i: 20), - ] - og.Generator.return_value = gen - - # TokenizerStream decodes tokens to text - stream = MagicMock() - stream.decode.side_effect = ["Hello", " world"] - og.Tokenizer.return_value.create_stream.return_value = stream - - return og - - -def _patch_og(mock: MagicMock): - """Context manager: inject mock_og as onnxruntime_genai in sys.modules.""" - return patch.dict(sys.modules, {"onnxruntime_genai": mock}) - - -# --------------------------------------------------------------------------- -# Tests: GenaiSession.__init__ -# --------------------------------------------------------------------------- - - -class TestGenaiSessionInit: - def test_missing_bundle_dir_raises(self, tmp_path: Path) -> None: - with pytest.raises(FileNotFoundError, match="Bundle directory not found"): - GenaiSession(tmp_path / "nonexistent") - - def test_missing_config_raises(self, tmp_path: Path) -> None: - # Dir exists but no genai_config.json - with pytest.raises(FileNotFoundError, match=r"genai_config\.json not found"): - GenaiSession(tmp_path) - - def test_unknown_ep_raises(self, bundle_dir: Path) -> None: - with pytest.raises(ValueError, match="Unknown EP"): - GenaiSession(bundle_dir, ep="tensorrt") - - def test_default_ep_is_cpu(self, bundle_dir: Path) -> None: - session = GenaiSession(bundle_dir) - assert session.ep == "cpu" - - def test_not_loaded_after_init(self, bundle_dir: Path) -> None: - session = GenaiSession(bundle_dir) - assert not session.is_loaded - assert session.context_length is None - - def test_bundle_dir_property(self, bundle_dir: Path) -> None: - session = GenaiSession(bundle_dir) - assert session.bundle_dir == bundle_dir - - def test_supported_eps(self, bundle_dir: Path) -> None: - for ep in ("cpu", "mixed", "qnn", "dml"): - session = GenaiSession(bundle_dir, ep=ep) - assert session.ep == ep - - -# --------------------------------------------------------------------------- -# Tests: load / unload -# --------------------------------------------------------------------------- - - -class TestGenaiSessionLoad: - def test_load_sets_is_loaded(self, bundle_dir: Path, mock_og: MagicMock) -> None: - with _patch_og(mock_og): - session = GenaiSession(bundle_dir) - session.load() - assert session.is_loaded - - def test_load_reads_context_length_from_config( - self, bundle_dir: Path, mock_og: MagicMock - ) -> None: - with _patch_og(mock_og): - session = GenaiSession(bundle_dir) - session.load() - assert session.context_length == 256 - - def test_context_length_override(self, bundle_dir: Path, mock_og: MagicMock) -> None: - with _patch_og(mock_og): - session = GenaiSession(bundle_dir, context_length=512) - session.load() - assert session.context_length == 512 - - def test_load_is_idempotent(self, bundle_dir: Path, mock_og: MagicMock) -> None: - with _patch_og(mock_og): - session = GenaiSession(bundle_dir) - session.load() - session.load() # second call is a no-op - assert mock_og.Model.call_count == 1 - - def test_unload_clears_state(self, bundle_dir: Path, mock_og: MagicMock) -> None: - with _patch_og(mock_og): - session = GenaiSession(bundle_dir) - session.load() - session.unload() - assert not session.is_loaded - assert session.context_length is None - - def test_unload_on_unloaded_session_is_safe(self, bundle_dir: Path) -> None: - session = GenaiSession(bundle_dir) - session.unload() # should not raise - - def test_context_manager_loads_and_unloads(self, bundle_dir: Path, mock_og: MagicMock) -> None: - with _patch_og(mock_og), GenaiSession(bundle_dir) as session: - assert session.is_loaded - assert not session.is_loaded - - def test_genai_not_installed_raises(self, bundle_dir: Path) -> None: - with patch.dict(sys.modules, {"onnxruntime_genai": None}): # type: ignore[dict-item] - session = GenaiSession(bundle_dir) - with pytest.raises(GenaiNotInstalledError): - session.load() - - def test_og_load_error_raises_genai_load_error( - self, bundle_dir: Path, mock_og: MagicMock - ) -> None: - mock_og.Model.side_effect = RuntimeError("driver not found") - with _patch_og(mock_og): - session = GenaiSession(bundle_dir) - with pytest.raises(GenaiLoadError, match="driver not found"): - session.load() - - def test_og_load_error_leaves_session_unloaded( - self, bundle_dir: Path, mock_og: MagicMock - ) -> None: - mock_og.Model.side_effect = RuntimeError("driver not found") - with _patch_og(mock_og): - session = GenaiSession(bundle_dir) - with pytest.raises(GenaiLoadError): - session.load() - assert not session.is_loaded - - -# --------------------------------------------------------------------------- -# Tests: EP registration -# --------------------------------------------------------------------------- - - -class TestEPRegistration: - def test_cpu_skips_winml_registration(self, bundle_dir: Path, mock_og: MagicMock) -> None: - with ( - _patch_og(mock_og), - patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, - ): - session = GenaiSession(bundle_dir, ep="cpu") - session.load() - mock_reg_cls.assert_not_called() - - def test_non_cpu_registers_winml_eps(self, bundle_dir: Path, mock_og: MagicMock) -> None: - mock_registry = MagicMock() - mock_registry.winml_available = True - mock_registry.register_execution_providers.return_value = { - "onnxruntime_genai": ["QNNExecutionProvider"] - } - with ( - _patch_og(mock_og), - patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, - ): - mock_reg_cls.get_instance.return_value = mock_registry - session = GenaiSession(bundle_dir, ep="qnn") - session.load() - mock_registry.register_execution_providers.assert_called_once_with(ort_genai=True) - - def test_mixed_registers_winml_eps(self, bundle_dir: Path, mock_og: MagicMock) -> None: - mock_registry = MagicMock() - mock_registry.winml_available = True - mock_registry.register_execution_providers.return_value = { - "onnxruntime_genai": ["QNNExecutionProvider"] - } - with ( - _patch_og(mock_og), - patch("winml.modelkit.session.genai_session.WinMLEPRegistry") as mock_reg_cls, - ): - mock_reg_cls.get_instance.return_value = mock_registry - session = GenaiSession(bundle_dir, ep="mixed") - session.load() - mock_registry.register_execution_providers.assert_called_once_with(ort_genai=True) - - def test_config_not_modified_at_load(self, bundle_dir: Path, mock_og: MagicMock) -> None: - # EP routing is driven by genai_config.json — we must NOT touch the config. - with _patch_og(mock_og): - session = GenaiSession(bundle_dir, ep="cpu") - session.load() - mock_og.Config.return_value.clear_providers.assert_not_called() - mock_og.Config.return_value.append_provider.assert_not_called() - - -# --------------------------------------------------------------------------- -# Tests: generate / generate_streaming -# --------------------------------------------------------------------------- - - -class TestGenerate: - def test_generate_streaming_yields_decoded_tokens( - self, bundle_dir: Path, mock_og: MagicMock - ) -> None: - with _patch_og(mock_og), GenaiSession(bundle_dir) as session: - tokens = list(session.generate_streaming("hi")) - assert tokens == ["Hello", " world"] - - def test_generate_returns_joined_string(self, bundle_dir: Path, mock_og: MagicMock) -> None: - with _patch_og(mock_og), GenaiSession(bundle_dir) as session: - result = session.generate("hi") - assert result == "Hello world" - - def test_generate_respects_max_new_tokens(self, bundle_dir: Path, mock_og: MagicMock) -> None: - # Generator never signals done; we stop at max_new_tokens=1 - gen = mock_og.Generator.return_value - gen.is_done.side_effect = None - gen.is_done.return_value = False - gen.get_next_tokens.return_value = MagicMock(__getitem__=lambda s, i: 99) - mock_og.Tokenizer.return_value.create_stream.return_value.decode.return_value = "x" - - with _patch_og(mock_og), GenaiSession(bundle_dir) as session: - tokens = list(session.generate_streaming("hi", GenerationConfig(max_new_tokens=1))) - assert len(tokens) == 1 - - def test_generate_with_token_list_input(self, bundle_dir: Path, mock_og: MagicMock) -> None: - """Pre-encoded token IDs are forwarded directly to append_tokens.""" - with _patch_og(mock_og), GenaiSession(bundle_dir) as session: - list(session.generate_streaming([1, 2, 3])) - gen = mock_og.Generator.return_value - gen.append_tokens.assert_called_once_with([1, 2, 3]) - - def test_generate_deletes_generator_after_iteration( - self, bundle_dir: Path, mock_og: MagicMock - ) -> None: - """Generator is deleted (not leaked) even on normal completion.""" - with _patch_og(mock_og), GenaiSession(bundle_dir) as session: - list(session.generate_streaming("hi")) - # No assertions needed — test passes if no ResourceWarning / hang - - def test_generate_with_custom_config(self, bundle_dir: Path, mock_og: MagicMock) -> None: - cfg = GenerationConfig(max_new_tokens=64, do_sample=True, temperature=0.7) - with _patch_og(mock_og), GenaiSession(bundle_dir) as session: - list(session.generate_streaming("hi", cfg)) - params = mock_og.GeneratorParams.return_value - params.set_search_options.assert_called_once() - call_kwargs = params.set_search_options.call_args.kwargs - assert call_kwargs["do_sample"] is True - assert call_kwargs["temperature"] == 0.7 - - def test_generate_uses_context_length_as_max_length( - self, bundle_dir: Path, mock_og: MagicMock - ) -> None: - with _patch_og(mock_og), GenaiSession(bundle_dir, context_length=128) as session: - list(session.generate_streaming("hi")) - params = mock_og.GeneratorParams.return_value - call_kwargs = params.set_search_options.call_args.kwargs - assert call_kwargs["max_length"] == 128 - - def test_auto_load_on_first_generate(self, bundle_dir: Path, mock_og: MagicMock) -> None: - with _patch_og(mock_og): - session = GenaiSession(bundle_dir) - assert not session.is_loaded - list(session.generate_streaming("hi")) - assert session.is_loaded - - -# --------------------------------------------------------------------------- -# Tests: apply_chatml_template -# --------------------------------------------------------------------------- - - -class TestApplyChatmlTemplate: - def test_user_only(self) -> None: - result = GenaiSession.apply_chatml_template("Hello") - assert result == "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" - - def test_with_system(self) -> None: - result = GenaiSession.apply_chatml_template("Hello", system="You are helpful.") - assert result.startswith("<|im_start|>system\nYou are helpful.<|im_end|>\n") - assert "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" in result - - def test_no_system_no_system_turn(self) -> None: - result = GenaiSession.apply_chatml_template("Hi") - assert "<|im_start|>system" not in result - - def test_ends_with_assistant_priming(self) -> None: - result = GenaiSession.apply_chatml_template("Hi") - assert result.endswith("<|im_start|>assistant\n") - - -# --------------------------------------------------------------------------- -# Tests: GenerationConfig defaults -# --------------------------------------------------------------------------- - - -class TestGenerationConfig: - def test_defaults(self) -> None: - cfg = GenerationConfig() - assert cfg.max_new_tokens == 128 - assert cfg.do_sample is False - assert cfg.temperature == 1.0 - assert cfg.top_p == 1.0 - assert cfg.top_k == 0 - assert cfg.repetition_penalty == 1.0 - - def test_custom_values(self) -> None: - cfg = GenerationConfig(max_new_tokens=32, do_sample=True, top_k=50) - assert cfg.max_new_tokens == 32 - assert cfg.do_sample is True - assert cfg.top_k == 50 - - -# --------------------------------------------------------------------------- -# Tests: exception hierarchy -# --------------------------------------------------------------------------- - - -class TestExceptions: - def test_genai_not_installed_is_genai_session_error(self) -> None: - assert issubclass(GenaiNotInstalledError, GenaiSessionError) - - def test_genai_load_error_is_genai_session_error(self) -> None: - assert issubclass(GenaiLoadError, GenaiSessionError) From 0b589f375fa2c0ded2d3e2674d05c51d0cb6a665 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 1 Jul 2026 19:24:01 +0800 Subject: [PATCH 20/29] fix(quant): use uint8 activations for transformer-only w8a8 (matches reference model) Switch activation_type from uint16 to uint8 to align with the reference qwen3-genai-share model (w8a8 QDQ, int8 weights + uint8 activations). This keeps ctx.onnx / iter.onnx at opset 18 instead of opset 21. ORT forces opset >= 21 for 16-bit QDQ (uint16), so the previous uint16 choice caused an automatic opset bump to 21 that deviated from the reference graph layout. Update test name and assertion accordingly. --- .../quant/calibration/qwen3_transformer_only.py | 10 ++++++---- tests/unit/quant/calibration/test_qwen3_calibration.py | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py index 04901ca7c..3c2397936 100644 --- a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py +++ b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Config-driven w8a16 calibration for the transformer-only Qwen3 build. +"""Config-driven w8a8 calibration for the transformer-only Qwen3 build. The transformer-only export (``models.hf.qwen3.qwen_transformer_only``) emits a graph whose only quantization-relevant runtime inputs (the calibration feeds and the @@ -351,7 +351,7 @@ def finalize_transformer_only_quant_config( picks generic dtypes; the transformer-only scheme is fixed and reference- matched, so this hook is authoritative: - - **int8-symmetric weights** (zp=0) + **uint16 asymmetric activations**, + - **int8-symmetric weights** (zp=0) + **uint8 asymmetric activations**, - **MinMax** calibration, ``mode="static"`` (forces QDQ dispatch), - GroupQueryAttention nodes excluded from QDQ (read from the graph), - the matching :class:`CalibrationDataReader` (prefill vs. decode-trajectory, @@ -365,14 +365,16 @@ def finalize_transformer_only_quant_config( seq_len, max_cache_len = _graph_shapes(onnx_path) gqa_nodes = _gqa_node_names(onnx_path) - # Fixed, reference-matched w8a16 scheme (authoritative over policy dtypes). + # Fixed, reference-matched w8a8 scheme (authoritative over policy dtypes). # ``mode`` must be pinned to "static": the new precision-driven flow keys the # quantizer dispatch on ``config.mode`` (fp16/rtn/static), so a build whose # precision policy resolved to "fp16"/"rtn" would otherwise bypass QDQ and # silently ignore the calibration reader + GQA exclusion set below. + # uint8 activations (matching the reference model) keep ctx/iter at opset 18; + # uint16 would force opset 21 (ORT requires opset >= 21 for 16-bit QDQ). quant.mode = "static" quant.weight_type = "int8" - quant.activation_type = "uint16" + quant.activation_type = "uint8" quant.weight_symmetric = True quant.activation_symmetric = False quant.calibration_method = "minmax" diff --git a/tests/unit/quant/calibration/test_qwen3_calibration.py b/tests/unit/quant/calibration/test_qwen3_calibration.py index 6881f0f72..a31ea9879 100644 --- a/tests/unit/quant/calibration/test_qwen3_calibration.py +++ b/tests/unit/quant/calibration/test_qwen3_calibration.py @@ -234,12 +234,12 @@ def test_decode_trajectory_reader_respects_max_cache(): assert max(int(f["past_seq_len"][0, 0]) for f in feeds) == max_cache_len - 1 -def test_finalize_pins_static_w8a16_scheme(tmp_path, monkeypatch): +def test_finalize_pins_static_w8a8_scheme(tmp_path, monkeypatch): """The finalizer is authoritative over the precision policy. The precision-driven build keys the quantizer dispatch on ``config.mode``, so the transformer-only policy must pin ``mode="static"`` (QDQ) along with - the reference-matched w8a16 dtypes/symmetry + GQA exclusion — even when the + the reference-matched w8a8 dtypes/symmetry + GQA exclusion — even when the incoming config arrived as a non-QDQ mode (e.g. ``fp16``/``rtn``). """ from winml.modelkit.quant import WinMLQuantizationConfig @@ -275,7 +275,7 @@ def test_finalize_pins_static_w8a16_scheme(tmp_path, monkeypatch): assert result.mode == "static" assert result.weight_type == "int8" - assert result.activation_type == "uint16" + assert result.activation_type == "uint8" assert result.weight_symmetric is True assert result.activation_symmetric is False assert result.calibration_method == "minmax" From 776f328355573402b9cd11de44cc17488338a6c7 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 1 Jul 2026 19:38:48 +0800 Subject: [PATCH 21/29] revert(quant): restore w8a16 (uint16 activations) for transformer-only Revert the uint8 change. uint16 activations give better generation quality at the cost of opset 21 (required by ORT for 16-bit QDQ). This is the correct precision for the QNN NPU pipeline. --- .../quant/calibration/qwen3_transformer_only.py | 12 ++++++------ .../unit/quant/calibration/test_qwen3_calibration.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py index 3c2397936..3e2637be6 100644 --- a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py +++ b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Config-driven w8a8 calibration for the transformer-only Qwen3 build. +"""Config-driven w8a16 calibration for the transformer-only Qwen3 build. The transformer-only export (``models.hf.qwen3.qwen_transformer_only``) emits a graph whose only quantization-relevant runtime inputs (the calibration feeds and the @@ -351,7 +351,7 @@ def finalize_transformer_only_quant_config( picks generic dtypes; the transformer-only scheme is fixed and reference- matched, so this hook is authoritative: - - **int8-symmetric weights** (zp=0) + **uint8 asymmetric activations**, + - **int8-symmetric weights** (zp=0) + **uint16 asymmetric activations**, - **MinMax** calibration, ``mode="static"`` (forces QDQ dispatch), - GroupQueryAttention nodes excluded from QDQ (read from the graph), - the matching :class:`CalibrationDataReader` (prefill vs. decode-trajectory, @@ -365,16 +365,16 @@ def finalize_transformer_only_quant_config( seq_len, max_cache_len = _graph_shapes(onnx_path) gqa_nodes = _gqa_node_names(onnx_path) - # Fixed, reference-matched w8a8 scheme (authoritative over policy dtypes). + # Fixed, reference-matched w8a16 scheme (authoritative over policy dtypes). # ``mode`` must be pinned to "static": the new precision-driven flow keys the # quantizer dispatch on ``config.mode`` (fp16/rtn/static), so a build whose # precision policy resolved to "fp16"/"rtn" would otherwise bypass QDQ and # silently ignore the calibration reader + GQA exclusion set below. - # uint8 activations (matching the reference model) keep ctx/iter at opset 18; - # uint16 would force opset 21 (ORT requires opset >= 21 for 16-bit QDQ). + # uint16 activations give higher precision than uint8; the trade-off is that + # ORT requires opset >= 21 for 16-bit QDQ, so ctx/iter will export at opset 21. quant.mode = "static" quant.weight_type = "int8" - quant.activation_type = "uint8" + quant.activation_type = "uint16" quant.weight_symmetric = True quant.activation_symmetric = False quant.calibration_method = "minmax" diff --git a/tests/unit/quant/calibration/test_qwen3_calibration.py b/tests/unit/quant/calibration/test_qwen3_calibration.py index a31ea9879..6881f0f72 100644 --- a/tests/unit/quant/calibration/test_qwen3_calibration.py +++ b/tests/unit/quant/calibration/test_qwen3_calibration.py @@ -234,12 +234,12 @@ def test_decode_trajectory_reader_respects_max_cache(): assert max(int(f["past_seq_len"][0, 0]) for f in feeds) == max_cache_len - 1 -def test_finalize_pins_static_w8a8_scheme(tmp_path, monkeypatch): +def test_finalize_pins_static_w8a16_scheme(tmp_path, monkeypatch): """The finalizer is authoritative over the precision policy. The precision-driven build keys the quantizer dispatch on ``config.mode``, so the transformer-only policy must pin ``mode="static"`` (QDQ) along with - the reference-matched w8a8 dtypes/symmetry + GQA exclusion — even when the + the reference-matched w8a16 dtypes/symmetry + GQA exclusion — even when the incoming config arrived as a non-QDQ mode (e.g. ``fp16``/``rtn``). """ from winml.modelkit.quant import WinMLQuantizationConfig @@ -275,7 +275,7 @@ def test_finalize_pins_static_w8a8_scheme(tmp_path, monkeypatch): assert result.mode == "static" assert result.weight_type == "int8" - assert result.activation_type == "uint8" + assert result.activation_type == "uint16" assert result.weight_symmetric is True assert result.activation_symmetric is False assert result.calibration_method == "minmax" From 854710f9863823ec0bb9f8fb50b9add029e7f777 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 1 Jul 2026 19:59:39 +0800 Subject: [PATCH 22/29] Strip exporter-injected default GQA attrs from transformer bundle ONNX MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add strip_node_attrs() to winml.modelkit.onnx — a generic utility that removes all attributes from matching op nodes except those listed in a keep_attrs set. Operates in-place on an onnx.ModelProto; safe for models with external data (modifies only the graph proto, not weight files). Wire it into write_genai_bundle() via a new transformer_onnx_passes parameter: a list of callables applied to ctx.onnx / iter.onnx after they are copied into the bundle directory. In scripts/qwen3.py, pass _strip_gqa_default_attrs (which retains only do_rotary / kv_num_heads / num_heads) to remove the five extra attrs that PyTorch's TorchScript ONNX exporter injects from the ORT com.microsoft::GroupQueryAttention schema: k_quant_type, local_window_size, qk_output, smooth_softmax, v_quant_type These are all no-op defaults and are absent from the reference model; stripping them brings our bundle's GQA attribute set in line with the reference. 8 new unit tests cover: extra-attr removal, keep-all, remove-all, domain mismatch no-op, multi-node graphs, and identity (same object returned). --- scripts/qwen3.py | 16 +++ src/winml/modelkit/onnx/__init__.py | 3 +- src/winml/modelkit/onnx/utils.py | 35 +++++++ src/winml/modelkit/utils/genai.py | 27 ++++- tests/unit/onnx/test_utils.py | 150 ++++++++++++++++++++++++++++ 5 files changed, 229 insertions(+), 2 deletions(-) create mode 100644 tests/unit/onnx/test_utils.py diff --git a/scripts/qwen3.py b/scripts/qwen3.py index 57b6f486b..27d72ce34 100644 --- a/scripts/qwen3.py +++ b/scripts/qwen3.py @@ -39,6 +39,8 @@ from winml.modelkit.models.hf.qwen3.qwen_transformer_only import ( WinMLQwen3TransformerOnlyModel, ) +from winml.modelkit.onnx import strip_node_attrs +from winml.modelkit.session import GenaiSession, GenerationConfig _DEVICE_TO_EP = { @@ -56,6 +58,19 @@ _REPO_ROOT = Path(__file__).resolve().parent.parent _DEFAULT_BUNDLE = _REPO_ROOT / "out" / "bundle" +_SUPPORTED_EPS = ["cpu", "mixed", "qnn", "dml"] + +# Attributes that com.microsoft::GroupQueryAttention requires for Qwen3. +# Any other attributes (e.g. k_quant_type, local_window_size, qk_output, +# smooth_softmax, v_quant_type) are default-valued extras injected by the +# TorchScript exporter from the ORT op schema; strip them so the bundle +# matches the expected minimal attribute set. +_GQA_KEEP_ATTRS = frozenset({"do_rotary", "kv_num_heads", "num_heads"}) + + +def _strip_gqa_default_attrs(model: onnx.ModelProto) -> onnx.ModelProto: + """Remove exporter-injected default attributes from GQA nodes.""" + return strip_node_attrs(model, "GroupQueryAttention", _GQA_KEEP_ATTRS, domain="com.microsoft") # --------------------------------------------------------------------------- # Helpers shared between sub-commands @@ -220,6 +235,7 @@ def _cmd_export(args: argparse.Namespace) -> int: embeddings_src=embeddings_src, lm_head_src=lm_head_src, ep="qnn" if args.device == "npu" else args.device, + transformer_onnx_passes=[_strip_gqa_default_attrs], ) print(f" genai_config.json -> {config_path}") diff --git a/src/winml/modelkit/onnx/__init__.py b/src/winml/modelkit/onnx/__init__.py index 4c7fe312d..d2e315e0d 100644 --- a/src/winml/modelkit/onnx/__init__.py +++ b/src/winml/modelkit/onnx/__init__.py @@ -22,7 +22,7 @@ from .metadata import capture_metadata, restore_metadata from .persistence import ONNXSaveError, cleanup_onnx, load_onnx, save_onnx from .shape import infer_onnx_shapes, infer_shapes -from .utils import EXTERNAL_DATA_THRESHOLD, check_onnx_model, get_model_size +from .utils import EXTERNAL_DATA_THRESHOLD, check_onnx_model, get_model_size, strip_node_attrs __all__ = [ @@ -48,6 +48,7 @@ "remove_optional_from_type_annotation", "restore_metadata", "save_onnx", + "strip_node_attrs", ] diff --git a/src/winml/modelkit/onnx/utils.py b/src/winml/modelkit/onnx/utils.py index e42d3407e..b598fd79a 100644 --- a/src/winml/modelkit/onnx/utils.py +++ b/src/winml/modelkit/onnx/utils.py @@ -13,6 +13,41 @@ EXTERNAL_DATA_THRESHOLD = 100 * 1024 * 1024 # 100 MiB +def strip_node_attrs( + model: onnx.ModelProto, + op_type: str, + keep_attrs: frozenset[str] | set[str], + domain: str = "", +) -> onnx.ModelProto: + """Remove all attributes from matching nodes except those in *keep_attrs*. + + Useful for stripping default-valued optional attributes that an exporter + fills in automatically but that are not needed at inference time. + + Modifies *model* **in-place** and also returns it for convenient chaining. + + Args: + model: ONNX model proto to modify. + op_type: Operator type string (e.g. ``"GroupQueryAttention"``). + keep_attrs: Attribute names to retain; every other attribute is removed. + domain: Operator domain to match (e.g. ``"com.microsoft"``). The + empty string matches the default ONNX domain. + + Returns: + The same *model* object (mutated in-place). + """ + for node in model.graph.node: + if node.op_type != op_type or node.domain != domain: + continue + to_remove = [a.name for a in node.attribute if a.name not in keep_attrs] + for name in to_remove: + for i, a in enumerate(node.attribute): + if a.name == name: + del node.attribute[i] + break + return model + + def get_model_size(model: onnx.ModelProto) -> int: """Calculate the total size of an ONNX model in bytes. diff --git a/src/winml/modelkit/utils/genai.py b/src/winml/modelkit/utils/genai.py index c8db1d21a..1c144da8d 100644 --- a/src/winml/modelkit/utils/genai.py +++ b/src/winml/modelkit/utils/genai.py @@ -70,7 +70,13 @@ import re from dataclasses import dataclass from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + import onnx logger = logging.getLogger(__name__) @@ -533,6 +539,7 @@ def write_genai_bundle( lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, context_session_options: dict | None = None, iterator_session_options: dict | None = None, + transformer_onnx_passes: Sequence[Callable[[onnx.ModelProto], onnx.ModelProto]] | None = None, ) -> Path: """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. @@ -561,10 +568,18 @@ def write_genai_bundle( supplies any EP-specific options. iterator_session_options: Same as *context_session_options* but for the ``iterator`` stage. + transformer_onnx_passes: Optional list of callables applied to each + transformer ONNX (context + iterator) **after** copying to the + bundle directory. Each callable receives an + ``onnx.ModelProto``, may modify it in-place, and must return it. + Passes are applied in order to the context model first, then the + iterator model. Use this to strip exporter-injected default + attributes that are not needed at inference time. Returns: Path to the written ``genai_config.json``. """ + import onnx as _onnx from transformers import AutoConfig, AutoTokenizer from ..onnx import copy_onnx_model @@ -581,6 +596,16 @@ def write_genai_bundle( logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename) copy_onnx_model(iterator_onnx, output_dir / iterator_filename) + # 1b. Apply optional post-copy ONNX passes to the transformer stages. + if transformer_onnx_passes: + for fname in (context_filename, iterator_filename): + dst = output_dir / fname + model = _onnx.load(str(dst), load_external_data=False) + for pass_fn in transformer_onnx_passes: + model = pass_fn(model) + _onnx.save(model, str(dst)) + logger.info("Applied %d ONNX pass(es) to %s", len(transformer_onnx_passes), fname) + # 2. Copy embeddings + lm_head models. if embeddings_src is not None: logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename) diff --git a/tests/unit/onnx/test_utils.py b/tests/unit/onnx/test_utils.py new file mode 100644 index 000000000..026811f36 --- /dev/null +++ b/tests/unit/onnx/test_utils.py @@ -0,0 +1,150 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for onnx/utils.py — strip_node_attrs.""" + +from __future__ import annotations + +import onnx +from onnx import TensorProto, helper + +from winml.modelkit.onnx import strip_node_attrs + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_gqa_model(attr_dict: dict[str, int]) -> onnx.ModelProto: + """Build a minimal model with a single com.microsoft::GroupQueryAttention node. + + Uses explicit ``make_attribute`` calls so attr names match what PyTorch's + TorchScript ONNX exporter produces (no ``_i`` suffix). + """ + x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 64, 512]) + y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 64, 512]) + node = onnx.NodeProto() + node.op_type = "GroupQueryAttention" + node.domain = "com.microsoft" + node.input.append("x") + node.output.append("y") + for name, value in attr_dict.items(): + attr = helper.make_attribute(name, value) + node.attribute.append(attr) + graph = helper.make_graph([node], "gqa_graph", [x], [y]) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("com.microsoft", 1)]) + + +def _attr_names(model: onnx.ModelProto) -> set[str]: + return {a.name for n in model.graph.node for a in n.attribute} + + +# --------------------------------------------------------------------------- +# strip_node_attrs +# --------------------------------------------------------------------------- + + +def test_strip_removes_extra_attrs(): + """Attributes not in keep_attrs are removed from matching nodes.""" + model = _make_gqa_model( + { + "do_rotary": 1, + "num_heads": 16, + "kv_num_heads": 8, + "local_window_size": -1, + "smooth_softmax": -1, + } + ) + keep = frozenset({"do_rotary", "num_heads", "kv_num_heads"}) + result = strip_node_attrs(model, "GroupQueryAttention", keep, domain="com.microsoft") + remaining = _attr_names(result) + assert remaining == keep + + +def test_strip_keep_attrs_preserved(): + """Attributes listed in keep_attrs survive stripping.""" + model = _make_gqa_model({"do_rotary": 1, "num_heads": 16, "kv_num_heads": 8}) + keep = frozenset({"do_rotary", "num_heads", "kv_num_heads"}) + strip_node_attrs(model, "GroupQueryAttention", keep, domain="com.microsoft") + remaining = _attr_names(model) + assert "do_rotary" in remaining + assert "num_heads" in remaining + assert "kv_num_heads" in remaining + + +def test_strip_no_matching_nodes_is_noop(): + """strip_node_attrs is a no-op when no nodes match op_type.""" + x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 4]) + y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 4]) + node = helper.make_node("Relu", ["x"], ["y"]) + graph = helper.make_graph([node], "g", [x], [y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + + result = strip_node_attrs(model, "GroupQueryAttention", frozenset(), domain="com.microsoft") + assert result is model # same object returned + + +def test_strip_domain_mismatch_is_noop(): + """Nodes with a different domain are not modified.""" + model = _make_gqa_model({"do_rotary_i": 1, "extra_i": 0}) + before = _attr_names(model) + # Pass wrong domain — nothing should be removed + strip_node_attrs(model, "GroupQueryAttention", frozenset({"do_rotary"}), domain="wrong.domain") + assert _attr_names(model) == before + + +def test_strip_keep_all_attrs(): + """When keep_attrs contains all attr names, nothing is removed.""" + model = _make_gqa_model({"do_rotary": 1, "num_heads": 16}) + keep = frozenset({"do_rotary", "num_heads"}) + strip_node_attrs(model, "GroupQueryAttention", keep, domain="com.microsoft") + assert _attr_names(model) == keep + + +def test_strip_empty_keep_attrs_removes_all(): + """An empty keep_attrs set removes every attribute from matching nodes.""" + model = _make_gqa_model({"do_rotary": 1, "num_heads": 16}) + strip_node_attrs(model, "GroupQueryAttention", frozenset(), domain="com.microsoft") + assert _attr_names(model) == set() + + +def test_strip_returns_same_model_object(): + """strip_node_attrs mutates in-place and returns the same object.""" + model = _make_gqa_model({"do_rotary": 1}) + result = strip_node_attrs( + model, "GroupQueryAttention", frozenset({"do_rotary"}), domain="com.microsoft" + ) + assert result is model + + +def test_strip_multiple_gqa_nodes(): + """All matching nodes in a multi-node graph are stripped.""" + x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 64, 512]) + y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 64, 512]) + z = helper.make_tensor_value_info("z", TensorProto.FLOAT, [1, 64, 512]) + + def _gqa_node(name: str, inp: str, out: str) -> onnx.NodeProto: + node = onnx.NodeProto() + node.op_type = "GroupQueryAttention" + node.domain = "com.microsoft" + node.name = name + node.input.append(inp) + node.output.append(out) + for attr_name, value in [("do_rotary", 1), ("num_heads", 16), ("local_window_size", -1)]: + node.attribute.append(helper.make_attribute(attr_name, value)) + return node + + graph = helper.make_graph( + [_gqa_node("gqa0", "x", "y"), _gqa_node("gqa1", "y", "z")], + "multi_gqa", + [x], + [z], + value_info=[y], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("com.microsoft", 1)]) + keep = frozenset({"do_rotary", "num_heads"}) + strip_node_attrs(model, "GroupQueryAttention", keep, domain="com.microsoft") + for node in model.graph.node: + assert {a.name for a in node.attribute} == keep From c7c84b9d1245d13af9e96d52268ee93709d573c6 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 1 Jul 2026 20:11:39 +0800 Subject: [PATCH 23/29] Truncate cos/sin rope cache to max_cache_len at export time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit WinMLQwen3Attention.forward was calling rotary_emb with torch.arange(config.max_position_embeddings) = 40960 positions, producing a 40960x64 cos/sin cache constant in every exported ONNX. The reference model uses a 4096x64 cache (= max_cache_len). Fix: use total_seq_len.item() (which equals max_cache_len at trace time, as set by _TransformerOnlySeqLenGenerator) instead of config.max_position_embeddings. This produces a cache of exactly max_cache_len rows — matching what will actually be needed at inference time and 10x smaller for the default Qwen3-0.6B export. Falls back to config.max_position_embeddings when total_seq_len is None (e.g. eager evaluation outside the export path). 4 new tests verify rope cache sizing across multiple max_cache_len values and the None fallback. --- .../models/hf/qwen3/qwen3_modeling.py | 6 +- .../unit/models/qwen3/test_qwen3_modeling.py | 163 ++++++++++++++++++ 2 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 tests/unit/models/qwen3/test_qwen3_modeling.py diff --git a/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py index 140c2dac8..ace6a6d81 100644 --- a/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py +++ b/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py @@ -155,7 +155,11 @@ def forward( cos, sin = cast("nn.Module", self.rotary_emb)( value_states, - torch.arange(self.config.max_position_embeddings).unsqueeze(0), + torch.arange( + int(total_seq_len.item()) + if total_seq_len is not None + else self.config.max_position_embeddings + ).unsqueeze(0), ) cos = cos.squeeze(0)[:, : cos.shape[-1] // 2] sin = sin.squeeze(0)[:, : sin.shape[-1] // 2] diff --git a/tests/unit/models/qwen3/test_qwen3_modeling.py b/tests/unit/models/qwen3/test_qwen3_modeling.py new file mode 100644 index 000000000..8d51a0252 --- /dev/null +++ b/tests/unit/models/qwen3/test_qwen3_modeling.py @@ -0,0 +1,163 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for WinMLQwen3Attention.forward — rope cache sizing.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import torch + +from winml.modelkit.models.hf.qwen3.qwen3_modeling import WinMLQwen3Attention + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_attention_module( + max_position_embeddings: int = 40960, + num_heads: int = 16, + num_kv_heads: int = 8, + head_dim: int = 64, +) -> MagicMock: + """Build a minimal mock bound to WinMLQwen3Attention.forward.""" + hidden_size = num_heads * head_dim + kv_size = num_kv_heads * head_dim + + mod = MagicMock() + mod.head_dim = head_dim + mod._matmul_to_conv = False + mod.config = SimpleNamespace( + num_attention_heads=num_heads, + num_key_value_heads=num_kv_heads, + max_position_embeddings=max_position_embeddings, + ) + + # Identity projections: return float32 tensors of the right shape + mod.q_proj.side_effect = lambda x: torch.zeros(1, x.shape[1], hidden_size) + mod.k_proj.side_effect = lambda x: torch.zeros(1, x.shape[1], kv_size) + mod.v_proj.side_effect = lambda x: torch.zeros(1, x.shape[1], kv_size) + + # q_norm / k_norm: identity + mod.q_norm.side_effect = lambda x: x + mod.k_norm.side_effect = lambda x: x + + return mod + + +def _run_forward( + mod: WinMLQwen3Attention, + seq_len: int, + max_cache_len: int, + kv_dtype: torch.dtype = torch.float16, +) -> list[torch.Tensor]: + """Invoke WinMLQwen3Attention.forward and capture rotary_emb position_ids.""" + hidden = torch.zeros(1, seq_len, mod.config.num_attention_heads * mod.head_dim) + past_keys = torch.zeros( + 1, mod.config.num_key_value_heads, max_cache_len, mod.head_dim, dtype=kv_dtype + ) + past_vals = torch.zeros( + 1, mod.config.num_key_value_heads, max_cache_len, mod.head_dim, dtype=kv_dtype + ) + past_seq_len = torch.zeros(1, 1, dtype=torch.int32) + total_seq_len = torch.tensor([max_cache_len], dtype=torch.int32) + + captured_pos_ids: list[torch.Tensor] = [] + + def _fake_rotary_emb(values, position_ids): + captured_pos_ids.append(position_ids) + seq_dim = position_ids.shape[-1] + cos = torch.ones(1, seq_dim, mod.head_dim, dtype=values.dtype) + sin = torch.zeros(1, seq_dim, mod.head_dim, dtype=values.dtype) + return cos, sin + + mod.rotary_emb.side_effect = _fake_rotary_emb + + # GQA op: return (attn_out, present_keys, present_values) + with patch( + "winml.modelkit.models.hf.qwen3.qwen3_modeling.GroupQueryAttentionOnnxExport.apply" + ) as mock_gqa: + attn_out = torch.zeros( + 1, seq_len, mod.config.num_attention_heads * mod.head_dim, dtype=kv_dtype + ) + mock_gqa.return_value = (attn_out, past_keys, past_vals) + WinMLQwen3Attention.forward( + mod, + hidden, + past_key_value=(past_keys, past_vals), + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + ) + + return captured_pos_ids + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestRopeCacheSizing: + def test_rope_cache_uses_total_seq_len_not_max_position_embeddings(self): + """rope cache length == max_cache_len, not max_position_embeddings.""" + mod = _make_attention_module(max_position_embeddings=40960) + pos_ids = _run_forward(mod, seq_len=64, max_cache_len=256) + assert len(pos_ids) == 1 + assert pos_ids[0].shape[-1] == 256, ( + f"Expected rope cache length 256 but got {pos_ids[0].shape[-1]}" + ) + + def test_rope_cache_matches_max_cache_len(self): + """rope cache length equals the max_cache_len used for KV cache.""" + for max_cache_len in (128, 512, 4096): + mod = _make_attention_module(max_position_embeddings=40960) + pos_ids = _run_forward(mod, seq_len=1, max_cache_len=max_cache_len) + assert pos_ids[0].shape[-1] == max_cache_len + + def test_rope_cache_much_smaller_than_max_position_embeddings(self): + """With max_cache_len=256, cache is 160x smaller than full rope.""" + mod = _make_attention_module(max_position_embeddings=40960) + pos_ids = _run_forward(mod, seq_len=1, max_cache_len=256) + assert pos_ids[0].shape[-1] < mod.config.max_position_embeddings + + def test_fallback_when_total_seq_len_is_none(self): + """When total_seq_len is None, falls back to max_position_embeddings.""" + mod = _make_attention_module(max_position_embeddings=512) + hidden = torch.zeros(1, 1, mod.config.num_attention_heads * mod.head_dim) + past_keys = torch.zeros( + 1, mod.config.num_key_value_heads, 256, mod.head_dim, dtype=torch.float16 + ) + past_vals = torch.zeros_like(past_keys) + + captured: list[torch.Tensor] = [] + + def _fake_rotary_emb(values, position_ids): + captured.append(position_ids) + seq_dim = position_ids.shape[-1] + cos = torch.ones(1, seq_dim, mod.head_dim, dtype=values.dtype) + sin = torch.zeros(1, seq_dim, mod.head_dim, dtype=values.dtype) + return cos, sin + + mod.rotary_emb.side_effect = _fake_rotary_emb + + with patch( + "winml.modelkit.models.hf.qwen3.qwen3_modeling.GroupQueryAttentionOnnxExport.apply" + ) as mock_gqa: + attn_out = torch.zeros( + 1, 1, mod.config.num_attention_heads * mod.head_dim, dtype=torch.float16 + ) + mock_gqa.return_value = (attn_out, past_keys, past_vals) + WinMLQwen3Attention.forward( + mod, + hidden, + past_key_value=(past_keys, past_vals), + past_seq_len=torch.zeros(1, 1, dtype=torch.int32), + total_seq_len=None, + ) + + assert captured[0].shape[-1] == 512 # falls back to max_position_embeddings From 0fc155a588c6e063a56a9491458170deb0d74a7b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 1 Jul 2026 20:50:51 +0800 Subject: [PATCH 24/29] fix: pin rope cache to static Python int to fix symbolic tracing bug torch.export.export (used by torch.onnx.export at opset 18+) treats int(total_seq_len.item()) as Sym(u0), causing torch.arange(Sym(u0)) to resolve to arange(1) at trace-time, baking cos_cache=[1,64] into the ONNX graph instead of the intended [40960,64]. Fix: WinMLQwen3Attention.prepare_for_onnx_export now stores _max_rope_len as a plain Python int (defaults to config.max_position_embeddings=40960). forward() uses torch.arange(self._max_rope_len) -- a concrete int literal that both the TorchScript and torch.export backends bake as a constant tensor, giving cos_cache shape [40960, 64] which satisfies GQA's check cos_cache.shape[0] >= total_seq_len for any total_seq_len <= 40960. apply_transformer_only_export_prep accepts an optional max_rope_len kwarg to allow callers to override the rope length (e.g., to match max_cache_len). Tests updated: replaced total_seq_len-based assertions with _max_rope_len attribute tests and forward() rope-length-independence tests. --- .../models/hf/qwen3/qwen3_modeling.py | 34 ++++-- .../unit/models/qwen3/test_qwen3_modeling.py | 110 +++++++++--------- 2 files changed, 82 insertions(+), 62 deletions(-) diff --git a/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py index ace6a6d81..e691038d5 100644 --- a/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py +++ b/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py @@ -96,8 +96,18 @@ class WinMLQwen3Attention(nn.Module): v_proj: nn.Module o_proj: nn.Module - def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: - """Optionally swap the Q/K/V/O projections for 1x1 convs.""" + def prepare_for_onnx_export( + self, *, matmul_to_conv: bool, max_rope_len: int | None = None + ) -> None: + """Optionally swap the Q/K/V/O projections for 1x1 convs. + + ``max_rope_len`` pins the cos/sin rope cache to exactly that many + positions so the exported ONNX constant has shape ``[max_rope_len, 64]`` + instead of ``[config.max_position_embeddings, 64]``. It must be a + plain Python ``int`` (not a tensor) so the ``torch.arange`` call in + ``forward`` bakes a static-shape constant into the ONNX graph rather + than a dynamic symbolic range. + """ if matmul_to_conv: self.q_proj = TransposeConv2d1x1Transpose.from_linear_module( cast("nn.Linear", self.q_proj) @@ -112,6 +122,9 @@ def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: cast("nn.Linear", self.o_proj) ) self._matmul_to_conv = matmul_to_conv + self._max_rope_len: int = ( + int(max_rope_len) if max_rope_len is not None else self.config.max_position_embeddings + ) def forward( self, @@ -155,11 +168,7 @@ def forward( cos, sin = cast("nn.Module", self.rotary_emb)( value_states, - torch.arange( - int(total_seq_len.item()) - if total_seq_len is not None - else self.config.max_position_embeddings - ).unsqueeze(0), + torch.arange(self._max_rope_len).unsqueeze(0), ) cos = cos.squeeze(0)[:, : cos.shape[-1] // 2] sin = sin.squeeze(0)[:, : sin.shape[-1] // 2] @@ -277,7 +286,7 @@ def forward( def apply_transformer_only_export_prep( - causal_lm: nn.Module, *, matmul_to_conv: bool = True + causal_lm: nn.Module, *, matmul_to_conv: bool = True, max_rope_len: int | None = None ) -> None: """Mutate ``Qwen3ForCausalLM`` in-place into the export topology. @@ -291,6 +300,11 @@ def apply_transformer_only_export_prep( causal_lm: A ``transformers.Qwen3ForCausalLM`` instance. matmul_to_conv: Swap ``nn.Linear`` projections to 1x1 ``Conv2d`` so QNN sees them as Conv. + max_rope_len: Number of positions to pre-compute for the cos/sin rope + cache. Pass the model's ``max_cache_len`` (e.g. 4096) so the + exported constant has shape ``[max_rope_len, 64]`` instead of + ``[config.max_position_embeddings, 64]`` (40960). Must be a plain + Python ``int`` — see ``WinMLQwen3Attention.prepare_for_onnx_export``. Raises: RuntimeError: If any expected Qwen3 submodule class is not found, @@ -325,7 +339,9 @@ def _is(module: nn.Module, name: str) -> bool: for mod in causal_lm.modules(): if _is(mod, "Qwen3Attention"): WinMLQwen3Attention.prepare_for_onnx_export( - cast("WinMLQwen3Attention", mod), matmul_to_conv=matmul_to_conv + cast("WinMLQwen3Attention", mod), + matmul_to_conv=matmul_to_conv, + max_rope_len=max_rope_len, ) _bind(mod, WinMLQwen3Attention) patched["Qwen3Attention"] += 1 diff --git a/tests/unit/models/qwen3/test_qwen3_modeling.py b/tests/unit/models/qwen3/test_qwen3_modeling.py index 8d51a0252..6826a95bf 100644 --- a/tests/unit/models/qwen3/test_qwen3_modeling.py +++ b/tests/unit/models/qwen3/test_qwen3_modeling.py @@ -2,7 +2,13 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Unit tests for WinMLQwen3Attention.forward — rope cache sizing.""" +"""Unit tests for WinMLQwen3Attention rope cache — static _max_rope_len approach. + +The rope cache length is now pinned to a plain Python int attribute +``_max_rope_len`` set by ``prepare_for_onnx_export``, so the ONNX exporter +bakes a static-shape constant rather than a symbolic range derived from +``total_seq_len.item()``. +""" from __future__ import annotations @@ -25,7 +31,7 @@ def _make_attention_module( num_kv_heads: int = 8, head_dim: int = 64, ) -> MagicMock: - """Build a minimal mock bound to WinMLQwen3Attention.forward.""" + """Build a minimal mock bound to WinMLQwen3Attention methods.""" hidden_size = num_heads * head_dim kv_size = num_kv_heads * head_dim @@ -98,66 +104,64 @@ def _fake_rotary_emb(values, position_ids): # --------------------------------------------------------------------------- -# Tests +# Tests for prepare_for_onnx_export — _max_rope_len attribute +# --------------------------------------------------------------------------- + + +class TestPrepareForOnnxExport: + def test_explicit_max_rope_len_is_stored(self): + """prepare_for_onnx_export stores the supplied max_rope_len as a Python int.""" + mod = _make_attention_module(max_position_embeddings=40960) + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=False, max_rope_len=4096) + assert mod._max_rope_len == 4096 + assert isinstance(mod._max_rope_len, int) + + def test_fallback_to_max_position_embeddings(self): + """Without max_rope_len, _max_rope_len falls back to max_position_embeddings.""" + mod = _make_attention_module(max_position_embeddings=512) + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=False) + assert mod._max_rope_len == 512 + + def test_max_rope_len_is_plain_int(self): + """_max_rope_len must be a plain int, not a tensor or other type.""" + mod = _make_attention_module() + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=False, max_rope_len=4096) + assert type(mod._max_rope_len) is int + + +# --------------------------------------------------------------------------- +# Tests for forward — uses _max_rope_len, ignores total_seq_len for rope # --------------------------------------------------------------------------- class TestRopeCacheSizing: - def test_rope_cache_uses_total_seq_len_not_max_position_embeddings(self): - """rope cache length == max_cache_len, not max_position_embeddings.""" + def test_forward_uses_max_rope_len_attribute(self): + """forward passes torch.arange(_max_rope_len) to rotary_emb regardless of total_seq_len.""" mod = _make_attention_module(max_position_embeddings=40960) - pos_ids = _run_forward(mod, seq_len=64, max_cache_len=256) + mod._max_rope_len = 256 # explicitly set; different from total_seq_len=4096 below + + pos_ids = _run_forward(mod, seq_len=1, max_cache_len=4096) assert len(pos_ids) == 1 assert pos_ids[0].shape[-1] == 256, ( - f"Expected rope cache length 256 but got {pos_ids[0].shape[-1]}" + f"Expected rope cache length 256 (from _max_rope_len) but got {pos_ids[0].shape[-1]}" ) - def test_rope_cache_matches_max_cache_len(self): - """rope cache length equals the max_cache_len used for KV cache.""" - for max_cache_len in (128, 512, 4096): - mod = _make_attention_module(max_position_embeddings=40960) - pos_ids = _run_forward(mod, seq_len=1, max_cache_len=max_cache_len) - assert pos_ids[0].shape[-1] == max_cache_len + def test_rope_cache_does_not_depend_on_total_seq_len(self): + """Changing total_seq_len does NOT change the rope cache length.""" + mod_a = _make_attention_module() + mod_a._max_rope_len = 512 + pos_a = _run_forward(mod_a, seq_len=1, max_cache_len=128) - def test_rope_cache_much_smaller_than_max_position_embeddings(self): - """With max_cache_len=256, cache is 160x smaller than full rope.""" - mod = _make_attention_module(max_position_embeddings=40960) - pos_ids = _run_forward(mod, seq_len=1, max_cache_len=256) - assert pos_ids[0].shape[-1] < mod.config.max_position_embeddings + mod_b = _make_attention_module() + mod_b._max_rope_len = 512 + pos_b = _run_forward(mod_b, seq_len=1, max_cache_len=4096) + + # Both use _max_rope_len=512; total_seq_len differs but rope length must match + assert pos_a[0].shape[-1] == pos_b[0].shape[-1] == 512 - def test_fallback_when_total_seq_len_is_none(self): - """When total_seq_len is None, falls back to max_position_embeddings.""" + def test_fallback_rope_len_is_max_position_embeddings(self): + """Without max_rope_len arg, forward uses max_position_embeddings as rope len.""" mod = _make_attention_module(max_position_embeddings=512) - hidden = torch.zeros(1, 1, mod.config.num_attention_heads * mod.head_dim) - past_keys = torch.zeros( - 1, mod.config.num_key_value_heads, 256, mod.head_dim, dtype=torch.float16 - ) - past_vals = torch.zeros_like(past_keys) - - captured: list[torch.Tensor] = [] - - def _fake_rotary_emb(values, position_ids): - captured.append(position_ids) - seq_dim = position_ids.shape[-1] - cos = torch.ones(1, seq_dim, mod.head_dim, dtype=values.dtype) - sin = torch.zeros(1, seq_dim, mod.head_dim, dtype=values.dtype) - return cos, sin - - mod.rotary_emb.side_effect = _fake_rotary_emb - - with patch( - "winml.modelkit.models.hf.qwen3.qwen3_modeling.GroupQueryAttentionOnnxExport.apply" - ) as mock_gqa: - attn_out = torch.zeros( - 1, 1, mod.config.num_attention_heads * mod.head_dim, dtype=torch.float16 - ) - mock_gqa.return_value = (attn_out, past_keys, past_vals) - WinMLQwen3Attention.forward( - mod, - hidden, - past_key_value=(past_keys, past_vals), - past_seq_len=torch.zeros(1, 1, dtype=torch.int32), - total_seq_len=None, - ) - - assert captured[0].shape[-1] == 512 # falls back to max_position_embeddings + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=False) + pos_ids = _run_forward(mod, seq_len=1, max_cache_len=256) + assert pos_ids[0].shape[-1] == 512 From bbff1f2b41d78de4d247272949efc7a1091bfd31 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 3 Jul 2026 09:50:50 +0800 Subject: [PATCH 25/29] fix(scripts): drop broken genai session import and dead _SUPPORTED_EPS Remove the leftover 'from winml.modelkit.session import GenaiSession, GenerationConfig' import (those symbols were removed with genai_session.py, so the import would crash the export script) and the unused _SUPPORTED_EPS global. Both were flagged by CodeQL. Also correct the ctx/iter docstring from 'QNN-quantized' to 'QDQ-quantized' per review feedback. --- scripts/qwen3.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/scripts/qwen3.py b/scripts/qwen3.py index 27d72ce34..7d88938c1 100644 --- a/scripts/qwen3.py +++ b/scripts/qwen3.py @@ -7,8 +7,8 @@ Builds (or reuses) all four components of the Qwen3 genai bundle and assembles them into an onnxruntime-genai directory: - - ``ctx.onnx`` — transformer prefill graph (QNN-quantized) - - ``iter.onnx`` — transformer decode graph (QNN-quantized) + - ``ctx.onnx`` — transformer prefill graph (QDQ-quantized) + - ``iter.onnx`` — transformer decode graph (QDQ-quantized) - ``embeddings.onnx`` — token embedding table (fp32) - ``lm_head.onnx`` — vocab projection (w4a32 MatMulNBits) - ``genai_config.json`` + HF tokenizer files @@ -40,7 +40,6 @@ WinMLQwen3TransformerOnlyModel, ) from winml.modelkit.onnx import strip_node_attrs -from winml.modelkit.session import GenaiSession, GenerationConfig _DEVICE_TO_EP = { @@ -58,8 +57,6 @@ _REPO_ROOT = Path(__file__).resolve().parent.parent _DEFAULT_BUNDLE = _REPO_ROOT / "out" / "bundle" -_SUPPORTED_EPS = ["cpu", "mixed", "qnn", "dml"] - # Attributes that com.microsoft::GroupQueryAttention requires for Qwen3. # Any other attributes (e.g. k_quant_type, local_window_size, qk_output, # smooth_softmax, v_quant_type) are default-valued extras injected by the @@ -72,6 +69,7 @@ def _strip_gqa_default_attrs(model: onnx.ModelProto) -> onnx.ModelProto: """Remove exporter-injected default attributes from GQA nodes.""" return strip_node_attrs(model, "GroupQueryAttention", _GQA_KEEP_ATTRS, domain="com.microsoft") + # --------------------------------------------------------------------------- # Helpers shared between sub-commands # --------------------------------------------------------------------------- From 22d38ba827c92f1282a934124930cb90dbea8abc Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 3 Jul 2026 11:30:04 +0800 Subject: [PATCH 26/29] fix(qwen3): forward transformer_onnx_passes through genai bundle wrapper The Qwen3 write_genai_bundle wrapper did not accept or forward the transformer_onnx_passes argument used by the generic assembler and the export script (scripts/qwen3.py), so 'qwen3 export' crashed at bundle assembly with a TypeError. Add the parameter and forward it verbatim to winml.modelkit.utils.genai.write_genai_bundle. Add regression tests covering the pass-through and the ep-derived session_options forwarding. --- src/winml/modelkit/models/hf/qwen3/genai.py | 8 +++ tests/unit/models/qwen3/test_genai_config.py | 60 ++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py index 315f95fe6..c1568ffa7 100644 --- a/src/winml/modelkit/models/hf/qwen3/genai.py +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -36,8 +36,11 @@ if TYPE_CHECKING: + from collections.abc import Callable, Sequence from pathlib import Path + import onnx + # --------------------------------------------------------------------------- # Qwen3-specific QNN execution-provider routing @@ -161,6 +164,7 @@ def write_genai_bundle( lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, ep: str = "cpu", soc_model: str = "60", + transformer_onnx_passes: Sequence[Callable[[onnx.ModelProto], onnx.ModelProto]] | None = None, ) -> Path: """Assemble a Qwen3 genai bundle, routing ctx/iter to QNN when ``ep="qnn"``. @@ -174,6 +178,9 @@ def write_genai_bundle( HTP (NPU) backend; ``"cpu"`` (default) keeps every stage on CPU. soc_model: Snapdragon SoC model passed to the QNN backend when ``ep="qnn"``. Default ``"60"`` = Snapdragon 8 Gen 3 / X Elite. + transformer_onnx_passes: Optional ONNX graph transforms applied to the + copied context/iterator models before ``genai_config.json`` is + written. Forwarded verbatim to the generic assembler. Returns: Path to the written ``genai_config.json``. @@ -194,6 +201,7 @@ def write_genai_bundle( lm_head_filename=lm_head_filename, context_session_options=ctx_opts, iterator_session_options=iter_opts, + transformer_onnx_passes=transformer_onnx_passes, ) diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py index 747d9c92d..2b2fb9309 100644 --- a/tests/unit/models/qwen3/test_genai_config.py +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -7,6 +7,7 @@ from __future__ import annotations from types import SimpleNamespace +from typing import ClassVar from unittest.mock import patch from winml.modelkit.models.hf.qwen3 import ( @@ -14,6 +15,7 @@ PipelineStage, build_genai_config, build_qwen3_transformer_only_stages, + write_genai_bundle, ) from winml.modelkit.models.hf.qwen3.genai import ( DEFAULT_CONTEXT_FILENAME, @@ -517,3 +519,61 @@ def test_custom_soc_model(self) -> None: ) ctx = next(s for s in stages if s.name == "context") assert ctx.session_options["provider_options"][0]["qnn"]["soc_model"] == "73" + + +# --------------------------------------------------------------------------- +# Tests: write_genai_bundle wrapper (ep routing + transformer_onnx_passes) +# --------------------------------------------------------------------------- + + +class TestWriteGenaiBundleWrapper: + """The Qwen3 ``write_genai_bundle`` wrapper injects the QNN ``session_options`` + and forwards the generic keyword arguments — notably ``transformer_onnx_passes`` + — to :func:`winml.modelkit.utils.genai.write_genai_bundle`. The generic + assembler is mocked so no ONNX/tokenizer files are needed. + """ + + _COMMON: ClassVar[dict] = { + "context_onnx": "ctx.onnx", + "iterator_onnx": "iter.onnx", + "model_id": "Qwen/Qwen3-0.6B", + "max_cache_len": 2048, + "prefill_seq_len": 64, + } + + @staticmethod + def _patch_generic(): + return patch("winml.modelkit.models.hf.qwen3.genai._write_genai_bundle") + + def test_forwards_transformer_onnx_passes(self) -> None: + """A supplied transformer_onnx_passes list reaches the generic assembler.""" + + def _identity_pass(model): + return model + + with self._patch_generic() as mock_write: + write_genai_bundle("out", transformer_onnx_passes=[_identity_pass], **self._COMMON) + assert mock_write.call_count == 1 + assert mock_write.call_args.kwargs["transformer_onnx_passes"] == [_identity_pass] + + def test_transformer_onnx_passes_defaults_to_none(self) -> None: + """Omitting transformer_onnx_passes forwards None (no passes).""" + with self._patch_generic() as mock_write: + write_genai_bundle("out", **self._COMMON) + assert mock_write.call_args.kwargs["transformer_onnx_passes"] is None + + def test_qnn_ep_forwards_session_options(self) -> None: + """ep='qnn' forwards QNN session_options for the context and iterator stages.""" + with self._patch_generic() as mock_write: + write_genai_bundle("out", ep="qnn", **self._COMMON) + kwargs = mock_write.call_args.kwargs + assert kwargs["context_session_options"]["log_id"] == "onnxruntime-genai.context" + assert kwargs["iterator_session_options"]["log_id"] == "onnxruntime-genai.iterator" + + def test_cpu_ep_forwards_no_session_options(self) -> None: + """ep='cpu' (default) forwards None session_options for both stages.""" + with self._patch_generic() as mock_write: + write_genai_bundle("out", **self._COMMON) + kwargs = mock_write.call_args.kwargs + assert kwargs["context_session_options"] is None + assert kwargs["iterator_session_options"] is None From a8075a992558c911a7b424f952b0152bd5a065ad Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 3 Jul 2026 12:29:19 +0800 Subject: [PATCH 27/29] fix(genai): preserve pad_token_id==0 and clear CodeQL import/statement alerts - build_genai_config: keep a valid pad_token_id of 0 instead of falling back to bos_token_id (the `... or bos` form treated the falsy 0 as unset). - utils/genai.py: use a TYPE_CHECKING `from onnx import ModelProto` for the transformer_onnx_passes annotation so the module no longer carries a second `import onnx` (clears CodeQL py/repeated-import). - quant/calibration/base.py: drop the redundant `...` after the Protocol method docstring (clears CodeQL py/ineffectual-statement). - tests/unit/onnx/test_utils.py: import onnx symbols via a single `from onnx import ...` (clears CodeQL py/import-and-import-from). - tests: add a regression test asserting pad_token_id==0 is preserved. --- src/winml/modelkit/quant/calibration/base.py | 1 - src/winml/modelkit/utils/genai.py | 8 ++++--- tests/unit/models/qwen3/test_genai_config.py | 22 ++++++++++++++++++++ tests/unit/onnx/test_utils.py | 13 ++++++------ 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/winml/modelkit/quant/calibration/base.py b/src/winml/modelkit/quant/calibration/base.py index 213199811..39b9543c5 100644 --- a/src/winml/modelkit/quant/calibration/base.py +++ b/src/winml/modelkit/quant/calibration/base.py @@ -38,4 +38,3 @@ def finalize( model_id: str | None = None, ) -> WinMLQuantizationConfig: """Return ``quant`` populated with the graph-derived quant settings.""" - ... diff --git a/src/winml/modelkit/utils/genai.py b/src/winml/modelkit/utils/genai.py index 1c144da8d..0931d56d4 100644 --- a/src/winml/modelkit/utils/genai.py +++ b/src/winml/modelkit/utils/genai.py @@ -76,7 +76,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence - import onnx + from onnx import ModelProto logger = logging.getLogger(__name__) @@ -235,7 +235,9 @@ def build_genai_config( # silently discard secondary EOS tokens (e.g. Qwen3 uses [151645, 151643]) # and cause generation to run until max_length instead of stopping early. - pad_token_id = getattr(hf_config, "pad_token_id", None) or hf_config.bos_token_id + pad_token_id = getattr(hf_config, "pad_token_id", None) + if pad_token_id is None: + pad_token_id = hf_config.bos_token_id decoder_section: dict = { "hidden_size": hf_config.hidden_size, @@ -539,7 +541,7 @@ def write_genai_bundle( lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, context_session_options: dict | None = None, iterator_session_options: dict | None = None, - transformer_onnx_passes: Sequence[Callable[[onnx.ModelProto], onnx.ModelProto]] | None = None, + transformer_onnx_passes: Sequence[Callable[[ModelProto], ModelProto]] | None = None, ) -> Path: """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py index 2b2fb9309..66553d589 100644 --- a/tests/unit/models/qwen3/test_genai_config.py +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -308,6 +308,28 @@ def test_pad_token_id_falls_back_to_bos(self) -> None: ) assert result["model"]["pad_token_id"] == 0 # falls back to bos_token_id + def test_pad_token_id_zero_is_preserved(self) -> None: + """A valid pad_token_id of 0 must not be overwritten by bos_token_id. + + 0 is a common, valid pad id; a truthiness check would silently swap it + for bos_token_id and corrupt batched-padding generation. + """ + cfg = SimpleNamespace( + num_hidden_layers=2, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=64, + bos_token_id=5, + eos_token_id=1, + pad_token_id=0, + vocab_size=32000, + ) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(2) + ) + assert result["model"]["pad_token_id"] == 0 + def test_different_layer_count(self) -> None: cfg = _mock_config(num_hidden_layers=4) result = build_genai_config( diff --git a/tests/unit/onnx/test_utils.py b/tests/unit/onnx/test_utils.py index 026811f36..de3b07aac 100644 --- a/tests/unit/onnx/test_utils.py +++ b/tests/unit/onnx/test_utils.py @@ -6,8 +6,7 @@ from __future__ import annotations -import onnx -from onnx import TensorProto, helper +from onnx import ModelProto, NodeProto, TensorProto, helper from winml.modelkit.onnx import strip_node_attrs @@ -17,7 +16,7 @@ # --------------------------------------------------------------------------- -def _make_gqa_model(attr_dict: dict[str, int]) -> onnx.ModelProto: +def _make_gqa_model(attr_dict: dict[str, int]) -> ModelProto: """Build a minimal model with a single com.microsoft::GroupQueryAttention node. Uses explicit ``make_attribute`` calls so attr names match what PyTorch's @@ -25,7 +24,7 @@ def _make_gqa_model(attr_dict: dict[str, int]) -> onnx.ModelProto: """ x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 64, 512]) y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 64, 512]) - node = onnx.NodeProto() + node = NodeProto() node.op_type = "GroupQueryAttention" node.domain = "com.microsoft" node.input.append("x") @@ -37,7 +36,7 @@ def _make_gqa_model(attr_dict: dict[str, int]) -> onnx.ModelProto: return helper.make_model(graph, opset_imports=[helper.make_opsetid("com.microsoft", 1)]) -def _attr_names(model: onnx.ModelProto) -> set[str]: +def _attr_names(model: ModelProto) -> set[str]: return {a.name for n in model.graph.node for a in n.attribute} @@ -125,8 +124,8 @@ def test_strip_multiple_gqa_nodes(): y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 64, 512]) z = helper.make_tensor_value_info("z", TensorProto.FLOAT, [1, 64, 512]) - def _gqa_node(name: str, inp: str, out: str) -> onnx.NodeProto: - node = onnx.NodeProto() + def _gqa_node(name: str, inp: str, out: str) -> NodeProto: + node = NodeProto() node.op_type = "GroupQueryAttention" node.domain = "com.microsoft" node.name = name From 0c22e2a0ce2facd8ff7cd1e6f3d9a0369dd6c96b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 3 Jul 2026 13:12:09 +0800 Subject: [PATCH 28/29] docs(qwen3): note deferred max_rope_len wiring at the transformer-only wrapper Explain why QwenTransformerOnlyDecoderWrapper leaves max_rope_len at its default: threading the build's max_cache_len down to this load-time hook needs generic model-loader plumbing, which is deferred to the follow-up PR. The apply_transformer_only_export_prep(..., max_rope_len=...) path is already implemented and unit-tested. --- src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py index 46cd13a6f..06cd3bdd9 100644 --- a/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py @@ -75,6 +75,12 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: self.model = model self.num_layers = num_layers self.config: PretrainedConfig = cast("PretrainedConfig", model.config) + # ``max_rope_len`` is intentionally left at its default here (the rope + # cache spans ``config.max_position_embeddings``). Pinning it to the + # build's ``max_cache_len`` needs that value threaded down to this + # load-time hook through the generic model loader, which is deferred to + # the follow-up PR. The ``apply_transformer_only_export_prep(..., + # max_rope_len=...)`` path is already implemented and unit-tested for it. apply_transformer_only_export_prep(model, matmul_to_conv=True) # Tag the config so the exporter resolves this variant's OnnxConfig # (registered under ``TRANSFORMER_ONLY_MODEL_TYPE``) rather than the From 272126e76461c4c91fad84004143392ee33b5001 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 3 Jul 2026 15:26:59 +0800 Subject: [PATCH 29/29] fix(qwen3): route export CLI via set_defaults(func) dispatch main() hard-dispatched to _cmd_export regardless of the parsed subcommand, so any future subcommand added to the add_subparsers scaffold would silently run export. Register each subparser's handler with set_defaults(func=...) and dispatch through args.func(args). --- scripts/qwen3.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/qwen3.py b/scripts/qwen3.py index 7d88938c1..1233c362d 100644 --- a/scripts/qwen3.py +++ b/scripts/qwen3.py @@ -155,6 +155,7 @@ def _add_export_parser(sub: argparse._SubParsersAction) -> None: # type: ignore help="Override path to a pre-built lm_head ONNX (skips auto-build).", ) p.add_argument("--force-rebuild", action="store_true", help="Rebuild even if cached.") + p.set_defaults(func=_cmd_export) def _cmd_export(args: argparse.Namespace) -> int: @@ -260,7 +261,9 @@ def main(argv: list[str] | None = None) -> int: _add_export_parser(sub) args = p.parse_args(argv) - return _cmd_export(args) + # Each subparser registers its handler via set_defaults(func=...); dispatch + # generically so new subcommands route to their own handler (not export). + return args.func(args) if __name__ == "__main__":