Skip to content

Commit 16f9d1b

Browse files
zou3519pytorchmergebot
authored andcommitted
[torch.func] Add migration guide from functorch (pytorch#91811)
Test Plan: - view preview Future: - still need to figure out the make_fx situation Pull Request resolved: pytorch#91811 Approved by: https://github.com/albanD
1 parent 89f1ad0 commit 16f9d1b

File tree

2 files changed

+208
-0
lines changed

2 files changed

+208
-0
lines changed

docs/source/func.migrating.rst

+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
Migrating from functorch to torch.func
2+
======================================
3+
4+
torch.func, previously known as "functorch", is
5+
`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
6+
7+
functorch started as an out-of-tree library over at
8+
the `pytorch/functorch <https://github.com/pytorch/functorch>`_ repository.
9+
Our goal has always been to upstream functorch directly into PyTorch and provide
10+
it as a core PyTorch library.
11+
12+
As the final step of the upstream, we've decided to migrate from being a top level package
13+
(``functorch``) to being a part of PyTorch to reflect how the function transforms are
14+
integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating
15+
``import functorch`` and ask that users migrate to the newest APIs, which we
16+
will maintain going forward. ``import functorch`` will be kept around to maintain
17+
backwards compatibility for a couple of releases.
18+
19+
function transforms
20+
-------------------
21+
22+
The following APIs are a drop-in replacement for the following
23+
`functorch APIs <https://pytorch.org/functorch/1.13/functorch.html>`_.
24+
They are fully backwards compatible.
25+
26+
27+
============================== =======================================
28+
functorch API PyTorch API (as of PyTorch 2.0)
29+
============================== =======================================
30+
functorch.vmap :func:`torch.vmap` or :func:`torch.func.vmap`
31+
functorch.grad :func:`torch.func.grad`
32+
functorch.vjp :func:`torch.func.vjp`
33+
functorch.jvp :func:`torch.func.jvp`
34+
functorch.jacrev :func:`torch.func.jacrev`
35+
functorch.jacfwd :func:`torch.func.jacfwd`
36+
functorch.hessian :func:`torch.func.hessian`
37+
functorch.functionalize :func:`torch.func.functionalize`
38+
============================== =======================================
39+
40+
Furthermore, if you are using torch.autograd.functional APIs, please try out
41+
the :mod:`torch.func` equivalents instead. :mod:`torch.func` function
42+
transforms are more composable and more performant in many cases.
43+
44+
=========================================== =======================================
45+
torch.autograd.functional API torch.func API (as of PyTorch 2.0)
46+
=========================================== =======================================
47+
:func:`torch.autograd.functional.vjp` :func:`torch.func.grad` or :func:`torch.func.vjp`
48+
:func:`torch.autograd.functional.jvp` :func:`torch.func.jvp`
49+
:func:`torch.autograd.functional.jacobian` :func:`torch.func.jacrev` or :func:`torch.func.jacfwd`
50+
:func:`torch.autograd.functional.hessian` :func:`torch.func.hessian`
51+
=========================================== =======================================
52+
53+
NN module utilities
54+
-------------------
55+
56+
We've changed the APIs to apply function transforms over NN modules to make them
57+
fit better into the PyTorch design philosophy. The new API is different, so
58+
please read this section carefully.
59+
60+
functorch.make_functional
61+
^^^^^^^^^^^^^^^^^^^^^^^^^
62+
63+
:func:`torch.func.functional_call` is the replacement for
64+
`functorch.make_functional <https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional>`_
65+
and
66+
`functorch.make_functional_with_buffers <https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers>`_.
67+
However, it is not a drop-in replacement.
68+
69+
If you're in a hurry, you can use
70+
`helper functions in this gist <https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf>`_
71+
that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.
72+
We recommend using :func:`torch.func.functional_call` directly because it is a more explicit
73+
and flexible API.
74+
75+
Concretely, functorch.make_functional returns a functional module and parameters.
76+
The functional module accepts parameters and inputs to the model as arguments.
77+
:func:`torch.func.functional_call` allows one to call the forward pass of an existing
78+
module using new parameters and buffers and inputs.
79+
80+
Here's an example of how to compute gradients of parameters of a model using functorch
81+
vs :mod:`torch.func`::
82+
83+
# ---------------
84+
# using functorch
85+
# ---------------
86+
import torch
87+
import functorch
88+
inputs = torch.randn(64, 3)
89+
targets = torch.randn(64, 3)
90+
model = torch.nn.Linear(3, 3)
91+
92+
fmodel, params = functorch.make_functional(model)
93+
94+
def compute_loss(params, inputs, targets):
95+
prediction = fmodel(params, inputs)
96+
return torch.nn.functional.mse_loss(prediction, targets)
97+
98+
grads = functorch.grad(compute_loss)(params, inputs, targets)
99+
100+
# ------------------------------------
101+
# using torch.func (as of PyTorch 2.0)
102+
# ------------------------------------
103+
import torch
104+
inputs = torch.randn(64, 3)
105+
targets = torch.randn(64, 3)
106+
model = torch.nn.Linear(3, 3)
107+
108+
params = dict(model.named_parameters())
109+
110+
def compute_loss(params, inputs, targets):
111+
prediction = torch.func.functional_call(model, params, (inputs,))
112+
return torch.nn.functional.mse_loss(prediction, targets)
113+
114+
grads = torch.func.grad(compute_loss)(params, inputs, targets)
115+
116+
And here's an example of how to compute jacobians of model parameters::
117+
118+
# ---------------
119+
# using functorch
120+
# ---------------
121+
import torch
122+
import functorch
123+
inputs = torch.randn(64, 3)
124+
model = torch.nn.Linear(3, 3)
125+
126+
fmodel, params = functorch.make_functional(model)
127+
jacobians = functorch.jacrev(fmodel)(params, inputs)
128+
129+
# ------------------------------------
130+
# using torch.func (as of PyTorch 2.0)
131+
# ------------------------------------
132+
import torch
133+
from torch.func import jacrev, functional_call
134+
inputs = torch.randn(64, 3)
135+
model = torch.nn.Linear(3, 3)
136+
137+
params = dict(model.named_parameters())
138+
# jacrev computes jacobians of argnums=0 by default.
139+
# We set it to 1 to compute jacobians of params
140+
jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
141+
142+
Note that it is important for memory consumption that you should only carry
143+
around a single copy of your parameters. ``model.named_parameters()`` does not copy
144+
the parameters. If in your model training you update the parameters of the model
145+
in-place, then the ``nn.Module`` that is your model has the single copy of the
146+
parameters and everything is OK.
147+
148+
However, if you want to carry your parameters around in a dictionary and update
149+
them out-of-place, then there are two copies of parameters: the one in the
150+
dictionary and the one in the ``model``. In this case, you should change
151+
``model`` to not hold memory by converting it to the meta device via
152+
``model.to('meta')``.
153+
154+
functorch.combine_state_for_ensemble
155+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
156+
157+
Please use :func:`torch.func.stack_module_state` instead of
158+
`functorch.combine_state_for_ensemble <https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html>`_
159+
:func:`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and
160+
one of stacked buffers, that can then be used with :func:`torch.vmap` and :func:`torch.func.functional_call`
161+
for ensembling.
162+
163+
For example, here is an example of how to ensemble over a very simple model::
164+
165+
import torch
166+
num_models = 5
167+
batch_size = 64
168+
in_features, out_features = 3, 3
169+
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
170+
data = torch.randn(batch_size, 3)
171+
172+
# ---------------
173+
# using functorch
174+
# ---------------
175+
import functorch
176+
fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
177+
output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
178+
assert output.shape == (num_models, batch_size, out_features)
179+
180+
# ------------------------------------
181+
# using torch.func (as of PyTorch 2.0)
182+
# ------------------------------------
183+
import copy
184+
185+
# Construct a version of the model with no memory by putting the Tensors on
186+
# the meta device.
187+
base_model = copy.deepcopy(models[0])
188+
base_model.to('meta')
189+
190+
params, buffers = torch.func.stack_module_state(models)
191+
192+
# It is possible to vmap directly over torch.func.functional_call,
193+
# but wrapping it in a function makes it clearer what is going on.
194+
def call_single_model(params, buffers, data):
195+
return torch.func.functional_call(base_model, (params, buffers), (data,))
196+
197+
output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
198+
assert output.shape == (num_models, batch_size, out_features)
199+
200+
201+
functorch.compile
202+
-----------------
203+
204+
We are no longer supporting functorch.compile (also known as AOTAutograd)
205+
as a frontend for compilation in PyTorch; we have integrated AOTAutograd
206+
into PyTorch's compilation story. If you are a user, please use
207+
:func:`torch.compile` instead.

docs/source/func.rst

+1
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,4 @@ Read More
5252
func.whirlwind_tour
5353
func.api
5454
func.ux_limitations
55+
func.migrating

0 commit comments

Comments
 (0)