@@ -12916,6 +12916,7 @@ def forward(self, x, y):
12916
12916
@testing.expectedFailureCppSerDes # TODO: When we deserialize we somehow hardcode sympy.lower to 2
12917
12917
@testing.expectedFailureSerDerNonStrict
12918
12918
@testing.expectedFailureSerDer
12919
+ @torch.fx.experimental._config.patch(backed_size_oblivious=True)
12919
12920
def test_baddbmm(self):
12920
12921
class M(torch.nn.Module):
12921
12922
def __init__(self):
@@ -12934,7 +12935,6 @@ def forward(self, x):
12934
12935
x2 = torch.randn(64, 1, 64, dtype=torch.float16)
12935
12936
m = M()
12936
12937
12937
- torch.fx.experimental._config.backed_size_oblivious = True
12938
12938
ep = export(m, (x2,), dynamic_shapes=({1: Dim("batch")},))
12939
12939
12940
12940
self.assertTrue(torch.allclose(m(x2), ep.module()(x2)))
@@ -13400,6 +13400,7 @@ def forward(self, x):
13400
13400
self.assertTrue(torch.allclose(comp_mod(inp1), mod(inp1)))
13401
13401
self.assertTrue(torch.allclose(comp_mod(inp2), mod(inp2)))
13402
13402
13403
+ @torch.fx.experimental._config.patch(backed_size_oblivious=True)
13403
13404
def test_repeat_interleave(self):
13404
13405
class M(torch.nn.Module):
13405
13406
def forward(self, values, batch_sizes):
@@ -13411,7 +13412,6 @@ def forward(self, values, batch_sizes):
13411
13412
)
13412
13413
13413
13414
inp = (torch.randint(0, 10, (1, 3)), torch.randint(0, 10, (1,)))
13414
- torch.fx.experimental._config.backed_size_oblivious = True
13415
13415
ep = torch.export.export(
13416
13416
M(), inp, dynamic_shapes=({0: Dim("dim")}, {0: Dim("dim")})
13417
13417
)
0 commit comments