Skip to content

Commit

Permalink
[pose-detection]Use f16 model. (#708)
Browse files Browse the repository at this point in the history
FEATURE
  • Loading branch information
lina128 authored May 13, 2021
1 parent fde8eab commit 59bd40b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pose-detection/src/blazepose_tfjs/blazepose_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {expectArraysClose} from '@tensorflow/tfjs-core/dist/test_util';
import * as poseDetection from '../index';
import {getXYPerFrame, KARMA_SERVER, loadImage, loadVideo} from '../test_util';

const EPSILON_IMAGE = 18;
const EPSILON_IMAGE = 19;
const EPSILON_VIDEO = 15;

// ref:
Expand Down
8 changes: 4 additions & 4 deletions pose-detection/src/blazepose_tfjs/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import {BlazePoseTfjsModelConfig} from './types';

export const DEFAULT_BLAZEPOSE_DETECTOR_MODEL_URL =
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/detector/heatmap/model.json';
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/detector/f16/model.json';
export const DEFAULT_BLAZEPOSE_LANDMARK_MODEL_URL_FULL =
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/landmark/full/model.json';
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/landmark/full-f16/model.json';
export const DEFAULT_BLAZEPOSE_LANDMARK_MODEL_URL_LITE =
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/landmark/lite/model.json';
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/landmark/lite-f16/model.json';
export const DEFAULT_BLAZEPOSE_LANDMARK_MODEL_URL_HEAVY =
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/landmark/heavy/model.json';
'https://storage.googleapis.com/tfjs-models/savedmodel/blazepose/landmark/heavy-f16/model.json';
export const BLAZEPOSE_DETECTOR_ANCHOR_CONFIGURATION = {
reduceBoxesInLowestlayer: false,
interpolatedScaleAspectRatio: 1.0,
Expand Down
26 changes: 13 additions & 13 deletions pose-detection/src/blazepose_tfjs/detector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -324,18 +324,18 @@ export class BlazePoseTfjsDetector extends BasePoseDetector {
// Output[3]: This tensor (shape: [1, 195]) represents 39 5-d keypoints.
// The first 33 refer to the keypoints. The final 6 key points refer to
// the alignment points from the detector model and the hands.)
// Output [0]: This tensor (shape: [1, 1]) represents the confidence
// Output [4]: This tensor (shape: [1, 1]) represents the confidence
// score.
// Output [2]: This tensor (shape: [1, 64, 64, 39]) represents heatmap for
// Output [1]: This tensor (shape: [1, 64, 64, 39]) represents heatmap for
// the 39 landmarks.
// Lite model:
// Output[1]: This tensor (shape: [1, 195]) represents 39 5-d keypoints.
// Output[2]: This tensor (shape: [1, 1]) represents the confidence score.
// Output[4]: This tensor (shape: [1, 64, 64, 39]) represents heatmap for
// Output[4]: This tensor (shape: [1, 195]) represents 39 5-d keypoints.
// Output[3]: This tensor (shape: [1, 1]) represents the confidence score.
// Output[1]: This tensor (shape: [1, 64, 64, 39]) represents heatmap for
// the 39 landmarks.
// Heavy model:
// Output[3]: This tensor (shape: [1, 195]) represents 39 5-d keypoints.
// Output[2]: This tensor (shape: [1, 1]) represents the confidence score.
// Output[1]: This tensor (shape: [1, 1]) represents the confidence score.
// Output[4]: This tensor (shape: [1, 64, 64, 39]) represents heatmap for
// the 39 landmarks.
const landmarkResult =
Expand All @@ -345,18 +345,18 @@ export class BlazePoseTfjsDetector extends BasePoseDetector {

switch (this.modelType) {
case 'lite':
landmarkTensor = landmarkResult[1] as tf.Tensor2D;
poseFlagTensor = landmarkResult[2] as tf.Tensor2D;
heatmapTensor = landmarkResult[4] as tf.Tensor4D;
landmarkTensor = landmarkResult[3] as tf.Tensor2D;
poseFlagTensor = landmarkResult[4] as tf.Tensor2D;
heatmapTensor = landmarkResult[1] as tf.Tensor4D;
break;
case 'full':
landmarkTensor = landmarkResult[3] as tf.Tensor2D;
poseFlagTensor = landmarkResult[0] as tf.Tensor2D;
heatmapTensor = landmarkResult[2] as tf.Tensor4D;
landmarkTensor = landmarkResult[4] as tf.Tensor2D;
poseFlagTensor = landmarkResult[3] as tf.Tensor2D;
heatmapTensor = landmarkResult[1] as tf.Tensor4D;
break;
case 'heavy':
landmarkTensor = landmarkResult[3] as tf.Tensor2D;
poseFlagTensor = landmarkResult[2] as tf.Tensor2D;
poseFlagTensor = landmarkResult[1] as tf.Tensor2D;
heatmapTensor = landmarkResult[4] as tf.Tensor4D;
break;
default:
Expand Down

0 comments on commit 59bd40b

Please sign in to comment.