From 07f01c05309a4d80c20b9c718f48f8c5166dbb09 Mon Sep 17 00:00:00 2001 From: Chao Zhang Date: Wed, 15 Apr 2026 17:26:13 +0800 Subject: [PATCH] add node cache --- .../analyze/core/runtime_checker_query.py | 36 ++++++++++++++++--- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/src/winml/modelkit/analyze/core/runtime_checker_query.py b/src/winml/modelkit/analyze/core/runtime_checker_query.py index fcdea498e..07d4e13eb 100644 --- a/src/winml/modelkit/analyze/core/runtime_checker_query.py +++ b/src/winml/modelkit/analyze/core/runtime_checker_query.py @@ -1026,6 +1026,8 @@ def __init__( self._failed_nodes_logged: set[Any] = set() # Cache of nodes that have been run locally for quick lookup self._local_run_nodes: dict[Any, RuntimeTestResult] = {} + # Cache of node results keyed by hashable table_filter_conditions + self._node_result_cache: dict[Any, PatternRuntime] = {} # Register per-domain lazy rule objects — no file I/O occurs here. registered_patterns = set(get_registered_pattern_input_generators()) @@ -1842,6 +1844,17 @@ def get_pattern_id(is_qdq): f"op {node.op_type} (domain: {op_domain})", ) + # Check cache using table_filter_conditions as key + cache_key = make_hashable(table_filter_conditions) + if cache_key in self._node_result_cache: + cached = self._node_result_cache[cache_key] + return PatternRuntime( + pattern_id=pattern_id, + result=cached.result, + alternatives=self.alternatives, + pattern_match=pattern_match, + ) + # Phase 3: Check if op exists in target rules if op_domain not in target_neg_rules or node.op_type not in target_neg_rules[op_domain]: domain_rules = target_neg_rules.get(op_domain, {}) @@ -1866,6 +1879,7 @@ def get_pattern_id(is_qdq): conditions=None, ) if local_result is not None: + self._node_result_cache[cache_key] = local_result return local_result result = RuntimeTestResult( @@ -1882,12 +1896,14 @@ def get_pattern_id(is_qdq): node_tags=node_tags, ) - return PatternRuntime( + pattern_runtime = PatternRuntime( pattern_id=pattern_id, result=result, alternatives=self.alternatives, pattern_match=pattern_match, ) + self._node_result_cache[cache_key] = pattern_runtime + return pattern_runtime # Phase 4: Apply negative rules and table matching op_neg_rules = target_neg_rules[op_domain][node.op_type] @@ -1978,6 +1994,7 @@ def get_pattern_id(is_qdq): conditions=make_hashable(table_filter_conditions), ) if local_result is not None: + self._node_result_cache[cache_key] = local_result return local_result result = RuntimeTestResult( @@ -1990,12 +2007,14 @@ def get_pattern_id(is_qdq): debug_details=debug_details, ) - return PatternRuntime( + pattern_runtime = PatternRuntime( pattern_id=pattern_id, result=result, alternatives=self.alternatives, pattern_match=pattern_match, ) + self._node_result_cache[cache_key] = pattern_runtime + return pattern_runtime else: # no table data if run_unknown_op: fallback_reason = self._get_domain_fallback_reason( @@ -2011,6 +2030,7 @@ def get_pattern_id(is_qdq): conditions=None, ) if local_result is not None: + self._node_result_cache[cache_key] = local_result return local_result table_source = "qdq" if is_qdq else "non_qdq" @@ -2061,12 +2081,14 @@ def get_pattern_id(is_qdq): }, ) - return PatternRuntime( + pattern_runtime = PatternRuntime( pattern_id=pattern_id, result=result, alternatives=self.alternatives, pattern_match=pattern_match, ) + self._node_result_cache[cache_key] = pattern_runtime + return pattern_runtime except (OpOptionalInputSupportError, OpLackOfRequiredInformationError) as e: exception_type = type(e).__name__ logger.error( @@ -2094,12 +2116,14 @@ def get_pattern_id(is_qdq): }, ) - return PatternRuntime( + pattern_runtime = PatternRuntime( pattern_id=pattern_id, result=result, alternatives=self.alternatives, pattern_match=pattern_match, ) + self._node_result_cache[cache_key] = pattern_runtime + return pattern_runtime result = RuntimeTestResult( compile=compile_result, @@ -2118,12 +2142,14 @@ def get_pattern_id(is_qdq): save_node_types=save_node_types, ) - return PatternRuntime( + pattern_runtime = PatternRuntime( pattern_id=pattern_id, result=result, alternatives=self.alternatives, pattern_match=pattern_match, ) + self._node_result_cache[cache_key] = pattern_runtime + return pattern_runtime def run_for_subgraph( self,