-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add TSN dygraph model #4817
base: release/1.8
Are you sure you want to change the base?
add TSN dygraph model #4817
Changes from 1 commit
fdcd1ff
f5e723e
4048b3b
ba0dd40
e035d91
5df6704
9e963ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# TSN 视频分类模型 | ||
本目录下为基于PaddlePaddle 动态图实现的 TSM视频分类模型 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TSM -> TSN |
||
|
||
--- | ||
## 内容 | ||
|
||
- [模型简介](#模型简介) | ||
- [数据准备](#数据准备) | ||
- [模型训练](#模型训练) | ||
- [模型评估](#模型评估) | ||
- [参考论文](#参考论文) | ||
|
||
|
||
## 模型简介 | ||
|
||
Temporal Segment Network (TSN) 是视频分类领域经典的基于2D-CNN的解决方案。该方法主要解决视频的长时间行为判断问题,通过稀疏采样视频帧的方式代替稠密采样,既能捕获视频全局信息,也能去除冗余,降低计算量。最终将每帧特征平均融合后得到视频的整体特征,并用于分类。本代码实现的模型为基于单路RGB图像的TSN网络结构,Backbone采用ResNet-50结构。 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ResNet-50 -> ResNet50 |
||
|
||
详细内容请参考ECCV 2016年论文[Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859) | ||
|
||
## 数据准备 | ||
|
||
TSN的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。数据下载及准备请参考[数据说明](./data/dataset/ucf101/README.md) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 【采用Kinetics-400动作识别数据集】,但数据说明链接给的是ucf-101? |
||
|
||
## 模型训练 | ||
|
||
数据准备完毕后,可以通过如下两种方式启动训练 | ||
|
||
1. 多卡训练 | ||
```python | ||
bash multi-gpus-run.sh train ./configs/tsn.yaml | ||
``` | ||
多卡训练所使用的gpu可以通过如下方式设置: | ||
- 首先,修改./configs/tsn.yaml 中的 num_gpus (默认为4,表示使用4个gpu进行训练) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is unnecessary to keep |
||
- 其次,修改 multi-gpus-run.sh 中 export CUDA_VISIBLE_DEVICES=0,1,2,3 和 --selected_gpus=0,1,2,3 (默认为0,1,2,3表示使用0,1,2,3卡号的gpu进行训练) | ||
|
||
|
||
2. 单卡训练 | ||
```python | ||
bash run.sh train ./configs/tsn.yaml | ||
``` | ||
单卡训练所使用的gpu可以通过如下方式设置: | ||
- 首先,修改./configs/tsn.yaml 中的 num_gpus=1 (表示使用单卡进行训练) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
- 首先,修改run.sh 中的 export CUDA_VISIBLE_DEVICES=0 (表示使用gpu 0 进行模型训练) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 两个首先,单卡时batch_size要调吗 |
||
|
||
## 模型评估 | ||
|
||
可通过如下两种方式进行模型评估: | ||
```python | ||
bash run.sh eval ./configs/tsn-test.yaml ./weights/final.pdparams | ||
``` | ||
|
||
- 使用`run.sh`进行评估时,需要修改脚本中的`weights`参数指定需要评估的权重 | ||
|
||
- `./tsn-test.yaml` 是评估模型时所用的参数文件;`./weights/final.pdparams` 为模型训练完成后,保存的模型文件 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use underline instead of dash name files |
||
|
||
- 评估结果以log的形式直接打印输出TOP1\_ACC、TOP5\_ACC等精度指标 | ||
|
||
|
||
|
||
当取如下参数时,在UCF101数据的validation数据集下评估精度如下: | ||
|
||
| | seg\_num | Top-1 | Top-5 | | ||
| :------: | :----------: | :----: | :----: | | ||
| Pytorch TSN | 3 | 83.88% | 96.78% | | ||
| Paddle TSN (静态图) | 3 | 84.00% | 97.38% | | ||
| Paddle TSN (动态图) | 3 | 84.27% | 97.27% | | ||
|
||
## 参考论文 | ||
|
||
- [Temporal Segment Networks: Towards Good Practices for Deep Action Recognition](https://arxiv.org/abs/1608.00859), Limin Wang, Yuanjun Xiong, Zhe Wang, Yu Qiao, Dahua Lin, Xiaoou Tang, Luc Van Gool |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
import argparse | ||
import os | ||
import glob | ||
import fnmatch | ||
import random | ||
|
||
|
||
def parse_directory(path, | ||
key_func=lambda x: x[-11:], | ||
rgb_prefix='img_', | ||
level=1): | ||
""" | ||
Parse directories holding extracted frames from standard benchmarks | ||
""" | ||
print('parse frames under folder {}'.format(path)) | ||
if level == 1: | ||
frame_folders = glob.glob(os.path.join(path, '*')) | ||
elif level == 2: | ||
frame_folders = glob.glob(os.path.join(path, '*', '*')) | ||
else: | ||
raise ValueError('level can be only 1 or 2') | ||
|
||
def count_files(directory, prefix_list): | ||
lst = os.listdir(directory) | ||
cnt_list = [len(fnmatch.filter(lst, x + '*')) for x in prefix_list] | ||
return cnt_list | ||
|
||
# check RGB | ||
frame_dict = {} | ||
for i, f in enumerate(frame_folders): | ||
all_cnt = count_files(f, (rgb_prefix)) | ||
k = key_func(f) | ||
|
||
x_cnt = all_cnt[1] | ||
y_cnt = all_cnt[2] | ||
if x_cnt != y_cnt: | ||
raise ValueError('x and y direction have different number ' | ||
'of flow images. video: ' + f) | ||
if i % 200 == 0: | ||
print('{} videos parsed'.format(i)) | ||
|
||
frame_dict[k] = (f, all_cnt[0], x_cnt) | ||
|
||
print('frame folder analysis done') | ||
return frame_dict | ||
|
||
|
||
def build_split_list(split, frame_info, shuffle=False): | ||
def build_set_list(set_list): | ||
rgb_list = list() | ||
for item in set_list: | ||
if item[0] not in frame_info: | ||
# print("item:", item) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete debug code |
||
continue | ||
elif frame_info[item[0]][1] > 0: | ||
rgb_cnt = frame_info[item[0]][1] | ||
rgb_list.append('{} {} {}\n'.format(item[0], rgb_cnt, item[1])) | ||
else: | ||
rgb_list.append('{} {}\n'.format(item[0], item[1])) | ||
if shuffle: | ||
random.shuffle(rgb_list) | ||
return rgb_list | ||
|
||
train_rgb_list = build_set_list(split[0]) | ||
test_rgb_list = build_set_list(split[1]) | ||
return (train_rgb_list, test_rgb_list) | ||
|
||
|
||
def parse_ucf101_splits(level): | ||
class_ind = [x.strip().split() for x in open('./annotations/classInd.txt')] | ||
class_mapping = {x[1]: int(x[0]) - 1 for x in class_ind} | ||
|
||
def line2rec(line): | ||
items = line.strip().split(' ') | ||
vid = items[0].split('.')[0] | ||
vid = '/'.join(vid.split('/')[-level:]) | ||
label = class_mapping[items[0].split('/')[0]] | ||
return vid, label | ||
|
||
splits = [] | ||
for i in range(1, 4): | ||
train_list = [ | ||
line2rec(x) | ||
for x in open('./annotations/trainlist{:02d}.txt'.format(i)) | ||
] | ||
test_list = [ | ||
line2rec(x) | ||
for x in open('./annotations/testlist{:02d}.txt'.format(i)) | ||
] | ||
splits.append((train_list, test_list)) | ||
return splits | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Build file list') | ||
parser.add_argument( | ||
'frame_path', type=str, help='root directory for the frames') | ||
parser.add_argument('--rgb_prefix', type=str, default='img_') | ||
parser.add_argument('--num_split', type=int, default=3) | ||
parser.add_argument( | ||
'--subset', type=str, default='train', | ||
choices=['train', 'val', 'test']) | ||
parser.add_argument('--level', type=int, default=2, choices=[1, 2]) | ||
parser.add_argument( | ||
'--format', | ||
type=str, | ||
default='rawframes', | ||
choices=['rawframes', 'videos']) | ||
parser.add_argument('--out_list_path', type=str, default='./') | ||
parser.add_argument('--shuffle', action='store_true', default=False) | ||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
||
if args.level == 2: | ||
|
||
def key_func(x): | ||
return '/'.join(x.split('/')[-2:]) | ||
else: | ||
|
||
def key_func(x): | ||
return x.split('/')[-1] | ||
|
||
if args.format == 'rawframes': | ||
frame_info = parse_directory( | ||
args.frame_path, | ||
key_func=key_func, | ||
rgb_prefix=args.rgb_prefix, | ||
level=args.level) | ||
elif args.format == 'videos': | ||
if args.level == 1: | ||
video_list = glob.glob(os.path.join(args.frame_path, '*')) | ||
elif args.level == 2: | ||
video_list = glob.glob(os.path.join(args.frame_path, '*', '*')) | ||
frame_info = { | ||
os.path.relpath(x.split('.')[0], args.frame_path): (x, -1, -1) | ||
for x in video_list | ||
} | ||
|
||
split_tp = parse_ucf101_splits(args.level) | ||
assert len(split_tp) == args.num_split | ||
|
||
out_path = args.out_list_path | ||
if len(split_tp) > 1: | ||
for i, split in enumerate(split_tp): | ||
lists = build_split_list( | ||
split_tp[i], frame_info, shuffle=args.shuffle) | ||
filename = 'ucf101_train_split_{}_{}.txt'.format(i + 1, args.format) | ||
|
||
with open(os.path.join(out_path, filename), 'w') as f: | ||
f.writelines(lists[0]) | ||
filename = 'ucf101_val_split_{}_{}.txt'.format(i + 1, args.format) | ||
with open(os.path.join(out_path, filename), 'w') as f: | ||
f.writelines(lists[1]) | ||
else: | ||
lists = build_split_list(split_tp[0], frame_info, shuffle=args.shuffle) | ||
filename = '{}_{}_list_{}.txt'.format(args.dataset, args.subset, | ||
args.format) | ||
if args.subset == 'train': | ||
ind = 0 | ||
elif args.subset == 'val': | ||
ind = 1 | ||
elif args.subset == 'test': | ||
ind = 2 | ||
with open(os.path.join(out_path, filename), 'w') as f: | ||
f.writelines(lists[0][ind]) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import argparse | ||
import sys | ||
import os | ||
import os.path as osp | ||
import glob | ||
from pipes import quote | ||
from multiprocessing import Pool, current_process | ||
import cv2 | ||
|
||
|
||
def dump_frames(vid_item): | ||
full_path, vid_path, vid_id = vid_item | ||
vid_name = vid_path.split('.')[0] | ||
out_full_path = osp.join(args.out_dir, vid_name) | ||
try: | ||
os.mkdir(out_full_path) | ||
except OSError: | ||
pass | ||
vr = cv2.VideoCapture(full_path) | ||
videolen = int(vr.get(cv2.CAP_PROP_FRAME_COUNT)) | ||
for i in range(videolen): | ||
ret, frame = vr.read() | ||
if ret == False: | ||
continue | ||
img = frame[:, :, ::-1] | ||
# covert the BGR img into RGB img | ||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | ||
if img is not None: | ||
cv2.imwrite('{}/img_{:05d}.jpg'.format(out_full_path, i + 1), img) | ||
else: | ||
print('[Warning] length inconsistent!' | ||
'Early stop with {} out of {} frames'.format(i + 1, videolen)) | ||
break | ||
print('{} done with {} frames'.format(vid_name, videolen)) | ||
sys.stdout.flush() | ||
return True | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='extract optical flows') | ||
parser.add_argument('src_dir', type=str) | ||
parser.add_argument('out_dir', type=str) | ||
parser.add_argument('--level', type=int, choices=[1, 2], default=2) | ||
parser.add_argument('--num_worker', type=int, default=8) | ||
parser.add_argument( | ||
"--out_format", | ||
type=str, | ||
default='dir', | ||
choices=['dir', 'zip'], | ||
help='output format') | ||
parser.add_argument( | ||
"--ext", | ||
type=str, | ||
default='avi', | ||
choices=['avi', 'mp4'], | ||
help='video file extensions') | ||
parser.add_argument( | ||
"--new_width", type=int, default=0, help='resize image width') | ||
parser.add_argument( | ||
"--new_height", type=int, default=0, help='resize image height') | ||
parser.add_argument( | ||
"--resume", | ||
action='store_true', | ||
default=False, | ||
help='resume optical flow extraction ' | ||
'instead of overwriting') | ||
args = parser.parse_args() | ||
|
||
return args | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
if not osp.isdir(args.out_dir): | ||
print('Creating folder: {}'.format(args.out_dir)) | ||
os.makedirs(args.out_dir) | ||
if args.level == 2: | ||
classes = os.listdir(args.src_dir) | ||
for classname in classes: | ||
new_dir = osp.join(args.out_dir, classname) | ||
if not osp.isdir(new_dir): | ||
print('Creating folder: {}'.format(new_dir)) | ||
os.makedirs(new_dir) | ||
|
||
print('Reading videos from folder: ', args.src_dir) | ||
print('Extension of videos: ', args.ext) | ||
if args.level == 2: | ||
fullpath_list = glob.glob(args.src_dir + '/*/*.' + args.ext) | ||
done_fullpath_list = glob.glob(args.out_dir + '/*/*') | ||
elif args.level == 1: | ||
fullpath_list = glob.glob(args.src_dir + '/*.' + args.ext) | ||
done_fullpath_list = glob.glob(args.out_dir + '/*') | ||
print('Total number of videos found: ', len(fullpath_list)) | ||
if args.resume: | ||
fullpath_list = set(fullpath_list).difference(set(done_fullpath_list)) | ||
fullpath_list = list(fullpath_list) | ||
print('Resuming. number of videos to be done: ', len(fullpath_list)) | ||
|
||
if args.level == 2: | ||
vid_list = list( | ||
map(lambda p: osp.join('/'.join(p.split('/')[-2:])), fullpath_list)) | ||
elif args.level == 1: | ||
vid_list = list(map(lambda p: p.split('/')[-1], fullpath_list)) | ||
|
||
pool = Pool(args.num_worker) | ||
pool.map(dump_frames, zip(fullpath_list, vid_list, range(len(vid_list)))) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# extract_rawframes_opencv.py | ||
## 应用说明 | ||
|
||
### 对于kinetics400数据 | ||
运行脚本的命令如下 `python extract_rawframes_opencv.py ./video/ ./rawframes/ --level 2 --ext mp4` 或者`python extract_rawframes_opencv.py ./video/ ./rawframes/ --level 2 --ext mp4` | ||
|
||
### 参数说明 | ||
`./video/` : 这个参数表示视频目录的地址 | ||
`./rawframes` : 提取出的frames的存放目录 | ||
`--level 1 or 2` : | ||
|
||
level 1,表示video的存储方式为 | ||
|
||
------ video | ||
|------ xajhljklk.mp4 | ||
|------ jjkjlljjk.mp4 | ||
.... | ||
|
||
|
||
level 2, 表示video的存储方式为 | ||
------ video | ||
|------ class1 | ||
|-------- xajhljklk.mp4 | ||
|-------- jjkjlljjk.mp4 | ||
.... | ||
`--ext 4` : 表示视频文件的格式。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
README.md VS REAMDE.md