-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
191 lines (159 loc) · 6.23 KB
/
models.py
File metadata and controls
191 lines (159 loc) · 6.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import os
import torch
from torch import nn, load, hub
from torchvision import models
import helpers.manipulation as h_manipulation
from enum import Enum
class ModelTypes(Enum):
RESNET = 0
ALEXNET = 1
DENSENET = 2
EFFICIENTNET = 3
GOOGLENET = 4
MOBILENET = 5
SQUEEZENET = 6
class LayerTypes(Enum):
CONVOLUTIONAL = nn.Conv2d
LINEAR = nn.Linear
class TransformTypes(Enum):
PAD = "Pad"
JITTER = "Jitter"
RANDOM_SCALE = "Random Scale"
RANDOM_ROTATE = "Random Rotate"
AD_JITTER = "Additional Jitter"
_hook_activations = None
# Hook function
def _get_activations():
"""
Used when registering forward hooks to get layer activations
:return:
"""
def hook(model, input, output):
global _hook_activations
_hook_activations = output
return hook
def _get_activation_shape():
"""
Gets the activation
:return: A Tuple of the size of the activation
"""
return _hook_activations.squeeze().size()
def setup_model(model):
"""
Takes in a model type and creates a standard and robust model. Currently
only setups up for the CIFAR10 dataset. Loads the respective standard and
robust checkpoints saved in the models directory. Raises a ValueError if
model type is not valid.
:param model: Expects a enum of type ModelTypes to specified model to setup
:return: model
"""
curr_dir = os.path.dirname(__file__) + "/../"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
match model:
case ModelTypes.RESNET:
base = models.resnet18(
weights=models.ResNet18_Weights.IMAGENET1K_V1)
base.fc = nn.Linear(base.fc.in_features, 10)
base.load_state_dict(
load(curr_dir + "models/resnet_standard_cifar10.pt", map_location=device))
model = base
case ModelTypes.ALEXNET:
base = models.alexnet(weights=models.AlexNet_Weights.IMAGENET1K_V1)
base.classifier[6] = nn.Linear(base.classifier[6].in_features, 10)
base.load_state_dict(
load(curr_dir + "models/alexnet_standard_cifar10.pt", map_location=device))
model = base
case ModelTypes.DENSENET:
base = models.densenet121(
weights=models.DenseNet121_Weights.IMAGENET1K_V1)
base.classifier = nn.Linear(base.classifier.in_features, 10)
base.load_state_dict(
load(curr_dir + "models/densenet_standard_cifar10.pt", map_location=device))
model = base
case ModelTypes.EFFICIENTNET:
base = hub.load('NVIDIA/DeepLearningExamples:torchhub',
'nvidia_efficientnet_b0', pretrained=True)
base.classifier.fc = nn.Linear(base.classifier.fc.in_features, 10)
base.load_state_dict(
load(curr_dir + "models/efficientnet_standard_cifar10.pt", map_location=device))
model = base
case ModelTypes.GOOGLENET:
base = models.googlenet(
weights=models.GoogLeNet_Weights.IMAGENET1K_V1)
base.fc = nn.Linear(base.fc.in_features, 10)
base.load_state_dict(
load(curr_dir + "models/googlenet_standard_cifar10.pt", map_location=device))
model = base
case ModelTypes.MOBILENET:
base = models.mobilenet_v2(
weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
base.classifier[1] = nn.Linear(base.classifier[1].in_features, 10)
base.load_state_dict(
load(curr_dir + "models/mobilenet_standard_cifar10.pt", map_location=device))
model = base
case ModelTypes.SQUEEZENET:
base = models.squeezenet1_0(
weights=models.SqueezeNet1_0_Weights.IMAGENET1K_V1)
base.classifier[1] = nn.Conv2d(
512, 10, kernel_size=(1, 1), stride=(1, 1))
base.load_state_dict(
load(curr_dir + "models/squeezenet_standard_cifar10.pt", map_location=device))
model = base
case _:
raise ValueError("Unknown model choice")
return model
def get_layer_by_name(model, layer_name):
"""
Gets a layer given a name and model. Raises ValueError if layer is not
found in the model.
:param model: Model to look for layer in
:param layer_name: Layer name to search
:return: Layer found
"""
current_layer = model
layer_names = layer_name.split('_') # Split into sub layers
for name in layer_names:
if isinstance(current_layer, nn.Module):
current_layer = getattr(current_layer, name, None)
else:
raise ValueError(f"Layer '{layer_name}' not found in the model.")
if current_layer is None:
raise ValueError(f"Layer '{layer_name}' not found in the model.")
return current_layer
def get_feature_map_sizes(model, layers, img=None):
"""
Gets the feature map sizes, used to dynamically limit values on the
interface.
:param img: Image to use for forward pass
:param model: Model to pass image through
:param layers: Layers to grab feature map sizes from
:return: Feature map sizes
"""
feature_map_sizes = [None] * len(layers)
if img is None:
# TODO Remove this and just generates a blank image of 227 by 227
img = h_manipulation.create_random_image((227, 227),
h_manipulation.DatasetNormalizations.CIFAR10_MEAN.value,
h_manipulation.DatasetNormalizations.CIFAR10_STD.value).clone().unsqueeze(0)
else:
img = img.unsqueeze(0)
# Use GPU if possible
train_on_gpu = torch.cuda.is_available()
if train_on_gpu:
print("Using GPU for generation")
model = model.cuda()
img = img.cuda()
else:
print("Using CPU for generation")
model = model.eval()
# Fake forward pass for activations
index = 0
for layer in layers:
if isinstance(layer, nn.Conv2d):
hook = layer.register_forward_hook(_get_activations())
model(img)
# Activations will have feature map sizes
feature_map_sizes[index] = _get_activation_shape()
hook.remove()
index += 1
return feature_map_sizes