@@ -214,10 +214,19 @@ def clone(fake_mode, func, input, memory_format=None):
214
214
out = torch .ops .aten ._to_copy (input .to ("meta" ), memory_format = memory_format )
215
215
return FakeTensor (fake_mode , out , out_device )
216
216
217
- @register_op_impl (lambda func : torch .Tag .dynamic_output_shape in func .tags ) # type: ignore[attr-defined]
217
+ # index.Tensor data-dependent in only some conditions
218
+ @register_op_impl (lambda func : torch .Tag .dynamic_output_shape in func .tags # type: ignore[attr-defined]
219
+ and func != aten .index .Tensor )
218
220
def data_dep_op (fake_mode , func , * args , ** kwargs ):
219
221
raise DynamicOutputShapeException (func )
220
222
223
+ # Bool Indices get Expanded as Masks
224
+ # See: IndexingUtils.h:expandTensors
225
+ def check_no_bool_index_tensors (func , self , indices ):
226
+ for index in indices :
227
+ if index is not None and index .dtype in (torch .bool , torch .uint8 ):
228
+ raise DynamicOutputShapeException (func )
229
+
221
230
# Meta tensors give you the ability to run PyTorch code without having to
222
231
# actually do computation through tensors allocated on a `meta` device.
223
232
# Because the device is `meta`, meta tensors do not model device propagation.
@@ -437,6 +446,8 @@ def check_non_fake_tensor(x):
437
446
if run_impl_check (func ):
438
447
return op_impl (self , func , * args , ** kwargs )
439
448
449
+ if func == aten .index .Tensor :
450
+ check_no_bool_index_tensors (func , * args , ** kwargs )
440
451
441
452
self .in_kernel_invocation = True
442
453
try :
0 commit comments