|
| 1 | +Migrating from functorch to torch.func |
| 2 | +====================================== |
| 3 | + |
| 4 | +torch.func, previously known as "functorch", is |
| 5 | +`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch. |
| 6 | + |
| 7 | +functorch started as an out-of-tree library over at |
| 8 | +the `pytorch/functorch <https://github.com/pytorch/functorch>`_ repository. |
| 9 | +Our goal has always been to upstream functorch directly into PyTorch and provide |
| 10 | +it as a core PyTorch library. |
| 11 | + |
| 12 | +As the final step of the upstream, we've decided to migrate from being a top level package |
| 13 | +(``functorch``) to being a part of PyTorch to reflect how the function transforms are |
| 14 | +integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating |
| 15 | +``import functorch`` and ask that users migrate to the newest APIs, which we |
| 16 | +will maintain going forward. ``import functorch`` will be kept around to maintain |
| 17 | +backwards compatibility for a couple of releases. |
| 18 | + |
| 19 | +function transforms |
| 20 | +------------------- |
| 21 | + |
| 22 | +The following APIs are a drop-in replacement for the following |
| 23 | +`functorch APIs <https://pytorch.org/functorch/1.13/functorch.html>`_. |
| 24 | +They are fully backwards compatible. |
| 25 | + |
| 26 | + |
| 27 | +============================== ======================================= |
| 28 | +functorch API PyTorch API (as of PyTorch 2.0) |
| 29 | +============================== ======================================= |
| 30 | +functorch.vmap :func:`torch.vmap` or :func:`torch.func.vmap` |
| 31 | +functorch.grad :func:`torch.func.grad` |
| 32 | +functorch.vjp :func:`torch.func.vjp` |
| 33 | +functorch.jvp :func:`torch.func.jvp` |
| 34 | +functorch.jacrev :func:`torch.func.jacrev` |
| 35 | +functorch.jacfwd :func:`torch.func.jacfwd` |
| 36 | +functorch.hessian :func:`torch.func.hessian` |
| 37 | +functorch.functionalize :func:`torch.func.functionalize` |
| 38 | +============================== ======================================= |
| 39 | + |
| 40 | +Furthermore, if you are using torch.autograd.functional APIs, please try out |
| 41 | +the :mod:`torch.func` equivalents instead. :mod:`torch.func` function |
| 42 | +transforms are more composable and more performant in many cases. |
| 43 | + |
| 44 | +=========================================== ======================================= |
| 45 | +torch.autograd.functional API torch.func API (as of PyTorch 2.0) |
| 46 | +=========================================== ======================================= |
| 47 | +:func:`torch.autograd.functional.vjp` :func:`torch.func.grad` or :func:`torch.func.vjp` |
| 48 | +:func:`torch.autograd.functional.jvp` :func:`torch.func.jvp` |
| 49 | +:func:`torch.autograd.functional.jacobian` :func:`torch.func.jacrev` or :func:`torch.func.jacfwd` |
| 50 | +:func:`torch.autograd.functional.hessian` :func:`torch.func.hessian` |
| 51 | +=========================================== ======================================= |
| 52 | + |
| 53 | +NN module utilities |
| 54 | +------------------- |
| 55 | + |
| 56 | +We've changed the APIs to apply function transforms over NN modules to make them |
| 57 | +fit better into the PyTorch design philosophy. The new API is different, so |
| 58 | +please read this section carefully. |
| 59 | + |
| 60 | +functorch.make_functional |
| 61 | +^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 62 | + |
| 63 | +:func:`torch.func.functional_call` is the replacement for |
| 64 | +`functorch.make_functional <https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional>`_ |
| 65 | +and |
| 66 | +`functorch.make_functional_with_buffers <https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers>`_. |
| 67 | +However, it is not a drop-in replacement. |
| 68 | + |
| 69 | +If you're in a hurry, you can use |
| 70 | +`helper functions in this gist <https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf>`_ |
| 71 | +that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers. |
| 72 | +We recommend using :func:`torch.func.functional_call` directly because it is a more explicit |
| 73 | +and flexible API. |
| 74 | + |
| 75 | +Concretely, functorch.make_functional returns a functional module and parameters. |
| 76 | +The functional module accepts parameters and inputs to the model as arguments. |
| 77 | +:func:`torch.func.functional_call` allows one to call the forward pass of an existing |
| 78 | +module using new parameters and buffers and inputs. |
| 79 | + |
| 80 | +Here's an example of how to compute gradients of parameters of a model using functorch |
| 81 | +vs :mod:`torch.func`:: |
| 82 | + |
| 83 | + # --------------- |
| 84 | + # using functorch |
| 85 | + # --------------- |
| 86 | + import torch |
| 87 | + import functorch |
| 88 | + inputs = torch.randn(64, 3) |
| 89 | + targets = torch.randn(64, 3) |
| 90 | + model = torch.nn.Linear(3, 3) |
| 91 | + |
| 92 | + fmodel, params = functorch.make_functional(model) |
| 93 | + |
| 94 | + def compute_loss(params, inputs, targets): |
| 95 | + prediction = fmodel(params, inputs) |
| 96 | + return torch.nn.functional.mse_loss(prediction, targets) |
| 97 | + |
| 98 | + grads = functorch.grad(compute_loss)(params, inputs, targets) |
| 99 | + |
| 100 | + # ------------------------------------ |
| 101 | + # using torch.func (as of PyTorch 2.0) |
| 102 | + # ------------------------------------ |
| 103 | + import torch |
| 104 | + inputs = torch.randn(64, 3) |
| 105 | + targets = torch.randn(64, 3) |
| 106 | + model = torch.nn.Linear(3, 3) |
| 107 | + |
| 108 | + params = dict(model.named_parameters()) |
| 109 | + |
| 110 | + def compute_loss(params, inputs, targets): |
| 111 | + prediction = torch.func.functional_call(model, params, (inputs,)) |
| 112 | + return torch.nn.functional.mse_loss(prediction, targets) |
| 113 | + |
| 114 | + grads = torch.func.grad(compute_loss)(params, inputs, targets) |
| 115 | + |
| 116 | +And here's an example of how to compute jacobians of model parameters:: |
| 117 | + |
| 118 | + # --------------- |
| 119 | + # using functorch |
| 120 | + # --------------- |
| 121 | + import torch |
| 122 | + import functorch |
| 123 | + inputs = torch.randn(64, 3) |
| 124 | + model = torch.nn.Linear(3, 3) |
| 125 | + |
| 126 | + fmodel, params = functorch.make_functional(model) |
| 127 | + jacobians = functorch.jacrev(fmodel)(params, inputs) |
| 128 | + |
| 129 | + # ------------------------------------ |
| 130 | + # using torch.func (as of PyTorch 2.0) |
| 131 | + # ------------------------------------ |
| 132 | + import torch |
| 133 | + from torch.func import jacrev, functional_call |
| 134 | + inputs = torch.randn(64, 3) |
| 135 | + model = torch.nn.Linear(3, 3) |
| 136 | + |
| 137 | + params = dict(model.named_parameters()) |
| 138 | + # jacrev computes jacobians of argnums=0 by default. |
| 139 | + # We set it to 1 to compute jacobians of params |
| 140 | + jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,)) |
| 141 | + |
| 142 | +Note that it is important for memory consumption that you should only carry |
| 143 | +around a single copy of your parameters. ``model.named_parameters()`` does not copy |
| 144 | +the parameters. If in your model training you update the parameters of the model |
| 145 | +in-place, then the ``nn.Module`` that is your model has the single copy of the |
| 146 | +parameters and everything is OK. |
| 147 | + |
| 148 | +However, if you want to carry your parameters around in a dictionary and update |
| 149 | +them out-of-place, then there are two copies of parameters: the one in the |
| 150 | +dictionary and the one in the ``model``. In this case, you should change |
| 151 | +``model`` to not hold memory by converting it to the meta device via |
| 152 | +``model.to('meta')``. |
| 153 | + |
| 154 | +functorch.combine_state_for_ensemble |
| 155 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 156 | + |
| 157 | +Please use :func:`torch.func.stack_module_state` instead of |
| 158 | +`functorch.combine_state_for_ensemble <https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html>`_ |
| 159 | +:func:`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and |
| 160 | +one of stacked buffers, that can then be used with :func:`torch.vmap` and :func:`torch.func.functional_call` |
| 161 | +for ensembling. |
| 162 | + |
| 163 | +For example, here is an example of how to ensemble over a very simple model:: |
| 164 | + |
| 165 | + import torch |
| 166 | + num_models = 5 |
| 167 | + batch_size = 64 |
| 168 | + in_features, out_features = 3, 3 |
| 169 | + models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] |
| 170 | + data = torch.randn(batch_size, 3) |
| 171 | + |
| 172 | + # --------------- |
| 173 | + # using functorch |
| 174 | + # --------------- |
| 175 | + import functorch |
| 176 | + fmodel, params, buffers = functorch.combine_state_for_ensemble(models) |
| 177 | + output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data) |
| 178 | + assert output.shape == (num_models, batch_size, out_features) |
| 179 | + |
| 180 | + # ------------------------------------ |
| 181 | + # using torch.func (as of PyTorch 2.0) |
| 182 | + # ------------------------------------ |
| 183 | + import copy |
| 184 | + |
| 185 | + # Construct a version of the model with no memory by putting the Tensors on |
| 186 | + # the meta device. |
| 187 | + base_model = copy.deepcopy(models[0]) |
| 188 | + base_model.to('meta') |
| 189 | + |
| 190 | + params, buffers = torch.func.stack_module_state(models) |
| 191 | + |
| 192 | + # It is possible to vmap directly over torch.func.functional_call, |
| 193 | + # but wrapping it in a function makes it clearer what is going on. |
| 194 | + def call_single_model(params, buffers, data): |
| 195 | + return torch.func.functional_call(base_model, (params, buffers), (data,)) |
| 196 | + |
| 197 | + output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data) |
| 198 | + assert output.shape == (num_models, batch_size, out_features) |
| 199 | + |
| 200 | + |
| 201 | +functorch.compile |
| 202 | +----------------- |
| 203 | + |
| 204 | +We are no longer supporting functorch.compile (also known as AOTAutograd) |
| 205 | +as a frontend for compilation in PyTorch; we have integrated AOTAutograd |
| 206 | +into PyTorch's compilation story. If you are a user, please use |
| 207 | +:func:`torch.compile` instead. |
0 commit comments