diff --git a/.gitignore b/.gitignore index 2184e9f72..17d96986b 100644 --- a/.gitignore +++ b/.gitignore @@ -241,6 +241,10 @@ nul export_config.json *_perf.json shape_config.json +/out/ + +# ONNX external data files at repo root (e.g. EPContext .data blobs) +/*.data # UV / pip uv.lock diff --git a/scripts/export_qwen3_embeddings_lm_head.py b/scripts/export_qwen3_embeddings_lm_head.py deleted file mode 100644 index 4769e4924..000000000 --- a/scripts/export_qwen3_embeddings_lm_head.py +++ /dev/null @@ -1,185 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""One-shot export of the Qwen3 embeddings + lm_head sub-models. - -Companion to ``export_qwen3_transformer_only.py``. That script builds the -transformer (prefill + decode). This one builds the two remaining pieces of the -split Qwen3 graph: - - - ``embeddings`` — input embedding table (``input_ids`` -> ``input_hidden_states``). - Built FLOAT (``precision="fp32"``); NOT quantized. Produces a ``Gather``. - - ``lm_head`` — vocab projection (``output_hidden_states`` -> ``logits``). - Quantized weight-only to int4 via MatMulNBits/RTN (``precision="w4a32"``). - Produces a ``MatMulNBits`` node (no float ``MatMul``). - -Both are standalone ``model_type`` builds, invoked separately. - -Usage:: - - # Build (or reuse cached) both ONNX, print their paths + node summary: - uv run python scripts/export_qwen3_embeddings_lm_head.py - - # Build only one of them: - uv run python scripts/export_qwen3_embeddings_lm_head.py --only embeddings - uv run python scripts/export_qwen3_embeddings_lm_head.py --only lm_head - - # Copy the ONNX (with external data) into a folder: - uv run python scripts/export_qwen3_embeddings_lm_head.py --output-dir out/qwen3 - # -> writes embeddings_fp32.onnx and lm_head_w4a32.onnx - - # Different model / device / seq geometry, force a rebuild: - uv run python scripts/export_qwen3_embeddings_lm_head.py \ - --model-id Qwen/Qwen3-0.6B --device npu --seq-len 64 --force-rebuild -""" - -from __future__ import annotations - -import argparse -import collections -import sys -import time -from pathlib import Path - -import onnx - -from winml.modelkit.models.auto import WinMLAutoModel -from winml.modelkit.onnx import copy_onnx_model - - -# Per-component build settings: which model_type to register against and the -# precision that drives its quant policy. Embeddings stay float (``fp32``); the -# lm_head is weight-only int4 with fp32 activations (``w4a32`` — the activations -# are NOT quantized, this is RTN/MatMulNBits). -_COMPONENTS = { - "embeddings": {"model_type": "qwen3_embeddings_only", "precision": "fp32"}, - "lm_head": {"model_type": "qwen3_lm_head_only", "precision": "w4a32"}, -} - -# Component -> output file stem (when --output-dir is given). The precision -# suffix is carried in the filename so the two pieces self-document their scheme. -_OUTPUT_STEMS = { - "embeddings": f"embeddings_{_COMPONENTS['embeddings']['precision']}", - "lm_head": f"lm_head_{_COMPONENTS['lm_head']['precision']}", -} - -# Default EP per device; CPU/NPU/GPU map to their canonical providers. -_DEVICE_TO_EP = { - "cpu": "CPUExecutionProvider", - "npu": "QNNExecutionProvider", - "gpu": "DmlExecutionProvider", -} - - -def node_summary(path: str | Path) -> str: - """Return a one-line structural summary of the graph's key ops. - - The interesting markers for these two sub-models are: - - embeddings: ``Gather`` present, no ``MatMulNBits`` / no QDQ (stays float). - - lm_head: ``MatMulNBits`` present, float ``MatMul`` gone (int4 weight-only). - """ - model = onnx.load(str(path), load_external_data=False) - counts = collections.Counter(n.op_type for n in model.graph.node) - return ( - f"Gather={counts['Gather']} MatMul={counts['MatMul']} " - f"MatMulNBits={counts['MatMulNBits']} " - f"Q={counts['QuantizeLinear']} DQ={counts['DequantizeLinear']}" - ) - - -def build_component(name: str, args: argparse.Namespace): - """Build (or reuse cached) one standalone sub-model and return it.""" - spec = _COMPONENTS[name] - print(f"\n=== building {name} (model_type={spec['model_type']}, " - f"precision={spec['precision']}) ===") - return WinMLAutoModel.from_pretrained( - args.model_id, - task="feature-extraction", - model_type=spec["model_type"], - precision=spec["precision"], - device=args.device, - ep=_DEVICE_TO_EP[args.device], - no_compile=args.no_compile, - use_cache=True, - force_rebuild=args.force_rebuild, - shape_config={"seq_len": args.seq_len}, - ) - - -def parse_args(argv: list[str] | None = None) -> argparse.Namespace: - """Parse command-line arguments.""" - p = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - p.add_argument("--model-id", default="Qwen/Qwen3-0.6B", help="HF model id or local path.") - p.add_argument( - "--device", - default="cpu", - choices=sorted(_DEVICE_TO_EP), - help="Target device (selects the canonical EP). Default: cpu.", - ) - p.add_argument( - "--only", - choices=sorted(_COMPONENTS), - default=None, - help="Build only this component. Default: build both.", - ) - p.add_argument( - "--seq-len", - type=int, - default=64, - help="Static sequence length baked into both sub-models. Default: 64.", - ) - p.add_argument( - "--no-compile", - dest="no_compile", - action="store_true", - default=True, - help="Skip EPContext compilation (default; these are consumed pre-compile).", - ) - p.add_argument( - "--compile", - dest="no_compile", - action="store_false", - help="Enable EPContext compilation (requires the device's compiler/SDK).", - ) - p.add_argument("--force-rebuild", action="store_true", help="Rebuild even if cached.") - p.add_argument( - "--output-dir", - type=Path, - default=None, - help="If set, copy the ONNX (with external data) here as " - "embeddings_fp32.onnx / lm_head_w4a32.onnx.", - ) - return p.parse_args(argv) - - -def main(argv: list[str] | None = None) -> int: - """Build (or reuse) the requested sub-models and report/copy them.""" - args = parse_args(argv) - - names = [args.only] if args.only else ["embeddings", "lm_head"] - - output_dir: Path | None = args.output_dir - if output_dir is not None: - output_dir.mkdir(parents=True, exist_ok=True) - - t0 = time.monotonic() - for name in names: - model = build_component(name, args) - src = Path(model.onnx_path) - print(f"[{name}] {src}") - print(f" {node_summary(src)}") - if output_dir is not None: - dst = output_dir / f"{_OUTPUT_STEMS[name]}.onnx" - copy_onnx_model(src, dst) - print(f" -> copied to {dst}") - - print(f"\n=== done in {time.monotonic() - t0:.1f}s ===") - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/export_qwen3_transformer_only.py b/scripts/export_qwen3_transformer_only.py deleted file mode 100644 index 6894af518..000000000 --- a/scripts/export_qwen3_transformer_only.py +++ /dev/null @@ -1,171 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""One-shot export of the Qwen3 transformer-only prefill + decode pair. - -Leverages the registered ``WinMLQwen3TransformerOnlyModel`` composite to build -BOTH transformer-only sub-models in a single call: - - - ``decoder_prefill`` — context graph, ``seq_len`` = --prefill-seq-len (64) - - ``decoder_gen`` — iteration graph, ``seq_len`` = 1 - -Each sub-model is built through the standard ``build_hf_model`` pipeline, so the -model-type quant finalizer is applied (int8 weight / uint16 activation, GQA -excluded from QDQ). Embeddings and the LM head are NOT part of this graph — they -run separately (e.g. from the bundle). - -Usage:: - - # Build (or reuse cached) both ONNX, print their paths + node summary: - uv run python scripts/export_qwen3_transformer_only.py - - # Copy the two ONNX (with external data) into a folder: - uv run python scripts/export_qwen3_transformer_only.py --output-dir out/qwen3 - - # Different model / device / cache geometry, force a rebuild: - uv run python scripts/export_qwen3_transformer_only.py \ - --model-id Qwen/Qwen3-0.6B --device npu \ - --max-cache-len 256 --prefill-seq-len 64 --force-rebuild -""" - -from __future__ import annotations - -import argparse -import collections -import sys -import time -from pathlib import Path - -import onnx - -from winml.modelkit.models.hf.qwen3.qwen_transformer_only import ( - WinMLQwen3TransformerOnlyModel, -) -from winml.modelkit.onnx import copy_onnx_model - - -# Component name -> output file stem used when --output-dir is given. -_OUTPUT_STEMS = { - "decoder_prefill": "prefill", - "decoder_gen": "decode", -} - -# Default EP per device; CPU/NPU/GPU map to their canonical providers. -_DEVICE_TO_EP = { - "cpu": "CPUExecutionProvider", - "npu": "QNNExecutionProvider", - "gpu": "DmlExecutionProvider", -} - - -def node_summary(path: str | Path) -> str: - """Return a one-line QDQ/GQA structural summary of an ONNX graph.""" - model = onnx.load(str(path), load_external_data=False) - counts = collections.Counter(n.op_type for n in model.graph.node) - gqa_io: set[str] = set() - for node in model.graph.node: - if node.op_type == "GroupQueryAttention": - gqa_io.update(node.input) - gqa_io.update(node.output) - qdq_touching_gqa = sum( - 1 - for n in model.graph.node - if n.op_type in ("QuantizeLinear", "DequantizeLinear") - and (set(n.input) & gqa_io or set(n.output) & gqa_io) - ) - return ( - f"Q={counts['QuantizeLinear']} DQ={counts['DequantizeLinear']} " - f"GQA={counts['GroupQueryAttention']} QDQ-touching-GQA={qdq_touching_gqa}" - ) - - -def parse_args(argv: list[str] | None = None) -> argparse.Namespace: - """Parse command-line arguments.""" - p = argparse.ArgumentParser( - description=__doc__, - formatter_class=argparse.RawDescriptionHelpFormatter, - ) - p.add_argument("--model-id", default="Qwen/Qwen3-0.6B", help="HF model id or local path.") - p.add_argument( - "--device", - default="cpu", - choices=sorted(_DEVICE_TO_EP), - help="Target device (selects the canonical EP). Default: cpu.", - ) - p.add_argument("--precision", default="w8a16", help="Build precision. Default: w8a16.") - p.add_argument("--max-cache-len", type=int, default=256, help="Static KV cache length.") - p.add_argument( - "--prefill-seq-len", - type=int, - default=64, - help="Prefill/context sequence length.", - ) - p.add_argument( - "--no-compile", - dest="no_compile", - action="store_true", - default=True, - help="Skip EPContext compilation (default; transformer-only is consumed pre-compile).", - ) - p.add_argument( - "--compile", - dest="no_compile", - action="store_false", - help="Enable EPContext compilation (requires the device's compiler/SDK).", - ) - p.add_argument("--force-rebuild", action="store_true", help="Rebuild even if cached.") - p.add_argument( - "--output-dir", - type=Path, - default=None, - help="If set, copy the two ONNX (with external data) here as prefill.onnx / decode.onnx.", - ) - return p.parse_args(argv) - - -def main(argv: list[str] | None = None) -> int: - """Build (or reuse) both transformer-only ONNX and report/copy them.""" - args = parse_args(argv) - - t0 = time.monotonic() - model = WinMLQwen3TransformerOnlyModel.from_pretrained( - args.model_id, - device=args.device, - precision=args.precision, - ep=_DEVICE_TO_EP[args.device], - no_compile=args.no_compile, - use_cache=True, - force_rebuild=args.force_rebuild, - sub_model_kwargs={ - "decoder_prefill": { - "shape_config": { - "max_cache_len": args.max_cache_len, - "seq_len": args.prefill_seq_len, - } - }, - "decoder_gen": {"shape_config": {"max_cache_len": args.max_cache_len, "seq_len": 1}}, - }, - ) - elapsed = time.monotonic() - t0 - - print(f"\n=== transformer-only build done in {elapsed:.1f}s ===") - - output_dir: Path | None = args.output_dir - if output_dir is not None: - output_dir.mkdir(parents=True, exist_ok=True) - - for name, sub in model.sub_models.items(): - src = Path(sub.onnx_path) - print(f"\n[{name}] {src}") - print(f" {node_summary(src)}") - if output_dir is not None: - dst = output_dir / f"{_OUTPUT_STEMS.get(name, name)}.onnx" - copy_onnx_model(src, dst) - print(f" -> copied to {dst}") - - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/qwen3.py b/scripts/qwen3.py new file mode 100644 index 000000000..1233c362d --- /dev/null +++ b/scripts/qwen3.py @@ -0,0 +1,270 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +r"""Qwen3 genai bundle export. + +Builds (or reuses) all four components of the Qwen3 genai bundle and assembles +them into an onnxruntime-genai directory: + + - ``ctx.onnx`` — transformer prefill graph (QDQ-quantized) + - ``iter.onnx`` — transformer decode graph (QDQ-quantized) + - ``embeddings.onnx`` — token embedding table (fp32) + - ``lm_head.onnx`` — vocab projection (w4a32 MatMulNBits) + - ``genai_config.json`` + HF tokenizer files + +Inference over the assembled bundle is provided separately (see the genai +inference session), so this script only covers bundle generation. + +Usage:: + + # Full export to out/bundle (default context_length=2048): + uv run python scripts/qwen3.py export --device npu --output out/bundle + + # Force rebuild from scratch: + uv run python scripts/qwen3.py export --device npu --output out/bundle --force-rebuild +""" + +from __future__ import annotations + +import argparse +import collections +import sys +import time +from pathlib import Path + +import onnx + +from winml.modelkit.models.auto import WinMLAutoModel +from winml.modelkit.models.hf.qwen3.qwen_transformer_only import ( + WinMLQwen3TransformerOnlyModel, +) +from winml.modelkit.onnx import strip_node_attrs + + +_DEVICE_TO_EP = { + "cpu": "CPUExecutionProvider", + "npu": "QNNExecutionProvider", + "gpu": "DmlExecutionProvider", +} + +# Build specs for the two CPU-side companion models. +_COMPANION_COMPONENTS: dict[str, dict[str, str]] = { + "embeddings": {"model_type": "qwen3_embeddings_only", "precision": "fp32"}, + "lm_head": {"model_type": "qwen3_lm_head_only", "precision": "w4a32"}, +} + +_REPO_ROOT = Path(__file__).resolve().parent.parent +_DEFAULT_BUNDLE = _REPO_ROOT / "out" / "bundle" + +# Attributes that com.microsoft::GroupQueryAttention requires for Qwen3. +# Any other attributes (e.g. k_quant_type, local_window_size, qk_output, +# smooth_softmax, v_quant_type) are default-valued extras injected by the +# TorchScript exporter from the ORT op schema; strip them so the bundle +# matches the expected minimal attribute set. +_GQA_KEEP_ATTRS = frozenset({"do_rotary", "kv_num_heads", "num_heads"}) + + +def _strip_gqa_default_attrs(model: onnx.ModelProto) -> onnx.ModelProto: + """Remove exporter-injected default attributes from GQA nodes.""" + return strip_node_attrs(model, "GroupQueryAttention", _GQA_KEEP_ATTRS, domain="com.microsoft") + + +# --------------------------------------------------------------------------- +# Helpers shared between sub-commands +# --------------------------------------------------------------------------- + + +def _node_summary(path: str | Path) -> str: + """One-line op-type summary of an ONNX graph (loads shape metadata only).""" + model = onnx.load(str(path), load_external_data=False) + counts = collections.Counter(n.op_type for n in model.graph.node) + gqa_io: set[str] = set() + for node in model.graph.node: + if node.op_type == "GroupQueryAttention": + gqa_io.update(node.input) + gqa_io.update(node.output) + qdq_touching_gqa = sum( + 1 + for n in model.graph.node + if n.op_type in ("QuantizeLinear", "DequantizeLinear") + and (set(n.input) & gqa_io or set(n.output) & gqa_io) + ) + return ( + f"Gather={counts['Gather']} MatMulNBits={counts['MatMulNBits']} " + f"Q={counts['QuantizeLinear']} DQ={counts['DequantizeLinear']} " + f"GQA={counts['GroupQueryAttention']} QDQ@GQA={qdq_touching_gqa}" + ) + + +# --------------------------------------------------------------------------- +# export sub-command +# --------------------------------------------------------------------------- + + +def _add_export_parser(sub: argparse._SubParsersAction) -> None: # type: ignore[type-arg] + p = sub.add_parser( + "export", + help="Build the full Qwen3 genai bundle (transformer + embeddings + lm_head).", + formatter_class=argparse.RawDescriptionHelpFormatter, + description=( + "Exports all four Qwen3 genai components and assembles them into an " + "onnxruntime-genai bundle directory. Transformer stages (ctx/iter) are " + "built for the target device; embeddings and lm_head always run on CPU." + ), + ) + p.add_argument("--model-id", default="Qwen/Qwen3-0.6B", help="HF model id or local path.") + p.add_argument( + "--device", + default="npu", + choices=sorted(_DEVICE_TO_EP), + help="Target device for transformer stages. Default: npu.", + ) + p.add_argument("--precision", default="w8a16", help="Transformer precision. Default: w8a16.") + p.add_argument( + "--max-cache-len", + type=int, + default=2048, + help="Static KV cache length (context_length). Default: 2048.", + ) + p.add_argument( + "--prefill-seq-len", + type=int, + default=64, + help="Prefill/context sequence length baked into ctx.onnx. Default: 64.", + ) + p.add_argument( + "--output", + type=Path, + default=_DEFAULT_BUNDLE, + metavar="DIR", + help=f"Bundle output directory. Default: {_DEFAULT_BUNDLE}.", + ) + p.add_argument( + "--embeddings", + type=Path, + default=None, + metavar="ONNX", + help="Override path to a pre-built embeddings ONNX (skips auto-build).", + ) + p.add_argument( + "--lm-head", + type=Path, + default=None, + metavar="ONNX", + help="Override path to a pre-built lm_head ONNX (skips auto-build).", + ) + p.add_argument("--force-rebuild", action="store_true", help="Rebuild even if cached.") + p.set_defaults(func=_cmd_export) + + +def _cmd_export(args: argparse.Namespace) -> int: + """Build all components and write the genai bundle.""" + from winml.modelkit.models.hf.qwen3.genai import write_genai_bundle + + t0 = time.monotonic() + + # --- Transformer (ctx + iter) --- + print(f"\n=== building transformer stages (device={args.device}) ===") + transformer = WinMLQwen3TransformerOnlyModel.from_pretrained( + args.model_id, + device=args.device, + precision=args.precision, + ep=_DEVICE_TO_EP[args.device], + no_compile=True, + use_cache=True, + force_rebuild=args.force_rebuild, + sub_model_kwargs={ + "decoder_prefill": { + "shape_config": { + "max_cache_len": args.max_cache_len, + "seq_len": args.prefill_seq_len, + } + }, + "decoder_gen": {"shape_config": {"max_cache_len": args.max_cache_len, "seq_len": 1}}, + }, + ) + prefill_path = Path(transformer.sub_models["decoder_prefill"].onnx_path) + decode_path = Path(transformer.sub_models["decoder_gen"].onnx_path) + for name, path in (("ctx", prefill_path), ("iter", decode_path)): + print(f" [{name}] {path}") + print(f" {_node_summary(path)}") + + # --- Companion models (embeddings + lm_head) --- + embeddings_src = args.embeddings + lm_head_src = args.lm_head + + for key, override in (("embeddings", embeddings_src), ("lm_head", lm_head_src)): + if override is not None: + print(f"\n=== using provided {key}: {override} ===") + continue + spec = _COMPANION_COMPONENTS[key] + print( + f"\n=== building {key} " + f"(model_type={spec['model_type']}, precision={spec['precision']}) ===" + ) + companion = WinMLAutoModel.from_pretrained( + args.model_id, + task="feature-extraction", + model_type=spec["model_type"], + precision=spec["precision"], + device="cpu", + ep=_DEVICE_TO_EP["cpu"], + no_compile=True, + use_cache=True, + force_rebuild=args.force_rebuild, + ) + companion_path = Path(companion.onnx_path) + print(f" [{key}] {companion_path}") + print(f" {_node_summary(companion_path)}") + if key == "embeddings": + embeddings_src = companion_path + else: + lm_head_src = companion_path + + # --- Assemble bundle --- + print(f"\n=== assembling bundle -> {args.output} ===") + config_path = write_genai_bundle( + args.output, + context_onnx=prefill_path, + iterator_onnx=decode_path, + model_id=args.model_id, + max_cache_len=args.max_cache_len, + prefill_seq_len=args.prefill_seq_len, + embeddings_src=embeddings_src, + lm_head_src=lm_head_src, + ep="qnn" if args.device == "npu" else args.device, + transformer_onnx_passes=[_strip_gqa_default_attrs], + ) + print(f" genai_config.json -> {config_path}") + + elapsed = time.monotonic() - t0 + print(f"\n=== export complete in {elapsed:.1f}s ===") + return 0 + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + + +def main(argv: list[str] | None = None) -> int: + """Parse sub-command and dispatch to the appropriate handler.""" + p = argparse.ArgumentParser( + prog="qwen3", + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + sub = p.add_subparsers(dest="command", metavar="") + sub.required = True + + _add_export_parser(sub) + + args = p.parse_args(argv) + # Each subparser registers its handler via set_defaults(func=...); dispatch + # generically so new subcommands route to their own handler (not export). + return args.func(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/winml/modelkit/models/hf/qwen3/__init__.py b/src/winml/modelkit/models/hf/qwen3/__init__.py index 332fb9234..dbabe2d60 100644 --- a/src/winml/modelkit/models/hf/qwen3/__init__.py +++ b/src/winml/modelkit/models/hf/qwen3/__init__.py @@ -3,4 +3,30 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Qwen3 transformer-only export support (modeling, export ops, IO configs).""" +"""Qwen3 transformer-only export + genai bundle support. + +Modules: + qwen_transformer_only — OnnxConfig, build config, composite model class. + qwen3_modeling — winml-owned Qwen3 module definitions (forward bindings). + qwen3_export_ops — custom ONNX symbolic ops (LpNorm, GQA, 1x1 Conv). + genai — genai_config.json generator + bundle assembler. +""" + +from .genai import ( + DecoderIOMapping, + PipelineStage, + build_decoder_pipeline_stages, + build_genai_config, + build_qwen3_transformer_only_stages, + write_genai_bundle, +) + + +__all__ = [ + "DecoderIOMapping", + "PipelineStage", + "build_decoder_pipeline_stages", + "build_genai_config", + "build_qwen3_transformer_only_stages", + "write_genai_bundle", +] diff --git a/src/winml/modelkit/models/hf/qwen3/genai.py b/src/winml/modelkit/models/hf/qwen3/genai.py new file mode 100644 index 000000000..c1568ffa7 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3/genai.py @@ -0,0 +1,220 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Qwen3 genai bundle support built on :mod:`winml.modelkit.utils.genai`. + +The generic, execution-provider-agnostic machinery (``PipelineStage``, +``DecoderIOMapping``, ``build_genai_config``, ``build_decoder_pipeline_stages``, +``write_genai_bundle``) lives in :mod:`winml.modelkit.utils.genai` so it can be +reused by other model families. + +This module adds the **Qwen3-specific** layer on top: the Qwen3 transformer +stages target the QNN HTP (NPU) backend, so this is where the QNN +``session_options`` are constructed. Keeping the EP-specific logic here lets the +generic utilities stay universal while the Qwen3 bundle keeps emitting the exact +same ``genai_config.json`` as before. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ....utils.genai import ( + DEFAULT_CONTEXT_FILENAME, + DEFAULT_EMBEDDINGS_FILENAME, + DEFAULT_ITERATOR_FILENAME, + DEFAULT_LM_HEAD_FILENAME, + DecoderIOMapping, + PipelineStage, + build_decoder_pipeline_stages, + build_genai_config, +) +from ....utils.genai import ( + write_genai_bundle as _write_genai_bundle, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from pathlib import Path + + import onnx + + +# --------------------------------------------------------------------------- +# Qwen3-specific QNN execution-provider routing +# --------------------------------------------------------------------------- + + +def qnn_stage_session_options(log_id: str, soc_model: str = "60") -> dict: + """Return the ``session_options`` block that routes a stage to QNN HTP. + + Args: + log_id: ORT log identifier (shown in ORT logs), e.g. + ``"onnxruntime-genai.context"``. + soc_model: Snapdragon SoC model number passed to the QNN HTP backend. + ``"60"`` targets Snapdragon 8 Gen 3 (X Elite). Change for other + SoCs (e.g. ``"55"`` for 8 Gen 2, ``"73"`` for 8 Elite). + + Returns: + Dict suitable for the ``session_options`` key of a pipeline stage in + ``genai_config.json``. + """ + return { + "log_id": log_id, + "provider_options": [ + { + "qnn": { + "backend_path": "QnnHtp.dll", + "htp_performance_mode": "burst", + "htp_graph_finalization_optimization_mode": "3", + "soc_model": soc_model, + } + } + ], + "intra_op_num_threads": 2, + "inter_op_num_threads": 1, + } + + +def _stage_session_options(ep: str, soc_model: str) -> tuple[dict | None, dict | None]: + """Return ``(context, iterator)`` session_options for the given EP. + + ``ep="qnn"`` routes the transformer stages to the QNN HTP (NPU) backend; any + other value (e.g. ``"cpu"``) leaves them on the default CPU provider. + """ + if ep == "qnn": + return ( + qnn_stage_session_options("onnxruntime-genai.context", soc_model=soc_model), + qnn_stage_session_options("onnxruntime-genai.iterator", soc_model=soc_model), + ) + return None, None + + +# --------------------------------------------------------------------------- +# Qwen3-specific stage factory + bundle assembler +# --------------------------------------------------------------------------- + + +def build_qwen3_transformer_only_stages( + context_onnx: str | Path, + iterator_onnx: str | Path, + num_layers: int, + *, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", +) -> tuple[list[PipelineStage], DecoderIOMapping]: + """Build the Qwen3 4-stage pipeline, routing ctx/iter to QNN when ``ep="qnn"``. + + Qwen3-specific wrapper over + :func:`winml.modelkit.utils.genai.build_decoder_pipeline_stages` that injects + the QNN ``session_options`` for the transformer stages. Tensor names are + still discovered by introspecting the ONNX graphs, so nothing is hardcoded. + + Args: + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + num_layers: Number of transformer layers (``hf_config.num_hidden_layers``). + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. + ep: ``"qnn"`` injects QNN HTP ``session_options`` into the ``context`` + and ``iterator`` stages so they run on the NPU while ``embeddings`` + and ``lm_head`` stay on CPU. ``"cpu"`` (default) omits them. + soc_model: Snapdragon SoC model number forwarded to the QNN backend when + ``ep="qnn"``. Default ``"60"`` targets Snapdragon 8 Gen 3. + + Returns: + ``(stages, decoder_io)`` — see + :func:`~winml.modelkit.utils.genai.build_decoder_pipeline_stages`. + """ + ctx_opts, iter_opts = _stage_session_options(ep, soc_model) + return build_decoder_pipeline_stages( + context_onnx, + iterator_onnx, + num_layers, + context_filename=context_filename, + iterator_filename=iterator_filename, + embeddings_filename=embeddings_filename, + lm_head_filename=lm_head_filename, + context_session_options=ctx_opts, + iterator_session_options=iter_opts, + ) + + +def write_genai_bundle( + output_dir: str | Path, + *, + context_onnx: str | Path, + iterator_onnx: str | Path, + model_id: str, + max_cache_len: int, + prefill_seq_len: int, + embeddings_src: str | Path | None = None, + lm_head_src: str | Path | None = None, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + ep: str = "cpu", + soc_model: str = "60", + transformer_onnx_passes: Sequence[Callable[[onnx.ModelProto], onnx.ModelProto]] | None = None, +) -> Path: + """Assemble a Qwen3 genai bundle, routing ctx/iter to QNN when ``ep="qnn"``. + + Qwen3-specific wrapper over + :func:`winml.modelkit.utils.genai.write_genai_bundle` that supplies the QNN + ``session_options`` for the transformer stages. See the generic function for + the description of every other argument. + + Args: + ep: ``"qnn"`` routes the transformer (context/iterator) stages to the QNN + HTP (NPU) backend; ``"cpu"`` (default) keeps every stage on CPU. + soc_model: Snapdragon SoC model passed to the QNN backend when + ``ep="qnn"``. Default ``"60"`` = Snapdragon 8 Gen 3 / X Elite. + transformer_onnx_passes: Optional ONNX graph transforms applied to the + copied context/iterator models before ``genai_config.json`` is + written. Forwarded verbatim to the generic assembler. + + Returns: + Path to the written ``genai_config.json``. + """ + ctx_opts, iter_opts = _stage_session_options(ep, soc_model) + return _write_genai_bundle( + output_dir, + context_onnx=context_onnx, + iterator_onnx=iterator_onnx, + model_id=model_id, + max_cache_len=max_cache_len, + prefill_seq_len=prefill_seq_len, + embeddings_src=embeddings_src, + lm_head_src=lm_head_src, + context_filename=context_filename, + iterator_filename=iterator_filename, + embeddings_filename=embeddings_filename, + lm_head_filename=lm_head_filename, + context_session_options=ctx_opts, + iterator_session_options=iter_opts, + transformer_onnx_passes=transformer_onnx_passes, + ) + + +__all__ = [ + "DEFAULT_CONTEXT_FILENAME", + "DEFAULT_EMBEDDINGS_FILENAME", + "DEFAULT_ITERATOR_FILENAME", + "DEFAULT_LM_HEAD_FILENAME", + "DecoderIOMapping", + "PipelineStage", + "build_decoder_pipeline_stages", + "build_genai_config", + "build_qwen3_transformer_only_stages", + "qnn_stage_session_options", + "write_genai_bundle", +] diff --git a/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py index 140c2dac8..e691038d5 100644 --- a/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py +++ b/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py @@ -96,8 +96,18 @@ class WinMLQwen3Attention(nn.Module): v_proj: nn.Module o_proj: nn.Module - def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: - """Optionally swap the Q/K/V/O projections for 1x1 convs.""" + def prepare_for_onnx_export( + self, *, matmul_to_conv: bool, max_rope_len: int | None = None + ) -> None: + """Optionally swap the Q/K/V/O projections for 1x1 convs. + + ``max_rope_len`` pins the cos/sin rope cache to exactly that many + positions so the exported ONNX constant has shape ``[max_rope_len, 64]`` + instead of ``[config.max_position_embeddings, 64]``. It must be a + plain Python ``int`` (not a tensor) so the ``torch.arange`` call in + ``forward`` bakes a static-shape constant into the ONNX graph rather + than a dynamic symbolic range. + """ if matmul_to_conv: self.q_proj = TransposeConv2d1x1Transpose.from_linear_module( cast("nn.Linear", self.q_proj) @@ -112,6 +122,9 @@ def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: cast("nn.Linear", self.o_proj) ) self._matmul_to_conv = matmul_to_conv + self._max_rope_len: int = ( + int(max_rope_len) if max_rope_len is not None else self.config.max_position_embeddings + ) def forward( self, @@ -155,7 +168,7 @@ def forward( cos, sin = cast("nn.Module", self.rotary_emb)( value_states, - torch.arange(self.config.max_position_embeddings).unsqueeze(0), + torch.arange(self._max_rope_len).unsqueeze(0), ) cos = cos.squeeze(0)[:, : cos.shape[-1] // 2] sin = sin.squeeze(0)[:, : sin.shape[-1] // 2] @@ -273,7 +286,7 @@ def forward( def apply_transformer_only_export_prep( - causal_lm: nn.Module, *, matmul_to_conv: bool = True + causal_lm: nn.Module, *, matmul_to_conv: bool = True, max_rope_len: int | None = None ) -> None: """Mutate ``Qwen3ForCausalLM`` in-place into the export topology. @@ -287,6 +300,11 @@ def apply_transformer_only_export_prep( causal_lm: A ``transformers.Qwen3ForCausalLM`` instance. matmul_to_conv: Swap ``nn.Linear`` projections to 1x1 ``Conv2d`` so QNN sees them as Conv. + max_rope_len: Number of positions to pre-compute for the cos/sin rope + cache. Pass the model's ``max_cache_len`` (e.g. 4096) so the + exported constant has shape ``[max_rope_len, 64]`` instead of + ``[config.max_position_embeddings, 64]`` (40960). Must be a plain + Python ``int`` — see ``WinMLQwen3Attention.prepare_for_onnx_export``. Raises: RuntimeError: If any expected Qwen3 submodule class is not found, @@ -321,7 +339,9 @@ def _is(module: nn.Module, name: str) -> bool: for mod in causal_lm.modules(): if _is(mod, "Qwen3Attention"): WinMLQwen3Attention.prepare_for_onnx_export( - cast("WinMLQwen3Attention", mod), matmul_to_conv=matmul_to_conv + cast("WinMLQwen3Attention", mod), + matmul_to_conv=matmul_to_conv, + max_rope_len=max_rope_len, ) _bind(mod, WinMLQwen3Attention) patched["Qwen3Attention"] += 1 diff --git a/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py index 46cd13a6f..06cd3bdd9 100644 --- a/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py @@ -75,6 +75,12 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: self.model = model self.num_layers = num_layers self.config: PretrainedConfig = cast("PretrainedConfig", model.config) + # ``max_rope_len`` is intentionally left at its default here (the rope + # cache spans ``config.max_position_embeddings``). Pinning it to the + # build's ``max_cache_len`` needs that value threaded down to this + # load-time hook through the generic model loader, which is deferred to + # the follow-up PR. The ``apply_transformer_only_export_prep(..., + # max_rope_len=...)`` path is already implemented and unit-tested for it. apply_transformer_only_export_prep(model, matmul_to_conv=True) # Tag the config so the exporter resolves this variant's OnnxConfig # (registered under ``TRANSFORMER_ONLY_MODEL_TYPE``) rather than the diff --git a/src/winml/modelkit/onnx/__init__.py b/src/winml/modelkit/onnx/__init__.py index 4c7fe312d..d2e315e0d 100644 --- a/src/winml/modelkit/onnx/__init__.py +++ b/src/winml/modelkit/onnx/__init__.py @@ -22,7 +22,7 @@ from .metadata import capture_metadata, restore_metadata from .persistence import ONNXSaveError, cleanup_onnx, load_onnx, save_onnx from .shape import infer_onnx_shapes, infer_shapes -from .utils import EXTERNAL_DATA_THRESHOLD, check_onnx_model, get_model_size +from .utils import EXTERNAL_DATA_THRESHOLD, check_onnx_model, get_model_size, strip_node_attrs __all__ = [ @@ -48,6 +48,7 @@ "remove_optional_from_type_annotation", "restore_metadata", "save_onnx", + "strip_node_attrs", ] diff --git a/src/winml/modelkit/onnx/utils.py b/src/winml/modelkit/onnx/utils.py index e42d3407e..b598fd79a 100644 --- a/src/winml/modelkit/onnx/utils.py +++ b/src/winml/modelkit/onnx/utils.py @@ -13,6 +13,41 @@ EXTERNAL_DATA_THRESHOLD = 100 * 1024 * 1024 # 100 MiB +def strip_node_attrs( + model: onnx.ModelProto, + op_type: str, + keep_attrs: frozenset[str] | set[str], + domain: str = "", +) -> onnx.ModelProto: + """Remove all attributes from matching nodes except those in *keep_attrs*. + + Useful for stripping default-valued optional attributes that an exporter + fills in automatically but that are not needed at inference time. + + Modifies *model* **in-place** and also returns it for convenient chaining. + + Args: + model: ONNX model proto to modify. + op_type: Operator type string (e.g. ``"GroupQueryAttention"``). + keep_attrs: Attribute names to retain; every other attribute is removed. + domain: Operator domain to match (e.g. ``"com.microsoft"``). The + empty string matches the default ONNX domain. + + Returns: + The same *model* object (mutated in-place). + """ + for node in model.graph.node: + if node.op_type != op_type or node.domain != domain: + continue + to_remove = [a.name for a in node.attribute if a.name not in keep_attrs] + for name in to_remove: + for i, a in enumerate(node.attribute): + if a.name == name: + del node.attribute[i] + break + return model + + def get_model_size(model: onnx.ModelProto) -> int: """Calculate the total size of an ONNX model in bytes. diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index 28555bcec..0bdd7ccd8 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -27,6 +27,7 @@ from typing import TYPE_CHECKING, Any +from .calibration import get_quant_finalizer from .config import QuantizeResult, WinMLQuantizationConfig from .passes import BaseQuantPass, FP16Pass, RTNPass, StaticPass @@ -49,10 +50,11 @@ ] -# Names below are loaded lazily via ``__getattr__`` to avoid pulling in -# onnxruntime.quantization/torch at import time. The TYPE_CHECKING re-imports -# give static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports -# without triggering the heavy imports at runtime. +# ``quantize_onnx`` is loaded lazily via ``__getattr__`` to avoid pulling in +# onnxruntime.quantization at import time. The TYPE_CHECKING re-import gives +# static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports. +# ``get_quant_finalizer`` is imported directly above — its module chain +# (calibration/__init__ -> registry) is lightweight and safe at import time. if TYPE_CHECKING: from .calibration import get_quant_finalizer from .quantizer import Quantizer, expand_precision, quantize_onnx @@ -62,7 +64,6 @@ "quantize_onnx": (".quantizer", "quantize_onnx"), "Quantizer": (".quantizer", "Quantizer"), "expand_precision": (".quantizer", "expand_precision"), - "get_quant_finalizer": (".calibration", "get_quant_finalizer"), } diff --git a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py index 04901ca7c..3e2637be6 100644 --- a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py +++ b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py @@ -370,6 +370,8 @@ def finalize_transformer_only_quant_config( # quantizer dispatch on ``config.mode`` (fp16/rtn/static), so a build whose # precision policy resolved to "fp16"/"rtn" would otherwise bypass QDQ and # silently ignore the calibration reader + GQA exclusion set below. + # uint16 activations give higher precision than uint8; the trade-off is that + # ORT requires opset >= 21 for 16-bit QDQ, so ctx/iter will export at opset 21. quant.mode = "static" quant.weight_type = "int8" quant.activation_type = "uint16" diff --git a/src/winml/modelkit/utils/genai.py b/src/winml/modelkit/utils/genai.py new file mode 100644 index 000000000..0931d56d4 --- /dev/null +++ b/src/winml/modelkit/utils/genai.py @@ -0,0 +1,707 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +r"""Generic onnxruntime-genai bundle utilities for decoder-pipeline models. + +The bundle is a directory that ``onnxruntime-genai`` can load directly via +``og.Config(str(bundle_dir))``. It contains: + + genai_config.json — pipeline config consumed by onnxruntime-genai + ctx.onnx — prefill/context ONNX + iter.onnx — iteration/decode ONNX + embeddings.onnx — embedding-lookup ONNX + lm_head.onnx — lm_head ONNX + tokenizer.json — HF tokenizer files (downloaded from the model repo) + tokenizer_config.json + vocab.json / merges.txt / generation_config.json + +The pipeline follows the standard 4-stage decoder layout: + + input_ids → [embeddings] → input_hidden_states + → [context | iterator] → output_hidden_states + present KVs + → [lm_head] → logits + +The context stage runs on the prompt (prefill); the iterator stage runs on each +subsequent decode step. Both share the same KV cache buffer via genai's +``past_present_share_buffer`` mode. + +Per-stage execution-provider routing (e.g. running the transformer stages on an +NPU) is expressed through the generic ``PipelineStage.session_options`` field and +is supplied by the caller — this module is itself execution-provider-agnostic and +hardcodes no EP-specific settings. + +Public API:: + + from winml.modelkit.utils.genai import ( + build_genai_config, + build_decoder_pipeline_stages, + write_genai_bundle, + DecoderIOMapping, + PipelineStage, + ) + + # Build stages by introspecting the ONNX I/O (no hardcoded tensor names) + stages, decoder_io = build_decoder_pipeline_stages( + ctx_path, iter_path, num_layers=hf_config.num_hidden_layers + ) + cfg = build_genai_config( + hf_config, max_cache_len=256, prefill_seq_len=64, + pipeline=stages, decoder_io=decoder_io, + ) + + # Or one-shot bundle assembly + write_genai_bundle( + Path("out/bundle"), + context_onnx=ctx_path, + iterator_onnx=iter_path, + model_id="Qwen/Qwen3-0.6B", + max_cache_len=256, + prefill_seq_len=64, + embeddings_src=emb_path, # None = skip (add later) + lm_head_src=lmh_path, # None = skip (add later) + ) +""" + +from __future__ import annotations + +import json +import logging +import re +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + + from onnx import ModelProto + + +logger = logging.getLogger(__name__) + +# Default filenames inside the bundle directory. +DEFAULT_EMBEDDINGS_FILENAME = "embeddings.onnx" +DEFAULT_CONTEXT_FILENAME = "ctx.onnx" +DEFAULT_ITERATOR_FILENAME = "iter.onnx" +DEFAULT_LM_HEAD_FILENAME = "lm_head.onnx" + +# Regex for detecting indexed tensor names such as ``past_keys_3``. +_KV_INDEXED_RE = re.compile(r"^(.+?)(\d+)$") + + +# --------------------------------------------------------------------------- +# Pipeline data structures +# --------------------------------------------------------------------------- + + +@dataclass +class PipelineStage: + """One stage in an onnxruntime-genai multi-model pipeline. + + Attributes: + name: Stage key used inside the ``pipeline`` list of ``genai_config.json``. + filename: ONNX filename inside the bundle directory. + run_on_prompt: Whether genai runs this stage during the prefill pass. + run_on_token_gen: Whether genai runs this stage during decode steps. + inputs: Actual ONNX input tensor names (not format strings). + outputs: Actual ONNX output tensor names (not format strings). + is_lm_head: Set ``True`` for the final language-model head stage. + """ + + name: str + filename: str + run_on_prompt: bool + run_on_token_gen: bool + inputs: list[str] + outputs: list[str] + is_lm_head: bool = False + session_options: dict | None = None + """Per-stage ORT session options (e.g. execution-provider selection and + provider_options). + + When set, emitted verbatim as the ``session_options`` key in the + ``genai_config.json`` pipeline stage. Leave ``None`` (default) for + stages that should run on the default (CPU) provider. + """ + + def to_dict(self) -> dict: + """Serialize to the dict format expected by ``genai_config.json``.""" + d: dict = { + "filename": self.filename, + "inputs": list(self.inputs), + "outputs": list(self.outputs), + "run_on_prompt": self.run_on_prompt, + "run_on_token_gen": self.run_on_token_gen, + } + if self.session_options: + d["session_options"] = self.session_options + if self.is_lm_head: + d["is_lm_head"] = True + return d + + +@dataclass +class DecoderIOMapping: + """Maps genai's abstract I/O concepts to ONNX tensor name format strings. + + The ``*_names`` fields use ``%d`` as the layer-index placeholder, which is + the convention genai uses to expand per-layer KV cache tensor names + (e.g. ``"past_keys_%d"`` → ``"past_keys_0"``, ``"past_keys_1"``, …). + + Defaults match the Qwen3 transformer-only export naming; override any field + when building bundles for models with different tensor names. + """ + + input_ids: str = "input_ids" + past_sequence_length: str = "past_seq_len" + total_sequence_length: str = "total_seq_len" + past_key_names: str = "past_keys_%d" + past_value_names: str = "past_values_%d" + logits: str = "logits" + present_key_names: str = "present_keys_%d" + present_value_names: str = "present_values_%d" + + def inputs_dict(self) -> dict: + """Return the ``decoder.inputs`` mapping dict for ``genai_config.json``.""" + return { + "input_ids": self.input_ids, + "past_sequence_length": self.past_sequence_length, + "total_sequence_length": self.total_sequence_length, + "past_key_names": self.past_key_names, + "past_value_names": self.past_value_names, + } + + def outputs_dict(self) -> dict: + """Return the ``decoder.outputs`` mapping dict for ``genai_config.json``.""" + return { + "logits": self.logits, + "present_key_names": self.present_key_names, + "present_value_names": self.present_value_names, + } + + +# --------------------------------------------------------------------------- +# Generic config builder +# --------------------------------------------------------------------------- + + +def build_genai_config( + hf_config: Any, + *, + max_cache_len: int, + prefill_seq_len: int | None = None, + pipeline: list[PipelineStage], + decoder_io: DecoderIOMapping | None = None, +) -> dict: + """Build a ``genai_config.json`` dict for any decoder-pipeline model. + + This function is architecture-agnostic: the caller supplies the pipeline + stages and the I/O name mapping so no tensor names are hardcoded here. + + Args: + hf_config: A ``transformers.PretrainedConfig``. Reads: + ``num_hidden_layers``, ``hidden_size``, ``num_attention_heads``, + ``num_key_value_heads``, ``head_dim`` (optional, falls back to + ``hidden_size // num_attention_heads``), ``bos_token_id``, + ``eos_token_id``, ``pad_token_id``, ``vocab_size``. + max_cache_len: Static KV cache length → ``context_length`` and + ``search.max_length``. + prefill_seq_len: When given, emits a ``sliding_window`` section with + ``window_size=prefill_seq_len``. Pass ``None`` to omit. + pipeline: Ordered list of :class:`PipelineStage` describing each + model in the genai pipeline. + decoder_io: Format-string mapping from genai's abstract I/O names to + actual ONNX tensor names. Defaults to + :class:`DecoderIOMapping` (the standard names). + + Returns: + A ``dict`` suitable for ``json.dumps`` as ``genai_config.json``. + """ + if decoder_io is None: + decoder_io = DecoderIOMapping() + + num_layers: int = hf_config.num_hidden_layers + head_size: int = getattr( + hf_config, + "head_dim", + hf_config.hidden_size // hf_config.num_attention_heads, + ) + + eos_token_id: int | list[int] = hf_config.eos_token_id + # Pass lists through unchanged — ORT genai accepts a JSON array of EOS token + # IDs and treats any of them as a valid stop signal. Truncating to [0] would + # silently discard secondary EOS tokens (e.g. Qwen3 uses [151645, 151643]) + # and cause generation to run until max_length instead of stopping early. + + pad_token_id = getattr(hf_config, "pad_token_id", None) + if pad_token_id is None: + pad_token_id = hf_config.bos_token_id + + decoder_section: dict = { + "hidden_size": hf_config.hidden_size, + "num_attention_heads": hf_config.num_attention_heads, + "num_key_value_heads": hf_config.num_key_value_heads, + "num_hidden_layers": num_layers, + "head_size": head_size, + } + + if prefill_seq_len is not None: + decoder_section["sliding_window"] = { + "window_size": prefill_seq_len, + "pad_value": 0, + "alignment": "left", + "slide_inputs": True, + "slide_key_value_cache": False, + } + + decoder_section["inputs"] = decoder_io.inputs_dict() + decoder_section["outputs"] = decoder_io.outputs_dict() + decoder_section["pipeline"] = [{s.name: s.to_dict()} for s in pipeline] + + return { + "model": { + "type": "decoder-pipeline", + "bos_token_id": hf_config.bos_token_id, + "eos_token_id": eos_token_id, + "pad_token_id": pad_token_id, + "vocab_size": hf_config.vocab_size, + "context_length": max_cache_len, + "decoder": decoder_section, + }, + "search": { + "max_length": max_cache_len, + "min_length": 0, + "do_sample": False, + "past_present_share_buffer": True, + }, + } + + +# --------------------------------------------------------------------------- +# ONNX introspection helpers +# --------------------------------------------------------------------------- + + +def _introspect_onnx_io(onnx_path: Path) -> tuple[list[str], list[str]]: + """Return ``(input_names, output_names)`` from an ONNX model graph header. + + External data is intentionally not loaded — only the graph topology is read, + so this is fast even for large quantized models. + """ + try: + import onnx + except ImportError as exc: + raise ImportError( + "The 'onnx' package is required for ONNX introspection. " + "Install it with: pip install onnx" + ) from exc + model = onnx.load(str(onnx_path), load_external_data=False) + return ( + [inp.name for inp in model.graph.input], + [out.name for out in model.graph.output], + ) + + +def _detect_format_patterns(names: list[str], num_layers: int) -> dict[str, str]: + """Detect ``prefix%d`` patterns from a list of indexed tensor names. + + Scans *names* for entries matching ```` where exactly + *num_layers* consecutive zero-based indices are present. + + Returns: + ``{prefix: "prefix%d"}`` for each qualifying group, in the order the + prefixes first appear in *names*. Only groups covering the full + ``[0, num_layers)`` index range are returned. + + Examples:: + + >>> _detect_format_patterns( + ... ["past_keys_0", "past_keys_1", "past_values_0", "past_values_1"], + ... num_layers=2, + ... ) + {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"} + """ + groups: dict[str, list[int]] = {} + for name in names: + m = _KV_INDEXED_RE.match(name) + if m: + prefix, idx = m.group(1), int(m.group(2)) + groups.setdefault(prefix, []).append(idx) + + return { + prefix: f"{prefix}%d" + for prefix, indices in groups.items() + if len(indices) == num_layers and sorted(indices) == list(range(num_layers)) + } + + +def _sort_patterns_by_first_occurrence(patterns: dict[str, str], names: list[str]) -> list[str]: + """Sort *patterns* keys by when ``0`` first appears in *names*.""" + + def _key(prefix: str) -> int: + try: + return names.index(f"{prefix}0") + except ValueError: + return len(names) + + return sorted(patterns.keys(), key=_key) + + +# --------------------------------------------------------------------------- +# Generic decoder-pipeline stage factory +# --------------------------------------------------------------------------- + + +def build_decoder_pipeline_stages( + context_onnx: str | Path, + iterator_onnx: str | Path, + num_layers: int, + *, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + context_session_options: dict | None = None, + iterator_session_options: dict | None = None, +) -> tuple[list[PipelineStage], DecoderIOMapping]: + """Build pipeline stages by introspecting the built ONNX models. + + Reads actual tensor names from *context_onnx* and *iterator_onnx* so the + generated ``genai_config.json`` can never drift out of sync with the real + model I/O — no tensor names are hardcoded. + + Args: + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + num_layers: Number of transformer layers (``hf_config.num_hidden_layers``). + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. + context_session_options: Optional ORT ``session_options`` dict attached + verbatim to the ``context`` stage (e.g. to route it to an + accelerator EP). ``None`` (default) runs the stage on CPU. This + function stays execution-provider-agnostic — the caller decides the + contents; no EP-specific values are constructed here. + iterator_session_options: Same as *context_session_options* but for the + ``iterator`` stage. + + Returns: + ``(stages, decoder_io)`` — a 4-element :class:`PipelineStage` list and + the :class:`DecoderIOMapping` derived from the introspected tensor names. + """ + ctx_inputs, ctx_outputs = _introspect_onnx_io(Path(context_onnx)) + iter_inputs, iter_outputs = _introspect_onnx_io(Path(iterator_onnx)) + + # Detect per-layer KV format-string patterns in the context model. + input_patterns = _detect_format_patterns(ctx_inputs, num_layers) + output_patterns = _detect_format_patterns(ctx_outputs, num_layers) + + in_sorted = _sort_patterns_by_first_occurrence(input_patterns, ctx_inputs) + out_sorted = _sort_patterns_by_first_occurrence(output_patterns, ctx_outputs) + + # Assign key/value patterns by name (look for "key"/"val" in the prefix), + # falling back to positional order only when names are ambiguous. Pure + # positional assignment would silently swap KV if a model lists values + # before keys in its ONNX graph. + def _pick_kv( + sorted_prefixes: list[str], + patterns: dict[str, str], + key_default: str, + val_default: str, + ) -> tuple[str, str]: + key_prefix = next((p for p in sorted_prefixes if "key" in p.lower()), None) + val_prefix = next((p for p in sorted_prefixes if "val" in p.lower()), None) + if key_prefix and val_prefix: + return patterns[key_prefix], patterns[val_prefix] + # Fallback: positional (preserves original behaviour for unambiguous names) + key_fmt = patterns[sorted_prefixes[0]] if len(sorted_prefixes) > 0 else key_default + val_fmt = patterns[sorted_prefixes[1]] if len(sorted_prefixes) > 1 else val_default + return key_fmt, val_fmt + + past_key_fmt, past_val_fmt = _pick_kv( + in_sorted, input_patterns, "past_keys_%d", "past_values_%d" + ) + pres_key_fmt, pres_val_fmt = _pick_kv( + out_sorted, output_patterns, "present_keys_%d", "present_values_%d" + ) + + # Non-indexed inputs: hidden-state tensor + scalar seq-length scalars. + non_indexed = [n for n in ctx_inputs if not _KV_INDEXED_RE.match(n)] + seq_len_names = [n for n in non_indexed if re.search(r"seq|len", n, re.IGNORECASE)] + hidden_state_in = next( + (n for n in non_indexed if n not in seq_len_names), "input_hidden_states" + ) + past_seq_name = next((n for n in seq_len_names if "past" in n.lower()), "past_seq_len") + total_seq_name = next((n for n in seq_len_names if "total" in n.lower()), "total_seq_len") + + # Non-indexed output: hidden-state output of the transformer stack. + hidden_state_out = next( + (n for n in ctx_outputs if not _KV_INDEXED_RE.match(n)), "output_hidden_states" + ) + + decoder_io = DecoderIOMapping( + past_sequence_length=past_seq_name, + total_sequence_length=total_seq_name, + past_key_names=past_key_fmt, + past_value_names=past_val_fmt, + present_key_names=pres_key_fmt, + present_value_names=pres_val_fmt, + ) + + stages: list[PipelineStage] = [ + PipelineStage( + name="embeddings", + filename=embeddings_filename, + run_on_prompt=True, + run_on_token_gen=True, + inputs=[decoder_io.input_ids], + outputs=[hidden_state_in], + ), + PipelineStage( + name="context", + filename=context_filename, + run_on_prompt=True, + run_on_token_gen=False, + inputs=ctx_inputs, + outputs=ctx_outputs, + session_options=context_session_options, + ), + PipelineStage( + name="iterator", + filename=iterator_filename, + run_on_prompt=False, + run_on_token_gen=True, + inputs=iter_inputs, + outputs=iter_outputs, + session_options=iterator_session_options, + ), + PipelineStage( + name="lm_head", + filename=lm_head_filename, + run_on_prompt=True, + run_on_token_gen=True, + inputs=[hidden_state_out], + outputs=[decoder_io.logits], + is_lm_head=True, + ), + ] + return stages, decoder_io + + +# --------------------------------------------------------------------------- +# Bundle assembler +# --------------------------------------------------------------------------- + + +def _patch_seq_dim_dynamic(onnx_path: Path, dim_index: int = 1) -> None: + """Make dimension *dim_index* of all graph inputs/outputs symbolic. + + ort-genai calls the embeddings model with the full prompt on prefill + (seq_len = prompt_len) and with a single token on each decode step + (seq_len = 1). The ONNX export may bake in a concrete value; this + helper replaces it with the symbolic name ``"seq_len"`` so the runtime + accepts any sequence length. + + The model weights (external data) are not touched — only the protobuf + shape annotations are updated. + """ + import onnx + + model = onnx.load(str(onnx_path), load_external_data=False) + changed = False + for value_info in list(model.graph.input) + list(model.graph.output): + shape = value_info.type.tensor_type.shape + if shape and len(shape.dim) > dim_index: + dim = shape.dim[dim_index] + if dim.HasField("dim_value"): # it's a fixed integer + dim.ClearField("dim_value") + dim.dim_param = "seq_len" + changed = True + if changed: + onnx.save(model, str(onnx_path)) + logger.info("Patched seq_len dim to dynamic in %s", onnx_path.name) + + +def write_genai_bundle( + output_dir: str | Path, + *, + context_onnx: str | Path, + iterator_onnx: str | Path, + model_id: str, + max_cache_len: int, + prefill_seq_len: int, + embeddings_src: str | Path | None = None, + lm_head_src: str | Path | None = None, + context_filename: str = DEFAULT_CONTEXT_FILENAME, + iterator_filename: str = DEFAULT_ITERATOR_FILENAME, + embeddings_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + lm_head_filename: str = DEFAULT_LM_HEAD_FILENAME, + context_session_options: dict | None = None, + iterator_session_options: dict | None = None, + transformer_onnx_passes: Sequence[Callable[[ModelProto], ModelProto]] | None = None, +) -> Path: + """Assemble a complete ``onnxruntime-genai`` bundle in *output_dir*. + + Copies the winml-built transformer ONNX files, optional embedding / + lm_head models, HF tokenizer files, and writes ``genai_config.json``. + Tensor names in the config are derived by introspecting the built ONNX + files rather than being hardcoded, so this works for any model that + follows the 4-stage decoder-pipeline layout. + + Args: + output_dir: Destination directory (created if absent). + context_onnx: Path to the built prefill/context ONNX. + iterator_onnx: Path to the built decode/iterator ONNX. + model_id: HuggingFace model ID or local path for config + tokenizer. + max_cache_len: Static KV cache length (= ``context_length`` in genai). + prefill_seq_len: Prefill sequence length (= ``sliding_window.window_size``). + embeddings_src: Source path of the embeddings ONNX. ``None`` = skip. + lm_head_src: Source path of the lm_head ONNX. ``None`` = skip. + context_filename: Bundle filename for the context model. + iterator_filename: Bundle filename for the iterator model. + embeddings_filename: Bundle filename for the embeddings model. + lm_head_filename: Bundle filename for the lm_head model. + context_session_options: Optional ORT ``session_options`` dict attached + verbatim to the ``context`` stage. ``None`` (default) runs it on + CPU. This assembler is execution-provider-agnostic; the caller + supplies any EP-specific options. + iterator_session_options: Same as *context_session_options* but for the + ``iterator`` stage. + transformer_onnx_passes: Optional list of callables applied to each + transformer ONNX (context + iterator) **after** copying to the + bundle directory. Each callable receives an + ``onnx.ModelProto``, may modify it in-place, and must return it. + Passes are applied in order to the context model first, then the + iterator model. Use this to strip exporter-injected default + attributes that are not needed at inference time. + + Returns: + Path to the written ``genai_config.json``. + """ + import onnx as _onnx + from transformers import AutoConfig, AutoTokenizer + + from ..onnx import copy_onnx_model + + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + context_onnx = Path(context_onnx) + iterator_onnx = Path(iterator_onnx) + + # 1. Copy winml-built transformer ONNX files. + logger.info("Copying context ONNX: %s -> %s", context_onnx.name, context_filename) + copy_onnx_model(context_onnx, output_dir / context_filename) + + logger.info("Copying iterator ONNX: %s -> %s", iterator_onnx.name, iterator_filename) + copy_onnx_model(iterator_onnx, output_dir / iterator_filename) + + # 1b. Apply optional post-copy ONNX passes to the transformer stages. + if transformer_onnx_passes: + for fname in (context_filename, iterator_filename): + dst = output_dir / fname + model = _onnx.load(str(dst), load_external_data=False) + for pass_fn in transformer_onnx_passes: + model = pass_fn(model) + _onnx.save(model, str(dst)) + logger.info("Applied %d ONNX pass(es) to %s", len(transformer_onnx_passes), fname) + + # 2. Copy embeddings + lm_head models. + if embeddings_src is not None: + logger.info("Copying embeddings: %s -> %s", Path(embeddings_src).name, embeddings_filename) + dst_embeddings = output_dir / embeddings_filename + copy_onnx_model(Path(embeddings_src), dst_embeddings) + # Patch seq_len to dynamic: ort-genai calls embeddings with the full + # prompt on prefill and with a single token on every decode step, so the + # seq_len dimension must be symbolic, not a fixed value. + _patch_seq_dim_dynamic(dst_embeddings) + else: + logger.warning( + "embeddings_src not provided — '%s' is missing from bundle.", + embeddings_filename, + ) + + if lm_head_src is not None: + logger.info("Copying lm_head: %s -> %s", Path(lm_head_src).name, lm_head_filename) + dst_lm_head = output_dir / lm_head_filename + copy_onnx_model(Path(lm_head_src), dst_lm_head) + # Same reason as embeddings: lm_head is called with prefill seq_len + # and with seq_len=1 on each decode step. + _patch_seq_dim_dynamic(dst_lm_head) + else: + logger.warning( + "lm_head_src not provided — '%s' is missing from bundle.", + lm_head_filename, + ) + + # 3. Save tokenizer files from the HF snapshot. + logger.info("Saving tokenizer from '%s' to %s", model_id, output_dir) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.save_pretrained(str(output_dir)) + + # 4. Build pipeline stages by introspecting the source ONNX files. + hf_config = AutoConfig.from_pretrained(model_id) + stages, decoder_io = build_decoder_pipeline_stages( + context_onnx, + iterator_onnx, + num_layers=hf_config.num_hidden_layers, + context_filename=context_filename, + iterator_filename=iterator_filename, + embeddings_filename=embeddings_filename, + lm_head_filename=lm_head_filename, + context_session_options=context_session_options, + iterator_session_options=iterator_session_options, + ) + + # 5. Write genai_config.json. + config = build_genai_config( + hf_config, + max_cache_len=max_cache_len, + prefill_seq_len=prefill_seq_len, + pipeline=stages, + decoder_io=decoder_io, + ) + config_path = output_dir / "genai_config.json" + config_path.write_text(json.dumps(config, indent=2), encoding="utf-8") + logger.info("Wrote genai_config.json -> %s", config_path) + + _log_bundle_summary(output_dir, config_path) + return config_path + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _log_bundle_summary(bundle_dir: Path, config_path: Path) -> None: + """Print a human-readable summary of the assembled bundle.""" + files = sorted(bundle_dir.iterdir()) + lines = [f"\n=== genai bundle: {bundle_dir} ==="] + for f in files: + size_kb = f.stat().st_size / 1024 + tag = "" + if f.name == "genai_config.json": + tag = " <- pipeline config" + elif f.name.endswith(".onnx"): + tag = " <- ONNX graph" + elif f.name.endswith(".data"): + tag = " <- ONNX external weights" + lines.append(f" {f.name:<45} {size_kb:>8.1f} KB{tag}") + lines.append(f"\nConfig written to: {config_path}") + logger.info("\n".join(lines)) + + +__all__ = [ + "DEFAULT_CONTEXT_FILENAME", + "DEFAULT_EMBEDDINGS_FILENAME", + "DEFAULT_ITERATOR_FILENAME", + "DEFAULT_LM_HEAD_FILENAME", + "DecoderIOMapping", + "PipelineStage", + "build_decoder_pipeline_stages", + "build_genai_config", + "write_genai_bundle", +] diff --git a/tests/unit/models/qwen3/test_genai_config.py b/tests/unit/models/qwen3/test_genai_config.py new file mode 100644 index 000000000..66553d589 --- /dev/null +++ b/tests/unit/models/qwen3/test_genai_config.py @@ -0,0 +1,601 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the Qwen3 genai config builder.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import ClassVar +from unittest.mock import patch + +from winml.modelkit.models.hf.qwen3 import ( + DecoderIOMapping, + PipelineStage, + build_genai_config, + build_qwen3_transformer_only_stages, + write_genai_bundle, +) +from winml.modelkit.models.hf.qwen3.genai import ( + DEFAULT_CONTEXT_FILENAME, + DEFAULT_EMBEDDINGS_FILENAME, + DEFAULT_ITERATOR_FILENAME, + DEFAULT_LM_HEAD_FILENAME, +) +from winml.modelkit.utils.genai import _detect_format_patterns + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _mock_config( + *, + num_hidden_layers: int = 28, + hidden_size: int = 1024, + num_attention_heads: int = 16, + num_key_value_heads: int = 8, + head_dim: int = 128, + bos_token_id: int = 151643, + eos_token_id: int = 151645, + pad_token_id: int = 151643, + vocab_size: int = 151936, +) -> SimpleNamespace: + """Return a minimal stand-in for a HF PretrainedConfig.""" + return SimpleNamespace( + num_hidden_layers=num_hidden_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + head_dim=head_dim, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + vocab_size=vocab_size, + ) + + +def _make_pipeline( + num_layers: int = 28, + *, + emb_filename: str = DEFAULT_EMBEDDINGS_FILENAME, + ctx_filename: str = DEFAULT_CONTEXT_FILENAME, + iter_filename: str = DEFAULT_ITERATOR_FILENAME, + lmh_filename: str = DEFAULT_LM_HEAD_FILENAME, +) -> list[PipelineStage]: + """Build a standard 4-stage pipeline for use in unit tests.""" + ctx_inputs = [ + "input_hidden_states", + "past_seq_len", + "total_seq_len", + *[f"past_keys_{i}" for i in range(num_layers)], + *[f"past_values_{i}" for i in range(num_layers)], + ] + ctx_outputs = [ + "output_hidden_states", + *[f"present_keys_{i}" for i in range(num_layers)], + *[f"present_values_{i}" for i in range(num_layers)], + ] + return [ + PipelineStage( + "embeddings", emb_filename, True, True, ["input_ids"], ["input_hidden_states"] + ), + PipelineStage("context", ctx_filename, True, False, ctx_inputs, ctx_outputs), + PipelineStage("iterator", iter_filename, False, True, ctx_inputs, ctx_outputs), + PipelineStage( + "lm_head", + lmh_filename, + True, + True, + ["output_hidden_states"], + ["logits"], + is_lm_head=True, + ), + ] + + +# --------------------------------------------------------------------------- +# Tests: build_genai_config +# --------------------------------------------------------------------------- + + +class TestBuildGenaiConfig: + def setup_method(self) -> None: + self.cfg = _mock_config() + self.result = build_genai_config( + self.cfg, + max_cache_len=256, + prefill_seq_len=64, + pipeline=_make_pipeline(), + ) + + def test_top_level_model_type(self) -> None: + assert self.result["model"]["type"] == "decoder-pipeline" + + def test_token_ids(self) -> None: + m = self.result["model"] + assert m["bos_token_id"] == 151643 + assert m["eos_token_id"] == 151645 + assert m["pad_token_id"] == 151643 + assert m["vocab_size"] == 151936 + + def test_context_length_equals_max_cache_len(self) -> None: + assert self.result["model"]["context_length"] == 256 + + def test_search_max_length_equals_context_length(self) -> None: + assert self.result["search"]["max_length"] == self.result["model"]["context_length"] + + def test_search_past_present_share_buffer(self) -> None: + assert self.result["search"]["past_present_share_buffer"] is True + + def test_decoder_architecture_params(self) -> None: + dec = self.result["model"]["decoder"] + assert dec["hidden_size"] == 1024 + assert dec["num_attention_heads"] == 16 + assert dec["num_key_value_heads"] == 8 + assert dec["num_hidden_layers"] == 28 + assert dec["head_size"] == 128 + + def test_sliding_window_present_when_prefill_seq_len_given(self) -> None: + sw = self.result["model"]["decoder"]["sliding_window"] + assert sw["window_size"] == 64 + assert sw["slide_inputs"] is True + assert sw["slide_key_value_cache"] is False + + def test_sliding_window_absent_when_prefill_seq_len_none(self) -> None: + result = build_genai_config( + self.cfg, + max_cache_len=256, + prefill_seq_len=None, + pipeline=_make_pipeline(), + ) + assert "sliding_window" not in result["model"]["decoder"] + + def test_decoder_io_tensor_names(self) -> None: + inputs = self.result["model"]["decoder"]["inputs"] + assert inputs["past_sequence_length"] == "past_seq_len" + assert inputs["total_sequence_length"] == "total_seq_len" + assert inputs["past_key_names"] == "past_keys_%d" + assert inputs["past_value_names"] == "past_values_%d" + outputs = self.result["model"]["decoder"]["outputs"] + assert outputs["logits"] == "logits" + assert outputs["present_key_names"] == "present_keys_%d" + assert outputs["present_value_names"] == "present_values_%d" + + def test_custom_decoder_io_mapping(self) -> None: + custom_io = DecoderIOMapping( + past_key_names="k_%d", + past_value_names="v_%d", + present_key_names="pk_%d", + present_value_names="pv_%d", + ) + result = build_genai_config( + self.cfg, + max_cache_len=256, + prefill_seq_len=64, + pipeline=_make_pipeline(), + decoder_io=custom_io, + ) + dec_inputs = result["model"]["decoder"]["inputs"] + assert dec_inputs["past_key_names"] == "k_%d" + assert dec_inputs["past_value_names"] == "v_%d" + dec_outputs = result["model"]["decoder"]["outputs"] + assert dec_outputs["present_key_names"] == "pk_%d" + assert dec_outputs["present_value_names"] == "pv_%d" + + def test_pipeline_has_four_stages(self) -> None: + pipeline = self.result["model"]["decoder"]["pipeline"] + assert len(pipeline) == 4 + stage_names = [next(iter(s.keys())) for s in pipeline] + assert stage_names == ["embeddings", "context", "iterator", "lm_head"] + + def test_embeddings_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][0]["embeddings"] + assert stage["filename"] == DEFAULT_EMBEDDINGS_FILENAME + assert stage["inputs"] == ["input_ids"] + assert stage["outputs"] == ["input_hidden_states"] + assert stage["run_on_prompt"] is True + assert stage["run_on_token_gen"] is True + + def test_context_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][1]["context"] + assert stage["filename"] == DEFAULT_CONTEXT_FILENAME + assert "input_hidden_states" in stage["inputs"] + assert "past_seq_len" in stage["inputs"] + assert "total_seq_len" in stage["inputs"] + assert stage["run_on_prompt"] is True + assert stage["run_on_token_gen"] is False + + def test_iterator_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][2]["iterator"] + assert stage["filename"] == DEFAULT_ITERATOR_FILENAME + assert stage["run_on_prompt"] is False + assert stage["run_on_token_gen"] is True + + def test_lm_head_stage(self) -> None: + stage = self.result["model"]["decoder"]["pipeline"][3]["lm_head"] + assert stage["filename"] == DEFAULT_LM_HEAD_FILENAME + assert stage["inputs"] == ["output_hidden_states"] + assert stage["outputs"] == ["logits"] + assert stage["is_lm_head"] is True + assert stage["run_on_prompt"] is True + assert stage["run_on_token_gen"] is True + + def test_context_kv_inputs_count(self) -> None: + """context.inputs must include all 28 past_keys + 28 past_values.""" + inputs = self.result["model"]["decoder"]["pipeline"][1]["context"]["inputs"] + past_keys = [x for x in inputs if x.startswith("past_keys_")] + past_values = [x for x in inputs if x.startswith("past_values_")] + assert len(past_keys) == 28 + assert len(past_values) == 28 + assert set(past_keys) == {f"past_keys_{i}" for i in range(28)} + assert set(past_values) == {f"past_values_{i}" for i in range(28)} + + def test_context_outputs_kv_count(self) -> None: + outputs = self.result["model"]["decoder"]["pipeline"][1]["context"]["outputs"] + present_keys = [x for x in outputs if x.startswith("present_keys_")] + present_values = [x for x in outputs if x.startswith("present_values_")] + assert len(present_keys) == 28 + assert len(present_values) == 28 + + def test_context_and_iterator_have_same_io(self) -> None: + ctx = self.result["model"]["decoder"]["pipeline"][1]["context"] + itr = self.result["model"]["decoder"]["pipeline"][2]["iterator"] + assert ctx["inputs"] == itr["inputs"] + assert ctx["outputs"] == itr["outputs"] + + def test_custom_filenames(self) -> None: + result = build_genai_config( + self.cfg, + max_cache_len=512, + prefill_seq_len=128, + pipeline=_make_pipeline( + emb_filename="emb.onnx", + ctx_filename="prefill.onnx", + iter_filename="decode.onnx", + lmh_filename="head.onnx", + ), + ) + pipeline = result["model"]["decoder"]["pipeline"] + assert pipeline[0]["embeddings"]["filename"] == "emb.onnx" + assert pipeline[1]["context"]["filename"] == "prefill.onnx" + assert pipeline[2]["iterator"]["filename"] == "decode.onnx" + assert pipeline[3]["lm_head"]["filename"] == "head.onnx" + + def test_eos_token_id_list_preserved(self) -> None: + cfg = _mock_config(eos_token_id=[151645, 151643]) + result = build_genai_config( + cfg, max_cache_len=256, prefill_seq_len=64, pipeline=_make_pipeline() + ) + # ORT genai accepts a list of EOS token IDs; all must be preserved so that + # any secondary stop token (e.g. 151643 in some Qwen3 variants) is honoured. + assert result["model"]["eos_token_id"] == [151645, 151643] + + def test_head_size_derived_when_head_dim_missing(self) -> None: + cfg = SimpleNamespace( + num_hidden_layers=2, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=4, + # no head_dim attribute + bos_token_id=0, + eos_token_id=1, + pad_token_id=0, + vocab_size=32000, + ) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(2) + ) + # head_size = hidden_size // num_attention_heads = 512 // 8 = 64 + assert result["model"]["decoder"]["head_size"] == 64 + + def test_pad_token_id_falls_back_to_bos(self) -> None: + cfg = SimpleNamespace( + num_hidden_layers=2, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=64, + bos_token_id=0, + eos_token_id=1, + pad_token_id=None, + vocab_size=32000, + ) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(2) + ) + assert result["model"]["pad_token_id"] == 0 # falls back to bos_token_id + + def test_pad_token_id_zero_is_preserved(self) -> None: + """A valid pad_token_id of 0 must not be overwritten by bos_token_id. + + 0 is a common, valid pad id; a truthiness check would silently swap it + for bos_token_id and corrupt batched-padding generation. + """ + cfg = SimpleNamespace( + num_hidden_layers=2, + hidden_size=512, + num_attention_heads=8, + num_key_value_heads=4, + head_dim=64, + bos_token_id=5, + eos_token_id=1, + pad_token_id=0, + vocab_size=32000, + ) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(2) + ) + assert result["model"]["pad_token_id"] == 0 + + def test_different_layer_count(self) -> None: + cfg = _mock_config(num_hidden_layers=4) + result = build_genai_config( + cfg, max_cache_len=128, prefill_seq_len=32, pipeline=_make_pipeline(4) + ) + inputs = result["model"]["decoder"]["pipeline"][1]["context"]["inputs"] + past_keys = [x for x in inputs if x.startswith("past_keys_")] + assert len(past_keys) == 4 + assert {f"past_keys_{i}" for i in range(4)} == set(past_keys) + + +# --------------------------------------------------------------------------- +# Tests: _detect_format_patterns +# --------------------------------------------------------------------------- + + +class TestDetectFormatPatterns: + def test_detects_two_kv_groups(self) -> None: + names = [ + "input_hidden_states", + "past_seq_len", + "past_keys_0", + "past_keys_1", + "past_keys_2", + "past_values_0", + "past_values_1", + "past_values_2", + ] + result = _detect_format_patterns(names, num_layers=3) + assert result == {"past_keys_": "past_keys_%d", "past_values_": "past_values_%d"} + + def test_ignores_incomplete_index_range(self) -> None: + # Missing index 1 — should not be detected + names = ["prefix_0", "prefix_2"] + result = _detect_format_patterns(names, num_layers=3) + assert "prefix_" not in result + + def test_ignores_wrong_num_layers(self) -> None: + # 3 entries but num_layers=5 + names = ["kv_0", "kv_1", "kv_2"] + result = _detect_format_patterns(names, num_layers=5) + assert len(result) == 0 + + def test_empty_input(self) -> None: + assert _detect_format_patterns([], num_layers=4) == {} + + def test_non_indexed_names_ignored(self) -> None: + names = ["input_hidden_states", "past_seq_len", "total_seq_len"] + result = _detect_format_patterns(names, num_layers=3) + assert result == {} + + def test_single_layer_model(self) -> None: + names = ["keys_0", "vals_0"] + result = _detect_format_patterns(names, num_layers=1) + assert result == {"keys_": "keys_%d", "vals_": "vals_%d"} + + +# --------------------------------------------------------------------------- +# Tests: build_qwen3_transformer_only_stages +# --------------------------------------------------------------------------- + + +class TestBuildQwen3TransformerOnlyStages: + """Uses mocked onnx.load so no real ONNX files are required.""" + + def _ctx_inputs(self, n: int = 4) -> list[str]: + return [ + "input_hidden_states", + "past_seq_len", + "total_seq_len", + *[f"past_keys_{i}" for i in range(n)], + *[f"past_values_{i}" for i in range(n)], + ] + + def _ctx_outputs(self, n: int = 4) -> list[str]: + return [ + "output_hidden_states", + *[f"present_keys_{i}" for i in range(n)], + *[f"present_values_{i}" for i in range(n)], + ] + + def _patch_onnx(self, n: int = 4): + ctx_io = (self._ctx_inputs(n), self._ctx_outputs(n)) + iter_io = (self._ctx_inputs(n), self._ctx_outputs(n)) + return patch( + "winml.modelkit.utils.genai._introspect_onnx_io", + side_effect=[ctx_io, iter_io], + ) + + def test_returns_four_stages(self) -> None: + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages("ctx.onnx", "iter.onnx", num_layers=4) + assert len(stages) == 4 + assert [s.name for s in stages] == ["embeddings", "context", "iterator", "lm_head"] + + def test_detected_kv_format_patterns(self) -> None: + with self._patch_onnx(): + _, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4 + ) + assert decoder_io.past_key_names == "past_keys_%d" + assert decoder_io.past_value_names == "past_values_%d" + assert decoder_io.present_key_names == "present_keys_%d" + assert decoder_io.present_value_names == "present_values_%d" + + def test_detected_seq_len_names(self) -> None: + with self._patch_onnx(): + _, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4 + ) + assert decoder_io.past_sequence_length == "past_seq_len" + assert decoder_io.total_sequence_length == "total_seq_len" + + def test_context_stage_inputs_from_onnx(self) -> None: + with self._patch_onnx(n=4): + stages, _ = build_qwen3_transformer_only_stages("ctx.onnx", "iter.onnx", num_layers=4) + ctx_stage = next(s for s in stages if s.name == "context") + assert "input_hidden_states" in ctx_stage.inputs + assert "past_keys_0" in ctx_stage.inputs + assert "past_values_3" in ctx_stage.inputs + + def test_custom_filenames(self) -> None: + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", + "iter.onnx", + num_layers=4, + context_filename="prefill.onnx", + iterator_filename="decode.onnx", + embeddings_filename="emb.onnx", + lm_head_filename="head.onnx", + ) + names = {s.name: s.filename for s in stages} + assert names["context"] == "prefill.onnx" + assert names["iterator"] == "decode.onnx" + assert names["embeddings"] == "emb.onnx" + assert names["lm_head"] == "head.onnx" + + def test_roundtrip_with_build_genai_config(self) -> None: + """build_qwen3_transformer_only_stages output feeds build_genai_config cleanly.""" + with self._patch_onnx(n=4): + stages, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4 + ) + cfg = _mock_config(num_hidden_layers=4) + result = build_genai_config( + cfg, + max_cache_len=128, + prefill_seq_len=32, + pipeline=stages, + decoder_io=decoder_io, + ) + assert result["model"]["type"] == "decoder-pipeline" + assert len(result["model"]["decoder"]["pipeline"]) == 4 + + def test_cpu_ep_no_session_options(self) -> None: + """Default cpu ep: context/iterator stages have no session_options.""" + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="cpu" + ) + ctx = next(s for s in stages if s.name == "context") + itr = next(s for s in stages if s.name == "iterator") + assert ctx.session_options is None + assert itr.session_options is None + + def test_qnn_ep_injects_session_options(self) -> None: + """ep='qnn': context/iterator get QNN session_options; emb/lm_head do not.""" + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn" + ) + stage_map = {s.name: s for s in stages} + assert stage_map["embeddings"].session_options is None + assert stage_map["lm_head"].session_options is None + ctx_opts = stage_map["context"].session_options + itr_opts = stage_map["iterator"].session_options + assert ctx_opts is not None + assert itr_opts is not None + assert ctx_opts["provider_options"][0]["qnn"]["backend_path"] == "QnnHtp.dll" + assert itr_opts["log_id"] == "onnxruntime-genai.iterator" + + def test_qnn_session_options_in_serialized_config(self) -> None: + """QNN session_options appear in genai_config.json pipeline output.""" + with self._patch_onnx(): + stages, decoder_io = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn" + ) + cfg = build_genai_config( + _mock_config(num_hidden_layers=4), + max_cache_len=256, + prefill_seq_len=64, + pipeline=stages, + decoder_io=decoder_io, + ) + pipeline = cfg["model"]["decoder"]["pipeline"] + ctx_dict = next(s for s in pipeline if "context" in s)["context"] + itr_dict = next(s for s in pipeline if "iterator" in s)["iterator"] + emb_dict = next(s for s in pipeline if "embeddings" in s)["embeddings"] + assert "session_options" in ctx_dict + assert "session_options" in itr_dict + assert "session_options" not in emb_dict + + def test_custom_soc_model(self) -> None: + """soc_model parameter propagates to QNN provider_options.""" + with self._patch_onnx(): + stages, _ = build_qwen3_transformer_only_stages( + "ctx.onnx", "iter.onnx", num_layers=4, ep="qnn", soc_model="73" + ) + ctx = next(s for s in stages if s.name == "context") + assert ctx.session_options["provider_options"][0]["qnn"]["soc_model"] == "73" + + +# --------------------------------------------------------------------------- +# Tests: write_genai_bundle wrapper (ep routing + transformer_onnx_passes) +# --------------------------------------------------------------------------- + + +class TestWriteGenaiBundleWrapper: + """The Qwen3 ``write_genai_bundle`` wrapper injects the QNN ``session_options`` + and forwards the generic keyword arguments — notably ``transformer_onnx_passes`` + — to :func:`winml.modelkit.utils.genai.write_genai_bundle`. The generic + assembler is mocked so no ONNX/tokenizer files are needed. + """ + + _COMMON: ClassVar[dict] = { + "context_onnx": "ctx.onnx", + "iterator_onnx": "iter.onnx", + "model_id": "Qwen/Qwen3-0.6B", + "max_cache_len": 2048, + "prefill_seq_len": 64, + } + + @staticmethod + def _patch_generic(): + return patch("winml.modelkit.models.hf.qwen3.genai._write_genai_bundle") + + def test_forwards_transformer_onnx_passes(self) -> None: + """A supplied transformer_onnx_passes list reaches the generic assembler.""" + + def _identity_pass(model): + return model + + with self._patch_generic() as mock_write: + write_genai_bundle("out", transformer_onnx_passes=[_identity_pass], **self._COMMON) + assert mock_write.call_count == 1 + assert mock_write.call_args.kwargs["transformer_onnx_passes"] == [_identity_pass] + + def test_transformer_onnx_passes_defaults_to_none(self) -> None: + """Omitting transformer_onnx_passes forwards None (no passes).""" + with self._patch_generic() as mock_write: + write_genai_bundle("out", **self._COMMON) + assert mock_write.call_args.kwargs["transformer_onnx_passes"] is None + + def test_qnn_ep_forwards_session_options(self) -> None: + """ep='qnn' forwards QNN session_options for the context and iterator stages.""" + with self._patch_generic() as mock_write: + write_genai_bundle("out", ep="qnn", **self._COMMON) + kwargs = mock_write.call_args.kwargs + assert kwargs["context_session_options"]["log_id"] == "onnxruntime-genai.context" + assert kwargs["iterator_session_options"]["log_id"] == "onnxruntime-genai.iterator" + + def test_cpu_ep_forwards_no_session_options(self) -> None: + """ep='cpu' (default) forwards None session_options for both stages.""" + with self._patch_generic() as mock_write: + write_genai_bundle("out", **self._COMMON) + kwargs = mock_write.call_args.kwargs + assert kwargs["context_session_options"] is None + assert kwargs["iterator_session_options"] is None diff --git a/tests/unit/models/qwen3/test_qwen3_modeling.py b/tests/unit/models/qwen3/test_qwen3_modeling.py new file mode 100644 index 000000000..6826a95bf --- /dev/null +++ b/tests/unit/models/qwen3/test_qwen3_modeling.py @@ -0,0 +1,167 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for WinMLQwen3Attention rope cache — static _max_rope_len approach. + +The rope cache length is now pinned to a plain Python int attribute +``_max_rope_len`` set by ``prepare_for_onnx_export``, so the ONNX exporter +bakes a static-shape constant rather than a symbolic range derived from +``total_seq_len.item()``. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import torch + +from winml.modelkit.models.hf.qwen3.qwen3_modeling import WinMLQwen3Attention + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_attention_module( + max_position_embeddings: int = 40960, + num_heads: int = 16, + num_kv_heads: int = 8, + head_dim: int = 64, +) -> MagicMock: + """Build a minimal mock bound to WinMLQwen3Attention methods.""" + hidden_size = num_heads * head_dim + kv_size = num_kv_heads * head_dim + + mod = MagicMock() + mod.head_dim = head_dim + mod._matmul_to_conv = False + mod.config = SimpleNamespace( + num_attention_heads=num_heads, + num_key_value_heads=num_kv_heads, + max_position_embeddings=max_position_embeddings, + ) + + # Identity projections: return float32 tensors of the right shape + mod.q_proj.side_effect = lambda x: torch.zeros(1, x.shape[1], hidden_size) + mod.k_proj.side_effect = lambda x: torch.zeros(1, x.shape[1], kv_size) + mod.v_proj.side_effect = lambda x: torch.zeros(1, x.shape[1], kv_size) + + # q_norm / k_norm: identity + mod.q_norm.side_effect = lambda x: x + mod.k_norm.side_effect = lambda x: x + + return mod + + +def _run_forward( + mod: WinMLQwen3Attention, + seq_len: int, + max_cache_len: int, + kv_dtype: torch.dtype = torch.float16, +) -> list[torch.Tensor]: + """Invoke WinMLQwen3Attention.forward and capture rotary_emb position_ids.""" + hidden = torch.zeros(1, seq_len, mod.config.num_attention_heads * mod.head_dim) + past_keys = torch.zeros( + 1, mod.config.num_key_value_heads, max_cache_len, mod.head_dim, dtype=kv_dtype + ) + past_vals = torch.zeros( + 1, mod.config.num_key_value_heads, max_cache_len, mod.head_dim, dtype=kv_dtype + ) + past_seq_len = torch.zeros(1, 1, dtype=torch.int32) + total_seq_len = torch.tensor([max_cache_len], dtype=torch.int32) + + captured_pos_ids: list[torch.Tensor] = [] + + def _fake_rotary_emb(values, position_ids): + captured_pos_ids.append(position_ids) + seq_dim = position_ids.shape[-1] + cos = torch.ones(1, seq_dim, mod.head_dim, dtype=values.dtype) + sin = torch.zeros(1, seq_dim, mod.head_dim, dtype=values.dtype) + return cos, sin + + mod.rotary_emb.side_effect = _fake_rotary_emb + + # GQA op: return (attn_out, present_keys, present_values) + with patch( + "winml.modelkit.models.hf.qwen3.qwen3_modeling.GroupQueryAttentionOnnxExport.apply" + ) as mock_gqa: + attn_out = torch.zeros( + 1, seq_len, mod.config.num_attention_heads * mod.head_dim, dtype=kv_dtype + ) + mock_gqa.return_value = (attn_out, past_keys, past_vals) + WinMLQwen3Attention.forward( + mod, + hidden, + past_key_value=(past_keys, past_vals), + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + ) + + return captured_pos_ids + + +# --------------------------------------------------------------------------- +# Tests for prepare_for_onnx_export — _max_rope_len attribute +# --------------------------------------------------------------------------- + + +class TestPrepareForOnnxExport: + def test_explicit_max_rope_len_is_stored(self): + """prepare_for_onnx_export stores the supplied max_rope_len as a Python int.""" + mod = _make_attention_module(max_position_embeddings=40960) + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=False, max_rope_len=4096) + assert mod._max_rope_len == 4096 + assert isinstance(mod._max_rope_len, int) + + def test_fallback_to_max_position_embeddings(self): + """Without max_rope_len, _max_rope_len falls back to max_position_embeddings.""" + mod = _make_attention_module(max_position_embeddings=512) + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=False) + assert mod._max_rope_len == 512 + + def test_max_rope_len_is_plain_int(self): + """_max_rope_len must be a plain int, not a tensor or other type.""" + mod = _make_attention_module() + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=False, max_rope_len=4096) + assert type(mod._max_rope_len) is int + + +# --------------------------------------------------------------------------- +# Tests for forward — uses _max_rope_len, ignores total_seq_len for rope +# --------------------------------------------------------------------------- + + +class TestRopeCacheSizing: + def test_forward_uses_max_rope_len_attribute(self): + """forward passes torch.arange(_max_rope_len) to rotary_emb regardless of total_seq_len.""" + mod = _make_attention_module(max_position_embeddings=40960) + mod._max_rope_len = 256 # explicitly set; different from total_seq_len=4096 below + + pos_ids = _run_forward(mod, seq_len=1, max_cache_len=4096) + assert len(pos_ids) == 1 + assert pos_ids[0].shape[-1] == 256, ( + f"Expected rope cache length 256 (from _max_rope_len) but got {pos_ids[0].shape[-1]}" + ) + + def test_rope_cache_does_not_depend_on_total_seq_len(self): + """Changing total_seq_len does NOT change the rope cache length.""" + mod_a = _make_attention_module() + mod_a._max_rope_len = 512 + pos_a = _run_forward(mod_a, seq_len=1, max_cache_len=128) + + mod_b = _make_attention_module() + mod_b._max_rope_len = 512 + pos_b = _run_forward(mod_b, seq_len=1, max_cache_len=4096) + + # Both use _max_rope_len=512; total_seq_len differs but rope length must match + assert pos_a[0].shape[-1] == pos_b[0].shape[-1] == 512 + + def test_fallback_rope_len_is_max_position_embeddings(self): + """Without max_rope_len arg, forward uses max_position_embeddings as rope len.""" + mod = _make_attention_module(max_position_embeddings=512) + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=False) + pos_ids = _run_forward(mod, seq_len=1, max_cache_len=256) + assert pos_ids[0].shape[-1] == 512 diff --git a/tests/unit/onnx/test_utils.py b/tests/unit/onnx/test_utils.py new file mode 100644 index 000000000..de3b07aac --- /dev/null +++ b/tests/unit/onnx/test_utils.py @@ -0,0 +1,149 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for onnx/utils.py — strip_node_attrs.""" + +from __future__ import annotations + +from onnx import ModelProto, NodeProto, TensorProto, helper + +from winml.modelkit.onnx import strip_node_attrs + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_gqa_model(attr_dict: dict[str, int]) -> ModelProto: + """Build a minimal model with a single com.microsoft::GroupQueryAttention node. + + Uses explicit ``make_attribute`` calls so attr names match what PyTorch's + TorchScript ONNX exporter produces (no ``_i`` suffix). + """ + x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 64, 512]) + y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 64, 512]) + node = NodeProto() + node.op_type = "GroupQueryAttention" + node.domain = "com.microsoft" + node.input.append("x") + node.output.append("y") + for name, value in attr_dict.items(): + attr = helper.make_attribute(name, value) + node.attribute.append(attr) + graph = helper.make_graph([node], "gqa_graph", [x], [y]) + return helper.make_model(graph, opset_imports=[helper.make_opsetid("com.microsoft", 1)]) + + +def _attr_names(model: ModelProto) -> set[str]: + return {a.name for n in model.graph.node for a in n.attribute} + + +# --------------------------------------------------------------------------- +# strip_node_attrs +# --------------------------------------------------------------------------- + + +def test_strip_removes_extra_attrs(): + """Attributes not in keep_attrs are removed from matching nodes.""" + model = _make_gqa_model( + { + "do_rotary": 1, + "num_heads": 16, + "kv_num_heads": 8, + "local_window_size": -1, + "smooth_softmax": -1, + } + ) + keep = frozenset({"do_rotary", "num_heads", "kv_num_heads"}) + result = strip_node_attrs(model, "GroupQueryAttention", keep, domain="com.microsoft") + remaining = _attr_names(result) + assert remaining == keep + + +def test_strip_keep_attrs_preserved(): + """Attributes listed in keep_attrs survive stripping.""" + model = _make_gqa_model({"do_rotary": 1, "num_heads": 16, "kv_num_heads": 8}) + keep = frozenset({"do_rotary", "num_heads", "kv_num_heads"}) + strip_node_attrs(model, "GroupQueryAttention", keep, domain="com.microsoft") + remaining = _attr_names(model) + assert "do_rotary" in remaining + assert "num_heads" in remaining + assert "kv_num_heads" in remaining + + +def test_strip_no_matching_nodes_is_noop(): + """strip_node_attrs is a no-op when no nodes match op_type.""" + x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 4]) + y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 4]) + node = helper.make_node("Relu", ["x"], ["y"]) + graph = helper.make_graph([node], "g", [x], [y]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + + result = strip_node_attrs(model, "GroupQueryAttention", frozenset(), domain="com.microsoft") + assert result is model # same object returned + + +def test_strip_domain_mismatch_is_noop(): + """Nodes with a different domain are not modified.""" + model = _make_gqa_model({"do_rotary_i": 1, "extra_i": 0}) + before = _attr_names(model) + # Pass wrong domain — nothing should be removed + strip_node_attrs(model, "GroupQueryAttention", frozenset({"do_rotary"}), domain="wrong.domain") + assert _attr_names(model) == before + + +def test_strip_keep_all_attrs(): + """When keep_attrs contains all attr names, nothing is removed.""" + model = _make_gqa_model({"do_rotary": 1, "num_heads": 16}) + keep = frozenset({"do_rotary", "num_heads"}) + strip_node_attrs(model, "GroupQueryAttention", keep, domain="com.microsoft") + assert _attr_names(model) == keep + + +def test_strip_empty_keep_attrs_removes_all(): + """An empty keep_attrs set removes every attribute from matching nodes.""" + model = _make_gqa_model({"do_rotary": 1, "num_heads": 16}) + strip_node_attrs(model, "GroupQueryAttention", frozenset(), domain="com.microsoft") + assert _attr_names(model) == set() + + +def test_strip_returns_same_model_object(): + """strip_node_attrs mutates in-place and returns the same object.""" + model = _make_gqa_model({"do_rotary": 1}) + result = strip_node_attrs( + model, "GroupQueryAttention", frozenset({"do_rotary"}), domain="com.microsoft" + ) + assert result is model + + +def test_strip_multiple_gqa_nodes(): + """All matching nodes in a multi-node graph are stripped.""" + x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 64, 512]) + y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [1, 64, 512]) + z = helper.make_tensor_value_info("z", TensorProto.FLOAT, [1, 64, 512]) + + def _gqa_node(name: str, inp: str, out: str) -> NodeProto: + node = NodeProto() + node.op_type = "GroupQueryAttention" + node.domain = "com.microsoft" + node.name = name + node.input.append(inp) + node.output.append(out) + for attr_name, value in [("do_rotary", 1), ("num_heads", 16), ("local_window_size", -1)]: + node.attribute.append(helper.make_attribute(attr_name, value)) + return node + + graph = helper.make_graph( + [_gqa_node("gqa0", "x", "y"), _gqa_node("gqa1", "y", "z")], + "multi_gqa", + [x], + [z], + value_info=[y], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("com.microsoft", 1)]) + keep = frozenset({"do_rotary", "num_heads"}) + strip_node_attrs(model, "GroupQueryAttention", keep, domain="com.microsoft") + for node in model.graph.node: + assert {a.name for a in node.attribute} == keep diff --git a/tests/unit/quant/calibration/test_qwen3_calibration.py b/tests/unit/quant/calibration/test_qwen3_calibration.py index ad53ef352..6881f0f72 100644 --- a/tests/unit/quant/calibration/test_qwen3_calibration.py +++ b/tests/unit/quant/calibration/test_qwen3_calibration.py @@ -16,8 +16,9 @@ from types import SimpleNamespace import numpy as np -import onnx import torch +from onnx import TensorProto, helper +from onnx import save as onnx_save from winml.modelkit.quant.calibration.qwen3_transformer_only import ( Qwen3DecodeTrajectoryCalibReader, @@ -47,27 +48,27 @@ def _fake_config() -> SimpleNamespace: def _build_tiny_onnx(path, *, seq_len: int, max_cache_len: int) -> None: """Write a minimal graph carrying the inputs the readers introspect.""" inputs = [ - onnx.helper.make_tensor_value_info( - "input_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] + helper.make_tensor_value_info( + "input_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] ), - onnx.helper.make_tensor_value_info( - "past_keys_0", onnx.TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] + helper.make_tensor_value_info( + "past_keys_0", TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] ), ] - out = onnx.helper.make_tensor_value_info( - "output_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] + out = helper.make_tensor_value_info( + "output_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] ) - gqa = onnx.helper.make_node( + gqa = helper.make_node( "GroupQueryAttention", ["input_hidden_states"], ["attn_out"], name="gqa_layer_0", domain="com.microsoft", ) - identity = onnx.helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) - graph = onnx.helper.make_graph([gqa, identity], "tiny", inputs, [out]) - model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 18)]) - onnx.save(model, str(path)) + identity = helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) + graph = helper.make_graph([gqa, identity], "tiny", inputs, [out]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + onnx_save(model, str(path)) def test_graph_shapes_and_gqa_nodes(tmp_path):