-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpose_metrics.py
More file actions
75 lines (58 loc) · 2.61 KB
/
pose_metrics.py
File metadata and controls
75 lines (58 loc) · 2.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
def calculate_add(
pred_t: np.ndarray,
pred_R: np.ndarray,
target_t: np.ndarray,
target_R: np.ndarray,
model_points: np.ndarray,
diameter: float,
threshold: float = 0.1,
):
"""Calculate the Average Distance of Model Points (ADD) metric.
When a prediction is considered correct if the average distance between the
predicted and the target 3D points is smaller than threshold * diameter.
Args:
pred_t (np.ndarray): Predicted translation vector.
pred_R (np.ndarray): Predicted rotation matrix.
target_t (np.ndarray): Target translation vector.
target_R (np.ndarray): Target rotation matrix.
model_points (np.ndarray): Several 3D points of the object model.
diameter (float): Diameter of the object model in [mm].
threshold (float): Threshold for the ADD metric (Default: 0.1).
Returns:
add_metric (np.ndarray): Average distance per prediction.
add_true_positive (np.ndarray): Boolean per prediction.
"""
# Transform model points
# pred_R is of shape (num_detections, 3, 3)
# pred_t is of shape (num_detections, 3)
pred_points = model_points @ pred_R + pred_t[:, np.newaxis, :]
targ_points = model_points @ target_R + target_t[:, np.newaxis, :]
# Point Distances per detection
point_distance = np.linalg.norm(targ_points - pred_points, axis=-1)
# Mean Point Distance per detection
add_metric = np.mean(point_distance, axis=-1)
add_true_positive = add_metric < diameter * threshold
return add_metric, add_true_positive
def calculate_translation_error(pred_t: np.ndarray, target_t: np.ndarray):
return np.linalg.norm(target_t - pred_t, axis=-1)
def calculate_depth_error(pred_t: np.ndarray, target_t: np.ndarray):
return target_t[..., 2] - pred_t[..., 2]
def calculate_xy_error(pred_t: np.ndarray, target_t: np.ndarray):
return np.linalg.norm(target_t[..., :2] - pred_t[..., :2], axis=-1)
def calculate_rotation_error(pred_R: np.ndarray, target_R: np.ndarray):
"""Calculate the angular error between two rotation matrices.
Adapted from https://github.com/ethnhe/PVN3D/blob/master/pvn3d/lib/utils/evaluation_utils.py#L316
"""
angular_distances = []
for Rp, Rt in zip(pred_R, target_R):
rotation_diff = Rp @ Rt.T
trace = np.trace(rotation_diff)
trace = (trace - 1.0) / 2.0
if trace < -1.0:
trace = -1.0
elif trace > 1.0:
trace = 1.0
angular_distance = np.rad2deg(np.arccos(trace))
angular_distances.append(angular_distance)
return np.array(angular_distances)