Skip to content

Commit 98b78aa

Browse files
zou3519pytorchmergebot
authored andcommitted
[autograd.Function] setup_context always appears on the Function (pytorch#92312)
Previously, we used the existence of setup_context to switch between if forward should take a ctx object or not. To be consistent with all other staticmethod (which always exist on the autograd.Function), this PR change it so that we use IF setup_context gets overriden by the user to switch between if forward should take a ctx object or not. Fixes pytorch#91451 Test Plan: - existing tests Pull Request resolved: pytorch#92312 Approved by: https://github.com/albanD, https://github.com/soulitzer
1 parent 00fe63d commit 98b78aa

File tree

5 files changed

+80
-30
lines changed

5 files changed

+80
-30
lines changed

docs/source/notes/extending.func.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ This guide assumes you are familiar with :ref:`extending-autograd`,
2525
which explains how to use :class:`torch.autograd.Function`.
2626

2727
:class:`torch.autograd.Function` can either have a :meth:`~Function.forward` that accepts a ctx object,
28-
or it can have separate :meth:`~Function.forward` (that does not accept ``ctx``) and a ``setup_context``
28+
or it can have separate :meth:`~Function.forward` (that does not accept ``ctx``) and a :meth:`~Function.setup_context`
2929
staticmethod that modifies the ``ctx`` object.
3030

3131
Only the latter is supported with function transforms:
@@ -52,7 +52,7 @@ Depending on the transform,
5252

5353
In order for the :class:`torch.autograd.Function` to be arbitrarily composable with function
5454
transforms, we recommend that all other staticmethods other than :meth:`~Function.forward` and
55-
``setup_context`` must be transformable: that is, they must consist of only PyTorch
55+
:meth:`~Function.setup_context` must be transformable: that is, they must consist of only PyTorch
5656
operators or call other :class:`torch.autograd.Function` (that may call into C++/CUDA/etc).
5757

5858
Let's go over some examples of common use cases.

docs/source/notes/extending.rst

+14-14
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ How to use
5252
^^^^^^^^^^
5353
Take the following steps:
5454
1. Subclass :class:`~Function` and implement the :meth:`~Function.forward`,
55-
(optional) ``setup_context`` and
55+
(optional) :meth:`~Function.setup_context` and
5656
:meth:`~Function.backward` methods.
5757
2. Call the proper methods on the `ctx` argument.
5858
3. Declare whether your function supports
@@ -73,12 +73,12 @@ Take the following steps:
7373
tensors if there are multiple outputs. Also, please refer to the
7474
docs of :class:`Function` to find descriptions of useful methods that can be
7575
called only from :meth:`~Function.forward`.
76-
- ``setup_context`` (optional). One can either write a "combined" :meth:`~Function.forward` that
76+
- :meth:`~Function.setup_context` (optional). One can either write a "combined" :meth:`~Function.forward` that
7777
accepts a ``ctx`` object or (as of PyTorch 2.0) a separate :meth:`~Function.forward` that does
78-
not accept ``ctx`` and a ``setup_context`` method where the ``ctx`` modification happens.
79-
The :meth:`~Function.forward` should have the compute and ``setup_context`` should
78+
not accept ``ctx`` and a :meth:`~Function.setup_context` method where the ``ctx`` modification happens.
79+
The :meth:`~Function.forward` should have the compute and :meth:`~Function.setup_context` should
8080
only be responsible for the ``ctx`` modification (and not have any compute).
81-
In general the separate :meth:`~Function.forward` and ``setup_context`` is closer to how
81+
In general the separate :meth:`~Function.forward` and :meth:`~Function.setup_context` is closer to how
8282
PyTorch native operations work and therefore more composable with various PyTorch subsystems.
8383
See :ref:`combining-forward-context` for more details.
8484
- :meth:`~Function.backward` (or :meth:`~Function.vjp`) defines the gradient formula.
@@ -234,7 +234,7 @@ And here, we optimize the above example by calling set_materialize_grads(False):
234234
return grad_output * ctx.constant, None
235235

236236
If you need any "intermediate" Tensors computed in :meth:`~Function.forward` to be saved,
237-
either they must be returned as outputs, or combine ``forward`` and ``setup_context``
237+
either they must be returned as outputs, or combine ``forward`` and :meth:`~Function.setup_context`
238238
(see :ref:`combining-forward-context`).
239239
Note that this means if you want gradients to flow through those intermediate values, you
240240
need to define the gradient formula for them (see also
@@ -300,25 +300,25 @@ can use the ``gradgradcheck`` function from the same package to check higher ord
300300

301301
.. _combining-forward-context:
302302

303-
Combined or separate :meth:`~Function.forward` and ``setup_context``
304-
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
303+
Combined or separate :meth:`~Function.forward` and :meth:`~Function.setup_context`
304+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
305305

306306
There are two main ways to define :class:`~Function`. Either:
307307

308-
- define a :meth:`~Function.forward` that combines the forward compute logic with ``setup_context``
309-
- (as of PyTorch 2.0) define a separate :meth:`~Function.forward` and ``setup_context``.
308+
- define a :meth:`~Function.forward` that combines the forward compute logic with :meth:`~Function.setup_context`
309+
- (as of PyTorch 2.0) define a separate :meth:`~Function.forward` and :meth:`~Function.setup_context`
310310

311-
We recommend the second option (separate :meth:`~Function.forward` and ``setup_context``)
311+
We recommend the second option (separate :meth:`~Function.forward` and :meth:`~Function.setup_context`)
312312
because that is closer to how PyTorch native operations are implemented and it composes
313313
with :mod:`torch.func` transforms. However, we plan to support both approaches going forward;
314-
combining :meth:`~Function.forward` with ``setup_context``: leads to more flexibility since
314+
combining :meth:`~Function.forward` with :meth:`~Function.setup_context`: leads to more flexibility since
315315
you are able to save intermediates without returning them as output.
316316

317317
Please see the previous section for how to define :class:`~Function` with separate
318-
:meth:`~Function.forward` and ``setup_context``.
318+
:meth:`~Function.forward` and :meth:`~Function.setup_context`.
319319

320320
Here is an example of how to define a :class:`Function` with combined :meth:`~Function.forward` and
321-
``setup_context``::
321+
:meth:`~Function.setup_context`::
322322

323323
class LinearFunction(Function):
324324
@staticmethod

test/functorch/test_eager_transforms.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3037,7 +3037,7 @@ def backward(ctx, gy):
30373037

30383038
x = torch.randn(3, device=device)
30393039
transform = getattr(functorch, transform)
3040-
with self.assertRaisesRegex(RuntimeError, 'must have a setup_context'):
3040+
with self.assertRaisesRegex(RuntimeError, 'must override the setup_context'):
30413041
transform(MySin.apply)(x)
30423042

30433043
@parametrize('transform', [

torch/autograd/function.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import functools
88
import warnings
99
from collections import OrderedDict
10-
from typing import Any, List, Optional
10+
from typing import Any, List, Optional, Tuple
1111
from torch._functorch.autograd_function import custom_function_call
1212

1313
__all__ = ["FunctionCtx", "BackwardCFunction", "FunctionMeta", "Function", "once_differentiable", "traceable",
@@ -323,8 +323,8 @@ def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
323323
pass
324324
325325
- The forward no longer accepts a ctx argument.
326-
- Instead, you must also define a setup_context staticmethod to handle setting up the
327-
``ctx`` object.
326+
- Instead, you must also override the :meth:`torch.autograd.Function.setup_context`
327+
staticmethod to handle setting up the ``ctx`` object.
328328
``output`` is the output of the forward, ``inputs`` are a Tuple of inputs
329329
to the forward.
330330
- See :ref:`extending-autograd` for more details
@@ -340,6 +340,23 @@ def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
340340
raise NotImplementedError("You must implement the forward function for custom"
341341
" autograd.Function.")
342342

343+
@staticmethod
344+
def setup_context(ctx: Any, inputs: Tuple[Any], output: Any) -> Any:
345+
r"""There are two ways to define the forward pass of an autograd.Function.
346+
347+
Either:
348+
349+
1. Override forward with the signature forward(ctx, *args, **kwargs).
350+
``setup_context`` is not overridden. Setting up the ctx for backward
351+
happens inside the ``forward``.
352+
2. Override forward with the signature forward(*args, **kwargs) and
353+
override ``setup_context``. Setting up the ctx for backward happens
354+
inside ``setup_context`` (as opposed to inside the ``forward``)
355+
356+
See :meth:`torch.autograd.Function.forward` and :ref:`extending-autograd` for more details.
357+
"""
358+
raise NotImplementedError("setup_context is not implemented.")
359+
343360
@staticmethod
344361
def backward(ctx: Any, *grad_outputs: Any) -> Any:
345362
r"""Defines a formula for differentiating the operation with backward mode
@@ -490,13 +507,12 @@ def apply(cls, *args, **kwargs):
490507
args = _functorch.utils.unwrap_dead_wrappers(args)
491508
return super().apply(*args, **kwargs)
492509

493-
if not hasattr(cls, 'setup_context'):
494-
# TODO: link documentation in error message
495-
# https://github.com/pytorch/pytorch/issues/90224
510+
if cls.setup_context == _SingleLevelFunction.setup_context:
496511
raise RuntimeError(
497-
'In order to use an autograd.Function with functorch transforms ',
498-
'(vmap, grad, jvp, jacrev, ...), it must have a setup_context ',
499-
'staticmethod.')
512+
'In order to use an autograd.Function with functorch transforms '
513+
'(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
514+
'staticmethod. For more details, please see '
515+
'https://pytorch.org/docs/master/notes/extending.func.html')
500516

501517
return custom_function_call(cls, *args, **kwargs)
502518

torch/csrc/autograd/python_function.cpp

+38-4
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,31 @@ THPObjectPtr make_ctx_input_output_tuple(
871871

872872
} // namespace
873873

874+
static PyObject* THPFunction_setup_context = nullptr;
875+
876+
static PyObject* get_base_setup_context() {
877+
if (THPFunction_setup_context != nullptr) {
878+
return THPFunction_setup_context;
879+
}
880+
881+
auto module = THPObjectPtr(PyImport_ImportModule("torch.autograd.function"));
882+
if (!module)
883+
return nullptr;
884+
885+
auto function =
886+
THPObjectPtr(PyObject_GetAttrString(module, "_SingleLevelFunction"));
887+
if (!function)
888+
return nullptr;
889+
890+
// setup_context gets "leaked" - we return a new reference and hold onto it
891+
// forever.
892+
auto setup_context = PyObject_GetAttrString(function, "setup_context");
893+
if (!setup_context)
894+
return nullptr;
895+
THPFunction_setup_context = setup_context;
896+
return THPFunction_setup_context;
897+
}
898+
874899
PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
875900
HANDLE_TH_ERRORS
876901

@@ -920,10 +945,19 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
920945
ctx->needs_input_grad = input_info.needs_input_grad.release();
921946
ctx->is_variable_input = std::move(input_info.is_variable_input);
922947

923-
// autograd.Function may optionally contain a setup_context staticmethod.
948+
// autograd.Function may optionally override a setup_context staticmethod.
924949
// In this case, autograd.Function.forward does NOT accept a ctx object.
925-
bool has_separate_setup_context_fn =
926-
PyObject_HasAttrString(cls, "setup_context");
950+
// Determine if this is the case.
951+
auto cls_setup_context =
952+
THPObjectPtr(PyObject_GetAttrString(cls, "setup_context"));
953+
if (!cls_setup_context) {
954+
return nullptr;
955+
}
956+
auto orig_setup_context = get_base_setup_context();
957+
if (!orig_setup_context) {
958+
return nullptr;
959+
}
960+
auto overridden_setup_context = cls_setup_context.get() != orig_setup_context;
927961

928962
auto num_args = PyTuple_GET_SIZE(inputs);
929963

@@ -935,7 +969,7 @@ PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
935969
THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward"));
936970
if (!forward_fn)
937971
return nullptr;
938-
if (has_separate_setup_context_fn) {
972+
if (overridden_setup_context) {
939973
// call forward followed by setup_context
940974
output = PyObject_CallObject(forward_fn, unpacked_input.input_tuple);
941975
if (!output) {

0 commit comments

Comments
 (0)