|
| 1 | +# Object Detection (coco-ssd) |
| 2 | + |
| 3 | +Object detection model aims to localize and identify multiple objects in a single image. |
| 4 | + |
| 5 | +This model is a TensorFlow.js port of the SSD-COCO model. For more information about Tensorflow object detection API, check out this readme in |
| 6 | +[tensorflow/object_detection](https://github.com/tensorflow/models/blob/master/research/object_detection/README.md). |
| 7 | + |
| 8 | +This model detects objects defined in the COCO dataset, which is a large-scale object detection, segmentation, and captioning dataset, you can find more information [here](http://cocodataset.org/#home). The model is capable of detecting [90 classes of objects](./src/classes.ts). SSD stands for Single Shot MultiBox Detection. |
| 9 | + |
| 10 | +This TensorFlow.js model does not require you to know about machine learning. |
| 11 | +It can take as input any browser-based image elements (`<img>`, `<video>`, `<canvas>` |
| 12 | +elements, for example) and returns an array of most bounding boxes with class name and confidence level. |
| 13 | + |
| 14 | +## Usage |
| 15 | + |
| 16 | +There are two main ways to get this model in your JavaScript project: via script tags or by installing it from NPM and using a build tool like Parcel, WebPack, or Rollup. |
| 17 | + |
| 18 | +### via Script Tag |
| 19 | + |
| 20 | +```html |
| 21 | +<!-- Load TensorFlow.js. This is required to use object detection model. --> |
| 22 | +< script src= "https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]"> </ script> |
| 23 | +<!-- Load the object detection model. --> |
| 24 | +< script src= "https://cdn.jsdelivr.net/npm/@tensorflow-models/[email protected]"> </ script> |
| 25 | + |
| 26 | +<!-- Replace this with your image. Make sure CORS settings allow reading the image! --> |
| 27 | +<img id="img" src="cat.jpg"/> |
| 28 | + |
| 29 | +<!-- Place your code in the script tag below. You can also use an external .js file --> |
| 30 | +<script> |
| 31 | + // Notice there is no 'import' statement. 'objectDetection' and 'tf' is |
| 32 | + // available on the index-page because of the script tag above. |
| 33 | +
|
| 34 | + const img = document.getElementById('img'); |
| 35 | +
|
| 36 | + // Load the model. |
| 37 | + objectDetection.load().then(model => { |
| 38 | + // Classify the image. |
| 39 | + model.detect(img).then(predictions => { |
| 40 | + console.log('Predictions: ', predictions); |
| 41 | + }); |
| 42 | + }); |
| 43 | +</script> |
| 44 | +``` |
| 45 | + |
| 46 | +### via NPM |
| 47 | + |
| 48 | +```js |
| 49 | +// Note: you do not need to import @tensorflow/tfjs here. |
| 50 | + |
| 51 | +import * as objectDetection from '@tensorflow-models/object-detection'; |
| 52 | + |
| 53 | +const img = document.getElementById('img'); |
| 54 | + |
| 55 | +// Load the model. |
| 56 | +const model = await objectDetection.load(); |
| 57 | + |
| 58 | +// Classify the image. |
| 59 | +const predictions = await model.detect(img); |
| 60 | + |
| 61 | +console.log('Predictions: '); |
| 62 | +console.log(predictions); |
| 63 | +``` |
| 64 | + |
| 65 | +You can also take a look at the [demo app](./demo). |
| 66 | + |
| 67 | +## API |
| 68 | + |
| 69 | +#### Loading the model |
| 70 | +`object-detection` is the module name, which is automatically included when you use the `<script src>` method. When using ES6 imports, object-detection is the module. |
| 71 | + |
| 72 | +```ts |
| 73 | +objectDetection.load( |
| 74 | + base?: 'ssd_mobilenet_v1' | 'ssd_mobilenet_v2' | 'ssdlite_mobilenet_v2' |
| 75 | +) |
| 76 | +``` |
| 77 | + |
| 78 | +Args: |
| 79 | + **base:** Controls the base cnn model, can be 'ssd_mobilenet_v1', 'ssd_mobilenet_v2' or 'ssdlite_mobilenet_v2'. Defaults to 'ssdlite_mobilenet_v2'. |
| 80 | + ssdlite_mobilenet_v2 is smallest in size, and fastest in inference speed. |
| 81 | + ssdlite_mobilenet_v2 has the highest classification accuracy. |
| 82 | + |
| 83 | +Returns a `model` object. |
| 84 | + |
| 85 | +#### Detecting the objects |
| 86 | + |
| 87 | +You can detect objects with the model without needing to create a Tensor. |
| 88 | +`model.detect` takes an input image element and returns an array of bounding boxes with class name and confidence level. |
| 89 | + |
| 90 | +This method exists on the model that is loaded from `objectDetection.load`. |
| 91 | + |
| 92 | +```ts |
| 93 | +model.detect( |
| 94 | + img: tf.Tensor3D | ImageData | HTMLImageElement | |
| 95 | + HTMLCanvasElement | HTMLVideoElement, maxDetectionSize: number |
| 96 | +) |
| 97 | +``` |
| 98 | + |
| 99 | +Args: |
| 100 | + |
| 101 | +- **img:** A Tensor or an image element to make a detection on. |
| 102 | +- **maxNumBoxes:** The maximum number of bounding boxes of detected objects. There can be multiple objects of the same class, but at different locations. Defaults to 20. |
| 103 | + |
| 104 | +Returns an array of classes and probabilities that looks like: |
| 105 | + |
| 106 | +```js |
| 107 | +[{ |
| 108 | + bbox: [x, y, width, height], |
| 109 | + class: "person", |
| 110 | + score: 0.8380282521247864 |
| 111 | +}, { |
| 112 | + bbox: [x, y, width, height], |
| 113 | + class: "kite", |
| 114 | + score: 0.74644153267145157 |
| 115 | +}] |
| 116 | +``` |
| 117 | + |
| 118 | +### Technical details for advance users |
| 119 | + |
| 120 | +This model is based on the TensorFlow object detection API, you can download the original models from [here](https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md#coco-trained-models). We applied following optimizations to improve the performance for browser execution: |
| 121 | + |
| 122 | + 1. Removed the post process graph from the original model. |
| 123 | + 2. Used single class NonMaxSuppression instead of original multiple classes NonMaxSuppression for faster speed with similar accuracy. |
| 124 | + 3. Executes NonMaxSuppression operations on CPU backend instead of WebGL to avoid delays on the texture downloads. |
| 125 | + |
| 126 | +Here is the converter command for removing the post process graph. |
| 127 | + |
| 128 | +```sh |
| 129 | +tensorflowjs_converter --input_format=tf_saved_model \ |
| 130 | + --output_node_names='Postprocessor/ExpandDims_1,Postprocessor/Slice' \ |
| 131 | + --saved_model_tags=serve \ |
| 132 | + ./saved_model \ |
| 133 | + ./web_model |
| 134 | +``` |
0 commit comments