Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 13 additions & 93 deletions src/winml/modelkit/analyze/runtime_checker/check_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
python -m winml.modelkit.analyze.runtime_checker.check_ops --all_ops
"""

import hashlib
import json
from pathlib import Path
from typing import Any
Expand All @@ -26,100 +25,24 @@
from google.protobuf import json_format
from onnx.defs import SchemaError

from ... import winml
from ...onnx import ONNXDomain
from ...pattern.op_input_gen import (
OpInputGenerator,
get_registered_operators,
get_runtime_checker_op,
normalize_constraint_dict,
)
from ...pattern.op_input_gen.qdq_gen import QDQGenerator
from ...sysinfo import SysInfo
from ...utils import constants
from ..utils.model_utils import get_op_since_version
from ..utils.op_utils import compute_case_signature, hash_case_signature
from .ep_checker import EPChecker


def _compute_case_signature(case: dict, *, namespace: str) -> str:
"""Compute a signature for a test case based on its content.

The signature is used to match test cases across different runs,
allowing delta detection when the input generator changes.

Args:
case: Test case dictionary containing type_vars, attrs, input_constraints, etc.

Returns:
A string signature that uniquely identifies the test case.
"""
# Extract the key fields that define a test case
sig_parts = []

if namespace:
# Namespacing keeps case_index stable per output file when signatures collide across files
sig_parts.append(f"ns:{namespace}")

def _safe_dump(obj: Any) -> str:
def _default(o: Any):
if isinstance(o, onnx.TensorProto):
return json.loads(json_format.MessageToJson(o))
if isinstance(o, np.ndarray):
return o.tolist()
if isinstance(o, np.generic):
return o.item()
raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable")

return json.dumps(obj, sort_keys=True, default=_default)

def _is_empty_top_level(value: Any) -> bool:
if value is None:
return True
if isinstance(value, (dict, list, tuple, set)):
return len(value) == 0
return False

# Type variables (e.g., T=FLOAT)
if "type_vars" in case:
type_vars = case["type_vars"]
sig_parts.append(f"types:{_safe_dump(type_vars)}")

# Attributes
if "attrs" in case:
attrs = case["attrs"]
if not _is_empty_top_level(attrs):
sig_parts.append(f"attrs:{_safe_dump(attrs)}")

# Input constraints (shapes/values)
if "input_constraints" in case:
constraints = {
k: normalize_constraint_dict(v) if isinstance(v, dict) else v
for k, v in case["input_constraints"].items()
}
sig_parts.append(f"inputs:{_safe_dump(constraints)}")

# Input is constant flags
if "input_is_constant" in case:
is_const = case["input_is_constant"]
sig_parts.append(f"const:{_safe_dump(is_const)}")

# Dynamic axes configuration
if "dynamic_axes" in case:
dynamic_axes = case["dynamic_axes"]
if not _is_empty_top_level(dynamic_axes):
sig_parts.append(f"dynamic:{_safe_dump(dynamic_axes)}")

# QDQ configuration: include only when present to keep non-QDQ signatures stable.
if "qdq_types" in case:
qdq_types = case["qdq_types"]
if not _is_empty_top_level(qdq_types):
sig_parts.append(f"qdq:{_safe_dump(qdq_types)}")

return "|".join(sig_parts)


def _hash_case_signature(signature: str) -> str:
"""Return a stable hash value for a case signature."""
return hashlib.sha256(signature.encode("utf-8")).hexdigest()
# Register WinML EPs at module level before any ORT session is created.
# This must stay at the top of the file so EPs are available for all downstream usage.
winml.register_execution_providers(ort=True)


class CheckResultWriter:
Expand Down Expand Up @@ -189,7 +112,7 @@ def __init__(
if self._contains_not_run_reason(case):
continue

sig = _compute_case_signature(case, namespace=self.case_namespace)
sig = compute_case_signature(case, namespace=self.case_namespace)
self.existing_signatures[sig] = case

check_result = case.get("check_result", {})
Expand All @@ -215,10 +138,10 @@ def should_skip_case(self, case: dict) -> bool:
Returns:
True if the case should be skipped.
"""
sig = _compute_case_signature(case, namespace=self.case_namespace)
sig = compute_case_signature(case, namespace=self.case_namespace)
if self.filter_case_indices is not None:
assert self._filter_case_index_set is not None
return _hash_case_signature(sig) not in self._filter_case_index_set
return hash_case_signature(sig) not in self._filter_case_index_set

if self.delta_only:
# Only run brand-new cases; skip anything we already have
Expand All @@ -237,7 +160,7 @@ def append_result(self, case: dict[str, Any]) -> None:
Args:
case: Test case dictionary
"""
sig = _compute_case_signature(case, namespace=self.case_namespace)
sig = compute_case_signature(case, namespace=self.case_namespace)
if sig in self.output_signatures:
self.duplicate_skipped_count += 1
return
Expand All @@ -260,10 +183,10 @@ def reuse_existing_result(self, case: dict) -> bool:
Returns:
True if existing result was found and reused, False otherwise.
"""
sig = _compute_case_signature(case, namespace=self.case_namespace)
sig = compute_case_signature(case, namespace=self.case_namespace)
if self.filter_case_indices is not None:
assert self._filter_case_index_set is not None
if _hash_case_signature(sig) not in self._filter_case_index_set:
if hash_case_signature(sig) not in self._filter_case_index_set:
return False

existing_case = self.existing_signatures.get(sig)
Expand All @@ -280,8 +203,8 @@ def reuse_existing_result(self, case: dict) -> bool:

def _set_case_index_signature(self, case: dict[str, Any]) -> None:
"""Set case_index to a stable hash derived from normalized signature."""
signature = _compute_case_signature(case, namespace=self.case_namespace)
case["case_index"] = _hash_case_signature(signature)
signature = compute_case_signature(case, namespace=self.case_namespace)
case["case_index"] = hash_case_signature(signature)

def _contains_not_run_reason(self, case: dict[str, Any]) -> bool:
"""Check whether compile/run reason contains a not_run placeholder."""
Expand Down Expand Up @@ -846,9 +769,6 @@ def build_parser():

def run_from_args(args: Any) -> None:
"""Run check_ops from parsed CLI args."""
from ... import winml

winml.register_execution_providers(ort=True)
available_ops = get_registered_operators()
ops_to_check = available_ops if args.all_ops else args.ops
ep_checker = get_ep_checker(args.ep, device=args.device)
Expand Down
97 changes: 97 additions & 0 deletions src/winml/modelkit/analyze/utils/op_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Utility functions for operator check result management."""

import hashlib
import json
from typing import Any

import numpy as np
import onnx
from google.protobuf import json_format

from ...pattern.op_input_gen import normalize_constraint_dict


def compute_case_signature(case: dict, *, namespace: str) -> str:
"""Compute a signature for a test case based on its content.

The signature is used to match test cases across different runs,
allowing delta detection when the input generator changes.

Args:
case: Test case dictionary containing type_vars, attrs, input_constraints, etc.

Returns:
A string signature that uniquely identifies the test case.
"""
# Extract the key fields that define a test case
sig_parts = []

if namespace:
# Namespacing keeps case_index stable per output file when signatures collide across files
sig_parts.append(f"ns:{namespace}")

def _safe_dump(obj: Any) -> str:
def _default(o: Any):
if isinstance(o, onnx.TensorProto):
return json.loads(json_format.MessageToJson(o))
if isinstance(o, np.ndarray):
return o.tolist()
if isinstance(o, np.generic):
return o.item()
raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable")

return json.dumps(obj, sort_keys=True, default=_default)

def _is_empty_top_level(value: Any) -> bool:
if value is None:
return True
if isinstance(value, (dict, list, tuple, set)):
return len(value) == 0
return False

# Type variables (e.g., T=FLOAT)
if "type_vars" in case:
type_vars = case["type_vars"]
sig_parts.append(f"types:{_safe_dump(type_vars)}")

# Attributes
if "attrs" in case:
attrs = case["attrs"]
if not _is_empty_top_level(attrs):
sig_parts.append(f"attrs:{_safe_dump(attrs)}")

# Input constraints (shapes/values)
if "input_constraints" in case:
constraints = {
k: normalize_constraint_dict(v) if isinstance(v, dict) else v
for k, v in case["input_constraints"].items()
}
sig_parts.append(f"inputs:{_safe_dump(constraints)}")

# Input is constant flags
if "input_is_constant" in case:
is_const = case["input_is_constant"]
sig_parts.append(f"const:{_safe_dump(is_const)}")

# Dynamic axes configuration
if "dynamic_axes" in case:
dynamic_axes = case["dynamic_axes"]
if not _is_empty_top_level(dynamic_axes):
sig_parts.append(f"dynamic:{_safe_dump(dynamic_axes)}")

# QDQ configuration: include only when present to keep non-QDQ signatures stable.
if "qdq_types" in case:
qdq_types = case["qdq_types"]
if not _is_empty_top_level(qdq_types):
sig_parts.append(f"qdq:{_safe_dump(qdq_types)}")

return "|".join(sig_parts)


def hash_case_signature(signature: str) -> str:
"""Return a stable hash value for a case signature."""
return hashlib.sha256(signature.encode("utf-8")).hexdigest()
18 changes: 9 additions & 9 deletions tests/unit/analyze/test_check_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from winml.modelkit.analyze.runtime_checker.check_ops import _compute_case_signature
from winml.modelkit.analyze.utils.op_utils import compute_case_signature
from winml.modelkit.pattern.op_input_gen import InputValueConstraint


Expand Down Expand Up @@ -51,7 +51,7 @@ def test_all_same_value_array_compact_matches_expanded(self) -> None:
compact_case = self._case({"X": compact_dict})
expanded_case = self._case({"X": expanded_dict})

assert _compute_case_signature(compact_case, namespace="") == _compute_case_signature(
assert compute_case_signature(compact_case, namespace="") == compute_case_signature(
expanded_case, namespace=""
)

Expand All @@ -65,7 +65,7 @@ def test_mixed_value_array_uses_expanded_form_stable(self) -> None:
assert "value" in full_dict and "same_value" not in full_dict

case = self._case({"X": full_dict})
assert _compute_case_signature(case, namespace="") == _compute_case_signature(
assert compute_case_signature(case, namespace="") == compute_case_signature(
case, namespace=""
)

Expand All @@ -89,7 +89,7 @@ def test_top_level_input_constraints_key_order_does_not_matter(self) -> None:
},
}

assert _compute_case_signature(case_xy, namespace="") == _compute_case_signature(
assert compute_case_signature(case_xy, namespace="") == compute_case_signature(
case_yx, namespace=""
)

Expand All @@ -105,7 +105,7 @@ def test_different_values_produce_different_signatures(self) -> None:
case_a = self._case({"X": InputValueConstraint(arr_a).to_dict()})
case_b = self._case({"X": InputValueConstraint(arr_b).to_dict()})

assert _compute_case_signature(case_a, namespace="") != _compute_case_signature(
assert compute_case_signature(case_a, namespace="") != compute_case_signature(
case_b, namespace=""
)

Expand All @@ -117,7 +117,7 @@ def test_different_shapes_produce_different_signatures(self) -> None:
case_a = self._case({"X": InputValueConstraint(arr_2x3).to_dict()})
case_b = self._case({"X": InputValueConstraint(arr_3x2).to_dict()})

assert _compute_case_signature(case_a, namespace="") != _compute_case_signature(
assert compute_case_signature(case_a, namespace="") != compute_case_signature(
case_b, namespace=""
)

Expand All @@ -128,19 +128,19 @@ def test_different_shapes_produce_different_signatures(self) -> None:
def test_namespace_is_included_in_signature(self) -> None:
"""Different namespaces produce different signatures for the same case."""
case = self._case({"X": InputValueConstraint(np.ones((2,), dtype=np.float32)).to_dict()})
assert _compute_case_signature(case, namespace="file_a") != _compute_case_signature(
assert compute_case_signature(case, namespace="file_a") != compute_case_signature(
case, namespace="file_b"
)

def test_empty_namespace_omitted_from_signature(self) -> None:
"""An empty namespace string does not appear in the signature."""
case = {"type_vars": {"T": "FLOAT"}}
assert "ns:" not in _compute_case_signature(case, namespace="")
assert "ns:" not in compute_case_signature(case, namespace="")

def test_empty_attrs_excluded_from_signature(self) -> None:
"""An empty attrs dict does not affect the signature."""
base = {"type_vars": {"T": "FLOAT"}}
with_empty_attrs = {**base, "attrs": {}}
assert _compute_case_signature(base, namespace="") == _compute_case_signature(
assert compute_case_signature(base, namespace="") == compute_case_signature(
with_empty_attrs, namespace=""
)
Loading