diff --git a/src/winml/modelkit/analyze/runtime_checker/check_ops.py b/src/winml/modelkit/analyze/runtime_checker/check_ops.py index 606386585..98d58efa5 100644 --- a/src/winml/modelkit/analyze/runtime_checker/check_ops.py +++ b/src/winml/modelkit/analyze/runtime_checker/check_ops.py @@ -15,7 +15,6 @@ python -m winml.modelkit.analyze.runtime_checker.check_ops --all_ops """ -import hashlib import json from pathlib import Path from typing import Any @@ -26,100 +25,24 @@ from google.protobuf import json_format from onnx.defs import SchemaError +from ... import winml from ...onnx import ONNXDomain from ...pattern.op_input_gen import ( OpInputGenerator, get_registered_operators, get_runtime_checker_op, - normalize_constraint_dict, ) from ...pattern.op_input_gen.qdq_gen import QDQGenerator from ...sysinfo import SysInfo from ...utils import constants from ..utils.model_utils import get_op_since_version +from ..utils.op_utils import compute_case_signature, hash_case_signature from .ep_checker import EPChecker -def _compute_case_signature(case: dict, *, namespace: str) -> str: - """Compute a signature for a test case based on its content. - - The signature is used to match test cases across different runs, - allowing delta detection when the input generator changes. - - Args: - case: Test case dictionary containing type_vars, attrs, input_constraints, etc. - - Returns: - A string signature that uniquely identifies the test case. - """ - # Extract the key fields that define a test case - sig_parts = [] - - if namespace: - # Namespacing keeps case_index stable per output file when signatures collide across files - sig_parts.append(f"ns:{namespace}") - - def _safe_dump(obj: Any) -> str: - def _default(o: Any): - if isinstance(o, onnx.TensorProto): - return json.loads(json_format.MessageToJson(o)) - if isinstance(o, np.ndarray): - return o.tolist() - if isinstance(o, np.generic): - return o.item() - raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") - - return json.dumps(obj, sort_keys=True, default=_default) - - def _is_empty_top_level(value: Any) -> bool: - if value is None: - return True - if isinstance(value, (dict, list, tuple, set)): - return len(value) == 0 - return False - - # Type variables (e.g., T=FLOAT) - if "type_vars" in case: - type_vars = case["type_vars"] - sig_parts.append(f"types:{_safe_dump(type_vars)}") - - # Attributes - if "attrs" in case: - attrs = case["attrs"] - if not _is_empty_top_level(attrs): - sig_parts.append(f"attrs:{_safe_dump(attrs)}") - - # Input constraints (shapes/values) - if "input_constraints" in case: - constraints = { - k: normalize_constraint_dict(v) if isinstance(v, dict) else v - for k, v in case["input_constraints"].items() - } - sig_parts.append(f"inputs:{_safe_dump(constraints)}") - - # Input is constant flags - if "input_is_constant" in case: - is_const = case["input_is_constant"] - sig_parts.append(f"const:{_safe_dump(is_const)}") - - # Dynamic axes configuration - if "dynamic_axes" in case: - dynamic_axes = case["dynamic_axes"] - if not _is_empty_top_level(dynamic_axes): - sig_parts.append(f"dynamic:{_safe_dump(dynamic_axes)}") - - # QDQ configuration: include only when present to keep non-QDQ signatures stable. - if "qdq_types" in case: - qdq_types = case["qdq_types"] - if not _is_empty_top_level(qdq_types): - sig_parts.append(f"qdq:{_safe_dump(qdq_types)}") - - return "|".join(sig_parts) - - -def _hash_case_signature(signature: str) -> str: - """Return a stable hash value for a case signature.""" - return hashlib.sha256(signature.encode("utf-8")).hexdigest() +# Register WinML EPs at module level before any ORT session is created. +# This must stay at the top of the file so EPs are available for all downstream usage. +winml.register_execution_providers(ort=True) class CheckResultWriter: @@ -189,7 +112,7 @@ def __init__( if self._contains_not_run_reason(case): continue - sig = _compute_case_signature(case, namespace=self.case_namespace) + sig = compute_case_signature(case, namespace=self.case_namespace) self.existing_signatures[sig] = case check_result = case.get("check_result", {}) @@ -215,10 +138,10 @@ def should_skip_case(self, case: dict) -> bool: Returns: True if the case should be skipped. """ - sig = _compute_case_signature(case, namespace=self.case_namespace) + sig = compute_case_signature(case, namespace=self.case_namespace) if self.filter_case_indices is not None: assert self._filter_case_index_set is not None - return _hash_case_signature(sig) not in self._filter_case_index_set + return hash_case_signature(sig) not in self._filter_case_index_set if self.delta_only: # Only run brand-new cases; skip anything we already have @@ -237,7 +160,7 @@ def append_result(self, case: dict[str, Any]) -> None: Args: case: Test case dictionary """ - sig = _compute_case_signature(case, namespace=self.case_namespace) + sig = compute_case_signature(case, namespace=self.case_namespace) if sig in self.output_signatures: self.duplicate_skipped_count += 1 return @@ -260,10 +183,10 @@ def reuse_existing_result(self, case: dict) -> bool: Returns: True if existing result was found and reused, False otherwise. """ - sig = _compute_case_signature(case, namespace=self.case_namespace) + sig = compute_case_signature(case, namespace=self.case_namespace) if self.filter_case_indices is not None: assert self._filter_case_index_set is not None - if _hash_case_signature(sig) not in self._filter_case_index_set: + if hash_case_signature(sig) not in self._filter_case_index_set: return False existing_case = self.existing_signatures.get(sig) @@ -280,8 +203,8 @@ def reuse_existing_result(self, case: dict) -> bool: def _set_case_index_signature(self, case: dict[str, Any]) -> None: """Set case_index to a stable hash derived from normalized signature.""" - signature = _compute_case_signature(case, namespace=self.case_namespace) - case["case_index"] = _hash_case_signature(signature) + signature = compute_case_signature(case, namespace=self.case_namespace) + case["case_index"] = hash_case_signature(signature) def _contains_not_run_reason(self, case: dict[str, Any]) -> bool: """Check whether compile/run reason contains a not_run placeholder.""" @@ -846,9 +769,6 @@ def build_parser(): def run_from_args(args: Any) -> None: """Run check_ops from parsed CLI args.""" - from ... import winml - - winml.register_execution_providers(ort=True) available_ops = get_registered_operators() ops_to_check = available_ops if args.all_ops else args.ops ep_checker = get_ep_checker(args.ep, device=args.device) diff --git a/src/winml/modelkit/analyze/utils/op_utils.py b/src/winml/modelkit/analyze/utils/op_utils.py new file mode 100644 index 000000000..9090f07bf --- /dev/null +++ b/src/winml/modelkit/analyze/utils/op_utils.py @@ -0,0 +1,97 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Utility functions for operator check result management.""" + +import hashlib +import json +from typing import Any + +import numpy as np +import onnx +from google.protobuf import json_format + +from ...pattern.op_input_gen import normalize_constraint_dict + + +def compute_case_signature(case: dict, *, namespace: str) -> str: + """Compute a signature for a test case based on its content. + + The signature is used to match test cases across different runs, + allowing delta detection when the input generator changes. + + Args: + case: Test case dictionary containing type_vars, attrs, input_constraints, etc. + + Returns: + A string signature that uniquely identifies the test case. + """ + # Extract the key fields that define a test case + sig_parts = [] + + if namespace: + # Namespacing keeps case_index stable per output file when signatures collide across files + sig_parts.append(f"ns:{namespace}") + + def _safe_dump(obj: Any) -> str: + def _default(o: Any): + if isinstance(o, onnx.TensorProto): + return json.loads(json_format.MessageToJson(o)) + if isinstance(o, np.ndarray): + return o.tolist() + if isinstance(o, np.generic): + return o.item() + raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") + + return json.dumps(obj, sort_keys=True, default=_default) + + def _is_empty_top_level(value: Any) -> bool: + if value is None: + return True + if isinstance(value, (dict, list, tuple, set)): + return len(value) == 0 + return False + + # Type variables (e.g., T=FLOAT) + if "type_vars" in case: + type_vars = case["type_vars"] + sig_parts.append(f"types:{_safe_dump(type_vars)}") + + # Attributes + if "attrs" in case: + attrs = case["attrs"] + if not _is_empty_top_level(attrs): + sig_parts.append(f"attrs:{_safe_dump(attrs)}") + + # Input constraints (shapes/values) + if "input_constraints" in case: + constraints = { + k: normalize_constraint_dict(v) if isinstance(v, dict) else v + for k, v in case["input_constraints"].items() + } + sig_parts.append(f"inputs:{_safe_dump(constraints)}") + + # Input is constant flags + if "input_is_constant" in case: + is_const = case["input_is_constant"] + sig_parts.append(f"const:{_safe_dump(is_const)}") + + # Dynamic axes configuration + if "dynamic_axes" in case: + dynamic_axes = case["dynamic_axes"] + if not _is_empty_top_level(dynamic_axes): + sig_parts.append(f"dynamic:{_safe_dump(dynamic_axes)}") + + # QDQ configuration: include only when present to keep non-QDQ signatures stable. + if "qdq_types" in case: + qdq_types = case["qdq_types"] + if not _is_empty_top_level(qdq_types): + sig_parts.append(f"qdq:{_safe_dump(qdq_types)}") + + return "|".join(sig_parts) + + +def hash_case_signature(signature: str) -> str: + """Return a stable hash value for a case signature.""" + return hashlib.sha256(signature.encode("utf-8")).hexdigest() diff --git a/tests/unit/analyze/test_check_ops.py b/tests/unit/analyze/test_check_ops.py index 065247077..946d71dc2 100644 --- a/tests/unit/analyze/test_check_ops.py +++ b/tests/unit/analyze/test_check_ops.py @@ -6,7 +6,7 @@ import numpy as np -from winml.modelkit.analyze.runtime_checker.check_ops import _compute_case_signature +from winml.modelkit.analyze.utils.op_utils import compute_case_signature from winml.modelkit.pattern.op_input_gen import InputValueConstraint @@ -51,7 +51,7 @@ def test_all_same_value_array_compact_matches_expanded(self) -> None: compact_case = self._case({"X": compact_dict}) expanded_case = self._case({"X": expanded_dict}) - assert _compute_case_signature(compact_case, namespace="") == _compute_case_signature( + assert compute_case_signature(compact_case, namespace="") == compute_case_signature( expanded_case, namespace="" ) @@ -65,7 +65,7 @@ def test_mixed_value_array_uses_expanded_form_stable(self) -> None: assert "value" in full_dict and "same_value" not in full_dict case = self._case({"X": full_dict}) - assert _compute_case_signature(case, namespace="") == _compute_case_signature( + assert compute_case_signature(case, namespace="") == compute_case_signature( case, namespace="" ) @@ -89,7 +89,7 @@ def test_top_level_input_constraints_key_order_does_not_matter(self) -> None: }, } - assert _compute_case_signature(case_xy, namespace="") == _compute_case_signature( + assert compute_case_signature(case_xy, namespace="") == compute_case_signature( case_yx, namespace="" ) @@ -105,7 +105,7 @@ def test_different_values_produce_different_signatures(self) -> None: case_a = self._case({"X": InputValueConstraint(arr_a).to_dict()}) case_b = self._case({"X": InputValueConstraint(arr_b).to_dict()}) - assert _compute_case_signature(case_a, namespace="") != _compute_case_signature( + assert compute_case_signature(case_a, namespace="") != compute_case_signature( case_b, namespace="" ) @@ -117,7 +117,7 @@ def test_different_shapes_produce_different_signatures(self) -> None: case_a = self._case({"X": InputValueConstraint(arr_2x3).to_dict()}) case_b = self._case({"X": InputValueConstraint(arr_3x2).to_dict()}) - assert _compute_case_signature(case_a, namespace="") != _compute_case_signature( + assert compute_case_signature(case_a, namespace="") != compute_case_signature( case_b, namespace="" ) @@ -128,19 +128,19 @@ def test_different_shapes_produce_different_signatures(self) -> None: def test_namespace_is_included_in_signature(self) -> None: """Different namespaces produce different signatures for the same case.""" case = self._case({"X": InputValueConstraint(np.ones((2,), dtype=np.float32)).to_dict()}) - assert _compute_case_signature(case, namespace="file_a") != _compute_case_signature( + assert compute_case_signature(case, namespace="file_a") != compute_case_signature( case, namespace="file_b" ) def test_empty_namespace_omitted_from_signature(self) -> None: """An empty namespace string does not appear in the signature.""" case = {"type_vars": {"T": "FLOAT"}} - assert "ns:" not in _compute_case_signature(case, namespace="") + assert "ns:" not in compute_case_signature(case, namespace="") def test_empty_attrs_excluded_from_signature(self) -> None: """An empty attrs dict does not affect the signature.""" base = {"type_vars": {"T": "FLOAT"}} with_empty_attrs = {**base, "attrs": {}} - assert _compute_case_signature(base, namespace="") == _compute_case_signature( + assert compute_case_signature(base, namespace="") == compute_case_signature( with_empty_attrs, namespace="" )