Skip to content

Commit

Permalink
Add object detection and image segmentation models to the task API (#705
Browse files Browse the repository at this point in the history
)

* Add more models

* update

* update

* update

* test

* test

* test

* test

* test

* test

* test

* test

* test

* test

* update

* update

* address comments
  • Loading branch information
jinjingforever authored May 13, 2021
1 parent 59bd40b commit ab97e06
Show file tree
Hide file tree
Showing 25 changed files with 7,353 additions and 106 deletions.
19 changes: 11 additions & 8 deletions coco-ssd/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const BASE_PATH = 'https://storage.googleapis.com/tfjs-models/savedmodel/';

export {version} from './version';

/** @docinline */
export type ObjectDetectionBaseModel =
'mobilenet_v1'|'mobilenet_v2'|'lite_mobilenet_v2';

Expand All @@ -35,17 +36,19 @@ export interface DetectedObject {

/**
* Coco-ssd model loading is configurable using the following config dictionary.
*
* `base`: ObjectDetectionBaseModel. It determines wich PoseNet architecture
* to load. The supported architectures are: 'mobilenet_v1', 'mobilenet_v2' and
* 'lite_mobilenet_v2'. It is default to 'lite_mobilenet_v2'.
*
* `modelUrl`: An optional string that specifies custom url of the model. This
* is useful for area/countries that don't have access to the model hosted on
* GCP.
*/
export interface ModelConfig {
/**
* It determines wich object detection architecture to load. The supported
* architectures are: 'mobilenet_v1', 'mobilenet_v2' and 'lite_mobilenet_v2'.
* It is default to 'lite_mobilenet_v2'.
*/
base?: ObjectDetectionBaseModel;
/**
*
* An optional string that specifies custom url of the model. This is useful
* for area/countries that don't have access to the model hosted on GCP.
*/
modelUrl?: string;
}

Expand Down
24 changes: 12 additions & 12 deletions coco-ssd/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,23 @@
estree-walker "^1.0.1"
picomatch "^2.2.2"

"@tensorflow/tfjs-backend-cpu@^3.0.0-rc.1":
version "3.3.0"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-3.3.0.tgz#aa0a3ed2c6237a6e0c169678c5bd4b5a88766b1c"
integrity sha512-DLctv+PUZni26kQW1hq8jwQQ8u+GGc/p764WQIC4/IDagGtfGAUW1mHzWcTxtni2l4re1VrwE41ogWLhv4sGHg==
"@tensorflow/tfjs-backend-cpu@^3.3.0":
version "3.6.0"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-3.6.0.tgz#4e64a7cf1c33b203f71f8f77cd7b0ac1ef25a871"
integrity sha512-ZpAs17hPdKXadbtNjAsymYUILe8V7+pY4fYo8j25nfDTW/HfBpyAwsHPbMcA/n5zyJ7ZJtGKFcCUv1sl24KL1Q==
dependencies:
"@types/seedrandom" "2.4.27"
seedrandom "2.4.3"

"@tensorflow/tfjs-converter@^3.0.0-rc.1":
version "3.3.0"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-3.3.0.tgz#d9f2ffd0fbdbb47c07d5fd7c3e5dc180cff317aa"
integrity sha512-k57wN4yelePhmO9orcT/wzGMIuyedrMpVtg0FhxpV6BQu0+TZ/ti3W4Kb97GWJsoHKXMoing9SnioKfVnBW6hw==
"@tensorflow/tfjs-converter@^3.3.0":
version "3.6.0"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-3.6.0.tgz#32b3ff31b47e29630a82e30fbe01708facad7fd6"
integrity sha512-9MtatbTSvo3gpEulYI6+byTA3OeXSMT2lzyGAegXO9nMxsvjR01zBvlZ5SmsNyecNh6fMSzdL2+cCdQfQtsIBg==

"@tensorflow/tfjs-core@^3.0.0-rc.1":
version "3.3.0"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-3.3.0.tgz#3d26bd03cb58e0ecf46c96d118c39c4a90b7f5ed"
integrity sha512-6G+LcCiQBl4Kza5mDbWbf8QSWBTW3l7SDjGhQzMO1ITtQatHzxkuHGHcJ4CTUJvNA0JmKf4QJWOvlFqEmxwyLQ==
"@tensorflow/tfjs-core@^3.3.0":
version "3.6.0"
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-3.6.0.tgz#6b4d8175790bdff78868eabe6adc6442eb4dc276"
integrity sha512-bb2c3zwK4SgXZRvkTiC7EhCpWbCGp0GMd+1/3Vo2/Z54jiLB/h3sXIgHQrTNiWwhKPtst/xxA+MsslFlvD0A5w==
dependencies:
"@types/offscreencanvas" "~2019.3.0"
"@types/seedrandom" "2.4.27"
Expand Down
6,063 changes: 6,063 additions & 0 deletions deeplab/demo/yarn.lock

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion deeplab/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@ import {DeepLabInput, DeepLabOutput, ModelArchitecture, ModelConfig, PredictionC
import {getColormap, getLabels, getURL, toInputTensor, toSegmentationImage} from './utils';

export {version} from './version';
export {getColormap, getLabels, getURL, toSegmentationImage};
export {
getColormap,
getLabels,
getURL,
ModelConfig,
PredictionConfig,
toSegmentationImage
};

/**
* Initializes the DeepLab model and returns a `SemanticSegmentation` object.
Expand Down
75 changes: 35 additions & 40 deletions deeplab/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,67 +34,62 @@ export interface Legend {
[name: string]: Color;
}

/*
The model supports quantization to 1 and 2 bytes, leaving 4 for the
non-quantized variant.
*/
/**
* The model supports quantization to 1 and 2 bytes, leaving 4 for the
* non-quantized variant.
*
* @docinline
*/
export type QuantizationBytes = 1|2|4;
/*
Three types of pre-trained weights are available, trained on Pascal, Cityscapes
and ADE20K datasets. Each dataset has its own colormap and labelling scheme.
*/
/**
* Three types of pre-trained weights are available, trained on Pascal,
* Cityscapes and ADE20K datasets. Each dataset has its own colormap and
* labelling scheme.
*
* @docinline
*/
export type ModelArchitecture = 'pascal'|'cityscapes'|'ade20k';

export type DeepLabInput =
|ImageData|HTMLImageElement|HTMLCanvasElement|HTMLVideoElement|tf.Tensor3D;

/*
* The model can be configured with any of the following attributes:
*
* * quantizationBytes (optional) :: `QuantizationBytes`
*
* The degree to which weights are quantized (either 1, 2 or 4).
* Setting this attribute to 1 or 2 will load the model with int32 and
* float32 compressed to 1 or 2 bytes respectively.
* Set it to 4 to disable quantization.
*
* * base (optional) :: `ModelArchitecture`
*
* The type of model to load (either `pascal`, `cityscapes` or `ade20k`).
*
* * modelUrl (optional) :: `string`
*
* The URL from which to load the TF.js GraphModel JSON.
* Inferred from `base` and `quantizationBytes` if undefined.
*/
export interface ModelConfig {
/**
* The degree to which weights are quantized (either 1, 2 or 4).
* Setting this attribute to 1 or 2 will load the model with int32 and
* float32 compressed to 1 or 2 bytes respectively.
* Set it to 4 to disable quantization.
*/
quantizationBytes?: QuantizationBytes;
/**
* The type of model to load (either `pascal`, `cityscapes` or `ade20k`).
*/
base?: ModelArchitecture;
/**
*
* The URL from which to load the TF.js GraphModel JSON.
* Inferred from `base` and `quantizationBytes` if undefined.
*/
modelUrl?: string;
}

/*
*
* Segmentation can be fine-tuned with three parameters:
*
* - **canvas** (optional) :: `HTMLCanvasElement`
*
* The canvas where to draw the output
*
* - **colormap** (optional) :: `[number, number, number][]`
*
* The array of RGB colors corresponding to labels
*
* - **labels** (optional) :: `string[]`
*
* The array of names corresponding to labels
*
* By [default](./src/index.ts#L81), `colormap` and `labels` are set
* according to the `base` model attribute passed during initialization.
*/
export interface PredictionConfig {
/** The canvas where to draw the output. */
canvas?: HTMLCanvasElement;
/** The array of RGB colors corresponding to labels. */
colormap?: Color[];
/**
* The array of names corresponding to labels.
*
* By [default](./src/index.ts#L81), `colormap` and `labels` are set
* according to the `base` model attribute passed during initialization.
*/
labels?: string[];
}

Expand Down
127 changes: 124 additions & 3 deletions tasks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ for JS developers without ML knowledge. It has the following features:
- **Easy-to-discover models**

Models from different runtime systems (e.g. [TFJS][tfjs], [TFLite][tflite],
[MediaPipe][mediapipe], etc) are grouped by popular ML tasks, such as.
[MediaPipe][mediapipe], etc) are grouped by popular ML tasks, such as
sentiment detection, image classification, pose detection, etc.

- **Clean and powerful APIs**
Expand All @@ -28,7 +28,128 @@ for JS developers without ML knowledge. It has the following features:

The following table summarizes all the supported tasks and their models:

(TODO)
<table>
<thead>
<tr>
<th>Task</th>
<th>Model</th>
<th>Supported runtimes · Docs · Resources</th>
</tr>
</thead>
<tbody>
<!-- Image classification -->
<tr>
<td rowspan="2">
<b>Image Classification</b>
<br>
Identify images into predefined classes.
<br>
<a href="https://codepen.io/jinjingforever/pen/VwPOePq">Demo</a>
</td>
<td>Mobilenet</td>
<td>
<div>
<span><code>TFJS  </code></span>
<span>·</span>
<a href="#">API doc</a>
</div>
<div>
<span><code>TFLite</code></span>
<span>·</span>
<a href="#">API doc</a>
</div>
</td>
</tr>
<tr>
<td>Custom model</td>
<td>
<div>
<span><code>TFLite</code></span>
<span>·</span>
<a href="#">API doc</a>
<span>·</span>
<a href="https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_classifier#model_compatibility_requirements">Model requirements</a>
<span>·</span>
<a href="https://tfhub.dev/tensorflow/collections/lite/task-library/image-classifier/1">Model collection</a>
</div>
</td>
</tr>
<!-- Object detection -->
<tr>
<td rowspan="2">
<b>Object Detection</b>
<br>
Localize and identify multiple objects in a single image.
<br>
<a href="https://codepen.io/jinjingforever/pen/PopPPXo">Demo</a>
</td>
<td>COCO-SSD</td>
<td>
<div>
<span><code>TFJS  </code></span>
<span>·</span>
<a href="#">API doc</a>
</div>
<div>
<span><code>TFLite</code></span>
<span>·</span>
<a href="#">API doc</a>
</div>
</td>
</tr>
<tr>
<td>Custom model</td>
<td>
<div>
<span><code>TFLite</code></span>
<span>·</span>
<a href="#">API doc</a>
<span>·</span>
<a href="https://www.tensorflow.org/lite/inference_with_metadata/task_library/object_detector#model_compatibility_requirements">Model requirements</a>
<span>·</span>
<a href="https://tfhub.dev/tensorflow/collections/lite/task-library/object-detector/1">Model collection</a>
</div>
</td>
</tr>
<!-- Image Segmentation -->
<tr>
<td rowspan="2">
<b>Image Segmentation</b>
<br>
Predict associated class for each pixel of an image.
<br>
<a href="https://codepen.io/jinjingforever/pen/yLMYVJw">Demo</a>
</td>
<td>Deeplab</td>
<td>
<div>
<span><code>TFJS  </code></span>
<span>·</span>
<a href="#">API doc</a>
</div>
<div>
<span><code>TFLite</code></span>
<span>·</span>
<a href="#">API doc</a>
</div>
</td>
</tr>
<tr>
<td>Custom model</td>
<td>
<div>
<span><code>TFLite</code></span>
<span>·</span>
<a href="#">API doc</a>
<span>·</span>
<a href="https://www.tensorflow.org/lite/inference_with_metadata/task_library/image_segmenter#model_compatibility_requirements">Model requirements</a>
<span>·</span>
<a href="https://tfhub.dev/tensorflow/collections/lite/task-library/image-segmenter/1">Model collection</a>
</div>
</td>
</tr>
</tbody>
</table>

(The initial version only supports the web browser environment. NodeJS support is
coming soon)
Expand Down Expand Up @@ -78,7 +199,7 @@ const model3 = await tfTask.ImageClassification.CustomModel.TFLite.load({
Since all these models are for the `Image Classification` task, they will have
the same task model type: [`ImageClassifier`][image classifier interface] in
this case. Each task model's `predict` inference method has an unique and
easy-to-use API interface. For example, in `ImageClassiier`, the method takes an
easy-to-use API interface. For example, in `ImageClassifier`, the method takes an
image-like element and returns the predicted classes:

```js
Expand Down
4 changes: 3 additions & 1 deletion tasks/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
"@tensorflow/tfjs-converter": "^3.5.0",
"@tensorflow/tfjs-core": "^3.5.0",
"@tensorflow/tfjs-tflite": "0.0.1-alpha.3",
"@tensorflow-models/mobilenet": "^2.1.0",
"@tensorflow-models/mobilenet": "link:../mobilenet",
"@tensorflow-models/coco-ssd": "link:../coco-ssd",
"@tensorflow-models/deeplab": "link:../deeplab",
"@types/jasmine": "~3.6.9",
"clang-format": "~1.5.0",
"jasmine": "~3.7.0",
Expand Down
26 changes: 26 additions & 0 deletions tasks/src/tasks/all_tasks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ import {Runtime, Task} from './common';
import {imageClassificationCustomModelTfliteLoader} from './image_classification/custom_model_tflite';
import {mobilenetTfjsLoader} from './image_classification/mobilenet_tfjs';
import {mobilenetTfliteLoader} from './image_classification/mobilenet_tflite';
import {imageSegmenterCustomModelTfliteLoader} from './image_segmentation/custom_model_tflite';
import {deeplabTfjsLoader} from './image_segmentation/deeplab_tfjs';
import {deeplabTfliteLoader} from './image_segmentation/deeplab_tflite';
import {cocoSsdTfjsLoader} from './object_detection/cocossd_tfjs';
import {cocoSsdTfliteLoader} from './object_detection/cocossd_tflite';
import {objectDetectorCustomModelTfliteLoader} from './object_detection/custom_model_tflite';

/**
* The main model index.
Expand Down Expand Up @@ -48,11 +54,31 @@ const modelIndex = {
[Runtime.TFLITE]: imageClassificationCustomModelTfliteLoader,
},
},
[Task.OBJECT_DETECTION]: {
CocoSsd: {
[Runtime.TFJS]: cocoSsdTfjsLoader,
[Runtime.TFLITE]: cocoSsdTfliteLoader,
},
CustomModel: {
[Runtime.TFLITE]: objectDetectorCustomModelTfliteLoader,
},
},
[Task.IMAGE_SEGMENTATION]: {
Deeplab: {
[Runtime.TFJS]: deeplabTfjsLoader,
[Runtime.TFLITE]: deeplabTfliteLoader,
},
CustomModel: {
[Runtime.TFLITE]: imageSegmenterCustomModelTfliteLoader,
},
},
};

// Export each task individually.

export const ImageClassification = modelIndex[Task.IMAGE_CLASSIFICATION];
export const ObjectDetection = modelIndex[Task.OBJECT_DETECTION];
export const ImageSegmentation = modelIndex[Task.IMAGE_SEGMENTATION];

/**
* Filter model loaders by runtimes.
Expand Down
Loading

0 comments on commit ab97e06

Please sign in to comment.