Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def check_df_consistent(
for group_key, group_df in grouped:
eval_df = group_df
if placeholder_col in group_df.columns:
eval_df = group_df[not group_df[placeholder_col]]
eval_df = group_df[group_df[placeholder_col].isna() | (group_df[placeholder_col] == "")]

# If all rows are placeholders, this group should not trigger conflicts.
if eval_df.empty:
Expand Down
10 changes: 10 additions & 0 deletions src/winml/modelkit/pattern/op_input_gen/conv_input_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,16 @@ def get_infinite_property_names(self) -> list[str]:
"B_shape",
]

def get_qdq_config(self):
"""Return QDQ configuration for Conv operator inputs."""
# B/Y can be non-QDQ from P1 models
return {
"X": QDQParameterConfig(support_activation=True),
"W": QDQParameterConfig(support_weight=True),
"B": QDQParameterConfig(support_non_qdq=True, qdq_types=[SupportedONNXType.INT32]),
"Y": QDQParameterConfig(support_non_qdq=True, support_activation=True),
}


@register_runtime_checker_op
class ConvTransposeInputGenerator(ConvInputGenerator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -304,5 +304,5 @@ def get_qdq_config(self):
return {
"A": QDQParameterConfig(support_activation=True),
"B": QDQParameterConfig(support_weight=True),
"C": QDQParameterConfig(qdq_types=[SupportedONNXType.INT32]),
"C": QDQParameterConfig(support_non_qdq=True, qdq_types=[SupportedONNXType.INT32]),
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ def get_input_and_infinite_attribute_combinations(self) -> list[dict[str, InputC
}
)

# p1 model - segmentation
combinations.append(
{
"coordinate_transformation_mode": "pytorch_half_pixel",
"cubic_coeff_a": -0.75,
"mode": "linear",
"nearest_mode": "floor",
"X": InputShapeConstraint(x_shape),
# Explicit roi to avoid empty-optional rejection on older schemas
"roi": InputValueConstraint(np.zeros(2 * ndim, dtype=np.float32)),
"scales": InputValueConstraint(scales_up),
"extrapolation_value": 0.0,
"axes": list(range(ndim)), # All axes
}
)

# Combination using scales (downsample) - empty sizes
combinations.append(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,5 +126,5 @@ def get_infinite_property_names(self) -> list[str]:
def get_qdq_config(self):
"""Return QDQ configuration for Transpose operator inputs."""
return {
"data": QDQParameterConfig(support_activation=True),
"data": QDQParameterConfig(support_non_qdq=True, support_activation=True),
}
41 changes: 27 additions & 14 deletions tests/unit/analyze/core/test_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,9 @@ class TestIterQDQCombinationsTagSchema:
- Optional QDQ input not provided: present in 'qdq_types' as ''

Note: When an operator has pass-through inputs, 'input_is_constant' contains only
those pass-through inputs (Gather). When no pass-through inputs exist, all inputs
are in 'input_is_constant' from the outer constant-combination loop (Gemm).
those pass-through inputs (Gather). When an optional input supports both QDQ and
non-QDQ (pass-through) modes (Gemm C), 'input_is_constant' contains that input only
in the non-QDQ combination; pure QDQ inputs (A, B) never appear in 'input_is_constant'.
"""

@pytest.fixture
Expand Down Expand Up @@ -970,29 +971,41 @@ def test_gemm_weight_b_present_in_qdq_types(self, gemm_gen) -> None:
assert final_tags["qdq_types"]["B"] != ""

def test_gemm_qdq_inputs_not_in_input_is_constant(self, gemm_gen) -> None:
"""A (activation) and B (weight) not exist in input_is_constant.
"""A (activation) and B (weight) never appear in input_is_constant.

Gemm has no pass-through inputs, so input_is_constant does not exist.
C may appear in input_is_constant for its non-QDQ (pass-through) combination,
but pure QDQ inputs A and B are never pass-through.
"""
gen = gemm_gen
kwargs, tags = self._gemm_float_c_provided_kwargs_tags(gen)
results = list(gen.iter_const_and_dynamic_models(kwargs, tags))
assert len(results) > 0
for _, final_tags in results:
assert "input_is_constant" not in final_tags
ic = final_tags.get("input_is_constant", {})
assert "A" not in ic
assert "B" not in ic

# ---- Gemm: optional QDQ input (C) ----

def test_gemm_optional_c_provided_has_int32_type(self, gemm_gen) -> None:
"""When optional C is provided as constant, qdq_types['C'] is INT32 annotation."""
"""When optional C is provided and quantized, qdq_types['C'] is INT32 annotation.

C supports both QDQ (INT32) and non-QDQ (pass-through) modes. In the non-QDQ
combination qdq_types['C'] is None; when quantized it must be INT32.
"""
gen = gemm_gen
int32_ann = dtypes.SupportedONNXType.INT32.annotation
kwargs, tags = self._gemm_float_c_provided_kwargs_tags(gen)
results = list(gen.iter_const_and_dynamic_models(kwargs, tags))
assert len(results) > 0
int32_seen = False
for _, final_tags in results:
assert "C" in final_tags["qdq_types"]
assert final_tags["qdq_types"]["C"] == int32_ann
c_type = final_tags["qdq_types"]["C"]
if c_type is not None:
assert c_type == int32_ann
int32_seen = True
assert int32_seen, "Expected at least one result with C quantized as INT32"

def test_gemm_optional_c_not_provided_recorded_as_empty_in_qdq_types(self, gemm_gen) -> None:
"""When optional C is not provided (None), qdq_types['C'] is '' (not omitted)."""
Expand Down Expand Up @@ -1045,8 +1058,8 @@ class TestIterQDQCombinations:
("Concat", 240), # 15 base shapes/axes * 4 variadic counts * 4 activation types
(
"Conv",
1536,
), # shape 3 * auto_pad 4 * group_opts 2 * kernel shape 2 * optional b 2 * 16 = 1536
1536 * 4,
), # shape 3 * attrs 4 * 2 * kernel shape 2 * opt B 2 * 16 * B/Y non qdq 4
(
"ConvTranspose",
3072,
Expand All @@ -1071,8 +1084,8 @@ class TestIterQDQCombinations:
("Gelu", unary_input_shapes * 4 * 2), # 64
(
"Gemm",
2304,
), # attributes 2 * 2 * 3 * 3 * C dim 4 * 16 = 2304
36 * 16 * (4 + 3 * 2),
), # attributes (2 * 2 * 3 * 3) * QDQ * C (qdq + non-qdq * opt)
("GlobalAveragePool", 3 * 4), # 12
("InstanceNormalization", 3 * 16), # 48
("LayerNormalization", 5 * 2 * 2 * 16), # 320
Expand All @@ -1094,9 +1107,9 @@ class TestIterQDQCombinations:
("Reshape", 36 * 4 * 2 * 2), # allowzero 2 * is_constant 2
(
"Resize",
2880,
3456,
), # shape 4 * T2 3 * QDQ 4 * antialias 2
# * attribute 5 * (optional input 4 + 2)
# * attribute 6 * (optional input 4 + 2)
(
"ScatterND",
1680,
Expand All @@ -1112,7 +1125,7 @@ class TestIterQDQCombinations:
# All unary use this and it is enough
("Tanh", unary_input_shapes * 4), # 32
("TopK", 768), # QDQ 4 * example 12 * k is_constant 2 * parameter 8
("Transpose", 11 * 4 * 2), # 88
("Transpose", 11 * 4 * 2 * 2), # cases * QDQ * opt perm * non_qdq data
("Unsqueeze", 208), # 26 * 4 QDQ types * 2 is_constant axes
(
"Where",
Expand Down
Loading