-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_p_net.py
71 lines (50 loc) · 2.47 KB
/
train_p_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2022/4/7 16:01
# @Author : guoyankai
# @Email : [email protected]
# @File : train_p_net.py
# @software: PyCharm
import argparse
import os
import sys
import config
from core.imagedb import ImageDB
from train_tools import train_pnet
def parse_args():
parser = argparse.ArgumentParser(description='Train PNet',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--anno_file', dest='annotation_file',
default=os.path.join(config.ANNO_STORE_DIR, config.PNET_TRAIN_IMGLIST_FILENAME),
help='training data annotation file', type=str)
parser.add_argument('--model_path', dest='model_store_path', help='训练模型存储路径',
default=config.MODEL_STORE_DIR, type=str)
parser.add_argument('--end_epoch', dest='end_epoch', help='end epoch of training',
default=config.END_EPOCH, type=int)
parser.add_argument('--frequent', dest='frequent', help='frequency of logging',
default=200, type=int)
parser.add_argument('--lr', dest='lr', help='learning rate',
default=config.TRAIN_LR, type=float)
parser.add_argument('--batch_size', dest='batch_size', help='训练Pnet批次大小',
default=config.TRAIN_BATCH_SIZE, type=int)
parser.add_argument('--gpu', dest='use_cuda', help='train with gpu',
default=config.USE_CUDA, type=bool)
parser.add_argument('--prefix_path', dest='',
help='training data annotation images prefix root path',
type=str)
args = parser.parse_args()
return args
def train_net(annotation_file, model_store_path,
end_epoch=16, lr=0.01, batch_size=128, use_cuda=False):
imagedb = ImageDB(annotation_file)
gt_imdb = imagedb.load_imdb()
# 这里是翻转进行数据增强,可以先不使用
# gt_imdb = imagedb.append_flipped_images(gt_imdb)
train_pnet(model_store_path=model_store_path, end_epoch=end_epoch, imdb=gt_imdb,
batch_size=batch_size, base_lr=lr, use_cuda=use_cuda)
if __name__ =="__main__":
args = parse_args()
print(args.annotation_file)
# train_net(args.annotation_file, args.model_store_path,
# end_epoch=args.end_epoch,
# lr=args.lr, batch_size=args.batch_size, use_cuda=args.use_cuda)