-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainval_individual.py
125 lines (100 loc) · 4.36 KB
/
trainval_individual.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import sys
import os
import importlib
import numpy as np
import torch
import torch.nn as nn
from core.classificationnet import ClassificationNet
from core.utils import build_backbone_info, build_imagedataloaders
from core.utils import save_checkpoint, load_checkpoint, save_val_record
from core.utils import build_optimizers, build_schedulers
from algorithms import train_epoch, test_epoch, ModelCheckpoint
def run_trainval(model, train_type, dataset, max_epoch, device, checkpoint_dir,
train_loader, val_loader, optimizers, schedulers, save_opt):
title_str = '== TRAINVAL {} on {} =='.format(train_type, dataset)
bound_str = '=' * len(title_str)
print(bound_str + '\n' + title_str + '\n' + bound_str)
print('Checkpoint Directory: {}'.format(checkpoint_dir))
output_dir, inner_chkpt = os.path.split(checkpoint_dir)
model_checkpoint = ModelCheckpoint(-1, checkpoint_dir, save_opt, max_epoch)
for epoch_idx in range(1, max_epoch+1):
train_loss, train_acc = train_epoch(
model, device, train_loader, optimizers, epoch_idx)
val_loss, val_acc = test_epoch(
model, device, val_loader, epoch_idx)
model_checkpoint(val_acc, epoch_idx, model)
schedulers.step()
return
def main(*args, **kwargs):
# ---------------------------------
# Loading the config
# ---------------------------------
config_module = importlib.import_module('configs.'+sys.argv[1])
args = config_module.args
print(args)
# ---------------------------------
# General settings
# ---------------------------------
device = 'cuda'
torch.manual_seed(args.rng_seed)
torch.cuda.manual_seed(args.rng_seed)
torch.cuda.manual_seed_all(args.rng_seed)
np.random.seed(args.rng_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
assert(args.train_type in ['baseline', 'finetune'])
assert(args.save_opt in ['best', 'last'])
# ---------------------------------
# Dataset settings
# ---------------------------------
image_size = args.image_size
batch_size = args.batch_size
padding = args.padding
transform_name = args.transform_name
# ---------------------------------
# Optimizer and Scheduler settings
#----------------------------------
param_types = args.param_types
max_epoch = args.max_epoch
optimizer_infos = args.optimizer_infos
scheduler_infos = args.scheduler_infos
# ---------------------------------
# Backbone settings
# ---------------------------------
backbone_info = build_backbone_info(args.backbone, 'standard', image_size)
# ---------------------------------
# Method settings
# ---------------------------------
experiment_dir = 'CHECKPOINTS/Individual/{}/{}/{}'.format(
args.exp_name, args.backbone, args.dataset)
if args.pretrain != '':
assert(args.train_type != 'baseline'), 'Cannot use pretrain in baseline train_type'
print('Load from the pretrained model!')
model, _ = load_checkpoint(args.pretrain)
else:
assert(args.train_type != 'finetune'), 'Cannot use finetune train_type without pretrain'
model = ClassificationNet(backbone_info, args.num_classes)
# ---------------------------------
# Build the parallel model
# ---------------------------------
model = nn.DataParallel(model.to(device))
# ---------------------------------
# Run trainval or evaluate
# ---------------------------------
# Build the train and validation dataloaders
train_loader, val_loader = build_imagedataloaders(
'trainval', os.path.join(args.exp_name, args.dataset), transform_name,
image_size, batch_size, padding, args.save_opt, args.workers)
# Get the checkpoint directory name
inner_chkpt = args.train_type + args.chkpt_postfix
checkpoint_dir = os.path.join(experiment_dir, inner_chkpt)
# Get the optimizers and schedulers
optimizers = build_optimizers(model.module, param_types, optimizer_infos)
schedulers = build_schedulers(optimizers, scheduler_infos)
# Run training and validation loops
run_trainval(
model, args.train_type, args.dataset, max_epoch, device, checkpoint_dir,
train_loader, val_loader, optimizers, schedulers, args.save_opt)
return
if __name__ == '__main__':
main()