Skip to content
This repository was archived by the owner on Jan 29, 2024. It is now read-only.

Commit 049983e

Browse files
committed
Add base model.
1 parent 863e879 commit 049983e

File tree

2 files changed

+165
-3
lines changed

2 files changed

+165
-3
lines changed

models/baseline.py

+72-3
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,82 @@
77
-------------------------------------------------
88
"""
99

10+
import torch
1011
import torch.nn as nn
12+
from torchvision.models import resnet50, resnet101
13+
14+
from models.blocks import weights_init_kaiming, weights_init_classifier
15+
16+
# arch type
17+
FACTORY = {
18+
'resnet50': resnet50,
19+
'resnet101': resnet101
20+
}
1121

1222

1323
class Baseline(nn.Module):
14-
def __init__(self, num_classes, backbone, pretrain_choice):
24+
def __init__(self, num_classes, arch='resnet50', stride=1):
1525
super(Baseline, self).__init__()
16-
self.pretrain_choice = pretrain_choice
26+
self.num_classes = num_classes
27+
self.arch = arch
28+
self.stride = stride
29+
30+
# backbone
31+
if arch not in FACTORY:
32+
raise KeyError("Unknown models: ", arch)
33+
else:
34+
resnet = FACTORY[arch](pretrained=True)
35+
if stride == 1:
36+
resnet.layer4[0].downsample[0].stride = (1, 1)
37+
resnet.layer4[0].conv2.stride = (1, 1)
38+
39+
self.backbone = nn.Sequential(
40+
resnet.conv1,
41+
resnet.bn1,
42+
resnet.relu,
43+
resnet.maxpool,
44+
resnet.layer1, # res_conv2
45+
resnet.layer2, # res_conv3
46+
resnet.layer3, # res_conv4
47+
resnet.layer4
48+
)
49+
50+
self.gap = nn.AdaptiveAvgPool2d((1, 1))
51+
52+
self.bottleneck = nn.Sequential(
53+
nn.Linear(2048, 512),
54+
nn.BatchNorm1d(512),
55+
nn.LeakyReLU(0.1),
56+
nn.Dropout(p=0.5)
57+
)
58+
self.bottleneck.apply(weights_init_kaiming)
59+
self.classifier = nn.Linear(512, self.num_classes)
60+
self.classifier.apply(weights_init_classifier)
1761

1862
def forward(self, x):
19-
return x
63+
global_feat = self.backbone(x)
64+
global_feat = self.gap(global_feat) # (b, 2048, 1, 1)
65+
global_feat = global_feat.view(global_feat.shape[0], -1)
66+
67+
if self.training:
68+
feat = self.bottleneck(global_feat)
69+
cls_score = self.classifier(feat)
70+
return [global_feat], [cls_score]
71+
else:
72+
return global_feat
73+
74+
75+
if __name__ == '__main__':
76+
data = torch.rand(32, 3, 384, 128)
77+
78+
model = Baseline(num_classes=751, arch='resnet101', stride=1)
79+
out = model(data)
80+
print(out[0][0].shape)
81+
print(out[1][0].shape)
82+
83+
model = Baseline(num_classes=751, arch='resnet101', stride=2)
84+
out = model(data)
85+
print(out[0][0].shape)
86+
print(out[1][0].shape)
87+
88+
print('Done.')

models/blocks.py

+93
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
"""
2+
-------------------------------------------------
3+
File Name: blocks.py
4+
Author: Zhonghao Huang
5+
Date: 2019/9/10
6+
Description:
7+
-------------------------------------------------
8+
"""
9+
10+
import torch.nn as nn
11+
12+
13+
def weights_init_kaiming(m):
14+
classname = m.__class__.__name__
15+
if classname.find('Linear') != -1:
16+
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
17+
nn.init.constant_(m.bias, 0.0)
18+
elif classname.find('Conv') != -1:
19+
nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
20+
if m.bias is not None:
21+
nn.init.constant_(m.bias, 0.0)
22+
elif classname.find('BatchNorm') != -1:
23+
if m.affine:
24+
nn.init.constant_(m.weight, 1.0)
25+
nn.init.constant_(m.bias, 0.0)
26+
27+
28+
def weights_init_classifier(m):
29+
classname = m.__class__.__name__
30+
if classname.find('Linear') != -1:
31+
nn.init.normal_(m.weight, std=0.001)
32+
if m.bias is not None:
33+
nn.init.constant_(m.bias, 0.0)
34+
35+
36+
# Defines the new fc layer and classification layer
37+
# |--Linear--|--bn--|--relu--|--Linear--|
38+
class ClassBlock(nn.Module):
39+
def __init__(self, input_dim, class_num, linear=True, num_bottleneck=512,
40+
bn=True, relu=False, drop=0.5, return_feat=False):
41+
super(ClassBlock, self).__init__()
42+
self.return_feat = return_feat
43+
add_block = []
44+
if linear:
45+
add_block += [nn.Linear(input_dim, num_bottleneck)]
46+
else:
47+
num_bottleneck = input_dim
48+
if bn:
49+
add_block += [nn.BatchNorm1d(num_bottleneck)]
50+
if relu:
51+
add_block += [nn.LeakyReLU(0.1)]
52+
if drop > 0:
53+
add_block += [nn.Dropout(p=drop)]
54+
add_block = nn.Sequential(*add_block)
55+
add_block.apply(weights_init_kaiming)
56+
57+
fc = nn.Linear(num_bottleneck, class_num)
58+
fc.apply(weights_init_classifier)
59+
60+
self.add_block = add_block
61+
self.fc = fc
62+
63+
def forward(self, x):
64+
x = self.add_block(x)
65+
if self.return_feat:
66+
feat = x
67+
cls = self.fc(x)
68+
return cls, feat
69+
else:
70+
cls = self.fc(x)
71+
return cls
72+
73+
74+
class BNNeck(nn.Module):
75+
def __init__(self, in_planes, num_classes):
76+
super(BNNeck, self).__init__()
77+
self.in_planes = in_planes
78+
self.num_classes = num_classes
79+
self.bn = nn.BatchNorm1d(self.in_planes)
80+
self.bn.bias.requires_grad_(False) # no shift
81+
self.classifier = nn.Linear(self.in_planes, self.num_classes)
82+
83+
self.bn.apply(weights_init_kaiming)
84+
self.classifier.apply(weights_init_classifier)
85+
86+
def forward(self, x):
87+
feat = self.bn(x)
88+
cls = self.classifier(feat)
89+
90+
if self.training:
91+
return cls, x
92+
else:
93+
return cls, feat

0 commit comments

Comments
 (0)