Skip to content

Commit ec2461b

Browse files
ezyangpytorchmergebot
authored andcommitted
Remove proxy tensor's check for data dependent output (pytorch#93265)
We'll rely on the underlying fake tensor to raise an error in these cases. We only raise the error if there is an input to the data dependent operation that is a real tensor (and thus we are at risk of accidentally burning in real values) Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#93265 Approved by: https://github.com/albanD
1 parent d7a3f21 commit ec2461b

File tree

2 files changed

+22
-13
lines changed

2 files changed

+22
-13
lines changed

test/test_proxy_tensor.py

+15-9
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def test_f():
433433
torch.zeros(3), torch.zeros(3)
434434
)
435435

436-
if self.tracing_mode == "symbolic":
436+
if self.tracing_mode != "real":
437437
self.assertRaises(DataDependentOutputException, test_f)
438438
else:
439439
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
@@ -464,21 +464,27 @@ def f():
464464
blowup = val.repeat(1000)
465465
return bool(blowup.sum().item() == 2)
466466

467-
self.assertRaisesRegex(
468-
RuntimeError, "data-dependent",
469-
lambda: make_fx(f, tracing_mode=self.tracing_mode)()
470-
)
467+
def test_f():
468+
make_fx(f, tracing_mode=self.tracing_mode)()
469+
470+
if self.tracing_mode == "fake":
471+
self.assertRaises(DataDependentOutputException, test_f)
472+
else:
473+
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
471474

472475
def test_constant_random(self):
473476
def f():
474477
val = torch.tensor([2.0])
475478
val.normal_()
476479
return bool(val.item() == 2.1)
477480

478-
self.assertRaisesRegex(
479-
RuntimeError, "data-dependent",
480-
lambda: make_fx(f, tracing_mode=self.tracing_mode)()
481-
)
481+
def test_f():
482+
make_fx(f, tracing_mode=self.tracing_mode)()
483+
484+
if self.tracing_mode == "fake":
485+
self.assertRaises(DataDependentOutputException, test_f)
486+
else:
487+
self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
482488

483489
def test_decomposition_interpreter(self):
484490
def fn(x):

torch/fx/experimental/proxy_tensor.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -275,12 +275,15 @@ def can_handle_tensor(x):
275275
)
276276
with maybe_disable_fake_tensor_mode():
277277
return func(*const_args, **const_kwargs)
278-
# For symbolic tracing, we return a SymInt/SymFloat and try to
279-
# get further in the trace
280-
if proxy_mode.tracing_mode != "symbolic":
278+
# If any of the Tensor inputs are "real" (not FakeTensor), we may
279+
# incorrectly burn in constants by allowing this access. Raise
280+
# an error in this case
281+
if pytree.tree_all_only(torch.Tensor, lambda t: not isinstance(t, FakeTensor), (args, kwargs)):
281282
raise RuntimeError(
282283
f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! "
283-
"It's likely that this is caused by data-dependent control flow or similar."
284+
"It's likely that this is caused by data-dependent control flow or similar. "
285+
"It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' "
286+
"in your make_fx call."
284287
)
285288
proxy_args, proxy_kwargs = pytree.tree_map_only(
286289
(SymInt, SymFloat, SymBool),

0 commit comments

Comments
 (0)