Skip to content

Commit 00ec01f

Browse files
authored
Add BodyPix Body Segmentation implementation (#891)
* Add BodyPix Body Segmentation implementation * Update channel diff constant * Remove dependancy on body pix
1 parent db14bc6 commit 00ec01f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+9338
-10
lines changed

body-segmentation/README.md

Lines changed: 288 additions & 5 deletions
Large diffs are not rendered by default.

body-segmentation/images/bokeh.gif

1.51 MB
Loading
337 KB
Loading

body-segmentation/images/drawMask.jpg

237 KB
Loading
383 KB
Loading
Loading
189 KB
Loading
194 KB
Loading
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# BodyPix
2+
3+
Body Segmentation - Body Pix wraps the BodyPix JS Solution within the familiar
4+
TFJS API [BodyPix](https://github.com/tensorflow/tfjs-models/tree/master/body-pix).
5+
6+
This model can be used to segment an image into pixels that are and are not part of a person, and into
7+
pixels that belong to each of twenty-four body parts. It works for multiple people in an input image or video.
8+
9+
--------------------------------------------------------------------------------
10+
11+
## Table of Contents
12+
13+
1. [Installation](#installation)
14+
2. [Usage](#usage)
15+
16+
## Installation
17+
18+
To use BodyPix:
19+
20+
Via script tags:
21+
22+
```html
23+
<!-- Require the peer dependencies of body-segmentation. -->
24+
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-core"></script>
25+
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-converter"></script>
26+
27+
<!-- You must explicitly require a TF.js backend if you're not using the TF.js union bundle. -->
28+
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-backend-webgl"></script>
29+
30+
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/body-segmentation"></script>
31+
```
32+
33+
Via npm:
34+
```sh
35+
yarn add @tensorflow-models/body-segmentation
36+
yarn add @tensorflow/tfjs-core, @tensorflow/tfjs-converter
37+
yarn add @tensorflow/tfjs-backend-webgl
38+
```
39+
40+
-----------------------------------------------------------------------
41+
## Usage
42+
43+
If you are using the Body Segmentation API via npm, you need to import the libraries first.
44+
45+
### Import the libraries
46+
47+
```javascript
48+
import * as bodySegmentation from '@tensorflow-models/body-segmentation';
49+
import '@tensorflow/tfjs-core';
50+
import '@tensorflow/tfjs-converter';
51+
// Register WebGL backend.
52+
import '@tensorflow/tfjs-backend-webgl';
53+
```
54+
55+
### Create a detector
56+
57+
Pass in `bodySegmentation.SupportedModels.BodyPix` from the
58+
`bodySegmentation.SupportedModel` enum list along with an optional `segmenterConfig` to the
59+
`createSegmenter` method to load and initialize the model.
60+
61+
**By default**, BodyPix loads a MobileNetV1 architecture with a **`0.75`** multiplier. This is recommended for computers with mid-range/lower-end GPUs. A model with a **`0.50`** multiplier is recommended for mobile. The ResNet architecture is recommended for computers with even more powerful GPUs.
62+
63+
`segmenterConfig` is an object that defines BodyPix specific configurations for `BodyPixModelConfig`:
64+
65+
* **architecture** - Can be either `MobileNetV1` or `ResNet50`. It determines which BodyPix architecture to load.
66+
67+
* **outputStride** - Can be one of `8`, `16`, `32` (Stride `16`, `32` are supported for the ResNet architecture and stride `8`, and `16` are supported for the MobileNetV1 architecture). It specifies the output stride of the BodyPix model. The smaller the value, the larger the output resolution, and more accurate the model at the cost of speed. ***A larger value results in a smaller model and faster prediction time but lower accuracy***.
68+
69+
* **multiplier** - Can be one of `1.0`, `0.75`, or `0.50` (The value is used *only* by the MobileNetV1 architecture and not by the ResNet architecture). It is the float multiplier for the depth (number of channels) for all convolution ops. The larger the value, the larger the size of the layers, and more accurate the model at the cost of speed. ***A smaller value results in a smaller model and faster prediction time but lower accuracy***.
70+
71+
* **quantBytes** - This argument controls the bytes used for weight quantization. The available options are:
72+
73+
- `4`. 4 bytes per float (no quantization). Leads to highest accuracy and original model size.
74+
- `2`. 2 bytes per float. Leads to slightly lower accuracy and 2x model size reduction.
75+
- `1`. 1 byte per float. Leads to lower accuracy and 4x model size reduction.
76+
77+
The following table contains the corresponding BodyPix 2.0 model checkpoint sizes (widthout gzip) when using different quantization bytes:
78+
79+
| Architecture | quantBytes=4 | quantBytes=2 | quantBytes=1 |
80+
| ------------------ |:------------:|:------------:|:------------:|
81+
| ResNet50 | ~90MB | ~45MB | ~22MB |
82+
| MobileNetV1 (1.00) | ~13MB | ~6MB | ~3MB |
83+
| MobileNetV1 (0.75) | ~5MB | ~2MB | ~1MB |
84+
| MobileNetV1 (0.50) | ~2MB | ~1MB | ~0.6MB |
85+
86+
87+
* **modelUrl** - An optional string that specifies custom url of the model. This is useful for local development or countries that don't have access to the models hosted on GCP.
88+
89+
```javascript
90+
const model = bodySegmentation.SupportedModels.BodyPix;
91+
const segmenterConfig = {
92+
architecture: 'ResNet50',
93+
outputStride: 32,
94+
quantBytes: 2
95+
};
96+
segmenter = await bodySegmentation.createSegmenter(model, segmenterConfig);
97+
```
98+
99+
### Run inference
100+
101+
Now you can use the segmenter to segment people. The `segmentPeople` method
102+
accepts both image and video in many formats, including:
103+
`HTMLVideoElement`, `HTMLImageElement`, `HTMLCanvasElement`, `ImageData`, `Tensor3D`. If you want more
104+
options, you can pass in a second `segmentationConfig` parameter.
105+
106+
`segmentationConfig` is an object that defines BodyPix specific configurations for `BodyPixSegmentationConfig`:
107+
108+
* **multiSegmentation** - Required. If set to true, then each person is segmented in a separate output, otherwise all people are segmented together in one segmentation.
109+
* **segmentBodyParts** - Required. If set to true, then 24 body parts are segmented in the output, otherwise only foreground / background binary segmentation is performed.
110+
* **flipHorizontal** - Defaults to false. If the segmentation & pose should be flipped/mirrored horizontally. This should be set to true for videos where the video is by default flipped horizontally (i.e. a webcam), and you want the segmentation & pose to be returned in the proper orientation.
111+
* **internalResolution** - Defaults to `medium`. The internal resolution percentage that the input is resized to before inference. The larger the `internalResolution` the more accurate the model at the cost of slower prediction times. Available values are `low`, `medium`, `high`, `full`, or a percentage value between 0 and 1. The values `low`, `medium`, `high`, and
112+
`full` map to 0.25, 0.5, 0.75, and 1.0 correspondingly.
113+
* **segmentationThreshold** - Defaults to 0.7. Must be between 0 and 1. For each pixel, the model estimates a score between 0 and 1 that indicates how confident it is that part of a person is displayed in that pixel. This *segmentationThreshold* is used to convert these values
114+
to binary 0 or 1s by determining the minimum value a pixel's score must have to be considered part of a person. In essence, a higher value will create a tighter crop
115+
around a person but may result in some pixels being that are part of a person being excluded from the returned segmentation mask.
116+
* **maxDetections** - Defaults to 10. For pose estimation, the maximum number of person poses to detect per image.
117+
* **scoreThreshold** - Defaults to 0.3. For pose estimation, only return individual person detections that have root part score greater or equal to this value.
118+
* **nmsRadius** - Defaults to 20. For pose estimation, the non-maximum suppression part distance in pixels. It needs to be strictly positive. Two parts suppress each other if they are less than `nmsRadius` pixels away.
119+
120+
If **multiSegmentation** is set to true then the following additional parameters can be adjusted:
121+
122+
* **minKeypointScore** - Default to 0.3. Keypoints above the score are used for matching and assigning segmentation mask to each person..
123+
* **refineSteps** - Default to 10. The number of refinement steps used when assigning the individual person segmentations. It needs to be strictly positive. The larger the higher the accuracy and slower the inference.
124+
125+
The following code snippet demonstrates how to run the model inference:
126+
127+
```javascript
128+
const segmentationConfig = {multiSegmentation: true, segmentBodyParts: false};
129+
const people = await segmenter.segmentPeople(image, segmentationConfig);
130+
```
131+
132+
When `multiSegmentation` is set to false, the returned `people` array contains a single element where all the people segmented in the image are found in that single segmentation element. When `multiSegmentation` is set to true, then the length of the array will be equal to the number of detected people, each segmentation containing one person.
133+
134+
When `segmentBodyParts` is set to false, the only label returned by the maskValueToLabel function is 'person'. When `segmentBodyParts` is set to true, the maskValueToLabel function will return one of the body parts defined by BodyPix, where the mapping of mask values to label is as follows:
135+
136+
| Part Id | Part Name | Part Id | Part Name |
137+
|---------|------------------------|---------|------------------------|
138+
| 0 | left_face | 12 | torso_front |
139+
| 1 | right_face | 13 | torso_back |
140+
| 2 | left_upper_arm_front | 14 | left_upper_leg_front |
141+
| 3 | left_upper_arm_back | 15 | left_upper_leg_back
142+
| 4 | right_upper_arm_front | 16 | right_upper_leg_front
143+
| 5 | right_upper_arm_back | 17 | right_upper_leg_back
144+
| 6 | left_lower_arm_front | 18 | left_lower_leg_front
145+
| 7 | left_lower_arm_back | 19 | left_lower_leg_back
146+
| 8 | right_lower_arm_front | 20 | right_lower_leg_front
147+
| 9 | right_lower_arm_back | 21 | right_lower_leg_back
148+
| 10 | left_hand | 22 | left_foot
149+
| 11 | right_hand | 23 | right_foot
150+
151+
152+
Please refer to the Body Segmentation API
153+
[README](https://github.com/tensorflow/tfjs-models/blob/master/body-segmentation/README.md#how-to-run-it)
154+
about the structure of the returned `people` array.
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
// tslint:disable-next-line: no-imports-from-dist
19+
import {BROWSER_ENVS, describeWithFlags} from '@tensorflow/tfjs-core/dist/jasmine_util';
20+
21+
import * as bodySegmentation from '../index';
22+
import {Mask} from '../shared/calculators/interfaces/common_interfaces';
23+
import {toImageDataLossy, toTensorLossy} from '../shared/calculators/mask_util';
24+
import * as renderUtil from '../shared/calculators/render_util';
25+
import {loadImage} from '../shared/test_util';
26+
import {BodyPixSegmentationConfig} from './types';
27+
28+
// Measured in channels.
29+
const DIFF_IMAGE = 30;
30+
31+
class CanvasImageSourceMask implements Mask {
32+
constructor(private mask: CanvasImageSource) {}
33+
34+
async toCanvasImageSource() {
35+
return this.mask;
36+
}
37+
38+
async toImageData() {
39+
return toImageDataLossy(this.mask);
40+
}
41+
42+
async toTensor() {
43+
return toTensorLossy(this.mask);
44+
}
45+
46+
getUnderlyingType() {
47+
return 'canvasimagesource' as const ;
48+
}
49+
}
50+
51+
async function getSegmentation(
52+
image: HTMLImageElement, config: BodyPixSegmentationConfig) {
53+
const segmenter = await bodySegmentation.createSegmenter(
54+
bodySegmentation.SupportedModels.BodyPix);
55+
56+
const segmentations = await segmenter.segmentPeople(image, config);
57+
return Promise.all(segmentations.map(async segmentation => {
58+
return {
59+
maskValueToLabel: segmentation.maskValueToLabel,
60+
// Convert to canvas image source to apply alpha-premultiplication.
61+
mask: new CanvasImageSourceMask(
62+
await segmentation.mask.toCanvasImageSource())
63+
};
64+
}));
65+
}
66+
67+
async function getBinaryMask(
68+
image: HTMLImageElement, expectedNumSegmentations?: number) {
69+
const segmentation = await getSegmentation(image, {
70+
multiSegmentation: expectedNumSegmentations != null,
71+
segmentBodyParts: false
72+
});
73+
74+
if (expectedNumSegmentations != null) {
75+
expect(segmentation.length).toBe(expectedNumSegmentations);
76+
}
77+
78+
const binaryMask = await renderUtil.toBinaryMask(
79+
segmentation, {r: 255, g: 255, b: 255, a: 255},
80+
{r: 0, g: 0, b: 0, a: 255});
81+
return binaryMask;
82+
}
83+
84+
async function getColoredMask(
85+
image: HTMLImageElement, expectedNumSegmentations?: number) {
86+
const segmentation = await getSegmentation(image, {
87+
multiSegmentation: expectedNumSegmentations != null,
88+
segmentBodyParts: true
89+
});
90+
91+
if (expectedNumSegmentations != null) {
92+
expect(segmentation.length).toBe(expectedNumSegmentations);
93+
}
94+
95+
const coloredMask = await renderUtil.toColoredMask(
96+
segmentation, bodySegmentation.bodyPixMaskValueToRainbowColor,
97+
{r: 255, g: 255, b: 255, a: 255});
98+
99+
return coloredMask;
100+
}
101+
102+
const WIDTH = 1049;
103+
const HEIGHT = 861;
104+
105+
async function expectImage(actual: ImageData, imageName: string) {
106+
const expectedImage = await loadImage(imageName, WIDTH, HEIGHT)
107+
.then(async image => toImageDataLossy(image));
108+
const mismatchedChannels = actual.data.reduce(
109+
(mismatched, channel, i) =>
110+
mismatched + +(channel !== expectedImage.data[i]),
111+
0);
112+
expect(mismatchedChannels).toBeLessThanOrEqual(DIFF_IMAGE);
113+
}
114+
115+
describeWithFlags('renderUtil', BROWSER_ENVS, () => {
116+
let image: HTMLImageElement;
117+
let timeout: number;
118+
119+
beforeAll(async () => {
120+
timeout = jasmine.DEFAULT_TIMEOUT_INTERVAL;
121+
jasmine.DEFAULT_TIMEOUT_INTERVAL = 120000; // 2mins
122+
123+
image = await loadImage('shared/three_people.jpg', WIDTH, HEIGHT);
124+
});
125+
126+
afterAll(() => {
127+
jasmine.DEFAULT_TIMEOUT_INTERVAL = timeout;
128+
});
129+
130+
it('Single Segmentation + No body parts.', async () => {
131+
const binaryMask = await getBinaryMask(image);
132+
await expectImage(binaryMask, 'shared/three_people_binary_mask.png');
133+
});
134+
135+
it('Multi Segmentation + No body parts.', async () => {
136+
const binaryMask = await getBinaryMask(image, 3);
137+
await expectImage(binaryMask, 'shared/three_people_binary_mask.png');
138+
});
139+
140+
it('Single Segmentation + Body parts.', async () => {
141+
const coloredMask = await getColoredMask(image);
142+
await expectImage(coloredMask, 'shared/three_people_colored_mask.png');
143+
});
144+
145+
it('Multi Segmentation + Body parts.', async () => {
146+
const coloredMask = await getColoredMask(image, 3);
147+
await expectImage(coloredMask, 'shared/three_people_colored_mask.png');
148+
});
149+
});
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/**
2+
* @license
3+
* Copyright 2021 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
import {Color} from '../shared/calculators/interfaces/common_interfaces';
18+
import {assertMaskValue} from '../shared/calculators/mask_util';
19+
20+
const RAINBOW_PART_COLORS: Array<[number, number, number]> = [
21+
[110, 64, 170], [143, 61, 178], [178, 60, 178], [210, 62, 167],
22+
[238, 67, 149], [255, 78, 125], [255, 94, 99], [255, 115, 75],
23+
[255, 140, 56], [239, 167, 47], [217, 194, 49], [194, 219, 64],
24+
[175, 240, 91], [135, 245, 87], [96, 247, 96], [64, 243, 115],
25+
[40, 234, 141], [28, 219, 169], [26, 199, 194], [33, 176, 213],
26+
[47, 150, 224], [65, 125, 224], [84, 101, 214], [99, 81, 195]
27+
];
28+
29+
export function bodyPixMaskValueToRainbowColor(maskValue: number): Color {
30+
assertMaskValue(maskValue);
31+
if (maskValue < RAINBOW_PART_COLORS.length) {
32+
const [r, g, b] = RAINBOW_PART_COLORS[maskValue];
33+
return {r, g, b, a: 255};
34+
}
35+
throw new Error(
36+
`Mask value must be in range [0, ${RAINBOW_PART_COLORS.length})`);
37+
}

0 commit comments

Comments
 (0)