diff --git a/.pipelines/modelkit-official-build.yml b/.pipelines/modelkit-official-build.yml index 6cd8243f4..e975a838e 100644 --- a/.pipelines/modelkit-official-build.yml +++ b/.pipelines/modelkit-official-build.yml @@ -72,7 +72,7 @@ extends: displayName: 'Check Python version' - powershell: | - $rulesDir = "$(Build.SourcesDirectory)\ModelKitArtifacts\op_check_results\rules" + $rulesDir = "$(Build.SourcesDirectory)\ModelKitArtifacts\rules_zip" $destDir = "$(Build.SourcesDirectory)\src\winml\modelkit\analyze\rules\runtime_check_rules" $outDir = "$(ob_outputDirectory)" New-Item -ItemType Directory -Path $outDir -Force | Out-Null diff --git a/scripts/materialize_rules_zip.py b/scripts/materialize_rules_zip.py new file mode 100644 index 000000000..6f08487db --- /dev/null +++ b/scripts/materialize_rules_zip.py @@ -0,0 +1,94 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Materialize runtime rule zip snapshots into full JSON payloads. + +This script resolves ``delta_v1`` snapshot chains and rewrites each JSON payload +as a full dictionary without snapshot metadata keys. + +Usage: + uv run python scripts/materialize_rules_zip.py --rules-dir + uv run python scripts/materialize_rules_zip.py --rules-dir --output-dir +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +import sys + + +def _load_materializer(): + """Load materializer utility from package, with src fallback in repo mode.""" + try: + from winml.modelkit.analyze.utils.rule_expander import expand_rules_zip_dir + + return expand_rules_zip_dir + except ModuleNotFoundError: + repo_root = Path(__file__).resolve().parent.parent + src_path = repo_root / "src" + if str(src_path) not in sys.path: + sys.path.insert(0, str(src_path)) + + from winml.modelkit.analyze.utils.rule_expander import expand_rules_zip_dir + + return expand_rules_zip_dir + + +def main() -> None: + parser = argparse.ArgumentParser( + description=( + "Materialize delta snapshots in runtime rule zips into full JSON payloads " + "(remove baseline dependencies)." + ) + ) + parser.add_argument( + "--rules-dir", + type=Path, + required=True, + help="Directory containing runtime rule zip files.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help=( + "Output directory for materialized zips. " + "When omitted, zips are overwritten in-place." + ), + ) + parser.add_argument( + "--glob", + type=str, + default="*.zip", + help="Filename glob used to select zip files (default: *.zip).", + ) + args = parser.parse_args() + expand_rules_zip_dir = _load_materializer() + summary = expand_rules_zip_dir( + args.rules_dir, + output_dir=args.output_dir, + glob_pattern=args.glob, + ) + + if not summary.per_zip: + print(f"No zip files matched '{args.glob}' in {args.rules_dir.resolve()}") + return + + for zip_name, json_count, materialized_count in summary.per_zip: + print( + f"[{zip_name}] json_entries={json_count}, " + f"materialized_delta_entries={materialized_count}" + ) + + print("\nDone.") + print(f" zip_files_processed: {summary.zip_files_processed}") + print(f" zip_files_with_delta: {summary.zip_files_with_delta}") + print(f" json_entries_processed: {summary.json_entries_processed}") + print(f" delta_entries_materialized: {summary.delta_entries_materialized}") + print(f" output_mode: {summary.output_mode}") + + +if __name__ == "__main__": + main() diff --git a/src/winml/modelkit/analyze/core/runtime_checker_query.py b/src/winml/modelkit/analyze/core/runtime_checker_query.py index a54df2674..8ac93183b 100644 --- a/src/winml/modelkit/analyze/core/runtime_checker_query.py +++ b/src/winml/modelkit/analyze/core/runtime_checker_query.py @@ -9,6 +9,7 @@ import json import logging +import re from pathlib import Path from typing import TYPE_CHECKING, Any @@ -67,6 +68,14 @@ EG_RULE_DEBUG_DETAILS_KEY = "__debug_details" EG_RULE_ERROR_KEY = "__error" +# Snapshot metadata keys used by runtime-check rule artifacts. +SNAPSHOT_TYPE_KEY = "__snapshot_type__" +SNAPSHOT_TYPE_DELTA = "delta_v1" +SNAPSHOT_BASE_OPSET_KEY = "__base_opset__" +SNAPSHOT_CURRENT_OPSET_KEY = "__current_opset__" +SNAPSHOT_CHANGED_KEY = "__changed__" +SNAPSHOT_DELETED_KEY = "__deleted__" + class _PseudoNode: """Lightweight stand-in for onnx.NodeProto used only for logging in _check_negative_rules.""" @@ -116,6 +125,110 @@ def _sanitize_df(df: pd.DataFrame) -> pd.DataFrame: return df +def _replace_opset_token(name: str, new_opset: int) -> str: + """Replace the first `_opset` token in a file name.""" + return re.sub(r"_opset\d+", f"_opset{new_opset}", name, count=1) + + +def _read_json_from_zip(zip_path: Path, file_name: str) -> dict[str, Any] | None: + """Read a JSON object from zip entry, returning None if not found/readable.""" + if not zip_path.exists(): + return None + + import zipfile + + try: + with zipfile.ZipFile(zip_path, "r") as zf: + if file_name not in zf.namelist(): + return None + raw = json.loads(zf.read(file_name).decode("utf-8")) + except Exception as e: + logger.debug("Failed to read %s from %s: %s", file_name, zip_path, e) + return None + + if not isinstance(raw, dict): + logger.debug("Unexpected non-dict payload in %s from %s", file_name, zip_path) + return None + return raw + + +def _apply_snapshot_delta( + base_payload: dict[str, Any], delta_payload: dict[str, Any] +) -> dict[str, Any]: + """Apply changed/deleted entries from delta payload onto base payload.""" + merged = dict(base_payload) + changed = delta_payload.get(SNAPSHOT_CHANGED_KEY, {}) + if isinstance(changed, dict): + merged.update(changed) + + deleted = delta_payload.get(SNAPSHOT_DELETED_KEY, []) + if isinstance(deleted, list): + for key in deleted: + if isinstance(key, str): + merged.pop(key, None) + + return merged + + +def _expand_snapshot_payload( + payload: dict[str, Any], + zip_path: Path, + file_name: str, + visited: set[tuple[str, str]] | None = None, +) -> dict[str, Any]: + """Expand a snapshot payload to full materialized data. + + Full snapshots are returned as-is. + Delta snapshots are recursively applied on top of their base opset snapshot. + """ + snapshot_type = payload.get(SNAPSHOT_TYPE_KEY) + if snapshot_type != SNAPSHOT_TYPE_DELTA: + return payload + + base_opset = payload.get(SNAPSHOT_BASE_OPSET_KEY) + if not isinstance(base_opset, int): + logger.warning( + "Delta snapshot %s in %s missing integer %s; " + "applying delta on empty base.", + file_name, + zip_path, + SNAPSHOT_BASE_OPSET_KEY, + ) + return _apply_snapshot_delta({}, payload) + + token = (str(zip_path), file_name) + if visited is None: + visited = set() + if token in visited: + logger.warning( + "Detected cyclic snapshot chain while loading %s from %s; " + "applying delta on empty base.", + file_name, + zip_path, + ) + return _apply_snapshot_delta({}, payload) + + visited.add(token) + + base_file_name = _replace_opset_token(file_name, base_opset) + base_zip_name = _replace_opset_token(zip_path.name, base_opset) + base_zip_path = resolve_rule_zip_path(base_zip_name) + + base_raw = _read_json_from_zip(base_zip_path, base_file_name) + if base_raw is None: + logger.warning( + "Base snapshot %s not found in %s (from %s); " + "applying delta on empty base.", + base_file_name, + base_zip_path, + file_name, + ) + return _apply_snapshot_delta({}, payload) + + base_full = _expand_snapshot_payload(base_raw, base_zip_path, base_file_name, visited) + return _apply_snapshot_delta(base_full, payload) + + def _load_table_columns_file(zip_path: Path, columns_file_name: str | None) -> dict[str, list[str]]: """Load per-op table column metadata from a runtime-rules zip file. @@ -126,22 +239,12 @@ def _load_table_columns_file(zip_path: Path, columns_file_name: str | None) -> d if not columns_file_name or not zip_path.exists(): return {} - try: - import zipfile - - with zipfile.ZipFile(zip_path, "r") as zf: - if columns_file_name not in zf.namelist(): - return {} - raw_columns = json.loads(zf.read(columns_file_name).decode("utf-8")) - except Exception as e: - logger.debug( - "Failed to load column file %s from %s: %s", - columns_file_name, - zip_path, - e, - ) + raw_columns = _read_json_from_zip(zip_path, columns_file_name) + if raw_columns is None: return {} + raw_columns = _expand_snapshot_payload(raw_columns, zip_path, columns_file_name) + if not isinstance(raw_columns, dict): return {} @@ -178,6 +281,8 @@ def __init__( for op_name, columns in (preloaded_columns or {}).items() } self._loaded_tables: dict[str, pd.DataFrame] = {} + self._base_tables: LazyDomainTables | None = None + self._deleted_keys: set[str] = set() self._loaded: bool = False self._columns_loaded: bool = bool(self._raw_columns) or self._columns_file_name is None @@ -193,7 +298,6 @@ def _ensure_loaded(self) -> None: if self._loaded: return self._loaded = True - import zipfile if not self._zip_path.exists(): logger.warning( @@ -203,23 +307,61 @@ def _ensure_loaded(self) -> None: self._zip_path, ) return - try: - with zipfile.ZipFile(self._zip_path, "r") as zf: - if self._file_name in zf.namelist(): - self._raw_data = json.loads(zf.read(self._file_name).decode("utf-8")) - else: - logger.debug(f"Table file not found in zip: {self._file_name}") - except Exception as e: - logger.debug(f"Failed to load table file {self._file_name}: {e}") + + raw_table_payload = _read_json_from_zip(self._zip_path, self._file_name) + if raw_table_payload is None: + logger.debug(f"Table file not found in zip: {self._file_name}") + return + + if raw_table_payload.get(SNAPSHOT_TYPE_KEY) == SNAPSHOT_TYPE_DELTA: + changed = raw_table_payload.get(SNAPSHOT_CHANGED_KEY, {}) + deleted = raw_table_payload.get(SNAPSHOT_DELETED_KEY, []) + self._raw_data = changed if isinstance(changed, dict) else {} + if isinstance(deleted, list): + self._deleted_keys = {key for key in deleted if isinstance(key, str)} + else: + self._deleted_keys = set() + + base_opset = raw_table_payload.get(SNAPSHOT_BASE_OPSET_KEY) + if isinstance(base_opset, int): + base_file_name = _replace_opset_token(self._file_name, base_opset) + base_zip_name = _replace_opset_token(self._zip_path.name, base_opset) + base_zip_path = resolve_rule_zip_path(base_zip_name) + base_columns_file_name = ( + _replace_opset_token(self._columns_file_name, base_opset) + if self._columns_file_name + else None + ) + self._base_tables = LazyDomainTables( + base_zip_path, + base_file_name, + columns_file_name=base_columns_file_name, + ) + else: + logger.warning( + "Delta table snapshot %s in %s missing integer %s; " + "base fallback disabled for this table.", + self._file_name, + self._zip_path, + SNAPSHOT_BASE_OPSET_KEY, + ) + return + + self._raw_data = raw_table_payload def __getitem__(self, key: str) -> pd.DataFrame: """Get table for operator, loading from zip if needed.""" if key not in self._loaded_tables: self._ensure_loaded() - if key not in self._raw_data: + if key in self._raw_data: + self._loaded_tables[key] = _sanitize_df(build_table_df(self._raw_data[key])) + del self._raw_data[key] + elif key in self._deleted_keys: + raise KeyError(f"Operator '{key}' not found in tables") + elif self._base_tables and key in self._base_tables: + self._loaded_tables[key] = self._base_tables[key] + else: raise KeyError(f"Operator '{key}' not found in tables") - self._loaded_tables[key] = _sanitize_df(build_table_df(self._raw_data[key])) - del self._raw_data[key] return self._loaded_tables[key] def __contains__(self, key: str) -> bool: @@ -227,7 +369,13 @@ def __contains__(self, key: str) -> bool: if key in self._loaded_tables: return True self._ensure_loaded() - return key in self._raw_data + if key in self._raw_data: + return True + if key in self._deleted_keys: + return False + if self._base_tables is not None: + return key in self._base_tables + return False def get(self, key: str, default: pd.DataFrame | None = None) -> pd.DataFrame | None: """Get table for operator with default fallback.""" @@ -253,6 +401,12 @@ def get_columns(self, key: str) -> list[str] | None: return list(self._raw_columns[key]) self._ensure_loaded() + if key in self._deleted_keys: + return None + if self._base_tables is not None: + base_columns = self._base_tables.get_columns(key) + if base_columns is not None: + return base_columns raw_table = self._raw_data.get(key) if isinstance(raw_table, dict): @@ -293,8 +447,6 @@ def _ensure_loaded(self) -> None: self._load() def _load(self) -> None: - import zipfile - if not self._zip_path.exists(): if self._set_error_on_missing: self[EG_RULE_ERROR_KEY] = "rules_zip_not_found" @@ -307,17 +459,17 @@ def _load(self) -> None: ) return - with zipfile.ZipFile(self._zip_path, "r") as zf: - if self._rule_file not in zf.namelist(): - if self._set_error_on_missing: - self[EG_RULE_ERROR_KEY] = "negative_rule_file_not_found" - self[EG_RULE_DEBUG_DETAILS_KEY] = str(self._rule_file) - logger.warning(f"Negative rule file not found: {self._rule_file}") - else: - logger.debug(f"Negative rule file not found: {self._rule_file}") - return + raw_payload = _read_json_from_zip(self._zip_path, self._rule_file) + if raw_payload is None: + if self._set_error_on_missing: + self[EG_RULE_ERROR_KEY] = "negative_rule_file_not_found" + self[EG_RULE_DEBUG_DETAILS_KEY] = str(self._rule_file) + logger.warning(f"Negative rule file not found: {self._rule_file}") + else: + logger.debug(f"Negative rule file not found: {self._rule_file}") + return - raw: dict[str, Any] = json.loads(zf.read(self._rule_file).decode("utf-8")) + raw = _expand_snapshot_payload(raw_payload, self._zip_path, self._rule_file) filtered: dict[str, Any] = { key: value diff --git a/src/winml/modelkit/analyze/rules/runtime_check_rules/README.md b/src/winml/modelkit/analyze/rules/runtime_check_rules/README.md index e72a217a3..46193d98d 100644 --- a/src/winml/modelkit/analyze/rules/runtime_check_rules/README.md +++ b/src/winml/modelkit/analyze/rules/runtime_check_rules/README.md @@ -26,10 +26,40 @@ Copy all `*.zip` files from [`gim-home/ModelKitArtifacts/op_check_results/rules/ Set `MODELKIT_RULES_DIR` to one or more directories containing runtime rule zip files. -- Windows (PowerShell): `$env:MODELKIT_RULES_DIR="D:\\rules;E:\\more_rules"` +Important: relative paths are resolved from `src/winml/modelkit/analyze/utils/` (the directory of `rule_loader.py`), not from your current terminal working directory. + +- Windows (PowerShell, user-level absolute path): `[Environment]::SetEnvironmentVariable("MODELKIT_RULES_DIR", "C:\*path*\rules_zip", "User")` +- Windows (PowerShell, user-level repo-relative path): `[Environment]::SetEnvironmentVariable("MODELKIT_RULES_DIR", "..\..\..\..\..\..\ModelKitArtifacts\rules_zip", "User")` Multiple directories are supported using `os.pathsep` (`;` on Windows, `:` on Unix-like systems). +### Option 4: Expand rule zips via CLI command + +You can materialize delta snapshots to full payloads in-place with: + +```bash +winml expand_rules +``` + +This command reads all entries from `MODELKIT_RULES_DIR`, resolves each via +`_resolve_env_rules_dir_entry`, and performs in-place rewrite for each existing +directory that contains matching zip files. + +After a folder is successfully expanded (and has at least one matching zip), +an empty marker file named `expanded` is created in that folder. + +You can also override the path entry: + +```bash +winml expand_rules --rules-dir-entry C:\path\to\rules_zip +``` + +Multiple explicit entries are supported: + +```bash +winml expand_rules --rules-dir-entry C:\path\a --rules-dir-entry C:\path\b +``` + ## Rule zip lookup order The analyzer searches zip files in this order: diff --git a/src/winml/modelkit/analyze/runtime_checker/result_processor.py b/src/winml/modelkit/analyze/runtime_checker/result_processor.py index a39bd89c8..43df61b3f 100644 --- a/src/winml/modelkit/analyze/runtime_checker/result_processor.py +++ b/src/winml/modelkit/analyze/runtime_checker/result_processor.py @@ -23,6 +23,64 @@ from ..utils.rule_loader import get_runtime_rules_search_dirs +# Snapshot metadata keys used in generated rule artifacts. +SNAPSHOT_TYPE_KEY = "__snapshot_type__" +SNAPSHOT_TYPE_DELTA = "delta_v1" +SNAPSHOT_BASE_OPSET_KEY = "__base_opset__" +SNAPSHOT_CURRENT_OPSET_KEY = "__current_opset__" +SNAPSHOT_CHANGED_KEY = "__changed__" +SNAPSHOT_DELETED_KEY = "__deleted__" + + +def _sorted_dict_by_key(payload: dict[str, Any]) -> dict[str, Any]: + """Return a shallow key-sorted dict for stable JSON output.""" + return dict(sorted(payload.items())) + + +def _build_snapshot_payload( + current_payload: dict[str, Any], + current_opset: int, + previous_payload: dict[str, Any] | None, + previous_opset: int | None, +) -> dict[str, Any]: + """Build either a full snapshot (first version) or a delta snapshot. + + Full snapshots keep backward compatibility with existing plain-dict format. + Delta snapshots store only changed/deleted operators relative to the previous opset. + """ + if previous_payload is None or previous_opset is None: + return _sorted_dict_by_key(current_payload) + + changed = { + op_name: value + for op_name, value in current_payload.items() + if op_name not in previous_payload or previous_payload[op_name] != value + } + deleted = sorted(op_name for op_name in previous_payload if op_name not in current_payload) + + return { + SNAPSHOT_TYPE_KEY: SNAPSHOT_TYPE_DELTA, + SNAPSHOT_BASE_OPSET_KEY: previous_opset, + SNAPSHOT_CURRENT_OPSET_KEY: current_opset, + SNAPSHOT_CHANGED_KEY: _sorted_dict_by_key(changed), + SNAPSHOT_DELETED_KEY: deleted, + } + + +def _is_delta_snapshot_payload(payload: Any) -> bool: + return isinstance(payload, dict) and payload.get(SNAPSHOT_TYPE_KEY) == SNAPSHOT_TYPE_DELTA + + +def _can_append_merge(existing_payload: Any, new_payload: Any) -> bool: + """Whether append-mode shallow dict merge is safe for these payloads.""" + return ( + isinstance(existing_payload, dict) + and isinstance(new_payload, dict) + and not _is_delta_snapshot_payload(existing_payload) + and not _is_delta_snapshot_payload(new_payload) + ) + + def _get_input_constraint_types( check_results: list[dict[str, Any]], ) -> dict[str, str]: @@ -552,18 +610,12 @@ def get_opset_version_range(op_name: str, start_opset_version: int, op_domain: s import traceback parser = argparse.ArgumentParser( - description="Process runtime checker results and generate negative rules" + description=( + "Process runtime checker results and generate negative rules for " + "ai.onnx opset 12-22 and com.microsoft opset 1" + ) ) parser.add_argument("input_dir", type=str, help="Input directory containing JSON result files") - parser.add_argument( - "--opset_version", type=int, required=True, help="Opset version for the ONNX operators" - ) - parser.add_argument( - "--opset_domain", - type=str, - required=True, - help="Opset domain for the ONNX operators (e.g., 'ai.onnx', 'com.microsoft')", - ) parser.add_argument( "--output-dir", type=str, help="Output directory for negative rules (defaults to input_dir)" ) @@ -589,18 +641,13 @@ def get_opset_version_range(op_name: str, start_opset_version: int, op_domain: s "(../rules/runtime_check_rules).", ) parser.add_argument( - "-range", - "--opset_range_ref_op", + "--domains", type=str, - default=None, - help="Reference operator name or end opset version number. " - "When a number N is provided, processes all opset versions in " - "[--opset_version, N] (inclusive). " - "When an operator name is provided, computes the range of opset versions " - "that share the same since_version for this op, starting from --opset_version. " - "Example: --opset_range_ref_op 12 --opset_version 11 processes versions 11-12. " - "Example: --opset_range_ref_op Slice --opset_version 11 processes versions 11-12 " - "since Slice has since_versions 1, 10, 11, 13.", + default="ai.onnx,com.microsoft", + help=( + "Comma-separated domains to process from defaults (ai.onnx,com.microsoft). " + "Invalid or unsupported values are ignored." + ), ) args = parser.parse_args() @@ -608,286 +655,363 @@ def get_opset_version_range(op_name: str, start_opset_version: int, op_domain: s output_dir = Path(args.output_dir) if args.output_dir else input_dir output_dir.mkdir(parents=True, exist_ok=True) - # Normalize the opset_domain (ai.onnx -> empty string for ONNX standard) - target_domain = "" if args.opset_domain == "ai.onnx" else args.opset_domain - domain_str_for_filename = args.opset_domain # Keep original for filename matching - json_files = list(input_dir.glob("*.json")) if not json_files: print(f"No JSON files found in {input_dir}") exit(1) - # Extract unique (op_name, ep_name, device, is_qdq) - # combinations from filenames for the target domain - # Filename format: ____opset[_qdq].json import re - op_info_set: set[tuple[str, str, str, bool]] = set() - for json_file in json_files: - is_qdq = json_file.stem.endswith("_qdq") - # Remove opset suffix to get base info - opset_match = re.search(r"_opset(\d+)(?:_qdq)?$", json_file.stem) - if opset_match: - filename_without_opset = json_file.stem[: opset_match.start()] - parts = filename_without_opset.split("_") - if len(parts) == 4: - op_name, ep_name, device, file_domain = parts[:4] - # Only include operators from the target domain - if file_domain == domain_str_for_filename: - op_info_set.add((op_name, ep_name, device, is_qdq)) - - print(f"Found {len(op_info_set)} unique operators to process for domain '{args.opset_domain}'") - - # Determine which opset versions to process - if args.opset_range_ref_op: - if args.opset_range_ref_op.isdigit(): - end_opset = int(args.opset_range_ref_op) - opset_versions_to_process = list(range(args.opset_version, end_opset + 1)) - print(f"Numeric range: will process opset versions {opset_versions_to_process}") + domain_plans: dict[str, list[int]] = { + "ai.onnx": list(range(12, 23)), + "com.microsoft": [1], + } + + requested_domains = [part.strip() for part in args.domains.split(",") if part.strip()] + if not requested_domains: + requested_domains = ["ai.onnx", "com.microsoft"] + + domains_to_process: list[str] = [] + for requested in requested_domains: + normalized = requested.lower() + if normalized == "ai.onnx": + mapped_domain = "ai.onnx" + elif normalized == "com.microsoft": + mapped_domain = "com.microsoft" else: - opset_versions_to_process = get_opset_version_range( - args.opset_range_ref_op, args.opset_version, target_domain - ) print( - f"Reference op '{args.opset_range_ref_op}' " - f"with opset_version {args.opset_version}: " - f"will process opset versions {opset_versions_to_process}" + f"Ignoring unsupported domain '{requested}'. " + "Supported values: ai.onnx, com.microsoft" ) - else: - opset_versions_to_process = [args.opset_version] - - qdq_generator = None - if any(is_qdq for _, _, _, is_qdq in op_info_set): - from ...pattern.op_input_gen.qdq_gen import QDQGenerator - - qdq_generator = QDQGenerator(1, ONNXDomain.COM_MICROSOFT) - - for current_opset_version in opset_versions_to_process: - if len(opset_versions_to_process) > 1: - print(f"\n{'=' * 60}") - print(f"Processing opset version {current_opset_version}") - print(f"{'=' * 60}") - - # Group results by (EP, device, domain, opset, is_qdq) - results_by_ep_domain_opset: dict[tuple[str, str, str, int, bool], dict[str, Any]] = {} - tables_by_ep_domain_opset: dict[tuple[str, str, str, int, bool], dict[str, Any]] = {} - table_columns_by_ep_domain_opset: dict[ - tuple[str, str, str, int, bool], dict[str, list[str]] - ] = {} + continue - for op_name, ep_name, device, is_qdq in sorted(op_info_set): - # Get the since_version for this operator based on - # the current opset_version. Handle Op and Pattern. - # TODO: build a since_version list for - # PatternSchemas based on since_version of - # included ops - try: - since_version = get_op_since_version(op_name, current_opset_version, target_domain) - except SchemaError: - since_version = current_opset_version - - # Build the expected filename with since_version - qdq_suffix = "_qdq" if is_qdq else "" - expected_filename = ( - f"{op_name}_{ep_name}_{device}" - f"_{domain_str_for_filename}" - f"_opset{since_version}{qdq_suffix}.json" - ) - json_file = input_dir / expected_filename - print(f"Processing {expected_filename}...", end=" ") + if mapped_domain not in domains_to_process: + domains_to_process.append(mapped_domain) - if not json_file.exists(): - print(f"{Fore.YELLOW}SKIPPED: File not found. {Style.RESET_ALL}") - continue + if not domains_to_process: + print("No valid domains selected to process.") + exit(1) - if json_file.stat().st_size == 0: - print(f"{Fore.YELLOW}SKIPPED: Empty JSON file. {Style.RESET_ALL}") - continue + for domain_str_for_filename in domains_to_process: + target_domain = "" if domain_str_for_filename == "ai.onnx" else domain_str_for_filename + opset_versions_to_process = domain_plans[domain_str_for_filename] + + # Extract unique (op_name, ep_name, device, is_qdq) + # combinations from filenames for the target domain + # Filename format: ____opset[_qdq].json + op_info_set: set[tuple[str, str, str, bool]] = set() + for json_file in json_files: + is_qdq = json_file.stem.endswith("_qdq") + # Remove opset suffix to get base info + opset_match = re.search(r"_opset(\d+)(?:_qdq)?$", json_file.stem) + if opset_match: + filename_without_opset = json_file.stem[: opset_match.start()] + parts = filename_without_opset.split("_") + if len(parts) == 4: + op_name, ep_name, device, file_domain = parts[:4] + if file_domain == domain_str_for_filename: + op_info_set.add((op_name, ep_name, device, is_qdq)) - try: - with open(json_file, encoding="utf-8") as f: # noqa: PTH123 - data = json.load(f) + print( + f"Found {len(op_info_set)} unique operators to process " + f"for domain '{domain_str_for_filename}'" + ) - op_domain, op_name, ep_name, device, opset_version, is_qdq = _parse_filename( - json_file.stem + if not op_info_set: + print(f"No operators found for domain '{domain_str_for_filename}', skipping.") + continue + + qdq_generator = None + if any(is_qdq for _, _, _, is_qdq in op_info_set): + from ...pattern.op_input_gen.qdq_gen import QDQGenerator + + qdq_generator = QDQGenerator(1, ONNXDomain.COM_MICROSOFT) + + # Keep previous full payload per (ep, device, domain, is_qdq) for delta generation. + previous_negative_rules_payloads: dict[ + tuple[str, str, str, bool], tuple[int, dict[str, Any]] + ] = {} + previous_tables_payloads: dict[tuple[str, str, str, bool], tuple[int, dict[str, Any]]] = {} + previous_columns_payloads: dict[tuple[str, str, str, bool], tuple[int, dict[str, Any]]] = {} + + for current_opset_version in opset_versions_to_process: + if len(opset_versions_to_process) > 1: + print(f"\n{'=' * 60}") + print(f"Processing domain {domain_str_for_filename}, opset {current_opset_version}") + print(f"{'=' * 60}") + + # Group results by (EP, device, domain, opset, is_qdq) + results_by_ep_domain_opset: dict[tuple[str, str, str, int, bool], dict[str, Any]] = {} + tables_by_ep_domain_opset: dict[tuple[str, str, str, int, bool], dict[str, Any]] = {} + table_columns_by_ep_domain_opset: dict[ + tuple[str, str, str, int, bool], dict[str, list[str]] + ] = {} + + for op_name, ep_name, device, is_qdq in sorted(op_info_set): + # Get the since_version for this operator based on + # the current opset_version. Handle Op and Pattern. + # TODO: build a since_version list for + # PatternSchemas based on since_version of + # included ops + try: + since_version = get_op_since_version( + op_name, + current_opset_version, + target_domain, + ) + except SchemaError: + since_version = current_opset_version + + # Build the expected filename with since_version + qdq_suffix = "_qdq" if is_qdq else "" + expected_filename = ( + f"{op_name}_{ep_name}_{device}" + f"_{domain_str_for_filename}" + f"_opset{since_version}{qdq_suffix}.json" ) + json_file = input_dir / expected_filename + print(f"Processing {expected_filename}...", end=" ") - check_results = data.get("check_results", []) + if not json_file.exists(): + print(f"{Fore.YELLOW}SKIPPED: File not found. {Style.RESET_ALL}") + continue - if not check_results: - print(f"{Fore.RED}Error: No check_results found, skipping{Style.RESET_ALL}") + if json_file.stat().st_size == 0: + print(f"{Fore.YELLOW}SKIPPED: Empty JSON file. {Style.RESET_ALL}") continue - # Build negative rules and get DataFrame - domain = ONNXDomain.from_str(op_domain) try: - schema = domain.get_op_schema(op_name, opset_version) - input_generator = get_runtime_checker_op(op_name, domain=op_domain)( - schema, qdq_generator=qdq_generator if is_qdq else None + with open(json_file, encoding="utf-8") as f: # noqa: PTH123 + data = json.load(f) + + op_domain, op_name, ep_name, device, opset_version, is_qdq = _parse_filename( + json_file.stem ) - except SchemaError: - # pattern case - # TODO: if a pattern depends on multiple - # domains, the filename currently contains - # only AI_ONNX; need to recover all domains - domain_versions = { - op_domain: opset_version, - ONNXDomain.COM_MICROSOFT: 1, # safeguard - } - input_generator = get_pattern_input_generator(op_name)(domain_versions) - - op_negative_rules, df = build_op_query_negative_rules_and_table( - check_results, - input_generator, - use_qdq=is_qdq, - op_version=opset_version, - device=device, - ep_name=ep_name, - op_domain=op_domain, + + check_results = data.get("check_results", []) + + if not check_results: + print(f"{Fore.RED}Error: No check_results found, skipping{Style.RESET_ALL}") + continue + + # Build negative rules and get DataFrame + domain = ONNXDomain.from_str(op_domain) + try: + schema = domain.get_op_schema(op_name, opset_version) + input_generator = get_runtime_checker_op(op_name, domain=op_domain)( + schema, qdq_generator=qdq_generator if is_qdq else None + ) + except SchemaError: + # pattern case + # TODO: if a pattern depends on multiple + # domains, the filename currently contains + # only AI_ONNX; need to recover all domains + domain_versions = { + op_domain: opset_version, + ONNXDomain.COM_MICROSOFT: 1, # safeguard + } + input_generator = get_pattern_input_generator(op_name)(domain_versions) + + op_negative_rules, df = build_op_query_negative_rules_and_table( + check_results, + input_generator, + use_qdq=is_qdq, + op_version=opset_version, + device=device, + ep_name=ep_name, + op_domain=op_domain, + ) + + # Group by (EP, domain, current opset_version, is_qdq) + key = (ep_name, device, target_domain, current_opset_version, is_qdq) + if key not in results_by_ep_domain_opset: + results_by_ep_domain_opset[key] = {} + tables_by_ep_domain_opset[key] = {} + table_columns_by_ep_domain_opset[key] = {} + + results_by_ep_domain_opset[key][op_name] = op_negative_rules + + # Convert DataFrame to JSON-serializable format + tables_by_ep_domain_opset[key][op_name] = df.to_dict() + table_columns_by_ep_domain_opset[key][op_name] = [ + col_name + for col_name in df.columns.to_list() + if col_name != "compile_run_success" + ] + + print(f"OK ({len(check_results)} results)") + + except Exception as e: + print(f"{Fore.RED}ERROR: {e}{Style.RESET_ALL}") + traceback.print_exc() + sys.exit(1) + + zip_group = {} + # Save negative rules + for ( + ep_name, + device, + op_domain, + opset_version, + is_qdq, + ), op_results in results_by_ep_domain_opset.items(): + # Create domain-specific filename + domain_str = op_domain if op_domain else "ai.onnx" + qdq_suffix = "_qdq" if is_qdq else "" + output_file = output_dir / ( + f"{ep_name}_{device}_{domain_str}" + f"_opset{opset_version}" + f"_negative_rules{qdq_suffix}.json" ) - # Group by (EP, domain, current opset_version, is_qdq) - key = (ep_name, device, target_domain, current_opset_version, is_qdq) - if key not in results_by_ep_domain_opset: - results_by_ep_domain_opset[key] = {} - tables_by_ep_domain_opset[key] = {} - table_columns_by_ep_domain_opset[key] = {} - - results_by_ep_domain_opset[key][op_name] = op_negative_rules - - # Convert DataFrame to JSON-serializable format - tables_by_ep_domain_opset[key][op_name] = df.to_dict() - table_columns_by_ep_domain_opset[key][op_name] = [ - col_name - for col_name in df.columns.to_list() - if col_name != "compile_run_success" - ] - - print(f"OK ({len(check_results)} results)") - - except Exception as e: - print(f"{Fore.RED}ERROR: {e}{Style.RESET_ALL}") - traceback.print_exc() - sys.exit(1) - - zip_group = {} - # Save negative rules - for ( - ep_name, - device, - op_domain, - opset_version, - is_qdq, - ), op_results in results_by_ep_domain_opset.items(): - # Create domain-specific filename - domain_str = op_domain if op_domain else "ai.onnx" - qdq_suffix = "_qdq" if is_qdq else "" - output_file = output_dir / ( - f"{ep_name}_{device}_{domain_str}" - f"_opset{opset_version}" - f"_negative_rules{qdq_suffix}.json" - ) + snapshot_key = (ep_name, device, op_domain, is_qdq) + previous_snapshot = previous_negative_rules_payloads.get(snapshot_key) + snapshot_payload = _build_snapshot_payload( + op_results, + opset_version, + previous_snapshot[1] if previous_snapshot else None, + previous_snapshot[0] if previous_snapshot else None, + ) - with open(output_file, "w", encoding="utf-8", newline="\n") as f: # noqa: PTH123 - json.dump(dict(sorted(op_results.items())), f, indent=2) - - print(f"\nSaved {len(op_results)} operators to {output_file}") - zip_group.setdefault(f"{ep_name}_{device}", []).append(output_file) - - # Save tables - for ( - ep_name, - device, - op_domain, - opset_version, - is_qdq, - ), op_tables in tables_by_ep_domain_opset.items(): - # Create domain-specific filename - domain_str = op_domain if op_domain else "ai.onnx" - qdq_suffix = "_qdq" if is_qdq else "" - output_file = ( - output_dir - / f"{ep_name}_{device}_{domain_str}_opset{opset_version}_tables{qdq_suffix}.json" - ) + with open(output_file, "w", encoding="utf-8", newline="\n") as f: # noqa: PTH123 + json.dump(snapshot_payload, f, indent=2) + + previous_negative_rules_payloads[snapshot_key] = (opset_version, op_results) + + print(f"\nSaved {len(op_results)} operators to {output_file}") + zip_group.setdefault(f"{ep_name}_{device}", []).append(output_file) + + # Save tables + for ( + ep_name, + device, + op_domain, + opset_version, + is_qdq, + ), op_tables in tables_by_ep_domain_opset.items(): + # Create domain-specific filename + domain_str = op_domain if op_domain else "ai.onnx" + qdq_suffix = "_qdq" if is_qdq else "" + output_file = ( + output_dir + / ( + f"{ep_name}_{device}_{domain_str}_opset{opset_version}_tables" + f"{qdq_suffix}.json" + ) + ) - with open(output_file, "w", encoding="utf-8", newline="\n") as f: # noqa: PTH123 - json.dump(dict(sorted(op_tables.items())), f, indent=2) - - print(f"Saved {len(op_tables)} operator tables to {output_file}") - zip_group.setdefault(f"{ep_name}_{device}", []).append(output_file) - - # Save table column names - for ( - ep_name, - device, - op_domain, - opset_version, - is_qdq, - ), op_columns in table_columns_by_ep_domain_opset.items(): - domain_str = op_domain if op_domain else "ai.onnx" - qdq_suffix = "_qdq" if is_qdq else "" - output_file = output_dir / ( - f"{ep_name}_{device}_{domain_str}" - f"_opset{opset_version}_table_columns{qdq_suffix}.json" - ) + snapshot_key = (ep_name, device, op_domain, is_qdq) + previous_snapshot = previous_tables_payloads.get(snapshot_key) + snapshot_payload = _build_snapshot_payload( + op_tables, + opset_version, + previous_snapshot[1] if previous_snapshot else None, + previous_snapshot[0] if previous_snapshot else None, + ) - with open(output_file, "w", encoding="utf-8", newline="\n") as f: # noqa: PTH123 - json.dump(dict(sorted(op_columns.items())), f, indent=2) + with open(output_file, "w", encoding="utf-8", newline="\n") as f: # noqa: PTH123 + json.dump(snapshot_payload, f, indent=2) + + previous_tables_payloads[snapshot_key] = (opset_version, op_tables) + + print(f"Saved {len(op_tables)} operator tables to {output_file}") + zip_group.setdefault(f"{ep_name}_{device}", []).append(output_file) + + # Save table column names + for ( + ep_name, + device, + op_domain, + opset_version, + is_qdq, + ), op_columns in table_columns_by_ep_domain_opset.items(): + domain_str = op_domain if op_domain else "ai.onnx" + qdq_suffix = "_qdq" if is_qdq else "" + output_file = output_dir / ( + f"{ep_name}_{device}_{domain_str}" + f"_opset{opset_version}_table_columns{qdq_suffix}.json" + ) - print(f"Saved {len(op_columns)} operator table column sets to {output_file}") - zip_group.setdefault(f"{ep_name}_{device}", []).append(output_file) + snapshot_key = (ep_name, device, op_domain, is_qdq) + previous_snapshot = previous_columns_payloads.get(snapshot_key) + snapshot_payload = _build_snapshot_payload( + op_columns, + opset_version, + previous_snapshot[1] if previous_snapshot else None, + previous_snapshot[0] if previous_snapshot else None, + ) - print( - f"\nProcessing complete! Generated " - f"{len(results_by_ep_domain_opset)} " - f"negative rule file(s) " - f"and {len(tables_by_ep_domain_opset)} table file(s), " - f"plus {len(table_columns_by_ep_domain_opset)} table-column file(s)." - ) + with open(output_file, "w", encoding="utf-8", newline="\n") as f: # noqa: PTH123 + json.dump(snapshot_payload, f, indent=2) - if args.update_zip: - rules_dir = ( - Path(args.rules_dir) if args.rules_dir else get_runtime_rules_search_dirs()[0] + previous_columns_payloads[snapshot_key] = (opset_version, op_columns) + + print(f"Saved {len(op_columns)} operator table column sets to {output_file}") + zip_group.setdefault(f"{ep_name}_{device}", []).append(output_file) + + print( + f"\nDomain '{domain_str_for_filename}' processing complete! Generated " + f"{len(results_by_ep_domain_opset)} " + f"negative rule file(s) " + f"and {len(tables_by_ep_domain_opset)} table file(s), " + f"plus {len(table_columns_by_ep_domain_opset)} table-column file(s)." ) - for group_name, file_list in zip_group.items(): - rule_zip_path = ( - rules_dir - / f"{group_name}_{domain_str_for_filename}_opset{current_opset_version}.zip" - ) - # In append mode, load existing zip entries to preserve files not being updated - existing_content: dict[str, bytes] = {} - if args.append and rule_zip_path.exists(): - with zipfile.ZipFile(rule_zip_path, mode="r") as existing_zf: - for name in existing_zf.namelist(): - existing_content[name] = existing_zf.read(name) - - new_arcnames = {Path(f).name for f in file_list} - - with zipfile.ZipFile( - rule_zip_path, mode="w", compression=zipfile.ZIP_DEFLATED - ) as rule_zf: - # Keep existing entries not covered by the new output - for name, data in existing_content.items(): - if name not in new_arcnames: - rule_zf.writestr(name, data) - - for filename in file_list: - arcname = Path(filename).name - if args.append and arcname in existing_content: - # Merge: old dict updated with new dict, then sort - old_dict = json.loads(existing_content[arcname]) - with open(filename, encoding="utf-8") as f: # noqa: PTH123 - new_dict = json.load(f) - merged = dict(sorted({**old_dict, **new_dict}.items())) - rule_zf.writestr(arcname, json.dumps(merged, indent=2)) - else: - rule_zf.write(filename, arcname=arcname) - - print( - f"Rule zip file {group_name}" - f"_{domain_str_for_filename}" - f"_opset{current_opset_version}.zip " - f"updated with {len(file_list)} files." + if args.update_zip: + rules_dir = ( + Path(args.rules_dir) if args.rules_dir else get_runtime_rules_search_dirs()[0] ) + rules_dir.mkdir(parents=True, exist_ok=True) + for group_name, file_list in zip_group.items(): + rule_zip_path = ( + rules_dir + / f"{group_name}_{domain_str_for_filename}_opset{current_opset_version}.zip" + ) + + # In append mode, load existing zip entries to preserve files not being updated + existing_content: dict[str, bytes] = {} + if args.append and rule_zip_path.exists(): + with zipfile.ZipFile(rule_zip_path, mode="r") as existing_zf: + for name in existing_zf.namelist(): + existing_content[name] = existing_zf.read(name) + + new_arcnames = {Path(f).name for f in file_list} + + with zipfile.ZipFile( + rule_zip_path, mode="w", compression=zipfile.ZIP_DEFLATED + ) as rule_zf: + # Keep existing entries not covered by the new output + for name, data in existing_content.items(): + if name not in new_arcnames: + rule_zf.writestr(name, data) + + for filename in file_list: + arcname = Path(filename).name + if args.append and arcname in existing_content: + # Legacy append merge is only valid for plain full snapshots. + with open(filename, encoding="utf-8") as f: # noqa: PTH123 + new_text = f.read() + try: + old_payload = json.loads( + existing_content[arcname].decode("utf-8") + ) + new_payload = json.loads(new_text) + except Exception: + rule_zf.writestr(arcname, new_text) + continue + + if _can_append_merge(old_payload, new_payload): + merged = dict(sorted({**old_payload, **new_payload}.items())) + rule_zf.writestr(arcname, json.dumps(merged, indent=2)) + else: + rule_zf.writestr(arcname, new_text) + else: + rule_zf.write(filename, arcname=arcname) + + print( + f"Rule zip file {group_name}" + f"_{domain_str_for_filename}" + f"_opset{current_opset_version}.zip " + f"updated with {len(file_list)} files." + ) diff --git a/src/winml/modelkit/analyze/runtime_checker/runner.py b/src/winml/modelkit/analyze/runtime_checker/runner.py index 4edb3486a..f5c53b442 100644 --- a/src/winml/modelkit/analyze/runtime_checker/runner.py +++ b/src/winml/modelkit/analyze/runtime_checker/runner.py @@ -194,25 +194,20 @@ def _join_process(proc: Any, timeout: float | None = None) -> None: """Best-effort process join that never raises.""" try: proc.join(timeout=timeout) - except Exception as e: - print(f"Warning: failed to join process during executor shutdown: {e}", file=sys.stderr) + except Exception: + # Intentionally suppress cleanup-time join errors to preserve + # resilient shutdown semantics ("never raises"). + pass @staticmethod def _kill_process(proc: Any) -> None: """Best-effort process kill that never raises.""" try: proc.kill() - except Exception as e: - # Keep cleanup non-fatal, but surface the failure for diagnostics. - print(f"Warning: failed to kill worker process: {e}", file=sys.stderr) - - @staticmethod - def _close_process(proc: Any) -> None: - """Best-effort process close that never raises.""" - try: - proc.close() - except Exception as ex: - print(f"Warning: failed to close process: {ex}", file=sys.stderr) + except Exception: + # Intentionally ignored: process may already be dead or handle invalid + # during teardown, and cleanup must remain best-effort/non-fatal. + pass def _shutdown_executor_two_phase( self, @@ -227,8 +222,8 @@ def _shutdown_executor_two_phase( try: executor.shutdown(wait=False, cancel_futures=cancel_futures) except Exception as exc: - # Best-effort shutdown: keep cleanup flow non-raising, but surface failure. - print(f"Executor shutdown failed during cleanup: {exc}", file=sys.stderr) + # Best-effort shutdown: ignore failures here, but surface context for debugging. + print(f"[ResilientRunner] executor.shutdown failed: {exc}", file=sys.stderr) timeout = ( self._GRACEFUL_SHUTDOWN_TIMEOUT_SEC @@ -251,8 +246,16 @@ def _shutdown_executor_two_phase( for proc in survivors: self._join_process(proc, timeout=self._FORCED_KILL_JOIN_TIMEOUT_SEC) - for proc in lingering: - self._close_process(proc) + # Do not close multiprocessing.Process handles manually here. The + # ProcessPoolExecutor management thread may still call proc.join(), and + # prematurely closing handles can trigger WinError 6 on Windows. + try: + executor.shutdown(wait=True, cancel_futures=cancel_futures) + except Exception as exc: + print( + f"[ResilientRunner] executor.shutdown(wait=True) failed: {exc}", + file=sys.stderr, + ) def run(self, fn: Callable[[Any, Any], Any] | None = None, *args: Any) -> dict[str, Any]: """Execute the function on a single input with automatic retry on failure. diff --git a/src/winml/modelkit/analyze/utils/rule_expander.py b/src/winml/modelkit/analyze/utils/rule_expander.py new file mode 100644 index 000000000..e3ca5c165 --- /dev/null +++ b/src/winml/modelkit/analyze/utils/rule_expander.py @@ -0,0 +1,310 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Expand runtime rule zip snapshots into full JSON payloads.""" + +from __future__ import annotations + +import json +import re +import tempfile +import zipfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any + + +SNAPSHOT_TYPE_KEY = "__snapshot_type__" +SNAPSHOT_TYPE_DELTA = "delta_v1" +SNAPSHOT_BASE_OPSET_KEY = "__base_opset__" +SNAPSHOT_CHANGED_KEY = "__changed__" +SNAPSHOT_DELETED_KEY = "__deleted__" +EXPANDED_MARKER_FILE = "expanded" + +_OPSET_TOKEN_PATTERN = re.compile(r"_opset\d+") + + +@dataclass +class ExpandSummary: + """Summary of an expand run.""" + + zip_files_processed: int + zip_files_with_delta: int + json_entries_processed: int + delta_entries_materialized: int + output_mode: str + per_zip: list[tuple[str, int, int]] + + +def _replace_opset_token(name: str, new_opset: int) -> str: + """Replace the first ``_opset`` token in a name.""" + return _OPSET_TOKEN_PATTERN.sub(f"_opset{new_opset}", name, count=1) + + +def _is_delta_snapshot(payload: Any) -> bool: + """Check whether payload is a delta snapshot dict.""" + return isinstance(payload, dict) and payload.get(SNAPSHOT_TYPE_KEY) == SNAPSHOT_TYPE_DELTA + + +def _sorted_payload(payload: dict[str, Any]) -> dict[str, Any]: + """Return payload sorted by top-level keys for stable output.""" + return dict(sorted(payload.items(), key=lambda item: item[0])) + + +def _apply_delta(base_payload: dict[str, Any], delta_payload: dict[str, Any]) -> dict[str, Any]: + """Apply changed/deleted entries from delta payload onto base payload.""" + changed = delta_payload.get(SNAPSHOT_CHANGED_KEY, {}) + if not isinstance(changed, dict): + raise TypeError(f"Delta snapshot field '{SNAPSHOT_CHANGED_KEY}' must be a dict.") + + deleted = delta_payload.get(SNAPSHOT_DELETED_KEY, []) + if not isinstance(deleted, list): + raise TypeError(f"Delta snapshot field '{SNAPSHOT_DELETED_KEY}' must be a list.") + + merged = dict(base_payload) + merged.update(changed) + + for key in deleted: + if not isinstance(key, str): + raise TypeError( + f"Delta snapshot field '{SNAPSHOT_DELETED_KEY}' contains non-string key: {key!r}" + ) + merged.pop(key, None) + + return _sorted_payload(merged) + + +class SnapshotExpander: + """Resolve and expand snapshot payloads stored in runtime rules zip files.""" + + def __init__(self, rules_dir: Path) -> None: + self.rules_dir = rules_dir + self._cache: dict[tuple[str, str], dict[str, Any]] = {} + + def _read_json_payload(self, zip_name: str, entry_name: str) -> dict[str, Any]: + zip_path = self.rules_dir / zip_name + if not zip_path.exists(): + raise FileNotFoundError(f"Base zip not found: {zip_path}") + + with zipfile.ZipFile(zip_path, "r") as zf: + try: + raw = zf.read(entry_name) + except KeyError as exc: + raise FileNotFoundError( + f"Base entry '{entry_name}' not found in zip '{zip_name}'" + ) from exc + + try: + payload = json.loads(raw.decode("utf-8")) + except Exception as exc: + raise ValueError(f"Failed to parse JSON entry '{entry_name}' in '{zip_name}'") from exc + + if not isinstance(payload, dict): + raise TypeError( + f"JSON entry '{entry_name}' in '{zip_name}' must be a dict, got {type(payload)}" + ) + return payload + + def resolve_payload( + self, + zip_name: str, + entry_name: str, + stack: set[tuple[str, str]] | None = None, + ) -> dict[str, Any]: + """Resolve an entry payload to full expanded contents.""" + token = (zip_name, entry_name) + if token in self._cache: + return self._cache[token] + + if stack is None: + stack = set() + if token in stack: + raise ValueError(f"Detected cyclic snapshot dependency at {zip_name}::{entry_name}") + + stack.add(token) + payload = self._read_json_payload(zip_name, entry_name) + + if _is_delta_snapshot(payload): + base_opset = payload.get(SNAPSHOT_BASE_OPSET_KEY) + if not isinstance(base_opset, int): + raise ValueError( + f"Delta snapshot in {zip_name}::{entry_name} is missing integer " + f"'{SNAPSHOT_BASE_OPSET_KEY}'" + ) + + base_zip_name = _replace_opset_token(zip_name, base_opset) + base_entry_name = _replace_opset_token(entry_name, base_opset) + if base_zip_name == zip_name and base_entry_name == entry_name: + raise ValueError( + "Could not derive base snapshot location for " + f"{zip_name}::{entry_name} (opset token not replaced)." + ) + + base_payload = self.resolve_payload(base_zip_name, base_entry_name, stack) + resolved = _apply_delta(base_payload, payload) + else: + resolved = _sorted_payload(payload) + + stack.remove(token) + self._cache[token] = resolved + return resolved + + +def _expand_single_zip( + zip_path: Path, + dest_path: Path, + expander: SnapshotExpander, +) -> tuple[int, int]: + """Expand all delta JSON entries in one zip. + + Returns: + Tuple of ``(json_entry_count, materialized_delta_count)``. + """ + json_entries = 0 + materialized_entries = 0 + + with zipfile.ZipFile(zip_path, "r") as src, zipfile.ZipFile( + dest_path, + "w", + compression=zipfile.ZIP_DEFLATED, + ) as dst: + for entry in src.infolist(): + raw = src.read(entry.filename) + out_raw = raw + + if entry.filename.endswith(".json"): + json_entries += 1 + try: + payload = json.loads(raw.decode("utf-8")) + except Exception as exc: + raise ValueError( + f"Failed to parse JSON entry '{entry.filename}' in '{zip_path.name}'" + ) from exc + + if _is_delta_snapshot(payload): + materialized = expander.resolve_payload(zip_path.name, entry.filename) + out_raw = (json.dumps(materialized, indent=2) + "\n").encode("utf-8") + materialized_entries += 1 + + dst.writestr(entry, out_raw) + + return json_entries, materialized_entries + + +def expand_rules_zip_dir( + rules_dir: Path, + *, + output_dir: Path | None = None, + glob_pattern: str = "*.zip", + marker_filename: str = EXPANDED_MARKER_FILE, +) -> ExpandSummary: + """Expand delta snapshots in rule zips. + + Args: + rules_dir: Directory containing runtime rule zip files. + output_dir: Optional output directory. If omitted, files are rewritten in place. + glob_pattern: Zip filename pattern to process. + marker_filename: Empty marker file name created after successful expand. + + Returns: + ExpandSummary with per-zip stats. + """ + rules_dir = rules_dir.resolve() + if not rules_dir.exists() or not rules_dir.is_dir(): + raise FileNotFoundError(f"Rules directory not found: {rules_dir}") + + # Ignore stale temp artifacts from prior interrupted in-place runs. + zip_files = [ + path + for path in sorted(rules_dir.glob(glob_pattern), key=lambda p: p.name) + if ".materialized." not in path.name + ] + if not zip_files: + return ExpandSummary( + zip_files_processed=0, + zip_files_with_delta=0, + json_entries_processed=0, + delta_entries_materialized=0, + output_mode=( + f"in-place ({rules_dir})" if output_dir is None else f"copied to {output_dir}" + ), + per_zip=[], + ) + + output_mode = "" + if output_dir is not None: + output_dir = output_dir.resolve() + output_dir.mkdir(parents=True, exist_ok=True) + output_mode = f"copied to {output_dir}" + else: + output_mode = f"in-place ({rules_dir})" + + expander = SnapshotExpander(rules_dir) + + total_json = 0 + total_materialized = 0 + changed_zip_count = 0 + per_zip: list[tuple[str, int, int]] = [] + + for zip_path in zip_files: + if output_dir is None: + with tempfile.NamedTemporaryFile( + mode="wb", + suffix=".zip", + prefix=f"{zip_path.stem}.materialized.", + dir=str(zip_path.parent), + delete=False, + ) as tmp: + tmp_path = Path(tmp.name) + try: + json_count, materialized_count = _expand_single_zip( + zip_path, + tmp_path, + expander, + ) + tmp_path.replace(zip_path) + finally: + if tmp_path.exists(): + tmp_path.unlink() + else: + dest = output_dir / zip_path.name + json_count, materialized_count = _expand_single_zip(zip_path, dest, expander) + + per_zip.append((zip_path.name, json_count, materialized_count)) + total_json += json_count + total_materialized += materialized_count + if materialized_count > 0: + changed_zip_count += 1 + + target_dir = output_dir if output_dir is not None else rules_dir + marker_path = target_dir / marker_filename + marker_path.touch(exist_ok=True) + + return ExpandSummary( + zip_files_processed=len(zip_files), + zip_files_with_delta=changed_zip_count, + json_entries_processed=total_json, + delta_entries_materialized=total_materialized, + output_mode=output_mode, + per_zip=per_zip, + ) + + +# Backward-compatible aliases for existing imports. +MaterializeSummary = ExpandSummary +SnapshotMaterializer = SnapshotExpander + + +def materialize_rules_zip_dir( + rules_dir: Path, + *, + output_dir: Path | None = None, + glob_pattern: str = "*.zip", +) -> ExpandSummary: + """Backward-compatible wrapper for previous API name.""" + return expand_rules_zip_dir( + rules_dir, + output_dir=output_dir, + glob_pattern=glob_pattern, + ) diff --git a/src/winml/modelkit/analyze/utils/rule_loader.py b/src/winml/modelkit/analyze/utils/rule_loader.py index e42029a0d..0079a15c3 100644 --- a/src/winml/modelkit/analyze/utils/rule_loader.py +++ b/src/winml/modelkit/analyze/utils/rule_loader.py @@ -20,18 +20,87 @@ #: Use ``os.pathsep`` (`;` on Windows, `:` on Unix) to separate multiple paths. MODELKIT_RULES_DIR_ENV = "MODELKIT_RULES_DIR" +# Directory containing this module file. Relative env-var entries are resolved from here. +_RULE_LOADER_DIR: Path = Path(__file__).resolve().parent + # Default runtime_check_rules directory (relative to the analyze package). _DEFAULT_RUNTIME_RULES_DIR: Path = ( Path(__file__).resolve().parent.parent / "rules" / "runtime_check_rules" ) +# Track directories already auto-checked in this process to avoid repeated scans/expands. +_EXPAND_CHECKED_DIRS: set[str] = set() + + +def _resolve_env_rules_dir_entry(entry: str) -> Path: + """Resolve a MODELKIT_RULES_DIR entry into an absolute directory path. + + Absolute paths are used directly. Relative paths are interpreted relative + to this module file's directory. + """ + entry_path = Path(entry).expanduser() + if entry_path.is_absolute(): + return entry_path.resolve() + return (_RULE_LOADER_DIR / entry_path).resolve() + + +def _has_non_temp_zip_files(rules_dir: Path, glob_pattern: str = "*.zip") -> bool: + """Return whether the directory contains at least one non-temp zip.""" + return any( + path.is_file() and ".materialized." not in path.name + for path in rules_dir.glob(glob_pattern) + ) + + +def _ensure_rules_dir_expanded_once(rules_dir: Path) -> None: + """Auto-expand rule zips once if marker is missing. + + Behavior: + 1. Skip when directory does not exist. + 2. Skip when marker file already exists. + 3. If directory has zip files and marker is missing, run in-place expand. + """ + resolved_dir = rules_dir.resolve() + dir_key = str(resolved_dir).casefold() + if dir_key in _EXPAND_CHECKED_DIRS: + return + + _EXPAND_CHECKED_DIRS.add(dir_key) + + if not resolved_dir.exists() or not resolved_dir.is_dir(): + return + + try: + from .rule_expander import EXPANDED_MARKER_FILE, expand_rules_zip_dir + + marker_path = resolved_dir / EXPANDED_MARKER_FILE + if marker_path.exists(): + return + + if not _has_non_temp_zip_files(resolved_dir): + return + + logger.info( + "!!! [RULES INIT] One-time runtime rules initialization is required for %s; " + "initializing now (may take up to 30 minutes).", + resolved_dir, + ) + expand_rules_zip_dir(resolved_dir) + except Exception: + logger.exception( + "Failed to auto-expand runtime rule zips in %s; " + "please check the zip files and expand manually if needed.", + resolved_dir, + ) + def get_runtime_rules_search_dirs() -> list[Path]: """Return ordered list of directories to search for runtime check rule zips. The search order is: - 1. Any extra directories listed in the :data:`MODELKIT_RULES_DIR` env var - (separated by ``os.pathsep``). + 1. Any extra directories listed in the :data:`MODELKIT_RULES_DIR` env var + (separated by ``os.pathsep``). Absolute paths are used directly; + relative paths are resolved relative to this module file directory. 2. Default embedded directory (``src/winml/modelkit/analyze/rules/runtime_check_rules/``) Returns: @@ -43,7 +112,7 @@ def get_runtime_rules_search_dirs() -> list[Path]: for entry in env_val.split(os.pathsep): entry = entry.strip() if entry: - dirs.append(Path(entry).resolve()) + dirs.append(_resolve_env_rules_dir_entry(entry)) dirs.append(_DEFAULT_RUNTIME_RULES_DIR) return dirs @@ -62,6 +131,8 @@ def resolve_rule_zip_path(zip_filename: str) -> Path: Resolved ``Path`` to the zip file. """ for search_dir in get_runtime_rules_search_dirs(): + # Disabled by default: one-time rules initialization can be expensive. + # _ensure_rules_dir_expanded_once(search_dir) candidate = search_dir / zip_filename if candidate.exists(): return candidate diff --git a/src/winml/modelkit/commands/expand_rules.py b/src/winml/modelkit/commands/expand_rules.py new file mode 100644 index 000000000..158a572ea --- /dev/null +++ b/src/winml/modelkit/commands/expand_rules.py @@ -0,0 +1,141 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +r"""Expand runtime rule zips in-place for faster loading. + +Resolves rules directories via ``_resolve_env_rules_dir_entry`` and, when +each directory exists and contains zip files, rewrites them in-place to full +payloads (no delta snapshot recursion at load time). + +Usage: + winml expand_rules + winml expand_rules --rules-dir-entry C:\\path\\to\\rules_zip +""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import click + +from ..analyze.utils.rule_expander import expand_rules_zip_dir +from ..analyze.utils.rule_loader import MODELKIT_RULES_DIR_ENV, _resolve_env_rules_dir_entry + + +if TYPE_CHECKING: + from pathlib import Path + + +def _entries_from_env() -> list[str]: + """Read all non-empty MODELKIT_RULES_DIR entries from environment.""" + env_val = os.environ.get(MODELKIT_RULES_DIR_ENV, "").strip() + if not env_val: + return [] + return [entry.strip() for entry in env_val.split(os.pathsep) if entry.strip()] + + +def _resolve_entries(entries: list[str]) -> list[Path]: + """Resolve entries to unique paths while preserving order.""" + resolved_dirs: list[Path] = [] + seen: set[str] = set() + + for entry in entries: + resolved = _resolve_env_rules_dir_entry(entry) + key = str(resolved).casefold() + if key in seen: + continue + seen.add(key) + resolved_dirs.append(resolved) + + return resolved_dirs + + +@click.command("expand_rules") +@click.option( + "--rules-dir-entry", + "rules_dir_entries", + type=str, + multiple=True, + help=( + "Optional rule directory entry. May be repeated. " + "If omitted, uses all entries from MODELKIT_RULES_DIR. " + "Each entry is resolved by rule_loader._resolve_env_rules_dir_entry." + ), +) +@click.option( + "--glob", + "glob_pattern", + type=str, + default="*.zip", + show_default=True, + help="Zip filename glob to process.", +) +def expand_rules(rules_dir_entries: tuple[str, ...], glob_pattern: str) -> None: + """Expand runtime rules zip files in-place when directories and zips exist.""" + entries = list(rules_dir_entries) if rules_dir_entries else _entries_from_env() + + if not entries: + click.echo( + f"{MODELKIT_RULES_DIR_ENV} is not set (or empty) " + "and no --rules-dir-entry provided, skip." + ) + return + + rules_dirs = _resolve_entries(entries) + if not rules_dirs: + click.echo("No resolvable rules directories found, skip.") + return + + grand_zip_count = 0 + grand_delta_zip_count = 0 + grand_json_count = 0 + grand_delta_count = 0 + + for rules_dir in rules_dirs: + if not rules_dir.exists() or not rules_dir.is_dir(): + click.echo(f"Rules directory does not exist, skip: {rules_dir}") + continue + + matched = [ + path + for path in sorted(rules_dir.glob(glob_pattern), key=lambda p: p.name) + if ".materialized." not in path.name + ] + if not matched: + click.echo(f"No zip files matched '{glob_pattern}' in {rules_dir}, skip.") + continue + + click.echo(f"Expanding {len(matched)} zip(s) in: {rules_dir}") + + summary = expand_rules_zip_dir( + rules_dir, + output_dir=None, + glob_pattern=glob_pattern, + ) + + for zip_name, json_count, materialized_count in summary.per_zip: + click.echo( + f"[{zip_name}] json_entries={json_count}, " + f"materialized_delta_entries={materialized_count}" + ) + + click.echo("\nDone.") + click.echo(f" zip_files_processed: {summary.zip_files_processed}") + click.echo(f" zip_files_with_delta: {summary.zip_files_with_delta}") + click.echo(f" json_entries_processed: {summary.json_entries_processed}") + click.echo(f" delta_entries_materialized: {summary.delta_entries_materialized}") + click.echo(f" output_mode: {summary.output_mode}") + + grand_zip_count += summary.zip_files_processed + grand_delta_zip_count += summary.zip_files_with_delta + grand_json_count += summary.json_entries_processed + grand_delta_count += summary.delta_entries_materialized + + if grand_zip_count > 0 and len(rules_dirs) > 1: + click.echo("\nAggregate:") + click.echo(f" zip_files_processed: {grand_zip_count}") + click.echo(f" zip_files_with_delta: {grand_delta_zip_count}") + click.echo(f" json_entries_processed: {grand_json_count}") + click.echo(f" delta_entries_materialized: {grand_delta_count}") diff --git a/tests/unit/analyze/core/test_lazy_domain_tables.py b/tests/unit/analyze/core/test_lazy_domain_tables.py index 46934d9cd..e37c4788e 100644 --- a/tests/unit/analyze/core/test_lazy_domain_tables.py +++ b/tests/unit/analyze/core/test_lazy_domain_tables.py @@ -439,3 +439,111 @@ def test_list_values_are_made_hashable(self, tmp_path: Path): # make_hashable converts lists to tuples assert isinstance(value, tuple) assert value == (1, 2, 3) + + +class TestDeltaSnapshots: + """Test delta snapshot chaining for rules/tables/columns.""" + + def test_lazy_neg_rules_supports_delta_chain( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ): + base_zip = tmp_path / "EP_CPU_ai.onnx_opset12.zip" + delta_zip = tmp_path / "EP_CPU_ai.onnx_opset13.zip" + base_file = "EP_CPU_ai.onnx_opset12_negative_rules.json" + delta_file = "EP_CPU_ai.onnx_opset13_negative_rules.json" + + base_rules = { + "Conv": _make_op_rule("Conv"), + "Add": _make_op_rule("Add"), + } + delta_rules = { + "__snapshot_type__": "delta_v1", + "__base_opset__": 12, + "__current_opset__": 13, + "__changed__": { + "Mul": _make_op_rule("Mul"), + }, + "__deleted__": ["Add"], + } + + with zipfile.ZipFile(base_zip, "w") as zf: + zf.writestr(base_file, json.dumps(base_rules)) + with zipfile.ZipFile(delta_zip, "w") as zf: + zf.writestr(delta_file, json.dumps(delta_rules)) + + monkeypatch.setenv("MODELKIT_RULES_DIR", str(tmp_path)) + + rules = _LazyNegRules(delta_zip, delta_file, REGISTERED_PATTERNS) + keys = set(rules.keys()) + assert "Conv" in keys + assert "Mul" in keys + assert "Add" not in keys + + def test_lazy_domain_tables_supports_delta_chain( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ): + base_zip = tmp_path / "EP_CPU_ai.onnx_opset12.zip" + delta_zip = tmp_path / "EP_CPU_ai.onnx_opset13.zip" + base_tables_file = "EP_CPU_ai.onnx_opset12_tables.json" + delta_tables_file = "EP_CPU_ai.onnx_opset13_tables.json" + base_columns_file = "EP_CPU_ai.onnx_opset12_table_columns.json" + delta_columns_file = "EP_CPU_ai.onnx_opset13_table_columns.json" + + base_tables = { + "Conv": RAW_DATA["Conv"], + "Add": RAW_DATA["Add"], + } + base_columns = { + "Conv": RAW_COLUMNS["Conv"], + "Add": RAW_COLUMNS["Add"], + } + + delta_tables = { + "__snapshot_type__": "delta_v1", + "__base_opset__": 12, + "__current_opset__": 13, + "__changed__": { + "Mul": { + "A_shape": {0: (1, 3, 224, 224)}, + "B_shape": {0: (1, 3, 224, 224)}, + "compile_run_success": {0: (True, True)}, + } + }, + "__deleted__": ["Add"], + } + delta_columns = { + "__snapshot_type__": "delta_v1", + "__base_opset__": 12, + "__current_opset__": 13, + "__changed__": { + "Mul": ["A_shape", "B_shape"], + }, + "__deleted__": ["Add"], + } + + with zipfile.ZipFile(base_zip, "w") as zf: + zf.writestr(base_tables_file, json.dumps(base_tables)) + zf.writestr(base_columns_file, json.dumps(base_columns)) + with zipfile.ZipFile(delta_zip, "w") as zf: + zf.writestr(delta_tables_file, json.dumps(delta_tables)) + zf.writestr(delta_columns_file, json.dumps(delta_columns)) + + monkeypatch.setenv("MODELKIT_RULES_DIR", str(tmp_path)) + + tables = LazyDomainTables( + delta_zip, + delta_tables_file, + columns_file_name=delta_columns_file, + ) + + assert "Conv" in tables # inherited from base + assert "Mul" in tables # changed in delta + assert "Add" not in tables # deleted in delta + + conv_df = tables["Conv"] + mul_df = tables["Mul"] + assert isinstance(conv_df, pd.DataFrame) + assert isinstance(mul_df, pd.DataFrame) + assert tables.get_columns("Conv") == RAW_COLUMNS["Conv"] + assert tables.get_columns("Mul") == ["A_shape", "B_shape"] + assert tables.get_columns("Add") is None diff --git a/tests/unit/analyze/models/test_rule_expander.py b/tests/unit/analyze/models/test_rule_expander.py new file mode 100644 index 000000000..112b1c4bf --- /dev/null +++ b/tests/unit/analyze/models/test_rule_expander.py @@ -0,0 +1,98 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import json +import zipfile +from pathlib import Path + +from winml.modelkit.analyze.utils.rule_expander import EXPANDED_MARKER_FILE, expand_rules_zip_dir + + +def _write_json_zip(zip_path: Path, entry_name: str, payload: dict) -> None: + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr(entry_name, json.dumps(payload)) + + +def test_expand_rules_zip_dir_in_place(tmp_path: Path) -> None: + rules_dir = tmp_path / "rules" + rules_dir.mkdir(parents=True, exist_ok=True) + + base_zip = rules_dir / "EP_CPU_ai.onnx_opset12.zip" + delta_zip = rules_dir / "EP_CPU_ai.onnx_opset13.zip" + base_entry = "EP_CPU_ai.onnx_opset12_negative_rules.json" + delta_entry = "EP_CPU_ai.onnx_opset13_negative_rules.json" + + _write_json_zip( + base_zip, + base_entry, + { + "Conv": {"v": 1}, + "Add": {"v": 2}, + }, + ) + _write_json_zip( + delta_zip, + delta_entry, + { + "__snapshot_type__": "delta_v1", + "__base_opset__": 12, + "__current_opset__": 13, + "__changed__": {"Mul": {"v": 3}}, + "__deleted__": ["Add"], + }, + ) + + summary = expand_rules_zip_dir(rules_dir) + + assert summary.zip_files_processed == 2 + assert summary.zip_files_with_delta == 1 + assert summary.delta_entries_materialized == 1 + assert summary.output_mode.startswith("in-place") + + with zipfile.ZipFile(delta_zip, "r") as zf: + payload = json.loads(zf.read(delta_entry).decode("utf-8")) + + assert "__snapshot_type__" not in payload + assert sorted(payload.keys()) == ["Conv", "Mul"] + marker = rules_dir / EXPANDED_MARKER_FILE + assert marker.exists() + assert marker.is_file() + + +def test_expand_rules_zip_dir_ignores_temp_materialized_zips(tmp_path: Path) -> None: + rules_dir = tmp_path / "rules" + rules_dir.mkdir(parents=True, exist_ok=True) + + real_zip = rules_dir / "EP_CPU_ai.onnx_opset12.zip" + temp_zip = rules_dir / "EP_CPU_ai.onnx_opset13.materialized.abcd.zip" + + _write_json_zip(real_zip, "EP_CPU_ai.onnx_opset12_negative_rules.json", {"Conv": {"v": 1}}) + _write_json_zip( + temp_zip, + "EP_CPU_ai.onnx_opset13_negative_rules.json", + { + "__snapshot_type__": "delta_v1", + "__base_opset__": 12, + "__current_opset__": 13, + "__changed__": {"Mul": {"v": 1}}, + "__deleted__": [], + }, + ) + + summary = expand_rules_zip_dir(rules_dir) + + assert summary.zip_files_processed == 1 + assert [item[0] for item in summary.per_zip] == [real_zip.name] + assert (rules_dir / EXPANDED_MARKER_FILE).exists() + + +def test_expand_rules_zip_dir_no_zip_does_not_create_marker(tmp_path: Path) -> None: + rules_dir = tmp_path / "rules" + rules_dir.mkdir(parents=True, exist_ok=True) + + summary = expand_rules_zip_dir(rules_dir) + + assert summary.zip_files_processed == 0 + assert not (rules_dir / EXPANDED_MARKER_FILE).exists() diff --git a/tests/unit/analyze/models/test_rule_loader.py b/tests/unit/analyze/models/test_rule_loader.py index 8d41bf055..174c7f1c0 100644 --- a/tests/unit/analyze/models/test_rule_loader.py +++ b/tests/unit/analyze/models/test_rule_loader.py @@ -23,7 +23,9 @@ from winml.modelkit.analyze import IHVType, RuleLoader from winml.modelkit.analyze.utils import get_runtime_rules_search_dirs, resolve_rule_zip_path -from winml.modelkit.analyze.utils.rule_loader import _DEFAULT_RUNTIME_RULES_DIR +from winml.modelkit.analyze.utils import rule_expander as rule_expander_module +from winml.modelkit.analyze.utils import rule_loader as rule_loader_module +from winml.modelkit.analyze.utils.rule_loader import _DEFAULT_RUNTIME_RULES_DIR, _RULE_LOADER_DIR class TestRuleLoaderBasicLoading: @@ -468,6 +470,17 @@ def test_env_var_adds_dirs(self, monkeypatch): assert dirs[1] == Path("/extra/path2").resolve() assert dirs[2].name == "runtime_check_rules" + def test_env_var_relative_path_resolved_from_module_dir(self, monkeypatch): + """Relative MODELKIT_RULES_DIR entries are resolved from rule_loader.py dir.""" + relative_entry = "custom/rules" + monkeypatch.setenv("MODELKIT_RULES_DIR", relative_entry) + + dirs = get_runtime_rules_search_dirs() + + assert len(dirs) == 2 + assert dirs[0] == (_RULE_LOADER_DIR / relative_entry).resolve() + assert dirs[1] == _DEFAULT_RUNTIME_RULES_DIR + def test_env_var_empty_ignored(self, monkeypatch): """Empty MODELKIT_RULES_DIR is treated as unset.""" monkeypatch.setenv("MODELKIT_RULES_DIR", " ") @@ -501,3 +514,71 @@ def test_resolve_prefers_env_over_default(self, monkeypatch): result = resolve_rule_zip_path(zip_name) assert result == Path(tmpdir).resolve() / zip_name + + def test_resolve_auto_expand_disabled_by_default(self, monkeypatch): + """Auto-expand is not triggered by resolve when the call is disabled.""" + zip_name = "QNN_NPU_ai_onnx_opset13.zip" + + with tempfile.TemporaryDirectory() as tmpdir: + rules_dir = Path(tmpdir) + (rules_dir / zip_name).write_bytes(b"PK") + monkeypatch.setenv("MODELKIT_RULES_DIR", tmpdir) + monkeypatch.setattr(rule_loader_module, "_EXPAND_CHECKED_DIRS", set()) + + calls: list[Path] = [] + + def _fake_expand_rules_zip_dir( + rules_dir: Path, + *, + output_dir: Path | None = None, + glob_pattern: str = "*.zip", + marker_filename: str = "expanded", + ): + del output_dir, glob_pattern, marker_filename + calls.append(rules_dir.resolve()) + + monkeypatch.setattr( + rule_expander_module, + "expand_rules_zip_dir", + _fake_expand_rules_zip_dir, + ) + + result = resolve_rule_zip_path(zip_name) + + assert result == rules_dir.resolve() / zip_name + assert calls == [] + + def test_resolve_skips_auto_expand_when_marker_exists(self, monkeypatch): + """Auto-expand is skipped when expanded marker already exists.""" + zip_name = "QNN_NPU_ai_onnx_opset13.zip" + + with tempfile.TemporaryDirectory() as tmpdir: + rules_dir = Path(tmpdir) + (rules_dir / zip_name).write_bytes(b"PK") + (rules_dir / rule_expander_module.EXPANDED_MARKER_FILE).touch() + monkeypatch.setenv("MODELKIT_RULES_DIR", tmpdir) + monkeypatch.setattr(rule_loader_module, "_EXPAND_CHECKED_DIRS", set()) + + called = False + + def _fake_expand_rules_zip_dir( + rules_dir: Path, + *, + output_dir: Path | None = None, + glob_pattern: str = "*.zip", + marker_filename: str = "expanded", + ): + del rules_dir, output_dir, glob_pattern, marker_filename + nonlocal called + called = True + + monkeypatch.setattr( + rule_expander_module, + "expand_rules_zip_dir", + _fake_expand_rules_zip_dir, + ) + + result = resolve_rule_zip_path(zip_name) + + assert result == rules_dir.resolve() / zip_name + assert called is False diff --git a/tests/unit/analyze/runtime_checker/test_runner.py b/tests/unit/analyze/runtime_checker/test_runner.py index 14a884e9c..41f3ea6b1 100644 --- a/tests/unit/analyze/runtime_checker/test_runner.py +++ b/tests/unit/analyze/runtime_checker/test_runner.py @@ -115,7 +115,7 @@ def _fake_new_executor(self) -> _FakeExecutor: runner = ResilientRunner() runner._shutdown_executor_two_phase(cancel_futures=True, graceful_timeout_sec=0.0) - assert executor.shutdown_calls == [(False, True)] + assert executor.shutdown_calls == [(False, True), (True, True)] assert worker.killed is True - assert worker.closed is True + assert worker.closed is False assert worker.join_calls == [runner._FORCED_KILL_JOIN_TIMEOUT_SEC] diff --git a/tests/unit/commands/test_expand_rules_cli.py b/tests/unit/commands/test_expand_rules_cli.py new file mode 100644 index 000000000..1b552bae8 --- /dev/null +++ b/tests/unit/commands/test_expand_rules_cli.py @@ -0,0 +1,85 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import json +import zipfile +from pathlib import Path + +from click.testing import CliRunner + +from winml.modelkit.analyze.utils.rule_expander import EXPANDED_MARKER_FILE +from winml.modelkit.cli import main + + +def _write_json_zip(zip_path: Path, entry_name: str, payload: dict) -> None: + with zipfile.ZipFile(zip_path, "w") as zf: + zf.writestr(entry_name, json.dumps(payload)) + + +def test_expand_rules_command_expands_in_place_from_env( + tmp_path: Path, monkeypatch +) -> None: + rules_dir = tmp_path / "rules_zip" + rules_dir.mkdir(parents=True, exist_ok=True) + + base_zip = rules_dir / "EP_CPU_ai.onnx_opset12.zip" + delta_zip = rules_dir / "EP_CPU_ai.onnx_opset13.zip" + base_entry = "EP_CPU_ai.onnx_opset12_negative_rules.json" + delta_entry = "EP_CPU_ai.onnx_opset13_negative_rules.json" + + _write_json_zip(base_zip, base_entry, {"Conv": {"v": 1}, "Add": {"v": 2}}) + _write_json_zip( + delta_zip, + delta_entry, + { + "__snapshot_type__": "delta_v1", + "__base_opset__": 12, + "__current_opset__": 13, + "__changed__": {"Mul": {"v": 3}}, + "__deleted__": ["Add"], + }, + ) + + monkeypatch.setenv("MODELKIT_RULES_DIR", str(rules_dir)) + + runner = CliRunner() + result = runner.invoke( + main, + ["expand_rules"], + ) + + assert result.exit_code == 0, result.output + assert "zip_files_processed: 2" in result.output + assert (rules_dir / EXPANDED_MARKER_FILE).exists() + + with zipfile.ZipFile(delta_zip, "r") as zf: + payload = json.loads(zf.read(delta_entry).decode("utf-8")) + + assert "__snapshot_type__" not in payload + assert sorted(payload.keys()) == ["Conv", "Mul"] + + +def test_expand_rules_command_skips_when_dir_missing(tmp_path: Path, monkeypatch) -> None: + missing = tmp_path / "missing_rules_zip" + monkeypatch.setenv("MODELKIT_RULES_DIR", str(missing)) + runner = CliRunner() + + result = runner.invoke( + main, + ["expand_rules"], + ) + + assert result.exit_code == 0 + assert "does not exist, skip" in result.output + + +def test_expand_rules_command_skips_when_env_unset(monkeypatch) -> None: + monkeypatch.delenv("MODELKIT_RULES_DIR", raising=False) + runner = CliRunner() + + result = runner.invoke(main, ["expand_rules"]) + + assert result.exit_code == 0 + assert "MODELKIT_RULES_DIR is not set" in result.output