Skip to content

Commit 388b245

Browse files
soulitzerpytorchmergebot
authored andcommitted
Expose autograd.graph.Node as an abstract base class (pytorch#91475)
This PR: - registers all of the codegened Nodes to the torch._C._functions module, this is where special nodes like AccumulateGrad are already registered. - creates a autograd.graph.Node abstract base class that all of the newly registered nodes subclass from. We make the subclassing happen by implementing the ``__subclasshook__`` method - enables static type checking to work and also enables Sphinx to generate documentation for the Node and its methods - handles both the custom Function and codegened cases Pull Request resolved: pytorch#91475 Approved by: https://github.com/albanD
1 parent 0157e2e commit 388b245

File tree

9 files changed

+198
-12
lines changed

9 files changed

+198
-12
lines changed

docs/source/autograd.rst

+31-3
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,40 @@ Anomaly detection
260260
.. autoclass:: set_detect_anomaly
261261

262262

263-
Saved tensors default hooks
264-
^^^^^^^^^^^^^^^^^^^^^^^^^^^
263+
Autograd graph
264+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
265+
Autograd exposes methods that allow one to inspect the graph and interpose behavior during
266+
the backward pass.
267+
268+
The ``grad_fn`` attribute of a :class:`torch.Tensor` holds a :class:`torch.autograd.graph.Node`
269+
if the tensor is the output of a operation that was recorded by autograd (i.e., grad_mode is
270+
enabled and at least one of the inputs required gradients), or ``None`` otherwise.
271+
272+
.. autosummary::
273+
:toctree: generated
274+
:nosignatures:
275+
276+
graph.Node.name
277+
graph.Node.metadata
278+
graph.Node.next_functions
279+
graph.Node.register_hook
280+
graph.Node.register_prehook
265281

266282
Some operations need intermediary results to be saved during the forward pass
267283
in order to execute the backward pass.
268-
You can define how these saved tensors should be packed / unpacked using hooks.
284+
These intermediary results are saved as attributes on the ``grad_fn`` and can be accessed.
285+
For example::
286+
287+
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
288+
>>> b = a.exp()
289+
>>> print(isinstance(b.grad_fn, torch.autograd.graph.Node))
290+
True
291+
>>> print(dir(b.grad_fn))
292+
['__call__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '_raw_saved_result', '_register_hook_dict', '_saved_result', 'metadata', 'name', 'next_functions', 'register_hook', 'register_prehook', 'requires_grad']
293+
>>> print(torch.allclose(b.grad_fn._saved_result, b))
294+
True
295+
296+
You can also define how these saved tensors should be packed / unpacked using hooks.
269297
A common application is to trade compute for memory by saving those intermediary results
270298
to disk or to CPU instead of leaving them on the GPU. This is especially useful if you
271299
notice your model fits on GPU during evaluation, but not training.

test/test_autograd.py

+45
Original file line numberDiff line numberDiff line change
@@ -6051,6 +6051,51 @@ def backward(ctx, g):
60516051
self.assertEqual(y.grad_fn.saved_tensors, ())
60526052
self.assertEqual(y.grad_fn._raw_saved_tensors, ())
60536053

6054+
def test_autograd_node_isinstance(self):
6055+
# Node is a "virtual" base class of codegen'd nodes. This means that
6056+
# isinstance and issubclass are overridden, but mro is unchanged
6057+
Node = torch.autograd.graph.Node
6058+
6059+
a = torch.rand(3, 3, requires_grad=True)
6060+
b = a.exp()
6061+
6062+
# Some nodes have codegened registrations to the torch._C._function module
6063+
self.assertIsInstance(b.grad_fn, Node)
6064+
self.assertTrue(issubclass(type(b.grad_fn), Node))
6065+
self.assertTrue(Node not in type(b.grad_fn).mro())
6066+
6067+
# Other nodes have manual registrations to the torch._C._function module
6068+
self.assertNotIsInstance(torch._C._functions.AccumulateGrad, Node)
6069+
self.assertTrue(issubclass(torch._C._functions.AccumulateGrad, Node))
6070+
self.assertIsInstance(b.grad_fn.next_functions[0][0], Node)
6071+
self.assertTrue(issubclass(torch._C._functions.DelayedError, Node))
6072+
6073+
# Special cases
6074+
self.assertNotIsInstance(None, Node)
6075+
self.assertNotIsInstance(1, Node)
6076+
self.assertNotIsInstance(Node, Node)
6077+
self.assertTrue(issubclass(Node, Node))
6078+
6079+
# Custom function case
6080+
self.assertTrue(issubclass(torch.autograd.function.BackwardCFunction, Node))
6081+
6082+
class Func(torch.autograd.Function):
6083+
@staticmethod
6084+
def forward(ctx, x):
6085+
self.assertIsInstance(ctx, Node)
6086+
return x
6087+
6088+
@staticmethod
6089+
def backward(ctx, x):
6090+
self.assertIsInstance(ctx, Node)
6091+
return x
6092+
6093+
out = Func.apply(a)
6094+
self.assertIsInstance(out.grad_fn, Node)
6095+
self.assertTrue(issubclass(type(out.grad_fn), Node))
6096+
self.assertTrue(Node not in type(out.grad_fn).mro())
6097+
out.sum().backward()
6098+
60546099
def test_autograd_views_codegen(self):
60556100
# This is not necessarily the absolute correct behavior, but this is the current
60566101
# one. This test is here to make sure that any change to this behavior is detected

tools/autograd/gen_autograd_functions.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
PY_FUNCTION_DEFINITION = CodeTemplate(
129129
"""\
130130
static PyTypeObject ${op}Class;
131-
addClass<${op}>(${op}Class, "${op}", ${op}_properties);
131+
addClass<${op}>(module, ${op}Class, "${op}", ${op}_properties);
132132
"""
133133
)
134134

@@ -432,11 +432,12 @@ def gen_autograd_functions_python(
432432
"generated_comment": "@"
433433
+ f"generated from {fm.template_dir_for_comments()}/python_functions.h",
434434
"shard_forward_declare": [
435-
f"void initialize_autogenerated_functions_{i}();"
435+
f"void initialize_autogenerated_functions_{i}(PyObject* module);"
436436
for i in range(num_shards)
437437
],
438438
"shard_call": [
439-
f"initialize_autogenerated_functions_{i}();" for i in range(num_shards)
439+
f"initialize_autogenerated_functions_{i}(module);"
440+
for i in range(num_shards)
440441
],
441442
},
442443
)

tools/autograd/templates/python_functions.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,18 @@
1919
namespace torch { namespace autograd { namespace generated {
2020

2121
template<typename C>
22-
static void addClass(PyTypeObject& type, const char* name,
22+
static void addClass(PyObject* module, PyTypeObject& type, const char* name,
2323
PyGetSetDef* function_properties=NULL, PyMethodDef* function_methods=NULL)
2424
{
2525
_initFunctionPyTypeObject(type, name, function_properties, function_methods);
2626
Py_INCREF(&type);
27+
PyModule_AddObject(module, name, (PyObject*)&type);
2728
registerCppFunction(typeid(C), &type);
2829
}
2930

3031
${py_function_props_and_getters}
3132

32-
void initialize_autogenerated_functions${shard_id}() {
33+
void initialize_autogenerated_functions${shard_id}(PyObject* module) {
3334
${py_function_initializers}
3435
}
3536

tools/autograd/templates/python_functions.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#pragma once
22

3+
#include <Python.h>
4+
35
// ${generated_comment}
46

57
// Python bindings for automatically generated autograd functions
@@ -8,7 +10,7 @@ namespace torch { namespace autograd { namespace generated {
810

911
${shard_forward_declare}
1012

11-
inline void initialize_autogenerated_functions() {
13+
inline void initialize_autogenerated_functions(PyObject* module) {
1214
${shard_call}
1315
}
1416

torch/_C/__init__.pyi.in

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from torch.package import PackageExporter
55
from torch import Tensor
6+
from torch.autograd.graph import Node as _Node
67
from enum import Enum
78
from pathlib import Path
89
from typing import (
@@ -1178,7 +1179,7 @@ class _TensorBase(metaclass=_TensorMeta):
11781179
_version: _int
11791180
_base: Optional[Tensor]
11801181
_cdata: _int
1181-
grad_fn: Any
1182+
grad_fn: _Node
11821183
_grad_fn: Any
11831184
_grad: Optional[Tensor]
11841185
grad: Optional[Tensor]
@@ -1542,7 +1543,7 @@ def _activate_cuda_trace() -> None: ...
15421543

15431544
# Defined in torch/csrc/Module.cpp
15441545
def _current_graph_task_id() -> _int: ...
1545-
def _current_autograd_node() -> Any: ...
1546+
def _current_autograd_node() -> _Node: ...
15461547

15471548
class _OutOfMemoryError:
15481549
pass

torch/autograd/graph.py

+107
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,122 @@
55
from torch.utils._python_dispatch import TorchDispatchMode
66
from collections import defaultdict
77
import weakref
8+
import abc
89

910
__all__ = [
1011
"saved_tensors_hooks",
1112
"save_on_cpu",
1213
"disable_saved_tensors_hooks",
1314
"register_multi_grad_hook",
1415
"allow_mutation_on_saved_tensors",
16+
"Node",
1517
]
1618

19+
class Node(abc.ABC):
20+
@abc.abstractmethod
21+
def name(self) -> str:
22+
r"""Returns the name.
23+
24+
Example::
25+
26+
>>> import torch
27+
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
28+
>>> b = a.clone()
29+
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
30+
>>> print(b.grad_fn.name())
31+
CloneBackward0
32+
"""
33+
...
34+
35+
@property
36+
@abc.abstractmethod
37+
def next_functions(self) -> Tuple[Tuple[Optional['Node'], int], ...]:
38+
...
39+
40+
@abc.abstractmethod
41+
def metadata(self) -> dict:
42+
r"""Returns the metadata."""
43+
...
44+
45+
@abc.abstractmethod
46+
def _register_hook_dict(self, tensor: torch.Tensor) -> None:
47+
...
48+
49+
@abc.abstractmethod
50+
def register_hook(self, fn: Callable[..., Any]) -> RemovableHandle:
51+
r"""Registers a backward hook.
52+
53+
The hook will be called every time a gradient with respect to the
54+
Node is computed. The hook should have the following signature::
55+
56+
hook(grad_inputs: Tuple[Tensor], grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
57+
58+
59+
The hook should not modify its argument, but it can optionally return
60+
a new gradient which will be used in place of :attr:`grad_outputs`.
61+
62+
This function returns a handle with a method ``handle.remove()``
63+
that removes the hook from the module.
64+
65+
Example::
66+
67+
>>> import torch
68+
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
69+
>>> b = a.clone()
70+
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
71+
>>> handle = b.grad_fn.register_hook(lambda gI, gO: (gO[0] * 2,))
72+
>>> b.sum().backward(retain_graph=True)
73+
>>> print(a.grad)
74+
tensor([2., 2., 2.])
75+
>>> handle.remove() # Removes the hook
76+
>>> a.grad = None
77+
>>> b.sum().backward(retain_graph=True)
78+
>>> print(a.grad)
79+
tensor([1., 1., 1.])
80+
"""
81+
...
82+
83+
@abc.abstractmethod
84+
def register_prehook(self, fn: Callable[..., Any]) -> RemovableHandle:
85+
r"""Registers a backward pre-hook.
86+
87+
The hook will be called every time a gradient with respect to the
88+
Node is computed. The hook should have the following signature::
89+
90+
hook(grad_outputs: Tuple[Tensor]) -> Tuple[Tensor] or None
91+
92+
93+
The hook should not modify its argument, but it can optionally return
94+
a new gradient which will be used in place of :attr:`grad_outputs`.
95+
96+
This function returns a handle with a method ``handle.remove()``
97+
that removes the hook from the module.
98+
99+
Example::
100+
101+
>>> a = torch.tensor([0., 0., 0.], requires_grad=True)
102+
>>> b = a.clone()
103+
>>> assert isinstance(b.grad_fn, torch.autograd.graph.Node)
104+
>>> handle = b.grad_fn.register_prehook(lambda gI: (gI[0] * 2,))
105+
>>> b.sum().backward(retain_graph=True)
106+
>>> print(a.grad)
107+
tensor([2., 2., 2.])
108+
>>> handle.remove()
109+
>>> a.grad = None
110+
>>> b.sum().backward(retain_graph=True)
111+
>>> print(a.grad)
112+
tensor([1., 1., 1.])
113+
"""
114+
...
115+
116+
@classmethod
117+
def __subclasshook__(cls, C):
118+
if cls is Node:
119+
if ((C is not None and C is getattr(torch._C._functions, C.__name__, None))
120+
or issubclass(C, torch.autograd.function.BackwardCFunction)):
121+
return True
122+
return NotImplemented
123+
17124
class saved_tensors_hooks():
18125
"""Context-manager that sets a pair of pack / unpack hooks for saved tensors.
19126

torch/csrc/autograd/functions/init.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ void THPAutograd_initFunctions() {
156156
static PyTypeObject CopySlicesClass;
157157
addClass<CopySlices, NoCtor>(module, CopySlicesClass, "CopySlices");
158158

159-
generated::initialize_autogenerated_functions();
159+
generated::initialize_autogenerated_functions(module);
160160

161161
auto c_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
162162
if (!c_module)

torch/distributed/fsdp/_runtime_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,7 @@ def _register_post_backward_hooks(
12241224
"register the post-backward hook",
12251225
)
12261226
acc_grad = temp_flat_param.grad_fn.next_functions[0][0]
1227+
assert acc_grad is not None
12271228
hook_handle = acc_grad.register_hook(
12281229
functools.partial(_post_backward_hook, state, handle)
12291230
)

0 commit comments

Comments
 (0)