Skip to content

Commit aa5db26

Browse files
Ruibo ZhangRuibo Zhang
Ruibo Zhang
authored and
Ruibo Zhang
committed
first commit
0 parents  commit aa5db26

File tree

3 files changed

+135
-0
lines changed

3 files changed

+135
-0
lines changed

README.md

Whitespace-only changes.

ResNet.py

+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class block(nn.Module):
5+
def __init__(self, in_channels, out_channels, identity_downsample=None, stride=1):
6+
super(block, self).__init__()
7+
self.expansion = 4
8+
self.cov1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
9+
self.bn1 = nn.BatchNorm2d(out_channels)
10+
self.cov2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
11+
self.bn2 = nn.BatchNorm2d(out_channels)
12+
self.cov3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1, padding=0)
13+
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
14+
self.relu = nn.ReLU()
15+
self.identity_downsample = identity_downsample
16+
17+
def forward(self, x):
18+
identity = x
19+
20+
x = self.relu(self.bn1(self.cov1(x)))
21+
x = self.relu(self.bn2(self.cov2(x)))
22+
x = self.bn3(self.cov3(x))
23+
24+
if self.identity_downsample is not None:
25+
identity = self.identity_downsample(identity)
26+
27+
x += identity
28+
x = self.relu(x)
29+
return x
30+
31+
32+
class ResNet(nn.Module):
33+
def __init__(self, block, layers, image_channels, num_classes):
34+
super(ResNet, self).__init__()
35+
self.in_channels = 64
36+
self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
37+
self.bn1 = nn.BatchNorm2d(64)
38+
self.relu = nn.ReLU()
39+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
40+
41+
self.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)
42+
self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)
43+
self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)
44+
self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)
45+
46+
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
47+
self.fc = nn.Linear(512*4, num_classes)
48+
49+
def forward(self, x):
50+
x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
51+
x = self.layer4(self.layer3(self.layer2(self.layer1(x))))
52+
x = self.avgpool(x)
53+
x = x.reshape(x.shape[0], -1)
54+
x = self.fc(x)
55+
return x
56+
57+
def _make_layer(self, block, num_residual_blocks, out_channels, stride):
58+
identity_downsample = None
59+
layers = []
60+
61+
if stride != 1 or self.in_channels != out_channels * 4:
62+
identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels, out_channels * 4, kernel_size=1,
63+
stride=stride),
64+
nn.BatchNorm2d(out_channels * 4))
65+
66+
layers.append(block(self.in_channels, out_channels, identity_downsample, stride))
67+
self.in_channels = out_channels * 4
68+
69+
for i in range(num_residual_blocks - 1):
70+
layers.append(block(self.in_channels, out_channels))
71+
72+
return nn.Sequential(*layers)
73+
74+
75+
def ResNet50(img_channels=3, num_classes=1000):
76+
return ResNet(block, [3, 4, 6, 3], img_channels, num_classes)
77+
78+
79+
def ResNet101(img_channels=3, num_classes=1000):
80+
return ResNet(block, [3, 4, 23, 3], img_channels, num_classes)
81+
82+
83+
def ResNet152(img_channels=3, num_classes=1000):
84+
return ResNet(block, [3, 8, 36, 3], img_channels, num_classes)
85+
86+
def test():
87+
net = ResNet152()
88+
x = torch.randn(2, 3, 224, 224)
89+
y = net(x)
90+
print(y.shape)
91+
92+
test()

model.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
class ZeroConvBatchNorm(nn.Module):
5+
def __init__(self, in_channels, out_channels):
6+
super(ZeroConvBatchNorm, self).__init__()
7+
8+
# Define zero convolutional layer and batch normalization
9+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
10+
self.conv.weight.data.fill_(0) # Set all weights to zero
11+
self.bn = nn.BatchNorm2d(out_channels)
12+
13+
def forward(self, x):
14+
out = self.conv(x)
15+
out = self.bn(out)
16+
return out
17+
18+
class IdentityMappingModule(nn.Module):
19+
def __init__(self, in_channels, hidden_channels):
20+
super(IdentityMappingModule, self).__init__()
21+
22+
layers = []
23+
for _ in range(6): # Create 6 pairs of zero conv and batch norm layers
24+
zero_conv_bn = ZeroConvBatchNorm(hidden_channels, hidden_channels)
25+
layers.append(zero_conv_bn)
26+
self.identity_module = nn.Sequential(*layers)
27+
28+
def forward(self, x):
29+
out = self.identity_module(x)
30+
return out + x # Skip connection
31+
32+
# Example usage
33+
input_channels = 3
34+
hidden_channels = 3
35+
input_tensor = torch.randn(8, input_channels, 32, 32) # Example input tensor with batch size 8
36+
37+
identity_module = IdentityMappingModule(input_channels, hidden_channels)
38+
output_tensor = identity_module(input_tensor)
39+
40+
print("Input tensor shape:", input_tensor)
41+
print("Output tensor shape:", output_tensor)
42+
print(input_tensor == output_tensor)
43+

0 commit comments

Comments
 (0)