|
30 | 30 | import torch._dynamo.config as dynamo_config
|
31 | 31 | import torch._inductor.aoti_eager
|
32 | 32 | import torch.nn as nn
|
33 |
| -from torch._C._dynamo.guards import assert_alignment, assert_size_stride |
34 | 33 | from torch._dispatch.python import enable_python_dispatcher
|
35 | 34 | from torch._dynamo.debug_utils import aot_graph_input_parser
|
36 | 35 | from torch._dynamo.device_interface import get_interface_for_device
|
@@ -1411,10 +1410,9 @@ def fn(a, b):
|
1411 | 1410 | )
|
1412 | 1411 | _, code = run_and_get_code(fn, x, y)
|
1413 | 1412 | code = " ".join(code)
|
1414 |
| - if config.cpp_wrapper: |
1415 |
| - self.assertEqual(code.count("view_dtype"), 3) |
1416 |
| - else: |
1417 |
| - self.assertEqual(code.count("aten.view"), 9) |
| 1413 | + self.assertEqual( |
| 1414 | + code.count("view_dtype" if config.cpp_wrapper else "aten.view"), 3 |
| 1415 | + ) |
1418 | 1416 |
|
1419 | 1417 | def test_add_complex5(self):
|
1420 | 1418 | def fn(a, b, alpha):
|
@@ -11884,82 +11882,6 @@ def fn(x):
|
11884 | 11882 | check_lowp=False,
|
11885 | 11883 | )
|
11886 | 11884 |
|
11887 |
| - @requires_gpu() |
11888 |
| - @skip_if_not_triton |
11889 |
| - @skip_if_cpp_wrapper("skip cpp_wrapper tests") |
11890 |
| - @config.patch(implicit_fallbacks=True) |
11891 |
| - def test_generated_code_has_size_stride_assert(self): |
11892 |
| - def foo(x): |
11893 |
| - return 3 * x |
11894 |
| - |
11895 |
| - def foo_meta(x): |
11896 |
| - return torch.empty_like(x) |
11897 |
| - |
11898 |
| - define_custom_op_for_test("foo", foo, foo_meta) |
11899 |
| - |
11900 |
| - def fn(x): |
11901 |
| - a = torch.nn.functional.relu(x) |
11902 |
| - b = torch.ops.test.foo(a) |
11903 |
| - return b |
11904 |
| - |
11905 |
| - a = torch.randn((16, 32), device=self.device) |
11906 |
| - |
11907 |
| - _, code = run_and_get_code( |
11908 |
| - torch.compile(fn), |
11909 |
| - a, |
11910 |
| - ) |
11911 |
| - if not is_dynamic_shape_enabled(): |
11912 |
| - FileCheck().check( |
11913 |
| - "assert_size_stride(buf2, (16, 32), (32, 1), 'torch.ops.test.foo.default')" |
11914 |
| - ).run(code[0]) |
11915 |
| - |
11916 |
| - @requires_gpu() |
11917 |
| - @skip_if_not_triton |
11918 |
| - @skip_if_cpp_wrapper("skip cpp_wrapper tests") |
11919 |
| - @config.patch(implicit_fallbacks=True) |
11920 |
| - def test_generated_code_has_alignment_assert(self): |
11921 |
| - def foo(x): |
11922 |
| - return 3 * x |
11923 |
| - |
11924 |
| - def foo_meta(x): |
11925 |
| - return torch.empty_like(x) |
11926 |
| - |
11927 |
| - define_custom_op_for_test("foo", foo, foo_meta) |
11928 |
| - |
11929 |
| - def fn(x): |
11930 |
| - a = torch.nn.functional.relu(x) |
11931 |
| - b = torch.ops.test.foo(a) |
11932 |
| - return b |
11933 |
| - |
11934 |
| - a = torch.randn((16, 32), device=self.device) |
11935 |
| - |
11936 |
| - _, code = run_and_get_code( |
11937 |
| - torch.compile(fn), |
11938 |
| - a, |
11939 |
| - ) |
11940 |
| - if not is_dynamic_shape_enabled(): |
11941 |
| - FileCheck().check( |
11942 |
| - "assert_alignment(buf2, 16, 'torch.ops.test.foo.default')" |
11943 |
| - ).run(code[0]) |
11944 |
| - |
11945 |
| - def test_assert_size_stride_op_name_pass(self): |
11946 |
| - tensor = torch.empty((16, 32)) |
11947 |
| - assert_size_stride(tensor, (16, 32), (32, 1), "torch.ops.dummy.op_name") |
11948 |
| - |
11949 |
| - def test_assert_size_stride_op_name_fail(self): |
11950 |
| - tensor = torch.empty((16, 32)) |
11951 |
| - with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"): |
11952 |
| - assert_size_stride(tensor, (32, 64), (32, 1), "torch.ops.dummy.op_name") |
11953 |
| - |
11954 |
| - def test_assert_alignment_op_name_pass(self): |
11955 |
| - tensor = torch.empty((16, 32)) |
11956 |
| - assert_alignment(tensor, 16, "torch.ops.dummy.op_name") |
11957 |
| - |
11958 |
| - def test_assert_alignment_op_name_fail(self): |
11959 |
| - tensor = torch.empty((16, 32)) |
11960 |
| - with self.assertRaisesRegex(AssertionError, "torch.ops.dummy.op_name"): |
11961 |
| - assert_alignment(tensor, 0, "torch.ops.dummy.op_name") |
11962 |
| - |
11963 | 11885 | @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
11964 | 11886 | @torch._inductor.config.patch(implicit_fallbacks=True)
|
11965 | 11887 | def test_custom_op_unbacked_symints(self):
|
@@ -13092,12 +13014,12 @@ def f(x):
|
13092 | 13014 | code = run_and_get_triton_code(f, x)
|
13093 | 13015 |
|
13094 | 13016 | if is_dynamic_shape_enabled():
|
13095 |
| - FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1)").check( |
13096 |
| - "assert_size_stride(buf2, (s77, s27), (s27, 1)" |
| 13017 | + FileCheck().check("assert_size_stride(buf1, (s77, s27), (s27, 1))").check( |
| 13018 | + "assert_size_stride(buf2, (s77, s27), (s27, 1))" |
13097 | 13019 | ).run(code)
|
13098 | 13020 | else:
|
13099 |
| - FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1)").check( |
13100 |
| - "assert_size_stride(buf2, (16, 32), (32, 1)" |
| 13021 | + FileCheck().check("assert_size_stride(buf1, (16, 32), (32, 1))").check( |
| 13022 | + "assert_size_stride(buf2, (16, 32), (32, 1))" |
13101 | 13023 | ).run(code)
|
13102 | 13024 |
|
13103 | 13025 | @requires_cuda
|
|
0 commit comments