Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/winml/modelkit/analyze/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,10 @@ def get_optimization_config(self, ep: str | None = None) -> WinMLOptimizationCon
continue

if action_item.optimization_options:
optim_options.update(action_item.optimization_options)
# Normalize kebab-case keys to snake_case (python_name)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider to update the original information config? Then we could get rid of this transform

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add both config and here? Adding here is to prevent wrong config...

# so they match the capability system's python_name format.
for key, value in action_item.optimization_options.items():
optim_options[key.replace("-", "_")] = value

# Create and return config from collected options
return WinMLOptimizationConfig(**optim_options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
{
"type": "GraphOptimization",
"optimization_options": {
"highdimRTR-lowdimRTR": true
"highdimRTR_lowdimRTR": true
}
}
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
{
"type": "GraphOptimization",
"optimization_options": {
"attention-expandedattention": true
"attention_expandedattention": true
}
}
],
Expand Down
10 changes: 10 additions & 0 deletions src/winml/modelkit/pattern/transpose_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,11 @@ def check_skeleton_result(
return None
return result

@property
def pattern_id(self) -> str:
"""Return pattern ID matching the information rule configuration."""
return f"SUBGRAPH/{type(self).__name__}"

def get_schema(self) -> PatternSchema:
"""Return the schema definition for ReshapeTransposeReshape pattern.

Expand Down Expand Up @@ -609,6 +614,11 @@ def get_internal_constants_and_attributes(

return internal_constants, internal_attributes

@property
def pattern_id(self) -> str:
"""Return pattern ID matching the information rule configuration."""
return f"SUBGRAPH/{type(self).__name__}"

def get_schema(self) -> PatternSchema:
"""Return the schema definition for ReshapeTransposeReshapeLowDim pattern."""
return _RESHAPE_TRANSPOSE_RESHAPE_SCHEMA
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/analyze/core/test_unified_pattern_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_load_default_config(self):
"SUBGRAPH/GeluPattern", # Shared by Gelu1-4
"SUBGRAPH/GemmPattern", # MatMulAdd
"SUBGRAPH/LayerNormalizationPattern", # Shared by Pow and Mul variants
"SUBGRAPH/ReshapeTransposeReshapePattern",
"SUBGRAPH/ReshapeTransposeReshapeOverlyHighDimPattern",
}
assert skeleton_pattern_ids == expected_ids, f"Pattern IDs mismatch: {skeleton_pattern_ids}"

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/analyze/models/test_information.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def test_default_information_json_has_reshape_transpose_reshape_entry(self):
action_items = entry["actions"][0]["action_items"]
assert len(action_items) == 1
assert action_items[0]["type"] == "GraphOptimization"
assert action_items[0]["optimization_options"] == {"highdimRTR-lowdimRTR": True}
assert action_items[0]["optimization_options"] == {"highdimRTR_lowdimRTR": True}

def test_qc_information_json_has_transpose_attention_entry(self):
"""Test that qc_information.json has the QC-specific TransposeAttentionPattern entry."""
Expand All @@ -408,7 +408,7 @@ def test_qc_information_json_has_transpose_attention_entry(self):
action_items = entry["actions"][0]["action_items"]
assert len(action_items) == 1
assert action_items[0]["type"] == "GraphOptimization"
assert action_items[0]["optimization_options"] == {"attention-expandedattention": True}
assert action_items[0]["optimization_options"] == {"attention_expandedattention": True}

def test_default_information_json_does_not_have_transpose_attention_entry(self):
"""Test that TransposeAttentionPattern is NOT in default_information.json (QC-specific)."""
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/analyze/pattern/test_transpose_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,29 @@ def test_unmerged_6d_rtr_is_matched(self) -> None:
assert len(results) == 1, (
"Unmerged 6-D RTR must match ReshapeTransposeReshapeOverlyHighDimPattern"
)


class TestRTRPatternIdAlignment:
"""Verify pattern_id values match the rule configuration expectations."""

def test_overly_high_dim_pattern_id(self) -> None:
"""ReshapeTransposeReshapeOverlyHighDimPattern must have distinct pattern_id."""
pattern = ReshapeTransposeReshapeOverlyHighDimPattern()
assert pattern.pattern_id == "SUBGRAPH/ReshapeTransposeReshapeOverlyHighDimPattern"

def test_low_dim_pattern_id(self) -> None:
"""ReshapeTransposeReshapeLowDimPattern must have distinct pattern_id."""
pattern = ReshapeTransposeReshapeLowDimPattern()
assert pattern.pattern_id == "SUBGRAPH/ReshapeTransposeReshapeLowDimPattern"

def test_high_dim_pattern_id_differs_from_schema_name(self) -> None:
"""HighDim pattern_id must NOT fall back to the shared schema name."""
pattern = ReshapeTransposeReshapeOverlyHighDimPattern()
schema_based = f"SUBGRAPH/{pattern.get_schema().name}"
assert pattern.pattern_id != schema_based

def test_low_dim_pattern_id_differs_from_schema_name(self) -> None:
"""LowDim pattern_id must NOT fall back to the shared schema name."""
pattern = ReshapeTransposeReshapeLowDimPattern()
schema_based = f"SUBGRAPH/{pattern.get_schema().name}"
assert pattern.pattern_id != schema_based
63 changes: 63 additions & 0 deletions tests/unit/analyze/test_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,69 @@ def test_get_optimization_config_custom_option(self, mock_output: AnalysisOutput
# Should accept any custom option
assert config.get("custom_fusion", False) is True

def test_get_optimization_config_normalizes_kebab_case(
self, mock_output: AnalysisOutput
) -> None:
"""Test get_optimization_config normalizes kebab-case keys to snake_case."""
rtr_action = Action(
pattern_from_id="SUBGRAPH/ReshapeTransposeReshapeOverlyHighDimPattern",
pattern_to_id="SUBGRAPH/ReshapeTransposeReshapeLowDimPattern",
details="RTR optimization",
action_items=[
ActionItem(
type="GraphOptimization",
optimization_options={"highdimRTR-lowdimRTR": True},
)
],
)
mock_output.results[0].information = [
Information(
pattern_id="SUBGRAPH/ReshapeTransposeReshapeOverlyHighDimPattern",
explanation="RTR pattern detected",
actions=[rtr_action],
)
]

result = AnalysisResult(output=mock_output)
config = result.get_optimization_config()

# Kebab-case key should be normalized to underscore
assert config.get("highdimRTR_lowdimRTR", False) is True
# Original kebab-case key should NOT be present
assert "highdimRTR-lowdimRTR" not in config

def test_get_optimization_config_mixed_kebab_and_snake(
self, mock_output: AnalysisOutput
) -> None:
"""Test get_optimization_config handles mix of kebab-case and snake_case keys."""
action = Action(
pattern_from_id="SUBGRAPH/TestPattern",
pattern_to_id="OP/Test",
details="Mixed key test",
action_items=[
ActionItem(
type="GraphOptimization",
optimization_options={
"already_snake": True,
"kebab-style-key": True,
},
)
],
)
mock_output.results[0].information = [
Information(
pattern_id="SUBGRAPH/TestPattern",
explanation="Test",
actions=[action],
)
]

result = AnalysisResult(output=mock_output)
config = result.get_optimization_config()

assert config.get("already_snake", False) is True
assert config.get("kebab_style_key", False) is True


class TestONNXStaticAnalyzer:
"""Tests for ONNXStaticAnalyzer."""
Expand Down
Loading