Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 36 additions & 28 deletions src/winml/modelkit/export/htp/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,23 @@ class HTPConfig:
"Embedding",
]

# Default export statistics structure
# empty_tags: CARDINAL RULE: Must be 0, default to max int to catch violations
# Default export statistics structure.
# Initialised before each export run and returned as a copy at the end.
DEFAULT_EXPORT_STATS: ClassVar[dict[str, Any]] = {
# Seconds elapsed from export() entry to final stat collection.
"export_time": 0.0,
# Number of named hierarchy modules discovered during tracing.
"hierarchy_modules": 0,
# Total ONNX graph nodes in the exported model.
"onnx_nodes": 0,
# Nodes that received a hierarchy_tag attribute (0 when embed_hierarchy_attributes=False).
"tagged_nodes": 0,
# CARDINAL RULE: tags with an empty/whitespace value must never exist.
# Sentinel sys.maxsize ensures any non-zero value is immediately visible as a violation.
"empty_tags": sys.maxsize,
# Percentage of onnx_nodes that were tagged (0.0 when embed_hierarchy_attributes=False).
"coverage_percentage": 0.0,
# Exporter strategy identifier, written into report metadata.
"strategy": STRATEGY_NAME,
}

Expand Down Expand Up @@ -295,20 +303,10 @@ def export(
export_time = time.time() - start_time
self._export_stats["export_time"] = export_time
self._export_stats["hierarchy_modules"] = len(self._hierarchy_data)
self._export_stats["onnx_nodes"] = len(onnx_model.graph.node)
self._export_stats["tagged_nodes"] = len(self._tagged_nodes)

# Calculate empty tags (should be 0 with our implementation)
empty_tag_count = sum(
1 for tag in self._tagged_nodes.values() if not tag or not tag.strip()
)
self._export_stats["empty_tags"] = empty_tag_count

# Calculate coverage percentage
total_nodes = len(onnx_model.graph.node)
tagged_nodes = len(self._tagged_nodes)
coverage = (tagged_nodes / total_nodes * 100.0) if total_nodes > 0 else 0.0
self._export_stats["coverage_percentage"] = coverage
self._export_stats["onnx_nodes"] = total_nodes

self._update_tag_stats(total_nodes)

# Update monitor with actual export time
monitor.data.export_time = export_time
Expand Down Expand Up @@ -493,6 +491,27 @@ def _get_optimum_patcher(model: nn.Module, task: str | None) -> Any:
)
return contextlib.nullcontext()

def _update_tag_stats(self, total_nodes: int) -> None:
"""Update tagged_nodes, empty_tags, and coverage_percentage in export stats.

Centralises the embed-aware calculation so _apply_hierarchy_tags and
the final stats block in export() always stay in sync.
All three stats are gated on embed_hierarchy_attributes: when hierarchy
embedding is disabled none of the tags are written to the model, so
all counts are reported as 0.
"""
if self.embed_hierarchy_attributes:
embedded_count = len(self._tagged_nodes)
empty_tags = sum(1 for tag in self._tagged_nodes.values() if not tag or not tag.strip())
else:
embedded_count = 0
empty_tags = 0
self._export_stats["tagged_nodes"] = embedded_count
self._export_stats["empty_tags"] = empty_tags
self._export_stats["coverage_percentage"] = (
embedded_count / total_nodes * 100.0 if total_nodes > 0 else 0.0
)

def _initialize_node_tagger(self, enable_operation_fallback: bool) -> None:
"""Create node tagger internally."""
self._node_tagger = create_node_tagger_from_hierarchy(
Expand All @@ -510,20 +529,9 @@ def _apply_hierarchy_tags(self, onnx_model: onnx.ModelProto) -> None:
self._tagging_stats = stats

# Update export stats
self._export_stats["onnx_nodes"] = len(onnx_model.graph.node)
self._export_stats["tagged_nodes"] = len(self._tagged_nodes)

# Calculate empty tags (should be 0 with our implementation)
empty_tag_count = sum(
1 for tag in self._tagged_nodes.values() if not tag or not tag.strip()
)
self._export_stats["empty_tags"] = empty_tag_count

# Calculate coverage percentage
total_nodes = len(onnx_model.graph.node)
tagged_nodes = len(self._tagged_nodes)
coverage = (tagged_nodes / total_nodes * 100.0) if total_nodes > 0 else 0.0
self._export_stats["coverage_percentage"] = coverage
self._export_stats["onnx_nodes"] = total_nodes
self._update_tag_stats(total_nodes)

def _embed_graph_metadata(
self, onnx_model: onnx.ModelProto, export_config: WinMLExportConfig
Expand Down
53 changes: 53 additions & 0 deletions tests/unit/export/test_htp_exporter_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Tests for HTPExporter export statistics correctness."""

from __future__ import annotations

from unittest.mock import MagicMock

from winml.modelkit.export.htp import HTPExporter


class TestHTPExporterTaggedNodesStats:
"""tagged_nodes, empty_tags, and coverage must be 0 when embed_hierarchy_attributes=False."""

def test_all_stats_zero_when_hierarchy_disabled(self) -> None:
exporter = HTPExporter(embed_hierarchy_attributes=False)
exporter._node_tagger = MagicMock()
exporter._node_tagger.tag_all_nodes.return_value = {
"node1": "/Model/Layer1",
"node2": "/Model/Layer2",
"node3": "/Model/Layer3",
}
exporter._node_tagger.get_tagging_statistics.return_value = {}

mock_model = MagicMock()
mock_model.graph.node = [MagicMock() for _ in range(5)]

exporter._apply_hierarchy_tags(mock_model)

assert exporter._export_stats["tagged_nodes"] == 0
assert exporter._export_stats["coverage_percentage"] == 0.0
assert exporter._export_stats["empty_tags"] == 0

def test_stats_populated_when_hierarchy_enabled(self) -> None:
"""Control: stats are populated normally when embedding is enabled."""
exporter = HTPExporter(embed_hierarchy_attributes=True)
exporter._node_tagger = MagicMock()
exporter._node_tagger.tag_all_nodes.return_value = {
"n1": "/t1",
"n2": "/t2",
}
exporter._node_tagger.get_tagging_statistics.return_value = {}

mock_model = MagicMock()
mock_model.graph.node = [MagicMock() for _ in range(4)]

exporter._apply_hierarchy_tags(mock_model)

assert exporter._export_stats["tagged_nodes"] == 2
assert exporter._export_stats["coverage_percentage"] == 50.0
assert exporter._export_stats["empty_tags"] == 0
Loading