Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions src/winml/modelkit/analyze/core/runtime_checker_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,8 @@ def __init__(
self._failed_nodes_logged: set[Any] = set()
# Cache of nodes that have been run locally for quick lookup
self._local_run_nodes: dict[Any, RuntimeTestResult] = {}
# Cache of node results keyed by hashable table_filter_conditions
self._node_result_cache: dict[Any, PatternRuntime] = {}

# Register per-domain lazy rule objects — no file I/O occurs here.
registered_patterns = set(get_registered_pattern_input_generators())
Expand Down Expand Up @@ -1842,6 +1844,17 @@ def get_pattern_id(is_qdq):
f"op {node.op_type} (domain: {op_domain})",
)

# Check cache using table_filter_conditions as key
cache_key = make_hashable(table_filter_conditions)
if cache_key in self._node_result_cache:
cached = self._node_result_cache[cache_key]
return PatternRuntime(
pattern_id=pattern_id,
result=cached.result,
alternatives=self.alternatives,
pattern_match=pattern_match,
)

# Phase 3: Check if op exists in target rules
if op_domain not in target_neg_rules or node.op_type not in target_neg_rules[op_domain]:
domain_rules = target_neg_rules.get(op_domain, {})
Expand All @@ -1866,6 +1879,7 @@ def get_pattern_id(is_qdq):
conditions=None,
)
if local_result is not None:
self._node_result_cache[cache_key] = local_result
return local_result

result = RuntimeTestResult(
Expand All @@ -1882,12 +1896,14 @@ def get_pattern_id(is_qdq):
node_tags=node_tags,
)

return PatternRuntime(
pattern_runtime = PatternRuntime(
pattern_id=pattern_id,
result=result,
alternatives=self.alternatives,
pattern_match=pattern_match,
)
self._node_result_cache[cache_key] = pattern_runtime
return pattern_runtime

# Phase 4: Apply negative rules and table matching
op_neg_rules = target_neg_rules[op_domain][node.op_type]
Expand Down Expand Up @@ -1978,6 +1994,7 @@ def get_pattern_id(is_qdq):
conditions=make_hashable(table_filter_conditions),
)
if local_result is not None:
self._node_result_cache[cache_key] = local_result
return local_result

result = RuntimeTestResult(
Expand All @@ -1990,12 +2007,14 @@ def get_pattern_id(is_qdq):
debug_details=debug_details,
)

return PatternRuntime(
pattern_runtime = PatternRuntime(
pattern_id=pattern_id,
result=result,
alternatives=self.alternatives,
pattern_match=pattern_match,
)
self._node_result_cache[cache_key] = pattern_runtime
return pattern_runtime
else: # no table data
if run_unknown_op:
fallback_reason = self._get_domain_fallback_reason(
Expand All @@ -2011,6 +2030,7 @@ def get_pattern_id(is_qdq):
conditions=None,
)
if local_result is not None:
self._node_result_cache[cache_key] = local_result
return local_result

table_source = "qdq" if is_qdq else "non_qdq"
Expand Down Expand Up @@ -2061,12 +2081,14 @@ def get_pattern_id(is_qdq):
},
)

return PatternRuntime(
pattern_runtime = PatternRuntime(
pattern_id=pattern_id,
result=result,
alternatives=self.alternatives,
pattern_match=pattern_match,
)
self._node_result_cache[cache_key] = pattern_runtime
return pattern_runtime
except (OpOptionalInputSupportError, OpLackOfRequiredInformationError) as e:
exception_type = type(e).__name__
logger.error(
Expand Down Expand Up @@ -2094,12 +2116,14 @@ def get_pattern_id(is_qdq):
},
)

return PatternRuntime(
pattern_runtime = PatternRuntime(
pattern_id=pattern_id,
result=result,
alternatives=self.alternatives,
pattern_match=pattern_match,
)
self._node_result_cache[cache_key] = pattern_runtime
return pattern_runtime

result = RuntimeTestResult(
compile=compile_result,
Expand All @@ -2118,12 +2142,14 @@ def get_pattern_id(is_qdq):
save_node_types=save_node_types,
)

return PatternRuntime(
pattern_runtime = PatternRuntime(
pattern_id=pattern_id,
result=result,
alternatives=self.alternatives,
pattern_match=pattern_match,
)
self._node_result_cache[cache_key] = pattern_runtime
return pattern_runtime

def run_for_subgraph(
self,
Expand Down
Loading