Skip to content

Commit 01ab409

Browse files
committed
Update code
1 parent 1805032 commit 01ab409

File tree

2 files changed

+97
-88
lines changed

2 files changed

+97
-88
lines changed

predict.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torchvision
77
from torchvision import datasets, models, transforms
88
from PIL import Image
9-
from trash_cnn_pytorch import train_model
9+
# from trash_cnn_pytorch import train_model
1010

1111
def createModel(num_classes=6, w_drop=True):
1212

@@ -43,7 +43,7 @@ def setup(model_dir, model_class):
4343
checkpoint = torch.load(model_dir, map_location=device)
4444

4545
# Build model structure and optimizer
46-
predictor = model_class()
46+
predictor = model_class(w_drop=False)
4747
opt = optim.SGD(predictor.parameters(), lr=0.001, momentum=0.9)
4848

4949
# Load model weights and optimizer states
@@ -76,30 +76,30 @@ def predict(model, img, transform, epoch, classes=['cardboard', 'glass', 'metal'
7676
return classes[idx[0]], preds[idx[0]].item(), preds
7777

7878
# TODO: Finish converting train_model() to retrain purpose
79-
def retrain(model, opt, imgs, transform, start_epoch=0, classes=['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']):
80-
'''
81-
Continue training on minibatch of new observations
82-
'''
83-
if len(imgs) < 40:
84-
85-
print("Not enough training data")
86-
return
79+
# def retrain(model, opt, imgs, transform, start_epoch=0, classes=['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']):
80+
# '''
81+
# Continue training on minibatch of new observations
82+
# '''
83+
# if len(imgs) < 40:
84+
85+
# print("Not enough training data")
86+
# return
8787

88-
criterion = nn.CrossEntropyLoss()
88+
# criterion = nn.CrossEntropyLoss()
8989

90-
# Decay LR by a factor of 0.1 every 7 epochs
91-
scheduler = lr_scheduler.StepLR(opt, step_size=5, gamma=0.1)
90+
# # Decay LR by a factor of 0.1 every 7 epochs
91+
# scheduler = lr_scheduler.StepLR(opt, step_size=5, gamma=0.1)
9292

93-
new_model_ft, best_acc, loss = train_model(model, criterion, opt, scheduler, start_epoch, num_epochs=5)
93+
# new_model_ft, best_acc, loss = train_model(model, criterion, opt, scheduler, start_epoch, num_epochs=5)
9494

95-
checkpoint = {
96-
'epoch': start_epoch + 5,
97-
'model': createModel(),
98-
'model_state_dict': new_model_ft.state_dict(),
99-
'optimizer_state_dict': opt.state_dict()
100-
}
95+
# checkpoint = {
96+
# 'epoch': start_epoch + 5,
97+
# 'model': createModel(),
98+
# 'model_state_dict': new_model_ft.state_dict(),
99+
# 'optimizer_state_dict': opt.state_dict()
100+
# }
101101

102-
torch.save(checkpoint, 'garbage-classification/models_resnext101_32x8d_acc: {:g} loss: {:g}'.format(best_acc, loss))
102+
# torch.save(checkpoint, 'garbage-classification/models_resnext101_32x8d_acc: {:g} loss: {:g}'.format(best_acc, loss))
103103

104104
if __name__ == "__main__":
105105

trash_cnn_pytorch.py

Lines changed: 76 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,6 @@
1212
import os
1313
import copy
1414

15-
plt.ion()
16-
17-
# Data augmentation and normalization for training
18-
# Just normalization for validation
19-
data_transforms = {
20-
'train': transforms.Compose([
21-
transforms.RandomResizedCrop(224),
22-
transforms.RandomHorizontalFlip(),
23-
transforms.ToTensor(),
24-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
25-
]),
26-
'val': transforms.Compose([
27-
transforms.Resize(256),
28-
transforms.CenterCrop(224),
29-
transforms.ToTensor(),
30-
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
31-
]),
32-
}
33-
34-
data_dir = 'garbage-classification/Garbage classification'
35-
36-
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
37-
data_transforms[x])
38-
for x in ['train', 'val']}
39-
40-
print("Train classes: {}".format(image_datasets['train'].classes))
41-
42-
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
43-
shuffle=True, num_workers=4)
44-
for x in ['train', 'val']}
45-
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
46-
print("Dataset size: {}".format(dataset_sizes))
47-
48-
class_names = image_datasets['train'].classes
49-
50-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
5115

5216
def imshow(inp, title=None):
5317
"""Imshow for Tensor."""
@@ -62,24 +26,21 @@ def imshow(inp, title=None):
6226
plt.pause(0.001) # pause a bit so that plots are updated
6327

6428

65-
# Get a batch of training data
66-
inputs, classes = next(iter(dataloaders['train']))
6729

68-
# Make a grid from batch
69-
out = torchvision.utils.make_grid(inputs)
7030

71-
imshow(out, title=[class_names[x] for x in classes])
72-
73-
def createModel(num_classes=6):
31+
def createModel(num_classes=6, w_drop=True):
7432

7533
model_ft = models.resnext101_32x8d(pretrained=True)
7634
num_ftrs = model_ft.fc.in_features
77-
# model_ft.fc = nn.Linear(num_ftrs, num_classes)
7835

79-
model_ft.fc = nn.Sequential(
80-
nn.Dropout(0.5),
81-
nn.Linear(num_ftrs, num_classes)
82-
)
36+
if not w_drop:
37+
model_ft.fc = nn.Linear(num_ftrs, num_classes)
38+
39+
else:
40+
model_ft.fc = nn.Sequential(
41+
nn.Dropout(0.5),
42+
nn.Linear(num_ftrs, num_classes)
43+
)
8344

8445
return model_ft
8546

@@ -179,29 +140,77 @@ def visualize_model(model, num_images=6):
179140
return
180141
model.train(mode=was_training)
181142

182-
model_ft = createModel()
143+
if __name__ == "__main__":
144+
145+
plt.ion()
146+
147+
# Data augmentation and normalization for training
148+
# Just normalization for validation
149+
data_transforms = {
150+
'train': transforms.Compose([
151+
transforms.RandomResizedCrop(224),
152+
transforms.RandomHorizontalFlip(),
153+
transforms.ToTensor(),
154+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
155+
]),
156+
'val': transforms.Compose([
157+
transforms.Resize(256),
158+
transforms.CenterCrop(224),
159+
transforms.ToTensor(),
160+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
161+
]),
162+
}
163+
164+
data_dir = 'garbage-classification/Garbage classification'
165+
166+
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
167+
data_transforms[x])
168+
for x in ['train', 'val']}
169+
170+
print("Train classes: {}".format(image_datasets['train'].classes))
171+
172+
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
173+
shuffle=True, num_workers=4)
174+
for x in ['train', 'val']}
175+
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
176+
print("Dataset size: {}".format(dataset_sizes))
177+
178+
class_names = image_datasets['train'].classes
179+
180+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
181+
182+
183+
# Get a batch of training data
184+
inputs, classes = next(iter(dataloaders['train']))
185+
186+
# Make a grid from batch
187+
out = torchvision.utils.make_grid(inputs)
188+
189+
imshow(out, title=[class_names[x] for x in classes])
190+
191+
model_ft = createModel()
183192

184-
model_ft = model_ft.to(device)
193+
model_ft = model_ft.to(device)
185194

186-
criterion = nn.CrossEntropyLoss()
195+
criterion = nn.CrossEntropyLoss()
187196

188-
# Observe that all parameters are being optimized
189-
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
190-
# optimizer_ft = optim.Adam(mode]+-[p0o98u3w` qa]\'l_ft.parameters(), lr=0.005)
197+
# Observe that all parameters are being optimized
198+
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
199+
# optimizer_ft = optim.Adam(mode]+-[p0o98u3w` qa]\'l_ft.parameters(), lr=0.005)
191200

192-
# Decay LR by a factor of 0.1 every 7 epochs
193-
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
201+
# Decay LR by a factor of 0.1 every 7 epochs
202+
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
194203

195-
num_epochs = 30
196-
start_epoch = 0
197-
model_ft, best_acc, loss = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, start_epoch=start_epoch,
198-
num_epochs= num_epochs)
204+
num_epochs = 30
205+
start_epoch = 0
206+
model_ft, best_acc, loss = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, start_epoch=start_epoch,
207+
num_epochs= num_epochs)
199208

200-
checkpoint = {
201-
'epoch': start_epoch + num_epochs,
202-
'model': createModel(),
203-
'model_state_dict': model_ft.state_dict(),
204-
'optimizer_state_dict': optimizer_ft.state_dict()
205-
}
209+
checkpoint = {
210+
'epoch': start_epoch + num_epochs,
211+
'model': createModel(),
212+
'model_state_dict': model_ft.state_dict(),
213+
'optimizer_state_dict': optimizer_ft.state_dict()
214+
}
206215

207-
torch.save(checkpoint, 'garbage-classification/models_resnext101_32x8d_acc: {:g} loss: {:g}'.format(best_acc, loss))
216+
torch.save(checkpoint, 'garbage-classification/models_resnext101_32x8d_acc: {:g} loss: {:g}'.format(best_acc, loss))

0 commit comments

Comments
 (0)