Skip to content

Commit 3c4f8c9

Browse files
authored
Add a model for facial expression recognition (#100)
1 parent a85f1ea commit 3c4f8c9

File tree

13 files changed

+488
-14
lines changed

13 files changed

+488
-14
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Guidelines:
1919
| ------------------------------------------------------- | ----------------------------- | ---------- | -------------- | ------------ | --------------- | ------------ | ----------- |
2020
| [YuNet](./models/face_detection_yunet) | Face Detection | 160x120 | 1.45 | 6.22 | 12.18 | 4.04 | 86.69 |
2121
| [SFace](./models/face_recognition_sface) | Face Recognition | 112x112 | 8.65 | 99.20 | 24.88 | 46.25 | --- |
22+
| [FER](./models/facial_expression_recognition/) | Facial Expression Recognition | 112x112 | 4.43 | 49.86 | 31.07 | 108.53\* | --- |
2223
| [LPD-YuNet](./models/license_plate_detection_yunet/) | License Plate Detection | 320x240 | --- | 168.03 | 56.12 | 29.53 | --- |
2324
| [YOLOX](./models/object_detection_yolox/) | Object Detection | 640x640 | 176.68 | 1496.70 | 388.95 | 420.98 | --- |
2425
| [NanoDet](./models/object_detection_nanodet/) | Object Detection | 416x416 | 157.91 | 220.36 | 64.94 | 116.64 | --- |
@@ -62,6 +63,10 @@ Some examples are listed below. You can find more in the directory of each model
6263

6364
![largest selfie](./models/face_detection_yunet/examples/largest_selfie.jpg)
6465

66+
### Facial Expression Recognition with Progressive Teacher(./models/facial_expression_recognition/)
67+
68+
![fer demo](./models/facial_expression_recognition/examples/selfie.jpg)
69+
6570
### Human Segmentation with [PP-HumanSeg](./models/human_segmentation_pphumanseg/)
6671

6772
![messi](./models/human_segmentation_pphumanseg/examples/messi.jpg)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
Benchmark:
2+
name: "Facial Expression Recognition Benchmark"
3+
type: "Recognition"
4+
data:
5+
path: "benchmark/data/facial_expression_recognition/fer_evaluation"
6+
files: ["RAF_test_0_61.jpg", "RAF_test_0_30.jpg", "RAF_test_6_25.jpg"]
7+
metric: # 'sizes' is omitted since this model requires input of fixed size
8+
warmup: 30
9+
repeat: 10
10+
reduction: "median"
11+
backend: "default"
12+
target: "cpu"
13+
14+
Model:
15+
name: "FacialExpressionRecog"
16+
modelPath: "models/facial_expression_recognition/facial_expression_recognition_mobilefacenet_2022july.onnx"

benchmark/download_data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ def get_confirm_token(response): # in case of large files
173173
url='https://drive.google.com/u/0/uc?id=1BRIozREIzqkm_aMQ581j93oWoS-6TLST&export=download',
174174
sha='03892b9036c58d9400255ff73858caeec1f46609',
175175
filename='face_recognition.zip'),
176+
facial_expression_recognition=Downloader(name='facial_expression_recognition',
177+
url='https://drive.google.com/u/0/uc?id=13ZE0Pz302z1AQmBmYGuowkTiEXVLyFFZ&export=download',
178+
sha='8f757559820c8eaa1b1e0065f9c3bbbd4f49efe2',
179+
filename='facial_expression_recognition.zip'),
176180
text=Downloader(name='text',
177181
url='https://drive.google.com/u/0/uc?id=1lTQdZUau7ujHBqp0P6M1kccnnJgO-dRj&export=download',
178182
sha='a40cf095ceb77159ddd2a5902f3b4329696dd866',

benchmark/utils/dataloaders/recognition.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ def __init__(self, **kwargs):
1616
def _load_label(self):
1717
labels = dict.fromkeys(self._files, None)
1818
for filename in self._files:
19-
labels[filename] = np.loadtxt(os.path.join(self._path, '{}.txt'.format(filename[:-4])), ndmin=2)
19+
if os.path.exists(os.path.join(self._path, '{}.txt'.format(filename[:-4]))):
20+
labels[filename] = np.loadtxt(os.path.join(self._path, '{}.txt'.format(filename[:-4])), ndmin=2)
21+
else:
22+
labels[filename] = None
2023
return labels
2124

2225
def __iter__(self):

benchmark/utils/metrics/recognition.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,20 @@ def forward(self, model, *args, **kwargs):
1212
img, bboxes = args
1313

1414
self._timer.reset()
15-
for idx, bbox in enumerate(bboxes):
15+
if bboxes is not None:
16+
for idx, bbox in enumerate(bboxes):
17+
for _ in range(self._warmup):
18+
model.infer(img, bbox)
19+
for _ in range(self._repeat):
20+
self._timer.start()
21+
model.infer(img, bbox)
22+
self._timer.stop()
23+
else:
1624
for _ in range(self._warmup):
17-
model.infer(img, bbox)
25+
model.infer(img, None)
1826
for _ in range(self._repeat):
1927
self._timer.start()
20-
model.infer(img, bbox)
28+
model.infer(img, None)
2129
self._timer.stop()
2230

2331
return self._getResult()

models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .license_plate_detection_yunet.lpd_yunet import LPD_YuNet
1515
from .object_detection_nanodet.nanodet import NanoDet
1616
from .object_detection_yolox.yolox import YoloX
17+
from .facial_expression_recognition.facial_fer_model import FacialExpressionRecog
1718

1819
class Registery:
1920
def __init__(self, name):
@@ -43,4 +44,4 @@ def register(self, item):
4344
MODELS.register(LPD_YuNet)
4445
MODELS.register(NanoDet)
4546
MODELS.register(YoloX)
46-
47+
MODELS.register(FacialExpressionRecog)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
2+
# Progressive Teacher
3+
4+
Progressive Teacher: [Boosting Facial Expression Recognition by A Semi-Supervised Progressive Teacher](https://scholar.google.com/citations?view_op=view_citation&hl=zh-CN&user=OCwcfAwAAAAJ&citation_for_view=OCwcfAwAAAAJ:u5HHmVD_uO8C)
5+
6+
Note:
7+
- Progressive Teacher is contributed by [Jing Jiang](https://scholar.google.com/citations?user=OCwcfAwAAAAJ&hl=zh-CN).
8+
- [MobileFaceNet](https://link.springer.com/chapter/10.1007/978-3-319-97909-0_46) is used as the backbone and the model is able to classify seven basic facial expressions (angry, disgust, fearful, happy, neutral, sad, surprised).
9+
- [facial_expression_recognition_mobilefacenet_2022july.onnx](https://github.com/opencv/opencv_zoo/raw/master/models/facial_expression_recognition/facial_expression_recognition_mobilefacenet_2022july.onnx) is implemented thanks to [Chengrui Wang](https://github.com/opencv).
10+
11+
Results of accuracy evaluation on [RAF-DB](http://whdeng.cn/RAF/model1.html).
12+
13+
| Models | Accuracy |
14+
|-------------|----------|
15+
| Progressive Teacher | 88.27% |
16+
17+
18+
## Demo
19+
20+
***NOTE***: This demo uses [../face_detection_yunet](../face_detection_yunet) as face detector, which supports 5-landmark detection for now (2021sep).
21+
22+
Run the following command to try the demo:
23+
```shell
24+
# recognize the facial expression on images
25+
python demo.py --input /path/to/image
26+
```
27+
28+
### Example outputs
29+
30+
Note: Zoom in to to see the recognized facial expression in the top-left corner of each face boxes.
31+
32+
![fer demo](./examples/selfie.jpg)
33+
34+
## License
35+
36+
All files in this directory are licensed under [Apache 2.0 License](./LICENSE).
37+
38+
## Reference
39+
40+
- https://ieeexplore.ieee.org/abstract/document/9629313
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import sys
2+
import argparse
3+
import copy
4+
import datetime
5+
6+
import numpy as np
7+
import cv2 as cv
8+
9+
from facial_fer_model import FacialExpressionRecog
10+
11+
sys.path.append('../face_detection_yunet')
12+
from yunet import YuNet
13+
14+
15+
def str2bool(v):
16+
if v.lower() in ['on', 'yes', 'true', 'y', 't']:
17+
return True
18+
elif v.lower() in ['off', 'no', 'false', 'n', 'f']:
19+
return False
20+
else:
21+
raise NotImplementedError
22+
23+
24+
backends = [cv.dnn.DNN_BACKEND_OPENCV, cv.dnn.DNN_BACKEND_CUDA]
25+
targets = [cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_CUDA, cv.dnn.DNN_TARGET_CUDA_FP16]
26+
help_msg_backends = "Choose one of the computation backends: {:d}: OpenCV implementation (default); {:d}: CUDA"
27+
help_msg_targets = "Chose one of the target computation devices: {:d}: CPU (default); {:d}: CUDA; {:d}: CUDA fp16"
28+
try:
29+
backends += [cv.dnn.DNN_BACKEND_TIMVX]
30+
targets += [cv.dnn.DNN_TARGET_NPU]
31+
help_msg_backends += "; {:d}: TIMVX"
32+
help_msg_targets += "; {:d}: NPU"
33+
except:
34+
print('This version of OpenCV does not support TIM-VX and NPU. Visit https://github.com/opencv/opencv/wiki/TIM-VX-Backend-For-Running-OpenCV-On-NPU for more information.')
35+
36+
parser = argparse.ArgumentParser(description='Facial Expression Recognition')
37+
parser.add_argument('--input', '-i', type=str, help='Path to the input image. Omit for using default camera.')
38+
parser.add_argument('--model', '-fm', type=str, default='./facial_expression_recognition_mobilefacenet_2022july.onnx', help='Path to the facial expression recognition model.')
39+
parser.add_argument('--backend', '-b', type=int, default=backends[0], help=help_msg_backends.format(*backends))
40+
parser.add_argument('--target', '-t', type=int, default=targets[0], help=help_msg_targets.format(*targets))
41+
parser.add_argument('--save', '-s', type=str, default=False, help='Set true to save results. This flag is invalid when using camera.')
42+
parser.add_argument('--vis', '-v', type=str2bool, default=True, help='Set true to open a window for result visualization. This flag is invalid when using camera.')
43+
args = parser.parse_args()
44+
45+
46+
def visualize(image, det_res, fer_res, box_color=(0, 255, 0), text_color=(0, 0, 255)):
47+
48+
print('%s %3d faces detected.' % (datetime.datetime.now(), len(det_res)))
49+
50+
output = image.copy()
51+
landmark_color = [
52+
(255, 0, 0), # right eye
53+
(0, 0, 255), # left eye
54+
(0, 255, 0), # nose tip
55+
(255, 0, 255), # right mouth corner
56+
(0, 255, 255) # left mouth corner
57+
]
58+
59+
for ind, (det, fer_type) in enumerate(zip(det_res, fer_res)):
60+
bbox = det[0:4].astype(np.int32)
61+
fer_type = FacialExpressionRecog.getDesc(fer_type)
62+
print("Face %2d: %d %d %d %d %s." % (ind, bbox[0], bbox[1], bbox[0]+bbox[2], bbox[1]+bbox[3], fer_type))
63+
cv.rectangle(output, (bbox[0], bbox[1]), (bbox[0]+bbox[2], bbox[1]+bbox[3]), box_color, 2)
64+
cv.putText(output, fer_type, (bbox[0], bbox[1]+12), cv.FONT_HERSHEY_DUPLEX, 0.5, text_color)
65+
landmarks = det[4:14].astype(np.int32).reshape((5, 2))
66+
for idx, landmark in enumerate(landmarks):
67+
cv.circle(output, landmark, 2, landmark_color[idx], 2)
68+
return output
69+
70+
71+
def process(detect_model, fer_model, frame):
72+
h, w, _ = frame.shape
73+
detect_model.setInputSize([w, h])
74+
dets = detect_model.infer(frame)
75+
76+
if dets is None:
77+
return False, None, None
78+
79+
fer_res = np.zeros(0, dtype=np.int8)
80+
for face_points in dets:
81+
fer_res = np.concatenate((fer_res, fer_model.infer(frame, face_points[:-1])), axis=0)
82+
return True, dets, fer_res
83+
84+
85+
if __name__ == '__main__':
86+
detect_model = YuNet(modelPath='../face_detection_yunet/face_detection_yunet_2022mar.onnx')
87+
88+
fer_model = FacialExpressionRecog(modelPath=args.model,
89+
backendId=args.backend,
90+
targetId=args.target)
91+
92+
# If input is an image
93+
if args.input is not None:
94+
image = cv.imread(args.input)
95+
96+
# Get detection and fer results
97+
status, dets, fer_res = process(detect_model, fer_model, image)
98+
99+
if status:
100+
# Draw results on the input image
101+
image = visualize(image, dets, fer_res)
102+
103+
# Save results
104+
if args.save:
105+
cv.imwrite('result.jpg', image)
106+
print('Results saved to result.jpg\n')
107+
108+
# Visualize results in a new window
109+
if args.vis:
110+
cv.namedWindow(args.input, cv.WINDOW_AUTOSIZE)
111+
cv.imshow(args.input, image)
112+
cv.waitKey(0)
113+
else: # Omit input to call default camera
114+
deviceId = 0
115+
cap = cv.VideoCapture(deviceId)
116+
117+
while cv.waitKey(1) < 0:
118+
hasFrame, frame = cap.read()
119+
if not hasFrame:
120+
print('No frames grabbed!')
121+
break
122+
123+
# Get detection and fer results
124+
status, dets, fer_res = process(detect_model, fer_model, frame)
125+
126+
if status:
127+
# Draw results on the input image
128+
frame = visualize(frame, dets, fer_res)
129+
130+
# Visualize results in a new window
131+
cv.imshow('FER Demo', frame)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:541597ca330e0e3babe883d0fa6ab121b0e3da65c9cc099c05ff274b3106a658
3+
size 1340132
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:4f61307602fc089ce20488a31d4e4614e3c9753a7d6c41578c854858b183e1a9
3+
size 4791892

0 commit comments

Comments
 (0)