Skip to content
Open

V1.2 #34

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file removed closed-loop/.DS_Store
Binary file not shown.
107 changes: 25 additions & 82 deletions projects/mmdet3d_plugin/GenAD/GenAD_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ def forward(self,
gt_labels_3d=None,
gt_attr_labels=None,
ego_fut_trajs=None,
gt_bboxes_3d=None,
):
"""Forward function.
Args:
Expand Down Expand Up @@ -755,12 +756,12 @@ def forward(self,
# generator for planning & motion
current_states = torch.cat((motion_hs.permute(1, 0, 2), ca_motion_query.reshape(batch_size, -1, self.embed_dims)), dim=2)
distribution_comp = {}
# states = torch.randn((2, 1, 64, 200, 200), device=motion_hs.device)
# future_distribution_inputs = torch.randn((2, 5, 6, 200, 200), device=motion_hs.device)
noise = None

if self.training:
future_distribution_inputs = self.get_future_labels(gt_labels_3d, gt_attr_labels,
ego_fut_trajs, motion_hs.device)
ego_fut_trajs, outputs_classes,
outputs_coords, gt_bboxes_3d)
else:
future_distribution_inputs = None

Expand Down Expand Up @@ -798,12 +799,7 @@ def forward(self,
agent_trajs = torch.stack(motion_fut_trajs_list, dim=3)
agent_trajs = agent_trajs.reshape(batch_size, 1, self.agent_dim, self.fut_mode, -1)

# future_hs = future_states_hs[:, :, 0:self.agent_dim * self.fut_mode, :].reshape(
# batch_size, self.agent_dim, self.fut_mode, -1)
# current_hs = current_states[:, 0:self.agent_dim * self.fut_mode, :].reshape(
# batch_size, self.agent_dim, self.fut_mode, -1)
#
# motion_cls_hs = torch.cat((future_hs, current_hs), dim=-1)

motion_cls_hs = torch.cat((future_states_hs[:, :, 0:self.agent_dim * self.fut_mode, :].
reshape(batch_size, self.agent_dim, self.fut_mode, -1),
current_states[:, 0:self.agent_dim * self.fut_mode, :].
Expand All @@ -820,8 +816,6 @@ def forward(self,
outputs_coords = torch.stack(outputs_coords)
outputs_trajs = agent_trajs.permute(1, 0, 2, 3, 4)
outputs_trajs_classes = torch.stack(outputs_trajs_classes)
# outputs_trajs = outputs_trajs.repeat(outputs_coords.shape[0], 1, 1, 1, 1)
# outputs_trajs_classes = outputs_trajs_classes.repeat(outputs_coords.shape[0], 1, 1, 1)

outs = {
'bev_embed': bev_embed,
Expand Down Expand Up @@ -1987,73 +1981,6 @@ def distribution_forward(self, present_features, future_distribution_inputs=None

return sample, output_distribution

def get_future_labels(self, gt_labels_3d, gt_attr_labels, ego_fut_trajs, device):

"""get_future_label.
Args:
gt_labels_3d: agent future 3d labels
gt_attr_labels: agent future 3d labels
ego_fut_trajs: ego future trajectory.
device: gpu device id
Returns:
gt_trajs: [B, A, T, 2]
"""

agent_dim = 300
veh_list = [0, 1, 3, 4]
mapped_class_names = [
'car', 'truck', 'construction_vehicle', 'bus',
'trailer', 'barrier', 'motorcycle', 'bicycle',
'pedestrian', 'traffic_cone'
]
ignore_list = ['construction_vehicle', 'barrier',
'traffic_cone', 'motorcycle', 'bicycle']

batch_size = len(gt_labels_3d)

# gt_label = gt_labels_3d[0]
# gt_attr_label = gt_attr_labels[0]

gt_fut_trajs_bz_list = []

for bz in range(batch_size):
gt_fut_trajs_list = []
gt_label = gt_labels_3d[bz]
gt_attr_label = gt_attr_labels[bz]
for i in range(gt_label.shape[0]):
gt_label[i] = 0 if gt_label[i] in veh_list else gt_label[i]
box_name = mapped_class_names[gt_label[i]]
if box_name in ignore_list:
continue
gt_fut_masks = gt_attr_label[i][self.fut_ts * 2:self.fut_ts * 3]
num_valid_ts = sum(gt_fut_masks == 1)
gt_fut_traj = gt_attr_label[i][:self.fut_ts * 2].reshape(-1, 2)
gt_fut_traj = gt_fut_traj[:num_valid_ts]
if gt_fut_traj.shape[0] == 0:
gt_fut_traj = torch.zeros([self.fut_ts - gt_fut_traj.shape[0], 2], device=device)
if gt_fut_traj.shape[0] < self.fut_ts:
gt_fut_traj = torch.cat(
(gt_fut_traj, torch.zeros([self.fut_ts - gt_fut_traj.shape[0], 2], device=device)), 0)
gt_fut_trajs_list.append(gt_fut_traj)

if len(gt_fut_trajs_list) != 0 & len(gt_fut_trajs_list) < agent_dim:
gt_fut_trajs = torch.cat(
(torch.stack(gt_fut_trajs_list),
torch.zeros([agent_dim - len(gt_fut_trajs_list), self.fut_ts, 2], device=device)), 0)
else:
gt_fut_trajs = torch.zeros([agent_dim, self.fut_ts, 2], device=device)

gt_fut_trajs_bz_list.append(gt_fut_trajs)

if len(gt_fut_trajs_bz_list) != 0:
gt_trajs = torch.cat((torch.stack(gt_fut_trajs_bz_list).repeat(1, 6, 1, 1), ego_fut_trajs), dim=1)
else:
gt_trajs = ego_fut_trajs
# future_states = gt_trajs.reshape(batch_size, gt_trajs.shape[1], -1)

# [bz, a, t, 2]
return gt_trajs.reshape(batch_size, gt_trajs.shape[1], -1)

def future_states_predict(self, batch_size, sample, hidden_states, current_states):
"""get_future_label.
Args:
Expand Down Expand Up @@ -2082,8 +2009,24 @@ def future_states_predict(self, batch_size, sample, hidden_states, current_state

return states_hs, future_states_hs





def get_future_labels(self, gt_labels_3d, gt_attr_labels, ego_fut_trajs, outputs_classes, outputs_coords, gt_bboxes_3d):
temp_cls_scores = torch.stack(outputs_classes)[-1]
temp_bbox_preds = torch.stack(outputs_coords)[-1]
bz = temp_cls_scores.size(0)
temp_cls_scores_list = [temp_cls_scores[i] for i in range(bz)]
temp_bbox_preds_list = [temp_bbox_preds[i] for i in range(bz)]
device = gt_labels_3d[0].device
gt_bboxes_list = [torch.cat(
(gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]),
dim=1).to(device) for gt_bboxes in gt_bboxes_3d]
cls_reg_targets = self.get_targets(temp_cls_scores_list, temp_bbox_preds_list,
gt_bboxes_list, gt_labels_3d,
gt_attr_labels, None)
(_, _, _, _, traj_targets_list, traj_weights_list, _, _, _) = cls_reg_targets
agent_targets_gt = torch.stack(traj_targets_list) # [bz, 300, 12]
agent_targets_gt = agent_targets_gt.repeat_interleave(6, dim=1)
ego_target_gt = ego_fut_trajs.flatten(2) # # [bz, 1, 12]
all_target_gt = torch.cat([agent_targets_gt, ego_target_gt], dim=1) # [bz, 301, 12]

return all_target_gt