Skip to content

Commit a6b783e

Browse files
eellisonpytorchmergebot
authored andcommitted
Refine conditions under which index.Tensor has a dynamic shape
Pull Request resolved: pytorch#79809 Approved by: https://github.com/davidberard98
1 parent e675c33 commit a6b783e

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

torch/_subclasses/fake_tensor.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,19 @@ def clone(fake_mode, func, input, memory_format=None):
214214
out = torch.ops.aten._to_copy(input.to("meta"), memory_format=memory_format)
215215
return FakeTensor(fake_mode, out, out_device)
216216

217-
@register_op_impl(lambda func: torch.Tag.dynamic_output_shape in func.tags) # type: ignore[attr-defined]
217+
# index.Tensor data-dependent in only some conditions
218+
@register_op_impl(lambda func: torch.Tag.dynamic_output_shape in func.tags # type: ignore[attr-defined]
219+
and func != aten.index.Tensor)
218220
def data_dep_op(fake_mode, func, *args, **kwargs):
219221
raise DynamicOutputShapeException(func)
220222

223+
# Bool Indices get Expanded as Masks
224+
# See: IndexingUtils.h:expandTensors
225+
def check_no_bool_index_tensors(func, self, indices):
226+
for index in indices:
227+
if index is not None and index.dtype in (torch.bool, torch.uint8):
228+
raise DynamicOutputShapeException(func)
229+
221230
# Meta tensors give you the ability to run PyTorch code without having to
222231
# actually do computation through tensors allocated on a `meta` device.
223232
# Because the device is `meta`, meta tensors do not model device propagation.
@@ -437,6 +446,8 @@ def check_non_fake_tensor(x):
437446
if run_impl_check(func):
438447
return op_impl(self, func, *args, **kwargs)
439448

449+
if func == aten.index.Tensor:
450+
check_no_bool_index_tensors(func, *args, **kwargs)
440451

441452
self.in_kernel_invocation = True
442453
try:

0 commit comments

Comments
 (0)