12
12
import os
13
13
import copy
14
14
15
- plt .ion ()
16
-
17
- # Data augmentation and normalization for training
18
- # Just normalization for validation
19
- data_transforms = {
20
- 'train' : transforms .Compose ([
21
- transforms .RandomResizedCrop (224 ),
22
- transforms .RandomHorizontalFlip (),
23
- transforms .ToTensor (),
24
- transforms .Normalize ([0.485 , 0.456 , 0.406 ], [0.229 , 0.224 , 0.225 ])
25
- ]),
26
- 'val' : transforms .Compose ([
27
- transforms .Resize (256 ),
28
- transforms .CenterCrop (224 ),
29
- transforms .ToTensor (),
30
- transforms .Normalize ([0.485 , 0.456 , 0.406 ], [0.229 , 0.224 , 0.225 ])
31
- ]),
32
- }
33
-
34
- data_dir = 'garbage-classification/Garbage classification'
35
-
36
- image_datasets = {x : datasets .ImageFolder (os .path .join (data_dir , x ),
37
- data_transforms [x ])
38
- for x in ['train' , 'val' ]}
39
-
40
- print ("Train classes: {}" .format (image_datasets ['train' ].classes ))
41
-
42
- dataloaders = {x : torch .utils .data .DataLoader (image_datasets [x ], batch_size = 4 ,
43
- shuffle = True , num_workers = 4 )
44
- for x in ['train' , 'val' ]}
45
- dataset_sizes = {x : len (image_datasets [x ]) for x in ['train' , 'val' ]}
46
- print ("Dataset size: {}" .format (dataset_sizes ))
47
-
48
- class_names = image_datasets ['train' ].classes
49
-
50
- device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
51
15
52
16
def imshow (inp , title = None ):
53
17
"""Imshow for Tensor."""
@@ -62,24 +26,21 @@ def imshow(inp, title=None):
62
26
plt .pause (0.001 ) # pause a bit so that plots are updated
63
27
64
28
65
- # Get a batch of training data
66
- inputs , classes = next (iter (dataloaders ['train' ]))
67
29
68
- # Make a grid from batch
69
- out = torchvision .utils .make_grid (inputs )
70
30
71
- imshow (out , title = [class_names [x ] for x in classes ])
72
-
73
- def createModel (num_classes = 6 ):
31
+ def createModel (num_classes = 6 , w_drop = True ):
74
32
75
33
model_ft = models .resnext101_32x8d (pretrained = True )
76
34
num_ftrs = model_ft .fc .in_features
77
- # model_ft.fc = nn.Linear(num_ftrs, num_classes)
78
35
79
- model_ft .fc = nn .Sequential (
80
- nn .Dropout (0.5 ),
81
- nn .Linear (num_ftrs , num_classes )
82
- )
36
+ if not w_drop :
37
+ model_ft .fc = nn .Linear (num_ftrs , num_classes )
38
+
39
+ else :
40
+ model_ft .fc = nn .Sequential (
41
+ nn .Dropout (0.5 ),
42
+ nn .Linear (num_ftrs , num_classes )
43
+ )
83
44
84
45
return model_ft
85
46
@@ -179,29 +140,77 @@ def visualize_model(model, num_images=6):
179
140
return
180
141
model .train (mode = was_training )
181
142
182
- model_ft = createModel ()
143
+ if __name__ == "__main__" :
144
+
145
+ plt .ion ()
146
+
147
+ # Data augmentation and normalization for training
148
+ # Just normalization for validation
149
+ data_transforms = {
150
+ 'train' : transforms .Compose ([
151
+ transforms .RandomResizedCrop (224 ),
152
+ transforms .RandomHorizontalFlip (),
153
+ transforms .ToTensor (),
154
+ transforms .Normalize ([0.485 , 0.456 , 0.406 ], [0.229 , 0.224 , 0.225 ])
155
+ ]),
156
+ 'val' : transforms .Compose ([
157
+ transforms .Resize (256 ),
158
+ transforms .CenterCrop (224 ),
159
+ transforms .ToTensor (),
160
+ transforms .Normalize ([0.485 , 0.456 , 0.406 ], [0.229 , 0.224 , 0.225 ])
161
+ ]),
162
+ }
163
+
164
+ data_dir = 'garbage-classification/Garbage classification'
165
+
166
+ image_datasets = {x : datasets .ImageFolder (os .path .join (data_dir , x ),
167
+ data_transforms [x ])
168
+ for x in ['train' , 'val' ]}
169
+
170
+ print ("Train classes: {}" .format (image_datasets ['train' ].classes ))
171
+
172
+ dataloaders = {x : torch .utils .data .DataLoader (image_datasets [x ], batch_size = 4 ,
173
+ shuffle = True , num_workers = 4 )
174
+ for x in ['train' , 'val' ]}
175
+ dataset_sizes = {x : len (image_datasets [x ]) for x in ['train' , 'val' ]}
176
+ print ("Dataset size: {}" .format (dataset_sizes ))
177
+
178
+ class_names = image_datasets ['train' ].classes
179
+
180
+ device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
181
+
182
+
183
+ # Get a batch of training data
184
+ inputs , classes = next (iter (dataloaders ['train' ]))
185
+
186
+ # Make a grid from batch
187
+ out = torchvision .utils .make_grid (inputs )
188
+
189
+ imshow (out , title = [class_names [x ] for x in classes ])
190
+
191
+ model_ft = createModel ()
183
192
184
- model_ft = model_ft .to (device )
193
+ model_ft = model_ft .to (device )
185
194
186
- criterion = nn .CrossEntropyLoss ()
195
+ criterion = nn .CrossEntropyLoss ()
187
196
188
- # Observe that all parameters are being optimized
189
- optimizer_ft = optim .SGD (model_ft .parameters (), lr = 0.001 , momentum = 0.9 )
190
- # optimizer_ft = optim.Adam(mode]+-[p0o98u3w` qa]\'l_ft.parameters(), lr=0.005)
197
+ # Observe that all parameters are being optimized
198
+ optimizer_ft = optim .SGD (model_ft .parameters (), lr = 0.001 , momentum = 0.9 )
199
+ # optimizer_ft = optim.Adam(mode]+-[p0o98u3w` qa]\'l_ft.parameters(), lr=0.005)
191
200
192
- # Decay LR by a factor of 0.1 every 7 epochs
193
- exp_lr_scheduler = lr_scheduler .StepLR (optimizer_ft , step_size = 7 , gamma = 0.1 )
201
+ # Decay LR by a factor of 0.1 every 7 epochs
202
+ exp_lr_scheduler = lr_scheduler .StepLR (optimizer_ft , step_size = 7 , gamma = 0.1 )
194
203
195
- num_epochs = 30
196
- start_epoch = 0
197
- model_ft , best_acc , loss = train_model (model_ft , criterion , optimizer_ft , exp_lr_scheduler , start_epoch = start_epoch ,
198
- num_epochs = num_epochs )
204
+ num_epochs = 30
205
+ start_epoch = 0
206
+ model_ft , best_acc , loss = train_model (model_ft , criterion , optimizer_ft , exp_lr_scheduler , start_epoch = start_epoch ,
207
+ num_epochs = num_epochs )
199
208
200
- checkpoint = {
201
- 'epoch' : start_epoch + num_epochs ,
202
- 'model' : createModel (),
203
- 'model_state_dict' : model_ft .state_dict (),
204
- 'optimizer_state_dict' : optimizer_ft .state_dict ()
205
- }
209
+ checkpoint = {
210
+ 'epoch' : start_epoch + num_epochs ,
211
+ 'model' : createModel (),
212
+ 'model_state_dict' : model_ft .state_dict (),
213
+ 'optimizer_state_dict' : optimizer_ft .state_dict ()
214
+ }
206
215
207
- torch .save (checkpoint , 'garbage-classification/models_resnext101_32x8d_acc: {:g} loss: {:g}' .format (best_acc , loss ))
216
+ torch .save (checkpoint , 'garbage-classification/models_resnext101_32x8d_acc: {:g} loss: {:g}' .format (best_acc , loss ))
0 commit comments