Skip to content

Commit b3e4229

Browse files
yanboliangpytorchmergebot
authored andcommitted
[Dynamo] Support out variants of ops mutate the tensors out of the function frame (pytorch#93177)
Fixes pytorch#93136 Pull Request resolved: pytorch#93177 Approved by: https://github.com/jansel
1 parent 129f136 commit b3e4229

File tree

3 files changed

+44
-4
lines changed

3 files changed

+44
-4
lines changed

test/dynamo/test_dynamic_shapes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ def make_dynamic_cls(cls):
5555
# Cannot call sizes() on tensor with symbolic sizes/strides
5656
)
5757

58+
unittest.expectedFailure(
59+
DynamicShapesReproTests.test_sort_out2_dynamic_shapes
60+
# Cannot call sizes() on tensor with symbolic sizes/strides
61+
)
62+
5863
# DynamicShapesExportTests
5964
unittest.expectedFailure(
6065
DynamicShapesExportTests.test_export_with_constant_list_nonzero_dynamic_shapes

test/dynamo/test_repros.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,24 @@ def fn():
15311531
opt_fn = torch._dynamo.optimize("eager")(fn)
15321532
opt_fn()
15331533

1534+
def test_sort_out2(self):
1535+
class MyModule(torch.nn.Module):
1536+
def __init__(self):
1537+
super().__init__()
1538+
self.register_buffer("sorted", torch.ones(4, 4))
1539+
self.register_buffer("indices", torch.ones(4, 4, dtype=torch.long))
1540+
1541+
def forward(self, x):
1542+
torch.sort(x, out=(self.sorted, self.indices))
1543+
return (x + 1, self.sorted, self.indices)
1544+
1545+
x = torch.randn(4, 4)
1546+
m = MyModule()
1547+
ref = m(x)
1548+
opt_m = torch._dynamo.optimize("eager")(m)
1549+
res = opt_m(x)
1550+
self.assertTrue(same(ref, res))
1551+
15341552
def test_sigmoid_out(self):
15351553

15361554
dtype = torch.float32
@@ -1546,6 +1564,23 @@ def fn():
15461564
opt_fn = torch._dynamo.optimize("eager")(fn)
15471565
opt_fn()
15481566

1567+
def test_sigmoid_out2(self):
1568+
class MyModule(torch.nn.Module):
1569+
def __init__(self):
1570+
super().__init__()
1571+
self.register_buffer("base", torch.ones(4, 4))
1572+
1573+
def forward(self, x):
1574+
torch.sigmoid(x, out=self.base)
1575+
return x + self.base
1576+
1577+
x = torch.randn(4, 4)
1578+
m = MyModule()
1579+
ref = m(x)
1580+
opt_m = torch._dynamo.optimize("eager")(m)
1581+
res = opt_m(x)
1582+
self.assertTrue(same(ref, res))
1583+
15491584
def test_slice_into_list_mutable(self):
15501585
class Mod(torch.nn.Module):
15511586
def forward(self, listy):

torch/_dynamo/variables/torch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -491,13 +491,13 @@ def get_state_from_generator():
491491
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
492492
]
493493
for idx, name in enumerate(output_tensor_names):
494-
assert name in tx.symbolic_locals
495-
tx.symbolic_locals[name] = tensor_variable.items[idx]
494+
if name in tx.symbolic_locals:
495+
tx.symbolic_locals[name] = tensor_variable.items[idx]
496496
elif isinstance(tensor_variable, TensorVariable):
497497
assert isinstance(kwargs["out"], TensorVariable)
498498
name = tx.find_symbolic_locals_name(kwargs["out"])
499-
assert name in tx.symbolic_locals
500-
tx.symbolic_locals[name] = tensor_variable
499+
if name in tx.symbolic_locals:
500+
tx.symbolic_locals[name] = tensor_variable
501501
else:
502502
unimplemented(f"out variant of {type(kwargs['out'])}")
503503

0 commit comments

Comments
 (0)