Skip to content
Merged
13 changes: 8 additions & 5 deletions src/winml/modelkit/analyze/runtime_checker/check_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
from google.protobuf import json_format
from onnx.defs import SchemaError

from ... import winml
from ...onnx import ONNXDomain
from ...pattern.op_input_gen import (
OpInputGenerator,
get_registered_operators,
get_runtime_checker_op,
normalize_constraint_dict,
)
from ...pattern.op_input_gen.qdq_gen import QDQGenerator
from ...sysinfo import SysInfo
Expand All @@ -40,9 +40,6 @@
from .ep_checker import EPChecker


winml.register_execution_providers(ort=True)


def _compute_case_signature(case: dict, *, namespace: str) -> str:
"""Compute a signature for a test case based on its content.

Expand Down Expand Up @@ -94,7 +91,10 @@ def _is_empty_top_level(value: Any) -> bool:

# Input constraints (shapes/values)
if "input_constraints" in case:
constraints = case["input_constraints"]
constraints = {
k: normalize_constraint_dict(v) if isinstance(v, dict) else v
for k, v in case["input_constraints"].items()
}
sig_parts.append(f"inputs:{_safe_dump(constraints)}")

# Input is constant flags
Expand Down Expand Up @@ -846,6 +846,9 @@ def build_parser():

def run_from_args(args: Any) -> None:
"""Run check_ops from parsed CLI args."""
from ... import winml

winml.register_execution_providers(ort=True)
Comment thread
xieofxie marked this conversation as resolved.
available_ops = get_registered_operators()
ops_to_check = available_ops if args.all_ops else args.ops
ep_checker = get_ep_checker(args.ep, device=args.device)
Expand Down
12 changes: 9 additions & 3 deletions src/winml/modelkit/analyze/runtime_checker/result_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

from ...onnx import ONNXDomain
from ...pattern.base import get_pattern_input_generator
from ...pattern.op_input_gen import OpInputGenerator, get_runtime_checker_op
from ...pattern.op_input_gen import (
OpInputGenerator,
get_runtime_checker_op,
normalize_constraint_dict,
)
from ..utils.model_utils import get_op_since_version, make_hashable


Expand Down Expand Up @@ -138,7 +142,9 @@ def set_properties_for_dynamic_axes(input_name: str, is_constant: bool):
for element in constraint["elements"]
)
res[f"{input_name}_value"] = tuple(
element["value"] if element["type"] == "value" else None
normalize_constraint_dict(element)["value"]
if element["type"] == "value"
else None
for element in constraint["elements"]
)
if use_qdq:
Expand All @@ -164,7 +170,7 @@ def set_properties_for_dynamic_axes(input_name: str, is_constant: bool):
res[f"{input_name}_shape"] = constraint["shape"]
res[f"{input_name}_is_none"] = False
else: # value
res[f"{input_name}_value"] = constraint["value"]
res[f"{input_name}_value"] = normalize_constraint_dict(constraint)["value"]
res[f"{input_name}_is_none"] = False

if use_qdq and f"{input_name}_is_constant" not in res:
Expand Down
25 changes: 25 additions & 0 deletions src/winml/modelkit/pattern/op_input_gen/op_input_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ def get_value(self, type_annotation: str = "") -> Any:
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-serializable dictionary."""
if isinstance(self.value, np.ndarray):
flat = self.value.ravel()
if flat.size > 0 and np.all(flat == flat[0]):
return {
"type": "value",
"same_value": flat[0].item(),
"same_value_shape": list(self.value.shape),
"dtype": str(self.value.dtype),
}
return {
"type": "value",
"value": self.value.tolist(), # Nested list structure reflects shape
Expand All @@ -175,6 +183,23 @@ def to_dict(self) -> dict[str, Any]:
raise TypeError(msg)


def normalize_constraint_dict(c: dict) -> dict:
"""Expand same_value/same_value_shape back to the canonical value list form.

Converts the compact representation produced by InputValueConstraint.to_dict()
when all values are equal, back to the full nested value list. Use this when
consuming serialized constraint dicts to ensure consistent handling regardless
of which representation was saved.
"""
if "same_value" in c and "same_value_shape" in c:
normalized = {k: v for k, v in c.items() if k not in ("same_value", "same_value_shape")}
dtype = np.dtype(c["dtype"]) if "dtype" in c else None
normalized["value"] = np.full(c["same_value_shape"], c["same_value"], dtype=dtype).tolist()

return normalized
return c


class InputShapeConstraint(InputConstraint):
"""Constraint on shape of the input tensor.

Expand Down
146 changes: 146 additions & 0 deletions tests/unit/analyze/test_check_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Unit tests for _compute_case_signature in check_ops."""

import numpy as np

from winml.modelkit.analyze.runtime_checker.check_ops import _compute_case_signature
from winml.modelkit.pattern.op_input_gen import InputValueConstraint


class TestComputeCaseSignature:
"""Tests for _compute_case_signature."""

# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------

@staticmethod
def _case(input_constraints: dict) -> dict:
return {"type_vars": {"T": "FLOAT"}, "input_constraints": input_constraints}

# ------------------------------------------------------------------
# Core: compact vs expanded representation
# ------------------------------------------------------------------

def test_all_same_value_array_compact_matches_expanded(self) -> None:
"""InputValueConstraint with all-same values serializes to compact form;
signature must match the equivalent fully-expanded value list (backward compat).

InputValueConstraint.to_dict() emits:
{"type": "value", "same_value": 1.0, "same_value_shape": [2, 3], "dtype": "float32"}
Old code stored the expanded form:
{"type": "value", "value": [[1.0, ...], ...], "dtype": "float32"}
Both must hash to the same signature.
"""
arr = np.ones((2, 3), dtype=np.float32)

compact_dict = InputValueConstraint(arr).to_dict()
# Compact form has same_value / same_value_shape keys
assert "same_value" in compact_dict and "same_value_shape" in compact_dict

# Manually build what the old code would have stored (expanded form)
expanded_dict = {
"type": "value",
"value": arr.tolist(),
"dtype": str(arr.dtype),
}

compact_case = self._case({"X": compact_dict})
expanded_case = self._case({"X": expanded_dict})

assert _compute_case_signature(compact_case, namespace="") == _compute_case_signature(
expanded_case, namespace=""
)

def test_mixed_value_array_uses_expanded_form_stable(self) -> None:
"""InputValueConstraint with non-uniform values always emits the full value list;
signature is identical to computing it from the raw to_dict() output.
"""
arr = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)

full_dict = InputValueConstraint(arr).to_dict()
assert "value" in full_dict and "same_value" not in full_dict

case = self._case({"X": full_dict})
assert _compute_case_signature(case, namespace="") == _compute_case_signature(
case, namespace=""
)

def test_top_level_input_constraints_key_order_does_not_matter(self) -> None:
"""Multiple inputs: swapping their dict insertion order must not change the signature."""
arr_x = np.zeros((2,), dtype=np.float32)
arr_y = np.ones((2,), dtype=np.float32)

case_xy = {
"type_vars": {"T": "FLOAT"},
"input_constraints": {
"X": InputValueConstraint(arr_x).to_dict(),
"Y": InputValueConstraint(arr_y).to_dict(),
},
}
case_yx = {
"type_vars": {"T": "FLOAT"},
"input_constraints": {
"Y": InputValueConstraint(arr_y).to_dict(),
"X": InputValueConstraint(arr_x).to_dict(),
},
}

assert _compute_case_signature(case_xy, namespace="") == _compute_case_signature(
case_yx, namespace=""
)

# ------------------------------------------------------------------
# Basic discriminability
# ------------------------------------------------------------------

def test_different_values_produce_different_signatures(self) -> None:
"""Two constraints with different underlying values must not collide."""
arr_a = np.ones((2, 2), dtype=np.float32)
arr_b = np.zeros((2, 2), dtype=np.float32)

case_a = self._case({"X": InputValueConstraint(arr_a).to_dict()})
case_b = self._case({"X": InputValueConstraint(arr_b).to_dict()})

assert _compute_case_signature(case_a, namespace="") != _compute_case_signature(
case_b, namespace=""
)

def test_different_shapes_produce_different_signatures(self) -> None:
"""Same fill value but different shapes must produce different signatures."""
arr_2x3 = np.ones((2, 3), dtype=np.float32)
arr_3x2 = np.ones((3, 2), dtype=np.float32)

case_a = self._case({"X": InputValueConstraint(arr_2x3).to_dict()})
case_b = self._case({"X": InputValueConstraint(arr_3x2).to_dict()})

assert _compute_case_signature(case_a, namespace="") != _compute_case_signature(
case_b, namespace=""
)

# ------------------------------------------------------------------
# Namespace / other fields
# ------------------------------------------------------------------

def test_namespace_is_included_in_signature(self) -> None:
"""Different namespaces produce different signatures for the same case."""
case = self._case({"X": InputValueConstraint(np.ones((2,), dtype=np.float32)).to_dict()})
assert _compute_case_signature(case, namespace="file_a") != _compute_case_signature(
case, namespace="file_b"
)

def test_empty_namespace_omitted_from_signature(self) -> None:
"""An empty namespace string does not appear in the signature."""
case = {"type_vars": {"T": "FLOAT"}}
assert "ns:" not in _compute_case_signature(case, namespace="")

def test_empty_attrs_excluded_from_signature(self) -> None:
"""An empty attrs dict does not affect the signature."""
base = {"type_vars": {"T": "FLOAT"}}
with_empty_attrs = {**base, "attrs": {}}
assert _compute_case_signature(base, namespace="") == _compute_case_signature(
with_empty_attrs, namespace=""
)
Loading