Skip to content

Commit 4d073af

Browse files
Revert "[inductor][dynamo] Include operator name in size/stride/alignment assertion (pytorch#152353)"
This reverts commit 725bbb6. Reverted pytorch#152353 on behalf of https://github.com/jeanschmidt due to seems to have broken a few internal tests, @jansel may you help the author get his PR merged? ([comment](pytorch#152353 (comment)))
1 parent 741539a commit 4d073af

File tree

7 files changed

+20
-170
lines changed

7 files changed

+20
-170
lines changed

benchmarks/dynamo/pr_time_benchmarks/expected_results.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ add_loop_inductor,compile_time_instruction_count,29370000000,0.015
1010

1111

1212

13-
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44130000000,0.025
13+
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44480000000,0.025
1414

1515

1616

test/distributed/test_functional_api.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -715,13 +715,6 @@ def run_with_backward():
715715

716716
_, codes = run_and_get_code(run_with_backward)
717717
for code in codes:
718-
assert_keywords = ["assert_size_stride", "assert_alignment"]
719-
filtered_lines = [
720-
line
721-
for line in code.splitlines()
722-
if not any(assert_key in line for assert_key in assert_keywords)
723-
]
724-
code = "\n".join(filtered_lines)
725718
FileCheck().check_count(
726719
"_c10d_functional.all_to_all_single.default", 1, exactly=True
727720
).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run(

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -231,14 +231,6 @@ def _test_code_common(
231231
torch.compile(mod, fullgraph=True, dynamic=check_dynamic),
232232
*clone_inputs,
233233
)
234-
assert_keywords = ["assert_size_stride", "assert_alignment"]
235-
filtered_lines = [
236-
line
237-
for line in source_code.splitlines()
238-
if not any(assert_key in line for assert_key in assert_keywords)
239-
]
240-
source_code = "\n".join(filtered_lines)
241-
242234
for op in include_ops:
243235
self.assertIn(op, source_code)
244236
if num_include_ops is not None:

test/inductor/test_torchinductor.py

Lines changed: 7 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import torch._dynamo.config as dynamo_config
3131
import torch._inductor.aoti_eager
3232
import torch.nn as nn
33-
from torch._C._dynamo.guards import assert_alignment, assert_size_stride
3433
from torch._dispatch.python import enable_python_dispatcher
3534
from torch._dynamo.debug_utils import aot_graph_input_parser
3635
from torch._dynamo.device_interface import get_interface_for_device
@@ -1411,10 +1410,9 @@ def fn(a, b):
14111410
)
14121411
_, code = run_and_get_code(fn, x, y)
14131412
code = " ".join(code)
1414-
if config.cpp_wrapper:
1415-
self.assertEqual(code.count("view_dtype"), 3)
1416-
else:
1417-
self.assertEqual(code.count("aten.view"), 9)
1413+
self.assertEqual(
1414+
code.count("view_dtype" if config.cpp_wrapper else "aten.view"), 3
1415+
)
14181416

14191417
def test_add_complex5(self):
14201418
def fn(a, b, alpha):
@@ -11884,82 +11882,6 @@ def fn(x):
1188411882
check_lowp=False,
1188511883
)
1188611884

11887-
@requires_gpu()
11888-
@skip_if_not_triton
11889-
@skip_if_cpp_wrapper("skip cpp_wrapper tests")
11890-
@config.patch(implicit_fallbacks=True)
11891-
def test_generated_code_has_size_stride_assert(self):
11892-
def foo(x):
11893-
return 3 * x
11894-
11895-
def foo_meta(x):
11896-
return torch.empty_like(x)
11897-
11898-
define_custom_op_for_test("foo", foo, foo_meta)
11899-
11900-
def fn(x):
11901-
a = torch.nn.functional.relu(x)
11902-
b = torch.ops.test.foo(a)
11903-
return b
11904-
11905-
a = torch.randn((16, 32), device=self.device)
11906-
11907-
_, code = run_and_get_code(
11908-
torch.compile(fn),
11909-
a,
11910-
)
11911-
if not is_dynamic_shape_enabled():
11912-
FileCheck().check(
11913-
"assert_size_stride(buf2, (16, 32), (32, 1), 'torch.ops.test.foo.default')"
11914-
).run(code[0])
11915-
11916-
@requires_gpu()
11917-
@skip_if_not_triton
11918-
@skip_if_cpp_wrapper("skip cpp_wrapper tests")
11919-
@config.patch(implicit_fallbacks=True)
11920-
def test_generated_code_has_alignment_assert(self):
11921-
def foo(x):
11922-
return 3 * x
11923-
11924-
def foo_meta(x):
11925-
return torch.empty_like(x)
11926-
11927-
define_custom_op_for_test("foo", foo, foo_meta)
11928-
11929-
def fn(x):
11930-
a = torch.nn.functional.relu(x)
11931-
b = torch.ops.test.foo(a)
11932-
return b
11933-
11934-
a = torch.randn((16, 32), device=self.device)
11935-
11936-
_, code = run_and_get_code(
11937-
torch.compile(fn),
11938-
a,
11939-
)
11940-
if not is_dynamic_shape_enabled():
11941-
FileCheck().check(
11942-
"assert_alignment(buf2, 16, 'torch.ops.test.foo.default')"
11943-
).run(code[0])
11944-
11945-
def test_assert_size_stride_op_name_pass(self):
11946-
tensor = torch.empty((16, 32))
11947-
assert_size_stride(tensor, (16, 32), (32, 1), "torch.ops.dummy.op_name")
11948-
11949-
def test_assert_size_stride_op_name_fail(self):
11950-
tensor = torch.empty((16, 32))
11951-
with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"):
11952-
assert_size_stride(tensor, (32, 64), (32, 1), "torch.ops.dummy.op_name")
11953-
11954-
def test_assert_alignment_op_name_pass(self):
11955-
tensor = torch.empty((16, 32))
11956-
assert_alignment(tensor, 16, "torch.ops.dummy.op_name")
11957-
11958-
def test_assert_alignment_op_name_fail(self):
11959-
tensor = torch.empty((16, 32))
11960-
with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"):
11961-
assert_alignment(tensor, 0, "torch.ops.dummy.op_name")
11962-
1196311885
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
1196411886
@torch._inductor.config.patch(implicit_fallbacks=True)
1196511887
def test_custom_op_unbacked_symints(self):
@@ -13092,12 +13014,12 @@ def f(x):
1309213014
code = run_and_get_triton_code(f, x)
1309313015

1309413016
if is_dynamic_shape_enabled():
13095-
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1)").check(
13096-
"assert_size_stride(buf2, (s77, s27), (s27, 1)"
13017+
FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check(
13018+
"assert_size_stride(buf2, (s77, s27), (s27, 1))"
1309713019
).run(code)
1309813020
else:
13099-
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1)").check(
13100-
"assert_size_stride(buf2, (16, 32), (32, 1)"
13021+
FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check(
13022+
"assert_size_stride(buf2, (16, 32), (32, 1))"
1310113023
).run(code)
1310213024

1310313025
@requires_cuda

torch/_C/_dynamo/guards.pyi

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,6 @@ def assert_size_stride(
176176
item: torch.Tensor,
177177
size: torch.types._size,
178178
stride: torch.types._size,
179-
op_name: str | None = None,
180-
): ...
181-
def assert_alignment(
182-
item: torch.Tensor,
183-
alignment: int,
184-
op_name: str | None = None,
185179
): ...
186180
def check_obj_id(obj: object, expected: int) -> bool: ...
187181
def check_type_id(obj: object, expected: int) -> bool: ...

torch/_inductor/ir.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5772,42 +5772,26 @@ def codegen_kwargs(self, skip_out=False): # type: ignore[no-untyped-def]
57725772
]
57735773
return kwargs
57745774

5775-
def get_op_name(self) -> str:
5776-
if self.fx_node is not None:
5777-
target = self.fx_node.target
5778-
op_namespace = getattr(target, "__module__", "unknown_namespace")
5779-
op_namespace = op_namespace.replace("._ops.", ".ops.")
5780-
op_namespace = op_namespace.rsplit(".", 1)[0]
5781-
op_name = f"{op_namespace}.{target}"
5782-
else:
5783-
op_name = "unknown_op"
5784-
return op_name
5785-
57865775
def codegen_size_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def]
57875776
if config.size_asserts and not V.graph.cpp_wrapper:
57885777
# comparing strides for 0 size tensor is tricky. Ignore them for now.
57895778
if sympy_product(self.get_size()) == 0:
57905779
return
57915780
size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
57925781
stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
5793-
op_name = self.get_op_name()
5782+
57945783
wrapper.writeline(
5795-
f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})"
5784+
f"assert_size_stride({self.get_name()}, {size}, {stride})"
57965785
)
57975786

57985787
def codegen_alignment_asserts(self, wrapper) -> None: # type: ignore[no-untyped-def]
57995788
if config.alignment_asserts and not V.graph.cpp_wrapper:
58005789
name = self.get_name()
58015790
aligned = name not in V.graph.unaligned_buffers
5802-
op_name = self.get_op_name()
58035791
if aligned:
5804-
wrapper.writeline(
5805-
f"assert_alignment({name}, {GPU_ALIGN_BYTES}, {op_name!r})"
5806-
)
5792+
wrapper.writeline(f"assert_alignment({name}, {GPU_ALIGN_BYTES})")
58075793
else:
5808-
wrapper.writeline(
5809-
f"# buffer {name} (op: {op_name}) is assumed to be not aligned"
5810-
)
5794+
wrapper.writeline(f"# buffer {name} is assumed to be not aligned")
58115795

58125796
def get_group_stride(self): # type: ignore[no-untyped-def]
58135797
"""

torch/csrc/dynamo/guards.cpp

Lines changed: 8 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -844,38 +844,21 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
844844
PyObject* item = nullptr;
845845
PyObject* size = nullptr;
846846
PyObject* stride = nullptr;
847-
const char* op_name = nullptr;
848-
849-
if (!PyArg_ParseTuple(args, "OOO|s", &item, &size, &stride, &op_name)) {
847+
if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
850848
return nullptr;
851849
}
852850
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
853-
std::stringstream msg;
854-
msg << "expected Tensor()";
855-
if (op_name) {
856-
msg << " for op: " << op_name;
857-
}
858-
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
851+
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
859852
return nullptr;
860853
}
861854
if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
862-
std::stringstream msg;
863-
msg << "expected tuple()";
864-
if (op_name) {
865-
msg << " for op: " << op_name;
866-
}
867-
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
855+
PyErr_SetString(PyExc_TypeError, "expected tuple()");
868856
return nullptr;
869857
}
870858
at::Tensor tensor = THPVariable_Unpack(item);
871859
int64_t ndim = tensor.ndimension();
872860
if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
873-
std::stringstream msg;
874-
msg << "wrong number of dimensions" << ndim;
875-
if (op_name) {
876-
msg << " for op: " << op_name;
877-
}
878-
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
861+
PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
879862
return nullptr;
880863
}
881864

@@ -904,9 +887,6 @@ static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
904887
}
905888

906889
if (num_errors) {
907-
if (op_name) {
908-
msg << "\nError in op: " << op_name;
909-
}
910890
msg << "\nThis error most often comes from a incorrect fake (aka meta) kernel for a custom op.";
911891
msg << "\nUse torch.library.opcheck to test your custom op.";
912892
msg << "\nSee https://pytorch.org/docs/stable/library.html#torch.library.opcheck";
@@ -924,27 +904,15 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) {
924904
*/
925905
PyObject* item = nullptr;
926906
unsigned long alignment = 0;
927-
const char* op_name = nullptr;
928-
929-
if (!PyArg_ParseTuple(args, "Ok|s", &item, &alignment, &op_name)) {
907+
if (!PyArg_ParseTuple(args, "Ok", &item, &alignment)) {
930908
return nullptr;
931909
}
932910
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
933-
std::stringstream msg;
934-
msg << "expected Tensor()";
935-
if (op_name) {
936-
msg << " for op: " << op_name;
937-
}
938-
PyErr_SetString(PyExc_TypeError, msg.str().c_str());
911+
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
939912
return nullptr;
940913
}
941914
if (alignment == 0) {
942-
std::stringstream msg;
943-
msg << "alignment cannot be 0";
944-
if (op_name) {
945-
msg << " in op: " << op_name;
946-
}
947-
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
915+
PyErr_SetString(PyExc_AssertionError, "alignment can not be 0");
948916
return nullptr;
949917
}
950918

@@ -954,10 +922,7 @@ static PyObject* assert_alignment(PyObject* dummy, PyObject* args) {
954922
size_t itemsize = tensor.itemsize();
955923
if (storage_offset * itemsize % alignment != 0) {
956924
std::stringstream msg;
957-
if (op_name) {
958-
msg << "\nError in op: " << op_name;
959-
}
960-
msg << "\nExpect the tensor to be " << alignment
925+
msg << "Expect the tensor to be " << alignment
961926
<< " bytes aligned. Fail due to storage_offset=" << storage_offset
962927
<< " itemsize=" << itemsize;
963928
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());

0 commit comments

Comments
 (0)