Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0b9f56a

Browse files
committedJun 12, 2025·
Shorten first section, add suggestions, and start ResNet example (WIP)
Still a work in progress, but I significantly reduced the first section and added some helpful images for the computational graph. I also added links for most terms. The WIP section with ResNet I still have to debug. I'm not sure my method for retaining the intermediate gradients is valid. See discussion on pull request.
1 parent 031ba22 commit 0b9f56a

File tree

3 files changed

+385
-253
lines changed

3 files changed

+385
-253
lines changed
 
41.3 KB
Loading
45.6 KB
Loading

‎advanced_source/visualizing_gradients_tutorial.py

Lines changed: 385 additions & 253 deletions
Original file line numberDiff line numberDiff line change
@@ -2,60 +2,75 @@
22
Visualizing Gradients
33
=====================
44
5-
**Author**: `Justin Silver <https://github.com/j-silv>`_
6-
7-
By performance and efficiency reasons, PyTorch does not save the
8-
intermediate gradients when running back-propagation. To visualize the
9-
gradients of these internal layer tensor, we have to explicitly tell
10-
PyTorch to retain those values with the ``retain_grad`` parameter.
5+
**Author:** `Justin Silver <https://github.com/j-silv>`__
6+
7+
When training neural networks with PyTorch, it’s possible to ignore some
8+
of the library’s internal mechanisms. For example, running
9+
backpropagation requires a simple call to ``backward()``. This tutorial
10+
dives into how those gradients are calculated and stored in two
11+
different kinds of PyTorch tensors: leaf vs. non-leaf. It will also
12+
cover how we can extract and visualize gradients at any layer in the
13+
network’s computational graph. By inspecting how information flows from
14+
the end of the network to the parameters we want to optimize, we can
15+
debug issues that occur during training such as `vanishing or exploding
16+
gradients <https://arxiv.org/abs/1211.5063>`__.
1117
1218
By the end of this tutorial, you will be able to:
1319
14-
- Visualize gradients after backward propagation in a neural network
15-
- Differentiate between *leaf* and *non-leaf* tensors
16-
- Know when to use\ ``retain_grad`` vs. ``require_grad``
20+
- Differentiate leaf vs. non-leaf tensors
21+
- Know when to use ``requires_grad`` vs. ``retain_grad``
22+
- Visualize gradients after backpropagation in a neural network
1723
18-
"""
24+
We will start off with a simple network to understand how PyTorch
25+
calculates and stores gradients, and then build on this knowledge to
26+
visualize the gradient flow of a `ResNet
27+
model <https://docs.pytorch.org/vision/2.0/models/resnet.html>`__.
1928
29+
Before starting, it is recommended to have a solid understanding of
30+
`tensors and how to manipulate
31+
them <https://docs.pytorch.org/tutorials/beginner/basics/tensorqs_tutorial.html>`__.
32+
A basic knowledge of `how autograd
33+
works <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html>`__
34+
would also be useful.
2035
21-
######################################################################
22-
# Introduction
23-
# ------------
24-
#
25-
# When training neural networks with PyTorch, it is easy to disregard the
26-
# internal mechanisms of the PyTorch library. For example, to run
27-
# back-propagation the API requires a single call to ``loss.backward()``.
28-
# This tutorial will dive into how exactly those gradients are calculated
29-
# and stored in two different kinds of PyTorch tensors: *leaf*, and
30-
# *non-leaf*. It will also cover how we can extract and visualize
31-
# gradients at any neuron in the computational graph. Some important
32-
# barriers to efficient neural network training are vanishing/exploding
33-
# gradients, which lead to slow training progress and/or broken
34-
# optimization pipelines. Thus, it is important to understand how
35-
# information flows from one end of the network, through the computational
36-
# graph, and finally to the parameters we want to optimize.
37-
#
36+
"""
3837

3938

4039
######################################################################
4140
# Setup
4241
# -----
4342
#
44-
# First, make sure PyTorch is installed and then import the necessary
45-
# libraries
43+
# First, make sure `PyTorch is
44+
# installed <https://pytorch.org/get-started/locally/>`__ and then import
45+
# the necessary libraries.
4646
#
4747

4848
import torch
49+
import torchvision
50+
from torchvision.models import resnet18
4951
import torch.nn as nn
52+
import torch.optim as optim
5053
import torch.nn.functional as F
54+
import matplotlib.pyplot as plt
5155

5256

5357
######################################################################
54-
# Next, we will instantiate an extremely simple network so that we can
55-
# focus on the gradients. This will be an affine layer followed by a ReLU
56-
# activation. Note that the ``requires_grad=True`` is necessary for the
57-
# parameters (``W`` and ``b``) so that PyTorch tracks operations involving
58-
# those tensors. We’ll discuss more about this attribute shortly.
58+
# Next, we will instantiate a simple network so that we can focus on the
59+
# gradients. This will be an affine layer, followed by a ReLU activation,
60+
# and ending with a MSE loss between the prediction and label tensors.
61+
#
62+
# .. math::
63+
#
64+
# \mathbf{y}_{\text{pred}} = \text{ReLU}(\mathbf{x} \mathbf{W} + \mathbf{b})
65+
#
66+
# .. math::
67+
#
68+
# L = \text{MSE}(\mathbf{y}_{\text{pred}}, \mathbf{y})
69+
#
70+
# Note that the ``requires_grad=True`` is necessary for the parameters
71+
# (``W`` and ``b``) so that PyTorch tracks operations involving those
72+
# tensors. We’ll discuss more about this in a future
73+
# `section <#requires-grad>`__.
5974
#
6075

6176
# tensor setup
@@ -70,183 +85,126 @@
7085
loss = F.mse_loss(y_pred, y) # scalar loss
7186

7287

73-
######################################################################
74-
# Before we perform back-propagation on this network, we need to know the
75-
# difference between *leaf* and *non-leaf* nodes. This is important
76-
# because the distinction affects how gradients are calculated and stored.
77-
#
78-
79-
8088
######################################################################
8189
# Leaf vs. non-leaf tensors
8290
# -------------------------
8391
#
84-
# The backbone for PyTorch Autograd is a dynamic computational graph which
85-
# keeps a record of input tensor data, all subsequent operations on those
86-
# tensors, and finally the resulting new tensors. It is a directed acyclic
87-
# graph (DAG) which can be used to compute gradients along every node all
88-
# the way from the roots (output tensors) to the leaves (input tensors)
89-
# using the chain rule from calculus.
92+
# After running the forward pass, PyTorch autograd has built up a `dynamic
93+
# computational
94+
# graph <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#computational-graph>`__
95+
# which is shown below. This is a `Directed Acyclic Graph
96+
# (DAG) <https://en.wikipedia.org/wiki/Directed_acyclic_graph>`__ which
97+
# keeps a record of input tensors (leaf nodes), all subsequent operations
98+
# on those tensors, and the intermediate/output tensors (non-leaf nodes).
99+
# The graph is used to compute gradients for each tensor starting from the
100+
# graph roots (outputs) to the leaves (inputs) using the `chain
101+
# rule <https://en.wikipedia.org/wiki/Chain_rule>`__ from calculus:
90102
#
91-
# In the context of a generic DAG then, a *leaf* is simply a node which is
92-
# at the input (beginning) of the graph, and *non-leaf* nodes are
93-
# everything else.
103+
# .. math::
94104
#
95-
# To start the generation of the computational graph which can be used for
96-
# gradient calculation, we need to pass in the ``requires_grad=True``
97-
# parameter to the tensor constructors. That is because by default,
98-
# PyTorch is not tracking gradients on any created tensors. To verify
99-
# this, try removing the parameter above and then run back-propagation:
105+
# \mathbf{y} = \mathbf{f}_k\bigl(\mathbf{f}_{k-1}(\dots \mathbf{f}_1(\mathbf{x}) \dots)\bigr)
100106
#
101-
# ::
107+
# .. math::
102108
#
103-
# >>> loss.backward()
104-
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
109+
# \frac{\partial \mathbf{y}}{\partial \mathbf{x}} =
110+
# \frac{\partial \mathbf{f}_k}{\partial \mathbf{f}_{k-1}} \cdot
111+
# \frac{\partial \mathbf{f}_{k-1}}{\partial \mathbf{f}_{k-2}} \cdot
112+
# \cdots \cdot
113+
# \frac{\partial \mathbf{f}_1}{\partial \mathbf{x}}
105114
#
106-
# This runtime error is telling us that the tensor is not tracking
107-
# gradients and has no associated gradient function. Thus, it cannot
108-
# back-propagate to the leaf tensors and calculate the gradients for each
109-
# node.
115+
# .. figure:: /_static/img/visualizing_gradients_tutorial/comp-graph-1.png
116+
# :alt: Computational graph after forward pass
110117
#
111-
# From the above discussion, we can see that ``x``, ``W``, ``b``, and
112-
# ``y`` are leaf tensors, whereas ``z``, ``y_pred``, and ``loss`` are
113-
# non-leaf tensors. We can verify this with the class attribute
114-
# ``is_leaf()``:
118+
# Computational graph after forward pass
115119
#
116120

117-
# prints all True because new tensors are leafs by convention
118-
print(f"{x.is_leaf=}")
119-
print(f"{W.is_leaf=}")
120-
print(f"{b.is_leaf=}")
121-
print(f"{y.is_leaf=}")
122-
123-
# prints all False because tensors are the result of an operation
124-
# with at least one tensor having requires_grad=True
125-
print(f"{z.is_leaf=}")
126-
print(f"{y_pred.is_leaf=}")
127-
print(f"{loss.is_leaf=}")
128-
129121

130122
######################################################################
131-
# The distinction between leaf and non-leaf is important, because that
132-
# attribute determines whether the tensor’s gradient will be stored in the
133-
# ``grad`` property after the backward pass, and thus be usable for
134-
# gradient descent optimization. We’ll cover this some more in the
135-
# following section.
136-
#
137-
# Also note that by convention, when the user creates a new tensor,
138-
# PyTorch automatically makes it a leaf node. This is the case even though
139-
# is no computational graph associated with the tensor. For example:
123+
# PyTorch considers a node to be a *leaf* if it is not the result of a
124+
# tensor operation with at least one input having ``requires_grad=True``
125+
# (e.g. ``x``, ``W``, ``b``, and ``y``), and everything else to be
126+
# *non-leaf* (e.g. ``z``, ``y_pred``, and ``loss``). You can verify this
127+
# programmatically by probing the ``is_leaf`` attribute of the tensors:
140128
#
141129

142-
a = torch.tensor([1.0, 5.0, 2.0])
143-
a.is_leaf
130+
# prints True because new tensors are leafs by convention
131+
print(f"{x.is_leaf=}")
132+
133+
# prints False because tensor is the result of an operation with at
134+
# least one input having requires_grad=True
135+
print(f"{z.is_leaf=}")
144136

145137

146138
######################################################################
147-
# Now that we understand what makes a tensor a leaf vs. non-leaf, the
148-
# second piece of the puzzle is knowing when PyTorch calculates and stores
149-
# gradients for the tensors in its computational graph.
139+
# The distinction between leaf and non-leaf determines whether the
140+
# tensor’s gradient will be stored in the ``grad`` property after the
141+
# backward pass, and thus be usable for gradient descent optimization.
142+
# We’ll cover this some more in the `following section <#retain-grad>`__.
143+
#
144+
# Let’s now investigate how PyTorch calculates and stores gradients for
145+
# the tensors in its computational graph.
150146
#
151147

152148

153149
######################################################################
154150
# ``requires_grad``
155-
# =================
151+
# -----------------
156152
#
157-
# To tell PyTorch to explicitly start tracking gradients, when we create
158-
# the tensor, we can pass in the parameter ``requires_grad=True`` to the
159-
# class constructor (by default it is ``False``). This tells PyTorch to
160-
# treat the tensor as a leaf tensor, and all the subsequent operations
161-
# will generate results which also need to require the gradient for
162-
# back-propagation to work. This is because the backward pass uses the
163-
# chain rule from calculus, where intermediate gradients ‘flow’ backward
164-
# through the network.
153+
# To start the generation of the computational graph which can be used for
154+
# gradient calculation, we need to pass in the ``requires_grad=True``
155+
# parameter to a tensor constructor. By default, the value is ``False``,
156+
# and thus PyTorch does not track gradients on any created tensors. To
157+
# verify this, try not setting ``requires_grad``, re-run the forward pass,
158+
# and then run backpropagation. You will see:
165159
#
166-
# We already did this for the parameters we want to optimize, so we’re
167-
# good. If you need to change the property though, you can call
168-
# ``requires_grad_()`` on the tensor to change it (notice the ``_``
169-
# suffix).
160+
# ::
170161
#
171-
# Similar to the analysis above, we can sanity-check which nodes in our
172-
# network have to calculate the gradient for back-propagation to work.
162+
# >>> loss.backward()
163+
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
164+
#
165+
# PyTorch is telling us that because the tensor is not tracking gradients,
166+
# autograd can’t backpropagate to any leaf tensors. If you need to change
167+
# the property, you can call ``requires_grad_()`` on the tensor (notice
168+
# the ’_’ suffix).
169+
#
170+
# We can sanity-check which nodes require gradient calculation, just like
171+
# we did above with the ``is_leaf`` attribute:
173172
#
174173

175-
# prints all False because tensors are leaf nodes
176-
print(f"{x.requires_grad=}")
177-
print(f"{y.requires_grad=}")
178-
179-
# prints all True because requires_grad=True in constructor
180-
print(f"{W.requires_grad=}")
181-
print(f"{b.requires_grad=}")
182-
183-
# prints all True because tensors are non-leaf nodes
184-
print(f"{z.requires_grad=}")
185-
print(f"{y_pred.requires_grad=}")
186-
print(f"{loss.requires_grad=}")
174+
print(f"{x.requires_grad=}") # prints False because requires_grad=False by default
175+
print(f"{W.requires_grad=}") # prints True because we set requires_grad=True in constructor
176+
print(f"{z.requires_grad=}") # prints True because tensor is a non-leaf node
187177

188178

189179
######################################################################
190-
# A useful heuristic to remember is that whenever a tensor is a non-leaf,
191-
# it **has** to have ``requires_grad=True``, otherwise back-propagation
192-
# would fail. If the tensor is a leaf, then it will only have
180+
# It’s useful to remember that by definition a non-leaf tensor has
181+
# ``requires_grad=True``. Backpropagation would fail if this wasn’t the
182+
# case. If the tensor is a leaf, then it will only have
193183
# ``requires_grad=True`` if it was specifically set by the user. Another
194184
# way to phrase this is that if at least one of the inputs to the tensor
195185
# requires the gradient, then it will require the gradient as well.
196186
#
197-
# There are two exceptions to the above guideline:
198-
#
199-
# 1. Using ``nn.Module`` and ``nn.Parameter``
200-
# 2. `Locally disabling gradient computation with context
201-
# managers <https://docs.pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation>`__
187+
# There are two exceptions to this rule:
202188
#
203-
# For the first case, if you subclass the ``nn.Module`` base class, then
204-
# by default all of the parameters of that module will have
205-
# ``requires_grad`` automatically set to ``True``. e.g.:
206-
#
207-
208-
class Model(nn.Module):
209-
def __init__(self) -> None:
210-
super().__init__()
211-
self.conv1 = nn.Conv2d(1, 20, 5)
212-
self.conv2 = nn.Conv2d(20, 20, 5)
213-
214-
def forward(self, x):
215-
x = F.relu(self.conv1(x))
216-
return F.relu(self.conv2(x))
217-
218-
m = Model()
219-
220-
for name, param in m.named_parameters():
221-
print(name, param.requires_grad)
222-
223-
224-
######################################################################
225-
# For the second case, if you wrap one of the gradient context managers
226-
# around a tensor, then computations behave as if none of the inputs
227-
# require grad.
189+
# 1. Any ``nn.Module`` that has ``nn.Parameter`` will have
190+
# ``requires_grad=True`` for its parameters (see
191+
# `here <https://docs.pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#creating-models>`__)
192+
# 2. Locally disabling gradient computation with context managers (see
193+
# `here <https://docs.pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation>`__)
228194
#
229195

230-
z = (x @ W) + b # same as before
231-
232-
with torch.no_grad(): # could also use torch.inference_mode()
233-
z2 = (x @ W) + b
234-
235-
print(f"{z.requires_grad=}")
236-
print(f"{z2.requires_grad=}")
237-
238196

239197
######################################################################
240198
# In summary, ``requires_grad`` tells autograd which tensors need to have
241-
# their gradients calculated for back-propagation to work. This is
199+
# their gradients calculated for backpropagation to work. This is
242200
# different from which gradients have to be stored inside the tensor,
243201
# which is the topic of the next section.
244202
#
245203

246204

247205
######################################################################
248-
# Back-propagation
249-
# ----------------
206+
# ``retain_grad``
207+
# ---------------
250208
#
251209
# To actually perform optimization (e.g. SGD, Adam, etc.), we need to run
252210
# the backward pass so that we can extract the gradients.
@@ -257,8 +215,9 @@ def forward(self, x):
257215

258216
######################################################################
259217
# This single function call populated the ``grad`` property of all leaf
260-
# tensors which had their ``requires_grad=True``. The ``grad`` is the
261-
# gradient of the loss with respect to the tensor we are probing.
218+
# tensors which had ``requires_grad=True``. The ``grad`` is the gradient
219+
# of the loss with respect to the tensor we are probing. Before running
220+
# ``backward()``, this attribute is set to ``None``.
262221
#
263222

264223
print(f"{W.grad=}")
@@ -270,15 +229,15 @@ def forward(self, x):
270229
# check the remaining leaf nodes:
271230
#
272231

232+
# prints all None because requires_grad=False
273233
print(f"{x.grad=}")
274-
print(f"{y.grad=}")
234+
print(f"{y.grad=}")
275235

276236

277237
######################################################################
278-
# Interestingly, these gradients haven’t been populated into the ``grad``
279-
# property and they default to ``None``. This is expected behavior though
280-
# because we did not explicitly tell PyTorch to calculate gradient with
281-
# the ``requires_grad`` parameter.
238+
# The gradients for these tensors haven’t been populated because we did
239+
# not explicitly tell PyTorch to calculate their gradient
240+
# (``requires_grad=False``).
282241
#
283242
# Let’s now look at an intermediate non-leaf node:
284243
#
@@ -288,139 +247,312 @@ def forward(self, x):
288247

289248
######################################################################
290249
# We also get ``None`` for the gradient, but now PyTorch warns us that a
291-
# non-leaf node’s ``grad`` attribute is being accessed. It might come as a
292-
# surprise that we can’t access the gradient for intermediate tensors in
293-
# the computational graph, since they **have** to calculate the gradient
294-
# for back-propagation to work. PyTorch errs on the side of performance
295-
# and assumes that you don’t need to access intermediate gradients if
296-
# you’re trying to optimize leaf tensors. To change this behavior, we can
297-
# use the ``retain_grad()`` function.
250+
# non-leaf node’s ``grad`` attribute is being accessed. Although autograd
251+
# has to calculate intermediate gradients for backpropagation to work, it
252+
# assumes you don’t need to access the values afterwards. To change this
253+
# behavior, we can use the ``retain_grad()`` function on a tensor. This
254+
# tells the autograd engine to populate that tensor’s ``grad`` after
255+
# calling ``backward()``.
298256
#
299257

258+
# we have to re-run the forward pass
259+
z = (x @ W) + b
260+
y_pred = F.relu(z)
261+
loss = F.mse_loss(y_pred, y)
262+
263+
# tell PyTorch to store the gradients after backward()
264+
z.retain_grad()
265+
y_pred.retain_grad()
266+
loss.retain_grad()
267+
268+
# have to zero out gradients otherwise they would accumulate
269+
W.grad = None
270+
b.grad = None
271+
272+
# backpropagation
273+
loss.backward()
274+
275+
# print gradients for all tensors that have requires_grad=True
276+
print(f"{W.grad=}")
277+
print(f"{b.grad=}")
278+
print(f"{z.grad=}")
279+
print(f"{y_pred.grad=}")
280+
print(f"{loss.grad=}")
281+
300282

301283
######################################################################
302-
# ``retain_grad``
303-
# ---------------
284+
# We get the same result for ``W.grad`` as before. Also note that because
285+
# the loss is scalar, the gradient of the loss with respect to itself is
286+
# simply ``1.0``.
304287
#
305-
# When we call ``retain_grad()`` on a tensor, this signals to the autograd
306-
# engine that we want to have that tensor’s ``grad`` populated after
307-
# calling ``backward()``.
288+
# If we look at the state of the computational graph now, we see that the
289+
# ``retains_grad`` attribute has changed for the intermediate tensors. By
290+
# convention, this attribute will print ``False`` for any leaf node, even
291+
# if it requires its gradient.
308292
#
309-
# We can verify that PyTorch is not storing gradients for non-leaf tensors
310-
# by accessing the ``retains_grad`` flag:
293+
# .. figure:: /_static/img/visualizing_gradients_tutorial/comp-graph-2.png
294+
# :alt: Computational graph after backward pass
311295
#
296+
# Computational graph after backward pass
297+
#
298+
312299

313-
# Prints all False because we didn't tell PyTorch to store gradients with `retain_grad()`
314-
print(f"{z.retains_grad=}")
315-
print(f"{y_pred.retains_grad=}")
316-
print(f"{loss.retains_grad=}")
300+
######################################################################
301+
# If you call ``retain_grad()`` on a non-leaf node, it results in a no-op.
302+
# If we call ``retain_grad()`` on a node that has ``requires_grad=False``,
303+
# PyTorch actually throws an error, since it can’t store the gradient if
304+
# it is never calculated.
305+
#
306+
# ::
307+
#
308+
# >>> x.retain_grad()
309+
# RuntimeError: can't retain_grad on Tensor that has requires_grad=False
310+
#
311+
# In summary, using ``retain_grad()`` and ``retains_grad`` only make sense
312+
# for non-leaf nodes, since the ``grad`` attribute will already be
313+
# populated for leaf tensors that have ``requires_grad=True``. By default,
314+
# these non-leaf nodes do not retain (store) their gradient after
315+
# backpropagation. We can change that by rerunning the forward pass,
316+
# telling PyTorch to store the gradients, and then performing
317+
# backpropagation.
318+
#
319+
# The following table can be used as a cheat-sheet which summarizes the
320+
# above discussions. The following scenarios are the only ones that are
321+
# valid for PyTorch tensors.
322+
#
323+
#
324+
#
325+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
326+
# | ``is_leaf`` | ``requires_grad`` | ``retains_grad`` | ``require_grad()`` | ``retain_grad()`` |
327+
# +================+========================+========================+===================================================+=====================================+
328+
# | ``True`` | ``False`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
329+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
330+
# | ``True`` | ``True`` | ``False`` | sets ``requires_grad`` to ``True`` or ``False`` | no-op |
331+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
332+
# | ``False`` | ``True`` | ``False`` | no-op | sets ``retains_grad`` to ``True`` |
333+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
334+
# | ``False`` | ``True`` | ``True`` | no-op | no-op |
335+
# +----------------+------------------------+------------------------+---------------------------------------------------+-------------------------------------+
336+
#
317337

318338

319339
######################################################################
320-
# We can also check the other leaf tensors, but note that by convention,
321-
# this attribute will print ``False`` for any leaf node, even if that
322-
# tensor was set to require its gradient. This is true even if you call
323-
# ``retain_grad()`` on a leaf node that has ``requires_grad=True``, which
324-
# results in a no-op.
340+
# (work-in-progress) Real world example with ResNet
341+
# -------------------------------------------------
342+
#
343+
# Let’s move on from the toy example above and study a realistic network:
344+
# `ResNet <https://docs.pytorch.org/vision/2.0/models/resnet.html>`__.
345+
#
346+
# To illustrate the importance of gradient visualization, we will
347+
# instantiate two versions of ResNet: one without batch normalization
348+
# (``BatchNorm``), and one with it. `Batch
349+
# normalization <https://arxiv.org/abs/1502.03167>`__ is an extremely
350+
# effective technique to resolve the vanishing/exploding gradients issue,
351+
# and we will be verifying that experimentally.
352+
#
353+
# We first initiate the models without ``BatchNorm`` following the
354+
# `documentation <https://docs.pytorch.org/vision/2.0/models/generated/torchvision.models.resnet18.html>`__.
325355
#
326356

327-
# Prints all False because these are leaf tensors
328-
print(f"{x.retains_grad=}")
329-
print(f"{y.retains_grad=}")
330-
print(f"{b.retains_grad=}")
331-
print(f"{W.retains_grad=}")
357+
# set up dummy data
358+
x = torch.randn(1, 3, 224, 224)
359+
y = torch.randn(1, 1000)
332360

333-
W.retain_grad()
334-
print(f"{W.retains_grad=}") # still False
361+
# init model
362+
# model = resnet18(norm_layer=nn.Identity)
363+
model = resnet18()
364+
model.train()
365+
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
335366

336367

337368
######################################################################
338-
# If we try calling ``retain_grad()`` on a node that has
339-
# ``require_grad=False``, PyTorch actually throws an error.
369+
# Because we are using a ``nn.Module`` instead of individual tensors for
370+
# our forward pass, we need another adopt our method to access the
371+
# intermediate gradients. This is done by `registering a
372+
# hook <https://www.digitalocean.com/community/tutorials/pytorch-hooks-gradient-clipping-debugging>`__.
340373
#
341-
# ::
374+
# Note that using backward pass hooks to probe an intermediate nodes
375+
# gradient is preferred over using ``retain_grad()``. It avoids the memory
376+
# retention overhead if gradients aren’t needed after backpropagation. It
377+
# also lets you modify and/or clamp gradients during the backward pass, so
378+
# they don’t vanish or explode.
342379
#
343-
# >>> x.retain_grad()
344-
# RuntimeError: can't retain_grad on Tensor that has requires_grad=False
380+
# The following code defines our forward pass hook (notice the call to
381+
# ``retain_grad()``) and also collects names of all parameters and layers.
345382
#
346383

384+
def hook_forward(module, args, output):
385+
output.retain_grad() # store gradient in ouput tensors
386+
387+
# grads and layers are global variables
388+
outputs.append((layers[module], output))
389+
390+
def get_all_layers(layer, hook_fn):
391+
"""Returns dict where keys are children modules and values are layer names"""
392+
layers = dict()
393+
for name, layer in model.named_modules():
394+
if any(layer.children()) is False:
395+
# skip Sequential and/or wrapper modules
396+
layers[layer] = name
397+
layer.register_forward_hook(hook_fn) # hook_forward
398+
return layers
399+
400+
def get_all_params(model):
401+
"""return list of all leaf tensors with requires_grad=True and which are not bias terms"""
402+
params = []
403+
for name, param in model.named_parameters():
404+
if param.requires_grad and "bias" not in name:
405+
params.append((name, param))
406+
return params
407+
408+
# register hooks
409+
layers = get_all_layers(model, hook_forward)
410+
411+
# get parameter gradients
412+
params = get_all_params(model)
413+
347414

348415
######################################################################
349-
# In summary, using ``retain_grad()`` and ``retains_grad`` only make sense
350-
# for non-leaf nodes, since the ``grad`` attribute has to be populated for
351-
# leaf tensors that have ``requires_grad=True``. By default, these
352-
# non-leaf nodes do not retain (store) their gradient after
353-
# back-propagation.
416+
# Let’s check a few of the layers and parameters to make sure things are
417+
# as expected:
354418
#
355-
# We can change that by rerunning the forward pass, telling PyTorch to
356-
# store the gradients, and then performing back-propagation.
419+
420+
num_layers = 5
421+
print("<--------Params-------->")
422+
for name, param in params[0:num_layers]:
423+
print(name, param.shape)
424+
425+
count = 0
426+
print("<--------Layers-------->")
427+
for layer in layers.values():
428+
print(layer)
429+
count += 1
430+
if count >= num_layers:
431+
break
432+
433+
434+
######################################################################
435+
# Now let’s run a forward pass and verify our output tensor values were
436+
# populated.
357437
#
358438

359-
# forward pass
360-
z = (x @ W) + b
361-
y_pred = F.relu(z)
362-
loss = F.mse_loss(y_pred, y)
439+
outputs = [] # list with layer name, output tensor tuple
440+
optimizer.zero_grad()
441+
y_pred = model(x)
442+
loss = F.mse_loss(y_pred, y)
363443

364-
# tell PyTorch to store the gradients after backward()
365-
z.retain_grad()
366-
y_pred.retain_grad()
367-
loss.retain_grad()
444+
print("<--------Outputs-------->")
445+
for name, output in outputs[0:num_layers]:
446+
print(name, output.shape)
368447

369-
# have to zero out gradients otherwise they would accumulate
370-
W.grad = None
371-
b.grad = None
372448

373-
# back-propagation
449+
######################################################################
450+
# Everything looks good so far, so let’s call ``backward()``, populate the
451+
# ``grad`` values for all intermediate tensors, and get the average
452+
# gradient for each layer.
453+
#
454+
374455
loss.backward()
375456

376-
# print gradients for all tensors that have requires_grad=True
377-
print(f"{W.grad=}")
378-
print(f"{b.grad=}")
379-
print(f"{z.grad=}")
380-
print(f"{y_pred.grad=}")
381-
print(f"{loss.grad=}")
457+
def get_grads():
458+
layer_idx = []
459+
avg_grads = []
460+
print("<--------Grads-------->")
461+
for idx, (name, output) in enumerate(outputs[0:-2]):
462+
if output.grad is not None:
463+
avg_grad = output.grad.abs().mean()
464+
if idx < num_layers:
465+
print(name, avg_grad)
466+
avg_grads.append(avg_grad)
467+
layer_idx.append(idx)
468+
return layer_idx, avg_grads
469+
470+
layer_idx, avg_grads = get_grads()
382471

383472

384473
######################################################################
385-
# Note we get the same result for ``W.grad`` as before. Also note that
386-
# because the loss is scalar, the gradient of the loss with respect to
387-
# itself is simply ``1.0``.
474+
# Now that we have all our gradients stored in ``grads``, we can plot them
475+
# and see how the average gradient values change as a function of the
476+
# network depth.
388477
#
389478

479+
def plot_grads(layer_idx, avg_grads):
480+
plt.plot(layer_idx, avg_grads)
481+
plt.xlabel("Layer depth")
482+
plt.ylabel("Average gradient")
483+
plt.title("Gradient flow")
484+
plt.grid(True)
485+
486+
plot_grads(layer_idx, avg_grads)
487+
390488

391489
######################################################################
392-
# (work-in-progress) Real-world example - visualizing gradient flow
393-
# -----------------------------------------------------------------
394-
#
395-
# We used a toy example above, but let’s now apply the concepts we learned
396-
# to the visualization of intermediate gradients in a more powerful neural
397-
# network: ResNet.
490+
# Upon initialization, this is not very interesting. Let’s try running for
491+
# several epochs, use gradient descent, and then see how the values
492+
# change.
398493
#
399494

495+
epochs = 20
496+
497+
for epoch in range(epochs):
498+
outputs = [] # list with layer name, output tensor tuple
499+
optimizer.zero_grad()
500+
y_pred = model(x)
501+
loss = F.mse_loss(y_pred, y)
502+
loss.backward()
503+
optimizer.step()
504+
505+
layer_idx, avg_grads = get_grads()
506+
plot_grads(layer_idx, avg_grads)
507+
508+
509+
######################################################################
510+
# Still not very interesting… surprised that the gradients don’t
511+
# accumulate. Let’s check the leaf tensors… those tensors are probably
512+
# just recreated whenever I rerun the forward pass, and thus they don’t
513+
# accumulate. Let’s see if that’s the case with the parameters.
514+
#
515+
516+
def get_param_grads():
517+
layer_idx = []
518+
avg_grads = []
519+
print("<--------Params-------->")
520+
for idx, (name, param) in enumerate(params):
521+
if param.grad is not None:
522+
avg_grad = param.grad.abs().mean()
523+
if idx < num_layers:
524+
print(name, avg_grad)
525+
avg_grads.append(avg_grad)
526+
layer_idx.append(idx)
527+
return layer_idx, avg_grads
528+
529+
layer_idx, avg_grads = get_param_grads()
530+
531+
532+
plot_grads(layer_idx, avg_grads)
533+
400534

401535
######################################################################
402536
# (work-in-progress) Conclusion
403537
# -----------------------------
404538
#
405-
# This table can be used as a cheat-sheet which summarizes the above
406-
# discussions. The following scenarios are the only ones that are valid
407-
# for PyTorch tensors.
539+
# If you would like to learn more about how PyTorch’s autograd system
540+
# works, please visit the `references <#references>`__ below. If you have
541+
# any feedback for this tutorial (improvements, typo fixes, etc.) then
542+
# please use the `PyTorch Forums <https://discuss.pytorch.org/>`__ and/or
543+
# the `issue tracker <https://github.com/pytorch/tutorials/issues>`__ to
544+
# reach out.
408545
#
409-
# ============ ================== ================ =================================== =============================
410-
# ``is_leaf`` ``requires_grad`` ``retains_grad`` ``require_grad()`` ``retain_grad()``
411-
# ============ ================== ================ =================================== =============================
412-
# True False False sets ``require_grad`` to True/False no-op
413-
# True True False sets ``require_grad`` to True/False no-op
414-
# False True False no-op sets ``retains_grad`` to True
415-
# False True True no-op no-op
416-
# ============ ================== ================ =================================== =============================
417546

418547

419548
######################################################################
420549
# References
421550
# ----------
422551
#
423-
# https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial
424-
#
425-
# https://docs.pytorch.org/docs/stable/notes/autograd.html#setting-requires-grad
552+
# - `A Gentle Introduction to
553+
# torch.autograd <https://docs.pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html>`__
554+
# - `Automatic Differentiation with
555+
# torch.autograd <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial>`__
556+
# - `Autograd
557+
# mechanics <https://docs.pytorch.org/docs/stable/notes/autograd.html>`__
426558
#

0 commit comments

Comments
 (0)
Please sign in to comment.