forked from chainer/chainercv
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcaffe2npz.py
57 lines (40 loc) · 1.43 KB
/
caffe2npz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import re
import chainer
from chainer import Link
import chainer.links.caffe.caffe_function as caffe
"""
Please download a weight from here.
http://www.robots.ox.ac.uk/%7Evgg/software/very_deep/caffe/VGG_ILSVRC_16_layers.caffemodel
"""
def rename(name):
m = re.match(r'conv(\d+)_(\d+)$', name)
if m:
i, j = map(int, m.groups())
return 'conv{:d}_{:d}/conv'.format(i, j)
return name
class VGGCaffeFunction(caffe.CaffeFunction):
def __init__(self, model_path):
print('loading weights from {:s} ... '.format(model_path))
super(VGGCaffeFunction, self).__init__(model_path)
def __setattr__(self, name, value):
if self.within_init_scope and isinstance(value, Link):
new_name = rename(name)
if new_name == 'conv1_1/conv':
# BGR -> RGB
value.W.array[:, ::-1] = value.W.array
print('{:s} -> {:s} (BGR -> RGB)'.format(name, new_name))
else:
print('{:s} -> {:s}'.format(name, new_name))
else:
new_name = name
super(VGGCaffeFunction, self).__setattr__(new_name, value)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('caffemodel')
parser.add_argument('output')
args = parser.parse_args()
model = VGGCaffeFunction(args.caffemodel)
chainer.serializers.save_npz(args.output, model)
if __name__ == '__main__':
main()