diff --git a/mmdeploy/mmcv/ops/nms.py b/mmdeploy/mmcv/ops/nms.py index dc22c55380..6e71b7cc22 100644 --- a/mmdeploy/mmcv/ops/nms.py +++ b/mmdeploy/mmcv/ops/nms.py @@ -610,7 +610,6 @@ def multiclass_nms__torchscript(boxes: Tensor, Use batched_nms from torchvision instead of custom nms. """ - assert not output_index, 'output_index is not supported on this backend.' # TODO: simplify inference for non-batch model from torchvision.ops import batched_nms batch_size = scores.shape[0] @@ -618,11 +617,12 @@ def multiclass_nms__torchscript(boxes: Tensor, num_classes = scores.shape[2] box_per_cls = len(boxes.shape) == 4 scores = torch.where(scores > score_threshold, scores, scores.new_zeros(1)) - + pre_topk_inds = None # pre-topk if pre_top_k > 0: max_scores, _ = scores.max(-1) _, topk_inds = max_scores.topk(pre_top_k) + pre_topk_inds = topk_inds batch_inds = torch.arange(batch_size).view(-1, 1).long() boxes = boxes[batch_inds, topk_inds, ...] scores = scores[batch_inds, topk_inds, :] @@ -646,10 +646,14 @@ def multiclass_nms__torchscript(boxes: Tensor, keeps = torch.cat(keeps) scores = scores.permute(0, 2, 1) - dets, labels = _select_nms_index( - scores, boxes, keeps, batch_size, keep_top_k=keep_top_k) - - return dets, labels + return _select_nms_index( + scores, + boxes, + keeps, + batch_size, + keep_top_k=keep_top_k, + pre_inds=pre_topk_inds, + output_index=output_index) class AscendBatchNMSOp(torch.autograd.Function): diff --git a/mmdeploy/pytorch/functions/multi_head_attention_forward.py b/mmdeploy/pytorch/functions/multi_head_attention_forward.py index 50f3f6c1ea..97d9ad327a 100644 --- a/mmdeploy/pytorch/functions/multi_head_attention_forward.py +++ b/mmdeploy/pytorch/functions/multi_head_attention_forward.py @@ -53,3 +53,29 @@ def _scaled_dot_product_attention__tensorrt(q: Tensor, **kwargs) -> Tuple[Tensor, Tensor]: """Rewrite for custom ops.""" return ScaledDotProductAttentionTRT.apply(q, k, v, attn_mask) + + +@FUNCTION_REWRITER.register_rewriter( + func_name='torch.nn.functional.scaled_dot_product_attention', + backend=Backend.DEFAULT.value) +def scaled_dot_product_attention__default(query, + key, + value, + attn_mask=None, + dropout_p=0., + scale=None, + is_causal=False): + """Rewrite to export to onnx on torch>=2.0.0.""" + scale = scale or query.size(-1)**0.5 + if is_causal and attn_mask is not None: + attn_mask = torch.ones( + query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0) + if attn_mask is not None and attn_mask.dtype == torch.bool: + attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) + + attn_weight = query @ key.transpose(-2, -1) / scale + if attn_mask is not None: + attn_weight += attn_mask + attn_weight = torch.softmax(attn_weight, dim=-1) + attn_weight = torch.dropout(attn_weight, dropout_p, True) + return attn_weight @ value