Skip to content

Commit a9a9a22

Browse files
committed
Create pyotrch_mnmg.py
1 parent 858d36c commit a9a9a22

File tree

1 file changed

+194
-0
lines changed

1 file changed

+194
-0
lines changed

pytorch_benchmarks/pyotrch_mnmg.py

+194
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
"""
2+
(MNMC) Multiple Nodes Multi-GPU Cards Training
3+
with DistributedDataParallel and torch.distributed.launch
4+
Try to compare with [snsc.py, snmc_dp.py & mnmc_ddp_mp.py] and find out the differences.
5+
"""
6+
7+
import os
8+
9+
import torch
10+
import torch.distributed as dist
11+
import torch.nn as nn
12+
import torchvision
13+
import torchvision.transforms as transforms
14+
from torch.nn.parallel import DistributedDataParallel as DDP
15+
16+
BATCH_SIZE = 256
17+
EPOCHS = 5
18+
19+
20+
if __name__ == "__main__":
21+
22+
# 0. set up distributed device
23+
rank = int(os.environ["RANK"])
24+
local_rank = int(os.environ["LOCAL_RANK"])
25+
torch.cuda.set_device(rank % torch.cuda.device_count())
26+
dist.init_process_group(backend="nccl")
27+
device = torch.device("cuda", local_rank)
28+
29+
print(f"[init] == local rank: {local_rank}, global rank: {rank} ==")
30+
31+
# 1. define network
32+
net = torchvision.models.resnet18(pretrained=False, num_classes=10)
33+
net = net.to(device)
34+
# DistributedDataParallel
35+
net = DDP(net, device_ids=[local_rank], output_device=local_rank)
36+
37+
# 2. define dataloader
38+
trainset = torchvision.datasets.CIFAR10(
39+
root="./data",
40+
train=True,
41+
download=False,
42+
transform=transforms.Compose(
43+
[
44+
transforms.RandomCrop(32, padding=4),
45+
transforms.RandomHorizontalFlip(),
46+
transforms.ToTensor(),
47+
transforms.Normalize(
48+
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
49+
),
50+
]
51+
),
52+
)
53+
# DistributedSampler
54+
# we test single Machine with 2 GPUs
55+
# so the [batch size] for each process is 256 / 2 = 128
56+
train_sampler = torch.utils.data.distributed.DistributedSampler(
57+
trainset,
58+
shuffle=True,
59+
)
60+
train_loader = torch.utils.data.DataLoader(
61+
trainset,
62+
batch_size=BATCH_SIZE,
63+
num_workers=4,
64+
pin_memory=True,
65+
sampler=train_sampler,
66+
)
67+
68+
# 3. define loss and optimizer
69+
criterion = nn.CrossEntropyLoss()
70+
optimizer = torch.optim.SGD(
71+
net.parameters(),
72+
lr=0.01 * 2,
73+
momentum=0.9,
74+
weight_decay=0.0001,
75+
nesterov=True,
76+
)
77+
78+
if rank == 0:
79+
print(" ======= Training ======= \n")
80+
81+
# 4. start to train
82+
net.train()
83+
for ep in range(1, EPOCHS + 1):
84+
train_loss = correct = total = 0
85+
# set sampler
86+
train_loader.sampler.set_epoch(ep)
87+
88+
for idx, (inputs, targets) in enumerate(train_loader):
89+
inputs, targets = inputs.to(device), targets.to(device)
90+
outputs = net(inputs)
91+
92+
loss = criterion(outputs, targets)
93+
optimizer.zero_grad()
94+
loss.backward()
95+
optimizer.step()
96+
97+
train_loss += loss.item()
98+
total += targets.size(0)
99+
correct += torch.eq(outputs.argmax(dim=1), targets).sum().item()
100+
101+
if rank == 0 and ((idx + 1) % 25 == 0 or (idx + 1) == len(train_loader)):
102+
print(
103+
" == step: [{:3}/{}] [{}/{}] | loss: {:.3f} | acc: {:6.3f}%".format(
104+
idx + 1,
105+
len(train_loader),
106+
ep,
107+
EPOCHS,
108+
train_loss / (idx + 1),
109+
100.0 * correct / total,
110+
)
111+
)
112+
if rank == 0:
113+
print("\n ======= Training Finished ======= \n")
114+
115+
"""
116+
usage:
117+
>>> python -m torch.distributed.launch --help
118+
exmaple: 1 node, 4 GPUs per node (4GPUs)
119+
>>> python -m torch.distributed.launch \
120+
--nproc_per_node=4 \
121+
--nnodes=1 \
122+
--node_rank=0 \
123+
--master_addr=localhost \
124+
--master_port=22222 \
125+
mnmc_ddp_launch.py
126+
[init] == local rank: 3, global rank: 3 ==
127+
[init] == local rank: 1, global rank: 1 ==
128+
[init] == local rank: 0, global rank: 0 ==
129+
[init] == local rank: 2, global rank: 2 ==
130+
======= Training =======
131+
== step: [ 25/49] [0/5] | loss: 1.980 | acc: 27.953%
132+
== step: [ 49/49] [0/5] | loss: 1.806 | acc: 33.816%
133+
== step: [ 25/49] [1/5] | loss: 1.464 | acc: 47.391%
134+
== step: [ 49/49] [1/5] | loss: 1.420 | acc: 48.448%
135+
== step: [ 25/49] [2/5] | loss: 1.300 | acc: 52.469%
136+
== step: [ 49/49] [2/5] | loss: 1.274 | acc: 53.648%
137+
== step: [ 25/49] [3/5] | loss: 1.201 | acc: 56.547%
138+
== step: [ 49/49] [3/5] | loss: 1.185 | acc: 57.360%
139+
== step: [ 25/49] [4/5] | loss: 1.129 | acc: 59.531%
140+
== step: [ 49/49] [4/5] | loss: 1.117 | acc: 59.800%
141+
======= Training Finished =======
142+
exmaple: 1 node, 2tasks, 4 GPUs per task (8GPUs)
143+
>>> CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch \
144+
--nproc_per_node=4 \
145+
--nnodes=2 \
146+
--node_rank=0 \
147+
--master_addr="10.198.189.10" \
148+
--master_port=22222 \
149+
mnmc_ddp_launch.py
150+
>>> CUDA_VISIBLE_DEVICES=4,5,6,7 python -m torch.distributed.launch \
151+
--nproc_per_node=4 \
152+
--nnodes=2 \
153+
--node_rank=1 \
154+
--master_addr="10.198.189.10" \
155+
--master_port=22222 \
156+
mnmc_ddp_launch.py
157+
======= Training =======
158+
== step: [ 25/25] [0/5] | loss: 1.932 | acc: 29.088%
159+
== step: [ 25/25] [1/5] | loss: 1.546 | acc: 43.088%
160+
== step: [ 25/25] [2/5] | loss: 1.424 | acc: 48.032%
161+
== step: [ 25/25] [3/5] | loss: 1.335 | acc: 51.440%
162+
== step: [ 25/25] [4/5] | loss: 1.243 | acc: 54.672%
163+
======= Training Finished =======
164+
exmaple: 2 node, 8 GPUs per node (16GPUs)
165+
>>> python -m torch.distributed.launch \
166+
--nproc_per_node=8 \
167+
--nnodes=2 \
168+
--node_rank=0 \
169+
--master_addr="10.198.189.10" \
170+
--master_port=22222 \
171+
mnmc_ddp_launch.py
172+
>>> python -m torch.distributed.launch \
173+
--nproc_per_node=8 \
174+
--nnodes=2 \
175+
--node_rank=1 \
176+
--master_addr="10.198.189.10" \
177+
--master_port=22222 \
178+
mnmc_ddp_launch.py
179+
[init] == local rank: 5, global rank: 5 ==
180+
[init] == local rank: 3, global rank: 3 ==
181+
[init] == local rank: 2, global rank: 2 ==
182+
[init] == local rank: 4, global rank: 4 ==
183+
[init] == local rank: 0, global rank: 0 ==
184+
[init] == local rank: 6, global rank: 6 ==
185+
[init] == local rank: 7, global rank: 7 ==
186+
[init] == local rank: 1, global rank: 1 ==
187+
======= Training =======
188+
== step: [ 13/13] [0/5] | loss: 2.056 | acc: 23.776%
189+
== step: [ 13/13] [1/5] | loss: 1.688 | acc: 36.736%
190+
== step: [ 13/13] [2/5] | loss: 1.508 | acc: 44.544%
191+
== step: [ 13/13] [3/5] | loss: 1.462 | acc: 45.472%
192+
== step: [ 13/13] [4/5] | loss: 1.357 | acc: 49.344%
193+
======= Training Finished =======
194+
"""

0 commit comments

Comments
 (0)