-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
75 lines (61 loc) · 2.24 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import pickle as pkl
from random import shuffle
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
from dataset import RSDataset
batch_size = 16
num_epochs = 100
net = "politically_correct"
master_list_location = "/home/connor1995/train"
if net == "politically_correct":
from politically_correct import network, preprocess, loss
def save_checkpoint(state, filename='checkpoint'):
filename = os.path.join(net, filename+ "_" + str(state["epoch"]) + ".pth.tar")
torch.save(state, filename)
if not os.path.exists(net):
os.mkdir(net)
restore = False
else:
restore = True
l = []
for f in os.listdir(net):
l.append(int(f.split(".")[0].split("_")[1]))
l.sort()
resume_path = os.path.join(net, "checkpoint" + "_"+ str(l[-1]) + ".pth.tar")
debug = True
if __name__ == "__main__":
model = network()
optimizer = torch.optim.Adam(model.parameters())
criterion = loss()
train_data = RSDataset(master_list_location, grey=True, transform=preprocess)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
print("Number of parameters: ", sum(param.numel() for param in model.parameters()))
start_epoch = 0
if restore:
print("=> loading checkpoint '{}'".format(resume_path))
checkpoint = torch.load(args.resume)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint['epoch']))
model.cuda()
for i in range(start_epoch, num_epochs):
for e in train_dataloader:
model.zero_grad()
e["image"] = Variable(e["image"]).cuda()
e["labels"] = Variable(e["labels"]).cuda()
#if debug:
# print(e["image"].size())
out = model(e["image"])
loss = criterion(e["labels"], out)
loss.backward()
optimizer.step()
save_checkpoint({
'epoch': epoch + 1,
'arch': net,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
})