Skip to content

Commit bab35eb

Browse files
exclamafortepytorchmergebot
authored andcommitted
fix intermediate debug information with cpp_wrapper (pytorch#145527)
Summary: before fix, code like: ```cpp aoti_torch_print_tensor_handle(buf0, "after_launch - triton_poi_fused_randn_0 - buf0"); aoti_torch_print_tensor_handle(buf1, "after_launch - triton_poi_fused_randn_0 - buf1"); printf("[ after_launch - triton_poi_fused_randn_0 - 0: %ld ]", 0); printf(" "); printf("[ after_launch - triton_poi_fused_randn_0 - 1228800L: %ld ]", 1228800L); printf(" "); ``` was generated, which is a syntax error. Test Plan: New unit test. Pull Request resolved: pytorch#145527 Approved by: https://github.com/desertfire
1 parent 6818945 commit bab35eb

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

test/inductor/test_gpu_cpp_wrapper.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,26 @@ class GpuWrapperTemplate:
5454
class TestGpuWrapper(InductorTestCase):
5555
device = GPU_TYPE
5656

57+
def test_aoti_debug_printer_works_on_constants(self):
58+
batch_size = 32
59+
seq_length = 50
60+
hidden_size = 768
61+
62+
def test_fn():
63+
inp = torch.randn(batch_size, seq_length, hidden_size, device=self.device)
64+
weight = torch.randn(hidden_size, hidden_size, device=self.device)
65+
matmul_output = inp @ weight
66+
torch.nn.LayerNorm(hidden_size, device=self.device)(matmul_output)
67+
return True
68+
69+
comp = torch.compile(
70+
options={
71+
"cpp_wrapper": True,
72+
"aot_inductor.debug_intermediate_value_printer": "2",
73+
}
74+
)(test_fn)
75+
comp()
76+
5777

5878
class DynamicShapesGpuWrapperGpuTests(InductorTestCase):
5979
device = GPU_TYPE

torch/_inductor/codegen/debug_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def codegen_intermediate_tensor_value_print(
267267
),
268268
):
269269
V.graph.wrapper_code.writeline(
270-
f'printf("[ {launch_prefix} - {kernel_name} - {arg}: %ld ]", {arg}); printf("\\n");'
270+
f'printf("[ {launch_prefix} - {kernel_name} - {arg}: %ld ]", {arg}); printf("\\\\n");'
271271
)
272272
else:
273273
if arg_signatures is None and self.kernel_type == "cpp" or "extern":

0 commit comments

Comments
 (0)