Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def get_opset_version_range(op_name: str, start_opset_version: int, op_domain: s
target_domain = "" if args.opset_domain == "ai.onnx" else args.opset_domain
domain_str_for_filename = args.opset_domain # Keep original for filename matching

json_files = list(input_dir.rglob("*.json"))
json_files = list(input_dir.glob("*.json"))

if not json_files:
print(f"No JSON files found in {input_dir}")
Expand Down Expand Up @@ -694,7 +694,6 @@ def get_opset_version_range(op_name: str, start_opset_version: int, op_domain: s
f"_opset{since_version}{qdq_suffix}.json"
)
json_file = input_dir / expected_filename

print(f"Processing {expected_filename}...", end=" ")

if not json_file.exists():
Expand Down
1 change: 1 addition & 0 deletions src/winml/modelkit/pattern/op_input_gen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .flatten_input_generator import FlattenInputGenerator
from .global_pooling_input_generator import *
from .indexing_input_generator import (
GatherBlockQuantizedInputGenerator,
GatherInputGenerator,
ScatterNDInputGenerator,
SplitInputGenerator,
Expand Down
189 changes: 189 additions & 0 deletions src/winml/modelkit/pattern/op_input_gen/indexing_input_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,18 @@
This module contains input generators for operators that perform indexing
and shape manipulation operations:
- Gather: Gathers entries along an axis using indices
- GatherBlockQuantized: Fused gather + block-wise dequantize (com.microsoft)
- ScatterND: Scatters updates into a copy of data at specified indices
- Unsqueeze: Inserts single-dimensional entries to shape
- Split: Splits a tensor into multiple outputs
"""

import math
from typing import Any

import numpy as np

from ...onnx import SupportedONNXType
from .op_input_gen import (
InputConstraint,
InputShapeConstraint,
Expand Down Expand Up @@ -858,3 +861,189 @@ def get_qdq_config(self):
"input": QDQParameterConfig(support_activation=True),
"split": QDQParameterConfig(support_non_qdq=True),
}


@register_runtime_checker_op
class GatherBlockQuantizedInputGenerator(OpInputGenerator):
"""Input generator for com.microsoft::GatherBlockQuantized operator.

GatherBlockQuantized is a fused gather + block-wise dequantize operator.
It gathers rows from a quantized weight tensor and dequantizes them on the fly.

Inputs:
- data (T1): Block-wise quantized weight (INT4/UINT4/UINT8), always a constant initializer
- indices (Tind): Gather indices (INT32/INT64), the runtime input
- scales (T2): Dequantization scales (FLOAT/FLOAT16), always a constant initializer
- zero_points (T1, optional): Dequantization zero points, always a constant initializer

Attributes:
- bits: 4 for INT4/UINT4 data, 8 for UINT8 data
- block_size: Quantization block size (power of 2, >= 16)
- gather_axis: Axis to gather on (UINT8 requires gather_axis=0)
- quantize_axis: Axis along which data was quantized (must differ from gather_axis)

Output (T2): Dequantized gathered tensor.

The op's inputs (INT4/UINT4/UINT8 data, indices, scales, optional zero_points) are
not wrapped by external DQ nodes — they are already quantized. The float output can
be followed by a QuantizeLinear node, so get_qdq_config() marks the output as
support_activation=True and all inputs as support_non_qdq (pass-through).

Coverage (base models, no QDQ):
- T1: INT4 (bits=4), UINT4 (bits=4), UINT8 (bits=8)
- T2: FLOAT, FLOAT16
- Tind: INT32, INT64
- block_size: 16, 32
- gather_axis: 0, 1 for INT4/UINT4; 0 only for UINT8 (spec constraint)
- zero_points: present / absent (doubles the count)

Base count: 2 INT4 gather_axes x 2 block_sizes x 2 T2 x 2 Tind x 2 zp = 32
+ 2 UINT4 gather_axes x 2 block_sizes x 2 T2 x 2 Tind x 2 zp = 32
+ 1 UINT8 gather_axis x 2 block_sizes x 2 T2 x 2 Tind x 2 zp = 16
= 80

QDQ models (output wrapped by Q): 80 base x 4 activation types = 320
"""

op_name = "GatherBlockQuantized"
expand_optionals = False # zero_points presence is enumerated explicitly in iter()

def get_finite_attribute_sets(self) -> dict[str, list]:
"""Not used: attribute enumeration is handled in iter() to couple bits with T1 type."""
return {}

def get_input_and_infinite_attribute_combinations(self) -> list[dict]:
"""Not used: combinations are enumerated directly in iter()."""
return []

def get_qdq_config(self) -> dict[str, QDQParameterConfig]:
"""Return QDQ config: output wrappable by Q; all inputs are pass-through.

GatherBlockQuantized inputs are already quantized (INT4/UINT4/UINT8) and
must not be wrapped by DQ nodes. Only the float output can be followed by
a QuantizeLinear node (support_activation).
"""
return {
"data": QDQParameterConfig(support_non_qdq=True),
"indices": QDQParameterConfig(support_non_qdq=True),
"scales": QDQParameterConfig(support_non_qdq=True),
"zero_points": QDQParameterConfig(support_non_qdq=True),
"output": QDQParameterConfig(support_activation=True),
}

def _iter_constant_combinations(self, kwargs: dict) -> Any:
"""Yield one constant map: data/scales/zero_points are weights; indices is runtime."""
is_constant_map = {
k: k != "indices" for k, v in kwargs.items() if self._is_input_key(k) and v is not None
}
yield is_constant_map

def iter(self) -> Any:
"""Enumerate all valid (T1, bits, T2, Tind, shape, axis, block_size, zp) combos."""
import ml_dtypes

# One representative 2-D embedding-style data shape
data_shape = (32, 64)

block_sizes = [16, 32]
t2_types = [
(SupportedONNXType.FLOAT.np_type, SupportedONNXType.FLOAT.annotation),
(SupportedONNXType.FLOAT16.np_type, SupportedONNXType.FLOAT16.annotation),
]
tind_types = [
(np.int32, SupportedONNXType.INT32.annotation),
(np.int64, SupportedONNXType.INT64.annotation),
]
# (np_dtype, annotation, bits, valid_gather_axes)
t1_configs = [
(np.dtype(ml_dtypes.int4), SupportedONNXType.INT4.annotation, 4, [0, 1]),
(np.dtype(ml_dtypes.uint4), SupportedONNXType.UINT4.annotation, 4, [0, 1]),
(np.dtype(np.uint8), SupportedONNXType.UINT8.annotation, 8, [0]),
]
rng = np.random.default_rng(42)
indices_shape = (2, 4)

for t1_dtype, t1_annotation, bits, gather_axes in t1_configs:
for gather_axis in gather_axes:
quantize_axis = 1 - gather_axis # 2-D: the other axis
for block_size in block_sizes:
sc_dims: list[int] = list(data_shape)
sc_dims[quantize_axis] = math.ceil(sc_dims[quantize_axis] / block_size)
sc_shape = tuple(sc_dims)
axis_size = data_shape[gather_axis]

for t2_dtype, t2_annotation in t2_types:
for tind_dtype, tind_annotation in tind_types:
for zero_points_present in [False, True]:
data_val = rng.integers(
0, 7, size=data_shape, dtype=np.int8
).astype(t1_dtype)
scales_val = rng.random(sc_shape).astype(t2_dtype)
indices_val = rng.integers(
0, axis_size, size=indices_shape, dtype=tind_dtype
)

kwargs: dict[str, Any] = {
"data": data_val,
"indices": indices_val,
"scales": scales_val,
"bits": bits,
"block_size": block_size,
"gather_axis": gather_axis,
"quantize_axis": quantize_axis,
}
if zero_points_present:
kwargs["zero_points"] = np.zeros(sc_shape, dtype=t1_dtype)

type_vars = {
f"T1_{self.op_name}": t1_annotation,
f"Tind_{self.op_name}": tind_annotation,
f"T2_{self.op_name}": t2_annotation,
}
attrs = {
k: v for k, v in kwargs.items() if k in self.op_attribute_names
}
input_constraints = {
k: {"type": "shape", "shape": list(v.shape)}
for k, v in kwargs.items()
if self._is_input_key(k) and v is not None
}
tags = {
self.type_vars_key: type_vars,
"input_constraints": input_constraints,
"attrs": attrs,
}
yield self.filter_kwargs_by_opset(kwargs), tags

def _run_op_on_cpu(self, kwargs: dict, tags: dict) -> Any:
"""Skip CPU validation for GatherBlockQuantized.

This op is a com.microsoft fused op not supported by the CPU EP.
The quantized data inputs (INT4/UINT4/UINT8) are constant initializers
and cannot be fed as runtime inputs; the base class builds an all-dynamic
model for CPU validation, which would fail on sub-byte dtypes.
Our combinations are valid by construction, so CPU pre-validation is not needed.
"""
return []

def derive_properties(self, properties: dict) -> dict:
"""Derive filter properties from node inputs and attributes."""
item = properties.copy()
item["data_dim"] = len(item.get("data_shape", ()))
item["indices_dim"] = len(item.get("indices_shape", ()))
return item

def get_infinite_property_names(self) -> list[str]:
"""Return names of properties with infinite possible values."""
return [
"data_shape",
"indices_shape",
"attr_gather_axis",
"attr_quantize_axis",
"attr_block_size",
# attr_bits is redundant with T1 type (INT4/UINT4 → 4, UINT8 → 8);
# some models omit the bits attribute entirely (attr_bits_is_none=True),
# so exclude both from table matching to avoid false gaps.
"attr_bits",
"attr_bits_is_none",
]
Original file line number Diff line number Diff line change
Expand Up @@ -421,27 +421,34 @@ def get_qdq_config(self) -> dict[str, QDQParameterConfig] | None:
# ============================================================================
# LpNormalization - NOT IMPLEMENTED in ONNXRuntime
# ============================================================================
#
# NOTE: LpNormalization(22) exists in the ONNX spec but is NOT IMPLEMENTED
# in ONNXRuntime as of the current version. The validation fails with:
# "NOT_IMPLEMENTED: Could not find an implementation for LpNormalization(22)"
#
# Uncomment and use the implementation below when runtime support is added:
#
# @register_runtime_checker_op
# class LpNormalizationInputGenerator(NormalizationInputGenerator):
# """Input generator for LpNormalization operator."""
# op_name = "LpNormalization"
# def get_finite_attribute_sets(self) -> dict[str, list]:
# return {"p": [1, 2]}
# def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, InputConstraint]]:
# combinations = []
# for shape in self.get_common_data_shapes():
# if len(shape) < 3:
# continue
# # TODO: add axis
# combinations.append({"input": InputShapeConstraint(shape)})
# return combinations


@register_runtime_checker_op
class LpNormalizationInputGenerator(NormalizationInputGenerator):
"""Input generator for LpNormalization operator."""

op_name = "LpNormalization"

def get_finite_attribute_sets(self) -> dict[str, list]:
"""Return finite attribute values for LpNormalization."""
return {"p": [1, 2]}

def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, InputConstraint]]:
"""Return input combinations for LpNormalization."""
combinations = []
for shape in self.get_common_data_shapes():
if len(shape) < 3:
continue
combinations.extend(
{"input": InputShapeConstraint(shape), "axis": axis} for axis in [0, 1, -1, 2]
)
return combinations

def get_qdq_config(self) -> dict[str, QDQParameterConfig] | None:
"""Return QDQ configuration for LpNormalization operator inputs."""
return {
self.op_input_names[0]: QDQParameterConfig(support_activation=True),
}


# ============================================================================
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/analyze/core/test_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,7 @@ class TestIterQDQCombinations:
("GlobalAveragePool", 3 * 4), # 12
("InstanceNormalization", 3 * 16), # 48
("LayerNormalization", 5 * 2 * 2 * 16), # 320
("LpNormalization", 3 * 2 * 4 * 4), # 96: 3 shapes (>=3D) x 2 p x 4 axis x 4 act types
("MatMul", 36 * (16 * 2 - 4 + 4)), # 1152: +4/shape for B=INT4
(
"MaxPool",
Expand Down Expand Up @@ -1169,3 +1170,36 @@ def test_qdq_total_count(self, op_name: str, expected_count: int) -> None:

# For rerun, could track in https://github.com/gim-home/ModelKit/issues/278
assert count == expected_count, "If changes, either bug or need to rerun"


class TestIterMSQDQCombinations:
"""Tests for com.microsoft domain ops."""

@pytest.mark.parametrize(
"op_name,expected_count",
[
# Only T2=FLOAT combos produce QDQ output models (T2=FLOAT16 fails Q input type check).
# FLOAT base combos:
# 2 INT4 gather_axes x 2 block_sizes x 2 Tind x 2 zp = 16
# 2 UINT4 gather_axes x 2 block_sizes x 2 Tind x 2 zp = 16
# 1 UINT8 gather_axis x 2 block_sizes x 2 Tind x 2 zp = 8
# total = 40
# x 4 activation output types (INT8/UINT8/INT16/UINT16) = 160
("GatherBlockQuantized", 160),
],
)
def test_com_microsoft_op_qdq_model_count(self, op_name: str, expected_count: int) -> None:
"""Test QDQ model count for com.microsoft ops."""
from winml.modelkit.pattern.op_input_gen import get_runtime_checker_op
from winml.modelkit.pattern.op_input_gen.qdq_gen import QDQGenerator

schema = ONNXDomain.COM_MICROSOFT.get_op_schema(op_name, 1)
qdq_gen = QDQGenerator(opset_version=1, domain=ONNXDomain.COM_MICROSOFT)
generator = get_runtime_checker_op(op_name)(schema, qdq_generator=qdq_gen)

count = 0
for kwargs, tags in generator.iter():
for _model, _final_tags in generator.iter_const_and_dynamic_models(kwargs, tags):
count += 1

assert count == expected_count, "If count changes, update both code and this comment"
Loading