forked from mattragoza/LiGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcaffe_util.py
150 lines (122 loc) · 4.9 KB
/
caffe_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import contextlib
import tempfile
from caffe.proto import caffe_pb2
from google.protobuf import text_format, message
import caffe
def read_prototxt(param, prototxt_file):
with open(prototxt_file, 'r') as f:
text_format.Merge(f.read(), param)
def from_prototxt_str(param_type, prototxt_str):
param = param_type()
text_format.Merge(prototxt_str, param)
return param
def from_prototxt(param_type, prototxt_file):
param = param_type()
read_prototxt(param, prototxt_file)
return param
def write_prototxt(param, prototxt_file):
with open(prototxt_file, 'w') as f:
f.write(str(param))
@contextlib.contextmanager
def temp_prototxt(param):
_, prototxt_file = tempfile.mkstemp()
write_prototxt(param, prototxt_file)
yield prototxt_file
os.remove(prototxt_file)
def update_param(param_, *args, **kwargs):
for key, value in kwargs.items():
if isinstance(value, list):
update_param(getattr(param_, key), *value)
elif isinstance(value, dict):
update_param(getattr(param_, key), **value)
elif isinstance(value, message.Message):
getattr(param_, key).CopyFrom(value)
else:
setattr(param_, key, value)
for i, value in enumerate(args):
if i == len(param_):
try:
param_.add()
except AttributeError:
param_.append(value)
continue
if isinstance(value, list):
update_param(param_[i], *value)
elif isinstance(value, dict):
update_param(param_[i], **value)
elif isinstance(value, message.Message):
param_[i].CopyFrom(value)
else:
param_[i] = value
def set_molgrid_data_source(net_param, data_file, data_root, phase=None):
for layer_param in net_param.layer:
if layer_param.type == 'MolGridData':
data_param = layer_param.molgrid_data_param
if phase is None:
data_param.source = data_file
data_param.root_folder = data_root
elif layer_param.include[0].phase == phase:
data_param.source = data_file
data_param.root_folder = data_root
def get_molgrid_data_param(net_param, phase=None):
for layer_param in net_param.layer:
if layer_param.type == 'MolGridData':
data_param = layer_param.molgrid_data_param
if phase is None:
return data_param
elif layer_param.include[0].phase == phase:
return data_param
# can't inherit from protobuf message, so just add methods to the generated classes
for name, cls in caffe_pb2.__dict__.items():
if isinstance(cls, type) and issubclass(cls, message.Message):
cls.from_prototxt = classmethod(from_prototxt)
cls.from_prototxt_str = classmethod(from_prototxt_str)
cls.to_prototxt = write_prototxt
cls.temp_prototxt = temp_prototxt
cls.update = update_param
cls.__init__ = update_param
globals()[name] = cls
if issubclass(cls, caffe_pb2.NetParameter):
cls.set_molgrid_data_source = set_molgrid_data_source
cls.get_molgrid_data_param = get_molgrid_data_param
class Net(caffe.Net):
@classmethod
def from_param(cls, net_param=None, weights_file=None, phase=-1, **kwargs):
net_param = net_param or NetParameter()
net_param.update(**kwargs)
with net_param.temp_prototxt() as model_file:
return cls(network_file=model_file, weights=weights_file, phase=phase)
@classmethod
def from_spec(cls, net_spec, *args, **kwargs):
return Net.from_param(net_spec.to_proto(), *args, **kwargs)
def get_n_params(self):
n_params = 0
for layer_name, param_blobs in self.params.items():
for param_blob in param_blobs:
n_params += param_blob.data.size
return n_params
def get_n_activs(self):
n_activs = 0
for blob_name, activ_blob in self.blobs.items():
n_activs += activ_blob.data.size
return n_activs
def get_approx_size(self):
return 2*(self.get_n_params() + self.get_n_activs())*4
def get_min_width(self):
min_width = float('inf')
min_width_name = None
for blob_name, activ_blob in self.blobs.items():
if '_latent_mean' in blob_name:
width = activ_blob.data.size // activ_blob.shape[0]
if width < min_width:
min_width = width
min_width_name = blob_name
return min_width
class Solver(caffe._caffe.Solver):
@classmethod
def from_param(cls, solver_param=None, **kwargs):
solver_param = solver_param or SolverParameter()
solver_param.update(**kwargs)
with solver_param.temp_prototxt() as solver_file:
return getattr(caffe, '{}Solver'.format(solver_param.type))(solver_file)