Skip to content

Commit c8f5c19

Browse files
tugsbayasgalanpytorchmergebot
authored andcommitted
Fix bug in dynamic shapes multiply (pytorch#90336)
Pull Request resolved: pytorch#90336 Approved by: https://github.com/ezyang
1 parent 2cf7032 commit c8f5c19

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

test/dynamo/test_repros.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2127,6 +2127,18 @@ def compiled_fn(x):
21272127
for buffer_ref, buffer_test in zip(m_ref.buffers(), m_test.buffers()):
21282128
self.assertTrue(same(buffer_ref, buffer_test))
21292129

2130+
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
2131+
def test_dynamic_shapes_right_side(self):
2132+
def f(x):
2133+
return torch.ones(5 * x.shape[0])
2134+
2135+
inp = torch.randn(6, 5)
2136+
2137+
gm, _ = torch._dynamo.export(
2138+
f, torch.randn(4, 5), aten_graph=True, tracing_mode="symbolic"
2139+
)
2140+
self.assertEqual(gm(inp).shape, f(inp).shape)
2141+
21302142

21312143
if __name__ == "__main__":
21322144
from torch._dynamo.test_case import run_tests

torch/_dynamo/variables/builtin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,11 @@ def call_mul(self, tx, a, b):
576576
return b.__class__(
577577
items=b.items * a.as_python_constant(), mutable_local=MutableLocal()
578578
).add_options(self, a, b)
579+
# TODO this doesn't generalize in other builtin operators.
580+
elif isinstance(a, variables.ConstantVariable) and isinstance(
581+
b, DynamicShapeVariable
582+
):
583+
return b.call_method(tx, "__rmul__", [a], {})
579584
else:
580585
return a.call_method(tx, "__mul__", [b], {})
581586

0 commit comments

Comments
 (0)