Skip to content

Commit c010280

Browse files
authored
[pose-detection]Add model comparison to demo. (#660)
1 parent 38d242a commit c010280

File tree

4 files changed

+109
-74
lines changed

4 files changed

+109
-74
lines changed

pose-detection/demo/src/camera.js

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ export class Camera {
104104
*/
105105
drawKeypoints(keypoints) {
106106
const keypointInd =
107-
posedetection.util.getKeypointIndexBySide(params.STATE.model.model);
107+
posedetection.util.getKeypointIndexBySide(params.STATE.model);
108108
this.ctx.fillStyle = 'White';
109109
this.ctx.strokeStyle = 'White';
110110
this.ctx.lineWidth = params.DEFAULT_LINE_WIDTH;
@@ -127,7 +127,7 @@ export class Camera {
127127
drawKeypoint(keypoint) {
128128
// If score is null, just show the keypoint.
129129
const score = keypoint.score != null ? keypoint.score : 1;
130-
const scoreThreshold = params.STATE.model.scoreThreshold || 0;
130+
const scoreThreshold = params.STATE.modelConfig.scoreThreshold || 0;
131131

132132
if (score >= scoreThreshold) {
133133
const circle = new Path2D();
@@ -146,22 +146,23 @@ export class Camera {
146146
this.ctx.strokeStyle = 'White';
147147
this.ctx.lineWidth = params.DEFAULT_LINE_WIDTH;
148148

149-
posedetection.util.getAdjacentPairs(params.STATE.model.model)
150-
.forEach(([i, j]) => {
151-
const kp1 = keypoints[i];
152-
const kp2 = keypoints[j];
153-
154-
// If score is null, just show the keypoint.
155-
const score1 = kp1.score != null ? kp1.score : 1;
156-
const score2 = kp2.score != null ? kp2.score : 1;
157-
const scoreThreshold = params.STATE.model.scoreThreshold || 0;
158-
159-
if (score1 >= scoreThreshold && score2 >= scoreThreshold) {
160-
this.ctx.beginPath();
161-
this.ctx.moveTo(kp1.x, kp1.y);
162-
this.ctx.lineTo(kp2.x, kp2.y);
163-
this.ctx.stroke();
164-
}
165-
});
149+
posedetection.util.getAdjacentPairs(params.STATE.model).forEach(([
150+
i, j
151+
]) => {
152+
const kp1 = keypoints[i];
153+
const kp2 = keypoints[j];
154+
155+
// If score is null, just show the keypoint.
156+
const score1 = kp1.score != null ? kp1.score : 1;
157+
const score2 = kp2.score != null ? kp2.score : 1;
158+
const scoreThreshold = params.STATE.modelConfig.scoreThreshold || 0;
159+
160+
if (score1 >= scoreThreshold && score2 >= scoreThreshold) {
161+
this.ctx.beginPath();
162+
this.ctx.moveTo(kp1.x, kp1.y);
163+
this.ctx.lineTo(kp2.x, kp2.y);
164+
this.ctx.stroke();
165+
}
166+
});
166167
}
167168
}

pose-detection/demo/src/index.js

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ import {setupStats} from './stats_panel';
2828
let detector, camera, stats;
2929

3030
async function createDetector() {
31-
switch (STATE.model.model) {
31+
switch (STATE.model) {
3232
case posedetection.SupportedModels.PoseNet:
33-
return posedetection.createDetector(STATE.model.model, {
33+
return posedetection.createDetector(STATE.model, {
3434
quantBytes: 4,
3535
architecture: 'MobileNetV1',
3636
outputStride: 16,
@@ -39,12 +39,12 @@ async function createDetector() {
3939
});
4040
case posedetection.SupportedModels.MediapipeBlazeposeUpperBody:
4141
case posedetection.SupportedModels.MediapipeBlazeposeFullBody:
42-
return posedetection.createDetector(STATE.model.model, {quantBytes: 4});
42+
return posedetection.createDetector(STATE.model, {quantBytes: 4});
4343
case posedetection.SupportedModels.MoveNet:
44-
const modelType = STATE.model.type == 'lightning' ?
44+
const modelType = STATE.modelConfig.type == 'lightning' ?
4545
posedetection.movenet.modelType.SINGLEPOSE_LIGHTNING :
4646
posedetection.movenet.modelType.SINGLEPOSE_THUNDER;
47-
return posedetection.createDetector(STATE.model.model, {modelType});
47+
return posedetection.createDetector(STATE.model, {modelType});
4848
}
4949
}
5050

@@ -55,9 +55,9 @@ async function checkGuiUpdate() {
5555
STATE.changeToSizeOption = null;
5656
}
5757

58-
if (STATE.changeToModel) {
58+
if (STATE.changeToModel != null) {
5959
detector.dispose();
60-
detector = await createDetector(STATE.model.model);
60+
detector = await createDetector(STATE.model);
6161
STATE.changeToModel = null;
6262
}
6363
}

pose-detection/demo/src/option_panel.js

Lines changed: 81 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
* =============================================================================
1616
*/
1717
import * as posedetection from '@tensorflow-models/pose-detection';
18-
import {type} from 'os';
1918

2019
import * as params from './params';
2120

@@ -39,77 +38,111 @@ export function setupDatGui(urlParams) {
3938
const modelFolder = gui.addFolder('Model');
4039

4140
const model = urlParams.get('model');
42-
const type = urlParams.get('type');
41+
let type = urlParams.get('type');
42+
43+
let modelConfigFolder;
44+
4345
switch (model) {
4446
case 'posenet':
45-
addPoseNetControllers(modelFolder);
47+
params.STATE.model = posedetection.SupportedModels.PoseNet;
4648
break;
4749
case 'movenet':
48-
addMoveNetControllers(modelFolder, type);
50+
params.STATE.model = posedetection.SupportedModels.MoveNet;
51+
if (type !== 'lightning' && type !== 'thunder') {
52+
// Nulify invalid value.
53+
type = null;
54+
}
4955
break;
5056
case 'blazepose':
51-
addBlazePoseControllers(modelFolder, type);
57+
params.STATE.model = type === 'upperbody' ?
58+
posedetection.SupportedModels.MediapipeBlazeposeUpperBody :
59+
posedetection.SupportedModels.MediapipeBlazeposeFullBody;
60+
if (type !== 'fullbody' && type !== 'upperbody') {
61+
// Nulify invalid value.
62+
type = null;
63+
}
5264
break;
5365
default:
5466
alert(`${urlParams.get('model')}`);
5567
break;
5668
}
5769

58-
modelFolder.open();
70+
const modelController = modelFolder.add(
71+
params.STATE, 'model', Object.values(posedetection.SupportedModels));
5972

60-
return gui;
61-
}
73+
modelController.onChange(model => {
74+
params.STATE.changeToModel = model;
6275

63-
// The MoveNet model config folder contains options for MoveNet config
64-
// settings.
65-
function addMoveNetControllers(modelFolder, type) {
66-
params.STATE.model = {
67-
model: posedetection.SupportedModels.MoveNet,
68-
...params.MOVENET_CONFIG
69-
};
70-
71-
params.STATE.model.type =
72-
type !== 'thunder' && type !== 'lightning' ? 'thunder' : type;
73-
74-
const typeController =
75-
modelFolder.add(params.STATE.model, 'type', ['thunder', 'lightning']);
76-
typeController.onChange(type => {
77-
params.STATE.changeToModel = type;
76+
// We don't pass in type, so that it will use default type when switching
77+
// models.
78+
modelConfigFolder = updateModelConfigFolder(gui, model, modelConfigFolder);
79+
80+
modelConfigFolder.open();
7881
});
7982

80-
modelFolder.add(params.STATE.model, 'scoreThreshold', 0, 1);
81-
}
83+
modelFolder.open();
8284

83-
// The Blazepose model config folder contains options for Blazepose config
84-
// settings.
85-
function addBlazePoseControllers(modelFolder, type) {
86-
params.STATE.model = {...params.BLAZEPOSE_CONFIG};
85+
// For initialization, pass in type from url.
86+
modelConfigFolder =
87+
updateModelConfigFolder(gui, params.STATE.model, modelConfigFolder, type);
88+
89+
modelConfigFolder.open();
90+
91+
return gui;
92+
}
8793

88-
params.STATE.model.model = type === 'upperbody' ?
89-
posedetection.SupportedModels.MediapipeBlazeposeUpperBody :
90-
posedetection.SupportedModels.MediapipeBlazeposeFullBody;
94+
function updateModelConfigFolder(gui, model, modelConfigFolder, type) {
95+
if (modelConfigFolder != null) {
96+
gui.removeFolder(modelConfigFolder);
97+
}
9198

92-
params.STATE.model.type = type === 'upperbody' ? 'upperbody' : 'fullbody';
99+
const newModelConfigFolder = gui.addFolder('Model Config');
93100

94-
const typeController =
95-
modelFolder.add(params.STATE.model, 'type', ['fullbody', 'upperbody']);
96-
typeController.onChange(type => {
97-
params.STATE.changeToModel = type;
98-
params.STATE.model.model = type === 'upperbody' ?
99-
posedetection.SupportedModels.MediapipeBlazeposeUpperBody :
100-
posedetection.SupportedModels.MediapipeBlazeposeFullBody;
101-
})
101+
switch (model) {
102+
case posedetection.SupportedModels.PoseNet:
103+
addPoseNetControllers(newModelConfigFolder);
104+
break;
105+
case posedetection.SupportedModels.MoveNet:
106+
addMoveNetControllers(newModelConfigFolder, type);
107+
break;
108+
case posedetection.SupportedModels.MediapipeBlazeposeUpperBody:
109+
case posedetection.SupportedModels.MediapipeBlazeposeFullBody:
110+
addBlazePoseControllers(newModelConfigFolder);
111+
break;
112+
default:
113+
alert(`Model ${model} is not supported.`);
114+
}
102115

103-
modelFolder.add(params.STATE.model, 'scoreThreshold', 0, 1);
116+
return newModelConfigFolder;
104117
}
105118

106119
// The PoseNet model config folder contains options for PoseNet config
107120
// settings.
108-
function addPoseNetControllers(modelFolder) {
109-
params.STATE.model = {
110-
model: posedetection.SupportedModels.PoseNet,
111-
...params.POSENET_CONFIG
112-
};
121+
function addPoseNetControllers(modelConfigFolder) {
122+
params.STATE.modelConfig = {...params.POSENET_CONFIG};
123+
modelConfigFolder.add(params.STATE.modelConfig, 'scoreThreshold', 0, 1);
124+
}
125+
126+
// The MoveNet model config folder contains options for MoveNet config
127+
// settings.
128+
function addMoveNetControllers(modelConfigFolder, type) {
129+
params.STATE.modelConfig = {...params.MOVENET_CONFIG};
130+
params.STATE.modelConfig.type = type != null ? type : 'thunder';
131+
132+
const typeController = modelConfigFolder.add(
133+
params.STATE.modelConfig, 'type', ['thunder', 'lightning']);
134+
typeController.onChange(_ => {
135+
// Set changeToModel to non-null, so that we don't render any result when
136+
// changeToModel is non-null.
137+
params.STATE.changeToModel = params.STATE.model;
138+
});
113139

114-
modelFolder.add(params.STATE.model, 'scoreThreshold', 0, 1);
140+
modelConfigFolder.add(params.STATE.modelConfig, 'scoreThreshold', 0, 1);
141+
}
142+
143+
// The Blazepose model config folder contains options for Blazepose config
144+
// settings.
145+
function addBlazePoseControllers(modelConfigFolder) {
146+
params.STATE.modelConfig = {...params.BLAZEPOSE_CONFIG};
147+
modelConfigFolder.add(params.STATE.modelConfig, 'scoreThreshold', 0, 1);
115148
}

pose-detection/demo/src/params.js

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,6 @@ export const POSENET_CONFIG = {
3333
scoreThreshold: 0.5
3434
};
3535
export const MOVENET_CONFIG = {
36+
type: 'lightning',
3637
scoreThreshold: 0.3
3738
};

0 commit comments

Comments
 (0)