-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdraw_confusion_matrix.py
93 lines (82 loc) · 1.91 KB
/
draw_confusion_matrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
def main():
y_pred = np.load('all_pred.npy')
y_true = np.load('all_labels.npy')
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(15, 15))
class_labels = [
"drink water",
"eat meal",
"brush teeth",
"brush hair",
"drop",
"pick up",
"throw",
"sit down",
"stand up",
"clapping",
"reading",
"writing",
"tear up paper",
"put on jacket",
"take off jacket",
"put on a shoe",
"take off a shoe",
"put on glasses",
"take off glasses",
"put on a hat/cap",
"take off a hat/cap",
"cheer up",
"hand waving",
"kicking something",
"reach into pocket",
"hopping",
"jump up",
"phone call",
"play with phone/tablet",
"type on a keyboard",
"point to something",
"taking a selfie",
"check time (from watch)",
"rub two hands",
"nod head/bow",
"shake head",
"wipe face",
"salute",
"put palms together",
"cross hands in front",
"sneeze/cough",
"staggering",
"falling down",
"headache",
"chest pain",
"back pain",
"neck pain",
"nausea/vomiting",
"fan self",
"punch/slap",
"kicking",
"pushing",
"pat on back",
"point finger",
"hugging",
"giving object",
"touch pocket",
"shaking hands",
"walking towards",
"walking apart"
]
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
plt.xticks(np.arange(len(class_labels)), class_labels, rotation=90)
plt.yticks(np.arange(len(class_labels)), class_labels)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
plt.savefig('confusion_matrix.png', dpi=800)
plt.show()
if __name__ == "__main__":
main()