-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathhonetMANO.py
286 lines (248 loc) · 12.6 KB
/
honetMANO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import os
import pickle
from collections import OrderedDict
from typing import Dict
import torch
import torch.nn as nn
from manotorch.utils.rodrigues import rodrigues
from anakin.datasets.hoquery import Queries
from anakin.models.mano import ManoAdaptor
from anakin.utils.builder import MODEL, build_backbone, build_head
from anakin.utils.transform import batch_persp_proj2d
from anakin.utils.logger import logger
from anakin.utils.misc import CONST, enable_lower_param
from anakin.utils.netutils import recurse_freeze
@MODEL.register_module
class HoNet(nn.Module):
@enable_lower_param
def __init__(self, **cfg):
super(HoNet, self).__init__()
if cfg["BACKBONE"]["PRETRAINED"] and cfg["PRETRAINED"]:
logger.warning(
f"{type(self).__name__}'s backbone {cfg['BACKBONE']['TYPE']} "
f"weights will be rewritten by {cfg['PRETRAINED']}"
)
self.inp_res = cfg["DATA_PRESET"]["IMAGE_SIZE"]
self.feature_dim = cfg["HEAD"]["INPUT_DIM"]
self.center_idx = cfg["DATA_PRESET"]["CENTER_IDX"]
logger.info(f"{type(self).__name__} uses center_idx {self.center_idx}")
self.base_net = build_backbone(cfg["BACKBONE"]) # ResNet18
self.mano_branch = build_head(cfg["HEAD"], default_args=cfg["DATA_PRESET"])
self.obj_trans_factor = cfg["OBJ_TRANS_FACTOR"]
self.obj_scale_factor = cfg["OBJ_SCALE_FACTOR"]
self.mano_transhead = HoNet.TransHead(inp_dim=self.feature_dim, out_dim=3)
self.obj_transhead = HoNet.TransHead(inp_dim=self.feature_dim, out_dim=6)
self.proj2d_func = batch_persp_proj2d
if cfg.get("MANO_FHB_ADAPTOR", False):
mano_fhb_adaptor_dir = cfg["MANO_FHB_ADAPTOR_DIR"] # assets/hasson20_assets/mano
adaptor_path = os.path.join(mano_fhb_adaptor_dir, f"fhb_skel_centeridx{self.center_idx}.pkl")
with open(adaptor_path, "rb") as p_f:
exp_data = pickle.load(p_f)
self.register_buffer("fhb_shape", torch.Tensor(exp_data["shape"]))
self.adaptor = ManoAdaptor(self.mano_branch.mano_layer, adaptor_path)
recurse_freeze(self.adaptor)
else:
self.adaptor = None
self.init_weights(pretrained=cfg["PRETRAINED"])
class TransHead(nn.Module):
def __init__(self, inp_dim: int, out_dim: int):
super().__init__()
if out_dim != 3 and out_dim != 6:
logger.error(f"Unrecognized TransHead out dim: {out_dim}")
raise ValueError()
base_neurons = [inp_dim, int(inp_dim / 2)]
layers = []
for (inp_neurons, out_neurons) in zip(base_neurons[:-1], base_neurons[1:]):
layers.append(nn.Linear(inp_neurons, out_neurons))
layers.append(nn.ReLU())
self.final_layer = nn.Linear(out_neurons, out_dim)
self.decoder = nn.Sequential(*layers)
def forward(self, inp):
decoded = self.decoder(inp)
out = self.final_layer(decoded)
return out
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
# ! remapping STATE_DICT from pretrained model of HASSON[CVPR2020] to our HONet
need_to_be_remove = []
need_to_be_insert = {}
for key in state_dict.keys():
if "mano_layer_left" in key:
need_to_be_remove.append(key)
elif "mano_layer_right" in key:
need_to_be_remove.append(key)
new_key = key.replace("mano_layer_right", "mano_layer")
need_to_be_insert[new_key] = state_dict[key]
elif "scaletrans_branch_obj" in key:
need_to_be_remove.append(key)
new_key = key.replace("scaletrans_branch_obj", "obj_transhead")
need_to_be_insert[new_key] = state_dict[key]
elif "scaletrans_branch." in key:
need_to_be_remove.append(key)
new_key = key.replace("scaletrans_branch", "mano_transhead")
need_to_be_insert[new_key] = state_dict[key]
if len(need_to_be_insert) or len(need_to_be_remove):
logger.warning("remapping STATE_DICT from pretrained model of HASSON[CVPR2020]")
state_dict.update(need_to_be_insert)
for key in need_to_be_remove:
state_dict.pop(key)
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
@staticmethod
def recover_3d_proj(
objpoints3d: torch.Tensor,
camintr: torch.Tensor,
est_scale,
est_trans,
input_res,
off_z=0.4,
):
"""
Given estimated centered points, camera intrinsics and predicted scale and translation
in pixel world, compute the point coordinates in camera coordinate system
"""
# Estimate scale and trans between 3D and 2D
focal = camintr[:, :1, :1]
batch_size = objpoints3d.shape[0]
focal = focal.view(batch_size, 1)
est_scale = est_scale.view(batch_size, 1)
est_trans = est_trans.view(batch_size, 2)
# est_scale is homogeneous to object scale change in pixels
est_Z0 = focal * est_scale + off_z
cam_centers = camintr[:, :2, 2]
img_centers = (cam_centers.new(input_res) / 2).view(1, 2).repeat(batch_size, 1)
est_XY0 = (est_trans + img_centers - cam_centers) * est_Z0 / focal
est_c3d = torch.cat([est_XY0, est_Z0], -1).unsqueeze(1) # TENSOR(B, 1, 3)
recons3d = est_c3d + objpoints3d
return recons3d, est_c3d
def recover_mano(self, feature: torch.Tensor, samples: Dict):
# ============= Get hand joints & verts, centered >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
mano_results = self.mano_branch(feature)
if self.adaptor:
"""
HASSON[CVPR2020] MANO Adaptor for FHAB dataset
"""
verts = mano_results["hand_verts_3d"]
adapt_joints, _ = self.adaptor(verts)
adapt_joints = adapt_joints.transpose(1, 2)
mano_results["joints_3d"] = adapt_joints - adapt_joints[:, self.center_idx].unsqueeze(1)
mano_results["hand_verts_3d"] = verts - adapt_joints[:, self.center_idx].unsqueeze(1)
# ============== Recover hand position in camera coordinates >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
scaletrans = self.mano_transhead(feature) # TENSOR (B, 3)
trans = scaletrans[:, 1:] # TENSOR (B, 2)
scale = scaletrans[:, :1] # TENSOR (B, 1)
final_trans = trans.unsqueeze(1) * self.obj_trans_factor
final_scale = scale.view(scale.shape[0], 1, 1) * self.obj_scale_factor
height, width = tuple(samples[Queries.IMAGE].shape[2:])
cam_intr = samples[Queries.CAM_INTR] # TENSOR(B, 3, 3)
joints_3d_abs, root_joint = HoNet.recover_3d_proj(
mano_results["joints_3d"], cam_intr, final_scale, final_trans, input_res=(width, height)
)
hand_verts_3d_abs = mano_results["hand_verts_3d"] + root_joint
joints_2d = self.proj2d_func(joints_3d_abs, cam_intr)
hand_verts_2d = self.proj2d_func(hand_verts_3d_abs, cam_intr)
# * @Xinyu: mano_results["recov_joints3d"] = mano_results["joints3d"] + mano_results["hand_center3d"]
mano_results["joints_2d"] = joints_2d
mano_results["root_joint"] = root_joint # ===== To PICR =====
mano_results["joints_3d_abs"] = joints_3d_abs # ===== To PICR =====
mano_results["hand_verts_3d_abs"] = hand_verts_3d_abs # ===== To PICR =====
mano_results["hand_verts_2d"] = hand_verts_2d
mano_results["hand_pred_trans"] = trans
mano_results["hand_pred_scale"] = scale
mano_results["hand_trans"] = final_trans
mano_results["hand_scale"] = final_scale
return mano_results
def recover_object(self, feature: torch.Tensor, samples: Dict):
"""
Compute object vertex and corner positions in camera coordinates by predicting object translation
and scaling, and recovering 3D positions given known object model
"""
scaletrans_obj = self.obj_transhead(feature)
batch_size = scaletrans_obj.shape[0]
scale = scaletrans_obj[:, :1]
trans = scaletrans_obj[:, 1:3]
rotaxisang = scaletrans_obj[:, 3:]
rotmat = rodrigues(rotaxisang).view(rotaxisang.shape[0], 3, 3)
obj_verts_can = samples[Queries.OBJ_VERTS_CAN]
obj_verts_ = rotmat.bmm(obj_verts_can.float().transpose(1, 2)).transpose(1, 2)
final_trans = trans.unsqueeze(1) * self.obj_trans_factor
final_scale = scale.view(batch_size, 1, 1) * self.obj_scale_factor
height, width = tuple(samples[Queries.IMAGE].shape[2:])
cam_intr = samples[Queries.CAM_INTR]
obj_verts_3d_abs, obj_center = HoNet.recover_3d_proj(
obj_verts_, cam_intr, final_scale, final_trans, input_res=(width, height)
)
obj_verts_2d = self.proj2d_func(obj_verts_3d_abs, cam_intr)
# Recover 2D positions given camera intrinsic parameters and object vertex
# coordinates in camera coordinate reference
if Queries.CORNERS_3D in samples:
corners_can = samples[Queries.CORNERS_CAN]
obj_corners_ = rotmat.bmm(corners_can.float().transpose(1, 2)).transpose(1, 2)
corners_3d_abs = obj_corners_ + obj_center
corners_2d = self.proj2d_func(corners_3d_abs, cam_intr)
else:
obj_corners_ = None
corners_3d_abs = None
corners_2d = None
# @Xinyu: obj_results["recov_obj_verts3d"] = \
# obj_results["rotaxisang"] @ OBJ_CAN_VERTS + obj_results["obj_center3d"]
obj_results = {
"obj_center": obj_center, # ===== To PICR =====
"obj_verts_3d_abs": obj_verts_3d_abs, # ===== To PICR =====
"corners_3d_abs": corners_3d_abs,
"obj_pred_scale": scale,
"obj_pred_trans": trans,
"obj_rot": rotaxisang, # ===== To PICR =====
"obj_scale": final_scale,
"obj_trans": final_trans,
"corners_2d": corners_2d,
"obj_verts_2d": obj_verts_2d,
# TODO for MSSD
"box_rot_rotmat": rotmat,
"boxroot_3d_abs": obj_center,
}
return obj_results
def forward(self, samples: Dict):
results = {}
features = self.base_net(image=samples["image"])
mano_results = self.recover_mano(features["res_layer4_mean"], samples)
obj_results = self.recover_object(features["res_layer4_mean"], samples)
# ⬇ make corners_3d root relative
obj_results["corners_3d"] = obj_results["corners_3d_abs"] - mano_results["root_joint"]
obj_results["obj_verts_3d"] = obj_results["obj_verts_3d_abs"] - mano_results["root_joint"]
results = {**mano_results, **obj_results}
return results
def init_weights(self, pretrained=""):
if pretrained == "":
logger.warning(f"=> Init {type(self).__name__} weights in backbone and head")
"""
Add init for other modules if has
...
"""
elif os.path.isfile(pretrained):
# pretrained_state_dict = torch.load(pretrained)
logger.info(f"=> Loading {type(self).__name__} pretrained model from: {pretrained}")
# self.load_state_dict(pretrained_state_dict, strict=False)
checkpoint = torch.load(pretrained)
if isinstance(checkpoint, OrderedDict):
state_dict = checkpoint
elif isinstance(checkpoint, dict) and "state_dict" in checkpoint:
state_dict_old = checkpoint["state_dict"]
state_dict = OrderedDict()
# delete 'module.' because it is saved from DataParallel module
for key in state_dict_old.keys():
if key.startswith("module."):
# state_dict[key[7:]] = state_dict[key]
# state_dict.pop(key)
state_dict[key[7:]] = state_dict_old[key] # delete "module." (in nn.parallel and ddp)
else:
state_dict[key] = state_dict_old[key]
else:
logger.error(f"=> No state_dict found in checkpoint file {pretrained}")
raise RuntimeError()
self.load_state_dict(state_dict, strict=True)
else:
logger.error(f"=> No {type(self).__name__} checkpoints file found in {pretrained}")
raise FileNotFoundError()