diff --git a/src/winml/modelkit/analyze/pattern/check_patterns.py b/src/winml/modelkit/analyze/pattern/check_patterns.py index 3c0cb767a..05c98d131 100644 --- a/src/winml/modelkit/analyze/pattern/check_patterns.py +++ b/src/winml/modelkit/analyze/pattern/check_patterns.py @@ -15,13 +15,10 @@ python -m winml.modelkit.analyze.pattern.check_patterns --all_patterns """ -import json from pathlib import Path from typing import Any -import onnx import onnxruntime as ort -from google.protobuf import json_format from ... import winml from ...onnx import ONNXDomain @@ -33,107 +30,24 @@ from ...sysinfo import SysInfo from ...utils import constants from ..runtime_checker.ep_checker import EPChecker +from ..utils import CheckResultWriter winml.register_execution_providers(ort=True) -class CheckResultWriter: - """Writer for test results that supports continuation from existing files.""" - - def __init__( - self, - file_path: str | Path, - sys_info: dict[str, Any], - save_per_cases: int = 20, - continue_from_existing: bool = False, - ) -> None: - """Initialize the writer. - - Args: - file_path: Path to the output JSON file - sys_info: System information dictionary (constant during run) - save_per_cases: Number of results to accumulate before saving to file - continue_from_existing: If True, read existing file and continue from there. - If False, start fresh (ignore existing file). - """ - self.file_path = Path(file_path) - self.sys_info = sys_info - self.save_per_cases = save_per_cases - self.skip_cases = 0 - self.results: list[dict[str, Any]] = [] - self.pending_count = 0 - - # Only read existing file if continuing - if continue_from_existing and self.file_path.exists(): - with self.file_path.open("r", encoding="utf-8") as f: - data = json.load(f) - if "check_results" in data: - self.results = data["check_results"] - self.skip_cases = len(self.results) - print( - f"Found existing file with " - f"{self.skip_cases} test cases. " - f"Will continue from there." - ) - - def get_skip_cases(self) -> int: - """Get the number of cases to skip when continuing.""" - return self.skip_cases - - def append_result(self, result: dict[str, Any]) -> None: - """Append a test result and save to file periodically. - - Args: - result: Test result dictionary - """ - self.results.append(result) - self.pending_count += 1 - - # Save to file once per save_per_cases results - if self.pending_count >= self.save_per_cases: - self._save() - self.pending_count = 0 - - def __enter__(self) -> "CheckResultWriter": - """Enter context manager.""" - return self - - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - """Exit context manager and flush pending results.""" - self.flush() - - def flush(self) -> None: - """Force save any pending results to file.""" - if self.pending_count > 0: - self._save() - self.pending_count = 0 - - def _save(self) -> None: - """Save results to file.""" - output_data = { - "check_results": self.results, - "sys_info": self.sys_info, - } - - def json_default(obj: Any) -> Any: - if isinstance(obj, onnx.TensorProto): - return json.loads(json_format.MessageToJson(obj)) - raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") - - with self.file_path.open("w", encoding="utf-8", newline="\n") as f: - json.dump(output_data, f, indent=2, default=json_default) - print(f"Saved {len(self.results)} results to {self.file_path}") - - def check_patterns( ep_checker: EPChecker, patterns: list[str], validate_inputs: bool = False, output_dir: str | Path = ".", n_cases: int | None = None, - continue_from_existing: bool = False, save_failed_model: bool = False, + rerun_failed: bool = False, + delta_only: bool = False, + dry_run: bool = False, + not_run_start_id: int = 1, + case_index: str | list[str] | None = None, opset_mapping: dict[str, int] | None = None, ) -> dict[str, dict[str, Any]]: """Run patterns on execution provider and return results. @@ -145,9 +59,12 @@ def check_patterns( output_dir: Output directory for test results JSON files (default: current directory) n_cases: If not None, only run the first n_cases test cases for each pattern. If n_cases is greater than total cases, run all cases. - continue_from_existing: If True, continue from existing result file by skipping - already completed test cases. save_failed_model: If True, save the model for compile failed test cases. + rerun_failed: If True, rerun failed cases (compile or run failed). + delta_only: If True, only run new test cases not in existing results. + dry_run: If True, skip compile/run execution and emit check_result with reason "not_run". + not_run_start_id: Initial id used for not_run placeholder reasons (not_run_). + case_index: Optional hashed signature(s) to filter to specific test cases. opset_mapping: Required dict mapping domain strings to opset versions, e.g., {"ai.onnx": 17, "com.microsoft": 1}. Used for ONNX model generation. @@ -169,6 +86,9 @@ def check_patterns( # Store results for all patterns all_results: dict[str, dict[str, Any]] = {} + if opset_mapping is None: + raise ValueError("opset_mapping must be provided for pattern model generation") + # Convert opset_mapping to ONNXDomain keys if provided domain_versions = { ONNXDomain.from_str(domain): version for domain, version in opset_mapping.items() @@ -214,14 +134,14 @@ def check_patterns( with CheckResultWriter( output_path, sys_info, - continue_from_existing=continue_from_existing, + save_per_cases=None if dry_run else 20, + rerun_failed=rerun_failed, + delta_only=delta_only, + not_run_start_id=not_run_start_id, + filter_case_index=case_index, ) as writer: - skip_cases = writer.get_skip_cases() - # Run tests on execution provider print(f"Running {pattern_name} tests on {ep_checker.ep_name}...") - if skip_cases > 0: - print(f"Continuing from existing results, skipping {skip_cases} test cases") if n_cases is not None: print(f"Limiting to first {n_cases} test cases") @@ -229,13 +149,39 @@ def check_patterns( ep_checker, capture_output=True, n_cases=n_cases, - skip_cases=skip_cases, + skip_cases=0, save_failed_model=save_failed_model, + skip_signature_fn=writer.should_skip_case, + yield_skipped=True, + dry_run=dry_run, ) - # Process results using writer + # Process results in generator order - reuse existing or run new + run_count = 0 + reused_count = 0 + skipped_count = 0 for result in check_results_iter: - writer.append_result(result) + if result.get("_skipped"): + skipped_count += 1 + if writer.reuse_existing_result(result): + reused_count += 1 + else: + writer.append_result(result) + run_count += 1 + + dropped_count = writer.get_dropped_count() + duplicate_skipped = writer.get_duplicate_skipped_count() + print( + f"Ran {run_count} test cases, reused " + f"{reused_count} existing cases, " + f"dropped {dropped_count} obsolete " + f"cases, duplicates skipped " + f"{duplicate_skipped}, skipped " + f"{skipped_count}." + ) + + # Finalize to clear unused signatures before final flush + writer.finalize() check_results = writer.results @@ -284,7 +230,7 @@ def get_ep_checker(ep_name: str, device: str) -> EPChecker: ValueError: If the execution provider name is not supported. """ device_type = constants.DEVICE_TO_DEVICE_TYPE[device] - ep_name_to_checker: dict[str, type[EPChecker]] = { + ep_name_to_checker: dict[str, Any] = { "QNNExecutionProvider": QNNNPUChecker, "OpenVINOExecutionProvider": OpenVINONPUChecker, # Add other EPChecker subclasses here as needed @@ -298,8 +244,8 @@ def get_ep_checker(ep_name: str, device: str) -> EPChecker: return ep_name_to_checker[ep_name](device_type=device_type) -def parse_and_check() -> None: - """Main entry point for command-line execution.""" +def build_parser(): + """Build argument parser for check_patterns-style commands.""" import argparse parser = argparse.ArgumentParser(description="Test ONNX patterns on execution provider") @@ -341,12 +287,33 @@ def parse_and_check() -> None: choices=["CPU", "GPU", "NPU"], help="Target device type (CPU, GPU, NPU).", ) - parser.add_argument( + + opset_group = parser.add_mutually_exclusive_group(required=True) + opset_group.add_argument( "--opset_mapping", type=str, nargs="+", - required=True, - help=("Domain:version pairs for ONNX opset versions, e.g., ai.onnx:17 com.microsoft:1"), + help=( + "Domain:version pairs for ONNX opset versions, " + "e.g., ai.onnx:17 com.microsoft:1" + ), + ) + opset_group.add_argument( + "--opset_version", + type=int, + help=( + "ONNX opset version to use together with --opset_domain. " + "If used without --opset_mapping, com.microsoft:1 is added automatically." + ), + ) + parser.add_argument( + "--opset_domain", + type=str, + default=ONNXDomain.AI_ONNX.value, + help=( + "ONNX opset domain to use with --opset_version " + f"(default: {ONNXDomain.AI_ONNX.value})" + ), ) parser.add_argument( "--validate_inputs", @@ -365,31 +332,104 @@ def parse_and_check() -> None: default=None, help="Limit number of test cases per pattern (default: run all cases)", ) + + mode_group = parser.add_mutually_exclusive_group() + mode_group.add_argument( + "--rerun_failed", + action="store_true", + help=( + "Rerun only failed cases (compile failed or run failed). " + "Mutually exclusive with --delta_only and --case_index." + ), + ) + mode_group.add_argument( + "--delta_only", + action="store_true", + help=( + "Only run new test cases that do not exist in the existing results file. " + "Mutually exclusive with --rerun_failed and --case_index." + ), + ) + mode_group.add_argument( + "--case_index", + type=str, + nargs="+", + default=None, + help=( + "Only process cases matching these case_index hashes. " + "Mutually exclusive with --rerun_failed and --delta_only." + ), + ) + parser.add_argument( - "--continue", + "--dry_run", action="store_true", - dest="continue_from_existing", - help="Continue from existing result file by skipping already completed test cases", + help="Skip compile/run execution and emit check_result with reason 'not_run'", + ) + parser.add_argument( + "--not_run_start_id", + type=int, + default=1, + help="Initial id used for not_run placeholder reasons (not_run_) (default: 1)", ) parser.add_argument( "--save_failed_model", action="store_true", help="Save the model for compile failed test cases", ) - args = parser.parse_args() + return parser + - # Parse opset_mapping from "domain:version" pairs - opset_mapping = None +def _parse_opset_mapping(args: Any) -> dict[str, int]: + """Parse opset mapping from CLI args. + + Supports either: + - --opset_mapping domain:version [domain:version ...] + - --opset_version + optional --opset_domain + """ if args.opset_mapping: - opset_mapping = {} + opset_mapping: dict[str, int] = {} for pair in args.opset_mapping: - domain, version = pair.split(":") - opset_mapping[domain] = int(version) + if ":" not in pair: + raise ValueError( + "Invalid --opset_mapping value " + f"'{pair}'. Expected format: domain:version" + ) + domain, version_text = pair.split(":", 1) + if not domain: + raise ValueError(f"Invalid --opset_mapping value '{pair}': empty domain") + try: + opset_mapping[domain] = int(version_text) + except ValueError as exc: + raise ValueError( + "Invalid --opset_mapping value " + f"'{pair}'. Version must be an integer" + ) from exc + return opset_mapping + + if args.opset_version is None: + raise ValueError("Either --opset_mapping or --opset_version must be provided") + + opset_mapping = {args.opset_domain: int(args.opset_version)} + + # Keep compatibility with existing pattern generators that expect this domain. + if args.opset_domain != ONNXDomain.COM_MICROSOFT.value: + opset_mapping.setdefault(ONNXDomain.COM_MICROSOFT.value, 1) + + return opset_mapping + + +def run_from_args(args: Any) -> None: + """Run check_patterns from parsed CLI args.""" + available_patterns = get_registered_pattern_input_generators() # Determine which patterns to test patterns_to_check = available_patterns if args.all_patterns else args.patterns ep_checker = get_ep_checker(args.ep, device=args.device) + # Parse opset mapping from either mapping pairs or opset_domain/opset_version + opset_mapping = _parse_opset_mapping(args) + # Run the tests check_patterns( ep_checker, @@ -397,11 +437,22 @@ def parse_and_check() -> None: validate_inputs=args.validate_inputs, output_dir=args.output_dir, n_cases=args.n_cases, - continue_from_existing=args.continue_from_existing, save_failed_model=args.save_failed_model, + rerun_failed=args.rerun_failed, + delta_only=args.delta_only, + dry_run=args.dry_run, + not_run_start_id=args.not_run_start_id, + case_index=args.case_index, opset_mapping=opset_mapping, ) +def parse_and_check() -> None: + """Main entry point for command-line execution.""" + parser = build_parser() + args = parser.parse_args() + run_from_args(args) + + if __name__ == "__main__": parse_and_check() diff --git a/src/winml/modelkit/analyze/runtime_checker/case_runner.py b/src/winml/modelkit/analyze/runtime_checker/case_runner.py deleted file mode 100644 index 5039df804..000000000 --- a/src/winml/modelkit/analyze/runtime_checker/case_runner.py +++ /dev/null @@ -1,245 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from __future__ import annotations - -import re -import time -from typing import TYPE_CHECKING, Any - - -if TYPE_CHECKING: - from collections.abc import Callable - -import numpy as np -import onnx - -from ...onnx import ONNXDomain -from ...pattern.op_input_gen import get_runtime_checker_op -from ...pattern.op_input_gen.op_input_gen import ( - InputShapeConstraint, - OpInputGenerator, - model_from_b64, -) -from .check_ops import get_ep_checker -from .runner import ResilientRunner - - -_FAILED_TO_FREE_LIBRARY_TOKEN = "failed to free library" # noqa: S105 -_DURATION_US_RE = re.compile(r"\(\s*\d+\s*us\)") -_TIMESTAMP_RE = re.compile(r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}(?:\.\d+)?") - - -class NeedRestartError(RuntimeError): - """Raised when a fatal provider error requires a machine restart.""" - - -def _strip_us_durations(text: str) -> str: - return _DURATION_US_RE.sub("", text) - - -def _strip_timestamps(text: str) -> str: - return _TIMESTAMP_RE.sub("", text) - - -def _clean_result_payload(result: dict[str, Any]) -> None: - if not isinstance(result, dict): - return - - def _clean_string(val: str) -> str: - return _strip_timestamps(_strip_us_durations(val)) - - for key in ("stdout", "stderr"): - val = result.get(key) - if isinstance(val, str): - result[key] = _clean_string(val) - - res_payload = result.get("result") - if isinstance(res_payload, dict): - reason = res_payload.get("reason") - if isinstance(reason, str): - res_payload["reason"] = _clean_string(reason) - - -def _contains_failed_to_free(record: dict[str, Any]) -> bool: - token = _FAILED_TO_FREE_LIBRARY_TOKEN - - def _has_token(value: object) -> bool: - return isinstance(value, str) and token in value.lower() - - if not isinstance(record, dict): - return False - - result_payload = record.get("result") - if isinstance(result_payload, dict) and _has_token(result_payload.get("reason")): - return True - - return _has_token(record.get("stdout")) or _has_token(record.get("stderr")) - - -def _raise_if_fatal(record: dict[str, Any], stage: str) -> None: - if _contains_failed_to_free(record): - raise NeedRestartError(f"Fatal ep error during {stage}; restart recommended") - - -def _constraint_to_value(constraint: Any, type_annotation: str) -> Any: - if constraint is None: - return None - - if not isinstance(constraint, dict): - return constraint - - constraint_type = constraint.get("type") - if constraint_type == "shape": - shape = constraint.get("shape", []) - min_max = constraint.get("min_max") - return InputShapeConstraint(shape, min_max=min_max).get_value(type_annotation) - - if constraint_type == "value": - value = constraint.get("value") - dtype = constraint.get("dtype") - if dtype is not None and isinstance(value, list): - return np.asarray(value, dtype=dtype) - if isinstance(value, list): - return np.asarray(value) - return value - - if constraint_type == "variadic": - return [ - _constraint_to_value(element, type_annotation) - for element in constraint.get("elements", []) - ] - - return constraint - - -def _build_kwargs_from_case(generator: OpInputGenerator, case: dict[str, Any]) -> dict[str, Any]: - type_var_comb = case.get("type_vars", {}) - applied_type_annotations = { - name: generator._apply_type_var_combination(type_annotation, type_var_comb) - for name, type_annotation in generator.type_annotations.items() - } - - attrs = case.get("attrs", {}) - input_constraints = case.get("input_constraints", {}) - - applied_input_comb: dict[str, Any] = {} - for input_name, constraint in input_constraints.items(): - if input_name in attrs: - continue - type_annotation = applied_type_annotations.get(input_name, "") - applied_input_comb[input_name] = _constraint_to_value(constraint, type_annotation) - - kwargs = {**attrs, **applied_input_comb} - - if generator.op_variadic_input_name is not None and generator.op_variadic_input_name in kwargs: - variadic_input = kwargs.pop(generator.op_variadic_input_name) - for idx, tensor in enumerate(variadic_input): - kwargs[f"{generator.op_variadic_input_name}__{idx}"] = tensor - variadic_keys = [ - f"{generator.op_variadic_input_name}__{i}" for i in range(len(variadic_input)) - ] - normalized_key_order = ( - generator.op_input_names + variadic_keys + generator.op_attribute_names - ) - else: - normalized_key_order = generator.op_input_names + generator.op_attribute_names - - kwargs = {k: kwargs[k] for k in normalized_key_order if k in kwargs} - return generator.filter_kwargs_by_opset(kwargs) - - -def _normalize_qdq_types(case: dict[str, Any]) -> dict[str, Any] | None: - qdq_types = case.get("qdq_types") - if not isinstance(qdq_types, dict): - return None - return qdq_types - - -class RunCaseRunner: - """Execute runtime checker cases; op configuration is provided per call.""" - - def __init__(self) -> None: - # Reuse a single child-process runner per instance to avoid per-case spawn cost. - self._runner: ResilientRunner | None = ResilientRunner(capture_output=True, timeout_sec=60) - - def close(self) -> None: - """Shut down the underlying child-process runner and release resources.""" - try: - if self._runner is not None: - self._runner.shutdown() - except Exception: - pass - - def _ensure_runner(self) -> ResilientRunner: - if self._runner is None: - self._runner = ResilientRunner(capture_output=True, timeout_sec=60) - return self._runner - - def run_case_check_result( - self, - case: dict[str, Any], - op_name: str, - op_domain: str, - opset_version: int, - ep_name: str, - device: str, - timing_hook: Callable[[dict[str, float]], None] | None = None, - ) -> dict[str, Any]: - """Run a single case and return compile/run check results.""" - domain = ONNXDomain.from_str(op_domain) - schema = domain.get_op_schema(op_name, opset_version) - generator_cls = get_runtime_checker_op(op_name, domain=op_domain) - ep_checker = get_ep_checker(ep_name, device) - - generator: OpInputGenerator = generator_cls(schema) - runner = self._ensure_runner() - - t0 = time.perf_counter() - - kwargs = _build_kwargs_from_case(generator, case) - qdq_types = _normalize_qdq_types(case) - model_bytes_b64 = case.get("model_bytes_b64") - if not isinstance(model_bytes_b64, str): - raise TypeError("Case is missing model_bytes_b64 payload") - model_bytes = model_from_b64(model_bytes_b64) - - # Validate that the decoded bytes form a valid ONNX model before dispatching. - try: - onnx.load_from_string(model_bytes) - # some known issue with onnx.checker throwing like - # "Field 'shape' of 'type' is required but missing" - # onnx.checker.check_model(onnx_model) - except Exception as exc: - raise ValueError("Decoded model_bytes_b64 is not a valid ONNX model") from exc - - input_kwargs = {k: v for k, v in kwargs.items() if generator._is_input_key(k)} - ep_checker_inputs = generator.create_input_dict(input_kwargs, qdq_types=qdq_types) - t_build_done = time.perf_counter() - - compile_result = runner.run(ep_checker.check_compile, model_bytes, ep_checker_inputs) - _clean_result_payload(compile_result) - _raise_if_fatal(compile_result, "compile") - t_compile_done = time.perf_counter() - - run_result = runner.run(ep_checker.check_run, model_bytes, ep_checker_inputs) - _clean_result_payload(run_result) - _raise_if_fatal(run_result, "run") - t_run_done = time.perf_counter() - - if timing_hook is not None: - timing_hook( - { - "build_ms": (t_build_done - t0) * 1000, - "compile_ms": (t_compile_done - t_build_done) * 1000, - "run_ms": (t_run_done - t_compile_done) * 1000, - "total_ms": (t_run_done - t0) * 1000, - } - ) - - return { - "compile": compile_result, - "run": run_result, - } diff --git a/src/winml/modelkit/analyze/runtime_checker/result_processor.py b/src/winml/modelkit/analyze/runtime_checker/result_processor.py index 6f17fb85e..a39bd89c8 100644 --- a/src/winml/modelkit/analyze/runtime_checker/result_processor.py +++ b/src/winml/modelkit/analyze/runtime_checker/result_processor.py @@ -701,6 +701,10 @@ def get_opset_version_range(op_name: str, start_opset_version: int, op_domain: s print(f"{Fore.YELLOW}SKIPPED: File not found. {Style.RESET_ALL}") continue + if json_file.stat().st_size == 0: + print(f"{Fore.YELLOW}SKIPPED: Empty JSON file. {Style.RESET_ALL}") + continue + try: with open(json_file, encoding="utf-8") as f: # noqa: PTH123 data = json.load(f) diff --git a/src/winml/modelkit/analyze/runtime_checker/runner.py b/src/winml/modelkit/analyze/runtime_checker/runner.py index bc5a6a6c0..4edb3486a 100644 --- a/src/winml/modelkit/analyze/runtime_checker/runner.py +++ b/src/winml/modelkit/analyze/runtime_checker/runner.py @@ -6,6 +6,7 @@ import concurrent.futures as cf import multiprocessing as mp import sys +import time from collections.abc import Callable from pathlib import Path from typing import Any @@ -160,11 +161,99 @@ def __init__( self.timeout_sec = timeout_sec self.ctx = mp.get_context("spawn") # avoid fork-related instability self.executor = self._new_executor() + self._executor_needs_recreate = False + + _GRACEFUL_SHUTDOWN_TIMEOUT_SEC = 2.0 + _FORCED_KILL_JOIN_TIMEOUT_SEC = 0.2 def _new_executor(self) -> cf.ProcessPoolExecutor: """Create a new single-worker process pool executor.""" return cf.ProcessPoolExecutor(max_workers=1, mp_context=self.ctx) + @staticmethod + def _snapshot_processes(executor: cf.ProcessPoolExecutor) -> list[Any]: + """Best-effort snapshot of worker process handles from the executor.""" + try: + processes = getattr(executor, "_processes", None) + if not processes: + return [] + return [proc for proc in processes.values() if proc is not None] + except Exception: + return [] + + @staticmethod + def _is_process_alive(proc: Any) -> bool: + """Check process liveness without propagating process-state errors.""" + try: + return bool(proc.is_alive()) + except Exception: + return False + + @staticmethod + def _join_process(proc: Any, timeout: float | None = None) -> None: + """Best-effort process join that never raises.""" + try: + proc.join(timeout=timeout) + except Exception as e: + print(f"Warning: failed to join process during executor shutdown: {e}", file=sys.stderr) + + @staticmethod + def _kill_process(proc: Any) -> None: + """Best-effort process kill that never raises.""" + try: + proc.kill() + except Exception as e: + # Keep cleanup non-fatal, but surface the failure for diagnostics. + print(f"Warning: failed to kill worker process: {e}", file=sys.stderr) + + @staticmethod + def _close_process(proc: Any) -> None: + """Best-effort process close that never raises.""" + try: + proc.close() + except Exception as ex: + print(f"Warning: failed to close process: {ex}", file=sys.stderr) + + def _shutdown_executor_two_phase( + self, + *, + cancel_futures: bool, + graceful_timeout_sec: float | None = None, + ) -> None: + """Shutdown executor with graceful wait, then force-kill lingering workers.""" + executor = self.executor + lingering = self._snapshot_processes(executor) + + try: + executor.shutdown(wait=False, cancel_futures=cancel_futures) + except Exception as exc: + # Best-effort shutdown: keep cleanup flow non-raising, but surface failure. + print(f"Executor shutdown failed during cleanup: {exc}", file=sys.stderr) + + timeout = ( + self._GRACEFUL_SHUTDOWN_TIMEOUT_SEC + if graceful_timeout_sec is None + else max(0.0, graceful_timeout_sec) + ) + deadline = time.monotonic() + timeout + + for proc in lingering: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + self._join_process(proc, timeout=remaining) + + survivors = [proc for proc in lingering if self._is_process_alive(proc)] + + for proc in survivors: + self._kill_process(proc) + + for proc in survivors: + self._join_process(proc, timeout=self._FORCED_KILL_JOIN_TIMEOUT_SEC) + + for proc in lingering: + self._close_process(proc) + def run(self, fn: Callable[[Any, Any], Any] | None = None, *args: Any) -> dict[str, Any]: """Execute the function on a single input with automatic retry on failure. @@ -187,6 +276,11 @@ def run(self, fn: Callable[[Any, Any], Any] | None = None, *args: Any) -> dict[s "stdout": None, "stderr": None, } + + if self._executor_needs_recreate: + self.executor = self._new_executor() + self._executor_needs_recreate = False + attempts = 0 while True: attempts += 1 @@ -196,9 +290,16 @@ def run(self, fn: Callable[[Any, Any], Any] | None = None, *args: Any) -> dict[s try: return future.result(timeout=self.timeout_sec) except Exception as e: - self.executor.shutdown(wait=False, cancel_futures=True) - self.executor = self._new_executor() + try: + future.cancel() + except Exception: + # Best-effort cleanup: ignore cancel failures so retry flow can continue. + pass + + self._shutdown_executor_two_phase(cancel_futures=True) + if attempts >= self.max_retries: + self._executor_needs_recreate = True # TODO: capture stdout/stderr on timeout/crashed inputs return { "result": { @@ -208,11 +309,12 @@ def run(self, fn: Callable[[Any, Any], Any] | None = None, *args: Any) -> dict[s "stdout": None, "stderr": None, } + self.executor = self._new_executor() continue def shutdown(self) -> None: """Shut down the executor cleanly.""" - self.executor.shutdown(wait=True, cancel_futures=False) + self._shutdown_executor_two_phase(cancel_futures=False) def __enter__(self) -> "ResilientRunner": """Support context manager protocol.""" diff --git a/src/winml/modelkit/pattern/conv2d_inplace_linear_patterns.py b/src/winml/modelkit/pattern/conv2d_inplace_linear_patterns.py index a495f9e7c..49844c415 100644 --- a/src/winml/modelkit/pattern/conv2d_inplace_linear_patterns.py +++ b/src/winml/modelkit/pattern/conv2d_inplace_linear_patterns.py @@ -383,7 +383,7 @@ class Conv2DInplaceLinear4DPatternInputGenerator( """Input generator for 4D Conv2DInplaceLinear pattern.""" pattern = Conv2DInplaceLinear4DPattern() - registration_name = "Conv2DInplaceLinear4D" + registration_name = "Conv2DInplaceLinear4DPattern" def _get_a_shapes(self) -> list[tuple[int, ...]]: """4D NHWC shapes.""" @@ -398,7 +398,7 @@ class Conv2DInplaceLinear3DPatternInputGenerator( """Input generator for 3D Conv2DInplaceLinear pattern.""" pattern = Conv2DInplaceLinear3DPattern() - registration_name = "Conv2DInplaceLinear3D" + registration_name = "Conv2DInplaceLinear3DPattern" def _get_a_shapes(self) -> list[tuple[int, ...]]: """3D shapes (batch, seq, features).""" @@ -413,7 +413,7 @@ class Conv2DInplaceLinear2DPatternInputGenerator( """Input generator for 2D Conv2DInplaceLinear pattern.""" pattern = Conv2DInplaceLinear2DPattern() - registration_name = "Conv2DInplaceLinear2D" + registration_name = "Conv2DInplaceLinear2DPattern" def _get_a_shapes(self) -> list[tuple[int, ...]]: """2D shapes.""" diff --git a/tests/unit/analyze/runtime_checker/test_runner.py b/tests/unit/analyze/runtime_checker/test_runner.py new file mode 100644 index 000000000..14a884e9c --- /dev/null +++ b/tests/unit/analyze/runtime_checker/test_runner.py @@ -0,0 +1,121 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Unit tests for ResilientRunner process lifecycle and recovery behavior.""" + +from __future__ import annotations + +from typing import Any + +from winml.modelkit.analyze.runtime_checker.runner import ResilientRunner + + +class _FakeFuture: + def __init__(self, outcome: Any): + self._outcome = outcome + self.cancel_called = False + + def result(self, timeout: float | None = None) -> dict[str, Any]: + _ = timeout + if isinstance(self._outcome, Exception): + raise self._outcome + return self._outcome + + def cancel(self) -> bool: + self.cancel_called = True + return True + + +class _FakeProcess: + def __init__(self) -> None: + self.alive = True + self.killed = False + self.closed = False + self.join_calls: list[float | None] = [] + + def is_alive(self) -> bool: + return self.alive + + def join(self, timeout: float | None = None) -> None: + self.join_calls.append(timeout) + + def kill(self) -> None: + self.killed = True + self.alive = False + + def close(self) -> None: + self.closed = True + + +class _FakeExecutor: + def __init__(self, futures: list[_FakeFuture], processes: list[_FakeProcess] | None = None): + self._futures = futures + self._processes = dict(enumerate(processes or [])) + self.shutdown_calls: list[tuple[bool, bool]] = [] + self.submit_calls = 0 + + def submit(self, fn, *args): + _ = fn + _ = args + self.submit_calls += 1 + return self._futures.pop(0) + + def shutdown(self, *, wait: bool, cancel_futures: bool) -> None: + self.shutdown_calls.append((wait, cancel_futures)) + + +class TestResilientRunner: + def test_run_recreates_executor_lazily_after_terminal_failure(self, monkeypatch): + failure_future = _FakeFuture(TimeoutError("timed out")) + success_payload = { + "result": {"success": True, "reason": None}, + "stdout": None, + "stderr": None, + } + success_future = _FakeFuture(success_payload) + + first_executor = _FakeExecutor([failure_future]) + second_executor = _FakeExecutor([success_future]) + created_executors = [first_executor, second_executor] + create_calls: list[int] = [] + + def _fake_new_executor(self) -> _FakeExecutor: + _ = self + create_calls.append(1) + return created_executors[len(create_calls) - 1] + + monkeypatch.setattr(ResilientRunner, "_new_executor", _fake_new_executor) + + runner = ResilientRunner(max_retries=1, timeout_sec=0.001) + + first_result = runner.run(lambda: None) + assert first_result["result"]["success"] is False + assert runner._executor_needs_recreate is True + assert runner.executor is first_executor + assert len(create_calls) == 1 + + second_result = runner.run(lambda: None) + assert second_result == success_payload + assert runner._executor_needs_recreate is False + assert runner.executor is second_executor + assert len(create_calls) == 2 + + def test_shutdown_executor_two_phase_kills_survivors(self, monkeypatch): + worker = _FakeProcess() + executor = _FakeExecutor([], processes=[worker]) + + def _fake_new_executor(self) -> _FakeExecutor: + _ = self + return executor + + monkeypatch.setattr(ResilientRunner, "_new_executor", _fake_new_executor) + + runner = ResilientRunner() + runner._shutdown_executor_two_phase(cancel_futures=True, graceful_timeout_sec=0.0) + + assert executor.shutdown_calls == [(False, True)] + assert worker.killed is True + assert worker.closed is True + assert worker.join_calls == [runner._FORCED_KILL_JOIN_TIMEOUT_SEC]