Skip to content

Commit b606007

Browse files
xieofxiehualxie
andauthored
update QDQ config for P1 models (#204)
* update for p1 * updated * update test * fix * update test --------- Co-authored-by: hualxie <hualxie@microsoft.com>
1 parent d436597 commit b606007

6 files changed

Lines changed: 56 additions & 17 deletions

File tree

src/winml/modelkit/analyze/runtime_checker/result_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def check_df_consistent(
267267
for group_key, group_df in grouped:
268268
eval_df = group_df
269269
if placeholder_col in group_df.columns:
270-
eval_df = group_df[not group_df[placeholder_col]]
270+
eval_df = group_df[group_df[placeholder_col].isna() | (group_df[placeholder_col] == "")]
271271

272272
# If all rows are placeholders, this group should not trigger conflicts.
273273
if eval_df.empty:

src/winml/modelkit/pattern/op_input_gen/conv_input_generator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,16 @@ def get_infinite_property_names(self) -> list[str]:
259259
"B_shape",
260260
]
261261

262+
def get_qdq_config(self):
263+
"""Return QDQ configuration for Conv operator inputs."""
264+
# B/Y can be non-QDQ from P1 models
265+
return {
266+
"X": QDQParameterConfig(support_activation=True),
267+
"W": QDQParameterConfig(support_weight=True),
268+
"B": QDQParameterConfig(support_non_qdq=True, qdq_types=[SupportedONNXType.INT32]),
269+
"Y": QDQParameterConfig(support_non_qdq=True, support_activation=True),
270+
}
271+
262272

263273
@register_runtime_checker_op
264274
class ConvTransposeInputGenerator(ConvInputGenerator):

src/winml/modelkit/pattern/op_input_gen/matmul_input_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,5 +304,5 @@ def get_qdq_config(self):
304304
return {
305305
"A": QDQParameterConfig(support_activation=True),
306306
"B": QDQParameterConfig(support_weight=True),
307-
"C": QDQParameterConfig(qdq_types=[SupportedONNXType.INT32]),
307+
"C": QDQParameterConfig(support_non_qdq=True, qdq_types=[SupportedONNXType.INT32]),
308308
}

src/winml/modelkit/pattern/op_input_gen/resize_input_generator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,22 @@ def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, InputC
9595
}
9696
)
9797

98+
# p1 model - segmentation
99+
combinations.append(
100+
{
101+
"coordinate_transformation_mode": "pytorch_half_pixel",
102+
"cubic_coeff_a": -0.75,
103+
"mode": "linear",
104+
"nearest_mode": "floor",
105+
"X": InputShapeConstraint(x_shape),
106+
# Explicit roi to avoid empty-optional rejection on older schemas
107+
"roi": InputValueConstraint(np.zeros(2 * ndim, dtype=np.float32)),
108+
"scales": InputValueConstraint(scales_up),
109+
"extrapolation_value": 0.0,
110+
"axes": list(range(ndim)), # All axes
111+
}
112+
)
113+
98114
# Combination using scales (downsample) - empty sizes
99115
combinations.append(
100116
{

src/winml/modelkit/pattern/op_input_gen/transpose_input_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,5 +126,5 @@ def get_infinite_property_names(self) -> list[str]:
126126
def get_qdq_config(self):
127127
"""Return QDQ configuration for Transpose operator inputs."""
128128
return {
129-
"data": QDQParameterConfig(support_activation=True),
129+
"data": QDQParameterConfig(support_non_qdq=True, support_activation=True),
130130
}

tests/unit/analyze/core/test_qdq.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -827,8 +827,9 @@ class TestIterQDQCombinationsTagSchema:
827827
- Optional QDQ input not provided: present in 'qdq_types' as ''
828828
829829
Note: When an operator has pass-through inputs, 'input_is_constant' contains only
830-
those pass-through inputs (Gather). When no pass-through inputs exist, all inputs
831-
are in 'input_is_constant' from the outer constant-combination loop (Gemm).
830+
those pass-through inputs (Gather). When an optional input supports both QDQ and
831+
non-QDQ (pass-through) modes (Gemm C), 'input_is_constant' contains that input only
832+
in the non-QDQ combination; pure QDQ inputs (A, B) never appear in 'input_is_constant'.
832833
"""
833834

834835
@pytest.fixture
@@ -970,29 +971,41 @@ def test_gemm_weight_b_present_in_qdq_types(self, gemm_gen) -> None:
970971
assert final_tags["qdq_types"]["B"] != ""
971972

972973
def test_gemm_qdq_inputs_not_in_input_is_constant(self, gemm_gen) -> None:
973-
"""A (activation) and B (weight) not exist in input_is_constant.
974+
"""A (activation) and B (weight) never appear in input_is_constant.
974975
975-
Gemm has no pass-through inputs, so input_is_constant does not exist.
976+
C may appear in input_is_constant for its non-QDQ (pass-through) combination,
977+
but pure QDQ inputs A and B are never pass-through.
976978
"""
977979
gen = gemm_gen
978980
kwargs, tags = self._gemm_float_c_provided_kwargs_tags(gen)
979981
results = list(gen.iter_const_and_dynamic_models(kwargs, tags))
980982
assert len(results) > 0
981983
for _, final_tags in results:
982-
assert "input_is_constant" not in final_tags
984+
ic = final_tags.get("input_is_constant", {})
985+
assert "A" not in ic
986+
assert "B" not in ic
983987

984988
# ---- Gemm: optional QDQ input (C) ----
985989

986990
def test_gemm_optional_c_provided_has_int32_type(self, gemm_gen) -> None:
987-
"""When optional C is provided as constant, qdq_types['C'] is INT32 annotation."""
991+
"""When optional C is provided and quantized, qdq_types['C'] is INT32 annotation.
992+
993+
C supports both QDQ (INT32) and non-QDQ (pass-through) modes. In the non-QDQ
994+
combination qdq_types['C'] is None; when quantized it must be INT32.
995+
"""
988996
gen = gemm_gen
989997
int32_ann = dtypes.SupportedONNXType.INT32.annotation
990998
kwargs, tags = self._gemm_float_c_provided_kwargs_tags(gen)
991999
results = list(gen.iter_const_and_dynamic_models(kwargs, tags))
9921000
assert len(results) > 0
1001+
int32_seen = False
9931002
for _, final_tags in results:
9941003
assert "C" in final_tags["qdq_types"]
995-
assert final_tags["qdq_types"]["C"] == int32_ann
1004+
c_type = final_tags["qdq_types"]["C"]
1005+
if c_type is not None:
1006+
assert c_type == int32_ann
1007+
int32_seen = True
1008+
assert int32_seen, "Expected at least one result with C quantized as INT32"
9961009

9971010
def test_gemm_optional_c_not_provided_recorded_as_empty_in_qdq_types(self, gemm_gen) -> None:
9981011
"""When optional C is not provided (None), qdq_types['C'] is '' (not omitted)."""
@@ -1045,8 +1058,8 @@ class TestIterQDQCombinations:
10451058
("Concat", 240), # 15 base shapes/axes * 4 variadic counts * 4 activation types
10461059
(
10471060
"Conv",
1048-
1536,
1049-
), # shape 3 * auto_pad 4 * group_opts 2 * kernel shape 2 * optional b 2 * 16 = 1536
1061+
1536 * 4,
1062+
), # shape 3 * attrs 4 * 2 * kernel shape 2 * opt B 2 * 16 * B/Y non qdq 4
10501063
(
10511064
"ConvTranspose",
10521065
3072,
@@ -1071,8 +1084,8 @@ class TestIterQDQCombinations:
10711084
("Gelu", unary_input_shapes * 4 * 2), # 64
10721085
(
10731086
"Gemm",
1074-
2304,
1075-
), # attributes 2 * 2 * 3 * 3 * C dim 4 * 16 = 2304
1087+
36 * 16 * (4 + 3 * 2),
1088+
), # attributes (2 * 2 * 3 * 3) * QDQ * C (qdq + non-qdq * opt)
10761089
("GlobalAveragePool", 3 * 4), # 12
10771090
("InstanceNormalization", 3 * 16), # 48
10781091
("LayerNormalization", 5 * 2 * 2 * 16), # 320
@@ -1094,9 +1107,9 @@ class TestIterQDQCombinations:
10941107
("Reshape", 36 * 4 * 2 * 2), # allowzero 2 * is_constant 2
10951108
(
10961109
"Resize",
1097-
2880,
1110+
3456,
10981111
), # shape 4 * T2 3 * QDQ 4 * antialias 2
1099-
# * attribute 5 * (optional input 4 + 2)
1112+
# * attribute 6 * (optional input 4 + 2)
11001113
(
11011114
"ScatterND",
11021115
1680,
@@ -1112,7 +1125,7 @@ class TestIterQDQCombinations:
11121125
# All unary use this and it is enough
11131126
("Tanh", unary_input_shapes * 4), # 32
11141127
("TopK", 768), # QDQ 4 * example 12 * k is_constant 2 * parameter 8
1115-
("Transpose", 11 * 4 * 2), # 88
1128+
("Transpose", 11 * 4 * 2 * 2), # cases * QDQ * opt perm * non_qdq data
11161129
("Unsqueeze", 208), # 26 * 4 QDQ types * 2 is_constant axes
11171130
(
11181131
"Where",

0 commit comments

Comments
 (0)