22
22
- Visualize gradients after backpropagation in a neural network
23
23
24
24
We will start off with a simple network to understand how PyTorch
25
- calculates and stores gradients, and then build on this knowledge to
26
- visualize the gradient flow of a `ResNet
27
- model <https://docs.pytorch.org/vision/2.0/models/resnet.html>`__.
25
+ calculates and stores gradients. Building on this knowledge, we will
26
+ then visualize the gradient flow of a more complicated model and see the
27
+ effect that `batch normalization <https://arxiv.org/abs/1502.03167>`__
28
+ has on the gradient distribution.
28
29
29
30
Before starting, it is recommended to have a solid understanding of
30
31
`tensors and how to manipulate
46
47
#
47
48
48
49
import torch
49
- import torchvision
50
- from torchvision .models import resnet18
51
50
import torch .nn as nn
52
51
import torch .optim as optim
53
52
import torch .nn .functional as F
192
191
# 2. Locally disabling gradient computation with context managers (see
193
192
# `here <https://docs.pytorch.org/docs/stable/notes/autograd.html#locally-disabling-gradient-computation>`__)
194
193
#
195
-
196
-
197
- ######################################################################
198
194
# In summary, ``requires_grad`` tells autograd which tensors need to have
199
195
# their gradients calculated for backpropagation to work. This is
200
196
# different from which gradients have to be stored inside the tensor,
337
333
338
334
339
335
######################################################################
340
- # (work-in-progress) Real world example with ResNet
341
- # -------------------------------------------------
336
+ # Real world example with BatchNorm
337
+ # ---------------------------------
342
338
#
343
- # Let’s move on from the toy example above and study a realistic network:
344
- # `ResNet <https://docs.pytorch.org/vision/2.0/models/resnet.html>`__.
339
+ # Let’s move on from the toy example above and study a more realistic
340
+ # network. We’ll be creating a network intended for the MNIST dataset,
341
+ # similar to the architecture described by the `batch normalization
342
+ # paper <https://arxiv.org/abs/1502.03167>`__.
345
343
#
346
344
# To illustrate the importance of gradient visualization, we will
347
- # instantiate two versions of ResNet: one without batch normalization
348
- # (``BatchNorm``), and one with it. `Batch
349
- # normalization <https://arxiv.org/abs/1502.03167>`__ is an extremely
345
+ # instantiate one version of the network with batch normalization
346
+ # (BatchNorm), and one without it. Batch normalization is an extremely
350
347
# effective technique to resolve the vanishing/exploding gradients issue,
351
348
# and we will be verifying that experimentally.
352
349
#
353
- # We first initiate the models without ``BatchNorm`` following the
354
- # `documentation <https://docs.pytorch.org/vision/2.0/models/generated/torchvision.models.resnet18.html>`__.
350
+ # The model we will use has a specified number of repeating
351
+ # fully-connected layers which alternate between ``nn.Linear``,
352
+ # ``norm_layer``, and ``nn.Sigmoid``. If we apply batch normalization,
353
+ # then ``norm_layer`` will use
354
+ # `BatchNorm1d <https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html>`__,
355
+ # otherwise it will use the identity transformation
356
+ # `Identity <https://docs.pytorch.org/docs/stable/generated/torch.nn.Identity.html>`__.
357
+ #
358
+
359
+ def fc_layer (in_size , out_size , norm_layer ):
360
+ """Return a stack of linear->norm->sigmoid layers"""
361
+ return nn .Sequential (nn .Linear (in_size , out_size ), norm_layer (out_size ), nn .Sigmoid ())
362
+
363
+ class Net (nn .Module ):
364
+ """Define a network that has num_layers of linear->norm->sigmoid transformations"""
365
+ def __init__ (self , in_size = 28 * 28 , hidden_size = 128 ,
366
+ out_size = 10 , num_layers = 3 , batchnorm = False ):
367
+ super ().__init__ ()
368
+ if batchnorm is False :
369
+ norm_layer = nn .Identity
370
+ else :
371
+ norm_layer = nn .BatchNorm1d
372
+
373
+ layers = []
374
+ layers .append (fc_layer (in_size , hidden_size , norm_layer ))
375
+
376
+ for i in range (num_layers - 1 ):
377
+ layers .append (fc_layer (hidden_size , hidden_size , norm_layer ))
378
+
379
+ layers .append (nn .Linear (hidden_size , out_size ))
380
+
381
+ self .layers = nn .Sequential (* layers )
382
+
383
+ def forward (self , x ):
384
+ x = torch .flatten (x , 1 )
385
+ return self .layers (x )
386
+
387
+
388
+ ######################################################################
389
+ # Next we set up some dummy data, instantiate two versions of the model,
390
+ # and initialize the optimizers.
355
391
#
356
392
357
393
# set up dummy data
358
- x = torch .randn (1 , 3 , 224 , 224 )
359
- y = torch .randn ( 1 , 1000 )
394
+ x = torch .randn (10 , 28 , 28 )
395
+ y = torch .randint ( 10 , ( 10 , ) )
360
396
361
397
# init model
362
- # model = resnet18(norm_layer=nn.Identity)
363
- model = resnet18 ()
364
- model .train ()
365
- optimizer = optim .SGD (model .parameters (), lr = 0.01 , momentum = 0.9 )
398
+ model_bn = Net (batchnorm = True , num_layers = 3 )
399
+ model_nobn = Net (batchnorm = False , num_layers = 3 )
400
+
401
+ model_bn .train ()
402
+ model_nobn .train ()
403
+
404
+ optimizer_bn = optim .SGD (model_bn .parameters (), lr = 0.01 , momentum = 0.9 )
405
+ optimizer_nobn = optim .SGD (model_nobn .parameters (), lr = 0.01 , momentum = 0.9 )
406
+
407
+
408
+
409
+ ######################################################################
410
+ # We can verify that batch normalization is only being applied to one of
411
+ # the models by probing one of the internal layers:
412
+ #
413
+
414
+ print (model_bn .layers [0 ])
415
+ print (model_nobn .layers [0 ])
366
416
367
417
368
418
######################################################################
369
419
# Because we are using a ``nn.Module`` instead of individual tensors for
370
- # our forward pass, we need another adopt our method to access the
371
- # intermediate gradients. This is done by `registering a
420
+ # our forward pass, we need another method to access the intermediate
421
+ # gradients. This is done by `registering a
372
422
# hook <https://www.digitalocean.com/community/tutorials/pytorch-hooks-gradient-clipping-debugging>`__.
373
423
#
374
- # Note that using backward pass hooks to probe an intermediate nodes
375
- # gradient is preferred over using ``retain_grad()``. It avoids the memory
376
- # retention overhead if gradients aren’t needed after backpropagation. It
377
- # also lets you modify and/or clamp gradients during the backward pass, so
378
- # they don’t vanish or explode.
424
+ # .. warning::
425
+ #
426
+ # Note that using backward pass hooks to probe an intermediate nodes gradient is preferred over using `retain_grad()`.
427
+ # It avoids the memory retention overhead if gradients aren't needed after backpropagation.
428
+ # It also lets you modify and/or clamp gradients during the backward pass, so they don't vanish or explode.
429
+ # However, if in-place operations are performed, you cannot use the backward pass hook
430
+ # since it wraps the forward pass with views instead of the actual tensors. For more information
431
+ # please refer to https://github.com/pytorch/pytorch/issues/61519.
379
432
#
380
433
# The following code defines our forward pass hook (notice the call to
381
- # ``retain_grad()``) and also collects names of all parameters and layers.
434
+ # ``retain_grad()``) and also gathers descriptive names for the network’s
435
+ # layers.
382
436
#
383
437
384
- def hook_forward (module , args , output ):
385
- output .retain_grad () # store gradient in ouput tensors
438
+ def hook_forward_wrapper (module_name , outputs ):
439
+ """Python function closure so we can pass args"""
440
+ def hook_forward (module , args , output ):
441
+ """Hook for forward pass which retains gradients and saves intermediate tensors"""
442
+ output .retain_grad ()
443
+ outputs .append ((module_name , output ))
444
+ return hook_forward
386
445
387
- # grads and layers are global variables
388
- outputs . append (( layers [ module ], output ))
446
+ def get_all_layers ( model , hook_fn ):
447
+ """Register forward pass hook to all outputs in model
389
448
390
- def get_all_layers (layer , hook_fn ):
391
- """Returns dict where keys are children modules and values are layer names"""
449
+ Returns layers, a dict with keys as layer/module and values as layer/module names
450
+ e.g.: layers[nn.Conv2d] = layer1.0.conv1
451
+
452
+ Returns outputs, a list of tuples with module name and tensor output. e.g.:
453
+ outputs[0] == (layer1.0.conv1, tensor.Torch(...))
454
+
455
+ The layer name is passed to a forward hook which will eventually go into a tuple
456
+ """
392
457
layers = dict ()
458
+ outputs = []
393
459
for name , layer in model .named_modules ():
394
460
if any (layer .children ()) is False :
395
461
# skip Sequential and/or wrapper modules
396
462
layers [layer ] = name
397
- layer .register_forward_hook (hook_fn ) # hook_forward
398
- return layers
463
+ layer .register_forward_hook (hook_forward_wrapper ( name , outputs ))
464
+ return layers , outputs
399
465
400
- def get_all_params (model ):
401
- """return list of all leaf tensors with requires_grad=True and which are not bias terms"""
402
- params = []
403
- for name , param in model .named_parameters ():
404
- if param .requires_grad and "bias" not in name :
405
- params .append ((name , param ))
406
- return params
407
-
408
- # register hooks
409
- layers = get_all_layers (model , hook_forward )
410
-
411
- # get parameter gradients
412
- params = get_all_params (model )
413
-
414
-
415
- ######################################################################
416
- # Let’s check a few of the layers and parameters to make sure things are
417
- # as expected:
418
- #
419
-
420
- num_layers = 5
421
- print ("<--------Params-------->" )
422
- for name , param in params [0 :num_layers ]:
423
- print (name , param .shape )
424
-
425
- count = 0
426
- print ("<--------Layers-------->" )
427
- for layer in layers .values ():
428
- print (layer )
429
- count += 1
430
- if count >= num_layers :
431
- break
466
+ # register hooks
467
+ layers_bn , outputs_bn = get_all_layers (model_bn , hook_forward_wrapper )
468
+ layers_nobn , outputs_nobn = get_all_layers (model_nobn , hook_forward_wrapper )
432
469
433
470
434
471
######################################################################
435
- # Now let’s run a forward pass and verify our output tensor values were
436
- # populated.
472
+ # Now let’s train the models for a few epochs:
437
473
#
438
474
439
- outputs = [] # list with layer name, output tensor tuple
440
- optimizer .zero_grad ()
441
- y_pred = model (x )
442
- loss = F .mse_loss (y_pred , y )
475
+ epochs = 10
443
476
444
- print ("<--------Outputs-------->" )
445
- for name , output in outputs [0 :num_layers ]:
446
- print (name , output .shape )
477
+ for epoch in range (epochs ):
478
+
479
+ # important to clear, because we append to
480
+ # outputs everytime we do a forward pass
481
+ outputs_bn .clear ()
482
+ outputs_nobn .clear ()
483
+
484
+ optimizer_bn .zero_grad ()
485
+ optimizer_nobn .zero_grad ()
486
+
487
+ y_pred_bn = model_bn (x )
488
+ y_pred_nobn = model_nobn (x )
489
+
490
+ loss_bn = F .cross_entropy (y_pred_bn , y )
491
+ loss_nobn = F .cross_entropy (y_pred_nobn , y )
492
+
493
+ loss_bn .backward ()
494
+ loss_nobn .backward ()
495
+
496
+ optimizer_bn .step ()
497
+ optimizer_nobn .step ()
447
498
448
499
449
500
######################################################################
450
- # Everything looks good so far, so let’s call ``backward()``, populate the
451
- # ``grad`` values for all intermediate tensors, and get the average
452
- # gradient for each layer.
501
+ # After running the forward and backward pass, the ``grad`` values for all
502
+ # the intermediate tensors should be present in ``outputs_bn`` and
503
+ # ``outputs_nobn``. We reduce the gradient matrix to a single number (mean
504
+ # absolute value) so that we can compare the two models.
453
505
#
454
506
455
- loss .backward ()
456
-
457
- def get_grads ():
507
+ def get_grads (outputs ):
458
508
layer_idx = []
459
509
avg_grads = []
460
- print ("<--------Grads-------->" )
461
- for idx , (name , output ) in enumerate (outputs [0 :- 2 ]):
510
+ for idx , (name , output ) in enumerate (outputs ):
462
511
if output .grad is not None :
463
512
avg_grad = output .grad .abs ().mean ()
464
- if idx < num_layers :
465
- print (name , avg_grad )
466
513
avg_grads .append (avg_grad )
467
514
layer_idx .append (idx )
468
515
return layer_idx , avg_grads
469
516
470
- layer_idx , avg_grads = get_grads ()
517
+ layer_idx_bn , avg_grads_bn = get_grads (outputs_bn )
518
+ layer_idx_nobn , avg_grads_nobn = get_grads (outputs_nobn )
471
519
472
520
473
521
######################################################################
474
- # Now that we have all our gradients stored in ``grads``, we can plot them
475
- # and see how the average gradient values change as a function of the
476
- # network depth.
522
+ # Now that we have all our gradients stored in ``avg_grads``, we can plot
523
+ # them and see how the average gradient values change as a function of the
524
+ # network depth. We see that when we don’t have batch normalization, the
525
+ # gradient values in the intermediate layers fall to zero very quickly.
526
+ # The batch normalization model, however, maintains non-zero gradients in
527
+ # its intermediate layers.
477
528
#
478
529
479
- def plot_grads (layer_idx , avg_grads ):
480
- plt .plot (layer_idx , avg_grads )
481
- plt .xlabel ("Layer depth" )
482
- plt .ylabel ("Average gradient" )
483
- plt .title ("Gradient flow" )
484
- plt .grid (True )
485
-
486
- plot_grads (layer_idx , avg_grads )
530
+ fig , ax = plt .subplots ()
531
+ ax .plot (layer_idx_bn , avg_grads_bn , label = "With BatchNorm" , marker = "o" )
532
+ ax .plot (layer_idx_nobn , avg_grads_nobn , label = "Without BatchNorm" , marker = "x" )
533
+ ax .set_xlabel ("Layer depth" )
534
+ ax .set_ylabel ("Average gradient" )
535
+ ax .set_title ("Gradient flow" )
536
+ ax .grid (True )
537
+ ax .legend ()
538
+ plt .show ()
487
539
488
540
489
541
######################################################################
490
- # Upon initialization, this is not very interesting. Let’s try running for
491
- # several epochs, use gradient descent, and then see how the values
492
- # change.
542
+ # Conclusion
543
+ # ----------
493
544
#
494
-
495
- epochs = 20
496
-
497
- for epoch in range (epochs ):
498
- outputs = [] # list with layer name, output tensor tuple
499
- optimizer .zero_grad ()
500
- y_pred = model (x )
501
- loss = F .mse_loss (y_pred , y )
502
- loss .backward ()
503
- optimizer .step ()
504
-
505
- layer_idx , avg_grads = get_grads ()
506
- plot_grads (layer_idx , avg_grads )
507
-
508
-
509
- ######################################################################
510
- # Still not very interesting… surprised that the gradients don’t
511
- # accumulate. Let’s check the leaf tensors… those tensors are probably
512
- # just recreated whenever I rerun the forward pass, and thus they don’t
513
- # accumulate. Let’s see if that’s the case with the parameters.
514
- #
515
-
516
- def get_param_grads ():
517
- layer_idx = []
518
- avg_grads = []
519
- print ("<--------Params-------->" )
520
- for idx , (name , param ) in enumerate (params ):
521
- if param .grad is not None :
522
- avg_grad = param .grad .abs ().mean ()
523
- if idx < num_layers :
524
- print (name , avg_grad )
525
- avg_grads .append (avg_grad )
526
- layer_idx .append (idx )
527
- return layer_idx , avg_grads
528
-
529
- layer_idx , avg_grads = get_param_grads ()
530
-
531
-
532
- plot_grads (layer_idx , avg_grads )
533
-
534
-
535
- ######################################################################
536
- # (work-in-progress) Conclusion
537
- # -----------------------------
545
+ # In this tutorial, we covered when and how PyTorch computes gradients for
546
+ # leaf and non-leaf tensors. By using ``retain_grad``, we can access the
547
+ # gradients of intermediate tensors within autograd’s computational graph.
548
+ # Building upon this, we then demonstrated how to visualize the gradient
549
+ # flow through a neural network wrapped in a ``nn.Module`` class. We
550
+ # qualitatively showed how batch normalization helps to alleviate the
551
+ # vanishing gradient issue which occurs with deep neural networks.
538
552
#
539
553
# If you would like to learn more about how PyTorch’s autograd system
540
554
# works, please visit the `references <#references>`__ below. If you have
@@ -545,6 +559,20 @@ def get_param_grads():
545
559
#
546
560
547
561
562
+ ######################################################################
563
+ # (Optional) Additional exercises
564
+ # -------------------------------
565
+ #
566
+ # - Try increasing the number of layers (``num_layers``) in our model and
567
+ # see what effect this has on the gradient flow graph
568
+ # - How would you adapt the code to visualize average activations instead
569
+ # of average gradients? (*Hint: in the ``get_grads()`` function we have
570
+ # access to the raw tensor output*)
571
+ # - What are some other methods to deal with vanishing and exploding
572
+ # gradients?
573
+ #
574
+
575
+
548
576
######################################################################
549
577
# References
550
578
# ----------
@@ -555,4 +583,11 @@ def get_param_grads():
555
583
# torch.autograd <https://docs.pytorch.org/tutorials/beginner/basics/autogradqs_tutorial>`__
556
584
# - `Autograd
557
585
# mechanics <https://docs.pytorch.org/docs/stable/notes/autograd.html>`__
586
+ # - `Batch Normalization: Accelerating Deep Network Training by Reducing
587
+ # Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__
588
+ #
589
+
590
+
591
+ ######################################################################
592
+ #
558
593
#
0 commit comments