-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmodel.py
35 lines (28 loc) · 1.09 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch.nn as nn
from anakin.postprocess.iknet import utils
class IKNet(nn.Module):
def __init__(
self,
njoints=21,
hidden_size_pose=[256, 512, 1024, 1024, 512, 256],
):
super(IKNet, self).__init__()
self.njoints = njoints
in_neurons = 3 * njoints
out_neurons = 16 * 4 # 16 quats
neurons = [in_neurons] + hidden_size_pose
invk_layers = []
for layer_idx, (inps, outs) in enumerate(zip(neurons[:-1], neurons[1:])):
invk_layers.append(nn.Linear(inps, outs))
invk_layers.append(nn.BatchNorm1d(outs))
invk_layers.append(nn.ReLU())
invk_layers.append(nn.Linear(neurons[-1], out_neurons))
self.invk_layers = nn.Sequential(*invk_layers)
def forward(self, joint):
joint = joint.contiguous().view(-1, self.njoints * 3)
quat = self.invk_layers(joint)
quat = quat.view(-1, 16, 4)
quat = utils.normalize_quaternion(quat)
so3 = utils.quaternion_to_angle_axis(quat).contiguous()
so3 = so3.view(-1, 16 * 3)
return so3, quat