1
+ import torch
2
+ import torch .nn as nn
3
+ import torch .optim as optim
4
+ import torchvision
5
+ import torchvision .transforms as transforms
6
+
7
+
8
+ # 数据预处理
9
+ transform = transforms .Compose (
10
+ [transforms .ToTensor (), transforms .Normalize ((0.5 ,), (0.5 ,))]
11
+ )
12
+
13
+ # 加载训练集和测试集
14
+ trainset = torchvision .datasets .MNIST (root = './data' , train = True ,
15
+ download = True , transform = transform )
16
+ trainloader = torch .utils .data .DataLoader (trainset , batch_size = 64 ,
17
+ shuffle = True , num_workers = 0 )
18
+
19
+ testset = torchvision .datasets .MNIST (root = './data' , train = False ,
20
+ download = True , transform = transform )
21
+ testloader = torch .utils .data .DataLoader (testset , batch_size = 64 ,
22
+ shuffle = False , num_workers = 0 )
23
+
24
+
25
+ class ZeroConvBatchNorm (nn .Module ):
26
+ def __init__ (self , in_channels , out_channels ):
27
+ super (ZeroConvBatchNorm , self ).__init__ ()
28
+
29
+ # Define a zero convolutional layer with batch normalization
30
+ self .conv = nn .Conv2d (in_channels , out_channels , kernel_size = 3 , stride = 1 , padding = 1 )
31
+ self .conv .weight .data .fill_ (0 ) # Initialize weights with zeros
32
+ self .bn = nn .BatchNorm2d (out_channels )
33
+
34
+ def forward (self , x ):
35
+ out = self .conv (x ) # Apply the zero convolution operation
36
+ out = self .bn (out ) # Apply batch normalization
37
+ return out
38
+
39
+ class IdentityMappingModule (nn .Module ):
40
+ def __init__ (self , in_channels , hidden_channels ):
41
+ super (IdentityMappingModule , self ).__init__ ()
42
+
43
+ layers = []
44
+ for _ in range (6 ): # Create 6 pairs of zero conv and batch norm layers
45
+ zero_conv_bn = ZeroConvBatchNorm (hidden_channels , hidden_channels )
46
+ layers .append (zero_conv_bn )
47
+ self .identity_module = nn .Sequential (* layers )
48
+
49
+ def forward (self , x ):
50
+ x = x .unsqueeze (- 1 ).unsqueeze (- 1 )
51
+ out = self .identity_module (x ) # Apply the sequence of zero conv and batch norm layers
52
+ return out + x # Implement skip connection by adding input tensor to output tensor
53
+
54
+ # 定义结合了IdentityMappingModule的神经网络模型
55
+ class SimpleNetWithIdentityModule (nn .Module ):
56
+ def __init__ (self ):
57
+ super (SimpleNetWithIdentityModule , self ).__init__ ()
58
+
59
+ self .fc1 = nn .Linear (28 * 28 , 128 )
60
+ self .fc2 = nn .Linear (128 , 64 )
61
+ self .identity_module = IdentityMappingModule (64 , 64 ) # 使用IdentityMappingModule
62
+ self .fc3 = nn .Linear (64 , 10 )
63
+
64
+ def forward (self , x ):
65
+ x = x .view (- 1 , 28 * 28 )
66
+ x = torch .relu (self .fc1 (x ))
67
+ x = torch .relu (self .fc2 (x ))
68
+ x = self .identity_module (x ) # 使用IdentityMappingModule
69
+ x = x .view (x .size (0 ), - 1 ) # 将 x 展平为 [batch_size, 64]
70
+ x = torch .relu (x )
71
+ x = self .fc3 (x )
72
+ return x
73
+
74
+
75
+ # 初始化网络、损失函数和优化器
76
+ # 初始化网络、损失函数和优化器
77
+ net = SimpleNetWithIdentityModule ()
78
+
79
+ # 加载之前训练过的全连接层的权重,不改变其权重
80
+ PATH = "IdentityMappingModule/model_weights.pth"
81
+ pretrained_dict = torch .load (PATH )
82
+ model_dict = net .state_dict ()
83
+ # 过滤出只需要的权重
84
+ pretrained_dict = {k : v for k , v in pretrained_dict .items () if k in model_dict }
85
+ # 更新当前模型的权重
86
+ model_dict .update (pretrained_dict )
87
+ net .load_state_dict (model_dict )
88
+
89
+ # 冻结全连接层的权重,使其不参与训练
90
+ for param in net .fc1 .parameters ():
91
+ param .requires_grad = False
92
+ for param in net .fc2 .parameters ():
93
+ param .requires_grad = False
94
+ for param in net .fc3 .parameters ():
95
+ param .requires_grad = False
96
+
97
+ # 定义只优化恒等映射层的优化器
98
+ criterion = nn .CrossEntropyLoss ()
99
+ #optimizer = optim.SGD(net.identity_module.parameters(), lr=0.001, momentum=0.9)
100
+ optimizer = optim .Adam (net .identity_module .parameters (), lr = 0.001 ) # 使用Adam优化器
101
+
102
+ # net = SimpleNetWithIdentityModule()
103
+ # optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
104
+
105
+
106
+ # 训练网络
107
+ for epoch in range (1 ):
108
+ running_loss = 0.0
109
+ for i , data in enumerate (trainloader , 0 ):
110
+ inputs , labels = data
111
+ optimizer .zero_grad ()
112
+ outputs = net (inputs )
113
+ loss = criterion (outputs , labels )
114
+ loss .backward ()
115
+ optimizer .step ()
116
+ running_loss += loss .item ()
117
+ if i % 100 == 99 :
118
+ print (f'[{ epoch + 1 } , { i + 1 :5d} ] loss: { running_loss / 100 :.3f} ' )
119
+ running_loss = 0.0
120
+
121
+ print ('Finished Training' )
122
+
123
+
124
+ # # 保存训练后的权重
125
+
126
+ # torch.save(net.state_dict(), PATH)
127
+ # print("Model weights saved to", PATH)
128
+
129
+
130
+ # # 加载模型权重并进行推断
131
+ # net = SimpleNetWithIdentityModule()
132
+ # net.load_state_dict(torch.load(PATH), strict = False)
133
+ # net.eval()
134
+
135
+
136
+ # 测试网络
137
+ correct = 0
138
+ total = 0
139
+ with torch .no_grad ():
140
+ for data in testloader :
141
+ images , labels = data
142
+ outputs = net (images )
143
+ _ , predicted = torch .max (outputs .data , 1 )
144
+ total += labels .size (0 )
145
+ correct += (predicted == labels ).sum ().item ()
146
+
147
+ print (f'Accuracy of the network on the 10000 test images: { 100 * correct / total :.2f} %' )
0 commit comments