Skip to content

Commit c4e8d62

Browse files
janselpytorchmergebot
authored andcommitted
Improve getitem syntax for TensorType (pytorch#84555)
Allows `TensorType[Dyn, 3, Dyn]` instead of the prior `TensorType[(Dyn, 3, Dyn)]`. Pull Request resolved: pytorch#84555 Approved by: https://github.com/jamesr66a
1 parent fa99b7b commit c4e8d62

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

test/fx/test_gradual_type.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,16 @@ def test_annotations(self):
4444
where n is the corresoinding node in the resulting graph.
4545
"""
4646
class M(torch.nn.Module):
47-
def forward(self, x: TensorType((1, 2, 3, Dyn)), y: Dyn):
48-
return torch.add(x, y)
47+
def forward(self,
48+
x: TensorType((1, 2, 3, Dyn)),
49+
y: Dyn,
50+
z: TensorType[Dyn, 3, Dyn]):
51+
return torch.add(x, y) + z
4952

5053
module = M()
5154
symbolic_traced: torch.fx.GraphModule = symbolic_trace(module)
5255

53-
expected_ph_types = [TensorType((1, 2, 3, Dyn)), Dyn]
56+
expected_ph_types = [TensorType((1, 2, 3, Dyn)), Dyn, TensorType((Dyn, 3, Dyn))]
5457
expected_iter = iter(expected_ph_types)
5558

5659
for n in symbolic_traced.graph.nodes:

torch/fx/tensor_type.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def __eq__(self, other):
2828

2929
@staticmethod
3030
def __class_getitem__(*args):
31-
return TensorType(args[0])
31+
if len(args) == 1 and isinstance(args[0], tuple):
32+
args = args[0]
33+
return TensorType(tuple(args))
3234

3335

3436
class _DynType:

0 commit comments

Comments
 (0)