Skip to content

Commit 32d2b29

Browse files
c-p-i-osvekars
andauthored
[doc] ddp multigpu tutorial - small updates (#3141)
* [doc] ddp multigpu tutorial - small updates The main fix is to change `diff` blocks into `python` blocks so that the user can easily copy/paste the data to run parts of the tutorial. Fix author's link. --------- Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent f08670d commit 32d2b29

File tree

1 file changed

+66
-48
lines changed

1 file changed

+66
-48
lines changed

beginner_source/ddp_series_multigpu.rst

Lines changed: 66 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
Multi GPU training with DDP
1010
===========================
1111

12-
Authors: `Suraj Subramanian <https://github.com/suraj813>`__
12+
Authors: `Suraj Subramanian <https://github.com/subramen>`__
1313

1414
.. grid:: 2
1515

@@ -19,13 +19,13 @@ Authors: `Suraj Subramanian <https://github.com/suraj813>`__
1919
- How to migrate a single-GPU training script to multi-GPU via DDP
2020
- Setting up the distributed process group
2121
- Saving and loading models in a distributed setup
22-
22+
2323
.. grid:: 1
2424

2525
.. grid-item::
2626

2727
:octicon:`code-square;1.0em;` View the code used in this tutorial on `GitHub <https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu.py>`__
28-
28+
2929
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
3030
:class-card: card-prerequisites
3131

@@ -45,11 +45,11 @@ In the `previous tutorial <ddp_series_theory.html>`__, we got a high-level overv
4545
In this tutorial, we start with a single-GPU training script and migrate that to running it on 4 GPUs on a single node.
4646
Along the way, we will talk through important concepts in distributed training while implementing them in our code.
4747

48-
.. note::
48+
.. note::
4949
If your model contains any ``BatchNorm`` layers, it needs to be converted to ``SyncBatchNorm`` to sync the running stats of ``BatchNorm``
5050
layers across replicas.
5151

52-
Use the helper function
52+
Use the helper function
5353
`torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) <https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#torch.nn.SyncBatchNorm.convert_sync_batchnorm>`__ to convert all ``BatchNorm`` layers in the model to ``SyncBatchNorm``.
5454

5555

@@ -58,27 +58,27 @@ Diff for `single_gpu.py <https://github.com/pytorch/examples/blob/main/distribut
5858
These are the changes you typically make to a single-GPU training script to enable DDP.
5959

6060
Imports
61-
~~~~~~~
61+
-------
6262
- ``torch.multiprocessing`` is a PyTorch wrapper around Python's native
6363
multiprocessing
6464
- The distributed process group contains all the processes that can
6565
communicate and synchronize with each other.
6666

67-
.. code-block:: diff
67+
.. code-block:: python
6868
69-
import torch
70-
import torch.nn.functional as F
71-
from utils import MyTrainDataset
69+
import torch
70+
import torch.nn.functional as F
71+
from utils import MyTrainDataset
7272
73-
+ import torch.multiprocessing as mp
74-
+ from torch.utils.data.distributed import DistributedSampler
75-
+ from torch.nn.parallel import DistributedDataParallel as DDP
76-
+ from torch.distributed import init_process_group, destroy_process_group
77-
+ import os
73+
import torch.multiprocessing as mp
74+
from torch.utils.data.distributed import DistributedSampler
75+
from torch.nn.parallel import DistributedDataParallel as DDP
76+
from torch.distributed import init_process_group, destroy_process_group
77+
import os
7878
7979
8080
Constructing the process group
81-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
81+
------------------------------
8282

8383
- First, before initializing the group process, call `set_device <https://pytorch.org/docs/stable/generated/torch.cuda.set_device.html?highlight=set_device#torch.cuda.set_device>`__,
8484
which sets the default GPU for each process. This is important to prevent hangs or excessive memory utilization on `GPU:0`
@@ -90,66 +90,66 @@ Constructing the process group
9090
- Read more about `choosing a DDP
9191
backend <https://pytorch.org/docs/stable/distributed.html#which-backend-to-use>`__
9292

93-
.. code-block:: diff
93+
.. code-block:: python
9494
95-
+ def ddp_setup(rank: int, world_size: int):
96-
+ """
97-
+ Args:
98-
+ rank: Unique identifier of each process
99-
+ world_size: Total number of processes
100-
+ """
101-
+ os.environ["MASTER_ADDR"] = "localhost"
102-
+ os.environ["MASTER_PORT"] = "12355"
103-
+ torch.cuda.set_device(rank)
104-
+ init_process_group(backend="nccl", rank=rank, world_size=world_size)
95+
def ddp_setup(rank: int, world_size: int):
96+
"""
97+
Args:
98+
rank: Unique identifier of each process
99+
world_size: Total number of processes
100+
"""
101+
os.environ["MASTER_ADDR"] = "localhost"
102+
os.environ["MASTER_PORT"] = "12355"
103+
torch.cuda.set_device(rank)
104+
init_process_group(backend="nccl", rank=rank, world_size=world_size)
105105
106106
107107
108108
Constructing the DDP model
109-
~~~~~~~~~~~~~~~~~~~~~~~~~~
109+
--------------------------
110110

111-
.. code-block:: diff
111+
.. code-block:: python
112112
113-
- self.model = model.to(gpu_id)
114-
+ self.model = DDP(model, device_ids=[gpu_id])
113+
self.model = DDP(model, device_ids=[gpu_id])
115114
116115
Distributing input data
117-
~~~~~~~~~~~~~~~~~~~~~~~
116+
-----------------------
118117

119118
- `DistributedSampler <https://pytorch.org/docs/stable/data.html?highlight=distributedsampler#torch.utils.data.distributed.DistributedSampler>`__
120119
chunks the input data across all distributed processes.
120+
- The `DataLoader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`__ combines a dataset and a
121+
sampler, and provides an iterable over the given dataset.
121122
- Each process will receive an input batch of 32 samples; the effective
122123
batch size is ``32 * nprocs``, or 128 when using 4 GPUs.
123124

124-
.. code-block:: diff
125+
.. code-block:: python
125126
126127
train_data = torch.utils.data.DataLoader(
127128
dataset=train_dataset,
128129
batch_size=32,
129-
- shuffle=True,
130-
+ shuffle=False,
131-
+ sampler=DistributedSampler(train_dataset),
130+
shuffle=False, # We don't shuffle
131+
sampler=DistributedSampler(train_dataset), # Use the Distributed Sampler here.
132132
)
133133
134-
- Calling the ``set_epoch()`` method on the ``DistributedSampler`` at the beginning of each epoch is necessary to make shuffling work
134+
- Calling the ``set_epoch()`` method on the ``DistributedSampler`` at the beginning of each epoch is necessary to make shuffling work
135135
properly across multiple epochs. Otherwise, the same ordering will be used in each epoch.
136136

137-
.. code-block:: diff
137+
.. code-block:: python
138138
139139
def _run_epoch(self, epoch):
140140
b_sz = len(next(iter(self.train_data))[0])
141-
+ self.train_data.sampler.set_epoch(epoch)
141+
self.train_data.sampler.set_epoch(epoch) # call this additional line at every epoch
142142
for source, targets in self.train_data:
143143
...
144144
self._run_batch(source, targets)
145145
146146
147147
Saving model checkpoints
148-
~~~~~~~~~~~~~~~~~~~~~~~~
149-
- We only need to save model checkpoints from one process. Without this
148+
------------------------
149+
- We only need to save model checkpoints from one process. Without this
150150
condition, each process would save its copy of the identical mode. Read
151151
more on saving and loading models with
152-
DDP `here <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints>`__
152+
DDP `here <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#save-and-load-checkpoints>`__
153153

154154
.. code-block:: diff
155155
@@ -164,18 +164,18 @@ Saving model checkpoints
164164
.. warning::
165165
`Collective calls <https://pytorch.org/docs/stable/distributed.html#collective-functions>`__ are functions that run on all the distributed processes,
166166
and they are used to gather certain states or values to a specific process. Collective calls require all ranks to run the collective code.
167-
In this example, `_save_checkpoint` should not have any collective calls because it is only run on the ``rank:0`` process.
167+
In this example, `_save_checkpoint` should not have any collective calls because it is only run on the ``rank:0`` process.
168168
If you need to make any collective calls, it should be before the ``if self.gpu_id == 0`` check.
169169

170170

171171
Running the distributed training job
172-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
172+
------------------------------------
173173

174174
- Include new arguments ``rank`` (replacing ``device``) and
175175
``world_size``.
176176
- ``rank`` is auto-allocated by DDP when calling
177177
`mp.spawn <https://pytorch.org/docs/stable/multiprocessing.html#spawning-subprocesses>`__.
178-
- ``world_size`` is the number of processes across the training job. For GPU training,
178+
- ``world_size`` is the number of processes across the training job. For GPU training,
179179
this corresponds to the number of GPUs in use, and each process works on a dedicated GPU.
180180

181181
.. code-block:: diff
@@ -189,7 +189,7 @@ Running the distributed training job
189189
+ trainer = Trainer(model, train_data, optimizer, rank, save_every)
190190
trainer.train(total_epochs)
191191
+ destroy_process_group()
192-
192+
193193
if __name__ == "__main__":
194194
import sys
195195
total_epochs = int(sys.argv[1])
@@ -199,13 +199,31 @@ Running the distributed training job
199199
+ world_size = torch.cuda.device_count()
200200
+ mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
201201
202+
Here's what the code looks like:
203+
204+
.. code-block:: python
205+
def main(rank, world_size, total_epochs, save_every):
206+
ddp_setup(rank, world_size)
207+
dataset, model, optimizer = load_train_objs()
208+
train_data = prepare_dataloader(dataset, batch_size=32)
209+
trainer = Trainer(model, train_data, optimizer, rank, save_every)
210+
trainer.train(total_epochs)
211+
destroy_process_group()
212+
213+
if __name__ == "__main__":
214+
import sys
215+
total_epochs = int(sys.argv[1])
216+
save_every = int(sys.argv[2])
217+
world_size = torch.cuda.device_count()
218+
mp.spawn(main, args=(world_size, total_epochs, save_every,), nprocs=world_size)
219+
202220
203221
204222
Further Reading
205223
---------------
206224

207225
- `Fault Tolerant distributed training <ddp_series_fault_tolerance.html>`__ (next tutorial in this series)
208226
- `Intro to DDP <ddp_series_theory.html>`__ (previous tutorial in this series)
209-
- `Getting Started with DDP <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
227+
- `Getting Started with DDP <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
210228
- `Process Group
211-
initialization <https://pytorch.org/docs/stable/distributed.html#tcp-initialization>`__
229+
Initialization <https://pytorch.org/docs/stable/distributed.html#tcp-initialization>`__

0 commit comments

Comments
 (0)