diff --git a/.pipelines/modelkit-official-build.yml b/.pipelines/modelkit-official-build.yml
index 6cd8243f4..e975a838e 100644
--- a/.pipelines/modelkit-official-build.yml
+++ b/.pipelines/modelkit-official-build.yml
@@ -72,7 +72,7 @@ extends:
displayName: 'Check Python version'
- powershell: |
- $rulesDir = "$(Build.SourcesDirectory)\ModelKitArtifacts\op_check_results\rules"
+ $rulesDir = "$(Build.SourcesDirectory)\ModelKitArtifacts\rules_zip"
$destDir = "$(Build.SourcesDirectory)\src\winml\modelkit\analyze\rules\runtime_check_rules"
$outDir = "$(ob_outputDirectory)"
New-Item -ItemType Directory -Path $outDir -Force | Out-Null
diff --git a/scripts/materialize_rules_zip.py b/scripts/materialize_rules_zip.py
new file mode 100644
index 000000000..6f08487db
--- /dev/null
+++ b/scripts/materialize_rules_zip.py
@@ -0,0 +1,94 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+"""Materialize runtime rule zip snapshots into full JSON payloads.
+
+This script resolves ``delta_v1`` snapshot chains and rewrites each JSON payload
+as a full dictionary without snapshot metadata keys.
+
+Usage:
+ uv run python scripts/materialize_rules_zip.py --rules-dir
+ uv run python scripts/materialize_rules_zip.py --rules-dir --output-dir
+"""
+
+from __future__ import annotations
+
+import argparse
+from pathlib import Path
+import sys
+
+
+def _load_materializer():
+ """Load materializer utility from package, with src fallback in repo mode."""
+ try:
+ from winml.modelkit.analyze.utils.rule_expander import expand_rules_zip_dir
+
+ return expand_rules_zip_dir
+ except ModuleNotFoundError:
+ repo_root = Path(__file__).resolve().parent.parent
+ src_path = repo_root / "src"
+ if str(src_path) not in sys.path:
+ sys.path.insert(0, str(src_path))
+
+ from winml.modelkit.analyze.utils.rule_expander import expand_rules_zip_dir
+
+ return expand_rules_zip_dir
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description=(
+ "Materialize delta snapshots in runtime rule zips into full JSON payloads "
+ "(remove baseline dependencies)."
+ )
+ )
+ parser.add_argument(
+ "--rules-dir",
+ type=Path,
+ required=True,
+ help="Directory containing runtime rule zip files.",
+ )
+ parser.add_argument(
+ "--output-dir",
+ type=Path,
+ default=None,
+ help=(
+ "Output directory for materialized zips. "
+ "When omitted, zips are overwritten in-place."
+ ),
+ )
+ parser.add_argument(
+ "--glob",
+ type=str,
+ default="*.zip",
+ help="Filename glob used to select zip files (default: *.zip).",
+ )
+ args = parser.parse_args()
+ expand_rules_zip_dir = _load_materializer()
+ summary = expand_rules_zip_dir(
+ args.rules_dir,
+ output_dir=args.output_dir,
+ glob_pattern=args.glob,
+ )
+
+ if not summary.per_zip:
+ print(f"No zip files matched '{args.glob}' in {args.rules_dir.resolve()}")
+ return
+
+ for zip_name, json_count, materialized_count in summary.per_zip:
+ print(
+ f"[{zip_name}] json_entries={json_count}, "
+ f"materialized_delta_entries={materialized_count}"
+ )
+
+ print("\nDone.")
+ print(f" zip_files_processed: {summary.zip_files_processed}")
+ print(f" zip_files_with_delta: {summary.zip_files_with_delta}")
+ print(f" json_entries_processed: {summary.json_entries_processed}")
+ print(f" delta_entries_materialized: {summary.delta_entries_materialized}")
+ print(f" output_mode: {summary.output_mode}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/winml/modelkit/analyze/core/runtime_checker_query.py b/src/winml/modelkit/analyze/core/runtime_checker_query.py
index a54df2674..8ac93183b 100644
--- a/src/winml/modelkit/analyze/core/runtime_checker_query.py
+++ b/src/winml/modelkit/analyze/core/runtime_checker_query.py
@@ -9,6 +9,7 @@
import json
import logging
+import re
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -67,6 +68,14 @@
EG_RULE_DEBUG_DETAILS_KEY = "__debug_details"
EG_RULE_ERROR_KEY = "__error"
+# Snapshot metadata keys used by runtime-check rule artifacts.
+SNAPSHOT_TYPE_KEY = "__snapshot_type__"
+SNAPSHOT_TYPE_DELTA = "delta_v1"
+SNAPSHOT_BASE_OPSET_KEY = "__base_opset__"
+SNAPSHOT_CURRENT_OPSET_KEY = "__current_opset__"
+SNAPSHOT_CHANGED_KEY = "__changed__"
+SNAPSHOT_DELETED_KEY = "__deleted__"
+
class _PseudoNode:
"""Lightweight stand-in for onnx.NodeProto used only for logging in _check_negative_rules."""
@@ -116,6 +125,110 @@ def _sanitize_df(df: pd.DataFrame) -> pd.DataFrame:
return df
+def _replace_opset_token(name: str, new_opset: int) -> str:
+ """Replace the first `_opset` token in a file name."""
+ return re.sub(r"_opset\d+", f"_opset{new_opset}", name, count=1)
+
+
+def _read_json_from_zip(zip_path: Path, file_name: str) -> dict[str, Any] | None:
+ """Read a JSON object from zip entry, returning None if not found/readable."""
+ if not zip_path.exists():
+ return None
+
+ import zipfile
+
+ try:
+ with zipfile.ZipFile(zip_path, "r") as zf:
+ if file_name not in zf.namelist():
+ return None
+ raw = json.loads(zf.read(file_name).decode("utf-8"))
+ except Exception as e:
+ logger.debug("Failed to read %s from %s: %s", file_name, zip_path, e)
+ return None
+
+ if not isinstance(raw, dict):
+ logger.debug("Unexpected non-dict payload in %s from %s", file_name, zip_path)
+ return None
+ return raw
+
+
+def _apply_snapshot_delta(
+ base_payload: dict[str, Any], delta_payload: dict[str, Any]
+) -> dict[str, Any]:
+ """Apply changed/deleted entries from delta payload onto base payload."""
+ merged = dict(base_payload)
+ changed = delta_payload.get(SNAPSHOT_CHANGED_KEY, {})
+ if isinstance(changed, dict):
+ merged.update(changed)
+
+ deleted = delta_payload.get(SNAPSHOT_DELETED_KEY, [])
+ if isinstance(deleted, list):
+ for key in deleted:
+ if isinstance(key, str):
+ merged.pop(key, None)
+
+ return merged
+
+
+def _expand_snapshot_payload(
+ payload: dict[str, Any],
+ zip_path: Path,
+ file_name: str,
+ visited: set[tuple[str, str]] | None = None,
+) -> dict[str, Any]:
+ """Expand a snapshot payload to full materialized data.
+
+ Full snapshots are returned as-is.
+ Delta snapshots are recursively applied on top of their base opset snapshot.
+ """
+ snapshot_type = payload.get(SNAPSHOT_TYPE_KEY)
+ if snapshot_type != SNAPSHOT_TYPE_DELTA:
+ return payload
+
+ base_opset = payload.get(SNAPSHOT_BASE_OPSET_KEY)
+ if not isinstance(base_opset, int):
+ logger.warning(
+ "Delta snapshot %s in %s missing integer %s; "
+ "applying delta on empty base.",
+ file_name,
+ zip_path,
+ SNAPSHOT_BASE_OPSET_KEY,
+ )
+ return _apply_snapshot_delta({}, payload)
+
+ token = (str(zip_path), file_name)
+ if visited is None:
+ visited = set()
+ if token in visited:
+ logger.warning(
+ "Detected cyclic snapshot chain while loading %s from %s; "
+ "applying delta on empty base.",
+ file_name,
+ zip_path,
+ )
+ return _apply_snapshot_delta({}, payload)
+
+ visited.add(token)
+
+ base_file_name = _replace_opset_token(file_name, base_opset)
+ base_zip_name = _replace_opset_token(zip_path.name, base_opset)
+ base_zip_path = resolve_rule_zip_path(base_zip_name)
+
+ base_raw = _read_json_from_zip(base_zip_path, base_file_name)
+ if base_raw is None:
+ logger.warning(
+ "Base snapshot %s not found in %s (from %s); "
+ "applying delta on empty base.",
+ base_file_name,
+ base_zip_path,
+ file_name,
+ )
+ return _apply_snapshot_delta({}, payload)
+
+ base_full = _expand_snapshot_payload(base_raw, base_zip_path, base_file_name, visited)
+ return _apply_snapshot_delta(base_full, payload)
+
+
def _load_table_columns_file(zip_path: Path, columns_file_name: str | None) -> dict[str, list[str]]:
"""Load per-op table column metadata from a runtime-rules zip file.
@@ -126,22 +239,12 @@ def _load_table_columns_file(zip_path: Path, columns_file_name: str | None) -> d
if not columns_file_name or not zip_path.exists():
return {}
- try:
- import zipfile
-
- with zipfile.ZipFile(zip_path, "r") as zf:
- if columns_file_name not in zf.namelist():
- return {}
- raw_columns = json.loads(zf.read(columns_file_name).decode("utf-8"))
- except Exception as e:
- logger.debug(
- "Failed to load column file %s from %s: %s",
- columns_file_name,
- zip_path,
- e,
- )
+ raw_columns = _read_json_from_zip(zip_path, columns_file_name)
+ if raw_columns is None:
return {}
+ raw_columns = _expand_snapshot_payload(raw_columns, zip_path, columns_file_name)
+
if not isinstance(raw_columns, dict):
return {}
@@ -178,6 +281,8 @@ def __init__(
for op_name, columns in (preloaded_columns or {}).items()
}
self._loaded_tables: dict[str, pd.DataFrame] = {}
+ self._base_tables: LazyDomainTables | None = None
+ self._deleted_keys: set[str] = set()
self._loaded: bool = False
self._columns_loaded: bool = bool(self._raw_columns) or self._columns_file_name is None
@@ -193,7 +298,6 @@ def _ensure_loaded(self) -> None:
if self._loaded:
return
self._loaded = True
- import zipfile
if not self._zip_path.exists():
logger.warning(
@@ -203,23 +307,61 @@ def _ensure_loaded(self) -> None:
self._zip_path,
)
return
- try:
- with zipfile.ZipFile(self._zip_path, "r") as zf:
- if self._file_name in zf.namelist():
- self._raw_data = json.loads(zf.read(self._file_name).decode("utf-8"))
- else:
- logger.debug(f"Table file not found in zip: {self._file_name}")
- except Exception as e:
- logger.debug(f"Failed to load table file {self._file_name}: {e}")
+
+ raw_table_payload = _read_json_from_zip(self._zip_path, self._file_name)
+ if raw_table_payload is None:
+ logger.debug(f"Table file not found in zip: {self._file_name}")
+ return
+
+ if raw_table_payload.get(SNAPSHOT_TYPE_KEY) == SNAPSHOT_TYPE_DELTA:
+ changed = raw_table_payload.get(SNAPSHOT_CHANGED_KEY, {})
+ deleted = raw_table_payload.get(SNAPSHOT_DELETED_KEY, [])
+ self._raw_data = changed if isinstance(changed, dict) else {}
+ if isinstance(deleted, list):
+ self._deleted_keys = {key for key in deleted if isinstance(key, str)}
+ else:
+ self._deleted_keys = set()
+
+ base_opset = raw_table_payload.get(SNAPSHOT_BASE_OPSET_KEY)
+ if isinstance(base_opset, int):
+ base_file_name = _replace_opset_token(self._file_name, base_opset)
+ base_zip_name = _replace_opset_token(self._zip_path.name, base_opset)
+ base_zip_path = resolve_rule_zip_path(base_zip_name)
+ base_columns_file_name = (
+ _replace_opset_token(self._columns_file_name, base_opset)
+ if self._columns_file_name
+ else None
+ )
+ self._base_tables = LazyDomainTables(
+ base_zip_path,
+ base_file_name,
+ columns_file_name=base_columns_file_name,
+ )
+ else:
+ logger.warning(
+ "Delta table snapshot %s in %s missing integer %s; "
+ "base fallback disabled for this table.",
+ self._file_name,
+ self._zip_path,
+ SNAPSHOT_BASE_OPSET_KEY,
+ )
+ return
+
+ self._raw_data = raw_table_payload
def __getitem__(self, key: str) -> pd.DataFrame:
"""Get table for operator, loading from zip if needed."""
if key not in self._loaded_tables:
self._ensure_loaded()
- if key not in self._raw_data:
+ if key in self._raw_data:
+ self._loaded_tables[key] = _sanitize_df(build_table_df(self._raw_data[key]))
+ del self._raw_data[key]
+ elif key in self._deleted_keys:
+ raise KeyError(f"Operator '{key}' not found in tables")
+ elif self._base_tables and key in self._base_tables:
+ self._loaded_tables[key] = self._base_tables[key]
+ else:
raise KeyError(f"Operator '{key}' not found in tables")
- self._loaded_tables[key] = _sanitize_df(build_table_df(self._raw_data[key]))
- del self._raw_data[key]
return self._loaded_tables[key]
def __contains__(self, key: str) -> bool:
@@ -227,7 +369,13 @@ def __contains__(self, key: str) -> bool:
if key in self._loaded_tables:
return True
self._ensure_loaded()
- return key in self._raw_data
+ if key in self._raw_data:
+ return True
+ if key in self._deleted_keys:
+ return False
+ if self._base_tables is not None:
+ return key in self._base_tables
+ return False
def get(self, key: str, default: pd.DataFrame | None = None) -> pd.DataFrame | None:
"""Get table for operator with default fallback."""
@@ -253,6 +401,12 @@ def get_columns(self, key: str) -> list[str] | None:
return list(self._raw_columns[key])
self._ensure_loaded()
+ if key in self._deleted_keys:
+ return None
+ if self._base_tables is not None:
+ base_columns = self._base_tables.get_columns(key)
+ if base_columns is not None:
+ return base_columns
raw_table = self._raw_data.get(key)
if isinstance(raw_table, dict):
@@ -293,8 +447,6 @@ def _ensure_loaded(self) -> None:
self._load()
def _load(self) -> None:
- import zipfile
-
if not self._zip_path.exists():
if self._set_error_on_missing:
self[EG_RULE_ERROR_KEY] = "rules_zip_not_found"
@@ -307,17 +459,17 @@ def _load(self) -> None:
)
return
- with zipfile.ZipFile(self._zip_path, "r") as zf:
- if self._rule_file not in zf.namelist():
- if self._set_error_on_missing:
- self[EG_RULE_ERROR_KEY] = "negative_rule_file_not_found"
- self[EG_RULE_DEBUG_DETAILS_KEY] = str(self._rule_file)
- logger.warning(f"Negative rule file not found: {self._rule_file}")
- else:
- logger.debug(f"Negative rule file not found: {self._rule_file}")
- return
+ raw_payload = _read_json_from_zip(self._zip_path, self._rule_file)
+ if raw_payload is None:
+ if self._set_error_on_missing:
+ self[EG_RULE_ERROR_KEY] = "negative_rule_file_not_found"
+ self[EG_RULE_DEBUG_DETAILS_KEY] = str(self._rule_file)
+ logger.warning(f"Negative rule file not found: {self._rule_file}")
+ else:
+ logger.debug(f"Negative rule file not found: {self._rule_file}")
+ return
- raw: dict[str, Any] = json.loads(zf.read(self._rule_file).decode("utf-8"))
+ raw = _expand_snapshot_payload(raw_payload, self._zip_path, self._rule_file)
filtered: dict[str, Any] = {
key: value
diff --git a/src/winml/modelkit/analyze/rules/runtime_check_rules/README.md b/src/winml/modelkit/analyze/rules/runtime_check_rules/README.md
index e72a217a3..46193d98d 100644
--- a/src/winml/modelkit/analyze/rules/runtime_check_rules/README.md
+++ b/src/winml/modelkit/analyze/rules/runtime_check_rules/README.md
@@ -26,10 +26,40 @@ Copy all `*.zip` files from [`gim-home/ModelKitArtifacts/op_check_results/rules/
Set `MODELKIT_RULES_DIR` to one or more directories containing runtime rule zip files.
-- Windows (PowerShell): `$env:MODELKIT_RULES_DIR="D:\\rules;E:\\more_rules"`
+Important: relative paths are resolved from `src/winml/modelkit/analyze/utils/` (the directory of `rule_loader.py`), not from your current terminal working directory.
+
+- Windows (PowerShell, user-level absolute path): `[Environment]::SetEnvironmentVariable("MODELKIT_RULES_DIR", "C:\*path*\rules_zip", "User")`
+- Windows (PowerShell, user-level repo-relative path): `[Environment]::SetEnvironmentVariable("MODELKIT_RULES_DIR", "..\..\..\..\..\..\ModelKitArtifacts\rules_zip", "User")`
Multiple directories are supported using `os.pathsep` (`;` on Windows, `:` on Unix-like systems).
+### Option 4: Expand rule zips via CLI command
+
+You can materialize delta snapshots to full payloads in-place with:
+
+```bash
+winml expand_rules
+```
+
+This command reads all entries from `MODELKIT_RULES_DIR`, resolves each via
+`_resolve_env_rules_dir_entry`, and performs in-place rewrite for each existing
+directory that contains matching zip files.
+
+After a folder is successfully expanded (and has at least one matching zip),
+an empty marker file named `expanded` is created in that folder.
+
+You can also override the path entry:
+
+```bash
+winml expand_rules --rules-dir-entry C:\path\to\rules_zip
+```
+
+Multiple explicit entries are supported:
+
+```bash
+winml expand_rules --rules-dir-entry C:\path\a --rules-dir-entry C:\path\b
+```
+
## Rule zip lookup order
The analyzer searches zip files in this order:
diff --git a/src/winml/modelkit/analyze/runtime_checker/result_processor.py b/src/winml/modelkit/analyze/runtime_checker/result_processor.py
index a39bd89c8..43df61b3f 100644
--- a/src/winml/modelkit/analyze/runtime_checker/result_processor.py
+++ b/src/winml/modelkit/analyze/runtime_checker/result_processor.py
@@ -23,6 +23,64 @@
from ..utils.rule_loader import get_runtime_rules_search_dirs
+# Snapshot metadata keys used in generated rule artifacts.
+SNAPSHOT_TYPE_KEY = "__snapshot_type__"
+SNAPSHOT_TYPE_DELTA = "delta_v1"
+SNAPSHOT_BASE_OPSET_KEY = "__base_opset__"
+SNAPSHOT_CURRENT_OPSET_KEY = "__current_opset__"
+SNAPSHOT_CHANGED_KEY = "__changed__"
+SNAPSHOT_DELETED_KEY = "__deleted__"
+
+
+def _sorted_dict_by_key(payload: dict[str, Any]) -> dict[str, Any]:
+ """Return a shallow key-sorted dict for stable JSON output."""
+ return dict(sorted(payload.items()))
+
+
+def _build_snapshot_payload(
+ current_payload: dict[str, Any],
+ current_opset: int,
+ previous_payload: dict[str, Any] | None,
+ previous_opset: int | None,
+) -> dict[str, Any]:
+ """Build either a full snapshot (first version) or a delta snapshot.
+
+ Full snapshots keep backward compatibility with existing plain-dict format.
+ Delta snapshots store only changed/deleted operators relative to the previous opset.
+ """
+ if previous_payload is None or previous_opset is None:
+ return _sorted_dict_by_key(current_payload)
+
+ changed = {
+ op_name: value
+ for op_name, value in current_payload.items()
+ if op_name not in previous_payload or previous_payload[op_name] != value
+ }
+ deleted = sorted(op_name for op_name in previous_payload if op_name not in current_payload)
+
+ return {
+ SNAPSHOT_TYPE_KEY: SNAPSHOT_TYPE_DELTA,
+ SNAPSHOT_BASE_OPSET_KEY: previous_opset,
+ SNAPSHOT_CURRENT_OPSET_KEY: current_opset,
+ SNAPSHOT_CHANGED_KEY: _sorted_dict_by_key(changed),
+ SNAPSHOT_DELETED_KEY: deleted,
+ }
+
+
+def _is_delta_snapshot_payload(payload: Any) -> bool:
+ return isinstance(payload, dict) and payload.get(SNAPSHOT_TYPE_KEY) == SNAPSHOT_TYPE_DELTA
+
+
+def _can_append_merge(existing_payload: Any, new_payload: Any) -> bool:
+ """Whether append-mode shallow dict merge is safe for these payloads."""
+ return (
+ isinstance(existing_payload, dict)
+ and isinstance(new_payload, dict)
+ and not _is_delta_snapshot_payload(existing_payload)
+ and not _is_delta_snapshot_payload(new_payload)
+ )
+
+
def _get_input_constraint_types(
check_results: list[dict[str, Any]],
) -> dict[str, str]:
@@ -552,18 +610,12 @@ def get_opset_version_range(op_name: str, start_opset_version: int, op_domain: s
import traceback
parser = argparse.ArgumentParser(
- description="Process runtime checker results and generate negative rules"
+ description=(
+ "Process runtime checker results and generate negative rules for "
+ "ai.onnx opset 12-22 and com.microsoft opset 1"
+ )
)
parser.add_argument("input_dir", type=str, help="Input directory containing JSON result files")
- parser.add_argument(
- "--opset_version", type=int, required=True, help="Opset version for the ONNX operators"
- )
- parser.add_argument(
- "--opset_domain",
- type=str,
- required=True,
- help="Opset domain for the ONNX operators (e.g., 'ai.onnx', 'com.microsoft')",
- )
parser.add_argument(
"--output-dir", type=str, help="Output directory for negative rules (defaults to input_dir)"
)
@@ -589,18 +641,13 @@ def get_opset_version_range(op_name: str, start_opset_version: int, op_domain: s
"(../rules/runtime_check_rules).",
)
parser.add_argument(
- "-range",
- "--opset_range_ref_op",
+ "--domains",
type=str,
- default=None,
- help="Reference operator name or end opset version number. "
- "When a number N is provided, processes all opset versions in "
- "[--opset_version, N] (inclusive). "
- "When an operator name is provided, computes the range of opset versions "
- "that share the same since_version for this op, starting from --opset_version. "
- "Example: --opset_range_ref_op 12 --opset_version 11 processes versions 11-12. "
- "Example: --opset_range_ref_op Slice --opset_version 11 processes versions 11-12 "
- "since Slice has since_versions 1, 10, 11, 13.",
+ default="ai.onnx,com.microsoft",
+ help=(
+ "Comma-separated domains to process from defaults (ai.onnx,com.microsoft). "
+ "Invalid or unsupported values are ignored."
+ ),
)
args = parser.parse_args()
@@ -608,286 +655,363 @@ def get_opset_version_range(op_name: str, start_opset_version: int, op_domain: s
output_dir = Path(args.output_dir) if args.output_dir else input_dir
output_dir.mkdir(parents=True, exist_ok=True)
- # Normalize the opset_domain (ai.onnx -> empty string for ONNX standard)
- target_domain = "" if args.opset_domain == "ai.onnx" else args.opset_domain
- domain_str_for_filename = args.opset_domain # Keep original for filename matching
-
json_files = list(input_dir.glob("*.json"))
if not json_files:
print(f"No JSON files found in {input_dir}")
exit(1)
- # Extract unique (op_name, ep_name, device, is_qdq)
- # combinations from filenames for the target domain
- # Filename format: ____opset[_qdq].json
import re
- op_info_set: set[tuple[str, str, str, bool]] = set()
- for json_file in json_files:
- is_qdq = json_file.stem.endswith("_qdq")
- # Remove opset suffix to get base info
- opset_match = re.search(r"_opset(\d+)(?:_qdq)?$", json_file.stem)
- if opset_match:
- filename_without_opset = json_file.stem[: opset_match.start()]
- parts = filename_without_opset.split("_")
- if len(parts) == 4:
- op_name, ep_name, device, file_domain = parts[:4]
- # Only include operators from the target domain
- if file_domain == domain_str_for_filename:
- op_info_set.add((op_name, ep_name, device, is_qdq))
-
- print(f"Found {len(op_info_set)} unique operators to process for domain '{args.opset_domain}'")
-
- # Determine which opset versions to process
- if args.opset_range_ref_op:
- if args.opset_range_ref_op.isdigit():
- end_opset = int(args.opset_range_ref_op)
- opset_versions_to_process = list(range(args.opset_version, end_opset + 1))
- print(f"Numeric range: will process opset versions {opset_versions_to_process}")
+ domain_plans: dict[str, list[int]] = {
+ "ai.onnx": list(range(12, 23)),
+ "com.microsoft": [1],
+ }
+
+ requested_domains = [part.strip() for part in args.domains.split(",") if part.strip()]
+ if not requested_domains:
+ requested_domains = ["ai.onnx", "com.microsoft"]
+
+ domains_to_process: list[str] = []
+ for requested in requested_domains:
+ normalized = requested.lower()
+ if normalized == "ai.onnx":
+ mapped_domain = "ai.onnx"
+ elif normalized == "com.microsoft":
+ mapped_domain = "com.microsoft"
else:
- opset_versions_to_process = get_opset_version_range(
- args.opset_range_ref_op, args.opset_version, target_domain
- )
print(
- f"Reference op '{args.opset_range_ref_op}' "
- f"with opset_version {args.opset_version}: "
- f"will process opset versions {opset_versions_to_process}"
+ f"Ignoring unsupported domain '{requested}'. "
+ "Supported values: ai.onnx, com.microsoft"
)
- else:
- opset_versions_to_process = [args.opset_version]
-
- qdq_generator = None
- if any(is_qdq for _, _, _, is_qdq in op_info_set):
- from ...pattern.op_input_gen.qdq_gen import QDQGenerator
-
- qdq_generator = QDQGenerator(1, ONNXDomain.COM_MICROSOFT)
-
- for current_opset_version in opset_versions_to_process:
- if len(opset_versions_to_process) > 1:
- print(f"\n{'=' * 60}")
- print(f"Processing opset version {current_opset_version}")
- print(f"{'=' * 60}")
-
- # Group results by (EP, device, domain, opset, is_qdq)
- results_by_ep_domain_opset: dict[tuple[str, str, str, int, bool], dict[str, Any]] = {}
- tables_by_ep_domain_opset: dict[tuple[str, str, str, int, bool], dict[str, Any]] = {}
- table_columns_by_ep_domain_opset: dict[
- tuple[str, str, str, int, bool], dict[str, list[str]]
- ] = {}
+ continue
- for op_name, ep_name, device, is_qdq in sorted(op_info_set):
- # Get the since_version for this operator based on
- # the current opset_version. Handle Op and Pattern.
- # TODO: build a since_version list for
- # PatternSchemas based on since_version of
- # included ops
- try:
- since_version = get_op_since_version(op_name, current_opset_version, target_domain)
- except SchemaError:
- since_version = current_opset_version
-
- # Build the expected filename with since_version
- qdq_suffix = "_qdq" if is_qdq else ""
- expected_filename = (
- f"{op_name}_{ep_name}_{device}"
- f"_{domain_str_for_filename}"
- f"_opset{since_version}{qdq_suffix}.json"
- )
- json_file = input_dir / expected_filename
- print(f"Processing {expected_filename}...", end=" ")
+ if mapped_domain not in domains_to_process:
+ domains_to_process.append(mapped_domain)
- if not json_file.exists():
- print(f"{Fore.YELLOW}SKIPPED: File not found. {Style.RESET_ALL}")
- continue
+ if not domains_to_process:
+ print("No valid domains selected to process.")
+ exit(1)
- if json_file.stat().st_size == 0:
- print(f"{Fore.YELLOW}SKIPPED: Empty JSON file. {Style.RESET_ALL}")
- continue
+ for domain_str_for_filename in domains_to_process:
+ target_domain = "" if domain_str_for_filename == "ai.onnx" else domain_str_for_filename
+ opset_versions_to_process = domain_plans[domain_str_for_filename]
+
+ # Extract unique (op_name, ep_name, device, is_qdq)
+ # combinations from filenames for the target domain
+ # Filename format: ____opset[_qdq].json
+ op_info_set: set[tuple[str, str, str, bool]] = set()
+ for json_file in json_files:
+ is_qdq = json_file.stem.endswith("_qdq")
+ # Remove opset suffix to get base info
+ opset_match = re.search(r"_opset(\d+)(?:_qdq)?$", json_file.stem)
+ if opset_match:
+ filename_without_opset = json_file.stem[: opset_match.start()]
+ parts = filename_without_opset.split("_")
+ if len(parts) == 4:
+ op_name, ep_name, device, file_domain = parts[:4]
+ if file_domain == domain_str_for_filename:
+ op_info_set.add((op_name, ep_name, device, is_qdq))
- try:
- with open(json_file, encoding="utf-8") as f: # noqa: PTH123
- data = json.load(f)
+ print(
+ f"Found {len(op_info_set)} unique operators to process "
+ f"for domain '{domain_str_for_filename}'"
+ )
- op_domain, op_name, ep_name, device, opset_version, is_qdq = _parse_filename(
- json_file.stem
+ if not op_info_set:
+ print(f"No operators found for domain '{domain_str_for_filename}', skipping.")
+ continue
+
+ qdq_generator = None
+ if any(is_qdq for _, _, _, is_qdq in op_info_set):
+ from ...pattern.op_input_gen.qdq_gen import QDQGenerator
+
+ qdq_generator = QDQGenerator(1, ONNXDomain.COM_MICROSOFT)
+
+ # Keep previous full payload per (ep, device, domain, is_qdq) for delta generation.
+ previous_negative_rules_payloads: dict[
+ tuple[str, str, str, bool], tuple[int, dict[str, Any]]
+ ] = {}
+ previous_tables_payloads: dict[tuple[str, str, str, bool], tuple[int, dict[str, Any]]] = {}
+ previous_columns_payloads: dict[tuple[str, str, str, bool], tuple[int, dict[str, Any]]] = {}
+
+ for current_opset_version in opset_versions_to_process:
+ if len(opset_versions_to_process) > 1:
+ print(f"\n{'=' * 60}")
+ print(f"Processing domain {domain_str_for_filename}, opset {current_opset_version}")
+ print(f"{'=' * 60}")
+
+ # Group results by (EP, device, domain, opset, is_qdq)
+ results_by_ep_domain_opset: dict[tuple[str, str, str, int, bool], dict[str, Any]] = {}
+ tables_by_ep_domain_opset: dict[tuple[str, str, str, int, bool], dict[str, Any]] = {}
+ table_columns_by_ep_domain_opset: dict[
+ tuple[str, str, str, int, bool], dict[str, list[str]]
+ ] = {}
+
+ for op_name, ep_name, device, is_qdq in sorted(op_info_set):
+ # Get the since_version for this operator based on
+ # the current opset_version. Handle Op and Pattern.
+ # TODO: build a since_version list for
+ # PatternSchemas based on since_version of
+ # included ops
+ try:
+ since_version = get_op_since_version(
+ op_name,
+ current_opset_version,
+ target_domain,
+ )
+ except SchemaError:
+ since_version = current_opset_version
+
+ # Build the expected filename with since_version
+ qdq_suffix = "_qdq" if is_qdq else ""
+ expected_filename = (
+ f"{op_name}_{ep_name}_{device}"
+ f"_{domain_str_for_filename}"
+ f"_opset{since_version}{qdq_suffix}.json"
)
+ json_file = input_dir / expected_filename
+ print(f"Processing {expected_filename}...", end=" ")
- check_results = data.get("check_results", [])
+ if not json_file.exists():
+ print(f"{Fore.YELLOW}SKIPPED: File not found. {Style.RESET_ALL}")
+ continue
- if not check_results:
- print(f"{Fore.RED}Error: No check_results found, skipping{Style.RESET_ALL}")
+ if json_file.stat().st_size == 0:
+ print(f"{Fore.YELLOW}SKIPPED: Empty JSON file. {Style.RESET_ALL}")
continue
- # Build negative rules and get DataFrame
- domain = ONNXDomain.from_str(op_domain)
try:
- schema = domain.get_op_schema(op_name, opset_version)
- input_generator = get_runtime_checker_op(op_name, domain=op_domain)(
- schema, qdq_generator=qdq_generator if is_qdq else None
+ with open(json_file, encoding="utf-8") as f: # noqa: PTH123
+ data = json.load(f)
+
+ op_domain, op_name, ep_name, device, opset_version, is_qdq = _parse_filename(
+ json_file.stem
)
- except SchemaError:
- # pattern case
- # TODO: if a pattern depends on multiple
- # domains, the filename currently contains
- # only AI_ONNX; need to recover all domains
- domain_versions = {
- op_domain: opset_version,
- ONNXDomain.COM_MICROSOFT: 1, # safeguard
- }
- input_generator = get_pattern_input_generator(op_name)(domain_versions)
-
- op_negative_rules, df = build_op_query_negative_rules_and_table(
- check_results,
- input_generator,
- use_qdq=is_qdq,
- op_version=opset_version,
- device=device,
- ep_name=ep_name,
- op_domain=op_domain,
+
+ check_results = data.get("check_results", [])
+
+ if not check_results:
+ print(f"{Fore.RED}Error: No check_results found, skipping{Style.RESET_ALL}")
+ continue
+
+ # Build negative rules and get DataFrame
+ domain = ONNXDomain.from_str(op_domain)
+ try:
+ schema = domain.get_op_schema(op_name, opset_version)
+ input_generator = get_runtime_checker_op(op_name, domain=op_domain)(
+ schema, qdq_generator=qdq_generator if is_qdq else None
+ )
+ except SchemaError:
+ # pattern case
+ # TODO: if a pattern depends on multiple
+ # domains, the filename currently contains
+ # only AI_ONNX; need to recover all domains
+ domain_versions = {
+ op_domain: opset_version,
+ ONNXDomain.COM_MICROSOFT: 1, # safeguard
+ }
+ input_generator = get_pattern_input_generator(op_name)(domain_versions)
+
+ op_negative_rules, df = build_op_query_negative_rules_and_table(
+ check_results,
+ input_generator,
+ use_qdq=is_qdq,
+ op_version=opset_version,
+ device=device,
+ ep_name=ep_name,
+ op_domain=op_domain,
+ )
+
+ # Group by (EP, domain, current opset_version, is_qdq)
+ key = (ep_name, device, target_domain, current_opset_version, is_qdq)
+ if key not in results_by_ep_domain_opset:
+ results_by_ep_domain_opset[key] = {}
+ tables_by_ep_domain_opset[key] = {}
+ table_columns_by_ep_domain_opset[key] = {}
+
+ results_by_ep_domain_opset[key][op_name] = op_negative_rules
+
+ # Convert DataFrame to JSON-serializable format
+ tables_by_ep_domain_opset[key][op_name] = df.to_dict()
+ table_columns_by_ep_domain_opset[key][op_name] = [
+ col_name
+ for col_name in df.columns.to_list()
+ if col_name != "compile_run_success"
+ ]
+
+ print(f"OK ({len(check_results)} results)")
+
+ except Exception as e:
+ print(f"{Fore.RED}ERROR: {e}{Style.RESET_ALL}")
+ traceback.print_exc()
+ sys.exit(1)
+
+ zip_group = {}
+ # Save negative rules
+ for (
+ ep_name,
+ device,
+ op_domain,
+ opset_version,
+ is_qdq,
+ ), op_results in results_by_ep_domain_opset.items():
+ # Create domain-specific filename
+ domain_str = op_domain if op_domain else "ai.onnx"
+ qdq_suffix = "_qdq" if is_qdq else ""
+ output_file = output_dir / (
+ f"{ep_name}_{device}_{domain_str}"
+ f"_opset{opset_version}"
+ f"_negative_rules{qdq_suffix}.json"
)
- # Group by (EP, domain, current opset_version, is_qdq)
- key = (ep_name, device, target_domain, current_opset_version, is_qdq)
- if key not in results_by_ep_domain_opset:
- results_by_ep_domain_opset[key] = {}
- tables_by_ep_domain_opset[key] = {}
- table_columns_by_ep_domain_opset[key] = {}
-
- results_by_ep_domain_opset[key][op_name] = op_negative_rules
-
- # Convert DataFrame to JSON-serializable format
- tables_by_ep_domain_opset[key][op_name] = df.to_dict()
- table_columns_by_ep_domain_opset[key][op_name] = [
- col_name
- for col_name in df.columns.to_list()
- if col_name != "compile_run_success"
- ]
-
- print(f"OK ({len(check_results)} results)")
-
- except Exception as e:
- print(f"{Fore.RED}ERROR: {e}{Style.RESET_ALL}")
- traceback.print_exc()
- sys.exit(1)
-
- zip_group = {}
- # Save negative rules
- for (
- ep_name,
- device,
- op_domain,
- opset_version,
- is_qdq,
- ), op_results in results_by_ep_domain_opset.items():
- # Create domain-specific filename
- domain_str = op_domain if op_domain else "ai.onnx"
- qdq_suffix = "_qdq" if is_qdq else ""
- output_file = output_dir / (
- f"{ep_name}_{device}_{domain_str}"
- f"_opset{opset_version}"
- f"_negative_rules{qdq_suffix}.json"
- )
+ snapshot_key = (ep_name, device, op_domain, is_qdq)
+ previous_snapshot = previous_negative_rules_payloads.get(snapshot_key)
+ snapshot_payload = _build_snapshot_payload(
+ op_results,
+ opset_version,
+ previous_snapshot[1] if previous_snapshot else None,
+ previous_snapshot[0] if previous_snapshot else None,
+ )
- with open(output_file, "w", encoding="utf-8", newline="\n") as f: # noqa: PTH123
- json.dump(dict(sorted(op_results.items())), f, indent=2)
-
- print(f"\nSaved {len(op_results)} operators to {output_file}")
- zip_group.setdefault(f"{ep_name}_{device}", []).append(output_file)
-
- # Save tables
- for (
- ep_name,
- device,
- op_domain,
- opset_version,
- is_qdq,
- ), op_tables in tables_by_ep_domain_opset.items():
- # Create domain-specific filename
- domain_str = op_domain if op_domain else "ai.onnx"
- qdq_suffix = "_qdq" if is_qdq else ""
- output_file = (
- output_dir
- / f"{ep_name}_{device}_{domain_str}_opset{opset_version}_tables{qdq_suffix}.json"
- )
+ with open(output_file, "w", encoding="utf-8", newline="\n") as f: # noqa: PTH123
+ json.dump(snapshot_payload, f, indent=2)
+
+ previous_negative_rules_payloads[snapshot_key] = (opset_version, op_results)
+
+ print(f"\nSaved {len(op_results)} operators to {output_file}")
+ zip_group.setdefault(f"{ep_name}_{device}", []).append(output_file)
+
+ # Save tables
+ for (
+ ep_name,
+ device,
+ op_domain,
+ opset_version,
+ is_qdq,
+ ), op_tables in tables_by_ep_domain_opset.items():
+ # Create domain-specific filename
+ domain_str = op_domain if op_domain else "ai.onnx"
+ qdq_suffix = "_qdq" if is_qdq else ""
+ output_file = (
+ output_dir
+ / (
+ f"{ep_name}_{device}_{domain_str}_opset{opset_version}_tables"
+ f"{qdq_suffix}.json"
+ )
+ )
- with open(output_file, "w", encoding="utf-8", newline="\n") as f: # noqa: PTH123
- json.dump(dict(sorted(op_tables.items())), f, indent=2)
-
- print(f"Saved {len(op_tables)} operator tables to {output_file}")
- zip_group.setdefault(f"{ep_name}_{device}", []).append(output_file)
-
- # Save table column names
- for (
- ep_name,
- device,
- op_domain,
- opset_version,
- is_qdq,
- ), op_columns in table_columns_by_ep_domain_opset.items():
- domain_str = op_domain if op_domain else "ai.onnx"
- qdq_suffix = "_qdq" if is_qdq else ""
- output_file = output_dir / (
- f"{ep_name}_{device}_{domain_str}"
- f"_opset{opset_version}_table_columns{qdq_suffix}.json"
- )
+ snapshot_key = (ep_name, device, op_domain, is_qdq)
+ previous_snapshot = previous_tables_payloads.get(snapshot_key)
+ snapshot_payload = _build_snapshot_payload(
+ op_tables,
+ opset_version,
+ previous_snapshot[1] if previous_snapshot else None,
+ previous_snapshot[0] if previous_snapshot else None,
+ )
- with open(output_file, "w", encoding="utf-8", newline="\n") as f: # noqa: PTH123
- json.dump(dict(sorted(op_columns.items())), f, indent=2)
+ with open(output_file, "w", encoding="utf-8", newline="\n") as f: # noqa: PTH123
+ json.dump(snapshot_payload, f, indent=2)
+
+ previous_tables_payloads[snapshot_key] = (opset_version, op_tables)
+
+ print(f"Saved {len(op_tables)} operator tables to {output_file}")
+ zip_group.setdefault(f"{ep_name}_{device}", []).append(output_file)
+
+ # Save table column names
+ for (
+ ep_name,
+ device,
+ op_domain,
+ opset_version,
+ is_qdq,
+ ), op_columns in table_columns_by_ep_domain_opset.items():
+ domain_str = op_domain if op_domain else "ai.onnx"
+ qdq_suffix = "_qdq" if is_qdq else ""
+ output_file = output_dir / (
+ f"{ep_name}_{device}_{domain_str}"
+ f"_opset{opset_version}_table_columns{qdq_suffix}.json"
+ )
- print(f"Saved {len(op_columns)} operator table column sets to {output_file}")
- zip_group.setdefault(f"{ep_name}_{device}", []).append(output_file)
+ snapshot_key = (ep_name, device, op_domain, is_qdq)
+ previous_snapshot = previous_columns_payloads.get(snapshot_key)
+ snapshot_payload = _build_snapshot_payload(
+ op_columns,
+ opset_version,
+ previous_snapshot[1] if previous_snapshot else None,
+ previous_snapshot[0] if previous_snapshot else None,
+ )
- print(
- f"\nProcessing complete! Generated "
- f"{len(results_by_ep_domain_opset)} "
- f"negative rule file(s) "
- f"and {len(tables_by_ep_domain_opset)} table file(s), "
- f"plus {len(table_columns_by_ep_domain_opset)} table-column file(s)."
- )
+ with open(output_file, "w", encoding="utf-8", newline="\n") as f: # noqa: PTH123
+ json.dump(snapshot_payload, f, indent=2)
- if args.update_zip:
- rules_dir = (
- Path(args.rules_dir) if args.rules_dir else get_runtime_rules_search_dirs()[0]
+ previous_columns_payloads[snapshot_key] = (opset_version, op_columns)
+
+ print(f"Saved {len(op_columns)} operator table column sets to {output_file}")
+ zip_group.setdefault(f"{ep_name}_{device}", []).append(output_file)
+
+ print(
+ f"\nDomain '{domain_str_for_filename}' processing complete! Generated "
+ f"{len(results_by_ep_domain_opset)} "
+ f"negative rule file(s) "
+ f"and {len(tables_by_ep_domain_opset)} table file(s), "
+ f"plus {len(table_columns_by_ep_domain_opset)} table-column file(s)."
)
- for group_name, file_list in zip_group.items():
- rule_zip_path = (
- rules_dir
- / f"{group_name}_{domain_str_for_filename}_opset{current_opset_version}.zip"
- )
- # In append mode, load existing zip entries to preserve files not being updated
- existing_content: dict[str, bytes] = {}
- if args.append and rule_zip_path.exists():
- with zipfile.ZipFile(rule_zip_path, mode="r") as existing_zf:
- for name in existing_zf.namelist():
- existing_content[name] = existing_zf.read(name)
-
- new_arcnames = {Path(f).name for f in file_list}
-
- with zipfile.ZipFile(
- rule_zip_path, mode="w", compression=zipfile.ZIP_DEFLATED
- ) as rule_zf:
- # Keep existing entries not covered by the new output
- for name, data in existing_content.items():
- if name not in new_arcnames:
- rule_zf.writestr(name, data)
-
- for filename in file_list:
- arcname = Path(filename).name
- if args.append and arcname in existing_content:
- # Merge: old dict updated with new dict, then sort
- old_dict = json.loads(existing_content[arcname])
- with open(filename, encoding="utf-8") as f: # noqa: PTH123
- new_dict = json.load(f)
- merged = dict(sorted({**old_dict, **new_dict}.items()))
- rule_zf.writestr(arcname, json.dumps(merged, indent=2))
- else:
- rule_zf.write(filename, arcname=arcname)
-
- print(
- f"Rule zip file {group_name}"
- f"_{domain_str_for_filename}"
- f"_opset{current_opset_version}.zip "
- f"updated with {len(file_list)} files."
+ if args.update_zip:
+ rules_dir = (
+ Path(args.rules_dir) if args.rules_dir else get_runtime_rules_search_dirs()[0]
)
+ rules_dir.mkdir(parents=True, exist_ok=True)
+ for group_name, file_list in zip_group.items():
+ rule_zip_path = (
+ rules_dir
+ / f"{group_name}_{domain_str_for_filename}_opset{current_opset_version}.zip"
+ )
+
+ # In append mode, load existing zip entries to preserve files not being updated
+ existing_content: dict[str, bytes] = {}
+ if args.append and rule_zip_path.exists():
+ with zipfile.ZipFile(rule_zip_path, mode="r") as existing_zf:
+ for name in existing_zf.namelist():
+ existing_content[name] = existing_zf.read(name)
+
+ new_arcnames = {Path(f).name for f in file_list}
+
+ with zipfile.ZipFile(
+ rule_zip_path, mode="w", compression=zipfile.ZIP_DEFLATED
+ ) as rule_zf:
+ # Keep existing entries not covered by the new output
+ for name, data in existing_content.items():
+ if name not in new_arcnames:
+ rule_zf.writestr(name, data)
+
+ for filename in file_list:
+ arcname = Path(filename).name
+ if args.append and arcname in existing_content:
+ # Legacy append merge is only valid for plain full snapshots.
+ with open(filename, encoding="utf-8") as f: # noqa: PTH123
+ new_text = f.read()
+ try:
+ old_payload = json.loads(
+ existing_content[arcname].decode("utf-8")
+ )
+ new_payload = json.loads(new_text)
+ except Exception:
+ rule_zf.writestr(arcname, new_text)
+ continue
+
+ if _can_append_merge(old_payload, new_payload):
+ merged = dict(sorted({**old_payload, **new_payload}.items()))
+ rule_zf.writestr(arcname, json.dumps(merged, indent=2))
+ else:
+ rule_zf.writestr(arcname, new_text)
+ else:
+ rule_zf.write(filename, arcname=arcname)
+
+ print(
+ f"Rule zip file {group_name}"
+ f"_{domain_str_for_filename}"
+ f"_opset{current_opset_version}.zip "
+ f"updated with {len(file_list)} files."
+ )
diff --git a/src/winml/modelkit/analyze/runtime_checker/runner.py b/src/winml/modelkit/analyze/runtime_checker/runner.py
index 4edb3486a..f5c53b442 100644
--- a/src/winml/modelkit/analyze/runtime_checker/runner.py
+++ b/src/winml/modelkit/analyze/runtime_checker/runner.py
@@ -194,25 +194,20 @@ def _join_process(proc: Any, timeout: float | None = None) -> None:
"""Best-effort process join that never raises."""
try:
proc.join(timeout=timeout)
- except Exception as e:
- print(f"Warning: failed to join process during executor shutdown: {e}", file=sys.stderr)
+ except Exception:
+ # Intentionally suppress cleanup-time join errors to preserve
+ # resilient shutdown semantics ("never raises").
+ pass
@staticmethod
def _kill_process(proc: Any) -> None:
"""Best-effort process kill that never raises."""
try:
proc.kill()
- except Exception as e:
- # Keep cleanup non-fatal, but surface the failure for diagnostics.
- print(f"Warning: failed to kill worker process: {e}", file=sys.stderr)
-
- @staticmethod
- def _close_process(proc: Any) -> None:
- """Best-effort process close that never raises."""
- try:
- proc.close()
- except Exception as ex:
- print(f"Warning: failed to close process: {ex}", file=sys.stderr)
+ except Exception:
+ # Intentionally ignored: process may already be dead or handle invalid
+ # during teardown, and cleanup must remain best-effort/non-fatal.
+ pass
def _shutdown_executor_two_phase(
self,
@@ -227,8 +222,8 @@ def _shutdown_executor_two_phase(
try:
executor.shutdown(wait=False, cancel_futures=cancel_futures)
except Exception as exc:
- # Best-effort shutdown: keep cleanup flow non-raising, but surface failure.
- print(f"Executor shutdown failed during cleanup: {exc}", file=sys.stderr)
+ # Best-effort shutdown: ignore failures here, but surface context for debugging.
+ print(f"[ResilientRunner] executor.shutdown failed: {exc}", file=sys.stderr)
timeout = (
self._GRACEFUL_SHUTDOWN_TIMEOUT_SEC
@@ -251,8 +246,16 @@ def _shutdown_executor_two_phase(
for proc in survivors:
self._join_process(proc, timeout=self._FORCED_KILL_JOIN_TIMEOUT_SEC)
- for proc in lingering:
- self._close_process(proc)
+ # Do not close multiprocessing.Process handles manually here. The
+ # ProcessPoolExecutor management thread may still call proc.join(), and
+ # prematurely closing handles can trigger WinError 6 on Windows.
+ try:
+ executor.shutdown(wait=True, cancel_futures=cancel_futures)
+ except Exception as exc:
+ print(
+ f"[ResilientRunner] executor.shutdown(wait=True) failed: {exc}",
+ file=sys.stderr,
+ )
def run(self, fn: Callable[[Any, Any], Any] | None = None, *args: Any) -> dict[str, Any]:
"""Execute the function on a single input with automatic retry on failure.
diff --git a/src/winml/modelkit/analyze/utils/rule_expander.py b/src/winml/modelkit/analyze/utils/rule_expander.py
new file mode 100644
index 000000000..e3ca5c165
--- /dev/null
+++ b/src/winml/modelkit/analyze/utils/rule_expander.py
@@ -0,0 +1,310 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+"""Expand runtime rule zip snapshots into full JSON payloads."""
+
+from __future__ import annotations
+
+import json
+import re
+import tempfile
+import zipfile
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any
+
+
+SNAPSHOT_TYPE_KEY = "__snapshot_type__"
+SNAPSHOT_TYPE_DELTA = "delta_v1"
+SNAPSHOT_BASE_OPSET_KEY = "__base_opset__"
+SNAPSHOT_CHANGED_KEY = "__changed__"
+SNAPSHOT_DELETED_KEY = "__deleted__"
+EXPANDED_MARKER_FILE = "expanded"
+
+_OPSET_TOKEN_PATTERN = re.compile(r"_opset\d+")
+
+
+@dataclass
+class ExpandSummary:
+ """Summary of an expand run."""
+
+ zip_files_processed: int
+ zip_files_with_delta: int
+ json_entries_processed: int
+ delta_entries_materialized: int
+ output_mode: str
+ per_zip: list[tuple[str, int, int]]
+
+
+def _replace_opset_token(name: str, new_opset: int) -> str:
+ """Replace the first ``_opset`` token in a name."""
+ return _OPSET_TOKEN_PATTERN.sub(f"_opset{new_opset}", name, count=1)
+
+
+def _is_delta_snapshot(payload: Any) -> bool:
+ """Check whether payload is a delta snapshot dict."""
+ return isinstance(payload, dict) and payload.get(SNAPSHOT_TYPE_KEY) == SNAPSHOT_TYPE_DELTA
+
+
+def _sorted_payload(payload: dict[str, Any]) -> dict[str, Any]:
+ """Return payload sorted by top-level keys for stable output."""
+ return dict(sorted(payload.items(), key=lambda item: item[0]))
+
+
+def _apply_delta(base_payload: dict[str, Any], delta_payload: dict[str, Any]) -> dict[str, Any]:
+ """Apply changed/deleted entries from delta payload onto base payload."""
+ changed = delta_payload.get(SNAPSHOT_CHANGED_KEY, {})
+ if not isinstance(changed, dict):
+ raise TypeError(f"Delta snapshot field '{SNAPSHOT_CHANGED_KEY}' must be a dict.")
+
+ deleted = delta_payload.get(SNAPSHOT_DELETED_KEY, [])
+ if not isinstance(deleted, list):
+ raise TypeError(f"Delta snapshot field '{SNAPSHOT_DELETED_KEY}' must be a list.")
+
+ merged = dict(base_payload)
+ merged.update(changed)
+
+ for key in deleted:
+ if not isinstance(key, str):
+ raise TypeError(
+ f"Delta snapshot field '{SNAPSHOT_DELETED_KEY}' contains non-string key: {key!r}"
+ )
+ merged.pop(key, None)
+
+ return _sorted_payload(merged)
+
+
+class SnapshotExpander:
+ """Resolve and expand snapshot payloads stored in runtime rules zip files."""
+
+ def __init__(self, rules_dir: Path) -> None:
+ self.rules_dir = rules_dir
+ self._cache: dict[tuple[str, str], dict[str, Any]] = {}
+
+ def _read_json_payload(self, zip_name: str, entry_name: str) -> dict[str, Any]:
+ zip_path = self.rules_dir / zip_name
+ if not zip_path.exists():
+ raise FileNotFoundError(f"Base zip not found: {zip_path}")
+
+ with zipfile.ZipFile(zip_path, "r") as zf:
+ try:
+ raw = zf.read(entry_name)
+ except KeyError as exc:
+ raise FileNotFoundError(
+ f"Base entry '{entry_name}' not found in zip '{zip_name}'"
+ ) from exc
+
+ try:
+ payload = json.loads(raw.decode("utf-8"))
+ except Exception as exc:
+ raise ValueError(f"Failed to parse JSON entry '{entry_name}' in '{zip_name}'") from exc
+
+ if not isinstance(payload, dict):
+ raise TypeError(
+ f"JSON entry '{entry_name}' in '{zip_name}' must be a dict, got {type(payload)}"
+ )
+ return payload
+
+ def resolve_payload(
+ self,
+ zip_name: str,
+ entry_name: str,
+ stack: set[tuple[str, str]] | None = None,
+ ) -> dict[str, Any]:
+ """Resolve an entry payload to full expanded contents."""
+ token = (zip_name, entry_name)
+ if token in self._cache:
+ return self._cache[token]
+
+ if stack is None:
+ stack = set()
+ if token in stack:
+ raise ValueError(f"Detected cyclic snapshot dependency at {zip_name}::{entry_name}")
+
+ stack.add(token)
+ payload = self._read_json_payload(zip_name, entry_name)
+
+ if _is_delta_snapshot(payload):
+ base_opset = payload.get(SNAPSHOT_BASE_OPSET_KEY)
+ if not isinstance(base_opset, int):
+ raise ValueError(
+ f"Delta snapshot in {zip_name}::{entry_name} is missing integer "
+ f"'{SNAPSHOT_BASE_OPSET_KEY}'"
+ )
+
+ base_zip_name = _replace_opset_token(zip_name, base_opset)
+ base_entry_name = _replace_opset_token(entry_name, base_opset)
+ if base_zip_name == zip_name and base_entry_name == entry_name:
+ raise ValueError(
+ "Could not derive base snapshot location for "
+ f"{zip_name}::{entry_name} (opset token not replaced)."
+ )
+
+ base_payload = self.resolve_payload(base_zip_name, base_entry_name, stack)
+ resolved = _apply_delta(base_payload, payload)
+ else:
+ resolved = _sorted_payload(payload)
+
+ stack.remove(token)
+ self._cache[token] = resolved
+ return resolved
+
+
+def _expand_single_zip(
+ zip_path: Path,
+ dest_path: Path,
+ expander: SnapshotExpander,
+) -> tuple[int, int]:
+ """Expand all delta JSON entries in one zip.
+
+ Returns:
+ Tuple of ``(json_entry_count, materialized_delta_count)``.
+ """
+ json_entries = 0
+ materialized_entries = 0
+
+ with zipfile.ZipFile(zip_path, "r") as src, zipfile.ZipFile(
+ dest_path,
+ "w",
+ compression=zipfile.ZIP_DEFLATED,
+ ) as dst:
+ for entry in src.infolist():
+ raw = src.read(entry.filename)
+ out_raw = raw
+
+ if entry.filename.endswith(".json"):
+ json_entries += 1
+ try:
+ payload = json.loads(raw.decode("utf-8"))
+ except Exception as exc:
+ raise ValueError(
+ f"Failed to parse JSON entry '{entry.filename}' in '{zip_path.name}'"
+ ) from exc
+
+ if _is_delta_snapshot(payload):
+ materialized = expander.resolve_payload(zip_path.name, entry.filename)
+ out_raw = (json.dumps(materialized, indent=2) + "\n").encode("utf-8")
+ materialized_entries += 1
+
+ dst.writestr(entry, out_raw)
+
+ return json_entries, materialized_entries
+
+
+def expand_rules_zip_dir(
+ rules_dir: Path,
+ *,
+ output_dir: Path | None = None,
+ glob_pattern: str = "*.zip",
+ marker_filename: str = EXPANDED_MARKER_FILE,
+) -> ExpandSummary:
+ """Expand delta snapshots in rule zips.
+
+ Args:
+ rules_dir: Directory containing runtime rule zip files.
+ output_dir: Optional output directory. If omitted, files are rewritten in place.
+ glob_pattern: Zip filename pattern to process.
+ marker_filename: Empty marker file name created after successful expand.
+
+ Returns:
+ ExpandSummary with per-zip stats.
+ """
+ rules_dir = rules_dir.resolve()
+ if not rules_dir.exists() or not rules_dir.is_dir():
+ raise FileNotFoundError(f"Rules directory not found: {rules_dir}")
+
+ # Ignore stale temp artifacts from prior interrupted in-place runs.
+ zip_files = [
+ path
+ for path in sorted(rules_dir.glob(glob_pattern), key=lambda p: p.name)
+ if ".materialized." not in path.name
+ ]
+ if not zip_files:
+ return ExpandSummary(
+ zip_files_processed=0,
+ zip_files_with_delta=0,
+ json_entries_processed=0,
+ delta_entries_materialized=0,
+ output_mode=(
+ f"in-place ({rules_dir})" if output_dir is None else f"copied to {output_dir}"
+ ),
+ per_zip=[],
+ )
+
+ output_mode = ""
+ if output_dir is not None:
+ output_dir = output_dir.resolve()
+ output_dir.mkdir(parents=True, exist_ok=True)
+ output_mode = f"copied to {output_dir}"
+ else:
+ output_mode = f"in-place ({rules_dir})"
+
+ expander = SnapshotExpander(rules_dir)
+
+ total_json = 0
+ total_materialized = 0
+ changed_zip_count = 0
+ per_zip: list[tuple[str, int, int]] = []
+
+ for zip_path in zip_files:
+ if output_dir is None:
+ with tempfile.NamedTemporaryFile(
+ mode="wb",
+ suffix=".zip",
+ prefix=f"{zip_path.stem}.materialized.",
+ dir=str(zip_path.parent),
+ delete=False,
+ ) as tmp:
+ tmp_path = Path(tmp.name)
+ try:
+ json_count, materialized_count = _expand_single_zip(
+ zip_path,
+ tmp_path,
+ expander,
+ )
+ tmp_path.replace(zip_path)
+ finally:
+ if tmp_path.exists():
+ tmp_path.unlink()
+ else:
+ dest = output_dir / zip_path.name
+ json_count, materialized_count = _expand_single_zip(zip_path, dest, expander)
+
+ per_zip.append((zip_path.name, json_count, materialized_count))
+ total_json += json_count
+ total_materialized += materialized_count
+ if materialized_count > 0:
+ changed_zip_count += 1
+
+ target_dir = output_dir if output_dir is not None else rules_dir
+ marker_path = target_dir / marker_filename
+ marker_path.touch(exist_ok=True)
+
+ return ExpandSummary(
+ zip_files_processed=len(zip_files),
+ zip_files_with_delta=changed_zip_count,
+ json_entries_processed=total_json,
+ delta_entries_materialized=total_materialized,
+ output_mode=output_mode,
+ per_zip=per_zip,
+ )
+
+
+# Backward-compatible aliases for existing imports.
+MaterializeSummary = ExpandSummary
+SnapshotMaterializer = SnapshotExpander
+
+
+def materialize_rules_zip_dir(
+ rules_dir: Path,
+ *,
+ output_dir: Path | None = None,
+ glob_pattern: str = "*.zip",
+) -> ExpandSummary:
+ """Backward-compatible wrapper for previous API name."""
+ return expand_rules_zip_dir(
+ rules_dir,
+ output_dir=output_dir,
+ glob_pattern=glob_pattern,
+ )
diff --git a/src/winml/modelkit/analyze/utils/rule_loader.py b/src/winml/modelkit/analyze/utils/rule_loader.py
index e42029a0d..0079a15c3 100644
--- a/src/winml/modelkit/analyze/utils/rule_loader.py
+++ b/src/winml/modelkit/analyze/utils/rule_loader.py
@@ -20,18 +20,87 @@
#: Use ``os.pathsep`` (`;` on Windows, `:` on Unix) to separate multiple paths.
MODELKIT_RULES_DIR_ENV = "MODELKIT_RULES_DIR"
+# Directory containing this module file. Relative env-var entries are resolved from here.
+_RULE_LOADER_DIR: Path = Path(__file__).resolve().parent
+
# Default runtime_check_rules directory (relative to the analyze package).
_DEFAULT_RUNTIME_RULES_DIR: Path = (
Path(__file__).resolve().parent.parent / "rules" / "runtime_check_rules"
)
+# Track directories already auto-checked in this process to avoid repeated scans/expands.
+_EXPAND_CHECKED_DIRS: set[str] = set()
+
+
+def _resolve_env_rules_dir_entry(entry: str) -> Path:
+ """Resolve a MODELKIT_RULES_DIR entry into an absolute directory path.
+
+ Absolute paths are used directly. Relative paths are interpreted relative
+ to this module file's directory.
+ """
+ entry_path = Path(entry).expanduser()
+ if entry_path.is_absolute():
+ return entry_path.resolve()
+ return (_RULE_LOADER_DIR / entry_path).resolve()
+
+
+def _has_non_temp_zip_files(rules_dir: Path, glob_pattern: str = "*.zip") -> bool:
+ """Return whether the directory contains at least one non-temp zip."""
+ return any(
+ path.is_file() and ".materialized." not in path.name
+ for path in rules_dir.glob(glob_pattern)
+ )
+
+
+def _ensure_rules_dir_expanded_once(rules_dir: Path) -> None:
+ """Auto-expand rule zips once if marker is missing.
+
+ Behavior:
+ 1. Skip when directory does not exist.
+ 2. Skip when marker file already exists.
+ 3. If directory has zip files and marker is missing, run in-place expand.
+ """
+ resolved_dir = rules_dir.resolve()
+ dir_key = str(resolved_dir).casefold()
+ if dir_key in _EXPAND_CHECKED_DIRS:
+ return
+
+ _EXPAND_CHECKED_DIRS.add(dir_key)
+
+ if not resolved_dir.exists() or not resolved_dir.is_dir():
+ return
+
+ try:
+ from .rule_expander import EXPANDED_MARKER_FILE, expand_rules_zip_dir
+
+ marker_path = resolved_dir / EXPANDED_MARKER_FILE
+ if marker_path.exists():
+ return
+
+ if not _has_non_temp_zip_files(resolved_dir):
+ return
+
+ logger.info(
+ "!!! [RULES INIT] One-time runtime rules initialization is required for %s; "
+ "initializing now (may take up to 30 minutes).",
+ resolved_dir,
+ )
+ expand_rules_zip_dir(resolved_dir)
+ except Exception:
+ logger.exception(
+ "Failed to auto-expand runtime rule zips in %s; "
+ "please check the zip files and expand manually if needed.",
+ resolved_dir,
+ )
+
def get_runtime_rules_search_dirs() -> list[Path]:
"""Return ordered list of directories to search for runtime check rule zips.
The search order is:
- 1. Any extra directories listed in the :data:`MODELKIT_RULES_DIR` env var
- (separated by ``os.pathsep``).
+ 1. Any extra directories listed in the :data:`MODELKIT_RULES_DIR` env var
+ (separated by ``os.pathsep``). Absolute paths are used directly;
+ relative paths are resolved relative to this module file directory.
2. Default embedded directory (``src/winml/modelkit/analyze/rules/runtime_check_rules/``)
Returns:
@@ -43,7 +112,7 @@ def get_runtime_rules_search_dirs() -> list[Path]:
for entry in env_val.split(os.pathsep):
entry = entry.strip()
if entry:
- dirs.append(Path(entry).resolve())
+ dirs.append(_resolve_env_rules_dir_entry(entry))
dirs.append(_DEFAULT_RUNTIME_RULES_DIR)
return dirs
@@ -62,6 +131,8 @@ def resolve_rule_zip_path(zip_filename: str) -> Path:
Resolved ``Path`` to the zip file.
"""
for search_dir in get_runtime_rules_search_dirs():
+ # Disabled by default: one-time rules initialization can be expensive.
+ # _ensure_rules_dir_expanded_once(search_dir)
candidate = search_dir / zip_filename
if candidate.exists():
return candidate
diff --git a/src/winml/modelkit/commands/expand_rules.py b/src/winml/modelkit/commands/expand_rules.py
new file mode 100644
index 000000000..158a572ea
--- /dev/null
+++ b/src/winml/modelkit/commands/expand_rules.py
@@ -0,0 +1,141 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+r"""Expand runtime rule zips in-place for faster loading.
+
+Resolves rules directories via ``_resolve_env_rules_dir_entry`` and, when
+each directory exists and contains zip files, rewrites them in-place to full
+payloads (no delta snapshot recursion at load time).
+
+Usage:
+ winml expand_rules
+ winml expand_rules --rules-dir-entry C:\\path\\to\\rules_zip
+"""
+
+from __future__ import annotations
+
+import os
+from typing import TYPE_CHECKING
+
+import click
+
+from ..analyze.utils.rule_expander import expand_rules_zip_dir
+from ..analyze.utils.rule_loader import MODELKIT_RULES_DIR_ENV, _resolve_env_rules_dir_entry
+
+
+if TYPE_CHECKING:
+ from pathlib import Path
+
+
+def _entries_from_env() -> list[str]:
+ """Read all non-empty MODELKIT_RULES_DIR entries from environment."""
+ env_val = os.environ.get(MODELKIT_RULES_DIR_ENV, "").strip()
+ if not env_val:
+ return []
+ return [entry.strip() for entry in env_val.split(os.pathsep) if entry.strip()]
+
+
+def _resolve_entries(entries: list[str]) -> list[Path]:
+ """Resolve entries to unique paths while preserving order."""
+ resolved_dirs: list[Path] = []
+ seen: set[str] = set()
+
+ for entry in entries:
+ resolved = _resolve_env_rules_dir_entry(entry)
+ key = str(resolved).casefold()
+ if key in seen:
+ continue
+ seen.add(key)
+ resolved_dirs.append(resolved)
+
+ return resolved_dirs
+
+
+@click.command("expand_rules")
+@click.option(
+ "--rules-dir-entry",
+ "rules_dir_entries",
+ type=str,
+ multiple=True,
+ help=(
+ "Optional rule directory entry. May be repeated. "
+ "If omitted, uses all entries from MODELKIT_RULES_DIR. "
+ "Each entry is resolved by rule_loader._resolve_env_rules_dir_entry."
+ ),
+)
+@click.option(
+ "--glob",
+ "glob_pattern",
+ type=str,
+ default="*.zip",
+ show_default=True,
+ help="Zip filename glob to process.",
+)
+def expand_rules(rules_dir_entries: tuple[str, ...], glob_pattern: str) -> None:
+ """Expand runtime rules zip files in-place when directories and zips exist."""
+ entries = list(rules_dir_entries) if rules_dir_entries else _entries_from_env()
+
+ if not entries:
+ click.echo(
+ f"{MODELKIT_RULES_DIR_ENV} is not set (or empty) "
+ "and no --rules-dir-entry provided, skip."
+ )
+ return
+
+ rules_dirs = _resolve_entries(entries)
+ if not rules_dirs:
+ click.echo("No resolvable rules directories found, skip.")
+ return
+
+ grand_zip_count = 0
+ grand_delta_zip_count = 0
+ grand_json_count = 0
+ grand_delta_count = 0
+
+ for rules_dir in rules_dirs:
+ if not rules_dir.exists() or not rules_dir.is_dir():
+ click.echo(f"Rules directory does not exist, skip: {rules_dir}")
+ continue
+
+ matched = [
+ path
+ for path in sorted(rules_dir.glob(glob_pattern), key=lambda p: p.name)
+ if ".materialized." not in path.name
+ ]
+ if not matched:
+ click.echo(f"No zip files matched '{glob_pattern}' in {rules_dir}, skip.")
+ continue
+
+ click.echo(f"Expanding {len(matched)} zip(s) in: {rules_dir}")
+
+ summary = expand_rules_zip_dir(
+ rules_dir,
+ output_dir=None,
+ glob_pattern=glob_pattern,
+ )
+
+ for zip_name, json_count, materialized_count in summary.per_zip:
+ click.echo(
+ f"[{zip_name}] json_entries={json_count}, "
+ f"materialized_delta_entries={materialized_count}"
+ )
+
+ click.echo("\nDone.")
+ click.echo(f" zip_files_processed: {summary.zip_files_processed}")
+ click.echo(f" zip_files_with_delta: {summary.zip_files_with_delta}")
+ click.echo(f" json_entries_processed: {summary.json_entries_processed}")
+ click.echo(f" delta_entries_materialized: {summary.delta_entries_materialized}")
+ click.echo(f" output_mode: {summary.output_mode}")
+
+ grand_zip_count += summary.zip_files_processed
+ grand_delta_zip_count += summary.zip_files_with_delta
+ grand_json_count += summary.json_entries_processed
+ grand_delta_count += summary.delta_entries_materialized
+
+ if grand_zip_count > 0 and len(rules_dirs) > 1:
+ click.echo("\nAggregate:")
+ click.echo(f" zip_files_processed: {grand_zip_count}")
+ click.echo(f" zip_files_with_delta: {grand_delta_zip_count}")
+ click.echo(f" json_entries_processed: {grand_json_count}")
+ click.echo(f" delta_entries_materialized: {grand_delta_count}")
diff --git a/tests/unit/analyze/core/test_lazy_domain_tables.py b/tests/unit/analyze/core/test_lazy_domain_tables.py
index 46934d9cd..e37c4788e 100644
--- a/tests/unit/analyze/core/test_lazy_domain_tables.py
+++ b/tests/unit/analyze/core/test_lazy_domain_tables.py
@@ -439,3 +439,111 @@ def test_list_values_are_made_hashable(self, tmp_path: Path):
# make_hashable converts lists to tuples
assert isinstance(value, tuple)
assert value == (1, 2, 3)
+
+
+class TestDeltaSnapshots:
+ """Test delta snapshot chaining for rules/tables/columns."""
+
+ def test_lazy_neg_rules_supports_delta_chain(
+ self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
+ ):
+ base_zip = tmp_path / "EP_CPU_ai.onnx_opset12.zip"
+ delta_zip = tmp_path / "EP_CPU_ai.onnx_opset13.zip"
+ base_file = "EP_CPU_ai.onnx_opset12_negative_rules.json"
+ delta_file = "EP_CPU_ai.onnx_opset13_negative_rules.json"
+
+ base_rules = {
+ "Conv": _make_op_rule("Conv"),
+ "Add": _make_op_rule("Add"),
+ }
+ delta_rules = {
+ "__snapshot_type__": "delta_v1",
+ "__base_opset__": 12,
+ "__current_opset__": 13,
+ "__changed__": {
+ "Mul": _make_op_rule("Mul"),
+ },
+ "__deleted__": ["Add"],
+ }
+
+ with zipfile.ZipFile(base_zip, "w") as zf:
+ zf.writestr(base_file, json.dumps(base_rules))
+ with zipfile.ZipFile(delta_zip, "w") as zf:
+ zf.writestr(delta_file, json.dumps(delta_rules))
+
+ monkeypatch.setenv("MODELKIT_RULES_DIR", str(tmp_path))
+
+ rules = _LazyNegRules(delta_zip, delta_file, REGISTERED_PATTERNS)
+ keys = set(rules.keys())
+ assert "Conv" in keys
+ assert "Mul" in keys
+ assert "Add" not in keys
+
+ def test_lazy_domain_tables_supports_delta_chain(
+ self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
+ ):
+ base_zip = tmp_path / "EP_CPU_ai.onnx_opset12.zip"
+ delta_zip = tmp_path / "EP_CPU_ai.onnx_opset13.zip"
+ base_tables_file = "EP_CPU_ai.onnx_opset12_tables.json"
+ delta_tables_file = "EP_CPU_ai.onnx_opset13_tables.json"
+ base_columns_file = "EP_CPU_ai.onnx_opset12_table_columns.json"
+ delta_columns_file = "EP_CPU_ai.onnx_opset13_table_columns.json"
+
+ base_tables = {
+ "Conv": RAW_DATA["Conv"],
+ "Add": RAW_DATA["Add"],
+ }
+ base_columns = {
+ "Conv": RAW_COLUMNS["Conv"],
+ "Add": RAW_COLUMNS["Add"],
+ }
+
+ delta_tables = {
+ "__snapshot_type__": "delta_v1",
+ "__base_opset__": 12,
+ "__current_opset__": 13,
+ "__changed__": {
+ "Mul": {
+ "A_shape": {0: (1, 3, 224, 224)},
+ "B_shape": {0: (1, 3, 224, 224)},
+ "compile_run_success": {0: (True, True)},
+ }
+ },
+ "__deleted__": ["Add"],
+ }
+ delta_columns = {
+ "__snapshot_type__": "delta_v1",
+ "__base_opset__": 12,
+ "__current_opset__": 13,
+ "__changed__": {
+ "Mul": ["A_shape", "B_shape"],
+ },
+ "__deleted__": ["Add"],
+ }
+
+ with zipfile.ZipFile(base_zip, "w") as zf:
+ zf.writestr(base_tables_file, json.dumps(base_tables))
+ zf.writestr(base_columns_file, json.dumps(base_columns))
+ with zipfile.ZipFile(delta_zip, "w") as zf:
+ zf.writestr(delta_tables_file, json.dumps(delta_tables))
+ zf.writestr(delta_columns_file, json.dumps(delta_columns))
+
+ monkeypatch.setenv("MODELKIT_RULES_DIR", str(tmp_path))
+
+ tables = LazyDomainTables(
+ delta_zip,
+ delta_tables_file,
+ columns_file_name=delta_columns_file,
+ )
+
+ assert "Conv" in tables # inherited from base
+ assert "Mul" in tables # changed in delta
+ assert "Add" not in tables # deleted in delta
+
+ conv_df = tables["Conv"]
+ mul_df = tables["Mul"]
+ assert isinstance(conv_df, pd.DataFrame)
+ assert isinstance(mul_df, pd.DataFrame)
+ assert tables.get_columns("Conv") == RAW_COLUMNS["Conv"]
+ assert tables.get_columns("Mul") == ["A_shape", "B_shape"]
+ assert tables.get_columns("Add") is None
diff --git a/tests/unit/analyze/models/test_rule_expander.py b/tests/unit/analyze/models/test_rule_expander.py
new file mode 100644
index 000000000..112b1c4bf
--- /dev/null
+++ b/tests/unit/analyze/models/test_rule_expander.py
@@ -0,0 +1,98 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+
+import json
+import zipfile
+from pathlib import Path
+
+from winml.modelkit.analyze.utils.rule_expander import EXPANDED_MARKER_FILE, expand_rules_zip_dir
+
+
+def _write_json_zip(zip_path: Path, entry_name: str, payload: dict) -> None:
+ with zipfile.ZipFile(zip_path, "w") as zf:
+ zf.writestr(entry_name, json.dumps(payload))
+
+
+def test_expand_rules_zip_dir_in_place(tmp_path: Path) -> None:
+ rules_dir = tmp_path / "rules"
+ rules_dir.mkdir(parents=True, exist_ok=True)
+
+ base_zip = rules_dir / "EP_CPU_ai.onnx_opset12.zip"
+ delta_zip = rules_dir / "EP_CPU_ai.onnx_opset13.zip"
+ base_entry = "EP_CPU_ai.onnx_opset12_negative_rules.json"
+ delta_entry = "EP_CPU_ai.onnx_opset13_negative_rules.json"
+
+ _write_json_zip(
+ base_zip,
+ base_entry,
+ {
+ "Conv": {"v": 1},
+ "Add": {"v": 2},
+ },
+ )
+ _write_json_zip(
+ delta_zip,
+ delta_entry,
+ {
+ "__snapshot_type__": "delta_v1",
+ "__base_opset__": 12,
+ "__current_opset__": 13,
+ "__changed__": {"Mul": {"v": 3}},
+ "__deleted__": ["Add"],
+ },
+ )
+
+ summary = expand_rules_zip_dir(rules_dir)
+
+ assert summary.zip_files_processed == 2
+ assert summary.zip_files_with_delta == 1
+ assert summary.delta_entries_materialized == 1
+ assert summary.output_mode.startswith("in-place")
+
+ with zipfile.ZipFile(delta_zip, "r") as zf:
+ payload = json.loads(zf.read(delta_entry).decode("utf-8"))
+
+ assert "__snapshot_type__" not in payload
+ assert sorted(payload.keys()) == ["Conv", "Mul"]
+ marker = rules_dir / EXPANDED_MARKER_FILE
+ assert marker.exists()
+ assert marker.is_file()
+
+
+def test_expand_rules_zip_dir_ignores_temp_materialized_zips(tmp_path: Path) -> None:
+ rules_dir = tmp_path / "rules"
+ rules_dir.mkdir(parents=True, exist_ok=True)
+
+ real_zip = rules_dir / "EP_CPU_ai.onnx_opset12.zip"
+ temp_zip = rules_dir / "EP_CPU_ai.onnx_opset13.materialized.abcd.zip"
+
+ _write_json_zip(real_zip, "EP_CPU_ai.onnx_opset12_negative_rules.json", {"Conv": {"v": 1}})
+ _write_json_zip(
+ temp_zip,
+ "EP_CPU_ai.onnx_opset13_negative_rules.json",
+ {
+ "__snapshot_type__": "delta_v1",
+ "__base_opset__": 12,
+ "__current_opset__": 13,
+ "__changed__": {"Mul": {"v": 1}},
+ "__deleted__": [],
+ },
+ )
+
+ summary = expand_rules_zip_dir(rules_dir)
+
+ assert summary.zip_files_processed == 1
+ assert [item[0] for item in summary.per_zip] == [real_zip.name]
+ assert (rules_dir / EXPANDED_MARKER_FILE).exists()
+
+
+def test_expand_rules_zip_dir_no_zip_does_not_create_marker(tmp_path: Path) -> None:
+ rules_dir = tmp_path / "rules"
+ rules_dir.mkdir(parents=True, exist_ok=True)
+
+ summary = expand_rules_zip_dir(rules_dir)
+
+ assert summary.zip_files_processed == 0
+ assert not (rules_dir / EXPANDED_MARKER_FILE).exists()
diff --git a/tests/unit/analyze/models/test_rule_loader.py b/tests/unit/analyze/models/test_rule_loader.py
index 8d41bf055..174c7f1c0 100644
--- a/tests/unit/analyze/models/test_rule_loader.py
+++ b/tests/unit/analyze/models/test_rule_loader.py
@@ -23,7 +23,9 @@
from winml.modelkit.analyze import IHVType, RuleLoader
from winml.modelkit.analyze.utils import get_runtime_rules_search_dirs, resolve_rule_zip_path
-from winml.modelkit.analyze.utils.rule_loader import _DEFAULT_RUNTIME_RULES_DIR
+from winml.modelkit.analyze.utils import rule_expander as rule_expander_module
+from winml.modelkit.analyze.utils import rule_loader as rule_loader_module
+from winml.modelkit.analyze.utils.rule_loader import _DEFAULT_RUNTIME_RULES_DIR, _RULE_LOADER_DIR
class TestRuleLoaderBasicLoading:
@@ -468,6 +470,17 @@ def test_env_var_adds_dirs(self, monkeypatch):
assert dirs[1] == Path("/extra/path2").resolve()
assert dirs[2].name == "runtime_check_rules"
+ def test_env_var_relative_path_resolved_from_module_dir(self, monkeypatch):
+ """Relative MODELKIT_RULES_DIR entries are resolved from rule_loader.py dir."""
+ relative_entry = "custom/rules"
+ monkeypatch.setenv("MODELKIT_RULES_DIR", relative_entry)
+
+ dirs = get_runtime_rules_search_dirs()
+
+ assert len(dirs) == 2
+ assert dirs[0] == (_RULE_LOADER_DIR / relative_entry).resolve()
+ assert dirs[1] == _DEFAULT_RUNTIME_RULES_DIR
+
def test_env_var_empty_ignored(self, monkeypatch):
"""Empty MODELKIT_RULES_DIR is treated as unset."""
monkeypatch.setenv("MODELKIT_RULES_DIR", " ")
@@ -501,3 +514,71 @@ def test_resolve_prefers_env_over_default(self, monkeypatch):
result = resolve_rule_zip_path(zip_name)
assert result == Path(tmpdir).resolve() / zip_name
+
+ def test_resolve_auto_expand_disabled_by_default(self, monkeypatch):
+ """Auto-expand is not triggered by resolve when the call is disabled."""
+ zip_name = "QNN_NPU_ai_onnx_opset13.zip"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ rules_dir = Path(tmpdir)
+ (rules_dir / zip_name).write_bytes(b"PK")
+ monkeypatch.setenv("MODELKIT_RULES_DIR", tmpdir)
+ monkeypatch.setattr(rule_loader_module, "_EXPAND_CHECKED_DIRS", set())
+
+ calls: list[Path] = []
+
+ def _fake_expand_rules_zip_dir(
+ rules_dir: Path,
+ *,
+ output_dir: Path | None = None,
+ glob_pattern: str = "*.zip",
+ marker_filename: str = "expanded",
+ ):
+ del output_dir, glob_pattern, marker_filename
+ calls.append(rules_dir.resolve())
+
+ monkeypatch.setattr(
+ rule_expander_module,
+ "expand_rules_zip_dir",
+ _fake_expand_rules_zip_dir,
+ )
+
+ result = resolve_rule_zip_path(zip_name)
+
+ assert result == rules_dir.resolve() / zip_name
+ assert calls == []
+
+ def test_resolve_skips_auto_expand_when_marker_exists(self, monkeypatch):
+ """Auto-expand is skipped when expanded marker already exists."""
+ zip_name = "QNN_NPU_ai_onnx_opset13.zip"
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ rules_dir = Path(tmpdir)
+ (rules_dir / zip_name).write_bytes(b"PK")
+ (rules_dir / rule_expander_module.EXPANDED_MARKER_FILE).touch()
+ monkeypatch.setenv("MODELKIT_RULES_DIR", tmpdir)
+ monkeypatch.setattr(rule_loader_module, "_EXPAND_CHECKED_DIRS", set())
+
+ called = False
+
+ def _fake_expand_rules_zip_dir(
+ rules_dir: Path,
+ *,
+ output_dir: Path | None = None,
+ glob_pattern: str = "*.zip",
+ marker_filename: str = "expanded",
+ ):
+ del rules_dir, output_dir, glob_pattern, marker_filename
+ nonlocal called
+ called = True
+
+ monkeypatch.setattr(
+ rule_expander_module,
+ "expand_rules_zip_dir",
+ _fake_expand_rules_zip_dir,
+ )
+
+ result = resolve_rule_zip_path(zip_name)
+
+ assert result == rules_dir.resolve() / zip_name
+ assert called is False
diff --git a/tests/unit/analyze/runtime_checker/test_runner.py b/tests/unit/analyze/runtime_checker/test_runner.py
index 14a884e9c..41f3ea6b1 100644
--- a/tests/unit/analyze/runtime_checker/test_runner.py
+++ b/tests/unit/analyze/runtime_checker/test_runner.py
@@ -115,7 +115,7 @@ def _fake_new_executor(self) -> _FakeExecutor:
runner = ResilientRunner()
runner._shutdown_executor_two_phase(cancel_futures=True, graceful_timeout_sec=0.0)
- assert executor.shutdown_calls == [(False, True)]
+ assert executor.shutdown_calls == [(False, True), (True, True)]
assert worker.killed is True
- assert worker.closed is True
+ assert worker.closed is False
assert worker.join_calls == [runner._FORCED_KILL_JOIN_TIMEOUT_SEC]
diff --git a/tests/unit/commands/test_expand_rules_cli.py b/tests/unit/commands/test_expand_rules_cli.py
new file mode 100644
index 000000000..1b552bae8
--- /dev/null
+++ b/tests/unit/commands/test_expand_rules_cli.py
@@ -0,0 +1,85 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# --------------------------------------------------------------------------
+
+import json
+import zipfile
+from pathlib import Path
+
+from click.testing import CliRunner
+
+from winml.modelkit.analyze.utils.rule_expander import EXPANDED_MARKER_FILE
+from winml.modelkit.cli import main
+
+
+def _write_json_zip(zip_path: Path, entry_name: str, payload: dict) -> None:
+ with zipfile.ZipFile(zip_path, "w") as zf:
+ zf.writestr(entry_name, json.dumps(payload))
+
+
+def test_expand_rules_command_expands_in_place_from_env(
+ tmp_path: Path, monkeypatch
+) -> None:
+ rules_dir = tmp_path / "rules_zip"
+ rules_dir.mkdir(parents=True, exist_ok=True)
+
+ base_zip = rules_dir / "EP_CPU_ai.onnx_opset12.zip"
+ delta_zip = rules_dir / "EP_CPU_ai.onnx_opset13.zip"
+ base_entry = "EP_CPU_ai.onnx_opset12_negative_rules.json"
+ delta_entry = "EP_CPU_ai.onnx_opset13_negative_rules.json"
+
+ _write_json_zip(base_zip, base_entry, {"Conv": {"v": 1}, "Add": {"v": 2}})
+ _write_json_zip(
+ delta_zip,
+ delta_entry,
+ {
+ "__snapshot_type__": "delta_v1",
+ "__base_opset__": 12,
+ "__current_opset__": 13,
+ "__changed__": {"Mul": {"v": 3}},
+ "__deleted__": ["Add"],
+ },
+ )
+
+ monkeypatch.setenv("MODELKIT_RULES_DIR", str(rules_dir))
+
+ runner = CliRunner()
+ result = runner.invoke(
+ main,
+ ["expand_rules"],
+ )
+
+ assert result.exit_code == 0, result.output
+ assert "zip_files_processed: 2" in result.output
+ assert (rules_dir / EXPANDED_MARKER_FILE).exists()
+
+ with zipfile.ZipFile(delta_zip, "r") as zf:
+ payload = json.loads(zf.read(delta_entry).decode("utf-8"))
+
+ assert "__snapshot_type__" not in payload
+ assert sorted(payload.keys()) == ["Conv", "Mul"]
+
+
+def test_expand_rules_command_skips_when_dir_missing(tmp_path: Path, monkeypatch) -> None:
+ missing = tmp_path / "missing_rules_zip"
+ monkeypatch.setenv("MODELKIT_RULES_DIR", str(missing))
+ runner = CliRunner()
+
+ result = runner.invoke(
+ main,
+ ["expand_rules"],
+ )
+
+ assert result.exit_code == 0
+ assert "does not exist, skip" in result.output
+
+
+def test_expand_rules_command_skips_when_env_unset(monkeypatch) -> None:
+ monkeypatch.delenv("MODELKIT_RULES_DIR", raising=False)
+ runner = CliRunner()
+
+ result = runner.invoke(main, ["expand_rules"])
+
+ assert result.exit_code == 0
+ assert "MODELKIT_RULES_DIR is not set" in result.output