Skip to content

Commit 78736cd

Browse files
committed
Lenet5 and ResNet18 example
1 parent 81d0692 commit 78736cd

File tree

3 files changed

+275
-0
lines changed

3 files changed

+275
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import torch
2+
from torch.utils.data import DataLoader
3+
from torchvision import datasets
4+
from torchvision import transforms
5+
from torch import nn, optim
6+
7+
from lenet5 import Lenet5
8+
from resnet import ResNet18
9+
10+
def main():
11+
batchsz = 32
12+
13+
cifar_train = datasets.CIFAR10('cifar', True, transform=transforms.Compose([
14+
transforms.Resize((32, 32)),
15+
transforms.ToTensor()
16+
]), download=True)
17+
cifar_train = DataLoader(cifar_train, batch_size=batchsz, shuffle=True)
18+
19+
cifar_test = datasets.CIFAR10('cifar', False, transform=transforms.Compose([
20+
transforms.Resize((32, 32)),
21+
transforms.ToTensor()
22+
]), download=True)
23+
cifar_test = DataLoader(cifar_test, batch_size=batchsz, shuffle=True)
24+
25+
26+
x, label = iter(cifar_train).next()
27+
print('x:', x.shape, 'label:', label.shape)
28+
29+
device = torch.device('cuda')
30+
# model = Lenet5().to(device)
31+
model = ResNet18().to(device)
32+
33+
criteon = nn.CrossEntropyLoss().to(device)
34+
optimizer = optim.Adam(model.parameters(), lr=1e-3)
35+
print(model)
36+
37+
for epoch in range(1000):
38+
39+
model.train()
40+
for batchidx, (x, label) in enumerate(cifar_train):
41+
# [b, 3, 32, 32]
42+
# [b]
43+
x, label = x.to(device), label.to(device)
44+
45+
logits = model(x)
46+
# logits: [b, 10]
47+
# label: [b]
48+
# loss: tensor scalar
49+
loss = criteon(logits, label)
50+
51+
# backprop
52+
optimizer.zero_grad()
53+
loss.backward()
54+
optimizer.step()
55+
56+
57+
#
58+
print(epoch, 'loss:', loss.item())
59+
60+
61+
model.eval()
62+
with torch.no_grad():
63+
# test
64+
total_correct = 0
65+
total_num = 0
66+
for x, label in cifar_test:
67+
# [b, 3, 32, 32]
68+
# [b]
69+
x, label = x.to(device), label.to(device)
70+
71+
# [b, 10]
72+
logits = model(x)
73+
# [b]
74+
pred = logits.argmax(dim=1)
75+
# [b] vs [b] => scalar tensor
76+
correct = torch.eq(pred, label).float().sum().item()
77+
total_correct += correct
78+
total_num += x.size(0)
79+
# print(correct)
80+
81+
acc = total_correct / total_num
82+
print(epoch, 'acc:', acc)
83+
84+
85+
86+
if __name__ == '__main__':
87+
main()
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import functional as F
4+
5+
6+
7+
8+
class Lenet5(nn.Module):
9+
"""
10+
for cifar10 dataset.
11+
"""
12+
def __init__(self):
13+
super(Lenet5, self).__init__()
14+
15+
self.conv_unit = nn.Sequential(
16+
# x: [b, 3, 32, 32] => [b, 6, ]
17+
nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
18+
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
19+
#
20+
nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
21+
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
22+
#
23+
)
24+
# flatten
25+
# fc unit
26+
self.fc_unit = nn.Sequential(
27+
nn.Linear(16*5*5, 32),
28+
nn.ReLU(),
29+
# nn.Linear(120, 84),
30+
# nn.ReLU(),
31+
nn.Linear(32, 10)
32+
)
33+
34+
35+
# [b, 3, 32, 32]
36+
tmp = torch.randn(2, 3, 32, 32)
37+
out = self.conv_unit(tmp)
38+
# [b, 16, 5, 5]
39+
print('conv out:', out.shape)
40+
41+
# # use Cross Entropy Loss
42+
# self.criteon = nn.CrossEntropyLoss()
43+
44+
45+
46+
def forward(self, x):
47+
"""
48+
49+
:param x: [b, 3, 32, 32]
50+
:return:
51+
"""
52+
batchsz = x.size(0)
53+
# [b, 3, 32, 32] => [b, 16, 5, 5]
54+
x = self.conv_unit(x)
55+
# [b, 16, 5, 5] => [b, 16*5*5]
56+
x = x.view(batchsz, 16*5*5)
57+
# [b, 16*5*5] => [b, 10]
58+
logits = self.fc_unit(x)
59+
60+
# # [b, 10]
61+
# pred = F.softmax(logits, dim=1)
62+
# loss = self.criteon(logits, y)
63+
64+
return logits
65+
66+
def main():
67+
68+
net = Lenet5()
69+
70+
tmp = torch.randn(2, 3, 32, 32)
71+
out = net(tmp)
72+
print('lenet out:', out.shape)
73+
74+
75+
76+
77+
if __name__ == '__main__':
78+
main()
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import functional as F
4+
5+
6+
7+
class ResBlk(nn.Module):
8+
"""
9+
resnet block
10+
"""
11+
12+
def __init__(self, ch_in, ch_out):
13+
"""
14+
15+
:param ch_in:
16+
:param ch_out:
17+
"""
18+
super(ResBlk, self).__init__()
19+
20+
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1)
21+
self.bn1 = nn.BatchNorm2d(ch_out)
22+
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
23+
self.bn2 = nn.BatchNorm2d(ch_out)
24+
25+
self.extra = nn.Sequential()
26+
if ch_out != ch_in:
27+
# [b, ch_in, h, w] => [b, ch_out, h, w]
28+
self.extra = nn.Sequential(
29+
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1),
30+
nn.BatchNorm2d(ch_out)
31+
)
32+
33+
34+
def forward(self, x):
35+
"""
36+
37+
:param x: [b, ch, h, w]
38+
:return:
39+
"""
40+
out = F.relu(self.bn1(self.conv1(x)))
41+
out = self.bn2(self.conv2(out))
42+
# short cut.
43+
# extra module: [b, ch_in, h, w] => [b, ch_out, h, w]
44+
# element-wise add:
45+
out = self.extra(x) + out
46+
47+
return out
48+
49+
50+
51+
52+
class ResNet18(nn.Module):
53+
54+
def __init__(self):
55+
super(ResNet18, self).__init__()
56+
57+
self.conv1 = nn.Sequential(
58+
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
59+
nn.BatchNorm2d(16)
60+
)
61+
# followed 4 blocks
62+
# [b, 64, h, w] => [b, 128, h ,w]
63+
self.blk1 = ResBlk(16, 16)
64+
# [b, 128, h, w] => [b, 256, h, w]
65+
self.blk2 = ResBlk(16, 32)
66+
# # [b, 256, h, w] => [b, 512, h, w]
67+
# self.blk3 = ResBlk(128, 256)
68+
# # [b, 512, h, w] => [b, 1024, h, w]
69+
# self.blk4 = ResBlk(256, 512)
70+
71+
self.outlayer = nn.Linear(32*32*32, 10)
72+
73+
def forward(self, x):
74+
"""
75+
76+
:param x:
77+
:return:
78+
"""
79+
x = F.relu(self.conv1(x))
80+
81+
# [b, 64, h, w] => [b, 1024, h, w]
82+
x = self.blk1(x)
83+
x = self.blk2(x)
84+
# x = self.blk3(x)
85+
# x = self.blk4(x)
86+
87+
# print(x.shape)
88+
x = x.view(x.size(0), -1)
89+
x = self.outlayer(x)
90+
91+
92+
return x
93+
94+
95+
96+
def main():
97+
blk = ResBlk(64, 128)
98+
tmp = torch.randn(2, 64,32, 32)
99+
out = blk(tmp)
100+
print('blkk', out.shape)
101+
102+
103+
model = ResNet18()
104+
tmp = torch.randn(2, 3, 32, 32)
105+
out = model(tmp)
106+
print('resnet:', out.shape)
107+
108+
109+
if __name__ == '__main__':
110+
main()

0 commit comments

Comments
 (0)