Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 13 additions & 27 deletions src/winml/modelkit/analyze/core/runtime_checker_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,9 +776,7 @@ def update_conditions_(
if inp_name in initializers:
init = initializers[inp_name]
arr = numpy_helper.to_array(init)
update_conditions_(
conditions, input_name, is_variadic, True, arr.shape, make_hashable(arr)
)
update_conditions_(conditions, input_name, is_variadic, True, arr.shape, arr)
conditions[f"{input_name}_is_none"] = False

# Add type_vars info for initializers
Expand All @@ -794,9 +792,7 @@ def update_conditions_(
# Handle Constant node inputs
const_tensor = constants[inp_name]
arr = numpy_helper.to_array(const_tensor)
update_conditions_(
conditions, input_name, is_variadic, True, arr.shape, make_hashable(arr)
)
update_conditions_(conditions, input_name, is_variadic, True, arr.shape, arr)
conditions[f"{input_name}_is_none"] = False

# Add type_vars info for constants
Expand Down Expand Up @@ -923,9 +919,7 @@ def _compute_dynamic_axes(shape: tuple | None, is_constant: bool) -> tuple[int,
conditions[f"{input_name}_is_fixed_shape"] = len(dyn_axes) == 0
conditions[f"{input_name}_dynamic_axes"] = dyn_axes
conditions[f"{input_name}_shape"] = info.shape
conditions[f"{input_name}_value"] = (
make_hashable(info.value) if info.value is not None else None
)
conditions[f"{input_name}_value"] = info.value
conditions[f"{input_name}_is_none"] = False

# Attributes (with attr_ prefix)
Expand Down Expand Up @@ -1605,7 +1599,7 @@ def _maybe_save_failed_node_result(
op_domain: ONNXDomain,
opset_version: int,
result: RuntimeTestResult,
table_filter_conditions: dict[str, Any],
cache_key: Any,
save_node_types: set[str] | None = None,
) -> None:
"""Save unsupported or partial node models without re-running result computation."""
Expand All @@ -1622,7 +1616,7 @@ def _maybe_save_failed_node_result(
self._save_failed_node(
node,
node_model,
make_hashable(table_filter_conditions),
cache_key,
name_suffix="unsupported" if is_unsupported else "partial",
)

Expand Down Expand Up @@ -1839,14 +1833,16 @@ def get_pattern_id(is_qdq):
op_columns = domain_tables.get_columns(node.op_type)
if op_columns is not None:
table_filter_conditions = _build_table_filter_conditions(
conditions,
table_filter_conditions,
op_columns,
infinite_properties,
f"op {node.op_type} (domain: {op_domain})",
)

# Check cache using table_filter_conditions as key
cache_key = make_hashable(table_filter_conditions)
# Values are already hashable from get_query_conditions_for_node;
# just convert the dict to a tuple of sorted items.
cache_key = tuple(sorted(table_filter_conditions.items()))
if cache_key in self._node_result_cache:
cached = self._node_result_cache[cache_key]
return PatternRuntime(
Expand Down Expand Up @@ -1912,10 +1908,10 @@ def get_pattern_id(is_qdq):

try:
compile_result, compile_reason = self._check_negative_rules(
op_neg_rules, conditions, node, "compile"
op_neg_rules, table_filter_conditions, node, "compile"
)
run_result, run_reason = self._check_negative_rules(
op_neg_rules, conditions, node, "run"
op_neg_rules, table_filter_conditions, node, "run"
)
reason = compile_reason + run_reason

Expand All @@ -1932,16 +1928,6 @@ def get_pattern_id(is_qdq):
)
table_file = str(getattr(domain_tables, "_file_name", ""))
table_df = domain_tables[node.op_type]
if op_columns is None:
op_columns = (
domain_tables.get_columns(node.op_type) or table_df.columns.to_list()
)
table_filter_conditions = _build_table_filter_conditions(
conditions,
op_columns,
infinite_properties,
f"op {node.op_type} (domain: {op_domain})",
)

ret = query_table_exact_match(table_df, table_filter_conditions)
if not ret.empty:
Expand Down Expand Up @@ -1992,7 +1978,7 @@ def get_pattern_id(is_qdq):
pattern_match,
node_tags,
fallback_reason,
conditions=make_hashable(table_filter_conditions),
conditions=cache_key,
)
if local_result is not None:
self._node_result_cache[cache_key] = local_result
Expand Down Expand Up @@ -2139,7 +2125,7 @@ def get_pattern_id(is_qdq):
op_domain,
opset_version,
result,
table_filter_conditions,
cache_key,
save_node_types=save_node_types,
)

Expand Down
3 changes: 3 additions & 0 deletions src/winml/modelkit/pattern/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def make_hashable(value: Any, replace_float_with_dummy: bool = True) -> Any:
- Floats -> DUMMY_FLOAT
- Lists/Tuples -> Tuple of processed elements
- Dicts -> Tuple of sorted (key, processed_value) items
- ndarrays -> Tuple of processed elements (converted via tolist())
- Others -> Original value
"""
# Fast path: type identity checks avoid isinstance MRO traversal
Expand All @@ -248,6 +249,8 @@ def make_hashable(value: Any, replace_float_with_dummy: bool = True) -> Any:
return tuple(
sorted([(k, make_hashable(v, replace_float_with_dummy)) for k, v in value.items()])
)
if isinstance(value, np.ndarray):
return make_hashable(value.tolist(), replace_float_with_dummy)
if isinstance(value, np.floating):
return DUMMY_FLOAT if replace_float_with_dummy else float(value)
return value
Expand Down
Loading