Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 243 additions & 6 deletions src/winml/modelkit/analyze/core/runtime_checker_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,19 +1141,23 @@ def __init__(
dynamic axis indices.
"""
self.dynamic_axis_strict_mode = dynamic_axis_strict_mode
self.model_proto: onnx.ModelProto = model_proto
self.model_path = Path(model_path) if model_path is not None else None
# Try shape inference: standard ONNX first, then symbolic (onnxruntime)
try:
# Standard ONNX shape inference — uses temp file for models
# with external data (avoids silent empty-graph result).
self.model_proto = infer_onnx_shapes(model_proto)
inferred_model = infer_onnx_shapes(model_proto)
self.model_proto = inferred_model if inferred_model is not None else model_proto

# Then try to enhance with symbolic shape inference
# if available which supports Microsoft domain
try:
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference

self.model_proto = SymbolicShapeInference.infer_shapes(self.model_proto)
symbolic_model = SymbolicShapeInference.infer_shapes(self.model_proto)
if symbolic_model is not None:
self.model_proto = symbolic_model
except Exception as e:
# If symbolic shape inference fails, continue with standard inference result
logger.debug(
Expand Down Expand Up @@ -1357,6 +1361,196 @@ def _get_ep_checker(self) -> EPChecker:
)
return self._ep_checker

@staticmethod
def _clone_node_proto(node: onnx.NodeProto) -> onnx.NodeProto:
"""Clone a node proto so extracted test models do not reuse graph objects."""
cloned = onnx.NodeProto()
cloned.CopyFrom(node)
return cloned

def _find_producer_node(self, tensor_name: str) -> onnx.NodeProto | None:
"""Return the node that produces a tensor, if any."""
if not tensor_name:
return None

for candidate in self.model_proto.graph.node:
if tensor_name in candidate.output:
return candidate
return None

def _find_consumer_nodes(self, tensor_name: str) -> list[onnx.NodeProto]:
"""Return nodes that consume a tensor."""
if not tensor_name:
return []

return [
candidate for candidate in self.model_proto.graph.node if tensor_name in candidate.input
]

def _build_opset_imports(
self,
nodes: list[onnx.NodeProto],
fallback_op_domain: ONNXDomain,
fallback_opset_version: int,
) -> list[onnx.OperatorSetIdProto]:
"""Build opset imports for an extracted runtime-test model."""
opset_imports: list[onnx.OperatorSetIdProto] = []
added_domains: set[str] = set()
saw_non_default_domain = False

def add_domain(domain_str: str, version: int) -> None:
canonical_domain = "" if domain_str in {"", ONNXDomain.AI_ONNX.value} else domain_str
if canonical_domain in added_domains:
return

added_domains.add(canonical_domain)
effective_version = max(version, 7) if canonical_domain == "" else version
opset_imports.append(onnx.helper.make_opsetid(canonical_domain, effective_version))

for included_node in nodes:
raw_domain = included_node.domain or ""
try:
node_domain = ONNXDomain.from_str(raw_domain)
add_domain(node_domain.schema_domain, self.opset_versions.get(node_domain, 1))
saw_non_default_domain = saw_non_default_domain or node_domain != ONNXDomain.AI_ONNX
except ValueError:
add_domain(raw_domain, 1)
saw_non_default_domain = saw_non_default_domain or bool(raw_domain)

if not opset_imports:
add_domain(fallback_op_domain.schema_domain, fallback_opset_version)
saw_non_default_domain = fallback_op_domain != ONNXDomain.AI_ONNX

if saw_non_default_domain and "" not in added_domains:
default_opset = self.opset_versions.get(ONNXDomain.AI_ONNX, 17)
add_domain("", default_opset)

return opset_imports

def _build_runtime_test_model(
self,
node: onnx.NodeProto,
op_domain: ONNXDomain,
opset_version: int,
include_adjacent_qdq: bool = False,
) -> onnx.ModelProto:
"""Build the model used for local EP fallback and failed-node artifacts.

For QDQ operators, include the adjacent DequantizeLinear and QuantizeLinear
nodes so the local test model preserves the same quantized context.
"""
if not include_adjacent_qdq:
return self._build_single_node_model(node, op_domain, opset_version)

graph_inputs: list[onnx.ValueInfoProto] = []
graph_initializers: list[onnx.TensorProto] = []
graph_outputs: list[onnx.ValueInfoProto] = []
pre_nodes: list[onnx.NodeProto] = []
post_nodes: list[onnx.NodeProto] = []
seen_inputs: set[str] = set()
seen_initializers: set[str] = set()
seen_outputs: set[str] = set()
seen_pre_nodes: set[str] = set()
seen_post_nodes: set[str] = set()

def add_graph_source(name: str) -> None:
if not name:
return

if name in self.initializers:
if name not in seen_initializers:
graph_initializers.append(self.initializers[name])
seen_initializers.add(name)
return

if name in self.constants:
if name not in seen_initializers:
graph_initializers.append(self.constants[name])
seen_initializers.add(name)
return

vi = self.valueinfo.get(name)
if vi is None:
raise ValueError(f"Tensor '{name}' not found in valueinfo or initializers")
if name not in seen_inputs:
graph_inputs.append(vi)
seen_inputs.add(name)

def add_graph_output(name: str) -> None:
if not name or name in seen_outputs:
return

vi = self.valueinfo.get(name)
if vi is not None:
graph_outputs.append(vi)
else:
graph_outputs.append(
onnx.helper.make_tensor_value_info(name, onnx.TensorProto.UNDEFINED, None)
)
seen_outputs.add(name)

for inp_name in node.input:
if not inp_name:
continue

producer = self._find_producer_node(inp_name)
if producer is not None and producer.op_type == "DequantizeLinear":
producer_key = producer.name or "|".join(producer.output)
if producer_key not in seen_pre_nodes:
pre_nodes.append(self._clone_node_proto(producer))
seen_pre_nodes.add(producer_key)
for producer_input in producer.input:
add_graph_source(producer_input)
continue

add_graph_source(inp_name)

for out_name in node.output:
if not out_name:
continue

quantize_consumers = [
consumer
for consumer in self._find_consumer_nodes(out_name)
if consumer.op_type == "QuantizeLinear"
and consumer.input
and consumer.input[0] == out_name
]
if quantize_consumers:
for consumer in quantize_consumers:
consumer_key = consumer.name or "|".join(consumer.output)
if consumer_key not in seen_post_nodes:
post_nodes.append(self._clone_node_proto(consumer))
seen_post_nodes.add(consumer_key)
for consumer_input in consumer.input[1:]:
add_graph_source(consumer_input)
for consumer_output in consumer.output:
add_graph_output(consumer_output)
continue

add_graph_output(out_name)

nodes = [*pre_nodes, self._clone_node_proto(node), *post_nodes]
graph = onnx.helper.make_graph(
nodes,
f"runtime_test_{node.op_type}",
graph_inputs,
graph_outputs,
initializer=graph_initializers,
)

model = onnx.helper.make_model(
graph,
opset_imports=self._build_opset_imports(nodes, op_domain, opset_version),
)

try:
model = infer_onnx_shapes(model)
except Exception as e:
logger.debug("Shape inference failed for runtime-test model: %s", e)

return model

def _build_single_node_model(
self, node: onnx.NodeProto, op_domain: ONNXDomain, opset_version: int
) -> onnx.ModelProto:
Expand Down Expand Up @@ -1454,6 +1648,33 @@ def _build_single_node_model(

return model

def _generate_model_inputs(self, model: onnx.ModelProto) -> dict[str, np.ndarray]:
"""Generate dummy input data for a runtime-test model."""
input_feed: dict[str, np.ndarray] = {}
default_dim_size = 2 # Replace dynamic/unknown dims with this size
initializer_names = {initializer.name for initializer in model.graph.initializer}

for graph_input in model.graph.input:
if graph_input.name in initializer_names:
continue

shape, dtype_str = shape_and_dtype_from_valueinfo(graph_input)
if dtype_str is None:
raise ValueError(f"Input '{graph_input.name}' has no dtype information")

np_dtype = SupportedONNXType.from_annotation(dtype_str).np_type

if shape is None:
concrete_shape = (default_dim_size,)
else:
concrete_shape = tuple(
dim if isinstance(dim, int) and dim > 0 else default_dim_size for dim in shape
)

input_feed[graph_input.name] = np.zeros(concrete_shape, dtype=np_dtype)

return input_feed

def _generate_node_inputs(self, node: onnx.NodeProto) -> dict[str, np.ndarray]:
"""Generate dummy input data for a single-node model.

Expand Down Expand Up @@ -1531,6 +1752,7 @@ def _try_local_ep_check(
pattern_match: PatternMatchResult,
node_tags: list[NodeTag],
fallback_reason: str,
include_adjacent_qdq: bool = False,
save_node_types: set[str] | None = None,
conditions: Any | None = None,
) -> PatternRuntime | None:
Expand Down Expand Up @@ -1572,11 +1794,16 @@ def _try_local_ep_check(
)

try:
model = self._build_single_node_model(node, op_domain, opset_version)
input_feed = self._generate_node_inputs(node)
model = self._build_runtime_test_model(
node,
op_domain,
opset_version,
include_adjacent_qdq=include_adjacent_qdq,
)
input_feed = self._generate_model_inputs(model)
except Exception as e:
logger.debug(
"Failed to build single-node model for local EP check on %s (%s): %s",
"Failed to build runtime-test model for local EP check on %s (%s): %s",
node.name,
node.op_type,
e,
Expand Down Expand Up @@ -1808,6 +2035,7 @@ def _maybe_save_failed_node_result(
opset_version: int,
result: RuntimeTestResult,
cache_key: Any,
include_adjacent_qdq: bool = False,
save_node_types: set[str] | None = None,
) -> None:
"""Save unsupported or partial node models without re-running result computation."""
Expand All @@ -1820,7 +2048,12 @@ def _maybe_save_failed_node_result(
if not (is_unsupported or is_partial):
return

node_model = self._build_single_node_model(node, op_domain, opset_version)
node_model = self._build_runtime_test_model(
node,
op_domain,
opset_version,
include_adjacent_qdq=include_adjacent_qdq,
)
self._save_failed_node(
node,
node_model,
Expand Down Expand Up @@ -2079,6 +2312,7 @@ def get_pattern_id(is_qdq):
pattern_match,
node_tags,
fallback_reason,
include_adjacent_qdq=is_qdq,
save_node_types=save_node_types,
# conditions not available when domain/op
# rules are missing
Expand Down Expand Up @@ -2187,6 +2421,7 @@ def get_pattern_id(is_qdq):
pattern_match,
node_tags,
fallback_reason,
include_adjacent_qdq=is_qdq,
conditions=cache_key,
)
if local_result is not None:
Expand Down Expand Up @@ -2223,6 +2458,7 @@ def get_pattern_id(is_qdq):
pattern_match,
node_tags,
fallback_reason,
include_adjacent_qdq=is_qdq,
conditions=None,
)
if local_result is not None:
Expand Down Expand Up @@ -2335,6 +2571,7 @@ def get_pattern_id(is_qdq):
opset_version,
result,
cache_key,
include_adjacent_qdq=is_qdq,
save_node_types=save_node_types,
)

Expand Down
Loading
Loading