Skip to content

Commit 434eb16

Browse files
ezyangpytorchmergebot
authored andcommitted
Correctly restore pybind11 error_already_set (pytorch#93238)
We would handle py::error_already_set correctly from pybind11 bindings, but not from our regular TH bindings, which meant that anything from an inner pybind11 function call was getting unconditionally transformed into a RuntimeError. Not too many cases where we do this, but PySymNodeImpl was one of them. To test this, I need to raise a non-RuntimeError from a function which is invoked from pybind11 and then propagated to a non-pybind11 call site. I introduce GuardOnDataDependentSymNode for expressly this purpose (this is how I discovered the bug anyway.) Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#93238 Approved by: https://github.com/Skylion007, https://github.com/albanD
1 parent 3e4d0e8 commit 434eb16

File tree

5 files changed

+22
-7
lines changed

5 files changed

+22
-7
lines changed

test/test_autograd.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9768,7 +9768,7 @@ def test_backward_out_of_context(self):
97689768
out = (a**2).sum()
97699769

97709770
msg = "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
9771-
with self.assertRaisesRegex(RuntimeError, msg):
9771+
with self.assertRaisesRegex(AssertionError, msg):
97729772
out.backward()
97739773

97749774
# Different context
@@ -9777,7 +9777,7 @@ def test_backward_out_of_context(self):
97779777
out = (a**2).sum()
97789778

97799779
with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
9780-
with self.assertRaisesRegex(RuntimeError, msg):
9780+
with self.assertRaisesRegex(AssertionError, msg):
97819781
out.backward()
97829782

97839783
def test_disallow_nesting(self):

test/test_dynamic_shapes.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from torch.utils._pytree import tree_map
1919
from torch.fx.experimental import symbolic_shapes
2020
from torch.fx.experimental.proxy_tensor import make_fx
21-
from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode, sym_sqrt, sym_int, to_node
21+
from torch.fx.experimental.symbolic_shapes import ShapeEnv, sym_float, guard_int, SymNode, \
22+
sym_sqrt, sym_int, to_node, GuardOnDataDependentSymNode
2223
from torch.utils._python_dispatch import TorchDispatchMode
2324
from torch import SymInt
2425

@@ -388,6 +389,12 @@ def test_int_conversion(self):
388389
a0 = create_symint(shape_env, 2)
389390
self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0))
390391

392+
@skipIfNoSympy
393+
def test_data_dependent_guard(self):
394+
shape_env = ShapeEnv()
395+
s0 = shape_env.create_unbacked_symint()
396+
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
397+
391398
@skipIfNoSympy
392399
def test_non_overlapping_and_dense(self):
393400
shape_env = ShapeEnv()

test/test_python_dispatch.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1579,7 +1579,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
15791579

15801580
err_msg = "no implementation found for 'torch.ops.aten.sym_stride'"
15811581
e = StridesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
1582-
with self.assertRaisesRegex(RuntimeError, err_msg):
1582+
with self.assertRaisesRegex(TypeError, err_msg):
15831583
e.stride()
15841584

15851585
e = StridesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
@@ -1631,7 +1631,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
16311631

16321632
err_msg = "no implementation found for 'torch.ops.aten.sym_size'"
16331633
e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
1634-
with self.assertRaisesRegex(RuntimeError, err_msg):
1634+
with self.assertRaisesRegex(TypeError, err_msg):
16351635
e.size()
16361636

16371637
e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)

torch/csrc/Exceptions.h

+4
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ static inline void PyErr_SetString(PyObject* type, const std::string& message) {
6868
e.restore(); \
6969
retstmnt; \
7070
} \
71+
catch (py::error_already_set & e) { \
72+
e.restore(); \
73+
retstmnt; \
74+
} \
7175
_CATCH_GENERIC_ERROR(IndexError, PyExc_IndexError, retstmnt) \
7276
_CATCH_GENERIC_ERROR(ValueError, PyExc_ValueError, retstmnt) \
7377
_CATCH_GENERIC_ERROR(TypeError, PyExc_TypeError, retstmnt) \

torch/fx/experimental/symbolic_shapes.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121

2222
log = logging.getLogger(__name__)
2323

24+
class GuardOnDataDependentSymNode(RuntimeError):
25+
pass
26+
2427
try:
2528
import sympy # type: ignore[import]
2629
from sympy.printing.precedence import precedence # type: ignore[import] # noqa: F401
@@ -1064,9 +1067,10 @@ def _make_data_dependent_error(self, expr):
10641067
f"Data dependent variable '{s}' allocated at:\n{s.stack}"
10651068
for s in expr.free_symbols
10661069
)
1067-
return RuntimeError(
1070+
return GuardOnDataDependentSymNode(
10681071
f"\n\n{accesses}\n"
1069-
"RuntimeError: It appears that you're trying to get a value out of symbolic int/float "
1072+
"GuardOnDataDependentSymNode: It appears that you're trying to get "
1073+
"a value out of symbolic int/float "
10701074
"whose value is data-dependent (and thus we do not know the true value.) "
10711075
f"The expression we were trying to evaluate is {expr}. "
10721076
"Scroll up to see where each of these data-dependent accesses originally occurred."

0 commit comments

Comments
 (0)