-
Notifications
You must be signed in to change notification settings - Fork 643
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a3c2fae
commit 50e9014
Showing
7 changed files
with
87 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,7 @@ | |
0: 'batch' | ||
} | ||
}) | ||
|
||
codebase_config = dict( | ||
export_postprocess=False # do not export get_simcc_maximum | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
from .post_processing import get_simcc_maximum | ||
|
||
__all__ = ['get_simcc_maximum'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import torch | ||
|
||
|
||
def get_simcc_maximum(simcc_x: torch.Tensor, | ||
simcc_y: torch.Tensor) -> torch.Tensor: | ||
"""Get maximum response location and value from simcc representations. | ||
rewrite to support `torch.Tensor` input type. | ||
Args: | ||
simcc_x (torch.Tensor): x-axis SimCC in shape (N, K, Wx) | ||
simcc_y (torch.Tensor): y-axis SimCC in shape (N, K, Wy) | ||
Returns: | ||
tuple: | ||
- locs (torch.Tensor): locations of maximum heatmap responses in shape | ||
(N, K, 2) | ||
- vals (torch.Tensor): values of maximum heatmap responses in shape | ||
(N, K) | ||
""" | ||
N, K, _ = simcc_x.shape | ||
simcc_x = simcc_x.flatten(0, 1) | ||
simcc_y = simcc_y.flatten(0, 1) | ||
x_locs = simcc_x.argmax(dim=1, keepdim=True) | ||
y_locs = simcc_y.argmax(dim=1, keepdim=True) | ||
locs = torch.cat((x_locs, y_locs), dim=1).to(torch.float32) | ||
max_val_x, _ = simcc_x.max(dim=1, keepdim=True) | ||
max_val_y, _ = simcc_y.max(dim=1, keepdim=True) | ||
vals, _ = torch.cat([max_val_x, max_val_y], dim=1).min(dim=1) | ||
locs = locs.reshape(N, K, 2) | ||
vals = vals.reshape(N, K) | ||
return locs, vals | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from . import mspn_head, yolox_pose_head # noqa: F401,F403 | ||
from . import mspn_head, simcc_head, yolox_pose_head # noqa: F401,F403 | ||
|
||
__all__ = ['mspn_head', 'yolox_pose_head'] | ||
__all__ = ['mspn_head', 'yolox_pose_head', 'simcc_head'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmdeploy.codebase.mmpose.codecs import get_simcc_maximum | ||
from mmdeploy.core import FUNCTION_REWRITER | ||
from mmdeploy.utils import get_codebase_config | ||
|
||
|
||
@FUNCTION_REWRITER.register_rewriter('mmpose.models.heads.RTMCCHead.forward') | ||
@FUNCTION_REWRITER.register_rewriter('mmpose.models.heads.SimCCHead.forward') | ||
def simcc_head__forward(self, feats): | ||
"""Rewrite `forward` of SimCCHead for default backend. | ||
Args: | ||
feats (tuple[Tensor]): Input features. | ||
Returns: | ||
key-points (torch.Tensor): Output keypoints in | ||
shape of (N, K, 3) | ||
""" | ||
ctx = FUNCTION_REWRITER.get_context() | ||
simcc_x, simcc_y = ctx.origin_func(self, feats) | ||
codebase_cfg = get_codebase_config(ctx.cfg) | ||
export_postprocess = codebase_cfg.get('export_postprocess', False) | ||
if not export_postprocess: | ||
return simcc_x, simcc_y | ||
assert self.decoder.use_dark is False, \ | ||
'Do not support SimCCLabel with use_dark=True' | ||
pts, scores = get_simcc_maximum(simcc_x, simcc_y) | ||
pts /= self.decoder.simcc_split_ratio | ||
return pts, scores | ||