-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
110 lines (92 loc) · 3.94 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import torch
from torch.utils.data import DataLoader
from model.dataset import Dataset
from model.yolo import YOLOv2
from tools.loss import YOLOLoss
from tools.fit import fit
from tools.utils import get_bound_boxes
from tools.mAP import mean_average_precision
data = './data/obj.data'
with open(data, 'r') as f:
classes = int(f.readline().split()[2])
data_train = f.readline().split()[2]
data_test = f.readline().split()[2]
data_label = f.readline().split()[2]
backup = f.readline().split()[2]
file_format = f.readline().split()[2]
convert_to_yolo = True if f.readline().split()[2] == 'True' else False
parser = argparse.ArgumentParser()
parser.add_argument('--darknet_weights', type=str, default=None, help='Path to Darknet19 weight file')
parser.add_argument('--yolo_weights', type=str, default=None, help='Path to YOLOv2 weight file')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
parser.add_argument('--epochs', type=int, default=100, help='Total epochs')
parser.add_argument('--learning_rate', type=float, default=0.00005, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=0.0002, help='Weight decay')
parser.add_argument('--multiscale_off', action='store_true', default=False, help='Disable multi-scale training')
parser.add_argument('--verbose', action='store_true', default=False, help='Show all losses and resolution changes')
args = parser.parse_args()
anchors = [[0.775, 0.774152],
[0.598437, 0.689189],
[0.234375, 0.320291],
[0.45625, 0.9],
[0.449219, 0.660934]]
train_dataset = Dataset(
data_dir=data_train,
labels_dir=data_label,
anchors=anchors,
num_classes=classes,
file_format=file_format,
type_dataset='train',
convert_to_yolo=convert_to_yolo
)
val_dataset = Dataset(
data_dir=data_test,
labels_dir=data_label,
anchors=anchors,
num_classes=classes,
file_format=file_format,
type_dataset='validation',
convert_to_yolo=convert_to_yolo
)
# a few checks to make sure the solution is correct
assert isinstance(train_dataset[0], dict)
assert len(train_dataset[0]) == 2
assert isinstance(train_dataset[0]['image'], torch.Tensor)
assert isinstance(train_dataset[0]['target'], torch.Tensor)
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True)
val_dataloader = DataLoader(
dataset=val_dataset,
batch_size=args.batch_size,
shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Executable device:', device)
if args.darknet_weights is not None:
model = YOLOv2(num_anchors=5, num_classes=classes, device=device, darknet_weights=args.darknet_weights).to(device)
else:
model = YOLOv2(num_anchors=5, num_classes=classes, device=device).to(device)
if args.yolo_weights is not None:
model.load_state_dict(torch.load(args.yolo_weights))
loss = YOLOLoss(anchors=anchors).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.99)
fit(model=model,
optimizer=optimizer,
scheduler=scheduler,
criterion=loss,
epochs=args.epochs,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
train_dataset=train_dataset if not args.multiscale_off else None,
backup=backup,
verbose=args.verbose)
torch.save(model.state_dict(), backup + 'yolov2_' + str(args.epochs) + '.pt')
pred_boxes, true_boxes = get_bound_boxes(train_dataloader, model, anchors, iou_threshold=0.5, threshold=0.3)
mAP = mean_average_precision(pred_boxes, true_boxes, classes=classes, iou_threshold=0.5)
print(f'Train mAP: {mAP}')
pred_boxes, true_boxes = get_bound_boxes(val_dataloader, model, anchors, iou_threshold=0.5, threshold=0.3)
mAP = mean_average_precision(pred_boxes, true_boxes, classes=classes, iou_threshold=0.5)
print(f'Validation mAP: {mAP}')