|
19 | 19 | import * as tfconv from '@tensorflow/tfjs-converter';
|
20 | 20 | import * as tf from '@tensorflow/tfjs-core';
|
21 | 21 |
|
| 22 | +import {BaseModel} from './base_model'; |
22 | 23 | import {decodeOnlyPartSegmentation, decodePartSegmentation, toMaskTensor} from './decode_part_map';
|
23 |
| -import {MobileNet, MobileNetMultiplier} from './mobilenet'; |
| 24 | +import {MobileNet} from './mobilenet'; |
24 | 25 | import {decodePersonInstanceMasks, decodePersonInstancePartMasks} from './multi_person/decode_instance_masks';
|
25 | 26 | import {decodeMultiplePoses} from './multi_person/decode_multiple_poses';
|
26 | 27 | import {ResNet} from './resnet';
|
27 | 28 | import {mobileNetSavedModel, resNet50SavedModel} from './saved_models';
|
28 |
| -import {decodeSinglePose} from './sinlge_person/decode_single_pose'; |
| 29 | +import {decodeSinglePose} from './single_person/decode_single_pose'; |
29 | 30 | import {BodyPixArchitecture, BodyPixInput, BodyPixInternalResolution, BodyPixMultiplier, BodyPixOutputStride, BodyPixQuantBytes, Padding, PartSegmentation, PersonSegmentation} from './types';
|
30 | 31 | import {getInputSize, padAndResizeTo, scaleAndCropToInputTensorShape, scaleAndFlipPoses, toTensorBuffers3D, toValidInternalResolutionNumber} from './util';
|
31 | 32 |
|
32 |
| - |
33 | 33 | const APPLY_SIGMOID_ACTIVATION = true;
|
34 | 34 |
|
35 |
| -/** |
36 |
| - * BodyPix supports using various convolution neural network models |
37 |
| - * (e.g. ResNet and MobileNetV1) as its underlying base model. |
38 |
| - * The following BaseModel interface defines a unified interface for |
39 |
| - * creating such BodyPix base models. Currently both MobileNet (in |
40 |
| - * ./mobilenet.ts) and ResNet (in ./resnet.ts) implements the BaseModel |
41 |
| - * interface. New base models that conform to the BaseModel interface can be |
42 |
| - * added to BodyPix. |
43 |
| - */ |
44 |
| -export interface BaseModel { |
45 |
| - // The output stride of the base model. |
46 |
| - readonly outputStride: BodyPixOutputStride; |
47 |
| - |
48 |
| - /** |
49 |
| - * Predicts intermediate Tensor representations. |
50 |
| - * |
51 |
| - * @param input The input RGB image of the base model. |
52 |
| - * A Tensor of shape: [`inputResolution`, `inputResolution`, 3]. |
53 |
| - * |
54 |
| - * @return A dictionary of base model's intermediate predictions. |
55 |
| - * The returned dictionary should contains the following elements: |
56 |
| - * - heatmapScores: A Tensor3D that represents the keypoint heatmap scores. |
57 |
| - * - offsets: A Tensor3D that represents the offsets. |
58 |
| - * - displacementFwd: A Tensor3D that represents the forward displacement. |
59 |
| - * - displacementBwd: A Tensor3D that represents the backward displacement. |
60 |
| - * - segmentation: A Tensor3D that represents the segmentation of all people. |
61 |
| - * - longOffsets: A Tensor3D that represents the long offsets used for |
62 |
| - * instance grouping. |
63 |
| - * - partHeatmaps: A Tensor3D that represents the body part segmentation. |
64 |
| - */ |
65 |
| - predict(input: tf.Tensor3D): {[key: string]: tf.Tensor3D}; |
66 |
| - /** |
67 |
| - * Releases the CPU and GPU memory allocated by the model. |
68 |
| - */ |
69 |
| - dispose(): void; |
70 |
| -} |
71 |
| - |
72 | 35 | /**
|
73 | 36 | * BodyPix model loading is configurable using the following config dictionary.
|
74 | 37 | *
|
@@ -101,7 +64,7 @@ export interface BaseModel {
|
101 | 64 | export interface ModelConfig {
|
102 | 65 | architecture: BodyPixArchitecture;
|
103 | 66 | outputStride: BodyPixOutputStride;
|
104 |
| - multiplier?: MobileNetMultiplier; |
| 67 | + multiplier?: BodyPixMultiplier; |
105 | 68 | modelUrl?: string;
|
106 | 69 | quantBytes?: BodyPixQuantBytes;
|
107 | 70 | }
|
@@ -602,15 +565,15 @@ export class BodyPix {
|
602 | 565 | };
|
603 | 566 | });
|
604 | 567 |
|
605 |
| - const [scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer] = |
| 568 | + const [scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf] = |
606 | 569 | await toTensorBuffers3D([
|
607 | 570 | heatmapScoresRaw, offsetsRaw, displacementFwdRaw, displacementBwdRaw
|
608 | 571 | ]);
|
609 | 572 |
|
610 |
| - let poses = await decodeMultiplePoses( |
611 |
| - scoresBuffer, offsetsBuffer, displacementsFwdBuffer, |
612 |
| - displacementsBwdBuffer, this.baseModel.outputStride, |
613 |
| - config.maxDetections, config.scoreThreshold, config.nmsRadius); |
| 573 | + let poses = decodeMultiplePoses( |
| 574 | + scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf, |
| 575 | + this.baseModel.outputStride, config.maxDetections, |
| 576 | + config.scoreThreshold, config.nmsRadius); |
614 | 577 |
|
615 | 578 | poses = scaleAndFlipPoses(
|
616 | 579 | poses, [height, width],
|
@@ -849,15 +812,15 @@ export class BodyPix {
|
849 | 812 | };
|
850 | 813 | });
|
851 | 814 |
|
852 |
| - const [scoresBuffer, offsetsBuffer, displacementsFwdBuffer, displacementsBwdBuffer] = |
| 815 | + const [scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf] = |
853 | 816 | await toTensorBuffers3D([
|
854 | 817 | heatmapScoresRaw, offsetsRaw, displacementFwdRaw, displacementBwdRaw
|
855 | 818 | ]);
|
856 | 819 |
|
857 |
| - let poses = await decodeMultiplePoses( |
858 |
| - scoresBuffer, offsetsBuffer, displacementsFwdBuffer, |
859 |
| - displacementsBwdBuffer, this.baseModel.outputStride, |
860 |
| - config.maxDetections, config.scoreThreshold, config.nmsRadius); |
| 820 | + let poses = decodeMultiplePoses( |
| 821 | + scoresBuf, offsetsBuf, displacementsFwdBuf, displacementsBwdBuf, |
| 822 | + this.baseModel.outputStride, config.maxDetections, |
| 823 | + config.scoreThreshold, config.nmsRadius); |
861 | 824 |
|
862 | 825 | poses = scaleAndFlipPoses(
|
863 | 826 | poses, [height, width],
|
|
0 commit comments