From 004e436e85f57a4f86fad98a2def1779adbb516e Mon Sep 17 00:00:00 2001 From: wep21 Date: Wed, 7 Jun 2023 02:44:16 +0900 Subject: [PATCH] override predict method Signed-off-by: wep21 --- .../TransFusion/transfusion/transfusion.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/projects/TransFusion/transfusion/transfusion.py b/projects/TransFusion/transfusion/transfusion.py index b5efd606b1..ed3e429026 100644 --- a/projects/TransFusion/transfusion/transfusion.py +++ b/projects/TransFusion/transfusion/transfusion.py @@ -1,5 +1,10 @@ +from typing import Dict, List, Optional + +from torch import Tensor + from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector from mmdet3d.registry import MODELS +from mmdet3d.structures import Det3DDataSample @MODELS.register_module() @@ -23,3 +28,42 @@ def init_weights(self): if self.with_img_neck: for param in self.img_neck.parameters(): param.requires_grad = False + + def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]], + batch_data_samples: List[Det3DDataSample], + **kwargs) -> List[Det3DDataSample]: + """Forward of testing. + + Args: + batch_inputs_dict (dict): The model input dict which include + 'points' keys. + + - points (list[torch.Tensor]): Point cloud of each sample. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance_3d`. + + Returns: + list[:obj:`Det3DDataSample`]: Detection results of the + input sample. Each Det3DDataSample usually contain + 'pred_instances_3d'. And the ``pred_instances_3d`` usually + contains following keys. + + - scores_3d (Tensor): Classification scores, has a shape + (num_instances, ) + - labels_3d (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bbox_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes, + contains a tensor with shape (num_instances, 7). + """ + batch_input_metas = [item.metainfo for item in batch_data_samples] + img_feats, pts_feats = self.extract_feat(batch_inputs_dict, batch_input_metas) + + if pts_feats and self.with_pts_bbox: + outputs = self.pts_bbox_head.predict(pts_feats, batch_input_metas) + else: + outputs = None + + res = self.add_pred_to_datasample(batch_data_samples, outputs) + + return res \ No newline at end of file