Skip to content

Commit 89d1588

Browse files
Ruibo ZhangRuibo Zhang
Ruibo Zhang
authored and
Ruibo Zhang
committed
Identity Mapping test in MNIST
1 parent 3fbc477 commit 89d1588

File tree

3 files changed

+228
-0
lines changed

3 files changed

+228
-0
lines changed

MNIST_2.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
import torchvision
5+
import torchvision.transforms as transforms
6+
7+
# 数据预处理
8+
transform = transforms.Compose(
9+
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
10+
)
11+
12+
# 加载训练集和测试集
13+
trainset = torchvision.datasets.MNIST(root='./data', train=True,
14+
download=True, transform=transform)
15+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
16+
shuffle=True, num_workers=0)
17+
18+
testset = torchvision.datasets.MNIST(root='./data', train=False,
19+
download=True, transform=transform)
20+
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
21+
shuffle=False, num_workers=0)
22+
23+
# 定义简单的神经网络模型
24+
class SimpleNet(nn.Module):
25+
def __init__(self):
26+
super(SimpleNet, self).__init__()
27+
self.fc1 = nn.Linear(28 * 28, 128)
28+
self.fc2 = nn.Linear(128, 64)
29+
self.fc3 = nn.Linear(64, 10)
30+
31+
def forward(self, x):
32+
x = x.view(-1, 28 * 28)
33+
x = torch.relu(self.fc1(x))
34+
x = torch.relu(self.fc2(x))
35+
x = self.fc3(x)
36+
return x
37+
38+
# 初始化网络、损失函数和优化器
39+
net = SimpleNet()
40+
criterion = nn.CrossEntropyLoss()
41+
#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
42+
optimizer = optim.Adam(net.parameters(), lr=0.001) # 使用Adam优化器
43+
44+
45+
# 训练网络
46+
for epoch in range(1):
47+
running_loss = 0.0
48+
for i, data in enumerate(trainloader, 0):
49+
inputs, labels = data
50+
optimizer.zero_grad()
51+
outputs = net(inputs)
52+
loss = criterion(outputs, labels)
53+
loss.backward()
54+
optimizer.step()
55+
running_loss += loss.item()
56+
if i % 100 == 99:
57+
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
58+
running_loss = 0.0
59+
60+
print('Finished Training')
61+
62+
63+
# 在训练结束后保存模型权重
64+
PATH = "IdentityMappingModule/model_weights.pth"
65+
torch.save(net.state_dict(), PATH)
66+
67+
print("Model weights saved to", PATH)
68+
69+
# 测试网络
70+
correct = 0
71+
total = 0
72+
with torch.no_grad():
73+
for data in testloader:
74+
images, labels = data
75+
outputs = net(images)
76+
_, predicted = torch.max(outputs.data, 1)
77+
total += labels.size(0)
78+
correct += (predicted == labels).sum().item()
79+
80+
print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')
81+

MNIST_IM.py

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
import torchvision
5+
import torchvision.transforms as transforms
6+
7+
8+
# 数据预处理
9+
transform = transforms.Compose(
10+
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
11+
)
12+
13+
# 加载训练集和测试集
14+
trainset = torchvision.datasets.MNIST(root='./data', train=True,
15+
download=True, transform=transform)
16+
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
17+
shuffle=True, num_workers=0)
18+
19+
testset = torchvision.datasets.MNIST(root='./data', train=False,
20+
download=True, transform=transform)
21+
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
22+
shuffle=False, num_workers=0)
23+
24+
25+
class ZeroConvBatchNorm(nn.Module):
26+
def __init__(self, in_channels, out_channels):
27+
super(ZeroConvBatchNorm, self).__init__()
28+
29+
# Define a zero convolutional layer with batch normalization
30+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
31+
self.conv.weight.data.fill_(0) # Initialize weights with zeros
32+
self.bn = nn.BatchNorm2d(out_channels)
33+
34+
def forward(self, x):
35+
out = self.conv(x) # Apply the zero convolution operation
36+
out = self.bn(out) # Apply batch normalization
37+
return out
38+
39+
class IdentityMappingModule(nn.Module):
40+
def __init__(self, in_channels, hidden_channels):
41+
super(IdentityMappingModule, self).__init__()
42+
43+
layers = []
44+
for _ in range(6): # Create 6 pairs of zero conv and batch norm layers
45+
zero_conv_bn = ZeroConvBatchNorm(hidden_channels, hidden_channels)
46+
layers.append(zero_conv_bn)
47+
self.identity_module = nn.Sequential(*layers)
48+
49+
def forward(self, x):
50+
x = x.unsqueeze(-1).unsqueeze(-1)
51+
out = self.identity_module(x) # Apply the sequence of zero conv and batch norm layers
52+
return out + x # Implement skip connection by adding input tensor to output tensor
53+
54+
# 定义结合了IdentityMappingModule的神经网络模型
55+
class SimpleNetWithIdentityModule(nn.Module):
56+
def __init__(self):
57+
super(SimpleNetWithIdentityModule, self).__init__()
58+
59+
self.fc1 = nn.Linear(28 * 28, 128)
60+
self.fc2 = nn.Linear(128, 64)
61+
self.identity_module = IdentityMappingModule(64, 64) # 使用IdentityMappingModule
62+
self.fc3 = nn.Linear(64, 10)
63+
64+
def forward(self, x):
65+
x = x.view(-1, 28 * 28)
66+
x = torch.relu(self.fc1(x))
67+
x = torch.relu(self.fc2(x))
68+
x = self.identity_module(x) # 使用IdentityMappingModule
69+
x = x.view(x.size(0), -1) # 将 x 展平为 [batch_size, 64]
70+
x = torch.relu(x)
71+
x = self.fc3(x)
72+
return x
73+
74+
75+
# 初始化网络、损失函数和优化器
76+
# 初始化网络、损失函数和优化器
77+
net = SimpleNetWithIdentityModule()
78+
79+
# 加载之前训练过的全连接层的权重,不改变其权重
80+
PATH = "IdentityMappingModule/model_weights.pth"
81+
pretrained_dict = torch.load(PATH)
82+
model_dict = net.state_dict()
83+
# 过滤出只需要的权重
84+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
85+
# 更新当前模型的权重
86+
model_dict.update(pretrained_dict)
87+
net.load_state_dict(model_dict)
88+
89+
# 冻结全连接层的权重,使其不参与训练
90+
for param in net.fc1.parameters():
91+
param.requires_grad = False
92+
for param in net.fc2.parameters():
93+
param.requires_grad = False
94+
for param in net.fc3.parameters():
95+
param.requires_grad = False
96+
97+
# 定义只优化恒等映射层的优化器
98+
criterion = nn.CrossEntropyLoss()
99+
#optimizer = optim.SGD(net.identity_module.parameters(), lr=0.001, momentum=0.9)
100+
optimizer = optim.Adam(net.identity_module.parameters(), lr=0.001) # 使用Adam优化器
101+
102+
# net = SimpleNetWithIdentityModule()
103+
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
104+
105+
106+
# 训练网络
107+
for epoch in range(1):
108+
running_loss = 0.0
109+
for i, data in enumerate(trainloader, 0):
110+
inputs, labels = data
111+
optimizer.zero_grad()
112+
outputs = net(inputs)
113+
loss = criterion(outputs, labels)
114+
loss.backward()
115+
optimizer.step()
116+
running_loss += loss.item()
117+
if i % 100 == 99:
118+
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
119+
running_loss = 0.0
120+
121+
print('Finished Training')
122+
123+
124+
# # 保存训练后的权重
125+
126+
# torch.save(net.state_dict(), PATH)
127+
# print("Model weights saved to", PATH)
128+
129+
130+
# # 加载模型权重并进行推断
131+
# net = SimpleNetWithIdentityModule()
132+
# net.load_state_dict(torch.load(PATH), strict = False)
133+
# net.eval()
134+
135+
136+
# 测试网络
137+
correct = 0
138+
total = 0
139+
with torch.no_grad():
140+
for data in testloader:
141+
images, labels = data
142+
outputs = net(images)
143+
_, predicted = torch.max(outputs.data, 1)
144+
total += labels.size(0)
145+
correct += (predicted == labels).sum().item()
146+
147+
print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')

model_weights.pth

429 KB
Binary file not shown.

0 commit comments

Comments
 (0)