Skip to content

Commit 25c5e24

Browse files
Juntian777facebook-github-bot
authored andcommitted
Enable E2E testing for the numerical discrepancy framework for XNNPACK. (#12723)
Summary: This PR introduces an end-to-end test framework for ExecuTorch's XNNPACK backend. It adds utilities to generate ETRecord and ETDump files with debug buffers for models, enabling numerical gap checks between runtime and AOT outputs. The PR also includes a test for the Vision Transformer (ViT) model to verify numeric gap thresholds. Additionally, it adds necessary build targets and runtime support for the new event tracer feature. This improves testing and debugging capabilities for ExecuTorch's XNNPACK backend. Differential Revision: D78380933
1 parent 6c4f934 commit 25c5e24

File tree

4 files changed

+325
-0
lines changed

4 files changed

+325
-0
lines changed

devtools/inspector/_inspector.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1401,6 +1401,13 @@ def calculate_numeric_gap(self, distance: str = "MSE"):
14011401
runtime_intermediate_outputs, runtime_debug_handle_to_op_names = (
14021402
self._get_runtime_intermediate_outputs_and_op_names()
14031403
)
1404+
if (
1405+
len(aot_intermediate_outputs) == 0
1406+
or len(runtime_debug_handle_to_op_names) == 0
1407+
):
1408+
raise ValueError(
1409+
"Inspector Events' debug_data is not populated properly which is required for calculating numerical gap"
1410+
)
14041411
mapping = map_runtime_aot_intermediate_outputs(
14051412
aot_intermediate_outputs, runtime_intermediate_outputs
14061413
)

devtools/tests/xnnpack/TARGETS

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
3+
load("@fbsource//tools/target_determinator/macros:ci.bzl", "ci")
4+
5+
oncall("executorch")
6+
7+
python_library(
8+
name = "xnnpack_test_utils",
9+
srcs = [
10+
"xnnpack_test_utils.py",
11+
],
12+
deps = [
13+
"//caffe2:torch",
14+
"//executorch/devtools/bundled_program:config",
15+
"//executorch/devtools/bundled_program:core",
16+
"//executorch/devtools/bundled_program/serialize:lib",
17+
"//executorch/devtools:lib",
18+
"//executorch/exir:lib",
19+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
20+
"//executorch/backends/xnnpack/utils:xnnpack_utils",
21+
"//executorch/extension/pybindings:portable_lib",
22+
],
23+
)
24+
25+
26+
python_unittest(
27+
name = "torchvision_vit_test",
28+
srcs = [
29+
"torchvision_vit_test.py",
30+
],
31+
# You still need to pass `-c executorch.event_tracer_enabled:true`
32+
# if you want to manually invoke buck.
33+
labels = ci.labels(
34+
ci.buckconfig("executorch.event_tracer_enabled", "true"),
35+
),
36+
deps = [
37+
"//executorch/devtools/tests/xnnpack:xnnpack_test_utils",
38+
"//executorch/exir/fb:bento_deps",
39+
"//executorch/extension/fb/ptez:lib",
40+
"//fair_infra/data/iopath/iopath:iopath",
41+
],
42+
)
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import os
9+
import unittest
10+
11+
import torch
12+
13+
from executorch.devtools.tests.xnnpack.xnnpack_test_utils import (
14+
check_disturbance,
15+
check_numeric_gap,
16+
generate_etrecord_and_etdump,
17+
)
18+
19+
from torchvision import models
20+
21+
22+
class TestViTModel(unittest.TestCase):
23+
def setUp(self):
24+
vit = models.vision_transformer.vit_b_16(weights="IMAGENET1K_V1")
25+
self.model = vit.eval()
26+
self.model_inputs = (torch.randn(1, 3, 224, 224),)
27+
28+
def test_numeric_gap(self):
29+
etrecord_path, etdump_path, debug_buffer_path = generate_etrecord_and_etdump(
30+
self.model,
31+
self.model_inputs,
32+
)
33+
34+
# Check if the output files exist
35+
self.assertTrue(
36+
os.path.exists(etrecord_path), f"ETRecord not found: {etrecord_path}"
37+
)
38+
self.assertTrue(os.path.exists(etdump_path), f"ETDump not found: {etdump_path}")
39+
self.assertTrue(
40+
os.path.exists(debug_buffer_path),
41+
f"Debug buffer not found: {debug_buffer_path}",
42+
)
43+
44+
metric = "MSE"
45+
max_allowed_gap = 1e-6
46+
is_within_threshold, max_gap = check_numeric_gap(
47+
etdump_path,
48+
etrecord_path,
49+
debug_buffer_path,
50+
metric=metric,
51+
max_allowed_gap=max_allowed_gap,
52+
)
53+
54+
# Check if the numeric gap is within threshold
55+
self.assertTrue(
56+
is_within_threshold,
57+
f"Numeric gap {max_gap} exceeds allowed threshold {max_allowed_gap}",
58+
)
59+
60+
def test_numeric_gap_with_disturbance(self):
61+
# Check if we can detect the first numeric gap directly affected by the disturbance
62+
etrecord_path, etdump_path, debug_buffer_path = generate_etrecord_and_etdump(
63+
self.model,
64+
self.model_inputs,
65+
disturb=True,
66+
)
67+
68+
metric = "MSE"
69+
max_allowed_gap = 1e-6
70+
disturbance_threshold = 1e-3
71+
is_within_thresholds = check_disturbance(
72+
etdump_path,
73+
etrecord_path,
74+
debug_buffer_path,
75+
metric=metric,
76+
row=1,
77+
max_allowed_gap=max_allowed_gap,
78+
disturbance_threshold=disturbance_threshold,
79+
)
80+
81+
self.assertTrue(is_within_thresholds)
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import copy
9+
import os
10+
import tempfile
11+
12+
import uuid
13+
14+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
15+
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
16+
17+
from executorch.devtools import BundledProgram, generate_etrecord
18+
from executorch.devtools.bundled_program.config import MethodTestCase, MethodTestSuite
19+
from executorch.exir import to_edge
20+
21+
from executorch.extension.pybindings.portable_lib import (
22+
_load_for_executorch_from_buffer, # @manual
23+
)
24+
from torch.export import export
25+
26+
27+
def _generate_new_paths():
28+
temp_dir = tempfile.mkdtemp()
29+
30+
# Use uuid to generate unique filenames
31+
etrecord_filename = f"etrecord_{uuid.uuid4().hex}.bin"
32+
etdump_filename = f"etdump_{uuid.uuid4().hex}.etdp"
33+
debug_buffer_filename = f"debug_buffer_{uuid.uuid4().hex}.bin"
34+
etrecord_path = os.path.join(temp_dir, etrecord_filename)
35+
etdump_path = os.path.join(temp_dir, etdump_filename)
36+
debug_buffer_path = os.path.join(temp_dir, debug_buffer_filename)
37+
return etrecord_path, etdump_path, debug_buffer_path
38+
39+
40+
def generate_etrecord_and_etdump(
41+
model,
42+
model_inputs,
43+
debug_buffer_size=1024 * 1024 * 1024,
44+
method_name="forward",
45+
num_test_cases=2,
46+
disturb=False,
47+
):
48+
"""
49+
Helper to generate ETRecord and ETDump (with debug buffer) for a model.
50+
51+
Returns:
52+
Tuple of (etrecord_path, etdump_path, debug_buffer_path)
53+
"""
54+
55+
etrecord_path, etdump_path, debug_buffer_path = _generate_new_paths()
56+
57+
aten_model = export(model, model_inputs, strict=True)
58+
59+
edge_compile_config = get_xnnpack_edge_compile_config()
60+
61+
edge_program_manager = to_edge(aten_model, compile_config=edge_compile_config)
62+
63+
edge_program_manager_copy = copy.deepcopy(edge_program_manager)
64+
65+
# Apply the disturbance if the flag is set
66+
if disturb:
67+
import torch
68+
69+
for _, exported_program in edge_program_manager_copy._edge_programs.items():
70+
for module in exported_program.graph_module.modules():
71+
if not isinstance(module, torch.fx.GraphModule):
72+
continue
73+
for node in module.graph.nodes:
74+
if node.op == "call_function" and node.name == "aten_add_tensor":
75+
node.target = torch.ops.aten.sub.Tensor
76+
module.recompile()
77+
module.graph.eliminate_dead_code()
78+
79+
edge_program_manager = edge_program_manager.to_backend(XnnpackPartitioner())
80+
81+
et_program_manager = edge_program_manager.to_executorch()
82+
83+
method_graphs = {method_name: export(model, model_inputs, strict=True)}
84+
inputs = [list(model_inputs) for _ in range(num_test_cases)]
85+
method_test_suites = [
86+
MethodTestSuite(
87+
method_name=method_name,
88+
test_cases=[
89+
MethodTestCase(
90+
inputs=inp, expected_outputs=getattr(model, method_name)(*inp)
91+
)
92+
for inp in inputs
93+
],
94+
)
95+
]
96+
executorch_program = (
97+
to_edge(method_graphs, compile_config=edge_compile_config)
98+
.to_backend(XnnpackPartitioner())
99+
.to_executorch()
100+
)
101+
bundled_program = BundledProgram(executorch_program, method_test_suites)
102+
103+
# Generate ETRecord
104+
generate_etrecord(etrecord_path, edge_program_manager_copy, bundled_program)
105+
106+
# Generate ETDump and debug buffer
107+
buff = et_program_manager.buffer
108+
executorch_module = _load_for_executorch_from_buffer(
109+
buff,
110+
enable_etdump=True,
111+
debug_buffer_size=debug_buffer_size,
112+
)
113+
executorch_module.run_method(method_name, tuple(model_inputs))
114+
executorch_module.write_etdump_result_to_file(etdump_path, debug_buffer_path)
115+
116+
return etrecord_path, etdump_path, debug_buffer_path
117+
118+
119+
from typing import Tuple
120+
121+
import pandas as pd
122+
from executorch.devtools import Inspector
123+
124+
125+
def check_numeric_gap(
126+
etdump_path: str,
127+
etrecord_path: str,
128+
debug_buffer_path: str,
129+
metric: str,
130+
max_allowed_gap: float,
131+
) -> Tuple[bool, float]:
132+
"""
133+
Create an Inspector and check if the maximum numeric gap for a given metric is less than the allowed threshold.
134+
Args:
135+
etdump_path: Path to the ETDump file.
136+
etrecord_path: Path to the ETRecord file.
137+
debug_buffer_path: Path to the debug buffer file.
138+
metric: The metric name to calculate the numeric gap for (e.g., "MSE").
139+
max_allowed_gap: The maximum allowed gap threshold.
140+
Returns:
141+
A tuple (is_within_threshold, max_gap) where:
142+
- is_within_threshold (bool): True if max gap < max_allowed_gap, else False.
143+
- max_gap (float): The maximum gap value found.
144+
"""
145+
inspector = Inspector(
146+
etdump_path=etdump_path,
147+
etrecord=etrecord_path,
148+
debug_buffer_path=debug_buffer_path,
149+
)
150+
df: pd.DataFrame = inspector.calculate_numeric_gap(metric)
151+
max_gap = df["gap"].apply(lambda x: max(x) if isinstance(x, list) else x).max()
152+
is_within_threshold = max_gap < max_allowed_gap
153+
return is_within_threshold, max_gap
154+
155+
156+
def check_disturbance(
157+
etdump_path: str,
158+
etrecord_path: str,
159+
debug_buffer_path: str,
160+
metric: str,
161+
row: int,
162+
max_allowed_gap: float,
163+
disturbance_threshold: float,
164+
) -> bool:
165+
"""
166+
Check if the given row in the DataFrame has a gap greater than the disturbance threshold.
167+
168+
Args:
169+
etdump_path: Path to the ETDump file.
170+
etrecord_path: Path to the ETRecord file.
171+
debug_buffer_path: Path to the debug buffer file.
172+
metric: The metric name to calculate the numeric gap for (e.g., "MSE").
173+
disturbance_threshold: The threshold to detect a disturbance.
174+
max_allowed_gap: The maximum allowed gap threshold before the disturbance(row).
175+
row: The row number to check for a disturbance.
176+
"""
177+
inspector = Inspector(
178+
etdump_path=etdump_path,
179+
etrecord=etrecord_path,
180+
debug_buffer_path=debug_buffer_path,
181+
)
182+
df: pd.DataFrame = inspector.calculate_numeric_gap(metric)
183+
184+
# Get the maximum gap for the given row
185+
disturbance_row_gap = max(df.loc[row, "gap"])
186+
# Get the maximum gap for the rows before the given row
187+
if row > 0:
188+
before_disturbance_row_gap = max(df.loc[: row - 1, "gap"].apply(max))
189+
else:
190+
before_disturbance_row_gap = 0
191+
192+
return (
193+
disturbance_row_gap > disturbance_threshold
194+
and before_disturbance_row_gap < max_allowed_gap
195+
)

0 commit comments

Comments
 (0)