diff --git a/src/winml/modelkit/analyze/core/runtime_checker_query.py b/src/winml/modelkit/analyze/core/runtime_checker_query.py index 954f24ffb..4f85b0e7f 100644 --- a/src/winml/modelkit/analyze/core/runtime_checker_query.py +++ b/src/winml/modelkit/analyze/core/runtime_checker_query.py @@ -1141,19 +1141,23 @@ def __init__( dynamic axis indices. """ self.dynamic_axis_strict_mode = dynamic_axis_strict_mode + self.model_proto: onnx.ModelProto = model_proto self.model_path = Path(model_path) if model_path is not None else None # Try shape inference: standard ONNX first, then symbolic (onnxruntime) try: # Standard ONNX shape inference — uses temp file for models # with external data (avoids silent empty-graph result). - self.model_proto = infer_onnx_shapes(model_proto) + inferred_model = infer_onnx_shapes(model_proto) + self.model_proto = inferred_model if inferred_model is not None else model_proto # Then try to enhance with symbolic shape inference # if available which supports Microsoft domain try: from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - self.model_proto = SymbolicShapeInference.infer_shapes(self.model_proto) + symbolic_model = SymbolicShapeInference.infer_shapes(self.model_proto) + if symbolic_model is not None: + self.model_proto = symbolic_model except Exception as e: # If symbolic shape inference fails, continue with standard inference result logger.debug( @@ -1357,6 +1361,196 @@ def _get_ep_checker(self) -> EPChecker: ) return self._ep_checker + @staticmethod + def _clone_node_proto(node: onnx.NodeProto) -> onnx.NodeProto: + """Clone a node proto so extracted test models do not reuse graph objects.""" + cloned = onnx.NodeProto() + cloned.CopyFrom(node) + return cloned + + def _find_producer_node(self, tensor_name: str) -> onnx.NodeProto | None: + """Return the node that produces a tensor, if any.""" + if not tensor_name: + return None + + for candidate in self.model_proto.graph.node: + if tensor_name in candidate.output: + return candidate + return None + + def _find_consumer_nodes(self, tensor_name: str) -> list[onnx.NodeProto]: + """Return nodes that consume a tensor.""" + if not tensor_name: + return [] + + return [ + candidate for candidate in self.model_proto.graph.node if tensor_name in candidate.input + ] + + def _build_opset_imports( + self, + nodes: list[onnx.NodeProto], + fallback_op_domain: ONNXDomain, + fallback_opset_version: int, + ) -> list[onnx.OperatorSetIdProto]: + """Build opset imports for an extracted runtime-test model.""" + opset_imports: list[onnx.OperatorSetIdProto] = [] + added_domains: set[str] = set() + saw_non_default_domain = False + + def add_domain(domain_str: str, version: int) -> None: + canonical_domain = "" if domain_str in {"", ONNXDomain.AI_ONNX.value} else domain_str + if canonical_domain in added_domains: + return + + added_domains.add(canonical_domain) + effective_version = max(version, 7) if canonical_domain == "" else version + opset_imports.append(onnx.helper.make_opsetid(canonical_domain, effective_version)) + + for included_node in nodes: + raw_domain = included_node.domain or "" + try: + node_domain = ONNXDomain.from_str(raw_domain) + add_domain(node_domain.schema_domain, self.opset_versions.get(node_domain, 1)) + saw_non_default_domain = saw_non_default_domain or node_domain != ONNXDomain.AI_ONNX + except ValueError: + add_domain(raw_domain, 1) + saw_non_default_domain = saw_non_default_domain or bool(raw_domain) + + if not opset_imports: + add_domain(fallback_op_domain.schema_domain, fallback_opset_version) + saw_non_default_domain = fallback_op_domain != ONNXDomain.AI_ONNX + + if saw_non_default_domain and "" not in added_domains: + default_opset = self.opset_versions.get(ONNXDomain.AI_ONNX, 17) + add_domain("", default_opset) + + return opset_imports + + def _build_runtime_test_model( + self, + node: onnx.NodeProto, + op_domain: ONNXDomain, + opset_version: int, + include_adjacent_qdq: bool = False, + ) -> onnx.ModelProto: + """Build the model used for local EP fallback and failed-node artifacts. + + For QDQ operators, include the adjacent DequantizeLinear and QuantizeLinear + nodes so the local test model preserves the same quantized context. + """ + if not include_adjacent_qdq: + return self._build_single_node_model(node, op_domain, opset_version) + + graph_inputs: list[onnx.ValueInfoProto] = [] + graph_initializers: list[onnx.TensorProto] = [] + graph_outputs: list[onnx.ValueInfoProto] = [] + pre_nodes: list[onnx.NodeProto] = [] + post_nodes: list[onnx.NodeProto] = [] + seen_inputs: set[str] = set() + seen_initializers: set[str] = set() + seen_outputs: set[str] = set() + seen_pre_nodes: set[str] = set() + seen_post_nodes: set[str] = set() + + def add_graph_source(name: str) -> None: + if not name: + return + + if name in self.initializers: + if name not in seen_initializers: + graph_initializers.append(self.initializers[name]) + seen_initializers.add(name) + return + + if name in self.constants: + if name not in seen_initializers: + graph_initializers.append(self.constants[name]) + seen_initializers.add(name) + return + + vi = self.valueinfo.get(name) + if vi is None: + raise ValueError(f"Tensor '{name}' not found in valueinfo or initializers") + if name not in seen_inputs: + graph_inputs.append(vi) + seen_inputs.add(name) + + def add_graph_output(name: str) -> None: + if not name or name in seen_outputs: + return + + vi = self.valueinfo.get(name) + if vi is not None: + graph_outputs.append(vi) + else: + graph_outputs.append( + onnx.helper.make_tensor_value_info(name, onnx.TensorProto.UNDEFINED, None) + ) + seen_outputs.add(name) + + for inp_name in node.input: + if not inp_name: + continue + + producer = self._find_producer_node(inp_name) + if producer is not None and producer.op_type == "DequantizeLinear": + producer_key = producer.name or "|".join(producer.output) + if producer_key not in seen_pre_nodes: + pre_nodes.append(self._clone_node_proto(producer)) + seen_pre_nodes.add(producer_key) + for producer_input in producer.input: + add_graph_source(producer_input) + continue + + add_graph_source(inp_name) + + for out_name in node.output: + if not out_name: + continue + + quantize_consumers = [ + consumer + for consumer in self._find_consumer_nodes(out_name) + if consumer.op_type == "QuantizeLinear" + and consumer.input + and consumer.input[0] == out_name + ] + if quantize_consumers: + for consumer in quantize_consumers: + consumer_key = consumer.name or "|".join(consumer.output) + if consumer_key not in seen_post_nodes: + post_nodes.append(self._clone_node_proto(consumer)) + seen_post_nodes.add(consumer_key) + for consumer_input in consumer.input[1:]: + add_graph_source(consumer_input) + for consumer_output in consumer.output: + add_graph_output(consumer_output) + continue + + add_graph_output(out_name) + + nodes = [*pre_nodes, self._clone_node_proto(node), *post_nodes] + graph = onnx.helper.make_graph( + nodes, + f"runtime_test_{node.op_type}", + graph_inputs, + graph_outputs, + initializer=graph_initializers, + ) + + model = onnx.helper.make_model( + graph, + opset_imports=self._build_opset_imports(nodes, op_domain, opset_version), + ) + + try: + model = infer_onnx_shapes(model) + except Exception as e: + logger.debug("Shape inference failed for runtime-test model: %s", e) + + return model + def _build_single_node_model( self, node: onnx.NodeProto, op_domain: ONNXDomain, opset_version: int ) -> onnx.ModelProto: @@ -1454,6 +1648,33 @@ def _build_single_node_model( return model + def _generate_model_inputs(self, model: onnx.ModelProto) -> dict[str, np.ndarray]: + """Generate dummy input data for a runtime-test model.""" + input_feed: dict[str, np.ndarray] = {} + default_dim_size = 2 # Replace dynamic/unknown dims with this size + initializer_names = {initializer.name for initializer in model.graph.initializer} + + for graph_input in model.graph.input: + if graph_input.name in initializer_names: + continue + + shape, dtype_str = shape_and_dtype_from_valueinfo(graph_input) + if dtype_str is None: + raise ValueError(f"Input '{graph_input.name}' has no dtype information") + + np_dtype = SupportedONNXType.from_annotation(dtype_str).np_type + + if shape is None: + concrete_shape = (default_dim_size,) + else: + concrete_shape = tuple( + dim if isinstance(dim, int) and dim > 0 else default_dim_size for dim in shape + ) + + input_feed[graph_input.name] = np.zeros(concrete_shape, dtype=np_dtype) + + return input_feed + def _generate_node_inputs(self, node: onnx.NodeProto) -> dict[str, np.ndarray]: """Generate dummy input data for a single-node model. @@ -1531,6 +1752,7 @@ def _try_local_ep_check( pattern_match: PatternMatchResult, node_tags: list[NodeTag], fallback_reason: str, + include_adjacent_qdq: bool = False, save_node_types: set[str] | None = None, conditions: Any | None = None, ) -> PatternRuntime | None: @@ -1572,11 +1794,16 @@ def _try_local_ep_check( ) try: - model = self._build_single_node_model(node, op_domain, opset_version) - input_feed = self._generate_node_inputs(node) + model = self._build_runtime_test_model( + node, + op_domain, + opset_version, + include_adjacent_qdq=include_adjacent_qdq, + ) + input_feed = self._generate_model_inputs(model) except Exception as e: logger.debug( - "Failed to build single-node model for local EP check on %s (%s): %s", + "Failed to build runtime-test model for local EP check on %s (%s): %s", node.name, node.op_type, e, @@ -1808,6 +2035,7 @@ def _maybe_save_failed_node_result( opset_version: int, result: RuntimeTestResult, cache_key: Any, + include_adjacent_qdq: bool = False, save_node_types: set[str] | None = None, ) -> None: """Save unsupported or partial node models without re-running result computation.""" @@ -1820,7 +2048,12 @@ def _maybe_save_failed_node_result( if not (is_unsupported or is_partial): return - node_model = self._build_single_node_model(node, op_domain, opset_version) + node_model = self._build_runtime_test_model( + node, + op_domain, + opset_version, + include_adjacent_qdq=include_adjacent_qdq, + ) self._save_failed_node( node, node_model, @@ -2079,6 +2312,7 @@ def get_pattern_id(is_qdq): pattern_match, node_tags, fallback_reason, + include_adjacent_qdq=is_qdq, save_node_types=save_node_types, # conditions not available when domain/op # rules are missing @@ -2187,6 +2421,7 @@ def get_pattern_id(is_qdq): pattern_match, node_tags, fallback_reason, + include_adjacent_qdq=is_qdq, conditions=cache_key, ) if local_result is not None: @@ -2223,6 +2458,7 @@ def get_pattern_id(is_qdq): pattern_match, node_tags, fallback_reason, + include_adjacent_qdq=is_qdq, conditions=None, ) if local_result is not None: @@ -2335,6 +2571,7 @@ def get_pattern_id(is_qdq): opset_version, result, cache_key, + include_adjacent_qdq=is_qdq, save_node_types=save_node_types, ) diff --git a/tests/unit/analyze/core/test_qdq.py b/tests/unit/analyze/core/test_qdq.py index ec9f8029d..8edef6c89 100644 --- a/tests/unit/analyze/core/test_qdq.py +++ b/tests/unit/analyze/core/test_qdq.py @@ -13,6 +13,7 @@ - _collect_qdq_types functionality via RuntimeCheckerQuery """ +import numpy as np import pytest from onnx import TensorProto, helper @@ -362,6 +363,56 @@ def test_dq_weight_vs_activation(self) -> None: assert "dq_out" in query.input_to_dq_type assert isinstance(query.input_to_dq_type["dq_out"], QDQTypeInfo) + def test_runtime_test_model_preserves_qdq_context(self) -> None: + """QDQ local fallback models keep the adjacent DQ and Q nodes.""" + query = RuntimeCheckerQuery( + model_proto=self._make_qdq_model(), + ep_name="QNNExecutionProvider", + device_type="NPU", + ) + relu_node = next(node for node in query.model_proto.graph.node if node.name == "relu_node") + + runtime_test_model = query._build_runtime_test_model( + relu_node, + ONNXDomain.AI_ONNX, + 17, + include_adjacent_qdq=True, + ) + + assert [node.op_type for node in runtime_test_model.graph.node] == [ + "DequantizeLinear", + "Relu", + "QuantizeLinear", + ] + assert [graph_input.name for graph_input in runtime_test_model.graph.input] == ["x"] + assert [graph_output.name for graph_output in runtime_test_model.graph.output] == ["y"] + assert {initializer.name for initializer in runtime_test_model.graph.initializer} == { + "dq_scale", + "q_scale", + "zp", + } + + def test_runtime_test_model_inputs_use_quantized_graph_input(self) -> None: + """Local QDQ fallback feeds the extracted model's graph inputs, not the inner op input.""" + query = RuntimeCheckerQuery( + model_proto=self._make_qdq_model(), + ep_name="QNNExecutionProvider", + device_type="NPU", + ) + relu_node = next(node for node in query.model_proto.graph.node if node.name == "relu_node") + + runtime_test_model = query._build_runtime_test_model( + relu_node, + ONNXDomain.AI_ONNX, + 17, + include_adjacent_qdq=True, + ) + input_feed = query._generate_model_inputs(runtime_test_model) + + assert list(input_feed) == ["x"] + assert input_feed["x"].shape == (1, 3, 4, 4) + assert input_feed["x"].dtype == np.int8 + class TestIterShouldQDQCombinations: """Unit tests for _iter_should_qdq_combinations.