From aeac6b34318fe7aedf5515c72aefa8a7d9937d5b Mon Sep 17 00:00:00 2001 From: huzhenhong <455879568@qq.com> Date: Wed, 29 Nov 2023 18:20:42 +0800 Subject: [PATCH 1/5] Support macos onnxruntime --- mmdeploy/backend/onnxruntime/init_plugins.py | 2 ++ setup.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mmdeploy/backend/onnxruntime/init_plugins.py b/mmdeploy/backend/onnxruntime/init_plugins.py index fd0d850fe5..3897be22a9 100644 --- a/mmdeploy/backend/onnxruntime/init_plugins.py +++ b/mmdeploy/backend/onnxruntime/init_plugins.py @@ -13,6 +13,7 @@ def get_ops_path() -> str: candidates = [ '../../lib/libmmdeploy_onnxruntime_ops.so', '../../lib/mmdeploy_onnxruntime_ops.dll', + '../../lib/libmmdeploy_onnxruntime_ops.dylib', ] return get_file_path(os.path.dirname(__file__), candidates) @@ -26,5 +27,6 @@ def get_lib_path() -> str: candidates = [ '../../lib/libonnxruntime.so*', '../../lib/onnxruntime.dll', + '../../lib/libmmdeploy_onnxruntime_ops.dylib', ] return get_file_path(os.path.dirname(__file__), candidates) diff --git a/setup.py b/setup.py index ddd853648d..aa71fd2898 100644 --- a/setup.py +++ b/setup.py @@ -138,7 +138,9 @@ def get_extensions(): # environment, the compiler will choose the appropriate compiler # to compile those cpp files, so there is no need to add the # argument - if platform.system() != 'Windows': + if platform.system() == 'Darwin': + extra_compile_args['cxx'] = ['-std=c++17'] + elif platform.system() != 'Windows': extra_compile_args['cxx'] = ['-std=c++14'] include_dirs = [] From 6a6314aeb2e4b42e3de61a886818bc0f799afd85 Mon Sep 17 00:00:00 2001 From: huzhenhong <455879568@qq.com> Date: Mon, 11 Dec 2023 21:02:23 +0800 Subject: [PATCH 2/5] format cpp and cuda code mistake --- .clang-format | 403 +- csrc/mmdeploy/apis/c/mmdeploy/classifier.cpp | 196 +- csrc/mmdeploy/apis/c/mmdeploy/classifier.h | 233 +- csrc/mmdeploy/apis/c/mmdeploy/common.cpp | 205 +- csrc/mmdeploy/apis/c/mmdeploy/common.h | 247 +- .../apis/c/mmdeploy/common_internal.h | 215 +- csrc/mmdeploy/apis/c/mmdeploy/detector.cpp | 208 +- csrc/mmdeploy/apis/c/mmdeploy/detector.h | 231 +- csrc/mmdeploy/apis/c/mmdeploy/executor.cpp | 350 +- csrc/mmdeploy/apis/c/mmdeploy/executor.h | 208 +- .../apis/c/mmdeploy/executor_internal.h | 66 +- csrc/mmdeploy/apis/c/mmdeploy/handle.h | 83 +- csrc/mmdeploy/apis/c/mmdeploy/model.cpp | 61 +- csrc/mmdeploy/apis/c/mmdeploy/model.h | 53 +- csrc/mmdeploy/apis/c/mmdeploy/pipeline.cpp | 137 +- csrc/mmdeploy/apis/c/mmdeploy/pipeline.h | 92 +- .../apis/c/mmdeploy/pose_detector.cpp | 301 +- csrc/mmdeploy/apis/c/mmdeploy/pose_detector.h | 208 +- .../mmdeploy/apis/c/mmdeploy/pose_tracker.cpp | 272 +- csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h | 273 +- csrc/mmdeploy/apis/c/mmdeploy/restorer.cpp | 177 +- csrc/mmdeploy/apis/c/mmdeploy/restorer.h | 132 +- .../apis/c/mmdeploy/rotated_detector.cpp | 200 +- .../apis/c/mmdeploy/rotated_detector.h | 235 +- csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp | 189 +- csrc/mmdeploy/apis/c/mmdeploy/segmentor.h | 165 +- .../apis/c/mmdeploy/text_detector.cpp | 281 +- csrc/mmdeploy/apis/c/mmdeploy/text_detector.h | 272 +- .../apis/c/mmdeploy/text_recognizer.cpp | 359 +- .../apis/c/mmdeploy/text_recognizer.h | 288 +- .../apis/c/mmdeploy/video_recognizer.cpp | 262 +- .../apis/c/mmdeploy/video_recognizer.h | 237 +- .../mmdeploy/apis/cxx/mmdeploy/classifier.hpp | 136 +- csrc/mmdeploy/apis/cxx/mmdeploy/common.hpp | 603 +- csrc/mmdeploy/apis/cxx/mmdeploy/detector.hpp | 136 +- csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp | 127 +- .../apis/cxx/mmdeploy/pose_detector.hpp | 155 +- .../apis/cxx/mmdeploy/pose_tracker.hpp | 304 +- csrc/mmdeploy/apis/cxx/mmdeploy/restorer.hpp | 127 +- .../apis/cxx/mmdeploy/rotated_detector.hpp | 138 +- csrc/mmdeploy/apis/cxx/mmdeploy/segmentor.hpp | 133 +- .../apis/cxx/mmdeploy/text_detector.hpp | 138 +- .../apis/cxx/mmdeploy/text_recognizer.hpp | 161 +- .../apis/cxx/mmdeploy/video_recognizer.hpp | 170 +- csrc/mmdeploy/apis/java/native/common.h | 81 +- .../apis/java/native/mmdeploy_Classifier.cpp | 46 +- .../apis/java/native/mmdeploy_Classifier.h | 50 +- .../apis/java/native/mmdeploy_Context.cpp | 63 +- .../apis/java/native/mmdeploy_Context.h | 49 +- .../apis/java/native/mmdeploy_Detector.cpp | 44 +- .../apis/java/native/mmdeploy_Detector.h | 50 +- .../apis/java/native/mmdeploy_Device.cpp | 29 +- .../apis/java/native/mmdeploy_Device.h | 37 +- .../apis/java/native/mmdeploy_Model.cpp | 29 +- .../apis/java/native/mmdeploy_Model.h | 37 +- .../java/native/mmdeploy_PoseDetector.cpp | 45 +- .../apis/java/native/mmdeploy_PoseDetector.h | 51 +- .../apis/java/native/mmdeploy_PoseTracker.cpp | 261 +- .../apis/java/native/mmdeploy_PoseTracker.h | 86 +- .../apis/java/native/mmdeploy_Profiler.cpp | 29 +- .../apis/java/native/mmdeploy_Profiler.h | 37 +- .../apis/java/native/mmdeploy_Restorer.cpp | 44 +- .../apis/java/native/mmdeploy_Restorer.h | 49 +- .../java/native/mmdeploy_RotatedDetector.cpp | 45 +- .../java/native/mmdeploy_RotatedDetector.h | 51 +- .../apis/java/native/mmdeploy_Scheduler.cpp | 21 +- .../apis/java/native/mmdeploy_Scheduler.h | 49 +- .../apis/java/native/mmdeploy_Segmentor.cpp | 44 +- .../apis/java/native/mmdeploy_Segmentor.h | 50 +- .../java/native/mmdeploy_TextDetector.cpp | 45 +- .../apis/java/native/mmdeploy_TextDetector.h | 51 +- .../java/native/mmdeploy_TextRecognizer.cpp | 56 +- .../java/native/mmdeploy_TextRecognizer.h | 65 +- csrc/mmdeploy/apis/python/classifier.cpp | 122 +- csrc/mmdeploy/apis/python/common.cpp | 334 +- csrc/mmdeploy/apis/python/common.h | 27 +- csrc/mmdeploy/apis/python/detector.cpp | 161 +- csrc/mmdeploy/apis/python/executor.cpp | 69 +- csrc/mmdeploy/apis/python/internal.cpp | 99 +- csrc/mmdeploy/apis/python/pipeline.cpp | 60 +- csrc/mmdeploy/apis/python/pose_detector.cpp | 235 +- csrc/mmdeploy/apis/python/pose_tracker.cpp | 314 +- csrc/mmdeploy/apis/python/restorer.cpp | 120 +- .../mmdeploy/apis/python/rotated_detector.cpp | 141 +- csrc/mmdeploy/apis/python/segmentor.cpp | 145 +- csrc/mmdeploy/apis/python/text_detector.cpp | 131 +- csrc/mmdeploy/apis/python/text_recognizer.cpp | 162 +- .../mmdeploy/apis/python/video_recognizer.cpp | 165 +- csrc/mmdeploy/archive/json_archive.h | 434 +- csrc/mmdeploy/archive/value_archive.h | 288 +- .../common_cuda_helper.cuh | 132 +- .../modulated_deform_conv_cpu.h | 129 +- .../modulated_deform_conv_cuda.cuh | 218 +- .../backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp | 4300 ++++---- .../backend_ops/ncnn/onnx2ncnn/fuse_pass.h | 148 +- .../backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp | 6078 ++++++----- .../ncnn/onnx2ncnn/shape_inference.cpp | 300 +- .../ncnn/onnx2ncnn/shape_inference.h | 5 +- .../backend_ops/ncnn/onnx2ncnn/utils.h | 761 +- .../ops/constantofshape/constantofshape.cpp | 95 +- .../ops/constantofshape/constantofshape.h | 21 +- .../backend_ops/ncnn/ops/expand/expand.cpp | 742 +- .../backend_ops/ncnn/ops/expand/expand.h | 15 +- .../backend_ops/ncnn/ops/gather/gather.cpp | 301 +- .../backend_ops/ncnn/ops/gather/gather.h | 21 +- .../backend_ops/ncnn/ops/ncnn_ops_definer.h | 30 +- .../ncnn/ops/ncnn_ops_register.cpp | 48 +- .../backend_ops/ncnn/ops/ncnn_ops_register.h | 2 +- .../backend_ops/ncnn/ops/shape/shape.cpp | 85 +- .../backend_ops/ncnn/ops/shape/shape.h | 15 +- .../ncnn/ops/tensorslice/tensorslice.cpp | 418 +- .../ncnn/ops/tensorslice/tensorslice.h | 27 +- .../backend_ops/ncnn/ops/topk/topk.cpp | 1914 ++-- .../mmdeploy/backend_ops/ncnn/ops/topk/topk.h | 27 +- .../backend_ops/ncnn/pyncnn_ext/ncnn_ext.cpp | 12 +- .../onnxruntime/common/onnxruntime_register.h | 7 +- .../onnxruntime/common/ort_utils.cpp | 12 +- .../onnxruntime/common/ort_utils.h | 59 +- .../onnxruntime/grid_sample/grid_sample.cpp | 592 +- .../onnxruntime/grid_sample/grid_sample.h | 90 +- .../modulated_deform_conv.cpp | 389 +- .../modulated_deform_conv.h | 119 +- .../onnxruntime/nms_match/nms_match.cpp | 233 +- .../onnxruntime/nms_match/nms_match.h | 80 +- .../onnxruntime/nms_rotated/nms_rotated.cpp | 736 +- .../onnxruntime/nms_rotated/nms_rotated.h | 84 +- .../onnxruntime/onnxruntime_register.cpp | 41 +- .../roi_align_rotated/roi_align_rotated.cpp | 458 +- .../roi_align_rotated/roi_align_rotated.h | 108 +- .../tensorrt/batched_nms/trt_batched_nms.cpp | 511 +- .../tensorrt/batched_nms/trt_batched_nms.hpp | 94 +- .../trt_batched_rotated_nms.cpp | 509 +- .../trt_batched_rotated_nms.hpp | 86 +- .../trt_bicubic_interpolate.cpp | 390 +- .../trt_bicubic_interpolate.hpp | 89 +- .../trt_bicubic_interpolate_kernel.cu | 277 +- .../trt_bicubic_interpolate_kernel.hpp | 6 +- .../tensorrt/common/common_cuda_helper.hpp | 87 +- .../common/nms/batched_nms_kernel.hpp | 10 +- .../tensorrt/common/nms/cub_helper.h | 20 +- .../backend_ops/tensorrt/common/nms/kernel.h | 78 +- .../tensorrt/common/trt_plugin_base.hpp | 125 +- .../tensorrt/common/trt_plugin_helper.hpp | 270 +- .../tensorrt/common/trt_serialize.hpp | 160 +- .../tensorrt/common_impl/nms/allClassNMS.cu | 463 +- .../common_impl/nms/allClassRotatedNMS.cu | 915 +- .../common_impl/nms/batched_nms_kernel.cpp | 224 +- .../common_impl/nms/gatherNMSOutputs.cu | 279 +- .../tensorrt/common_impl/nms/kernel.cu | 141 +- .../tensorrt/common_impl/nms/permuteData.cu | 110 +- .../common_impl/nms/sortScoresPerClass.cu | 248 +- .../common_impl/nms/sortScoresPerImage.cu | 129 +- .../tensorrt/common_impl/trt_cuda_helper.cu | 150 +- .../tensorrt/deform_conv/trt_deform_conv.cpp | 526 +- .../tensorrt/deform_conv/trt_deform_conv.hpp | 131 +- .../deform_conv/trt_deform_conv_kernel.cu | 196 +- .../deform_conv/trt_deform_conv_kernel.cuh | 204 +- .../deform_conv/trt_deform_conv_kernel.hpp | 16 +- .../tensorrt/gather_topk/gather_topk.cpp | 309 +- .../tensorrt/gather_topk/gather_topk.hpp | 100 +- .../gather_topk/gather_topk_kernel.cu | 55 +- .../gather_topk/gather_topk_kernel.hpp | 6 +- .../tensorrt/grid_priors/trt_grid_priors.cpp | 327 +- .../tensorrt/grid_priors/trt_grid_priors.hpp | 82 +- .../grid_priors/trt_grid_priors_kernel.cu | 63 +- .../grid_priors/trt_grid_priors_kernel.hpp | 5 +- .../grid_sampler/trt_grid_sampler.cpp | 419 +- .../grid_sampler/trt_grid_sampler.hpp | 94 +- .../grid_sampler/trt_grid_sampler_kernel.cu | 718 +- .../grid_sampler/trt_grid_sampler_kernel.hpp | 19 +- .../instance_norm/trt_instance_norm.cpp | 436 +- .../instance_norm/trt_instance_norm.hpp | 100 +- .../trt_modulated_deform_conv.cpp | 673 +- .../trt_modulated_deform_conv.hpp | 133 +- .../trt_modulated_deform_conv_kernel.cu | 314 +- .../trt_modulated_deform_conv_kernel.hpp | 32 +- .../trt_multi_level_roi_align.cpp | 474 +- .../trt_multi_level_roi_align.hpp | 102 +- .../trt_multi_level_roi_align_kernel.cu | 373 +- .../trt_multi_level_roi_align_kernel.hpp | 8 +- .../trt_multi_level_rotated_roi_align.cpp | 494 +- .../trt_multi_level_rotated_roi_align.hpp | 103 +- ...rt_multi_level_rotated_roi_align_kernel.cu | 348 +- ...t_multi_level_rotated_roi_align_kernel.hpp | 8 +- .../trt_ms_deform_attn.cpp | 364 +- .../trt_ms_deform_attn.hpp | 89 +- .../trt_ms_deform_attn_kernel.cu | 121 +- .../trt_ms_deform_attn_kernel.cuh | 456 +- .../trt_ms_deform_attn_kernel.hpp | 9 +- .../tensorrt/roi_align/trt_roi_align.cpp | 515 +- .../tensorrt/roi_align/trt_roi_align.hpp | 97 +- .../roi_align/trt_roi_align_kernel.cu | 203 +- .../roi_align/trt_roi_align_kernel.hpp | 9 +- .../scaled_dot_product_attention.cpp | 389 +- .../scaled_dot_product_attention.hpp | 95 +- .../scaled_dot_product_attention_kernel.cu | 144 +- .../scaled_dot_product_attention_kernel.hpp | 10 +- .../tensorrt/scatternd/trt_scatternd.cpp | 331 +- .../tensorrt/scatternd/trt_scatternd.hpp | 100 +- .../scatternd/trt_scatternd_kernel.cu | 114 +- .../scatternd/trt_scatternd_kernel.hpp | 6 +- .../backend_ops/torchscript/ops/bind.cpp | 19 +- .../ops/coreml_nms/coreml_nms_cpu.cpp | 52 +- .../modulated_deform_conv_cpu.cpp | 167 +- .../modulated_deform_conv_cuda.cu | 149 +- .../torchscript/optimizer/bind.cpp | 61 +- .../optimizer/ir/subgraph_matcher.cpp | 643 +- .../optimizer/ir/subgraph_matcher.h | 64 +- .../torchscript/optimizer/optimizer.cpp | 78 +- .../torchscript/optimizer/optimizer.h | 9 +- .../onnx/common_subgraph_elimination.cpp | 282 +- .../passes/onnx/common_subgraph_elimination.h | 24 +- .../passes/onnx/flatten_cls_head.cpp | 190 +- .../optimizer/passes/onnx/flatten_cls_head.h | 12 +- .../passes/onnx/fuse_select_assign.cpp | 287 +- .../passes/onnx/fuse_select_assign.h | 18 +- .../passes/onnx/merge_shape_concate.cpp | 234 +- .../passes/onnx/merge_shape_concate.h | 12 +- .../optimizer/passes/onnx/onnx_peephole.cpp | 158 +- .../optimizer/passes/onnx/onnx_peephole.h | 12 +- .../torchscript/optimizer/passes/onnx/utils.h | 24 +- csrc/mmdeploy/codebase/common.h | 128 +- csrc/mmdeploy/codebase/mmaction/base_head.cpp | 109 +- .../codebase/mmaction/format_shape.cpp | 243 +- .../mmdeploy/codebase/mmaction/format_shape.h | 29 +- csrc/mmdeploy/codebase/mmaction/mmaction.cpp | 5 +- csrc/mmdeploy/codebase/mmaction/mmaction.h | 18 +- csrc/mmdeploy/codebase/mmcls/linear_cls.cpp | 218 +- csrc/mmdeploy/codebase/mmcls/mmcls.cpp | 5 +- csrc/mmdeploy/codebase/mmcls/mmcls.h | 18 +- .../codebase/mmcls/multi_label_linear_cls.cpp | 88 +- .../codebase/mmdet/base_dense_head.cpp | 190 +- .../mmdeploy/codebase/mmdet/base_dense_head.h | 27 +- .../codebase/mmdet/instance_segmentation.cpp | 395 +- csrc/mmdeploy/codebase/mmdet/mmdet.cpp | 5 +- csrc/mmdeploy/codebase/mmdet/mmdet.h | 30 +- .../codebase/mmdet/object_detection.cpp | 350 +- .../codebase/mmdet/object_detection.h | 33 +- csrc/mmdeploy/codebase/mmdet/rtmdet_head.cpp | 365 +- csrc/mmdeploy/codebase/mmdet/rtmdet_head.h | 40 +- csrc/mmdeploy/codebase/mmdet/utils.cpp | 152 +- csrc/mmdeploy/codebase/mmdet/utils.h | 33 +- csrc/mmdeploy/codebase/mmdet/yolo_head.cpp | 435 +- csrc/mmdeploy/codebase/mmdet/yolo_head.h | 74 +- csrc/mmdeploy/codebase/mmedit/mmedit.cpp | 5 +- csrc/mmdeploy/codebase/mmedit/mmedit.h | 7 +- csrc/mmdeploy/codebase/mmedit/restorer.cpp | 106 +- .../codebase/mmocr/attention_convertor.cpp | 127 +- .../codebase/mmocr/base_convertor.cpp | 298 +- csrc/mmdeploy/codebase/mmocr/base_convertor.h | 56 +- .../codebase/mmocr/contour_expand.cpp | 227 +- csrc/mmdeploy/codebase/mmocr/cpu/dbnet.cpp | 128 +- csrc/mmdeploy/codebase/mmocr/cpu/panet.cpp | 98 +- csrc/mmdeploy/codebase/mmocr/cpu/psenet.cpp | 102 +- csrc/mmdeploy/codebase/mmocr/crnn.cpp | 130 +- .../mmocr/cuda/connected_component.cu | 867 +- .../codebase/mmocr/cuda/connected_component.h | 29 +- csrc/mmdeploy/codebase/mmocr/cuda/dbnet.cpp | 135 +- csrc/mmdeploy/codebase/mmocr/cuda/panet.cpp | 185 +- csrc/mmdeploy/codebase/mmocr/cuda/psenet.cpp | 130 +- csrc/mmdeploy/codebase/mmocr/cuda/utils.cu | 211 +- csrc/mmdeploy/codebase/mmocr/cuda/utils.h | 34 +- csrc/mmdeploy/codebase/mmocr/dbnet.cpp | 263 +- csrc/mmdeploy/codebase/mmocr/dbnet.h | 27 +- csrc/mmdeploy/codebase/mmocr/mmocr.cpp | 5 +- csrc/mmdeploy/codebase/mmocr/mmocr.h | 37 +- csrc/mmdeploy/codebase/mmocr/panet.cpp | 231 +- csrc/mmdeploy/codebase/mmocr/panet.h | 55 +- csrc/mmdeploy/codebase/mmocr/pixel_group.cpp | 219 +- csrc/mmdeploy/codebase/mmocr/psenet.cpp | 225 +- csrc/mmdeploy/codebase/mmocr/psenet.h | 45 +- .../codebase/mmocr/rescale_to_height.cpp | 123 +- csrc/mmdeploy/codebase/mmocr/resize_ocr.cpp | 161 +- .../mmocr/short_scale_aspect_jitter.cpp | 167 +- csrc/mmdeploy/codebase/mmocr/warp.cpp | 105 +- .../mmpose/keypoints_from_heatmap.cpp | 716 +- .../mmpose/keypoints_from_regression.cpp | 207 +- csrc/mmdeploy/codebase/mmpose/mmpose.cpp | 5 +- csrc/mmdeploy/codebase/mmpose/mmpose.h | 29 +- .../codebase/mmpose/pose_tracker/common.h | 80 +- .../codebase/mmpose/pose_tracker/pipeline.cpp | 207 +- .../mmpose/pose_tracker/pose_tracker.cpp | 792 +- .../mmpose/pose_tracker/pose_tracker.h | 133 +- .../mmpose/pose_tracker/smoothing_filter.cpp | 93 +- .../mmpose/pose_tracker/smoothing_filter.h | 98 +- .../codebase/mmpose/pose_tracker/track.cpp | 124 +- .../codebase/mmpose/pose_tracker/track.h | 125 +- .../mmpose/pose_tracker/tracking_filter.cpp | 410 +- .../mmpose/pose_tracker/tracking_filter.h | 58 +- .../codebase/mmpose/pose_tracker/utils.cpp | 257 +- .../codebase/mmpose/pose_tracker/utils.h | 154 +- csrc/mmdeploy/codebase/mmpose/simcc_label.cpp | 225 +- .../codebase/mmpose/topdown_affine.cpp | 280 +- .../mmpose/topdown_get_bbox_center_scale.cpp | 93 +- csrc/mmdeploy/codebase/mmrotate/mmrotate.cpp | 5 +- csrc/mmdeploy/codebase/mmrotate/mmrotate.h | 31 +- .../mmrotate/oriented_object_detection.cpp | 206 +- csrc/mmdeploy/codebase/mmseg/mmseg.cpp | 5 +- csrc/mmdeploy/codebase/mmseg/mmseg.h | 22 +- csrc/mmdeploy/codebase/mmseg/segment.cpp | 256 +- csrc/mmdeploy/core/archive.h | 220 +- csrc/mmdeploy/core/device.h | 724 +- csrc/mmdeploy/core/device_impl.cpp | 755 +- csrc/mmdeploy/core/device_impl.h | 295 +- csrc/mmdeploy/core/graph.cpp | 264 +- csrc/mmdeploy/core/graph.h | 129 +- csrc/mmdeploy/core/logger.cpp | 118 +- csrc/mmdeploy/core/logger.h | 75 +- csrc/mmdeploy/core/macro.h | 44 +- csrc/mmdeploy/core/mat.cpp | 139 +- csrc/mmdeploy/core/mat.h | 189 +- csrc/mmdeploy/core/model.cpp | 130 +- csrc/mmdeploy/core/model.h | 191 +- csrc/mmdeploy/core/model_impl.h | 80 +- csrc/mmdeploy/core/module.cpp | 5 +- csrc/mmdeploy/core/module.h | 16 +- csrc/mmdeploy/core/mpl/detected.h | 90 +- csrc/mmdeploy/core/mpl/iterator.h | 11 +- csrc/mmdeploy/core/mpl/priority_tag.h | 15 +- csrc/mmdeploy/core/mpl/span.h | 295 +- csrc/mmdeploy/core/mpl/static_any.h | 948 +- csrc/mmdeploy/core/mpl/structure.h | 502 +- csrc/mmdeploy/core/mpl/type_traits.h | 69 +- csrc/mmdeploy/core/net.cpp | 5 +- csrc/mmdeploy/core/net.h | 28 +- csrc/mmdeploy/core/operator.cpp | 427 +- csrc/mmdeploy/core/operator.h | 213 +- csrc/mmdeploy/core/profiler.cpp | 175 +- csrc/mmdeploy/core/profiler.h | 143 +- csrc/mmdeploy/core/registry.cpp | 150 +- csrc/mmdeploy/core/registry.h | 465 +- csrc/mmdeploy/core/serialization.h | 615 +- csrc/mmdeploy/core/status_code.cpp | 76 +- csrc/mmdeploy/core/status_code.h | 258 +- csrc/mmdeploy/core/tensor.cpp | 445 +- csrc/mmdeploy/core/tensor.h | 164 +- csrc/mmdeploy/core/types.h | 85 +- csrc/mmdeploy/core/utils/device_utils.cpp | 81 +- csrc/mmdeploy/core/utils/device_utils.h | 37 +- csrc/mmdeploy/core/utils/filesystem.h | 4 +- csrc/mmdeploy/core/utils/formatter.cpp | 8 +- csrc/mmdeploy/core/utils/formatter.h | 196 +- csrc/mmdeploy/core/utils/source_location.h | 43 +- csrc/mmdeploy/core/utils/stacktrace.cpp | 108 +- csrc/mmdeploy/core/utils/stacktrace.h | 32 +- csrc/mmdeploy/core/value.h | 2553 +++-- csrc/mmdeploy/device/acl/acl_device.cpp | 17 +- csrc/mmdeploy/device/cpu/cpu_device.cpp | 926 +- csrc/mmdeploy/device/cpu/cpu_device.h | 187 +- csrc/mmdeploy/device/cuda/buddy_allocator.h | 355 +- csrc/mmdeploy/device/cuda/cuda_device.cpp | 1119 ++- csrc/mmdeploy/device/cuda/cuda_device.h | 305 +- csrc/mmdeploy/device/cuda/default_allocator.h | 106 +- csrc/mmdeploy/device/cuda/linear_allocator.h | 123 +- csrc/mmdeploy/device/device_allocator.h | 735 +- csrc/mmdeploy/execution/bulk.h | 236 +- csrc/mmdeploy/execution/closure.h | 153 +- csrc/mmdeploy/execution/concepts.h | 276 +- csrc/mmdeploy/execution/dynamic_batch.h | 97 +- csrc/mmdeploy/execution/ensure_started.h | 342 +- csrc/mmdeploy/execution/execute.h | 45 +- csrc/mmdeploy/execution/expand.h | 112 +- csrc/mmdeploy/execution/just.h | 124 +- csrc/mmdeploy/execution/let_value.h | 295 +- csrc/mmdeploy/execution/on.h | 239 +- csrc/mmdeploy/execution/run_loop.h | 339 +- csrc/mmdeploy/execution/schedule_from.h | 270 +- .../schedulers/dynamic_batch_scheduler.h | 556 +- .../execution/schedulers/inlined_scheduler.h | 142 +- .../execution/schedulers/intrusive_queue.h | 186 +- csrc/mmdeploy/execution/schedulers/registry.h | 9 +- .../execution/schedulers/schedulers.cpp | 201 +- .../schedulers/single_thread_context.h | 104 +- .../execution/schedulers/static_thread_pool.h | 731 +- .../schedulers/timed_single_thread_context.h | 417 +- csrc/mmdeploy/execution/split.h | 351 +- csrc/mmdeploy/execution/start_detached.h | 66 +- csrc/mmdeploy/execution/submit.h | 103 +- csrc/mmdeploy/execution/sync_wait.h | 148 +- csrc/mmdeploy/execution/tag_invoke.h | 152 +- csrc/mmdeploy/execution/then.h | 197 +- csrc/mmdeploy/execution/transfer.h | 76 +- csrc/mmdeploy/execution/transfer_just.h | 45 +- csrc/mmdeploy/execution/type_erased.h | 1007 +- csrc/mmdeploy/execution/type_traits.h | 133 +- csrc/mmdeploy/execution/utility.h | 129 +- csrc/mmdeploy/execution/when_all.h | 358 +- csrc/mmdeploy/execution/when_all_value.h | 176 +- csrc/mmdeploy/experimental/module_adapter.h | 229 +- csrc/mmdeploy/graph/common.h | 103 +- csrc/mmdeploy/graph/cond.cpp | 240 +- csrc/mmdeploy/graph/cond.h | 45 +- csrc/mmdeploy/graph/flattened.h | 74 +- csrc/mmdeploy/graph/inference.cpp | 145 +- csrc/mmdeploy/graph/inference.h | 22 +- csrc/mmdeploy/graph/pipeline.cpp | 22 +- csrc/mmdeploy/graph/pipeline.h | 16 +- csrc/mmdeploy/graph/static_router.cpp | 434 +- csrc/mmdeploy/graph/static_router.h | 82 +- csrc/mmdeploy/graph/task.cpp | 131 +- csrc/mmdeploy/graph/task.h | 51 +- csrc/mmdeploy/model/directory_model_impl.cpp | 112 +- csrc/mmdeploy/model/zip_model_impl.cpp | 258 +- csrc/mmdeploy/net/acl/acl_net.cpp | 1334 +-- csrc/mmdeploy/net/acl/acl_net.h | 103 +- csrc/mmdeploy/net/coreml/coreml_net.h | 55 +- csrc/mmdeploy/net/ncnn/ncnn_net.cpp | 309 +- csrc/mmdeploy/net/ncnn/ncnn_net.h | 49 +- csrc/mmdeploy/net/net_module.cpp | 623 +- csrc/mmdeploy/net/net_module.h | 30 +- csrc/mmdeploy/net/openvino/openvino_net.cpp | 522 +- csrc/mmdeploy/net/openvino/openvino_net.h | 46 +- csrc/mmdeploy/net/ort/ort_net.cpp | 405 +- csrc/mmdeploy/net/ort/ort_net.h | 40 +- csrc/mmdeploy/net/ppl/ppl_net.cpp | 714 +- csrc/mmdeploy/net/ppl/ppl_net.h | 62 +- csrc/mmdeploy/net/rknn/rknn_net.cpp | 570 +- csrc/mmdeploy/net/rknn/rknn_net.h | 46 +- csrc/mmdeploy/net/snpe/snpe_net.cpp | 511 +- csrc/mmdeploy/net/snpe/snpe_net.h | 75 +- csrc/mmdeploy/net/torchscript/torch_net.cpp | 444 +- csrc/mmdeploy/net/torchscript/torch_net.h | 48 +- csrc/mmdeploy/net/trt/trt_net.cpp | 462 +- csrc/mmdeploy/net/trt/trt_net.h | 138 +- csrc/mmdeploy/net/tvm/tvm_net.cpp | 575 +- csrc/mmdeploy/net/tvm/tvm_net.h | 56 +- csrc/mmdeploy/operation/cpu/crop.cpp | 27 +- .../operation/cpu/crop_resize_pad.cpp | 31 +- csrc/mmdeploy/operation/cpu/cvtcolor.cpp | 26 +- csrc/mmdeploy/operation/cpu/flip.cpp | 31 +- csrc/mmdeploy/operation/cpu/hwc2chw.cpp | 46 +- csrc/mmdeploy/operation/cpu/normalize.cpp | 49 +- csrc/mmdeploy/operation/cpu/pad.cpp | 65 +- csrc/mmdeploy/operation/cpu/permute.cpp | 167 +- csrc/mmdeploy/operation/cpu/resize.cpp | 37 +- csrc/mmdeploy/operation/cpu/to_float.cpp | 73 +- csrc/mmdeploy/operation/cpu/warp_affine.cpp | 40 +- csrc/mmdeploy/operation/cuda/cast.cu | 44 +- csrc/mmdeploy/operation/cuda/crop.cpp | 118 +- csrc/mmdeploy/operation/cuda/crop.cu | 97 +- .../operation/cuda/crop_resize_pad.cpp | 175 +- csrc/mmdeploy/operation/cuda/cvtcolor.cpp | 249 +- csrc/mmdeploy/operation/cuda/flip.cpp | 119 +- csrc/mmdeploy/operation/cuda/hwc2chw.cpp | 79 +- csrc/mmdeploy/operation/cuda/normalize.cpp | 130 +- csrc/mmdeploy/operation/cuda/normalize.cu | 103 +- csrc/mmdeploy/operation/cuda/pad.cpp | 166 +- csrc/mmdeploy/operation/cuda/permute.cpp | 156 +- csrc/mmdeploy/operation/cuda/permute.cu | 86 +- csrc/mmdeploy/operation/cuda/permute.h | 27 +- csrc/mmdeploy/operation/cuda/resize.cpp | 187 +- csrc/mmdeploy/operation/cuda/to_float.cpp | 65 +- csrc/mmdeploy/operation/cuda/transpose.cu | 76 +- csrc/mmdeploy/operation/cuda/warp_affine.cpp | 237 +- csrc/mmdeploy/operation/dummy/operations.cpp | 177 +- csrc/mmdeploy/operation/managed.h | 399 +- csrc/mmdeploy/operation/operation.cpp | 66 +- csrc/mmdeploy/operation/operation.h | 237 +- csrc/mmdeploy/operation/vision.cpp | 25 +- csrc/mmdeploy/operation/vision.h | 189 +- .../preprocess/elena/elena_registry.cpp | 49 +- .../preprocess/elena/elena_registry.h | 60 +- csrc/mmdeploy/preprocess/elena/fused.cpp | 266 +- .../preprocess/transform/center_crop.cpp | 174 +- .../mmdeploy/preprocess/transform/collect.cpp | 125 +- .../mmdeploy/preprocess/transform/compose.cpp | 175 +- .../transform/default_format_bundle.cpp | 122 +- .../preprocess/transform/image2tensor.cpp | 85 +- .../preprocess/transform/letter_resize.cpp | 292 +- csrc/mmdeploy/preprocess/transform/lift.cpp | 53 +- csrc/mmdeploy/preprocess/transform/load.cpp | 192 +- .../preprocess/transform/normalize.cpp | 261 +- csrc/mmdeploy/preprocess/transform/pad.cpp | 307 +- csrc/mmdeploy/preprocess/transform/resize.cpp | 278 +- .../preprocess/transform/ten_crop.cpp | 165 +- .../preprocess/transform/three_crop.cpp | 180 +- csrc/mmdeploy/preprocess/transform/tracer.cpp | 127 +- csrc/mmdeploy/preprocess/transform/tracer.h | 173 +- .../preprocess/transform/transform.cpp | 39 +- .../mmdeploy/preprocess/transform/transform.h | 33 +- csrc/mmdeploy/preprocess/transform_module.cpp | 94 +- csrc/mmdeploy/utils/dlpack/dlpack_utils.cpp | 366 +- csrc/mmdeploy/utils/dlpack/dlpack_utils.h | 8 +- csrc/mmdeploy/utils/opencv/opencv_utils.cpp | 685 +- csrc/mmdeploy/utils/opencv/opencv_utils.h | 311 +- demo/csrc/c/batch_image_classification.cpp | 173 +- demo/csrc/c/batch_object_detection.cpp | 275 +- demo/csrc/c/det_cls.cpp | 121 +- demo/csrc/c/det_pose.cpp | 232 +- demo/csrc/c/image_classification.cpp | 93 +- demo/csrc/c/image_restorer.cpp | 90 +- demo/csrc/c/image_segmentation.cpp | 141 +- demo/csrc/c/object_detection.cpp | 158 +- demo/csrc/c/ocr.cpp | 136 +- demo/csrc/c/pose_detection.cpp | 93 +- demo/csrc/c/rotated_object_detection.cpp | 119 +- demo/csrc/c/video_recognition.cpp | 207 +- demo/csrc/cpp/classifier.cxx | 86 +- demo/csrc/cpp/det_pose.cxx | 95 +- demo/csrc/cpp/detector.cxx | 85 +- demo/csrc/cpp/pose_detector.cxx | 77 +- demo/csrc/cpp/pose_tracker.cxx | 88 +- demo/csrc/cpp/pose_tracker_params.h | 44 +- demo/csrc/cpp/restorer.cxx | 77 +- demo/csrc/cpp/rotated_detector.cxx | 85 +- demo/csrc/cpp/segmentor.cxx | 84 +- demo/csrc/cpp/text_det_recog.cxx | 57 +- demo/csrc/cpp/text_ocr.cxx | 81 +- demo/csrc/cpp/utils/argparse.h | 528 +- demo/csrc/cpp/utils/mediaio.h | 920 +- demo/csrc/cpp/utils/palette.h | 245 +- demo/csrc/cpp/utils/skeleton.h | 413 +- demo/csrc/cpp/utils/visualize.h | 498 +- demo/csrc/cpp/video_cls.cxx | 124 +- tests/test_csrc/archive/test_json_archive.cpp | 85 +- .../test_csrc/archive/test_value_archive.cpp | 185 +- tests/test_csrc/capi/test_classifier.cpp | 94 +- tests/test_csrc/capi/test_detector.cpp | 103 +- tests/test_csrc/capi/test_model.cpp | 40 +- tests/test_csrc/capi/test_restorer.cpp | 90 +- tests/test_csrc/capi/test_segmentor.cpp | 92 +- tests/test_csrc/capi/test_text_detector.cpp | 98 +- tests/test_csrc/capi/test_text_recognizer.cpp | 210 +- tests/test_csrc/core/test_execution.cpp | 800 +- tests/test_csrc/core/test_mat.cpp | 169 +- tests/test_csrc/core/test_module_adapter.cpp | 49 +- tests/test_csrc/core/test_registry.cpp | 165 +- tests/test_csrc/core/test_span.cpp | 155 +- tests/test_csrc/core/test_status_code.cpp | 66 +- tests/test_csrc/core/test_value.cpp | 581 +- tests/test_csrc/device/test_cpu_device.cpp | 52 +- tests/test_csrc/device/test_cuda_device.cpp | 56 +- tests/test_csrc/device/test_opencl_device.cpp | 56 +- tests/test_csrc/graph/test_cond.cpp | 86 +- .../test_csrc/model/test_directory_model.cpp | 47 +- tests/test_csrc/model/test_model.cpp | 74 +- tests/test_csrc/model/test_zip_model.cpp | 70 +- tests/test_csrc/net/test_ncnn_net.cpp | 31 +- tests/test_csrc/net/test_openvino_net.cpp | 31 +- tests/test_csrc/net/test_ort_net.cpp | 31 +- tests/test_csrc/net/test_ppl_net.cpp | 25 +- tests/test_csrc/net/test_trt_net.cpp | 31 +- tests/test_csrc/preprocess/test_collect.cpp | 172 +- tests/test_csrc/preprocess/test_compose.cpp | 62 +- tests/test_csrc/preprocess/test_crop.cpp | 180 +- .../preprocess/test_default_format_bundle.cpp | 90 +- .../preprocess/test_image2tensor.cpp | 96 +- tests/test_csrc/preprocess/test_load.cpp | 114 +- tests/test_csrc/preprocess/test_normalize.cpp | 151 +- tests/test_csrc/preprocess/test_pad.cpp | 184 +- tests/test_csrc/preprocess/test_permute.cpp | 181 +- tests/test_csrc/preprocess/test_resize.cpp | 527 +- tests/test_csrc/preprocess/test_utils.cpp | 121 +- tests/test_csrc/preprocess/test_utils.h | 28 +- tests/test_csrc/test_resource.h | 265 +- third_party/clipper/clipper.cpp | 8871 +++++++++-------- third_party/clipper/clipper.hpp | 823 +- third_party/concurrentqueue/concurrentqueue.h | 7634 +++++++------- third_party/dlpack/dlpack.h | 374 +- 559 files changed, 72407 insertions(+), 59911 deletions(-) mode change 100755 => 100644 csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.cpp mode change 100755 => 100644 csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.h mode change 100755 => 100644 csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.cpp mode change 100755 => 100644 csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.h mode change 100755 => 100644 csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.h mode change 100755 => 100644 csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_definer.h mode change 100755 => 100644 csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.cpp mode change 100755 => 100644 csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.h mode change 100755 => 100644 csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.cpp mode change 100755 => 100644 csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.h mode change 100755 => 100644 csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.h mode change 100755 => 100644 csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/ncnn_ext.cpp diff --git a/.clang-format b/.clang-format index c7370bb66a..018938c588 100644 --- a/.clang-format +++ b/.clang-format @@ -1,156 +1,255 @@ ---- -Language: Cpp -# BasedOnStyle: Google -AccessModifierOffset: -1 -AlignAfterOpenBracket: Align -AlignConsecutiveMacros: false -AlignConsecutiveAssignments: false -AlignConsecutiveDeclarations: false -AlignEscapedNewlines: Left -AlignOperands: true -AlignTrailingComments: true -AllowAllArgumentsOnNextLine: true -AllowAllConstructorInitializersOnNextLine: true -AllowAllParametersOfDeclarationOnNextLine: true -AllowShortBlocksOnASingleLine: false -AllowShortCaseLabelsOnASingleLine: false -AllowShortFunctionsOnASingleLine: All -AllowShortLambdasOnASingleLine: All -AllowShortIfStatementsOnASingleLine: WithoutElse -AllowShortLoopsOnASingleLine: true -AlwaysBreakAfterDefinitionReturnType: None -AlwaysBreakAfterReturnType: None -AlwaysBreakBeforeMultilineStrings: true -AlwaysBreakTemplateDeclarations: Yes -BinPackArguments: true -BinPackParameters: true -BraceWrapping: - AfterCaseLabel: false - AfterClass: false - AfterControlStatement: false - AfterEnum: false - AfterFunction: false - AfterNamespace: false - AfterObjCDeclaration: false - AfterStruct: false - AfterUnion: false - AfterExternBlock: false - BeforeCatch: false - BeforeElse: false - IndentBraces: false +# reference from https://clang.llvm.org/docs/ClangFormatStyleOptions.html + +# 关闭格式化 +DisableFormat: false + +# 基础格式化方案 +BasedOnStyle: LLVM + +# 语言: None, Cpp, Java, JavaScript, ObjC, Proto, TableGen, TextProto +Language: Cpp + +# 标准: Cpp03, Cpp11, Auto +Standard: Cpp11 + +# tab宽度 +TabWidth: 4 + +# 使用tab字符: Never, ForIndentation, ForContinuationAndIndentation, Always +UseTab: Never + +# 访问说明符(public、private等)的偏移 +AccessModifierOffset: -2 + +# 缩进宽度 +IndentWidth: 4 + +# 构造函数的初始化列表的缩进宽度 +ConstructorInitializerIndentWidth: 4 + +# 延续的行的最小缩进宽度 +ContinuationIndentWidth: 4 + +# 缩进case标签 +IndentCaseLabels: true + +# 函数返回类型换行时,缩进函数声明或函数定义的函数名 +IndentWrappedFunctionNames: true + +# 命名空间的缩进: None, Inner(缩进嵌套的命名空间中的内容), All +NamespaceIndentation: All + +# 预处理缩进, None, AfterHash, BeforeHash +IndentPPDirectives: BeforeHash + +# 开括号(开圆括号、开尖括号、开方括号)后的对齐: Align, DontAlign, AlwaysBreak(总是在开括号后换行) +AlignAfterOpenBracket: Align + +# 连续赋值时,对齐所有等号 +#AlignConsecutiveAssignments: AcrossEmptyLinesAndComments +AlignConsecutiveAssignments: AcrossComments + +# 连续声明时,对齐所有声明的变量名 +AlignConsecutiveDeclarations: AcrossEmptyLinesAndComments +#AlignConsecutiveDeclarations: AcrossComments + +#AlignEscapedNewlines: Right + +# 左对齐逃脱换行(使用反斜杠换行)的反斜杠 +#AlignEscapedNewlinesLeft: true + +# 水平对齐二元和三元表达式的操作数 +AlignOperands: true + +# 对齐连续的尾随的注释 +AlignTrailingComments: true + +# 指针和引用的对齐: Left, Right, Middle +PointerAlignment: Left + +# 继承最常用的指针和引用的对齐方式 +DerivePointerAlignment: false + +# 允许函数声明的所有参数在放在下一行 +AllowAllParametersOfDeclarationOnNextLine: false + +# false表示函数实参要么都在同一行,要么都各自一行 +BinPackArguments: false + +# false表示所有形参要么都在同一行,要么都各自一行 +BinPackParameters: false + +# 允许函数调用的所有参数在放在下一行,即使BinPackParameters为false +AllowAllArgumentsOnNextLine: false + +# 允许短的块放在同一行 +AllowShortBlocksOnASingleLine: true + +# 允许短的case标签放在同一行 +AllowShortCaseLabelsOnASingleLine: true + +# 允许短的函数放在同一行: None, InlineOnly(定义在类中), Empty(空函数), Inline(定义在类中,空函数), All +AllowShortFunctionsOnASingleLine: Empty + +# 允许短的if语句保持在同一行 +AllowShortIfStatementsOnASingleLine: true + +# 允许短的循环保持在同一行 +AllowShortLoopsOnASingleLine: true + +# 总是在定义返回类型后换行(deprecated) +AlwaysBreakAfterDefinitionReturnType: None + +# 总是在返回类型后换行: None, All, TopLevel(顶级函数,不包括在类中的函数), +# AllDefinitions(所有的定义,不包括声明), TopLevelDefinitions(所有的顶级函数的定义) +AlwaysBreakAfterReturnType: None + +# 总是在多行string字面量前换行 +AlwaysBreakBeforeMultilineStrings: false + +# 总是在template声明后换行 +AlwaysBreakTemplateDeclarations: true + +# 构造函数的初始化列表要么都在同一行,要么都各自一行 +ConstructorInitializerAllOnOneLineOrOnePerLine: false + +# 构造函数的初始化列表的逗号和分号在前,对齐参数 +BreakConstructorInitializers: BeforeComma + +# 自动检测函数的调用和定义是否被格式为每行一个参数(Experimental) +ExperimentalAutoDetectBinPacking: true + +# 去除C++11的列表初始化的大括号{后和}前的空格 +Cpp11BracedListStyle: true + +# 大括号换行,只有当BreakBeforeBraces设置为Custom时才有效 +BraceWrapping: + # class定义后面 + AfterClass: true + # 控制语句后面 + AfterControlStatement: true + # enum定义后面 + AfterEnum: true + # 函数定义后面 + AfterFunction: true + # 命名空间定义后面 + AfterNamespace: true + # ObjC定义后面 + AfterObjCDeclaration: true + # struct定义后面 + AfterStruct: true + # union定义后面 + AfterUnion: true + AfterExternBlock: true + # catch之前 + BeforeCatch: true + # else之前 + BeforeElse: true + # 缩进大括号 + IndentBraces: false SplitEmptyFunction: true SplitEmptyRecord: true SplitEmptyNamespace: true -BreakBeforeBinaryOperators: None -BreakBeforeBraces: Attach -BreakBeforeInheritanceComma: false -BreakInheritanceList: BeforeColon -BreakBeforeTernaryOperators: true -BreakConstructorInitializersBeforeComma: false -BreakConstructorInitializers: BeforeColon -BreakAfterJavaFieldAnnotations: false -BreakStringLiterals: true -ColumnLimit: 100 -CommentPragmas: '^ IWYU pragma:' + +# 在二元运算符前换行: None(在操作符后换行), NonAssignment(在非赋值的操作符前换行), All(在操作符前换行) +BreakBeforeBinaryOperators: None + +# 在大括号前换行: Attach(始终将大括号附加到周围的上下文), Linux(除函数、命名空间和类定义,与Attach类似), +# Mozilla(除枚举、函数、记录定义,与Attach类似), Stroustrup(除函数定义、catch、else,与Attach类似), +# Allman(总是在大括号前换行), GNU(总是在大括号前换行,并对于控制语句的大括号增加额外的缩进), WebKit(在函数前换行), Custom +# 注:这里认为语句块也属于函数 +BreakBeforeBraces: Allman + +# 在三元运算符前换行 +BreakBeforeTernaryOperators: false + +# 字符串字面值换行 +BreakStringLiterals: false + +# 每行字符的限制,0表示没有限制 +ColumnLimit: 0 + +# 赋值对齐换行的penalty +PenaltyBreakAssignment: 100 + +# 在call(后对函数调用换行的penalty +PenaltyBreakBeforeFirstCallParameter: 100 + +# 在一个注释中引入换行的penalty +PenaltyBreakComment: 100 + +# 第一次在<<前换行的penalty +PenaltyBreakFirstLessLess: 100 + +# 在一个字符串字面量中引入换行的penalty +PenaltyBreakString: 100 + +# 对于每个在行字符数限制之外的字符的penalty +PenaltyExcessCharacter: 100 + +# 将函数的返回类型放到它自己的行的penalty +PenaltyReturnTypeOnItsOwnLine: 100 + +# 在C风格类型转换后添加空格 +SpaceAfterCStyleCast: false + +# 在模板 template 关键字后面添加空格 +SpaceAfterTemplateKeyword: false + +# 在赋值运算符之前添加空格 +SpaceBeforeAssignmentOperators: true + +# 开圆括号之前添加一个空格: Never, ControlStatements, Always +SpaceBeforeParens: ControlStatements + +# 在尾随的评论前添加的空格数(只适用于//) +SpacesBeforeTrailingComments: 2 + +# 在尖括号的<后和>前添加空格 +SpacesInAngles: false + +# 在容器(ObjC和JavaScript的数组和字典等)字面量中添加空格 +SpacesInContainerLiterals: false + +# 在C风格类型转换的括号中添加空格 +SpacesInCStyleCastParentheses: false + +# 在圆括号的(后和)前添加空格 +SpacesInParentheses: false + +# 在空的圆括号中添加空格 +SpaceInEmptyParentheses: false + +# 在方括号的[后和]前添加空格,lamda表达式和未指明大小的数组的声明不受影响 +SpacesInSquareBrackets: false + +# 单行最多允许的连续空格? +PenaltyIndentedWhitespace: 10 + +# 描述具有特殊意义的注释的正则表达式,它不应该被分割为多行或以其它方式改变 +CommentPragmas: '^ IWYU pragma:' + +# 连续 namespace CompactNamespaces: false -ConstructorInitializerAllOnOneLineOrOnePerLine: true -ConstructorInitializerIndentWidth: 4 -ContinuationIndentWidth: 4 -Cpp11BracedListStyle: true -DerivePointerAlignment: true -DisableFormat: false -ExperimentalAutoDetectBinPacking: false -FixNamespaceComments: true -ForEachMacros: - - foreach - - Q_FOREACH - - BOOST_FOREACH -IncludeBlocks: Regroup -IncludeCategories: - - Regex: '^' - Priority: 2 - - Regex: '^<.*\.h>' - Priority: 1 - - Regex: '^<.*' - Priority: 2 - - Regex: '.*' - Priority: 3 -IncludeIsMainRegex: '([-_](test|unittest))?$' -IndentCaseLabels: true -IndentPPDirectives: None -IndentWidth: 2 -IndentWrappedFunctionNames: false -JavaScriptQuotes: Leave -JavaScriptWrapImports: true -KeepEmptyLinesAtTheStartOfBlocks: false -MacroBlockBegin: '' -MacroBlockEnd: '' -MaxEmptyLinesToKeep: 1 -NamespaceIndentation: None -ObjCBinPackProtocolList: Never -ObjCBlockIndentWidth: 2 -ObjCSpaceAfterProperty: false -ObjCSpaceBeforeProtocolList: true -PenaltyBreakAssignment: 2 -PenaltyBreakBeforeFirstCallParameter: 1 -PenaltyBreakComment: 300 -PenaltyBreakFirstLessLess: 120 -PenaltyBreakString: 1000 -PenaltyBreakTemplateDeclaration: 10 -PenaltyExcessCharacter: 1000000 -PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left -RawStringFormats: - - Language: Cpp - Delimiters: - - cc - - CC - - cpp - - Cpp - - CPP - - 'c++' - - 'C++' - CanonicalDelimiter: '' - BasedOnStyle: google - - Language: TextProto - Delimiters: - - pb - - PB - - proto - - PROTO - EnclosingFunctions: - - EqualsProto - - EquivToProto - - PARSE_PARTIAL_TEXT_PROTO - - PARSE_TEST_PROTO - - PARSE_TEXT_PROTO - - ParseTextOrDie - - ParseTextProtoOrDie - CanonicalDelimiter: '' - BasedOnStyle: google -ReflowComments: true -SortIncludes: true -SortUsingDeclarations: true -SpaceAfterCStyleCast: false -SpaceAfterLogicalNot: false -SpaceAfterTemplateKeyword: true -SpaceBeforeAssignmentOperators: true -SpaceBeforeCpp11BracedList: false -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SpaceInEmptyParentheses: false -SpacesBeforeTrailingComments: 2 -SpacesInAngles: false -SpacesInContainerLiterals: true -SpacesInCStyleCastParentheses: false -SpacesInParentheses: false -SpacesInSquareBrackets: false -Standard: Auto -StatementMacros: - - Q_UNUSED - - QT_REQUIRE_VERSION -TabWidth: 8 -UseTab: Never -... + +# 保留在块开始处的空行 +KeepEmptyLinesAtTheStartOfBlocks: false + +# 连续空行的最大数量 +MaxEmptyLinesToKeep: 2 + +# 允许重新排版注释 +ReflowComments: true + +# 允许排序#include +SortIncludes: false + +# 对#include进行排序,匹配了某正则表达式的#include拥有对应的优先级,匹配不到的则默认优先级为INT_MAX(优先级越小排序越靠前), +# 可以定义负数优先级从而保证某些#include永远在最前面 +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + - Regex: '^(<|"(gtest|isl|json)/)' + Priority: 3 + - Regex: '.*' + Priority: 1 \ No newline at end of file diff --git a/csrc/mmdeploy/apis/c/mmdeploy/classifier.cpp b/csrc/mmdeploy/apis/c/mmdeploy/classifier.cpp index 3eec4ef90b..9faf47f349 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/classifier.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/classifier.cpp @@ -16,118 +16,132 @@ using namespace mmdeploy; using namespace std; -int mmdeploy_classifier_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_classifier_t* classifier) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_classifier_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_classifier_t* classifier) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_classifier_create_v2(model, context, classifier); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_classifier_create_v2(model, context, classifier); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_classifier_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_classifier_t* classifier) { - mmdeploy_model_t model{}; +int mmdeploy_classifier_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_classifier_t* classifier) +{ + mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_classifier_create(model, device_name, device_id, classifier); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_classifier_create(model, device_name, device_id, classifier); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_classifier_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_classifier_t* classifier) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)classifier); +int mmdeploy_classifier_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_classifier_t* classifier) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)classifier); } -int mmdeploy_classifier_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value) { - return mmdeploy_common_create_input(mats, mat_count, value); +int mmdeploy_classifier_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value) +{ + return mmdeploy_common_create_input(mats, mat_count, value); } -int mmdeploy_classifier_apply(mmdeploy_classifier_t classifier, const mmdeploy_mat_t* mats, - int mat_count, mmdeploy_classification_t** results, - int** result_count) { - wrapped input; - if (auto ec = mmdeploy_classifier_create_input(mats, mat_count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_classifier_apply_v2(classifier, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_classifier_get_result(output, results, result_count)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_classifier_apply(mmdeploy_classifier_t classifier, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_classification_t** results, int** result_count) +{ + wrapped input; + if (auto ec = mmdeploy_classifier_create_input(mats, mat_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_classifier_apply_v2(classifier, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_classifier_get_result(output, results, result_count)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_classifier_apply_v2(mmdeploy_classifier_t classifier, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)classifier, input, output); +int mmdeploy_classifier_apply_v2(mmdeploy_classifier_t classifier, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)classifier, input, output); } -int mmdeploy_classifier_apply_async(mmdeploy_classifier_t classifier, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)classifier, input, output); +int mmdeploy_classifier_apply_async(mmdeploy_classifier_t classifier, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)classifier, input, output); } -int mmdeploy_classifier_get_result(mmdeploy_value_t output, mmdeploy_classification_t** results, - int** result_count) { - if (!output || !results || !result_count) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - Value& value = Cast(output)->front(); - - auto classify_outputs = from_value>(value); - - vector _result_count; - _result_count.reserve(classify_outputs.size()); - - for (const auto& cls_output : classify_outputs) { - _result_count.push_back((int)cls_output.size()); +int mmdeploy_classifier_get_result(mmdeploy_value_t output, mmdeploy_classification_t** results, int** result_count) +{ + if (!output || !results || !result_count) + { + return MMDEPLOY_E_INVALID_ARG; } - - auto total = std::accumulate(begin(_result_count), end(_result_count), 0); - - std::unique_ptr result_count_data(new int[_result_count.size()]{}); - std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); - - std::unique_ptr result_data( - new mmdeploy_classification_t[total]{}); - auto result_ptr = result_data.get(); - for (const auto& cls_output : classify_outputs) { - for (const auto& label : cls_output) { - result_ptr->label_id = label.label_id; - result_ptr->score = label.score; - ++result_ptr; - } + try + { + Value& value = Cast(output)->front(); + + auto classify_outputs = from_value>(value); + + vector _result_count; + _result_count.reserve(classify_outputs.size()); + + for (const auto& cls_output : classify_outputs) + { + _result_count.push_back((int)cls_output.size()); + } + + auto total = std::accumulate(begin(_result_count), end(_result_count), 0); + + std::unique_ptr result_count_data(new int[_result_count.size()]{}); + std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); + + std::unique_ptr result_data( + new mmdeploy_classification_t[total]{}); + auto result_ptr = result_data.get(); + for (const auto& cls_output : classify_outputs) + { + for (const auto& label : cls_output) + { + result_ptr->label_id = label.label_id; + result_ptr->score = label.score; + ++result_ptr; + } + } + + *result_count = result_count_data.release(); + *results = result_data.release(); + + return MMDEPLOY_SUCCESS; } - - *result_count = result_count_data.release(); - *results = result_data.release(); - - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -void mmdeploy_classifier_release_result(mmdeploy_classification_t* results, const int* result_count, - int count) { - delete[] results; - delete[] result_count; +void mmdeploy_classifier_release_result(mmdeploy_classification_t* results, const int* result_count, int count) +{ + delete[] results; + delete[] result_count; } -void mmdeploy_classifier_destroy(mmdeploy_classifier_t classifier) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)classifier); +void mmdeploy_classifier_destroy(mmdeploy_classifier_t classifier) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)classifier); } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/classifier.h b/csrc/mmdeploy/apis/c/mmdeploy/classifier.h index 54e9d0215b..1681cf7fae 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/classifier.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/classifier.h @@ -13,124 +13,125 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_classification_t { - int label_id; - float score; -} mmdeploy_classification_t; - -typedef struct mmdeploy_classifier* mmdeploy_classifier_t; - -/** - * @brief Create classifier's handle - * @param[in] model an instance of mmclassification sdk model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] classifier instance of a classifier, which must be destroyed - * by \ref mmdeploy_classifier_destroy - * @return status of creating classifier's handle - */ -MMDEPLOY_API int mmdeploy_classifier_create(mmdeploy_model_t model, const char* device_name, - int device_id, mmdeploy_classifier_t* classifier); - -/** - * @brief Create classifier's handle - * @param[in] model_path path of mmclassification sdk model exported by mmdeploy model converter - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] classifier instance of a classifier, which must be destroyed - * by \ref mmdeploy_classifier_destroy - * @return status of creating classifier's handle - */ -MMDEPLOY_API int mmdeploy_classifier_create_by_path(const char* model_path, const char* device_name, - int device_id, - mmdeploy_classifier_t* classifier); - -/** - * @brief Use classifier created by \ref mmdeploy_classifier_create_by_path to get label - * information of each image in a batch - * @param[in] classifier classifier's handle created by \ref mmdeploy_classifier_create_by_path - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @param[out] results a linear buffer to save classification results of each - * image, which must be freed by \ref mmdeploy_classifier_release_result - * @param[out] result_count a linear buffer with length being \p mat_count to save the number of - * classification results of each image. It must be released by \ref - * mmdeploy_classifier_release_result - * @return status of inference - */ -MMDEPLOY_API int mmdeploy_classifier_apply(mmdeploy_classifier_t classifier, - const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_classification_t** results, int** result_count); - -/** - * @brief Release the inference result buffer created \ref mmdeploy_classifier_apply - * @param[in] results classification results buffer - * @param[in] result_count \p results size buffer - * @param[in] count length of \p result_count - */ -MMDEPLOY_API void mmdeploy_classifier_release_result(mmdeploy_classification_t* results, - const int* result_count, int count); - -/** - * @brief Destroy classifier's handle - * @param[in] classifier classifier's handle created by \ref mmdeploy_classifier_create_by_path - */ -MMDEPLOY_API void mmdeploy_classifier_destroy(mmdeploy_classifier_t classifier); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -/** - * @brief Same as \ref mmdeploy_classifier_create, but allows to control execution context of tasks - * via context - */ -MMDEPLOY_API int mmdeploy_classifier_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_classifier_t* classifier); - -/** - * @brief Pack classifier inputs into mmdeploy_value_t - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @param[out] value the packed value - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_classifier_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value); - -/** - * @brief Same as \ref mmdeploy_classifier_apply, but input and output are packed in \ref - * mmdeploy_value_t. - */ -MMDEPLOY_API int mmdeploy_classifier_apply_v2(mmdeploy_classifier_t classifier, - mmdeploy_value_t input, mmdeploy_value_t* output); - -/** - * @brief Apply classifier asynchronously - * @param[in] classifier handle of the classifier - * @param[in] input input sender that will be consumed by the operation - * @param[out] output output sender - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_classifier_apply_async(mmdeploy_classifier_t classifier, - mmdeploy_sender_t input, - mmdeploy_sender_t* output); - -/** - * - * @param[in] output output obtained by applying a classifier - * @param[out] results a linear buffer containing classification results of each image, released by - * \ref mmdeploy_classifier_release_result - * @param[out] result_count a linear buffer containing the number of results for each input image, - * released by \ref mmdeploy_classifier_release_result - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_classifier_get_result(mmdeploy_value_t output, - mmdeploy_classification_t** results, - int** result_count); + typedef struct mmdeploy_classification_t + { + int label_id; + float score; + } mmdeploy_classification_t; + + typedef struct mmdeploy_classifier* mmdeploy_classifier_t; + + /** + * @brief Create classifier's handle + * @param[in] model an instance of mmclassification sdk model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] classifier instance of a classifier, which must be destroyed + * by \ref mmdeploy_classifier_destroy + * @return status of creating classifier's handle + */ + MMDEPLOY_API int mmdeploy_classifier_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_classifier_t* classifier); + + /** + * @brief Create classifier's handle + * @param[in] model_path path of mmclassification sdk model exported by mmdeploy model converter + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] classifier instance of a classifier, which must be destroyed + * by \ref mmdeploy_classifier_destroy + * @return status of creating classifier's handle + */ + MMDEPLOY_API int mmdeploy_classifier_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_classifier_t* classifier); + + /** + * @brief Use classifier created by \ref mmdeploy_classifier_create_by_path to get label + * information of each image in a batch + * @param[in] classifier classifier's handle created by \ref mmdeploy_classifier_create_by_path + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @param[out] results a linear buffer to save classification results of each + * image, which must be freed by \ref mmdeploy_classifier_release_result + * @param[out] result_count a linear buffer with length being \p mat_count to save the number of + * classification results of each image. It must be released by \ref + * mmdeploy_classifier_release_result + * @return status of inference + */ + MMDEPLOY_API int mmdeploy_classifier_apply(mmdeploy_classifier_t classifier, + const mmdeploy_mat_t* mats, + int mat_count, + mmdeploy_classification_t** results, + int** result_count); + + /** + * @brief Release the inference result buffer created \ref mmdeploy_classifier_apply + * @param[in] results classification results buffer + * @param[in] result_count \p results size buffer + * @param[in] count length of \p result_count + */ + MMDEPLOY_API void mmdeploy_classifier_release_result(mmdeploy_classification_t* results, + const int* result_count, + int count); + + /** + * @brief Destroy classifier's handle + * @param[in] classifier classifier's handle created by \ref mmdeploy_classifier_create_by_path + */ + MMDEPLOY_API void mmdeploy_classifier_destroy(mmdeploy_classifier_t classifier); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + /** + * @brief Same as \ref mmdeploy_classifier_create, but allows to control execution context of tasks + * via context + */ + MMDEPLOY_API int mmdeploy_classifier_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_classifier_t* classifier); + + /** + * @brief Pack classifier inputs into mmdeploy_value_t + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @param[out] value the packed value + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_classifier_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value); + + /** + * @brief Same as \ref mmdeploy_classifier_apply, but input and output are packed in \ref + * mmdeploy_value_t. + */ + MMDEPLOY_API int mmdeploy_classifier_apply_v2(mmdeploy_classifier_t classifier, + mmdeploy_value_t input, + mmdeploy_value_t* output); + + /** + * @brief Apply classifier asynchronously + * @param[in] classifier handle of the classifier + * @param[in] input input sender that will be consumed by the operation + * @param[out] output output sender + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_classifier_apply_async(mmdeploy_classifier_t classifier, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + /** + * + * @param[in] output output obtained by applying a classifier + * @param[out] results a linear buffer containing classification results of each image, released by + * \ref mmdeploy_classifier_release_result + * @param[out] result_count a linear buffer containing the number of results for each input image, + * released by \ref mmdeploy_classifier_release_result + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_classifier_get_result(mmdeploy_value_t output, + mmdeploy_classification_t** results, + int** result_count); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/common.cpp b/csrc/mmdeploy/apis/c/mmdeploy/common.cpp index e00cc3f1cf..fff83da181 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/common.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/common.cpp @@ -5,111 +5,142 @@ #include "mmdeploy/core/profiler.h" #include "mmdeploy/executor_internal.h" -mmdeploy_value_t mmdeploy_value_copy(mmdeploy_value_t value) { - if (!value) { - return nullptr; - } - return Guard([&] { return Take(Value(*Cast(value))); }); +mmdeploy_value_t mmdeploy_value_copy(mmdeploy_value_t value) +{ + if (!value) + { + return nullptr; + } + return Guard([&] + { return Take(Value(*Cast(value))); }); } -void mmdeploy_value_destroy(mmdeploy_value_t value) { delete Cast(value); } +void mmdeploy_value_destroy(mmdeploy_value_t value) +{ + delete Cast(value); +} -int mmdeploy_context_create(mmdeploy_context_t* context) { - *context = (mmdeploy_context_t) new Value; - return 0; +int mmdeploy_context_create(mmdeploy_context_t* context) +{ + *context = (mmdeploy_context_t) new Value; + return 0; } -int mmdeploy_context_create_by_device(const char* device_name, int device_id, - mmdeploy_context_t* context) { - mmdeploy_device_t device{}; - int ec = MMDEPLOY_SUCCESS; - mmdeploy_context_t _context{}; - ec = mmdeploy_context_create(&_context); - if (ec != MMDEPLOY_SUCCESS) { - return ec; - } - ec = mmdeploy_device_create(device_name, device_id, &device); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_context_create_by_device(const char* device_name, int device_id, mmdeploy_context_t* context) +{ + mmdeploy_device_t device{}; + int ec = MMDEPLOY_SUCCESS; + mmdeploy_context_t _context{}; + ec = mmdeploy_context_create(&_context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_device_create(device_name, device_id, &device); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_context_add(_context, MMDEPLOY_TYPE_DEVICE, nullptr, device); + mmdeploy_device_destroy(device); + if (ec == MMDEPLOY_SUCCESS) + { + *context = _context; + } return ec; - } - ec = mmdeploy_context_add(_context, MMDEPLOY_TYPE_DEVICE, nullptr, device); - mmdeploy_device_destroy(device); - if (ec == MMDEPLOY_SUCCESS) { - *context = _context; - } - return ec; } -void mmdeploy_context_destroy(mmdeploy_context_t context) { delete Cast(context); } +void mmdeploy_context_destroy(mmdeploy_context_t context) +{ + delete Cast(context); +} -int mmdeploy_common_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value) { - if (mat_count && mats == nullptr) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - auto input = std::make_unique(Value{Value::kArray}); - for (int i = 0; i < mat_count; ++i) { - input->front().push_back({{"ori_img", Cast(mats[i])}}); +int mmdeploy_common_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value) +{ + if (mat_count && mats == nullptr) + { + return MMDEPLOY_E_INVALID_ARG; } - *value = Cast(input.release()); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_SUCCESS; + try + { + auto input = std::make_unique(Value{Value::kArray}); + for (int i = 0; i < mat_count; ++i) + { + input->front().push_back({{"ori_img", Cast(mats[i])}}); + } + *value = Cast(input.release()); + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_device_create(const char* device_name, int device_id, mmdeploy_device_t* device) { - Device tmp(device_name, device_id); - if (tmp.platform_id() == -1) { - MMDEPLOY_ERROR("Device \"{}\" not found", device_name); - return MMDEPLOY_E_INVALID_ARG; - } - *device = (mmdeploy_device_t) new Device(tmp); - return MMDEPLOY_SUCCESS; +int mmdeploy_device_create(const char* device_name, int device_id, mmdeploy_device_t* device) +{ + Device tmp(device_name, device_id); + if (tmp.platform_id() == -1) + { + MMDEPLOY_ERROR("Device \"{}\" not found", device_name); + return MMDEPLOY_E_INVALID_ARG; + } + *device = (mmdeploy_device_t) new Device(tmp); + return MMDEPLOY_SUCCESS; } -void mmdeploy_device_destroy(mmdeploy_device_t device) { delete (Device*)device; } - -int mmdeploy_profiler_create(const char* path, mmdeploy_profiler_t* profiler) { - *profiler = (mmdeploy_profiler_t) new profiler::Profiler(path); - return MMDEPLOY_SUCCESS; +void mmdeploy_device_destroy(mmdeploy_device_t device) +{ + delete (Device*)device; } -void mmdeploy_profiler_destroy(mmdeploy_profiler_t profiler) { - if (profiler) { - auto p = (profiler::Profiler*)profiler; - p->Release(); - delete p; - } +int mmdeploy_profiler_create(const char* path, mmdeploy_profiler_t* profiler) +{ + *profiler = (mmdeploy_profiler_t) new profiler::Profiler(path); + return MMDEPLOY_SUCCESS; } -int mmdeploy_context_add(mmdeploy_context_t context, mmdeploy_context_type_t type, const char* name, - const void* object) { - auto& ctx = *Cast(context); - switch (type) { - case MMDEPLOY_TYPE_DEVICE: { - const auto& device = *(Device*)object; - ctx["device"] = device; - ctx["stream"] = Stream(device); - break; +void mmdeploy_profiler_destroy(mmdeploy_profiler_t profiler) +{ + if (profiler) + { + auto p = (profiler::Profiler*)profiler; + p->Release(); + delete p; } - case MMDEPLOY_TYPE_SCHEDULER: - ctx["scheduler"][name] = *Cast((const mmdeploy_scheduler_t)object); - break; - case MMDEPLOY_TYPE_MODEL: - ctx["model"][name] = *Cast((const mmdeploy_model_t)object); - break; - case MMDEPLOY_TYPE_PROFILER: { - const auto& profiler = *(profiler::Profiler*)object; - profiler::Scope* root(profiler.scope()); - ctx["scope"] = root; - break; +} + +int mmdeploy_context_add(mmdeploy_context_t context, mmdeploy_context_type_t type, const char* name, const void* object) +{ + auto& ctx = *Cast(context); + switch (type) + { + case MMDEPLOY_TYPE_DEVICE: + { + const auto& device = *(Device*)object; + ctx["device"] = device; + ctx["stream"] = Stream(device); + break; + } + case MMDEPLOY_TYPE_SCHEDULER: + ctx["scheduler"][name] = *Cast((const mmdeploy_scheduler_t)object); + break; + case MMDEPLOY_TYPE_MODEL: + ctx["model"][name] = *Cast((const mmdeploy_model_t)object); + break; + case MMDEPLOY_TYPE_PROFILER: + { + const auto& profiler = *(profiler::Profiler*)object; + profiler::Scope* root(profiler.scope()); + ctx["scope"] = root; + break; + } + default: + return MMDEPLOY_E_NOT_SUPPORTED; } - default: - return MMDEPLOY_E_NOT_SUPPORTED; - } - return 0; + return 0; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/common.h b/csrc/mmdeploy/apis/c/mmdeploy/common.h index c665134cbf..26b92973ca 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/common.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/common.h @@ -6,19 +6,19 @@ #include // NOLINT #ifndef MMDEPLOY_EXPORT -#ifdef _MSC_VER -#define MMDEPLOY_EXPORT __declspec(dllexport) -#else -#define MMDEPLOY_EXPORT __attribute__((visibility("default"))) -#endif + #ifdef _MSC_VER + #define MMDEPLOY_EXPORT __declspec(dllexport) + #else + #define MMDEPLOY_EXPORT __attribute__((visibility("default"))) + #endif #endif #ifndef MMDEPLOY_API -#ifdef MMDEPLOY_API_EXPORTS -#define MMDEPLOY_API MMDEPLOY_EXPORT -#else -#define MMDEPLOY_API -#endif + #ifdef MMDEPLOY_API_EXPORTS + #define MMDEPLOY_API MMDEPLOY_EXPORT + #else + #define MMDEPLOY_API + #endif #endif // clang-format off @@ -54,136 +54,137 @@ typedef enum mmdeploy_status_t { // clang-format on -typedef struct mmdeploy_device* mmdeploy_device_t; +typedef struct mmdeploy_device* mmdeploy_device_t; typedef struct mmdeploy_profiler* mmdeploy_profiler_t; -typedef struct mmdeploy_mat_t { - uint8_t* data; - int height; - int width; - int channel; - mmdeploy_pixel_format_t format; - mmdeploy_data_type_t type; - mmdeploy_device_t device; +typedef struct mmdeploy_mat_t +{ + uint8_t* data; + int height; + int width; + int channel; + mmdeploy_pixel_format_t format; + mmdeploy_data_type_t type; + mmdeploy_device_t device; } mmdeploy_mat_t; -typedef struct mmdeploy_rect_t { - float left; - float top; - float right; - float bottom; +typedef struct mmdeploy_rect_t +{ + float left; + float top; + float right; + float bottom; } mmdeploy_rect_t; -typedef struct mmdeploy_point_t { - float x; - float y; +typedef struct mmdeploy_point_t +{ + float x; + float y; } mmdeploy_point_t; -typedef struct mmdeploy_value* mmdeploy_value_t; +typedef struct mmdeploy_value* mmdeploy_value_t; typedef struct mmdeploy_context* mmdeploy_context_t; -typedef enum mmdeploy_context_type_t { - MMDEPLOY_TYPE_DEVICE = 0, - MMDEPLOY_TYPE_STREAM = 1, - MMDEPLOY_TYPE_MODEL = 2, - MMDEPLOY_TYPE_SCHEDULER = 3, - MMDEPLOY_TYPE_MAT = 4, - MMDEPLOY_TYPE_PROFILER = 5, +typedef enum mmdeploy_context_type_t +{ + MMDEPLOY_TYPE_DEVICE = 0, + MMDEPLOY_TYPE_STREAM = 1, + MMDEPLOY_TYPE_MODEL = 2, + MMDEPLOY_TYPE_SCHEDULER = 3, + MMDEPLOY_TYPE_MAT = 4, + MMDEPLOY_TYPE_PROFILER = 5, } mmdeploy_context_type_t; #if __cplusplus -extern "C" { +extern "C" +{ #endif -/** - * Copy value - * @param value - * @return - */ -MMDEPLOY_API mmdeploy_value_t mmdeploy_value_copy(mmdeploy_value_t value); - -/** - * Destroy value - * @param value - */ -MMDEPLOY_API void mmdeploy_value_destroy(mmdeploy_value_t value); - -/** - * Create device handle - * @param device_name - * @param device_id - * @param device - * @return - */ -MMDEPLOY_API int mmdeploy_device_create(const char* device_name, int device_id, - mmdeploy_device_t* device); - -/** - * Destroy device handle - * @param device - */ -MMDEPLOY_API void mmdeploy_device_destroy(mmdeploy_device_t device); - -/** - * Create profiler - * @param path path to save the profile data - * @param profiler handle for profiler, should be added to context and deleted by - * mmdeploy_profiler_destroy - * @return status of create - */ -MMDEPLOY_API int mmdeploy_profiler_create(const char* path, mmdeploy_profiler_t* profiler); - -/** - * Destroy profiler handle - * @param profiler handle for profiler, profile data will be written to disk after this call - */ -MMDEPLOY_API void mmdeploy_profiler_destroy(mmdeploy_profiler_t profiler); - -/** - * Create context - * @param context - * @return - */ -MMDEPLOY_API int mmdeploy_context_create(mmdeploy_context_t* context); - -/** - * Create context - * @param device_name - * @param device_id - * @param context - * @return - */ -MMDEPLOY_API int mmdeploy_context_create_by_device(const char* device_name, int device_id, - mmdeploy_context_t* context); - -/** - * Destroy context - * @param context - */ -MMDEPLOY_API void mmdeploy_context_destroy(mmdeploy_context_t context); - -/** - * Add context object - * @param context - * @param type - * @param name - * @param object - * @return - */ -MMDEPLOY_API int mmdeploy_context_add(mmdeploy_context_t context, mmdeploy_context_type_t type, - const char* name, const void* object); - -/** - * Create input value from array of mats - * @param mats - * @param mat_count - * @param value - * @return - */ -MMDEPLOY_API int mmdeploy_common_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value); + /** + * Copy value + * @param value + * @return + */ + MMDEPLOY_API mmdeploy_value_t mmdeploy_value_copy(mmdeploy_value_t value); + + /** + * Destroy value + * @param value + */ + MMDEPLOY_API void mmdeploy_value_destroy(mmdeploy_value_t value); + + /** + * Create device handle + * @param device_name + * @param device_id + * @param device + * @return + */ + MMDEPLOY_API int mmdeploy_device_create(const char* device_name, int device_id, mmdeploy_device_t* device); + + /** + * Destroy device handle + * @param device + */ + MMDEPLOY_API void mmdeploy_device_destroy(mmdeploy_device_t device); + + /** + * Create profiler + * @param path path to save the profile data + * @param profiler handle for profiler, should be added to context and deleted by + * mmdeploy_profiler_destroy + * @return status of create + */ + MMDEPLOY_API int mmdeploy_profiler_create(const char* path, mmdeploy_profiler_t* profiler); + + /** + * Destroy profiler handle + * @param profiler handle for profiler, profile data will be written to disk after this call + */ + MMDEPLOY_API void mmdeploy_profiler_destroy(mmdeploy_profiler_t profiler); + + /** + * Create context + * @param context + * @return + */ + MMDEPLOY_API int mmdeploy_context_create(mmdeploy_context_t* context); + + /** + * Create context + * @param device_name + * @param device_id + * @param context + * @return + */ + MMDEPLOY_API int mmdeploy_context_create_by_device(const char* device_name, int device_id, mmdeploy_context_t* context); + + /** + * Destroy context + * @param context + */ + MMDEPLOY_API void mmdeploy_context_destroy(mmdeploy_context_t context); + + /** + * Add context object + * @param context + * @param type + * @param name + * @param object + * @return + */ + MMDEPLOY_API int mmdeploy_context_add(mmdeploy_context_t context, mmdeploy_context_type_t type, const char* name, const void* object); + + /** + * Create input value from array of mats + * @param mats + * @param mat_count + * @param value + * @return + */ + MMDEPLOY_API int mmdeploy_common_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value); #if __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/common_internal.h b/csrc/mmdeploy/apis/c/mmdeploy/common_internal.h index a1ddecb54d..6beb2f6b5e 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/common_internal.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/common_internal.h @@ -12,93 +12,152 @@ using namespace mmdeploy; -namespace { - -inline mmdeploy_value_t Cast(Value* s) { return reinterpret_cast(s); } - -inline Value* Cast(mmdeploy_value_t s) { return reinterpret_cast(s); } - -inline Value Take(mmdeploy_value_t v) { - auto value = std::move(*Cast(v)); - mmdeploy_value_destroy(v); - return value; -} - -inline Value* Cast(mmdeploy_context_t c) { return reinterpret_cast(c); } - -inline mmdeploy_value_t Take(Value v) { - return Cast(new Value(std::move(v))); // NOLINT -} - -inline mmdeploy_pipeline_t Cast(AsyncHandle* pipeline) { - return reinterpret_cast(pipeline); -} - -inline AsyncHandle* Cast(mmdeploy_pipeline_t pipeline) { - return reinterpret_cast(pipeline); -} - -inline mmdeploy_model_t Cast(Model* model) { return reinterpret_cast(model); } - -inline Model* Cast(mmdeploy_model_t model) { return reinterpret_cast(model); } - -inline Mat Cast(const mmdeploy_mat_t& mat) { - return Mat{mat.height, mat.width, PixelFormat(mat.format), - DataType(mat.type), mat.data, mat.device ? *(const Device*)mat.device : Device{0}}; -} - -template -std::invoke_result_t Guard(F f) { - try { - return f(); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return nullptr; -} - -template -class wrapped {}; - -template -class wrapped> { - public: - wrapped() noexcept : v_(nullptr) {} - explicit wrapped(T v) noexcept : v_(v) {} - - void reset() { - if (v_) { - delete Cast(v_); - v_ = nullptr; +namespace +{ + + inline mmdeploy_value_t Cast(Value* s) + { + return reinterpret_cast(s); + } + + inline Value* Cast(mmdeploy_value_t s) + { + return reinterpret_cast(s); + } + + inline Value Take(mmdeploy_value_t v) + { + auto value = std::move(*Cast(v)); + mmdeploy_value_destroy(v); + return value; + } + + inline Value* Cast(mmdeploy_context_t c) + { + return reinterpret_cast(c); } - } - ~wrapped() { reset(); } + inline mmdeploy_value_t Take(Value v) + { + return Cast(new Value(std::move(v))); // NOLINT + } - wrapped(const wrapped&) = delete; - wrapped& operator=(const wrapped&) = delete; + inline mmdeploy_pipeline_t Cast(AsyncHandle* pipeline) + { + return reinterpret_cast(pipeline); + } - wrapped(wrapped&& other) noexcept : v_(other.release()) {} - wrapped& operator=(wrapped&& other) noexcept { - reset(); - v_ = other.release(); - return *this; - } + inline AsyncHandle* Cast(mmdeploy_pipeline_t pipeline) + { + return reinterpret_cast(pipeline); + } - T release() noexcept { return std::exchange(v_, nullptr); } + inline mmdeploy_model_t Cast(Model* model) + { + return reinterpret_cast(model); + } - auto operator*() { return Cast(v_); } - auto operator-> () { return Cast(v_); } + inline Model* Cast(mmdeploy_model_t model) + { + return reinterpret_cast(model); + } - T* ptr() noexcept { return &v_; } + inline Mat Cast(const mmdeploy_mat_t& mat) + { + return Mat{mat.height, mat.width, PixelFormat(mat.format), DataType(mat.type), mat.data, mat.device ? *(const Device*)mat.device : Device{0}}; + } - operator T() const noexcept { return v_; } // NOLINT + template + std::invoke_result_t Guard(F f) + { + try + { + return f(); + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return nullptr; + } - private: - T v_; -}; + template + class wrapped + { + }; + + template + class wrapped> + { + public: + wrapped() noexcept + : v_(nullptr) + { + } + explicit wrapped(T v) noexcept + : v_(v) + { + } + + void reset() + { + if (v_) + { + delete Cast(v_); + v_ = nullptr; + } + } + + ~wrapped() + { + reset(); + } + + wrapped(const wrapped&) = delete; + wrapped& operator=(const wrapped&) = delete; + + wrapped(wrapped&& other) noexcept + : v_(other.release()) + { + } + wrapped& operator=(wrapped&& other) noexcept + { + reset(); + v_ = other.release(); + return *this; + } + + T release() noexcept + { + return std::exchange(v_, nullptr); + } + + auto operator*() + { + return Cast(v_); + } + auto operator->() + { + return Cast(v_); + } + + T* ptr() noexcept + { + return &v_; + } + + operator T() const noexcept + { + return v_; + } // NOLINT + + private: + T v_; + }; } // namespace diff --git a/csrc/mmdeploy/apis/c/mmdeploy/detector.cpp b/csrc/mmdeploy/apis/c/mmdeploy/detector.cpp index aadf92fb62..30ea52fcab 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/detector.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/detector.cpp @@ -24,126 +24,142 @@ using ResultType = mmdeploy::Structure, // std::vector>; // -int mmdeploy_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_detector_t* detector) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_detector_t* detector) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_detector_create_v2(model, context, detector); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_detector_create_v2(model, context, detector); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_detector_t* detector) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); +int mmdeploy_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_detector_t* detector) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); } -int mmdeploy_detector_create_by_path(const char* model_path, const char* device_name, int device_id, - mmdeploy_detector_t* detector) { - mmdeploy_model_t model{}; +int mmdeploy_detector_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_detector_t* detector) +{ + mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_detector_create(model, device_name, device_id, detector); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_detector_create(model, device_name, device_id, detector); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* input) { - return mmdeploy_common_create_input(mats, mat_count, input); +int mmdeploy_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* input) +{ + return mmdeploy_common_create_input(mats, mat_count, input); } -int mmdeploy_detector_apply(mmdeploy_detector_t detector, const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_detection_t** results, int** result_count) { - wrapped input; - if (auto ec = mmdeploy_detector_create_input(mats, mat_count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_detector_apply_v2(detector, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_detector_get_result(output, results, result_count)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_detector_apply(mmdeploy_detector_t detector, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_detection_t** results, int** result_count) +{ + wrapped input; + if (auto ec = mmdeploy_detector_create_input(mats, mat_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_detector_apply_v2(detector, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_detector_get_result(output, results, result_count)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_detector_apply_v2(mmdeploy_detector_t detector, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_detector_apply_v2(mmdeploy_detector_t detector, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_detector_apply_async(mmdeploy_detector_t detector, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_detector_apply_async(mmdeploy_detector_t detector, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_detector_get_result(mmdeploy_value_t output, mmdeploy_detection_t** results, - int** result_count) { - if (!output || !results || !result_count) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - Value& value = Cast(output)->front(); - auto detector_outputs = from_value>(value); - - vector _result_count(detector_outputs.size()); - size_t total = 0; - for (size_t i = 0; i < detector_outputs.size(); ++i) { - _result_count[i] = static_cast(detector_outputs[i].size()); - total += detector_outputs[i].size(); +int mmdeploy_detector_get_result(mmdeploy_value_t output, mmdeploy_detection_t** results, int** result_count) +{ + if (!output || !results || !result_count) + { + return MMDEPLOY_E_INVALID_ARG; } + try + { + Value& value = Cast(output)->front(); + auto detector_outputs = from_value>(value); + + vector _result_count(detector_outputs.size()); + size_t total = 0; + for (size_t i = 0; i < detector_outputs.size(); ++i) + { + _result_count[i] = static_cast(detector_outputs[i].size()); + total += detector_outputs[i].size(); + } - ResultType r({total, 1, 1, 1}); - auto [result_data, result_count_vec, masks, buffers] = r.pointers(); - - auto result_ptr = result_data; - - for (const auto& det_output : detector_outputs) { - for (const auto& detection : det_output) { - result_ptr->label_id = detection.label_id; - result_ptr->score = detection.score; - const auto& bbox = detection.bbox; - result_ptr->bbox = {bbox[0], bbox[1], bbox[2], bbox[3]}; - auto mask_byte_size = detection.mask.byte_size(); - if (mask_byte_size) { - auto& mask = detection.mask; - result_ptr->mask = &masks->emplace_back(); - buffers->push_back(mask.buffer()); - result_ptr->mask->data = mask.data(); - result_ptr->mask->width = mask.width(); - result_ptr->mask->height = mask.height(); + ResultType r({total, 1, 1, 1}); + auto [result_data, result_count_vec, masks, buffers] = r.pointers(); + + auto result_ptr = result_data; + + for (const auto& det_output : detector_outputs) + { + for (const auto& detection : det_output) + { + result_ptr->label_id = detection.label_id; + result_ptr->score = detection.score; + const auto& bbox = detection.bbox; + result_ptr->bbox = {bbox[0], bbox[1], bbox[2], bbox[3]}; + auto mask_byte_size = detection.mask.byte_size(); + if (mask_byte_size) + { + auto& mask = detection.mask; + result_ptr->mask = &masks->emplace_back(); + buffers->push_back(mask.buffer()); + result_ptr->mask->data = mask.data(); + result_ptr->mask->width = mask.width(); + result_ptr->mask->height = mask.height(); + } + ++result_ptr; + } } - ++result_ptr; - } - } - *result_count_vec = std::move(_result_count); - *result_count = result_count_vec->data(); - *results = result_data; - r.release(); + *result_count_vec = std::move(_result_count); + *result_count = result_count_vec->data(); + *results = result_data; + r.release(); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -void mmdeploy_detector_release_result(mmdeploy_detection_t* results, const int* result_count, - int count) { - auto num_dets = std::accumulate(result_count, result_count + count, 0); - ResultType deleter({static_cast(num_dets), 1, 1, 1}, results); +void mmdeploy_detector_release_result(mmdeploy_detection_t* results, const int* result_count, int count) +{ + auto num_dets = std::accumulate(result_count, result_count + count, 0); + ResultType deleter({static_cast(num_dets), 1, 1, 1}, results); } -void mmdeploy_detector_destroy(mmdeploy_detector_t detector) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); +void mmdeploy_detector_destroy(mmdeploy_detector_t detector) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/detector.h b/csrc/mmdeploy/apis/c/mmdeploy/detector.h index 5c5ba2f356..713214ca4f 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/detector.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/detector.h @@ -13,124 +13,123 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_instance_mask_t { - char* data; - int height; - int width; -} mmdeploy_instance_mask_t; - -typedef struct mmdeploy_detection_t { - int label_id; - float score; - mmdeploy_rect_t bbox; - mmdeploy_instance_mask_t* mask; -} mmdeploy_detection_t; - -typedef struct mmdeploy_detector* mmdeploy_detector_t; - -/** - * @brief Create detector's handle - * @param[in] model an instance of mmdetection sdk model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector instance of a detector - * @return status of creating detector's handle - */ -MMDEPLOY_API int mmdeploy_detector_create(mmdeploy_model_t model, const char* device_name, - int device_id, mmdeploy_detector_t* detector); - -/** - * @brief Create detector's handle - * @param[in] model_path path of mmdetection sdk model exported by mmdeploy model converter - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector instance of a detector - * @return status of creating detector's handle - */ -MMDEPLOY_API int mmdeploy_detector_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_detector_t* detector); - -/** - * @brief Apply detector to batch images and get their inference results - * @param[in] detector detector's handle created by \ref mmdeploy_detector_create_by_path - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @param[out] results a linear buffer to save detection results of each image. It must be released - * by \ref mmdeploy_detector_release_result - * @param[out] result_count a linear buffer with length being \p mat_count to save the number of - * detection results of each image. And it must be released by \ref - * mmdeploy_detector_release_result - * @return status of inference - */ -MMDEPLOY_API int mmdeploy_detector_apply(mmdeploy_detector_t detector, const mmdeploy_mat_t* mats, - int mat_count, mmdeploy_detection_t** results, - int** result_count); - -/** @brief Release the inference result buffer created by \ref mmdeploy_detector_apply - * @param[in] results detection results buffer - * @param[in] result_count \p results size buffer - * @param[in] count length of \p result_count - */ -MMDEPLOY_API void mmdeploy_detector_release_result(mmdeploy_detection_t* results, - const int* result_count, int count); - -/** - * @brief Destroy detector's handle - * @param[in] detector detector's handle created by \ref mmdeploy_detector_create_by_path - */ -MMDEPLOY_API void mmdeploy_detector_destroy(mmdeploy_detector_t detector); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -/** - * @brief Same as \ref mmdeploy_detector_create, but allows to control execution context of tasks - * via context - */ -MMDEPLOY_API int mmdeploy_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_detector_t* detector); - -/** - * @brief Pack detector inputs into mmdeploy_value_t - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @return the created value - */ -MMDEPLOY_API int mmdeploy_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* input); - -/** - * @brief Same as \ref mmdeploy_detector_apply, but input and output are packed in \ref - * mmdeploy_value_t. - */ -MMDEPLOY_API int mmdeploy_detector_apply_v2(mmdeploy_detector_t detector, mmdeploy_value_t input, - mmdeploy_value_t* output); - -/** - * @brief Apply detector asynchronously - * @param[in] detector handle to the detector - * @param[in] input input sender - * @return output sender - */ -MMDEPLOY_API int mmdeploy_detector_apply_async(mmdeploy_detector_t detector, - mmdeploy_sender_t input, mmdeploy_sender_t* output); - -/** - * @brief Unpack detector output from a mmdeploy_value_t - * @param[in] output output obtained by applying a detector - * @param[out] results a linear buffer to save detection results of each image. It must be released - * by \ref mmdeploy_detector_release_result - * @param[out] result_count a linear buffer with length number of input images to save the number of - * detection results of each image. Must be released by \ref - * mmdeploy_detector_release_result - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_detector_get_result(mmdeploy_value_t output, - mmdeploy_detection_t** results, int** result_count); + typedef struct mmdeploy_instance_mask_t + { + char* data; + int height; + int width; + } mmdeploy_instance_mask_t; + + typedef struct mmdeploy_detection_t + { + int label_id; + float score; + mmdeploy_rect_t bbox; + mmdeploy_instance_mask_t* mask; + } mmdeploy_detection_t; + + typedef struct mmdeploy_detector* mmdeploy_detector_t; + + /** + * @brief Create detector's handle + * @param[in] model an instance of mmdetection sdk model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector instance of a detector + * @return status of creating detector's handle + */ + MMDEPLOY_API int mmdeploy_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_detector_t* detector); + + /** + * @brief Create detector's handle + * @param[in] model_path path of mmdetection sdk model exported by mmdeploy model converter + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector instance of a detector + * @return status of creating detector's handle + */ + MMDEPLOY_API int mmdeploy_detector_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_detector_t* detector); + + /** + * @brief Apply detector to batch images and get their inference results + * @param[in] detector detector's handle created by \ref mmdeploy_detector_create_by_path + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @param[out] results a linear buffer to save detection results of each image. It must be released + * by \ref mmdeploy_detector_release_result + * @param[out] result_count a linear buffer with length being \p mat_count to save the number of + * detection results of each image. And it must be released by \ref + * mmdeploy_detector_release_result + * @return status of inference + */ + MMDEPLOY_API int mmdeploy_detector_apply(mmdeploy_detector_t detector, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_detection_t** results, int** result_count); + + /** @brief Release the inference result buffer created by \ref mmdeploy_detector_apply + * @param[in] results detection results buffer + * @param[in] result_count \p results size buffer + * @param[in] count length of \p result_count + */ + MMDEPLOY_API void mmdeploy_detector_release_result(mmdeploy_detection_t* results, + const int* result_count, + int count); + + /** + * @brief Destroy detector's handle + * @param[in] detector detector's handle created by \ref mmdeploy_detector_create_by_path + */ + MMDEPLOY_API void mmdeploy_detector_destroy(mmdeploy_detector_t detector); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + /** + * @brief Same as \ref mmdeploy_detector_create, but allows to control execution context of tasks + * via context + */ + MMDEPLOY_API int mmdeploy_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_detector_t* detector); + + /** + * @brief Pack detector inputs into mmdeploy_value_t + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @return the created value + */ + MMDEPLOY_API int mmdeploy_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* input); + + /** + * @brief Same as \ref mmdeploy_detector_apply, but input and output are packed in \ref + * mmdeploy_value_t. + */ + MMDEPLOY_API int mmdeploy_detector_apply_v2(mmdeploy_detector_t detector, mmdeploy_value_t input, mmdeploy_value_t* output); + + /** + * @brief Apply detector asynchronously + * @param[in] detector handle to the detector + * @param[in] input input sender + * @return output sender + */ + MMDEPLOY_API int mmdeploy_detector_apply_async(mmdeploy_detector_t detector, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + /** + * @brief Unpack detector output from a mmdeploy_value_t + * @param[in] output output obtained by applying a detector + * @param[out] results a linear buffer to save detection results of each image. It must be released + * by \ref mmdeploy_detector_release_result + * @param[out] result_count a linear buffer with length number of input images to save the number of + * detection results of each image. Must be released by \ref + * mmdeploy_detector_release_result + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_detector_get_result(mmdeploy_value_t output, + mmdeploy_detection_t** results, + int** result_count); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/executor.cpp b/csrc/mmdeploy/apis/c/mmdeploy/executor.cpp index 2fdfb9091f..e73ffe0606 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/executor.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/executor.cpp @@ -9,199 +9,261 @@ using namespace mmdeploy; -namespace { +namespace +{ -mmdeploy_scheduler_t CreateScheduler(const char* type, const Value& config = Value()) { - try { - auto creator = gRegistry().Get(type); - if (!creator) { - MMDEPLOY_ERROR("Creator for {} not found. Available schedulers: {}", type, - gRegistry().List()); - return nullptr; + mmdeploy_scheduler_t CreateScheduler(const char* type, const Value& config = Value()) + { + try + { + auto creator = gRegistry().Get(type); + if (!creator) + { + MMDEPLOY_ERROR("Creator for {} not found. Available schedulers: {}", type, gRegistry().List()); + return nullptr; + } + return Cast(new SchedulerType(creator->Create(config))); + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("failed to create Scheduler: {} ({}), config: {}", type, e.what(), config); + return nullptr; + } } - return Cast(new SchedulerType(creator->Create(config))); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("failed to create Scheduler: {} ({}), config: {}", type, e.what(), config); - return nullptr; - } -} } // namespace -mmdeploy_sender_t mmdeploy_sender_copy(mmdeploy_sender_t input) { - if (!input) { - return nullptr; - } - return Take(SenderType(*Cast(input))); +mmdeploy_sender_t mmdeploy_sender_copy(mmdeploy_sender_t input) +{ + if (!input) + { + return nullptr; + } + return Take(SenderType(*Cast(input))); } -int mmdeploy_sender_destroy(mmdeploy_sender_t sender) { - delete Cast(sender); - return 0; +int mmdeploy_sender_destroy(mmdeploy_sender_t sender) +{ + delete Cast(sender); + return 0; } -mmdeploy_scheduler_t mmdeploy_executor_inline() { return CreateScheduler("Inline"); } +mmdeploy_scheduler_t mmdeploy_executor_inline() +{ + return CreateScheduler("Inline"); +} -mmdeploy_scheduler_t mmdeploy_executor_system_pool() { - // create a thread pool context and hold its shared handle - static auto scheduler = *Cast(CreateScheduler("ThreadPool")); - // return a copy of the handle to the thread pool - return Cast(new SchedulerType(scheduler)); +mmdeploy_scheduler_t mmdeploy_executor_system_pool() +{ + // create a thread pool context and hold its shared handle + static auto scheduler = *Cast(CreateScheduler("ThreadPool")); + // return a copy of the handle to the thread pool + return Cast(new SchedulerType(scheduler)); } -mmdeploy_scheduler_t mmdeploy_executor_create_thread_pool(int num_threads) { - return CreateScheduler("ThreadPool", {{"num_threads", num_threads}}); +mmdeploy_scheduler_t mmdeploy_executor_create_thread_pool(int num_threads) +{ + return CreateScheduler("ThreadPool", {{"num_threads", num_threads}}); } -mmdeploy_scheduler_t mmdeploy_executor_create_thread() { return CreateScheduler("SingleThread"); } +mmdeploy_scheduler_t mmdeploy_executor_create_thread() +{ + return CreateScheduler("SingleThread"); +} mmdeploy_scheduler_t mmdeploy_executor_dynamic_batch(mmdeploy_scheduler_t scheduler, - int max_batch_size, int timeout) { - if (!scheduler) { - return nullptr; - } - return CreateScheduler( - "DynamicBatch", - {{"scheduler", *Cast(scheduler)}, {"max_batch_size", max_batch_size}, {"timeout", timeout}}); + int max_batch_size, + int timeout) +{ + if (!scheduler) + { + return nullptr; + } + return CreateScheduler( + "DynamicBatch", + {{"scheduler", *Cast(scheduler)}, {"max_batch_size", max_batch_size}, {"timeout", timeout}}); } -int mmdeploy_scheduler_destroy(mmdeploy_scheduler_t scheduler) { - delete Cast(scheduler); - return 0; +int mmdeploy_scheduler_destroy(mmdeploy_scheduler_t scheduler) +{ + delete Cast(scheduler); + return 0; } -mmdeploy_sender_t mmdeploy_executor_just(mmdeploy_value_t value) { - if (value) { - return Guard([&] { return Take(Just(*Cast(value))); }); - } else { - return Take(Just(Value())); - } +mmdeploy_sender_t mmdeploy_executor_just(mmdeploy_value_t value) +{ + if (value) + { + return Guard([&] + { return Take(Just(*Cast(value))); }); + } + else + { + return Take(Just(Value())); + } } -mmdeploy_sender_t mmdeploy_executor_schedule(mmdeploy_scheduler_t scheduler) { - if (!scheduler) { - return nullptr; - } - return Guard([&] { return Take(Then(Schedule(*Cast(scheduler)), [] { return Value(); })); }); +mmdeploy_sender_t mmdeploy_executor_schedule(mmdeploy_scheduler_t scheduler) +{ + if (!scheduler) + { + return nullptr; + } + return Guard([&] + { return Take(Then(Schedule(*Cast(scheduler)), [] + { return Value(); })); }); } mmdeploy_sender_t mmdeploy_executor_transfer_just(mmdeploy_scheduler_t scheduler, - mmdeploy_value_t value) { - if (!scheduler || !value) { - return nullptr; - } - return Guard([&] { return Take(TransferJust(*Cast(scheduler), *Cast(value))); }); -} - -mmdeploy_sender_t mmdeploy_executor_transfer(mmdeploy_sender_t input, - mmdeploy_scheduler_t scheduler) { - if (!input || !scheduler) { - return nullptr; - } - return Guard([&] { return Take(Transfer(Take(input), *Cast(scheduler))); }); -} - -mmdeploy_sender_t mmdeploy_executor_on(mmdeploy_scheduler_t scheduler, mmdeploy_sender_t input) { - if (!scheduler || !input) { - return nullptr; - } - return Guard([&] { return Take(On(*Cast(scheduler), Take(input))); }); -} - -mmdeploy_sender_t mmdeploy_executor_then(mmdeploy_sender_t input, mmdeploy_then_fn_t fn, - void* context) { - if (!input || !fn) { - return nullptr; - } - return Guard([&] { - return Take(Then(Take(input), [fn, context](Value args) { + mmdeploy_value_t value) +{ + if (!scheduler || !value) + { + return nullptr; + } + return Guard([&] + { return Take(TransferJust(*Cast(scheduler), *Cast(value))); }); +} + +mmdeploy_sender_t mmdeploy_executor_transfer(mmdeploy_sender_t input, + mmdeploy_scheduler_t scheduler) +{ + if (!input || !scheduler) + { + return nullptr; + } + return Guard([&] + { return Take(Transfer(Take(input), *Cast(scheduler))); }); +} + +mmdeploy_sender_t mmdeploy_executor_on(mmdeploy_scheduler_t scheduler, mmdeploy_sender_t input) +{ + if (!scheduler || !input) + { + return nullptr; + } + return Guard([&] + { return Take(On(*Cast(scheduler), Take(input))); }); +} + +mmdeploy_sender_t mmdeploy_executor_then(mmdeploy_sender_t input, mmdeploy_then_fn_t fn, void* context) +{ + if (!input || !fn) + { + return nullptr; + } + return Guard([&] + { return Take(Then(Take(input), [fn, context](Value args) + { auto out = Cast(fn(Take(std::move(args)), context)); Value ret(std::move(*out)); delete out; - return ret; - })); - }); -} - -mmdeploy_sender_t mmdeploy_executor_let_value(mmdeploy_sender_t input, mmdeploy_let_value_fn_t fn, - void* context) { - if (!input || !fn) { - return nullptr; - } - return Guard([&] { - return Take(LetValue(Take(input), [fn, context](Value& args) { + return ret; })); }); +} + +mmdeploy_sender_t mmdeploy_executor_let_value(mmdeploy_sender_t input, mmdeploy_let_value_fn_t fn, void* context) +{ + if (!input || !fn) + { + return nullptr; + } + return Guard([&] + { return Take(LetValue(Take(input), [fn, context](Value& args) + { auto out = Cast(fn(Cast(&args), context)); SenderType ret(std::move(*out)); delete out; - return ret; - })); - }); + return ret; })); }); } -mmdeploy_sender_t mmdeploy_executor_split(mmdeploy_sender_t input) { - if (!input) { - return nullptr; - } - return Guard([&] { return Take(Split(Take(input))); }); +mmdeploy_sender_t mmdeploy_executor_split(mmdeploy_sender_t input) +{ + if (!input) + { + return nullptr; + } + return Guard([&] + { return Take(Split(Take(input))); }); } -mmdeploy_sender_t mmdeploy_executor_when_all(mmdeploy_sender_t inputs[], int32_t n) { - if (!inputs) { - return nullptr; - } - return Guard([&] { +mmdeploy_sender_t mmdeploy_executor_when_all(mmdeploy_sender_t inputs[], int32_t n) +{ + if (!inputs) + { + return nullptr; + } + return Guard([&] + { std::vector senders; senders.reserve(n); for (int i = 0; i < n; ++i) { senders.emplace_back(Take(inputs[i])); } return Take( - Then(WhenAll(std::move(senders)), [](Value::Array&& v) { return Value(std::move(v)); })); - }); + Then(WhenAll(std::move(senders)), [](Value::Array&& v) { return Value(std::move(v)); })); }); } -mmdeploy_sender_t mmdeploy_executor_ensure_started(mmdeploy_sender_t input) { - if (!input) { - return nullptr; - } - return Guard([&] { return Take(EnsureStarted(Take(input))); }); +mmdeploy_sender_t mmdeploy_executor_ensure_started(mmdeploy_sender_t input) +{ + if (!input) + { + return nullptr; + } + return Guard([&] + { return Take(EnsureStarted(Take(input))); }); } -int mmdeploy_executor_start_detached(mmdeploy_sender_t input) { - if (!input) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - StartDetached(Take(input)); - return 0; - } catch (...) { - } - return MMDEPLOY_E_FAIL; +int mmdeploy_executor_start_detached(mmdeploy_sender_t input) +{ + if (!input) + { + return MMDEPLOY_E_INVALID_ARG; + } + try + { + StartDetached(Take(input)); + return 0; + } + catch (...) + { + } + return MMDEPLOY_E_FAIL; } -mmdeploy_value_t mmdeploy_executor_sync_wait(mmdeploy_sender_t input) { - if (!input) { - return nullptr; - } - return Guard([&] { return Take(std::get(SyncWait(Take(input)))); }); +mmdeploy_value_t mmdeploy_executor_sync_wait(mmdeploy_sender_t input) +{ + if (!input) + { + return nullptr; + } + return Guard([&] + { return Take(std::get(SyncWait(Take(input)))); }); } -int mmdeploy_executor_sync_wait_v2(mmdeploy_sender_t sender, mmdeploy_value_t* value) { - if (!sender) { - return MMDEPLOY_E_INVALID_ARG; - } - auto result = mmdeploy_executor_sync_wait(sender); - if (!result) { - return MMDEPLOY_E_FAIL; - } - if (value) { - *value = result; - } else { - mmdeploy_value_destroy(result); - } - return MMDEPLOY_SUCCESS; +int mmdeploy_executor_sync_wait_v2(mmdeploy_sender_t sender, mmdeploy_value_t* value) +{ + if (!sender) + { + return MMDEPLOY_E_INVALID_ARG; + } + auto result = mmdeploy_executor_sync_wait(sender); + if (!result) + { + return MMDEPLOY_E_FAIL; + } + if (value) + { + *value = result; + } + else + { + mmdeploy_value_destroy(result); + } + return MMDEPLOY_SUCCESS; } -void mmdeploy_executor_execute(mmdeploy_scheduler_t scheduler, void (*fn)(void*), void* context) { - Execute(*Cast(scheduler), [fn, context] { fn(context); }); +void mmdeploy_executor_execute(mmdeploy_scheduler_t scheduler, void (*fn)(void*), void* context) +{ + Execute(*Cast(scheduler), [fn, context] + { fn(context); }); } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/executor.h b/csrc/mmdeploy/apis/c/mmdeploy/executor.h index a2c8ffa387..4b044a6b51 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/executor.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/executor.h @@ -6,133 +6,135 @@ #include "mmdeploy/common.h" #if __cplusplus -extern "C" { +extern "C" +{ #endif -/****************************************************************************** - * Experimental asynchronous APIs */ + /****************************************************************************** + * Experimental asynchronous APIs */ -typedef mmdeploy_value_t (*mmdeploy_then_fn_t)(mmdeploy_value_t, void*); + typedef mmdeploy_value_t (*mmdeploy_then_fn_t)(mmdeploy_value_t, void*); -typedef mmdeploy_value_t (*mmdeploy_then_fn_v2_t)(mmdeploy_value_t*, void*); - -typedef int (*mmdeploy_then_fn_v3_t)(mmdeploy_value_t* input, mmdeploy_value_t* output, void*); + typedef mmdeploy_value_t (*mmdeploy_then_fn_v2_t)(mmdeploy_value_t*, void*); + + typedef int (*mmdeploy_then_fn_v3_t)(mmdeploy_value_t* input, mmdeploy_value_t* output, void*); + + struct mmdeploy_sender; + struct mmdeploy_scheduler; + + typedef struct mmdeploy_sender* mmdeploy_sender_t; + typedef struct mmdeploy_scheduler* mmdeploy_scheduler_t; -struct mmdeploy_sender; -struct mmdeploy_scheduler; + typedef mmdeploy_sender_t (*mmdeploy_let_value_fn_t)(mmdeploy_value_t, void*); -typedef struct mmdeploy_sender* mmdeploy_sender_t; -typedef struct mmdeploy_scheduler* mmdeploy_scheduler_t; + /////////////////////////////////////////////////////////////////////////////// + // Scheduler + /////////////////////////////////////////////////////////////////////////////// + MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_inline(); -typedef mmdeploy_sender_t (*mmdeploy_let_value_fn_t)(mmdeploy_value_t, void*); + MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_system_pool(); -/////////////////////////////////////////////////////////////////////////////// -// Scheduler -/////////////////////////////////////////////////////////////////////////////// -MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_inline(); + /** + * Create a thread pool with the given number of worker threads + * @param[in] num_threads + * @return the handle to the created thread pool + */ + MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_create_thread_pool(int num_threads); -MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_system_pool(); + MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_create_thread(); -/** - * Create a thread pool with the given number of worker threads - * @param[in] num_threads - * @return the handle to the created thread pool - */ -MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_create_thread_pool(int num_threads); + MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_dynamic_batch(mmdeploy_scheduler_t scheduler, + int max_batch_size, + int timeout); -MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_create_thread(); + MMDEPLOY_API int mmdeploy_scheduler_destroy(mmdeploy_scheduler_t scheduler); -MMDEPLOY_API mmdeploy_scheduler_t mmdeploy_executor_dynamic_batch(mmdeploy_scheduler_t scheduler, - int max_batch_size, int timeout); + /////////////////////////////////////////////////////////////////////////////// + // Utilities + /////////////////////////////////////////////////////////////////////////////// -MMDEPLOY_API int mmdeploy_scheduler_destroy(mmdeploy_scheduler_t scheduler); + /** + * @brief Create a copy of a copyable sender. Only senders created by \ref mmdeploy_executor_split + * is copyable for now. + * @param[in] input copyable sender, + * @return the sender created, or nullptr if the sender is not copyable + */ + MMDEPLOY_API mmdeploy_sender_t mmdeploy_sender_copy(mmdeploy_sender_t input); -/////////////////////////////////////////////////////////////////////////////// -// Utilities -/////////////////////////////////////////////////////////////////////////////// + /** + * @brief Destroy a sender, notice that all sender adapters will consume input senders, only unused + * senders should be destroyed using this function. + * @param[in] input + */ + MMDEPLOY_API int mmdeploy_sender_destroy(mmdeploy_sender_t sender); -/** - * @brief Create a copy of a copyable sender. Only senders created by \ref mmdeploy_executor_split - * is copyable for now. - * @param[in] input copyable sender, - * @return the sender created, or nullptr if the sender is not copyable - */ -MMDEPLOY_API mmdeploy_sender_t mmdeploy_sender_copy(mmdeploy_sender_t input); + /////////////////////////////////////////////////////////////////////////////// + // Sender factories + /////////////////////////////////////////////////////////////////////////////// -/** - * @brief Destroy a sender, notice that all sender adapters will consume input senders, only unused - * senders should be destroyed using this function. - * @param[in] input - */ -MMDEPLOY_API int mmdeploy_sender_destroy(mmdeploy_sender_t sender); + /** + * @brief Create a sender that sends the provided value + * @param[in] value + * @return created sender + */ + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_just(mmdeploy_value_t value); -/////////////////////////////////////////////////////////////////////////////// -// Sender factories -/////////////////////////////////////////////////////////////////////////////// + /** + * @brief + * @param[in] scheduler + * @return the sender created + */ + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_schedule(mmdeploy_scheduler_t scheduler); -/** - * @brief Create a sender that sends the provided value - * @param[in] value - * @return created sender - */ -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_just(mmdeploy_value_t value); + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_transfer_just(mmdeploy_scheduler_t scheduler, + mmdeploy_value_t value); -/** - * @brief - * @param[in] scheduler - * @return the sender created - */ -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_schedule(mmdeploy_scheduler_t scheduler); + /////////////////////////////////////////////////////////////////////////////// + // Sender adapters + /////////////////////////////////////////////////////////////////////////////// -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_transfer_just(mmdeploy_scheduler_t scheduler, - mmdeploy_value_t value); + /** + * Transfer the execution to the execution agent of the provided scheduler + * @param[in] input + * @param[in] scheduler + * @return the sender created + */ + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_transfer(mmdeploy_sender_t input, + mmdeploy_scheduler_t scheduler); -/////////////////////////////////////////////////////////////////////////////// -// Sender adapters -/////////////////////////////////////////////////////////////////////////////// + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_on(mmdeploy_scheduler_t scheduler, + mmdeploy_sender_t input); -/** - * Transfer the execution to the execution agent of the provided scheduler - * @param[in] input - * @param[in] scheduler - * @return the sender created - */ -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_transfer(mmdeploy_sender_t input, - mmdeploy_scheduler_t scheduler); + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_then(mmdeploy_sender_t input, + mmdeploy_then_fn_t fn, + void* context); -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_on(mmdeploy_scheduler_t scheduler, - mmdeploy_sender_t input); + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_let_value(mmdeploy_sender_t input, + mmdeploy_let_value_fn_t fn, + void* context); -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_then(mmdeploy_sender_t input, - mmdeploy_then_fn_t fn, void* context); - -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_let_value(mmdeploy_sender_t input, - mmdeploy_let_value_fn_t fn, - void* context); - -/** - * Convert the input sender into a sender that is copyable via \ref mmdeploy_sender_copy. Notice - * that this function doesn't make the sender multi-shot, it just return a sender that is copyable. - * @param[in] input - * @return the sender that is copyable - */ -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_split(mmdeploy_sender_t input); - -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_when_all(mmdeploy_sender_t inputs[], int32_t n); - -MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_ensure_started(mmdeploy_sender_t input); - -/////////////////////////////////////////////////////////////////////////////// -// Sender consumers -/////////////////////////////////////////////////////////////////////////////// -MMDEPLOY_API int mmdeploy_executor_start_detached(mmdeploy_sender_t input); - -MMDEPLOY_API mmdeploy_value_t mmdeploy_executor_sync_wait(mmdeploy_sender_t input); - -MMDEPLOY_API int mmdeploy_executor_sync_wait_v2(mmdeploy_sender_t input, mmdeploy_value_t* output); - -MMDEPLOY_API void mmdeploy_executor_execute(mmdeploy_scheduler_t scheduler, void (*fn)(void*), - void* context); + /** + * Convert the input sender into a sender that is copyable via \ref mmdeploy_sender_copy. Notice + * that this function doesn't make the sender multi-shot, it just return a sender that is copyable. + * @param[in] input + * @return the sender that is copyable + */ + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_split(mmdeploy_sender_t input); + + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_when_all(mmdeploy_sender_t inputs[], int32_t n); + + MMDEPLOY_API mmdeploy_sender_t mmdeploy_executor_ensure_started(mmdeploy_sender_t input); + + /////////////////////////////////////////////////////////////////////////////// + // Sender consumers + /////////////////////////////////////////////////////////////////////////////// + MMDEPLOY_API int mmdeploy_executor_start_detached(mmdeploy_sender_t input); + + MMDEPLOY_API mmdeploy_value_t mmdeploy_executor_sync_wait(mmdeploy_sender_t input); + + MMDEPLOY_API int mmdeploy_executor_sync_wait_v2(mmdeploy_sender_t input, mmdeploy_value_t* output); + + MMDEPLOY_API void mmdeploy_executor_execute(mmdeploy_scheduler_t scheduler, void (*fn)(void*), void* context); #if __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/executor_internal.h b/csrc/mmdeploy/apis/c/mmdeploy/executor_internal.h index 95f39fe009..0ae8c2a529 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/executor_internal.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/executor_internal.h @@ -8,33 +8,49 @@ using namespace mmdeploy; -using SenderType = TypeErasedSender; +using SenderType = TypeErasedSender; using SchedulerType = TypeErasedScheduler; -namespace { - -inline SchedulerType* Cast(mmdeploy_scheduler_t s) { return reinterpret_cast(s); } - -inline mmdeploy_scheduler_t Cast(SchedulerType* s) { - return reinterpret_cast(s); -} - -inline SenderType* Cast(mmdeploy_sender_t s) { return reinterpret_cast(s); } - -inline mmdeploy_sender_t Cast(SenderType* s) { return reinterpret_cast(s); } - -inline SenderType Take(mmdeploy_sender_t s) { - auto sender = std::move(*Cast(s)); - mmdeploy_sender_destroy(s); - return sender; -} - -inline mmdeploy_sender_t Take(SenderType s) { return Cast(new SenderType(std::move(s))); } - -template , int> = 0> -inline mmdeploy_sender_t Take(T& s) { - return Take(SenderType(std::move(s))); -} +namespace +{ + + inline SchedulerType* Cast(mmdeploy_scheduler_t s) + { + return reinterpret_cast(s); + } + + inline mmdeploy_scheduler_t Cast(SchedulerType* s) + { + return reinterpret_cast(s); + } + + inline SenderType* Cast(mmdeploy_sender_t s) + { + return reinterpret_cast(s); + } + + inline mmdeploy_sender_t Cast(SenderType* s) + { + return reinterpret_cast(s); + } + + inline SenderType Take(mmdeploy_sender_t s) + { + auto sender = std::move(*Cast(s)); + mmdeploy_sender_destroy(s); + return sender; + } + + inline mmdeploy_sender_t Take(SenderType s) + { + return Cast(new SenderType(std::move(s))); + } + + template, int> = 0> + inline mmdeploy_sender_t Take(T& s) + { + return Take(SenderType(std::move(s))); + } } // namespace diff --git a/csrc/mmdeploy/apis/c/mmdeploy/handle.h b/csrc/mmdeploy/apis/c/mmdeploy/handle.h index 006ddaae3d..d2ccde1ef5 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/handle.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/handle.h @@ -11,42 +11,53 @@ #include "mmdeploy/graph/common.h" #include "mmdeploy/graph/static_router.h" -namespace mmdeploy { - -using namespace framework; - -namespace { - -class AsyncHandle { - public: - AsyncHandle(const char* device_name, int device_id, Value config) - : AsyncHandle(SetContext(std::move(config), device_name, device_id)) {} - - explicit AsyncHandle(const Value& config) { - if (auto builder = graph::Builder::CreateFromConfig(config).value()) { - node_ = builder->Build().value(); - } else { - MMDEPLOY_ERROR("failed to find creator for node"); - throw_exception(eEntryNotFound); - } - } - - graph::Sender Process(graph::Sender input) { - return node_->Process(std::move(input)); - } - - private: - static Value SetContext(Value config, const char* device_name, int device_id) { - Device device(device_name, device_id); - Stream stream(device); - config["context"].update({{"device", device}, {"stream", stream}}); - return config; - } - - std::unique_ptr node_; -}; - -} // namespace +namespace mmdeploy +{ + + using namespace framework; + + namespace + { + + class AsyncHandle + { + public: + AsyncHandle(const char* device_name, int device_id, Value config) + : AsyncHandle(SetContext(std::move(config), device_name, device_id)) + { + } + + explicit AsyncHandle(const Value& config) + { + if (auto builder = graph::Builder::CreateFromConfig(config).value()) + { + node_ = builder->Build().value(); + } + else + { + MMDEPLOY_ERROR("failed to find creator for node"); + throw_exception(eEntryNotFound); + } + } + + graph::Sender Process(graph::Sender input) + { + return node_->Process(std::move(input)); + } + + private: + static Value SetContext(Value config, const char* device_name, int device_id) + { + Device device(device_name, device_id); + Stream stream(device); + config["context"].update({{"device", device}, {"stream", stream}}); + return config; + } + + std::unique_ptr node_; + }; + + } // namespace } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/c/mmdeploy/model.cpp b/csrc/mmdeploy/apis/c/mmdeploy/model.cpp index 6d202bce81..08af517522 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/model.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/model.cpp @@ -12,30 +12,45 @@ using namespace mmdeploy; -int mmdeploy_model_create_by_path(const char* path, mmdeploy_model_t* model) { - try { - auto ptr = std::make_unique(path); - *model = reinterpret_cast(ptr.release()); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("failed to create model: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; +int mmdeploy_model_create_by_path(const char* path, mmdeploy_model_t* model) +{ + try + { + auto ptr = std::make_unique(path); + *model = reinterpret_cast(ptr.release()); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("failed to create model: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -int mmdeploy_model_create(const void* buffer, int size, mmdeploy_model_t* model) { - try { - auto ptr = std::make_unique(buffer, size); - *model = reinterpret_cast(ptr.release()); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("failed to create model: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; +int mmdeploy_model_create(const void* buffer, int size, mmdeploy_model_t* model) +{ + try + { + auto ptr = std::make_unique(buffer, size); + *model = reinterpret_cast(ptr.release()); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("failed to create model: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -void mmdeploy_model_destroy(mmdeploy_model_t model) { delete reinterpret_cast(model); } +void mmdeploy_model_destroy(mmdeploy_model_t model) +{ + delete reinterpret_cast(model); +} diff --git a/csrc/mmdeploy/apis/c/mmdeploy/model.h b/csrc/mmdeploy/apis/c/mmdeploy/model.h index 394d2902c2..ddea967f1a 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/model.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/model.h @@ -11,34 +11,35 @@ #include "mmdeploy/common.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_model* mmdeploy_model_t; - -/** - * @brief Create SDK Model instance from given model path - * @param[in] path model path - * @param[out] model sdk model instance that must be destroyed by \ref mmdeploy_model_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_model_create_by_path(const char* path, mmdeploy_model_t* model); - -/** - * @brief Create SDK Model instance from memory - * @param[in] buffer a linear buffer contains the model information - * @param[in] size size of \p buffer in bytes - * @param[out] model sdk model instance that must be destroyed by \ref mmdeploy_model_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_model_create(const void* buffer, int size, mmdeploy_model_t* model); - -/** - * @brief Destroy model instance - * @param[in] model sdk model instance created by \ref mmdeploy_model_create_by_path or \ref - * mmdeploy_model_create - */ -MMDEPLOY_API void mmdeploy_model_destroy(mmdeploy_model_t model); + typedef struct mmdeploy_model* mmdeploy_model_t; + + /** + * @brief Create SDK Model instance from given model path + * @param[in] path model path + * @param[out] model sdk model instance that must be destroyed by \ref mmdeploy_model_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_model_create_by_path(const char* path, mmdeploy_model_t* model); + + /** + * @brief Create SDK Model instance from memory + * @param[in] buffer a linear buffer contains the model information + * @param[in] size size of \p buffer in bytes + * @param[out] model sdk model instance that must be destroyed by \ref mmdeploy_model_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_model_create(const void* buffer, int size, mmdeploy_model_t* model); + + /** + * @brief Destroy model instance + * @param[in] model sdk model instance created by \ref mmdeploy_model_create_by_path or \ref + * mmdeploy_model_create + */ + MMDEPLOY_API void mmdeploy_model_destroy(mmdeploy_model_t model); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pipeline.cpp b/csrc/mmdeploy/apis/c/mmdeploy/pipeline.cpp index a9a02807ee..b0d3d6a220 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pipeline.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/pipeline.cpp @@ -6,73 +6,90 @@ #include "mmdeploy/executor_internal.h" #include "mmdeploy/handle.h" -int mmdeploy_pipeline_create_v3(mmdeploy_value_t config, mmdeploy_context_t context, - mmdeploy_pipeline_t* pipeline) { - try { - auto _config = *Cast(config); - if (context) { - if (!_config.contains("context")) { - _config["context"] = Value::Object(); - } - update(_config["context"].object(), Cast(context)->object(), 2); +int mmdeploy_pipeline_create_v3(mmdeploy_value_t config, mmdeploy_context_t context, mmdeploy_pipeline_t* pipeline) +{ + try + { + auto _config = *Cast(config); + if (context) + { + if (!_config.contains("context")) + { + _config["context"] = Value::Object(); + } + update(_config["context"].object(), Cast(context)->object(), 2); + } + auto _handle = std::make_unique(std::move(_config)); + *pipeline = Cast(_handle.release()); + return MMDEPLOY_SUCCESS; } - auto _handle = std::make_unique(std::move(_config)); - *pipeline = Cast(_handle.release()); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("exception caught: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + catch (const std::exception& e) + { + MMDEPLOY_ERROR("exception caught: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -int mmdeploy_pipeline_create_from_model(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_pipeline_t* pipeline) { - auto config = Cast(model)->ReadConfig("pipeline.json"); - auto _context = *Cast(context); - _context["model"] = *Cast(model); - return mmdeploy_pipeline_create_v3(Cast(&config.value()), (mmdeploy_context_t)&_context, - pipeline); +int mmdeploy_pipeline_create_from_model(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_pipeline_t* pipeline) +{ + auto config = Cast(model)->ReadConfig("pipeline.json"); + auto _context = *Cast(context); + _context["model"] = *Cast(model); + return mmdeploy_pipeline_create_v3(Cast(&config.value()), (mmdeploy_context_t)&_context, pipeline); } -int mmdeploy_pipeline_apply_async(mmdeploy_pipeline_t pipeline, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - if (!pipeline || !input || !output) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - auto h = Cast(pipeline); - *output = Take(h->Process(Take(input))); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("exception caught: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; +int mmdeploy_pipeline_apply_async(mmdeploy_pipeline_t pipeline, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + if (!pipeline || !input || !output) + { + return MMDEPLOY_E_INVALID_ARG; + } + try + { + auto h = Cast(pipeline); + *output = Take(h->Process(Take(input))); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("exception caught: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -void mmdeploy_pipeline_destroy(mmdeploy_pipeline_t pipeline) { - if (pipeline != nullptr) { - delete Cast(pipeline); - } +void mmdeploy_pipeline_destroy(mmdeploy_pipeline_t pipeline) +{ + if (pipeline != nullptr) + { + delete Cast(pipeline); + } } -int mmdeploy_pipeline_apply(mmdeploy_pipeline_t pipeline, mmdeploy_value_t input, - mmdeploy_value_t* output) { - auto input_sender = mmdeploy_executor_just(input); - if (!input_sender) { - return MMDEPLOY_E_FAIL; - } - mmdeploy_sender_t output_sender{}; - if (auto ec = mmdeploy_pipeline_apply_async(pipeline, input_sender, &output_sender)) { - return ec; - } - auto _output = mmdeploy_executor_sync_wait(output_sender); - if (!_output) { - return MMDEPLOY_E_FAIL; - } - *output = _output; - return MMDEPLOY_SUCCESS; +int mmdeploy_pipeline_apply(mmdeploy_pipeline_t pipeline, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + auto input_sender = mmdeploy_executor_just(input); + if (!input_sender) + { + return MMDEPLOY_E_FAIL; + } + mmdeploy_sender_t output_sender{}; + if (auto ec = mmdeploy_pipeline_apply_async(pipeline, input_sender, &output_sender)) + { + return ec; + } + auto _output = mmdeploy_executor_sync_wait(output_sender); + if (!_output) + { + return MMDEPLOY_E_FAIL; + } + *output = _output; + return MMDEPLOY_SUCCESS; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pipeline.h b/csrc/mmdeploy/apis/c/mmdeploy/pipeline.h index 55ccf1e67c..faf523863f 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pipeline.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/pipeline.h @@ -8,59 +8,59 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -/****************************************************************************** - * Experimental pipeline APIs */ + /****************************************************************************** + * Experimental pipeline APIs */ -typedef struct mmdeploy_pipeline* mmdeploy_pipeline_t; + typedef struct mmdeploy_pipeline* mmdeploy_pipeline_t; -/** - * Create pipeline - * @param config - * @param context - * @param pipeline - * @return - */ -MMDEPLOY_API int mmdeploy_pipeline_create_v3(mmdeploy_value_t config, mmdeploy_context_t context, - mmdeploy_pipeline_t* pipeline); -/** - * Create pipeline from internal pipeline config of the model - * @param model - * @param context - * @param pipeline - * @return - */ -MMDEPLOY_API int mmdeploy_pipeline_create_from_model(mmdeploy_model_t model, - mmdeploy_context_t context, - mmdeploy_pipeline_t* pipeline); + /** + * Create pipeline + * @param config + * @param context + * @param pipeline + * @return + */ + MMDEPLOY_API int mmdeploy_pipeline_create_v3(mmdeploy_value_t config, mmdeploy_context_t context, mmdeploy_pipeline_t* pipeline); + /** + * Create pipeline from internal pipeline config of the model + * @param model + * @param context + * @param pipeline + * @return + */ + MMDEPLOY_API int mmdeploy_pipeline_create_from_model(mmdeploy_model_t model, + mmdeploy_context_t context, + mmdeploy_pipeline_t* pipeline); -/** - * @brief Apply pipeline - * @param[in] pipeline handle of the pipeline - * @param[in] input input value - * @param[out] output output value - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_pipeline_apply(mmdeploy_pipeline_t pipeline, mmdeploy_value_t input, - mmdeploy_value_t* output); + /** + * @brief Apply pipeline + * @param[in] pipeline handle of the pipeline + * @param[in] input input value + * @param[out] output output value + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_pipeline_apply(mmdeploy_pipeline_t pipeline, mmdeploy_value_t input, mmdeploy_value_t* output); -/** - * Apply pipeline asynchronously - * @param pipeline handle of the pipeline - * @param input input sender that will be consumed by the operation - * @param output output sender - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_pipeline_apply_async(mmdeploy_pipeline_t pipeline, - mmdeploy_sender_t input, mmdeploy_sender_t* output); + /** + * Apply pipeline asynchronously + * @param pipeline handle of the pipeline + * @param input input sender that will be consumed by the operation + * @param output output sender + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_pipeline_apply_async(mmdeploy_pipeline_t pipeline, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); -/** - * @brief destroy pipeline - * @param[in] pipeline - */ -MMDEPLOY_API void mmdeploy_pipeline_destroy(mmdeploy_pipeline_t pipeline); + /** + * @brief destroy pipeline + * @param[in] pipeline + */ + MMDEPLOY_API void mmdeploy_pipeline_destroy(mmdeploy_pipeline_t pipeline); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.cpp b/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.cpp index 46f9921e62..ee0cc0c564 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.cpp @@ -16,164 +16,197 @@ using namespace std; using namespace mmdeploy; -int mmdeploy_pose_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_pose_detector_t* detector) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_pose_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_pose_detector_t* detector) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_pose_detector_create_v2(model, context, detector); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_pose_detector_create_v2(model, context, detector); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_pose_detector_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_pose_detector_t* detector) { - mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { +int mmdeploy_pose_detector_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_pose_detector_t* detector) +{ + mmdeploy_model_t model{}; + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_pose_detector_create(model, device_name, device_id, detector); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_pose_detector_create(model, device_name, device_id, detector); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_pose_detector_apply(mmdeploy_pose_detector_t detector, const mmdeploy_mat_t* mats, - int mat_count, mmdeploy_pose_detection_t** results) { - return mmdeploy_pose_detector_apply_bbox(detector, mats, mat_count, nullptr, nullptr, results); +int mmdeploy_pose_detector_apply(mmdeploy_pose_detector_t detector, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_pose_detection_t** results) +{ + return mmdeploy_pose_detector_apply_bbox(detector, mats, mat_count, nullptr, nullptr, results); } -int mmdeploy_pose_detector_apply_bbox(mmdeploy_pose_detector_t detector, const mmdeploy_mat_t* mats, - int mat_count, const mmdeploy_rect_t* bboxes, - const int* bbox_count, mmdeploy_pose_detection_t** results) { - wrapped input; - if (auto ec = - mmdeploy_pose_detector_create_input(mats, mat_count, bboxes, bbox_count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_pose_detector_apply_v2(detector, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_pose_detector_get_result(output, results)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_pose_detector_apply_bbox(mmdeploy_pose_detector_t detector, const mmdeploy_mat_t* mats, int mat_count, const mmdeploy_rect_t* bboxes, const int* bbox_count, mmdeploy_pose_detection_t** results) +{ + wrapped input; + if (auto ec = + mmdeploy_pose_detector_create_input(mats, mat_count, bboxes, bbox_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_pose_detector_apply_v2(detector, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_pose_detector_get_result(output, results)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -void mmdeploy_pose_detector_release_result(mmdeploy_pose_detection_t* results, int count) { - if (results == nullptr) { - return; - } - for (int i = 0; i < count; ++i) { - delete[] results[i].point; - delete[] results[i].score; - } - delete[] results; +void mmdeploy_pose_detector_release_result(mmdeploy_pose_detection_t* results, int count) +{ + if (results == nullptr) + { + return; + } + for (int i = 0; i < count; ++i) + { + delete[] results[i].point; + delete[] results[i].score; + } + delete[] results; } -void mmdeploy_pose_detector_destroy(mmdeploy_pose_detector_t detector) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); +void mmdeploy_pose_detector_destroy(mmdeploy_pose_detector_t detector) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); } -int mmdeploy_pose_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_pose_detector_t* detector) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); +int mmdeploy_pose_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_pose_detector_t* detector) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); } -int mmdeploy_pose_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - const mmdeploy_rect_t* bboxes, const int* bbox_count, - mmdeploy_value_t* value) { - if (mat_count && mats == nullptr) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - Value::Array input_images; - - auto add_bbox = [&](const Mat& img, const mmdeploy_rect_t* bbox) { - Value::Array b; - if (bbox) { - float width = bbox->right - bbox->left + 1; - float height = bbox->bottom - bbox->top + 1; - b = {bbox->left, bbox->top, width, height, 1.0}; - } else { - b = {0, 0, img.width(), img.height(), 1.0}; - } - input_images.push_back({{"ori_img", img}, {"bbox", std::move(b)}}); - }; - - for (int i = 0; i < mat_count; ++i) { - auto _mat = Cast(mats[i]); - if (bboxes && bbox_count) { - for (int j = 0; j < bbox_count[i]; ++j) { - add_bbox(_mat, bboxes++); - } - } else { // inference whole image - add_bbox(_mat, nullptr); - } +int mmdeploy_pose_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, const mmdeploy_rect_t* bboxes, const int* bbox_count, mmdeploy_value_t* value) +{ + if (mat_count && mats == nullptr) + { + return MMDEPLOY_E_INVALID_ARG; } + try + { + Value::Array input_images; + + auto add_bbox = [&](const Mat& img, const mmdeploy_rect_t* bbox) + { + Value::Array b; + if (bbox) + { + float width = bbox->right - bbox->left + 1; + float height = bbox->bottom - bbox->top + 1; + b = {bbox->left, bbox->top, width, height, 1.0}; + } + else + { + b = {0, 0, img.width(), img.height(), 1.0}; + } + input_images.push_back({{"ori_img", img}, {"bbox", std::move(b)}}); + }; + + for (int i = 0; i < mat_count; ++i) + { + auto _mat = Cast(mats[i]); + if (bboxes && bbox_count) + { + for (int j = 0; j < bbox_count[i]; ++j) + { + add_bbox(_mat, bboxes++); + } + } + else + { // inference whole image + add_bbox(_mat, nullptr); + } + } - *value = Take(Value{std::move(input_images)}); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + *value = Take(Value{std::move(input_images)}); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -int mmdeploy_pose_detector_apply_v2(mmdeploy_pose_detector_t detector, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_pose_detector_apply_v2(mmdeploy_pose_detector_t detector, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_pose_detector_apply_async(mmdeploy_pose_detector_t detector, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_pose_detector_apply_async(mmdeploy_pose_detector_t detector, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_pose_detector_get_result(mmdeploy_value_t output, - mmdeploy_pose_detection_t** results) { - if (!output || !results) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - std::vector detections; - from_value(Cast(output)->front(), detections); - - size_t count = detections.size(); - - auto deleter = [&](mmdeploy_pose_detection_t* p) { - mmdeploy_pose_detector_release_result(p, static_cast(count)); - }; - - std::unique_ptr _results( - new mmdeploy_pose_detection_t[count]{}, deleter); - - size_t result_idx = 0; - for (const auto& bbox_result : detections) { - auto& res = _results[result_idx++]; - auto size = bbox_result.key_points.size(); - - res.point = new mmdeploy_point_t[size]; - res.score = new float[size]; - res.length = static_cast(size); - - for (int k = 0; k < size; k++) { - res.point[k].x = bbox_result.key_points[k].bbox[0]; - res.point[k].y = bbox_result.key_points[k].bbox[1]; - res.score[k] = bbox_result.key_points[k].score; - } +int mmdeploy_pose_detector_get_result(mmdeploy_value_t output, + mmdeploy_pose_detection_t** results) +{ + if (!output || !results) + { + return MMDEPLOY_E_INVALID_ARG; } + try + { + std::vector detections; + from_value(Cast(output)->front(), detections); + + size_t count = detections.size(); + + auto deleter = [&](mmdeploy_pose_detection_t* p) + { + mmdeploy_pose_detector_release_result(p, static_cast(count)); + }; + + std::unique_ptr _results( + new mmdeploy_pose_detection_t[count]{}, + deleter); + + size_t result_idx = 0; + for (const auto& bbox_result : detections) + { + auto& res = _results[result_idx++]; + auto size = bbox_result.key_points.size(); + + res.point = new mmdeploy_point_t[size]; + res.score = new float[size]; + res.length = static_cast(size); + + for (int k = 0; k < size; k++) + { + res.point[k].x = bbox_result.key_points[k].bbox[0]; + res.point[k].y = bbox_result.key_points[k].bbox[1]; + res.score[k] = bbox_result.key_points[k].score; + } + } - *results = _results.release(); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + *results = _results.release(); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.h b/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.h index ff0987cee4..6fceb99f72 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/pose_detector.h @@ -13,111 +13,113 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_pose_detection_t { - mmdeploy_point_t* point; ///< keypoint - float* score; ///< keypoint score - int length; ///< number of keypoint -} mmdeploy_pose_detection_t; - -typedef struct mmdeploy_pose_detector* mmdeploy_pose_detector_t; - -/** - * @brief Create a pose detector instance - * @param[in] model an instance of mmpose model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector handle of the created pose detector, which must be destroyed - * by \ref mmdeploy_pose_detector_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_pose_detector_create(mmdeploy_model_t model, const char* device_name, - int device_id, mmdeploy_pose_detector_t* detector); - -/** - * @brief Create a pose detector instance - * @param[in] model_path path to pose detection model - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector handle of the created pose detector, which must be destroyed - * by \ref mmdeploy_pose_detector_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_pose_detector_create_by_path(const char* model_path, - const char* device_name, int device_id, - mmdeploy_pose_detector_t* detector); - -/** - * @brief Apply pose detector to a batch of images with full image roi - * @param[in] detector pose detector's handle created by \ref - * mmdeploy_pose_detector_create_by_path - * @param[in] images a batch of images - * @param[in] count number of images in the batch - * @param[out] results a linear buffer contains the pose result, must be release - * by \ref mmdeploy_pose_detector_release_result - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_pose_detector_apply(mmdeploy_pose_detector_t detector, - const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_pose_detection_t** results); - -/** - * @brief Apply pose detector to a batch of images supplied with bboxes(roi) - * @param[in] detector pose detector's handle created by \ref - * mmdeploy_pose_detector_create_by_path - * @param[in] images a batch of images - * @param[in] image_count number of images in the batch - * @param[in] bboxes bounding boxes(roi) detected by mmdet - * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images - * @param[out] results a linear buffer contains the pose result, which has the same length as \p - * bboxes, must be release by \ref mmdeploy_pose_detector_release_result - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_pose_detector_apply_bbox(mmdeploy_pose_detector_t detector, - const mmdeploy_mat_t* mats, int mat_count, - const mmdeploy_rect_t* bboxes, - const int* bbox_count, - mmdeploy_pose_detection_t** results); - -/** @brief Release result buffer returned by \ref mmdeploy_pose_detector_apply or \ref - * mmdeploy_pose_detector_apply_bbox - * @param[in] results result buffer by pose detector - * @param[in] count length of \p result - */ -MMDEPLOY_API void mmdeploy_pose_detector_release_result(mmdeploy_pose_detection_t* results, - int count); - -/** - * @brief destroy pose_detector - * @param[in] detector handle of pose_detector created by \ref - * mmdeploy_pose_detector_create_by_path or \ref mmdeploy_pose_detector_create - */ -MMDEPLOY_API void mmdeploy_pose_detector_destroy(mmdeploy_pose_detector_t detector); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -MMDEPLOY_API int mmdeploy_pose_detector_create_v2(mmdeploy_model_t model, - mmdeploy_context_t context, - mmdeploy_pose_detector_t* detector); - -MMDEPLOY_API int mmdeploy_pose_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - const mmdeploy_rect_t* bboxes, - const int* bbox_count, - mmdeploy_value_t* value); - -MMDEPLOY_API int mmdeploy_pose_detector_apply_v2(mmdeploy_pose_detector_t detector, - mmdeploy_value_t input, mmdeploy_value_t* output); - -MMDEPLOY_API int mmdeploy_pose_detector_apply_async(mmdeploy_pose_detector_t detector, - mmdeploy_sender_t input, - mmdeploy_sender_t* output); - -MMDEPLOY_API int mmdeploy_pose_detector_get_result(mmdeploy_value_t output, - mmdeploy_pose_detection_t** results); + typedef struct mmdeploy_pose_detection_t + { + mmdeploy_point_t* point; ///< keypoint + float* score; ///< keypoint score + int length; ///< number of keypoint + } mmdeploy_pose_detection_t; + + typedef struct mmdeploy_pose_detector* mmdeploy_pose_detector_t; + + /** + * @brief Create a pose detector instance + * @param[in] model an instance of mmpose model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector handle of the created pose detector, which must be destroyed + * by \ref mmdeploy_pose_detector_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_pose_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_pose_detector_t* detector); + + /** + * @brief Create a pose detector instance + * @param[in] model_path path to pose detection model + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector handle of the created pose detector, which must be destroyed + * by \ref mmdeploy_pose_detector_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_pose_detector_create_by_path(const char* model_path, + const char* device_name, + int device_id, + mmdeploy_pose_detector_t* detector); + + /** + * @brief Apply pose detector to a batch of images with full image roi + * @param[in] detector pose detector's handle created by \ref + * mmdeploy_pose_detector_create_by_path + * @param[in] images a batch of images + * @param[in] count number of images in the batch + * @param[out] results a linear buffer contains the pose result, must be release + * by \ref mmdeploy_pose_detector_release_result + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_pose_detector_apply(mmdeploy_pose_detector_t detector, + const mmdeploy_mat_t* mats, + int mat_count, + mmdeploy_pose_detection_t** results); + + /** + * @brief Apply pose detector to a batch of images supplied with bboxes(roi) + * @param[in] detector pose detector's handle created by \ref + * mmdeploy_pose_detector_create_by_path + * @param[in] images a batch of images + * @param[in] image_count number of images in the batch + * @param[in] bboxes bounding boxes(roi) detected by mmdet + * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images + * @param[out] results a linear buffer contains the pose result, which has the same length as \p + * bboxes, must be release by \ref mmdeploy_pose_detector_release_result + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_pose_detector_apply_bbox(mmdeploy_pose_detector_t detector, + const mmdeploy_mat_t* mats, + int mat_count, + const mmdeploy_rect_t* bboxes, + const int* bbox_count, + mmdeploy_pose_detection_t** results); + + /** @brief Release result buffer returned by \ref mmdeploy_pose_detector_apply or \ref + * mmdeploy_pose_detector_apply_bbox + * @param[in] results result buffer by pose detector + * @param[in] count length of \p result + */ + MMDEPLOY_API void mmdeploy_pose_detector_release_result(mmdeploy_pose_detection_t* results, + int count); + + /** + * @brief destroy pose_detector + * @param[in] detector handle of pose_detector created by \ref + * mmdeploy_pose_detector_create_by_path or \ref mmdeploy_pose_detector_create + */ + MMDEPLOY_API void mmdeploy_pose_detector_destroy(mmdeploy_pose_detector_t detector); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + MMDEPLOY_API int mmdeploy_pose_detector_create_v2(mmdeploy_model_t model, + mmdeploy_context_t context, + mmdeploy_pose_detector_t* detector); + + MMDEPLOY_API int mmdeploy_pose_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, const mmdeploy_rect_t* bboxes, const int* bbox_count, mmdeploy_value_t* value); + + MMDEPLOY_API int mmdeploy_pose_detector_apply_v2(mmdeploy_pose_detector_t detector, + mmdeploy_value_t input, + mmdeploy_value_t* output); + + MMDEPLOY_API int mmdeploy_pose_detector_apply_async(mmdeploy_pose_detector_t detector, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + MMDEPLOY_API int mmdeploy_pose_detector_get_result(mmdeploy_value_t output, + mmdeploy_pose_detection_t** results); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.cpp b/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.cpp index 113b520c39..d2587b1949 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.cpp @@ -9,18 +9,21 @@ #include "mmdeploy/core/mpl/structure.h" #include "mmdeploy/pipeline.h" -namespace mmdeploy { +namespace mmdeploy +{ -using namespace framework; + using namespace framework; } // namespace mmdeploy using namespace mmdeploy; -namespace { +namespace +{ -Value config_template() { - static const auto json = R"( + Value config_template() + { + static const auto json = R"( { "type": "Pipeline", "input": ["img", "force_det", "state"], @@ -77,149 +80,184 @@ Value config_template() { ] } )"_json; - static const auto config = from_json(json); - return config; -} + static const auto config = from_json(json); + return config; + } } // namespace -int mmdeploy_pose_tracker_default_params(mmdeploy_pose_tracker_param_t* params) { - mmpose::_pose_tracker::SetDefaultParams(*params); - return 0; +int mmdeploy_pose_tracker_default_params(mmdeploy_pose_tracker_param_t* params) +{ + mmpose::_pose_tracker::SetDefaultParams(*params); + return 0; } -int mmdeploy_pose_tracker_create(mmdeploy_model_t det_model, mmdeploy_model_t pose_model, - mmdeploy_context_t context, mmdeploy_pose_tracker_t* pipeline) { - mmdeploy_context_add(context, MMDEPLOY_TYPE_MODEL, "detection", det_model); - mmdeploy_context_add(context, MMDEPLOY_TYPE_MODEL, "pose", pose_model); - auto config = config_template(); - return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)pipeline); +int mmdeploy_pose_tracker_create(mmdeploy_model_t det_model, mmdeploy_model_t pose_model, mmdeploy_context_t context, mmdeploy_pose_tracker_t* pipeline) +{ + mmdeploy_context_add(context, MMDEPLOY_TYPE_MODEL, "detection", det_model); + mmdeploy_context_add(context, MMDEPLOY_TYPE_MODEL, "pose", pose_model); + auto config = config_template(); + return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)pipeline); } -void mmdeploy_pose_tracker_destroy(mmdeploy_pose_tracker_t pipeline) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)pipeline); +void mmdeploy_pose_tracker_destroy(mmdeploy_pose_tracker_t pipeline) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)pipeline); } -int mmdeploy_pose_tracker_create_state(mmdeploy_pose_tracker_t pipeline, +int mmdeploy_pose_tracker_create_state(mmdeploy_pose_tracker_t pipeline, const mmdeploy_pose_tracker_param_t* params, - mmdeploy_pose_tracker_state_t* state) { - try { - auto create_fn = gRegistry().Create("pose_tracker::Create", Value()).value(); - *state = reinterpret_cast(new Value( - create_fn->Process({const_cast(params)}).value()[0])); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + mmdeploy_pose_tracker_state_t* state) +{ + try + { + auto create_fn = gRegistry().Create("pose_tracker::Create", Value()).value(); + *state = reinterpret_cast(new Value( + create_fn->Process({const_cast(params)}).value()[0])); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -void mmdeploy_pose_tracker_destroy_state(mmdeploy_pose_tracker_state_t state) { - delete reinterpret_cast(state); +void mmdeploy_pose_tracker_destroy_state(mmdeploy_pose_tracker_state_t state) +{ + delete reinterpret_cast(state); } int mmdeploy_pose_tracker_create_input(mmdeploy_pose_tracker_state_t* states, - const mmdeploy_mat_t* frames, const int32_t* use_detect, - int batch_size, mmdeploy_value_t* value) { - try { - Value::Array images; - Value::Array use_dets; - Value::Array trackers; - for (int i = 0; i < batch_size; ++i) { - images.push_back({{"ori_img", Cast(frames[i])}}); - use_dets.emplace_back(use_detect ? use_detect[i] : -1); - trackers.push_back(*reinterpret_cast(states[i])); + const mmdeploy_mat_t* frames, + const int32_t* use_detect, + int batch_size, + mmdeploy_value_t* value) +{ + try + { + Value::Array images; + Value::Array use_dets; + Value::Array trackers; + for (int i = 0; i < batch_size; ++i) + { + images.push_back({{"ori_img", Cast(frames[i])}}); + use_dets.emplace_back(use_detect ? use_detect[i] : -1); + trackers.push_back(*reinterpret_cast(states[i])); + } + *value = Take(Value{std::move(images), std::move(use_dets), std::move(trackers)}); + return MMDEPLOY_SUCCESS; } - *value = Take(Value{std::move(images), std::move(use_dets), std::move(trackers)}); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -using ResultType = mmdeploy::Structure, - std::vector>; +using ResultType = mmdeploy::Structure, std::vector>; -int mmdeploy_pose_tracker_get_result(mmdeploy_value_t output, +int mmdeploy_pose_tracker_get_result(mmdeploy_value_t output, mmdeploy_pose_tracker_target_t** results, - int32_t** result_count) { - if (!output || !results) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - // convert result from Values - std::vector res; - from_value(Cast(output)->front(), res); - - size_t total = 0; - for (const auto& r : res) { - total += r.bboxes.size(); + int32_t** result_count) +{ + if (!output || !results) + { + return MMDEPLOY_E_INVALID_ARG; } + try + { + // convert result from Values + std::vector res; + from_value(Cast(output)->front(), res); - // preserve space for the output structure - ResultType result_type({total, 1, 1}); - auto [result_data, result_cnt, result_holder] = result_type.pointers(); + size_t total = 0; + for (const auto& r : res) + { + total += r.bboxes.size(); + } - auto result_ptr = result_data; + // preserve space for the output structure + ResultType result_type({total, 1, 1}); + auto [result_data, result_cnt, result_holder] = result_type.pointers(); - result_holder->swap(res); + auto result_ptr = result_data; - // build output structure - for (auto& r : *result_holder) { - for (int j = 0; j < r.bboxes.size(); ++j) { - auto& p = *result_ptr++; - p.keypoint_count = static_cast(r.keypoints[j].size()); - p.keypoints = r.keypoints[j].data(); - p.scores = r.scores[j].data(); - p.bbox = r.bboxes[j]; - p.target_id = r.track_ids[j]; - } - result_cnt->push_back(r.bboxes.size()); - // debug info - // p.reserved0 = new std::vector(r.pose_input_bboxes); - // p.reserved1 = new std::vector(r.pose_output_bboxes); - } + result_holder->swap(res); - *results = result_data; - *result_count = result_cnt->data(); - result_type.release(); + // build output structure + for (auto& r : *result_holder) + { + for (int j = 0; j < r.bboxes.size(); ++j) + { + auto& p = *result_ptr++; + p.keypoint_count = static_cast(r.keypoints[j].size()); + p.keypoints = r.keypoints[j].data(); + p.scores = r.scores[j].data(); + p.bbox = r.bboxes[j]; + p.target_id = r.track_ids[j]; + } + result_cnt->push_back(r.bboxes.size()); + // debug info + // p.reserved0 = new std::vector(r.pose_input_bboxes); + // p.reserved1 = new std::vector(r.pose_output_bboxes); + } - return MMDEPLOY_SUCCESS; + *results = result_data; + *result_count = result_cnt->data(); + result_type.release(); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -int mmdeploy_pose_tracker_apply(mmdeploy_pose_tracker_t pipeline, - mmdeploy_pose_tracker_state_t* states, const mmdeploy_mat_t* frames, - const int32_t* use_detect, int32_t count, - mmdeploy_pose_tracker_target_t** results, int32_t** result_count) { - wrapped input; - if (auto ec = - mmdeploy_pose_tracker_create_input(states, frames, use_detect, count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_pipeline_apply((mmdeploy_pipeline_t)pipeline, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_pose_tracker_get_result(output, results, result_count)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_pose_tracker_apply(mmdeploy_pose_tracker_t pipeline, + mmdeploy_pose_tracker_state_t* states, + const mmdeploy_mat_t* frames, + const int32_t* use_detect, + int32_t count, + mmdeploy_pose_tracker_target_t** results, + int32_t** result_count) +{ + wrapped input; + if (auto ec = + mmdeploy_pose_tracker_create_input(states, frames, use_detect, count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_pipeline_apply((mmdeploy_pipeline_t)pipeline, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_pose_tracker_get_result(output, results, result_count)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } void mmdeploy_pose_tracker_release_result(mmdeploy_pose_tracker_target_t* results, - const int32_t* result_count, int count) { - auto total = std::accumulate(result_count, result_count + count, 0); - ResultType deleter({static_cast(total), 1, 1}, results); + const int32_t* result_count, + int count) +{ + auto total = std::accumulate(result_count, result_count + count, 0); + ResultType deleter({static_cast(total), 1, 1}, results); } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h b/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h index 4b27fbab8a..c8191b40fa 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/pose_tracker.h @@ -14,142 +14,147 @@ #include "mmdeploy/pose_detector.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_pose_tracker* mmdeploy_pose_tracker_t; -typedef struct mmdeploy_pose_tracker_state* mmdeploy_pose_tracker_state_t; - -typedef struct mmdeploy_pose_tracker_param_t { - // detection interval, default = 1 - int32_t det_interval; - // detection label use for pose estimation, default = 0 - int32_t det_label; - // detection score threshold, default = 0.5 - float det_thr; - // detection minimum bbox size (compute as sqrt(area)), default = -1 - float det_min_bbox_size; - // nms iou threshold for merging detected bboxes and bboxes from tracked targets, default = 0.7 - float det_nms_thr; - - // max number of bboxes used for pose estimation per frame, default = -1 - int32_t pose_max_num_bboxes; - // threshold for visible key-points, default = 0.5 - float pose_kpt_thr; - // min number of key-points for valid poses (-1 indicates ceil(n_kpts/2)), default = -1 - int32_t pose_min_keypoints; - // scale for expanding key-points to bbox, default = 1.25 - float pose_bbox_scale; - // min pose bbox size, tracks with bbox size smaller than the threshold will be dropped, - // default = -1 - float pose_min_bbox_size; - // nms oks/iou threshold for suppressing overlapped poses, useful when multiple pose estimations - // collapse to the same target, default = 0.5 - float pose_nms_thr; - // keypoint sigmas for computing OKS, will use IOU if not set, default = nullptr - float* keypoint_sigmas; - // size of keypoint sigma array, must be consistent with the number of key-points, default = 0 - int32_t keypoint_sigmas_size; - - // iou threshold for associating missing tracks, default = 0.4 - float track_iou_thr; - // max number of missing frames before a missing tracks is removed, default = 10 - int32_t track_max_missing; - // track history size, default = 1 - int32_t track_history_size; - - // weight of position for setting covariance matrices of kalman filters, default = 0.05 - float std_weight_position; - // weight of velocity for setting covariance matrices of kalman filters, default = 0.00625 - float std_weight_velocity; - - // params for the one-euro filter for smoothing the outputs - (beta, fc_min, fc_derivative) - // default = (0.007, 1, 1) - float smooth_params[3]; -} mmdeploy_pose_tracker_param_t; - -typedef struct mmdeploy_pose_tracker_target_t { - mmdeploy_point_t* keypoints; // key-points of the target - int32_t keypoint_count; // size of `keypoints` array - float* scores; // scores of each key-point - mmdeploy_rect_t bbox; // estimated bbox from key-points - uint32_t target_id; // target id from internal tracker -} mmdeploy_pose_tracker_target_t; - -/** - * @brief Fill params with default parameters - * @param[in,out] params - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_pose_tracker_default_params(mmdeploy_pose_tracker_param_t* params); - -/** - * @brief Create pose tracker pipeline - * @param[in] det_model detection model object, created by \ref mmdeploy_model_create - * @param[in] pose_model pose model object - * @param[in] context context object describing execution environment (device, profiler, etc...), - * created by \ref mmdeploy_context_create - * @param[out] pipeline handle of the created pipeline - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_pose_tracker_create(mmdeploy_model_t det_model, - mmdeploy_model_t pose_model, - mmdeploy_context_t context, - mmdeploy_pose_tracker_t* pipeline); - -/** - * @brief Destroy pose tracker pipeline - * @param[in] pipeline - */ -MMDEPLOY_API void mmdeploy_pose_tracker_destroy(mmdeploy_pose_tracker_t pipeline); - -/** - * @brief Create a tracker state handle corresponds to a video stream - * @param[in] pipeline handle of a pose tracker pipeline - * @param[in] params params for creating the tracker state - * @param[out] state handle of the created tracker state - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_pose_tracker_create_state(mmdeploy_pose_tracker_t pipeline, - const mmdeploy_pose_tracker_param_t* params, - mmdeploy_pose_tracker_state_t* state); - -/** - * @brief Destroy tracker state - * @param[in] state handle of the tracker state - */ -MMDEPLOY_API void mmdeploy_pose_tracker_destroy_state(mmdeploy_pose_tracker_state_t state); - -/** - * @brief Apply pose tracker pipeline, notice that this function supports batch operation by feeding - * arrays of size \p count to \p states, \p frames and \p use_detect - * @param[in] pipeline handle of a pose tracker pipeline - * @param[in] states tracker states handles, array of size \p count - * @param[in] frames input frames of size \p count - * @param[in] use_detect control the use of detector, array of size \p count - * -1: use params.det_interval, 0: don't use detector, 1: force use detector - * @param[in] count batch size - * @param[out] results a linear buffer contains the tracked targets of input frames. Should be - * released by \ref mmdeploy_pose_tracker_release_result - * @param[out] result_count a linear buffer of size \p count contains the number of tracked - * targets of the frames. Should be released by \ref mmdeploy_pose_tracker_release_result - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_pose_tracker_apply(mmdeploy_pose_tracker_t pipeline, - mmdeploy_pose_tracker_state_t* states, - const mmdeploy_mat_t* frames, - const int32_t* use_detect, int32_t count, - mmdeploy_pose_tracker_target_t** results, - int32_t** result_count); - -/** - * @brief Release result objects - * @param[in] results - * @param[in] result_count - * @param[in] count - */ -MMDEPLOY_API void mmdeploy_pose_tracker_release_result(mmdeploy_pose_tracker_target_t* results, - const int32_t* result_count, int count); + typedef struct mmdeploy_pose_tracker* mmdeploy_pose_tracker_t; + typedef struct mmdeploy_pose_tracker_state* mmdeploy_pose_tracker_state_t; + + typedef struct mmdeploy_pose_tracker_param_t + { + // detection interval, default = 1 + int32_t det_interval; + // detection label use for pose estimation, default = 0 + int32_t det_label; + // detection score threshold, default = 0.5 + float det_thr; + // detection minimum bbox size (compute as sqrt(area)), default = -1 + float det_min_bbox_size; + // nms iou threshold for merging detected bboxes and bboxes from tracked targets, default = 0.7 + float det_nms_thr; + + // max number of bboxes used for pose estimation per frame, default = -1 + int32_t pose_max_num_bboxes; + // threshold for visible key-points, default = 0.5 + float pose_kpt_thr; + // min number of key-points for valid poses (-1 indicates ceil(n_kpts/2)), default = -1 + int32_t pose_min_keypoints; + // scale for expanding key-points to bbox, default = 1.25 + float pose_bbox_scale; + // min pose bbox size, tracks with bbox size smaller than the threshold will be dropped, + // default = -1 + float pose_min_bbox_size; + // nms oks/iou threshold for suppressing overlapped poses, useful when multiple pose estimations + // collapse to the same target, default = 0.5 + float pose_nms_thr; + // keypoint sigmas for computing OKS, will use IOU if not set, default = nullptr + float* keypoint_sigmas; + // size of keypoint sigma array, must be consistent with the number of key-points, default = 0 + int32_t keypoint_sigmas_size; + + // iou threshold for associating missing tracks, default = 0.4 + float track_iou_thr; + // max number of missing frames before a missing tracks is removed, default = 10 + int32_t track_max_missing; + // track history size, default = 1 + int32_t track_history_size; + + // weight of position for setting covariance matrices of kalman filters, default = 0.05 + float std_weight_position; + // weight of velocity for setting covariance matrices of kalman filters, default = 0.00625 + float std_weight_velocity; + + // params for the one-euro filter for smoothing the outputs - (beta, fc_min, fc_derivative) + // default = (0.007, 1, 1) + float smooth_params[3]; + } mmdeploy_pose_tracker_param_t; + + typedef struct mmdeploy_pose_tracker_target_t + { + mmdeploy_point_t* keypoints; // key-points of the target + int32_t keypoint_count; // size of `keypoints` array + float* scores; // scores of each key-point + mmdeploy_rect_t bbox; // estimated bbox from key-points + uint32_t target_id; // target id from internal tracker + } mmdeploy_pose_tracker_target_t; + + /** + * @brief Fill params with default parameters + * @param[in,out] params + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_pose_tracker_default_params(mmdeploy_pose_tracker_param_t* params); + + /** + * @brief Create pose tracker pipeline + * @param[in] det_model detection model object, created by \ref mmdeploy_model_create + * @param[in] pose_model pose model object + * @param[in] context context object describing execution environment (device, profiler, etc...), + * created by \ref mmdeploy_context_create + * @param[out] pipeline handle of the created pipeline + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_pose_tracker_create(mmdeploy_model_t det_model, + mmdeploy_model_t pose_model, + mmdeploy_context_t context, + mmdeploy_pose_tracker_t* pipeline); + + /** + * @brief Destroy pose tracker pipeline + * @param[in] pipeline + */ + MMDEPLOY_API void mmdeploy_pose_tracker_destroy(mmdeploy_pose_tracker_t pipeline); + + /** + * @brief Create a tracker state handle corresponds to a video stream + * @param[in] pipeline handle of a pose tracker pipeline + * @param[in] params params for creating the tracker state + * @param[out] state handle of the created tracker state + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_pose_tracker_create_state(mmdeploy_pose_tracker_t pipeline, + const mmdeploy_pose_tracker_param_t* params, + mmdeploy_pose_tracker_state_t* state); + + /** + * @brief Destroy tracker state + * @param[in] state handle of the tracker state + */ + MMDEPLOY_API void mmdeploy_pose_tracker_destroy_state(mmdeploy_pose_tracker_state_t state); + + /** + * @brief Apply pose tracker pipeline, notice that this function supports batch operation by feeding + * arrays of size \p count to \p states, \p frames and \p use_detect + * @param[in] pipeline handle of a pose tracker pipeline + * @param[in] states tracker states handles, array of size \p count + * @param[in] frames input frames of size \p count + * @param[in] use_detect control the use of detector, array of size \p count + * -1: use params.det_interval, 0: don't use detector, 1: force use detector + * @param[in] count batch size + * @param[out] results a linear buffer contains the tracked targets of input frames. Should be + * released by \ref mmdeploy_pose_tracker_release_result + * @param[out] result_count a linear buffer of size \p count contains the number of tracked + * targets of the frames. Should be released by \ref mmdeploy_pose_tracker_release_result + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_pose_tracker_apply(mmdeploy_pose_tracker_t pipeline, + mmdeploy_pose_tracker_state_t* states, + const mmdeploy_mat_t* frames, + const int32_t* use_detect, + int32_t count, + mmdeploy_pose_tracker_target_t** results, + int32_t** result_count); + + /** + * @brief Release result objects + * @param[in] results + * @param[in] result_count + * @param[in] count + */ + MMDEPLOY_API void mmdeploy_pose_tracker_release_result(mmdeploy_pose_tracker_target_t* results, + const int32_t* result_count, + int count); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/restorer.cpp b/csrc/mmdeploy/apis/c/mmdeploy/restorer.cpp index 9ca2ca65f7..49f8487d12 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/restorer.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/restorer.cpp @@ -16,106 +16,121 @@ using namespace mmdeploy; using ResultType = mmdeploy::Structure; -int mmdeploy_restorer_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_restorer_t* restorer) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_restorer_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_restorer_t* restorer) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_restorer_create_v2(model, context, restorer); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_restorer_create_v2(model, context, restorer); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_restorer_create_by_path(const char* model_path, const char* device_name, int device_id, - mmdeploy_restorer_t* restorer) { - mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { +int mmdeploy_restorer_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_restorer_t* restorer) +{ + mmdeploy_model_t model{}; + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_restorer_create(model, device_name, device_id, restorer); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_restorer_create(model, device_name, device_id, restorer); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_restorer_apply(mmdeploy_restorer_t restorer, const mmdeploy_mat_t* images, int count, - mmdeploy_mat_t** results) { - wrapped input; - if (auto ec = mmdeploy_restorer_create_input(images, count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_restorer_apply_v2(restorer, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_restorer_get_result(output, results)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_restorer_apply(mmdeploy_restorer_t restorer, const mmdeploy_mat_t* images, int count, mmdeploy_mat_t** results) +{ + wrapped input; + if (auto ec = mmdeploy_restorer_create_input(images, count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_restorer_apply_v2(restorer, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_restorer_get_result(output, results)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -void mmdeploy_restorer_release_result(mmdeploy_mat_t* results, int count) { - ResultType deleter{static_cast(count), results}; +void mmdeploy_restorer_release_result(mmdeploy_mat_t* results, int count) +{ + ResultType deleter{static_cast(count), results}; } -void mmdeploy_restorer_destroy(mmdeploy_restorer_t restorer) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)restorer); +void mmdeploy_restorer_destroy(mmdeploy_restorer_t restorer) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)restorer); } -int mmdeploy_restorer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_restorer_t* restorer) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)restorer); +int mmdeploy_restorer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_restorer_t* restorer) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)restorer); } -int mmdeploy_restorer_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value) { - return mmdeploy_common_create_input(mats, mat_count, value); +int mmdeploy_restorer_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value) +{ + return mmdeploy_common_create_input(mats, mat_count, value); } -int mmdeploy_restorer_apply_v2(mmdeploy_restorer_t restorer, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)restorer, input, output); +int mmdeploy_restorer_apply_v2(mmdeploy_restorer_t restorer, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)restorer, input, output); } -int mmdeploy_restorer_apply_async(mmdeploy_restorer_t restorer, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)restorer, input, output); +int mmdeploy_restorer_apply_async(mmdeploy_restorer_t restorer, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)restorer, input, output); } -int mmdeploy_restorer_get_result(mmdeploy_value_t output, mmdeploy_mat_t** results) { - if (!output || !results) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - const Value& value = Cast(output)->front(); - - auto restorer_output = from_value>(value); - auto count = restorer_output.size(); - - ResultType r(count); - auto [_results, buffers] = r.pointers(); - - for (int i = 0; i < count; ++i) { - auto upscale = restorer_output[i]; - auto& res = _results[i]; - res.data = upscale.data(); - buffers[i] = upscale.buffer(); - res.format = (mmdeploy_pixel_format_t)upscale.pixel_format(); - res.height = upscale.height(); - res.width = upscale.width(); - res.channel = upscale.channel(); - res.type = (mmdeploy_data_type_t)upscale.type(); +int mmdeploy_restorer_get_result(mmdeploy_value_t output, mmdeploy_mat_t** results) +{ + if (!output || !results) + { + return MMDEPLOY_E_INVALID_ARG; } - - *results = _results; - r.release(); - - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + try + { + const Value& value = Cast(output)->front(); + + auto restorer_output = from_value>(value); + auto count = restorer_output.size(); + + ResultType r(count); + auto [_results, buffers] = r.pointers(); + + for (int i = 0; i < count; ++i) + { + auto upscale = restorer_output[i]; + auto& res = _results[i]; + res.data = upscale.data(); + buffers[i] = upscale.buffer(); + res.format = (mmdeploy_pixel_format_t)upscale.pixel_format(); + res.height = upscale.height(); + res.width = upscale.width(); + res.channel = upscale.channel(); + res.type = (mmdeploy_data_type_t)upscale.type(); + } + + *results = _results; + r.release(); + + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/restorer.h b/csrc/mmdeploy/apis/c/mmdeploy/restorer.h index 9ab529850f..5c8533102f 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/restorer.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/restorer.h @@ -13,76 +13,72 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_restorer* mmdeploy_restorer_t; - -/** - * @brief Create a restorer instance - * @param[in] model an instance of image restoration model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] restorer handle of the created restorer, which must be destroyed - * by \ref mmdeploy_restorer_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_restorer_create(mmdeploy_model_t model, const char* device_name, - int device_id, mmdeploy_restorer_t* restorer); - -/** - * @brief Create a restorer instance - * @param[in] model_path path to image restoration model - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] restorer handle of the created restorer, which must be destroyed - * by \ref mmdeploy_restorer_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_restorer_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_restorer_t* restorer); - -/** - * @brief Apply restorer to a batch of images - * @param[in] restorer restorer's handle created by \ref mmdeploy_restorer_create_by_path - * @param[in] images a batch of images - * @param[in] count number of images in the batch - * @param[out] results a linear buffer contains the restored images, must be release - * by \ref mmdeploy_restorer_release_result - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_restorer_apply(mmdeploy_restorer_t restorer, const mmdeploy_mat_t* images, - int count, mmdeploy_mat_t** results); - -/** @brief Release result buffer returned by \ref mmdeploy_restorer_apply - * @param[in] results result buffer by restorer - * @param[in] count length of \p result - */ -MMDEPLOY_API void mmdeploy_restorer_release_result(mmdeploy_mat_t* results, int count); - -/** - * @brief destroy restorer - * @param[in] restorer handle of restorer created by \ref mmdeploy_restorer_create_by_path - */ -MMDEPLOY_API void mmdeploy_restorer_destroy(mmdeploy_restorer_t restorer); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -MMDEPLOY_API int mmdeploy_restorer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_restorer_t* restorer); - -MMDEPLOY_API int mmdeploy_restorer_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value); - -MMDEPLOY_API int mmdeploy_restorer_apply_v2(mmdeploy_restorer_t restorer, mmdeploy_value_t input, - mmdeploy_value_t* output); - -MMDEPLOY_API int mmdeploy_restorer_apply_async(mmdeploy_restorer_t restorer, - mmdeploy_sender_t input, mmdeploy_sender_t* output); - -MMDEPLOY_API int mmdeploy_restorer_get_result(mmdeploy_value_t output, mmdeploy_mat_t** results); + typedef struct mmdeploy_restorer* mmdeploy_restorer_t; + + /** + * @brief Create a restorer instance + * @param[in] model an instance of image restoration model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] restorer handle of the created restorer, which must be destroyed + * by \ref mmdeploy_restorer_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_restorer_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_restorer_t* restorer); + + /** + * @brief Create a restorer instance + * @param[in] model_path path to image restoration model + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] restorer handle of the created restorer, which must be destroyed + * by \ref mmdeploy_restorer_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_restorer_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_restorer_t* restorer); + + /** + * @brief Apply restorer to a batch of images + * @param[in] restorer restorer's handle created by \ref mmdeploy_restorer_create_by_path + * @param[in] images a batch of images + * @param[in] count number of images in the batch + * @param[out] results a linear buffer contains the restored images, must be release + * by \ref mmdeploy_restorer_release_result + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_restorer_apply(mmdeploy_restorer_t restorer, const mmdeploy_mat_t* images, int count, mmdeploy_mat_t** results); + + /** @brief Release result buffer returned by \ref mmdeploy_restorer_apply + * @param[in] results result buffer by restorer + * @param[in] count length of \p result + */ + MMDEPLOY_API void mmdeploy_restorer_release_result(mmdeploy_mat_t* results, int count); + + /** + * @brief destroy restorer + * @param[in] restorer handle of restorer created by \ref mmdeploy_restorer_create_by_path + */ + MMDEPLOY_API void mmdeploy_restorer_destroy(mmdeploy_restorer_t restorer); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + MMDEPLOY_API int mmdeploy_restorer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_restorer_t* restorer); + + MMDEPLOY_API int mmdeploy_restorer_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value); + + MMDEPLOY_API int mmdeploy_restorer_apply_v2(mmdeploy_restorer_t restorer, mmdeploy_value_t input, mmdeploy_value_t* output); + + MMDEPLOY_API int mmdeploy_restorer_apply_async(mmdeploy_restorer_t restorer, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + MMDEPLOY_API int mmdeploy_restorer_get_result(mmdeploy_value_t output, mmdeploy_mat_t** results); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.cpp b/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.cpp index d2172c54b8..04d537a376 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.cpp @@ -15,124 +15,146 @@ using namespace std; using namespace mmdeploy; -int mmdeploy_rotated_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_rotated_detector_t* detector) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_rotated_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_rotated_detector_t* detector) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_rotated_detector_create_v2(model, context, detector); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_rotated_detector_create_v2(model, context, detector); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_rotated_detector_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_rotated_detector_t* detector) { - mmdeploy_model_t model{}; +int mmdeploy_rotated_detector_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_rotated_detector_t* detector) +{ + mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_rotated_detector_create(model, device_name, device_id, detector); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_rotated_detector_create(model, device_name, device_id, detector); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_rotated_detector_apply(mmdeploy_rotated_detector_t detector, - const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_rotated_detection_t** results, int** result_count) { - wrapped input; - if (auto ec = mmdeploy_rotated_detector_create_input(mats, mat_count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_rotated_detector_apply_v2(detector, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_rotated_detector_get_result(output, results, result_count)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_rotated_detector_apply(mmdeploy_rotated_detector_t detector, + const mmdeploy_mat_t* mats, + int mat_count, + mmdeploy_rotated_detection_t** results, + int** result_count) +{ + wrapped input; + if (auto ec = mmdeploy_rotated_detector_create_input(mats, mat_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_rotated_detector_apply_v2(detector, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_rotated_detector_get_result(output, results, result_count)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } void mmdeploy_rotated_detector_release_result(mmdeploy_rotated_detection_t* results, - const int* result_count) { - delete[] results; - delete[] result_count; + const int* result_count) +{ + delete[] results; + delete[] result_count; } -void mmdeploy_rotated_detector_destroy(mmdeploy_rotated_detector_t detector) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); +void mmdeploy_rotated_detector_destroy(mmdeploy_rotated_detector_t detector) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); } -int mmdeploy_rotated_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_rotated_detector_t* detector) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); +int mmdeploy_rotated_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_rotated_detector_t* detector) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); } -int mmdeploy_rotated_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* input) { - return mmdeploy_common_create_input(mats, mat_count, input); +int mmdeploy_rotated_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* input) +{ + return mmdeploy_common_create_input(mats, mat_count, input); } -int mmdeploy_rotated_detector_apply_v2(mmdeploy_rotated_detector_t detector, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_rotated_detector_apply_v2(mmdeploy_rotated_detector_t detector, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); } int mmdeploy_rotated_detector_apply_async(mmdeploy_rotated_detector_t detector, - mmdeploy_sender_t input, mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); + mmdeploy_sender_t input, + mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_rotated_detector_get_result(mmdeploy_value_t output, +int mmdeploy_rotated_detector_get_result(mmdeploy_value_t output, mmdeploy_rotated_detection_t** results, - int** result_count) { - if (!output || !results || !result_count) { - return MMDEPLOY_E_INVALID_ARG; - } - - try { - Value& value = Cast(output)->front(); - auto detector_outputs = from_value>(value); - - vector _result_count; - _result_count.reserve(detector_outputs.size()); - for (const auto& det_output : detector_outputs) { - _result_count.push_back((int)det_output.detections.size()); + int** result_count) +{ + if (!output || !results || !result_count) + { + return MMDEPLOY_E_INVALID_ARG; } - auto total = std::accumulate(_result_count.begin(), _result_count.end(), 0); + try + { + Value& value = Cast(output)->front(); + auto detector_outputs = from_value>(value); - std::unique_ptr result_count_data(new int[_result_count.size()]{}); - std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); - - std::unique_ptr result_data( - new mmdeploy_rotated_detection_t[total]{}); - auto result_ptr = result_data.get(); - - for (const auto& det_output : detector_outputs) { - for (const auto& detection : det_output.detections) { - result_ptr->label_id = detection.label_id; - result_ptr->score = detection.score; - const auto& rbbox = detection.rbbox; - for (int i = 0; i < 5; i++) { - result_ptr->rbbox[i] = rbbox[i]; + vector _result_count; + _result_count.reserve(detector_outputs.size()); + for (const auto& det_output : detector_outputs) + { + _result_count.push_back((int)det_output.detections.size()); } - ++result_ptr; - } - } - *result_count = result_count_data.release(); - *results = result_data.release(); + auto total = std::accumulate(_result_count.begin(), _result_count.end(), 0); + + std::unique_ptr result_count_data(new int[_result_count.size()]{}); + std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); + + std::unique_ptr result_data( + new mmdeploy_rotated_detection_t[total]{}); + auto result_ptr = result_data.get(); + + for (const auto& det_output : detector_outputs) + { + for (const auto& detection : det_output.detections) + { + result_ptr->label_id = detection.label_id; + result_ptr->score = detection.score; + const auto& rbbox = detection.rbbox; + for (int i = 0; i < 5; i++) + { + result_ptr->rbbox[i] = rbbox[i]; + } + ++result_ptr; + } + } - return MMDEPLOY_SUCCESS; + *result_count = result_count_data.release(); + *results = result_data.release(); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.h b/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.h index 35125a74ff..1d745debae 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/rotated_detector.h @@ -13,125 +13,126 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_rotated_detection_t { - int label_id; - float score; - float rbbox[5]; // cx, cy, w, h, angle -} mmdeploy_rotated_detection_t; - -typedef struct mmdeploy_rotated_detector* mmdeploy_rotated_detector_t; - -/** - * @brief Create rotated detector's handle - * @param[in] model an instance of mmrotate sdk model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector instance of a rotated detector - * @return status of creating rotated detector's handle - */ -MMDEPLOY_API int mmdeploy_rotated_detector_create(mmdeploy_model_t model, const char* device_name, - int device_id, - mmdeploy_rotated_detector_t* detector); - -/** - * @brief Create rotated detector's handle - * @param[in] model_path path of mmrotate sdk model exported by mmdeploy model converter - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector instance of a rotated detector - * @return status of creating rotated detector's handle - */ -MMDEPLOY_API int mmdeploy_rotated_detector_create_by_path(const char* model_path, - const char* device_name, int device_id, - mmdeploy_rotated_detector_t* detector); - -/** - * @brief Apply rotated detector to batch images and get their inference results - * @param[in] detector rotated detector's handle created by \ref - * mmdeploy_rotated_detector_create_by_path - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @param[out] results a linear buffer to save detection results of each image. It must be released - * by \ref mmdeploy_rotated_detector_release_result - * @param[out] result_count a linear buffer with length being \p mat_count to save the number of - * detection results of each image. And it must be released by \ref - * mmdeploy_rotated_detector_release_result - * @return status of inference - */ -MMDEPLOY_API int mmdeploy_rotated_detector_apply(mmdeploy_rotated_detector_t detector, - const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_rotated_detection_t** results, - int** result_count); - -/** @brief Release the inference result buffer created by \ref mmdeploy_rotated_detector_apply - * @param[in] results rotated detection results buffer - * @param[in] result_count \p results size buffer - */ -MMDEPLOY_API void mmdeploy_rotated_detector_release_result(mmdeploy_rotated_detection_t* results, - const int* result_count); - -/** - * @brief Destroy rotated detector's handle - * @param[in] detector rotated detector's handle created by \ref - * mmdeploy_rotated_detector_create_by_path or by \ref mmdeploy_rotated_detector_create - */ -MMDEPLOY_API void mmdeploy_rotated_detector_destroy(mmdeploy_rotated_detector_t detector); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -/** - * @brief Same as \ref mmdeploy_detector_create, but allows to control execution context of tasks - * via context - */ -MMDEPLOY_API int mmdeploy_rotated_detector_create_v2(mmdeploy_model_t model, - mmdeploy_context_t context, - mmdeploy_rotated_detector_t* detector); - -/** - * @brief Pack rotated detector inputs into mmdeploy_value_t - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @return the created value - */ -MMDEPLOY_API int mmdeploy_rotated_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* input); - -/** - * @brief Same as \ref mmdeploy_rotated_detector_apply, but input and output are packed in \ref - * mmdeploy_value_t. - */ -MMDEPLOY_API int mmdeploy_rotated_detector_apply_v2(mmdeploy_rotated_detector_t detector, - mmdeploy_value_t input, - mmdeploy_value_t* output); - -/** - * @brief Apply rotated detector asynchronously - * @param[in] detector handle to the detector - * @param[in] input input sender - * @return output sender - */ -MMDEPLOY_API int mmdeploy_rotated_detector_apply_async(mmdeploy_rotated_detector_t detector, - mmdeploy_sender_t input, - mmdeploy_sender_t* output); - -/** - * @brief Unpack rotated detector output from a mmdeploy_value_t - * @param[in] output output obtained by applying a detector - * @param[out] results a linear buffer to save detection results of each image. It must be released - * by \ref mmdeploy_detector_release_result - * @param[out] result_count a linear buffer with length number of input images to save the number of - * detection results of each image. Must be released by \ref - * mmdeploy_detector_release_result - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_rotated_detector_get_result(mmdeploy_value_t output, - mmdeploy_rotated_detection_t** results, - int** result_count); + typedef struct mmdeploy_rotated_detection_t + { + int label_id; + float score; + float rbbox[5]; // cx, cy, w, h, angle + } mmdeploy_rotated_detection_t; + + typedef struct mmdeploy_rotated_detector* mmdeploy_rotated_detector_t; + + /** + * @brief Create rotated detector's handle + * @param[in] model an instance of mmrotate sdk model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector instance of a rotated detector + * @return status of creating rotated detector's handle + */ + MMDEPLOY_API int mmdeploy_rotated_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_rotated_detector_t* detector); + + /** + * @brief Create rotated detector's handle + * @param[in] model_path path of mmrotate sdk model exported by mmdeploy model converter + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector instance of a rotated detector + * @return status of creating rotated detector's handle + */ + MMDEPLOY_API int mmdeploy_rotated_detector_create_by_path(const char* model_path, + const char* device_name, + int device_id, + mmdeploy_rotated_detector_t* detector); + + /** + * @brief Apply rotated detector to batch images and get their inference results + * @param[in] detector rotated detector's handle created by \ref + * mmdeploy_rotated_detector_create_by_path + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @param[out] results a linear buffer to save detection results of each image. It must be released + * by \ref mmdeploy_rotated_detector_release_result + * @param[out] result_count a linear buffer with length being \p mat_count to save the number of + * detection results of each image. And it must be released by \ref + * mmdeploy_rotated_detector_release_result + * @return status of inference + */ + MMDEPLOY_API int mmdeploy_rotated_detector_apply(mmdeploy_rotated_detector_t detector, + const mmdeploy_mat_t* mats, + int mat_count, + mmdeploy_rotated_detection_t** results, + int** result_count); + + /** @brief Release the inference result buffer created by \ref mmdeploy_rotated_detector_apply + * @param[in] results rotated detection results buffer + * @param[in] result_count \p results size buffer + */ + MMDEPLOY_API void mmdeploy_rotated_detector_release_result(mmdeploy_rotated_detection_t* results, + const int* result_count); + + /** + * @brief Destroy rotated detector's handle + * @param[in] detector rotated detector's handle created by \ref + * mmdeploy_rotated_detector_create_by_path or by \ref mmdeploy_rotated_detector_create + */ + MMDEPLOY_API void mmdeploy_rotated_detector_destroy(mmdeploy_rotated_detector_t detector); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + /** + * @brief Same as \ref mmdeploy_detector_create, but allows to control execution context of tasks + * via context + */ + MMDEPLOY_API int mmdeploy_rotated_detector_create_v2(mmdeploy_model_t model, + mmdeploy_context_t context, + mmdeploy_rotated_detector_t* detector); + + /** + * @brief Pack rotated detector inputs into mmdeploy_value_t + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @return the created value + */ + MMDEPLOY_API int mmdeploy_rotated_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* input); + + /** + * @brief Same as \ref mmdeploy_rotated_detector_apply, but input and output are packed in \ref + * mmdeploy_value_t. + */ + MMDEPLOY_API int mmdeploy_rotated_detector_apply_v2(mmdeploy_rotated_detector_t detector, + mmdeploy_value_t input, + mmdeploy_value_t* output); + + /** + * @brief Apply rotated detector asynchronously + * @param[in] detector handle to the detector + * @param[in] input input sender + * @return output sender + */ + MMDEPLOY_API int mmdeploy_rotated_detector_apply_async(mmdeploy_rotated_detector_t detector, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + /** + * @brief Unpack rotated detector output from a mmdeploy_value_t + * @param[in] output output obtained by applying a detector + * @param[out] results a linear buffer to save detection results of each image. It must be released + * by \ref mmdeploy_detector_release_result + * @param[out] result_count a linear buffer with length number of input images to save the number of + * detection results of each image. Must be released by \ref + * mmdeploy_detector_release_result + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_rotated_detector_get_result(mmdeploy_value_t output, + mmdeploy_rotated_detection_t** results, + int** result_count); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp index c982df39e5..9ec8ae366c 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.cpp @@ -18,111 +18,128 @@ using namespace mmdeploy; using ResultType = mmdeploy::Structure; -int mmdeploy_segmentor_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_segmentor_t* segmentor) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_segmentor_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_segmentor_t* segmentor) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_segmentor_create_v2(model, context, segmentor); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_segmentor_create_v2(model, context, segmentor); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_segmentor_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_segmentor_t* segmentor) { - mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { +int mmdeploy_segmentor_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_segmentor_t* segmentor) +{ + mmdeploy_model_t model{}; + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_segmentor_create(model, device_name, device_id, segmentor); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_segmentor_create(model, device_name, device_id, segmentor); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_segmentor_apply(mmdeploy_segmentor_t segmentor, const mmdeploy_mat_t* mats, - int mat_count, mmdeploy_segmentation_t** results) { - wrapped input; - if (auto ec = mmdeploy_segmentor_create_input(mats, mat_count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_segmentor_apply_v2(segmentor, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_segmentor_get_result(output, results)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_segmentor_apply(mmdeploy_segmentor_t segmentor, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_segmentation_t** results) +{ + wrapped input; + if (auto ec = mmdeploy_segmentor_create_input(mats, mat_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_segmentor_apply_v2(segmentor, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_segmentor_get_result(output, results)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -void mmdeploy_segmentor_release_result(mmdeploy_segmentation_t* results, int count) { - ResultType deleter(static_cast(count), results); +void mmdeploy_segmentor_release_result(mmdeploy_segmentation_t* results, int count) +{ + ResultType deleter(static_cast(count), results); } -void mmdeploy_segmentor_destroy(mmdeploy_segmentor_t segmentor) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)segmentor); +void mmdeploy_segmentor_destroy(mmdeploy_segmentor_t segmentor) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)segmentor); } -int mmdeploy_segmentor_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_segmentor_t* segmentor) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)segmentor); +int mmdeploy_segmentor_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_segmentor_t* segmentor) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)segmentor); } -int mmdeploy_segmentor_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value) { - return mmdeploy_common_create_input(mats, mat_count, value); +int mmdeploy_segmentor_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value) +{ + return mmdeploy_common_create_input(mats, mat_count, value); } -int mmdeploy_segmentor_apply_v2(mmdeploy_segmentor_t segmentor, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)segmentor, input, output); +int mmdeploy_segmentor_apply_v2(mmdeploy_segmentor_t segmentor, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)segmentor, input, output); } -int mmdeploy_segmentor_apply_async(mmdeploy_segmentor_t segmentor, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)segmentor, input, output); +int mmdeploy_segmentor_apply_async(mmdeploy_segmentor_t segmentor, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)segmentor, input, output); } -int mmdeploy_segmentor_get_result(mmdeploy_value_t output, mmdeploy_segmentation_t** results) { - try { - const auto& value = Cast(output)->front(); - size_t image_count = value.size(); - - ResultType r(image_count); - auto [results_data, buffers] = r.pointers(); - - auto results_ptr = results_data; - - for (auto i = 0; i < image_count; ++i, ++results_ptr) { - auto& output_item = value[i]; - MMDEPLOY_DEBUG("the {}-th item in output: {}", i, output_item); - auto segmentor_output = from_value(output_item); - results_ptr->height = segmentor_output.height; - results_ptr->width = segmentor_output.width; - results_ptr->classes = segmentor_output.classes; - auto& mask = segmentor_output.mask; - auto& score = segmentor_output.score; - results_ptr->mask = nullptr; - results_ptr->score = nullptr; - if (mask.shape().size()) { - results_ptr->mask = mask.data(); - buffers[i] = mask.buffer(); - } else { - results_ptr->score = score.data(); - buffers[i] = score.buffer(); - } +int mmdeploy_segmentor_get_result(mmdeploy_value_t output, mmdeploy_segmentation_t** results) +{ + try + { + const auto& value = Cast(output)->front(); + size_t image_count = value.size(); + + ResultType r(image_count); + auto [results_data, buffers] = r.pointers(); + + auto results_ptr = results_data; + + for (auto i = 0; i < image_count; ++i, ++results_ptr) + { + auto& output_item = value[i]; + MMDEPLOY_DEBUG("the {}-th item in output: {}", i, output_item); + auto segmentor_output = from_value(output_item); + results_ptr->height = segmentor_output.height; + results_ptr->width = segmentor_output.width; + results_ptr->classes = segmentor_output.classes; + auto& mask = segmentor_output.mask; + auto& score = segmentor_output.score; + results_ptr->mask = nullptr; + results_ptr->score = nullptr; + if (mask.shape().size()) + { + results_ptr->mask = mask.data(); + buffers[i] = mask.buffer(); + } + else + { + results_ptr->score = score.data(); + buffers[i] = score.buffer(); + } + } + + *results = results_data; + r.release(); + + return MMDEPLOY_SUCCESS; } - - *results = results_data; - r.release(); - - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("exception caught: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + catch (const std::exception& e) + { + MMDEPLOY_ERROR("exception caught: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h index 65bcfd03f3..8d885a275b 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/segmentor.h @@ -13,91 +13,90 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_segmentation_t { - int height; ///< height of \p mask that equals to the input image's height - int width; ///< width of \p mask that equals to the input image's width - int classes; ///< the number of labels in \p mask - int* mask; ///< segmentation mask of the input image, in which mask[i * width + j] indicates - ///< the label id of pixel at (i, j), this field might be null - float* score; ///< segmentation score map of the input image in CHW format, in which - ///< score[height * width * k + i * width + j] indicates the score - ///< of class k at pixel (i, j), this field might be null -} mmdeploy_segmentation_t; - -typedef struct mmdeploy_segmentor* mmdeploy_segmentor_t; - -/** - * @brief Create segmentor's handle - * @param[in] model an instance of mmsegmentation sdk model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] segmentor instance of a segmentor, which must be destroyed - * by \ref mmdeploy_segmentor_destroy - * @return status of creating segmentor's handle - */ -MMDEPLOY_API int mmdeploy_segmentor_create(mmdeploy_model_t model, const char* device_name, - int device_id, mmdeploy_segmentor_t* segmentor); - -/** - * @brief Create segmentor's handle - * @param[in] model_path path of mmsegmentation sdk model exported by mmdeploy model converter - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] segmentor instance of a segmentor, which must be destroyed - * by \ref mmdeploy_segmentor_destroy - * @return status of creating segmentor's handle - */ -MMDEPLOY_API int mmdeploy_segmentor_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_segmentor_t* segmentor); - -/** - * @brief Apply segmentor to batch images and get their inference results - * @param[in] segmentor segmentor's handle created by \ref mmdeploy_segmentor_create_by_path or \ref - * mmdeploy_segmentor_create - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @param[out] results a linear buffer of length \p mat_count to save segmentation result of each - * image. It must be released by \ref mmdeploy_segmentor_release_result - * @return status of inference - */ -MMDEPLOY_API int mmdeploy_segmentor_apply(mmdeploy_segmentor_t segmentor, - const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_segmentation_t** results); - -/** - * @brief Release result buffer returned by \ref mmdeploy_segmentor_apply - * @param[in] results result buffer - * @param[in] count length of \p results - */ -MMDEPLOY_API void mmdeploy_segmentor_release_result(mmdeploy_segmentation_t* results, int count); - -/** - * @brief Destroy segmentor's handle - * @param[in] segmentor segmentor's handle created by \ref mmdeploy_segmentor_create_by_path - */ -MMDEPLOY_API void mmdeploy_segmentor_destroy(mmdeploy_segmentor_t segmentor); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -MMDEPLOY_API int mmdeploy_segmentor_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_segmentor_t* segmentor); - -MMDEPLOY_API int mmdeploy_segmentor_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* value); - -MMDEPLOY_API int mmdeploy_segmentor_apply_v2(mmdeploy_segmentor_t segmentor, mmdeploy_value_t input, - mmdeploy_value_t* output); - -MMDEPLOY_API int mmdeploy_segmentor_apply_async(mmdeploy_segmentor_t segmentor, - mmdeploy_sender_t input, mmdeploy_sender_t* output); - -MMDEPLOY_API int mmdeploy_segmentor_get_result(mmdeploy_value_t output, - mmdeploy_segmentation_t** results); + typedef struct mmdeploy_segmentation_t + { + int height; ///< height of \p mask that equals to the input image's height + int width; ///< width of \p mask that equals to the input image's width + int classes; ///< the number of labels in \p mask + int* mask; ///< segmentation mask of the input image, in which mask[i * width + j] indicates + ///< the label id of pixel at (i, j), this field might be null + float* score; ///< segmentation score map of the input image in CHW format, in which + ///< score[height * width * k + i * width + j] indicates the score + ///< of class k at pixel (i, j), this field might be null + } mmdeploy_segmentation_t; + + typedef struct mmdeploy_segmentor* mmdeploy_segmentor_t; + + /** + * @brief Create segmentor's handle + * @param[in] model an instance of mmsegmentation sdk model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] segmentor instance of a segmentor, which must be destroyed + * by \ref mmdeploy_segmentor_destroy + * @return status of creating segmentor's handle + */ + MMDEPLOY_API int mmdeploy_segmentor_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_segmentor_t* segmentor); + + /** + * @brief Create segmentor's handle + * @param[in] model_path path of mmsegmentation sdk model exported by mmdeploy model converter + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] segmentor instance of a segmentor, which must be destroyed + * by \ref mmdeploy_segmentor_destroy + * @return status of creating segmentor's handle + */ + MMDEPLOY_API int mmdeploy_segmentor_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_segmentor_t* segmentor); + + /** + * @brief Apply segmentor to batch images and get their inference results + * @param[in] segmentor segmentor's handle created by \ref mmdeploy_segmentor_create_by_path or \ref + * mmdeploy_segmentor_create + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @param[out] results a linear buffer of length \p mat_count to save segmentation result of each + * image. It must be released by \ref mmdeploy_segmentor_release_result + * @return status of inference + */ + MMDEPLOY_API int mmdeploy_segmentor_apply(mmdeploy_segmentor_t segmentor, + const mmdeploy_mat_t* mats, + int mat_count, + mmdeploy_segmentation_t** results); + + /** + * @brief Release result buffer returned by \ref mmdeploy_segmentor_apply + * @param[in] results result buffer + * @param[in] count length of \p results + */ + MMDEPLOY_API void mmdeploy_segmentor_release_result(mmdeploy_segmentation_t* results, int count); + + /** + * @brief Destroy segmentor's handle + * @param[in] segmentor segmentor's handle created by \ref mmdeploy_segmentor_create_by_path + */ + MMDEPLOY_API void mmdeploy_segmentor_destroy(mmdeploy_segmentor_t segmentor); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + MMDEPLOY_API int mmdeploy_segmentor_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_segmentor_t* segmentor); + + MMDEPLOY_API int mmdeploy_segmentor_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* value); + + MMDEPLOY_API int mmdeploy_segmentor_apply_v2(mmdeploy_segmentor_t segmentor, mmdeploy_value_t input, mmdeploy_value_t* output); + + MMDEPLOY_API int mmdeploy_segmentor_apply_async(mmdeploy_segmentor_t segmentor, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + MMDEPLOY_API int mmdeploy_segmentor_get_result(mmdeploy_value_t output, + mmdeploy_segmentation_t** results); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/text_detector.cpp b/csrc/mmdeploy/apis/c/mmdeploy/text_detector.cpp index 576af07762..44b124187f 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/text_detector.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/text_detector.cpp @@ -16,158 +16,186 @@ using namespace std; using namespace mmdeploy; -int mmdeploy_text_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_text_detector_t* detector) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_text_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_text_detector_t* detector) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_text_detector_create_v2(model, context, detector); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_text_detector_create_v2(model, context, detector); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_text_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_text_detector_t* detector) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); +int mmdeploy_text_detector_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_text_detector_t* detector) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)detector); } -int mmdeploy_text_detector_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_text_detector_t* detector) { - mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { +int mmdeploy_text_detector_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_text_detector_t* detector) +{ + mmdeploy_model_t model{}; + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_text_detector_create(model, device_name, device_id, detector); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_text_detector_create(model, device_name, device_id, detector); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_text_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* input) { - return mmdeploy_common_create_input(mats, mat_count, input); +int mmdeploy_text_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* input) +{ + return mmdeploy_common_create_input(mats, mat_count, input); } -int mmdeploy_text_detector_apply(mmdeploy_text_detector_t detector, const mmdeploy_mat_t* mats, - int mat_count, mmdeploy_text_detection_t** results, - int** result_count) { - wrapped input; - if (auto ec = mmdeploy_text_detector_create_input(mats, mat_count, input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_text_detector_apply_v2(detector, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_text_detector_get_result(output, results, result_count)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_text_detector_apply(mmdeploy_text_detector_t detector, const mmdeploy_mat_t* mats, int mat_count, mmdeploy_text_detection_t** results, int** result_count) +{ + wrapped input; + if (auto ec = mmdeploy_text_detector_create_input(mats, mat_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_text_detector_apply_v2(detector, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_text_detector_get_result(output, results, result_count)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_text_detector_apply_v2(mmdeploy_text_detector_t detector, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_text_detector_apply_v2(mmdeploy_text_detector_t detector, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_text_detector_apply_async(mmdeploy_text_detector_t detector, mmdeploy_sender_t input, - mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); +int mmdeploy_text_detector_apply_async(mmdeploy_text_detector_t detector, mmdeploy_sender_t input, mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)detector, input, output); } -int mmdeploy_text_detector_get_result(mmdeploy_value_t output, mmdeploy_text_detection_t** results, - int** result_count) { - if (!output || !results || !result_count) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - Value& value = reinterpret_cast(output)->front(); - auto detector_outputs = from_value>(value); - - vector _result_count; - _result_count.reserve(detector_outputs.size()); - for (const auto& det_output : detector_outputs) { - _result_count.push_back((int)det_output.size()); +int mmdeploy_text_detector_get_result(mmdeploy_value_t output, mmdeploy_text_detection_t** results, int** result_count) +{ + if (!output || !results || !result_count) + { + return MMDEPLOY_E_INVALID_ARG; } - - auto total = std::accumulate(_result_count.begin(), _result_count.end(), 0); - - std::unique_ptr result_count_data(new int[_result_count.size()]{}); - std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); - - std::unique_ptr result_data( - new mmdeploy_text_detection_t[total]{}); - auto result_ptr = result_data.get(); - - for (const auto& det_output : detector_outputs) { - for (auto i = 0; i < det_output.size(); ++i, ++result_ptr) { - result_ptr->score = det_output[i].score; - auto& bbox = det_output[i].bbox; - for (auto j = 0; j < bbox.size(); j += 2) { - result_ptr->bbox[j / 2].x = bbox[j]; - result_ptr->bbox[j / 2].y = bbox[j + 1]; + try + { + Value& value = reinterpret_cast(output)->front(); + auto detector_outputs = from_value>(value); + + vector _result_count; + _result_count.reserve(detector_outputs.size()); + for (const auto& det_output : detector_outputs) + { + _result_count.push_back((int)det_output.size()); } - } - } - *result_count = result_count_data.release(); - *results = result_data.release(); + auto total = std::accumulate(_result_count.begin(), _result_count.end(), 0); + + std::unique_ptr result_count_data(new int[_result_count.size()]{}); + std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); + + std::unique_ptr result_data( + new mmdeploy_text_detection_t[total]{}); + auto result_ptr = result_data.get(); + + for (const auto& det_output : detector_outputs) + { + for (auto i = 0; i < det_output.size(); ++i, ++result_ptr) + { + result_ptr->score = det_output[i].score; + auto& bbox = det_output[i].bbox; + for (auto j = 0; j < bbox.size(); j += 2) + { + result_ptr->bbox[j / 2].x = bbox[j]; + result_ptr->bbox[j / 2].y = bbox[j + 1]; + } + } + } - return MMDEPLOY_SUCCESS; + *result_count = result_count_data.release(); + *results = result_data.release(); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return 0; + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return 0; } void mmdeploy_text_detector_release_result(mmdeploy_text_detection_t* results, - const int* result_count, int count) { - delete[] results; - delete[] result_count; + const int* result_count, + int count) +{ + delete[] results; + delete[] result_count; } -void mmdeploy_text_detector_destroy(mmdeploy_text_detector_t detector) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); +void mmdeploy_text_detector_destroy(mmdeploy_text_detector_t detector) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)detector); } -int mmdeploy_text_detector_apply_async_v2(mmdeploy_text_detector_t detector, - const mmdeploy_mat_t* imgs, int img_count, - mmdeploy_text_detector_continue_t cont, void* context, - mmdeploy_sender_t* output) { - mmdeploy_sender_t result_sender{}; - if (auto ec = mmdeploy_text_detector_apply_async_v3(detector, imgs, img_count, &result_sender)) { - return ec; - } - if (auto ec = mmdeploy_text_detector_continue_async(result_sender, cont, context, output)) { - return ec; - } - return MMDEPLOY_SUCCESS; +int mmdeploy_text_detector_apply_async_v2(mmdeploy_text_detector_t detector, + const mmdeploy_mat_t* imgs, + int img_count, + mmdeploy_text_detector_continue_t cont, + void* context, + mmdeploy_sender_t* output) +{ + mmdeploy_sender_t result_sender{}; + if (auto ec = mmdeploy_text_detector_apply_async_v3(detector, imgs, img_count, &result_sender)) + { + return ec; + } + if (auto ec = mmdeploy_text_detector_continue_async(result_sender, cont, context, output)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } int mmdeploy_text_detector_apply_async_v3(mmdeploy_text_detector_t detector, - const mmdeploy_mat_t* imgs, int img_count, - mmdeploy_sender_t* output) { - wrapped input_val; - if (auto ec = mmdeploy_text_detector_create_input(imgs, img_count, input_val.ptr())) { - return ec; - } - mmdeploy_sender_t input_sndr = mmdeploy_executor_just(input_val); - if (auto ec = mmdeploy_text_detector_apply_async(detector, input_sndr, output)) { - return ec; - } - return MMDEPLOY_SUCCESS; + const mmdeploy_mat_t* imgs, + int img_count, + mmdeploy_sender_t* output) +{ + wrapped input_val; + if (auto ec = mmdeploy_text_detector_create_input(imgs, img_count, input_val.ptr())) + { + return ec; + } + mmdeploy_sender_t input_sndr = mmdeploy_executor_just(input_val); + if (auto ec = mmdeploy_text_detector_apply_async(detector, input_sndr, output)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_text_detector_continue_async(mmdeploy_sender_t input, - mmdeploy_text_detector_continue_t cont, void* context, - mmdeploy_sender_t* output) { - auto sender = Guard([&] { - return Take( - LetValue(Take(input), [fn = cont, context](Value& value) -> TypeErasedSender { +int mmdeploy_text_detector_continue_async(mmdeploy_sender_t input, + mmdeploy_text_detector_continue_t cont, + void* context, + mmdeploy_sender_t* output) +{ + auto sender = Guard([&] + { return Take( + LetValue(Take(input), [fn = cont, context](Value& value) -> TypeErasedSender + { mmdeploy_text_detection_t* results{}; int* result_count{}; if (auto ec = mmdeploy_text_detector_get_result(Cast(&value), &results, &result_count)) { @@ -178,12 +206,11 @@ int mmdeploy_text_detector_continue_async(mmdeploy_sender_t input, if (auto ec = fn(results, result_count, context, &output); ec || !output) { return Just(Value()); } - return Take(output); - })); - }); - if (sender) { - *output = sender; - return MMDEPLOY_SUCCESS; - } - return MMDEPLOY_E_FAIL; + return Take(output); })); }); + if (sender) + { + *output = sender; + return MMDEPLOY_SUCCESS; + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/text_detector.h b/csrc/mmdeploy/apis/c/mmdeploy/text_detector.h index a3c38dc6f6..da363940d7 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/text_detector.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/text_detector.h @@ -13,141 +13,147 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_text_detection_t { - mmdeploy_point_t bbox[4]; ///< a text bounding box of which the vertex are in clock-wise - float score; -} mmdeploy_text_detection_t; - -typedef struct mmdeploy_text_detector* mmdeploy_text_detector_t; - -/** - * @brief Create text-detector's handle - * @param[in] model an instance of mmocr text detection model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] detector instance of a text-detector, which must be destroyed - * by \ref mmdeploy_text_detector_destroy - * @return status of creating text-detector's handle - */ -MMDEPLOY_API int mmdeploy_text_detector_create(mmdeploy_model_t model, const char* device_name, - int device_id, mmdeploy_text_detector_t* detector); - -/** - * @brief Create text-detector's handle - * @param[in] model_path path to text detection model - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device - * @param[out] detector instance of a text-detector, which must be destroyed - * by \ref mmdeploy_text_detector_destroy - * @return status of creating text-detector's handle - */ -MMDEPLOY_API int mmdeploy_text_detector_create_by_path(const char* model_path, - const char* device_name, int device_id, - mmdeploy_text_detector_t* detector); - -/** - * @brief Apply text-detector to batch images and get their inference results - * @param[in] detector text-detector's handle created by \ref mmdeploy_text_detector_create_by_path - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @param[out] results a linear buffer to save text detection results of each - * image. It must be released by calling \ref mmdeploy_text_detector_release_result - * @param[out] result_count a linear buffer of length \p mat_count to save the number of detection - * results of each image. It must be released by \ref mmdeploy_detector_release_result - * @return status of inference - */ -MMDEPLOY_API int mmdeploy_text_detector_apply(mmdeploy_text_detector_t detector, - const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_text_detection_t** results, - int** result_count); - -/** @brief Release the inference result buffer returned by \ref mmdeploy_text_detector_apply - * @param[in] results text detection result buffer - * @param[in] result_count \p results size buffer - * @param[in] count the length of buffer \p result_count - */ -MMDEPLOY_API void mmdeploy_text_detector_release_result(mmdeploy_text_detection_t* results, - const int* result_count, int count); - -/** - * @brief Destroy text-detector's handle - * @param[in] detector text-detector's handle created by \ref mmdeploy_text_detector_create_by_path - * or \ref mmdeploy_text_detector_create - */ -MMDEPLOY_API void mmdeploy_text_detector_destroy(mmdeploy_text_detector_t detector); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -/** - * @brief Same as \ref mmdeploy_text_detector_create, but allows to control execution context of - * tasks via context - */ -MMDEPLOY_API int mmdeploy_text_detector_create_v2(mmdeploy_model_t model, - mmdeploy_context_t context, - mmdeploy_text_detector_t* detector); - -/** - * @brief Pack text-detector inputs into mmdeploy_value_t - * @param[in] mats a batch of images - * @param[in] mat_count number of images in the batch - * @return the created value - */ -MMDEPLOY_API int mmdeploy_text_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, - mmdeploy_value_t* input); - -/** - * @brief Same as \ref mmdeploy_text_detector_apply, but input and output are packed in \ref - * mmdeploy_value_t. - */ -MMDEPLOY_API int mmdeploy_text_detector_apply_v2(mmdeploy_text_detector_t detector, - mmdeploy_value_t input, mmdeploy_value_t* output); - -/** - * @brief Apply text-detector asynchronously - * @param[in] detector handle to the detector - * @param[in] input input sender that will be consumed by the operation - * @return output sender - */ -MMDEPLOY_API int mmdeploy_text_detector_apply_async(mmdeploy_text_detector_t detector, - mmdeploy_sender_t input, - mmdeploy_sender_t* output); - -/** - * @brief Unpack detector output from a mmdeploy_value_t - * @param[in] output output sender returned by applying a detector - * @param[out] results a linear buffer to save detection results of each image. It must be - * released by \ref mmdeploy_text_detector_release_result - * @param[out] result_count a linear buffer with length number of input images to save the - * number of detection results of each image. Must be released by \ref - * mmdeploy_text_detector_release_result - * @return status of the operation - */ -MMDEPLOY_API -int mmdeploy_text_detector_get_result(mmdeploy_value_t output, mmdeploy_text_detection_t** results, - int** result_count); - -typedef int (*mmdeploy_text_detector_continue_t)(mmdeploy_text_detection_t* results, - int* result_count, void* context, - mmdeploy_sender_t* output); - -// MMDEPLOY_API int mmdeploy_text_detector_apply_async_v2(mm_handle_t handle, const mm_mat_t* imgs, -// int img_count, -// mmdeploy_text_detector_continuation_t -// cont, void* context, mmdeploy_sender_t* -// output); - -MMDEPLOY_API int mmdeploy_text_detector_apply_async_v3(mmdeploy_text_detector_t detector, - const mmdeploy_mat_t* imgs, int img_count, - mmdeploy_sender_t* output); - -MMDEPLOY_API int mmdeploy_text_detector_continue_async(mmdeploy_sender_t input, - mmdeploy_text_detector_continue_t cont, - void* context, mmdeploy_sender_t* output); + typedef struct mmdeploy_text_detection_t + { + mmdeploy_point_t bbox[4]; ///< a text bounding box of which the vertex are in clock-wise + float score; + } mmdeploy_text_detection_t; + + typedef struct mmdeploy_text_detector* mmdeploy_text_detector_t; + + /** + * @brief Create text-detector's handle + * @param[in] model an instance of mmocr text detection model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] detector instance of a text-detector, which must be destroyed + * by \ref mmdeploy_text_detector_destroy + * @return status of creating text-detector's handle + */ + MMDEPLOY_API int mmdeploy_text_detector_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_text_detector_t* detector); + + /** + * @brief Create text-detector's handle + * @param[in] model_path path to text detection model + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device + * @param[out] detector instance of a text-detector, which must be destroyed + * by \ref mmdeploy_text_detector_destroy + * @return status of creating text-detector's handle + */ + MMDEPLOY_API int mmdeploy_text_detector_create_by_path(const char* model_path, + const char* device_name, + int device_id, + mmdeploy_text_detector_t* detector); + + /** + * @brief Apply text-detector to batch images and get their inference results + * @param[in] detector text-detector's handle created by \ref mmdeploy_text_detector_create_by_path + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @param[out] results a linear buffer to save text detection results of each + * image. It must be released by calling \ref mmdeploy_text_detector_release_result + * @param[out] result_count a linear buffer of length \p mat_count to save the number of detection + * results of each image. It must be released by \ref mmdeploy_detector_release_result + * @return status of inference + */ + MMDEPLOY_API int mmdeploy_text_detector_apply(mmdeploy_text_detector_t detector, + const mmdeploy_mat_t* mats, + int mat_count, + mmdeploy_text_detection_t** results, + int** result_count); + + /** @brief Release the inference result buffer returned by \ref mmdeploy_text_detector_apply + * @param[in] results text detection result buffer + * @param[in] result_count \p results size buffer + * @param[in] count the length of buffer \p result_count + */ + MMDEPLOY_API void mmdeploy_text_detector_release_result(mmdeploy_text_detection_t* results, + const int* result_count, + int count); + + /** + * @brief Destroy text-detector's handle + * @param[in] detector text-detector's handle created by \ref mmdeploy_text_detector_create_by_path + * or \ref mmdeploy_text_detector_create + */ + MMDEPLOY_API void mmdeploy_text_detector_destroy(mmdeploy_text_detector_t detector); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + /** + * @brief Same as \ref mmdeploy_text_detector_create, but allows to control execution context of + * tasks via context + */ + MMDEPLOY_API int mmdeploy_text_detector_create_v2(mmdeploy_model_t model, + mmdeploy_context_t context, + mmdeploy_text_detector_t* detector); + + /** + * @brief Pack text-detector inputs into mmdeploy_value_t + * @param[in] mats a batch of images + * @param[in] mat_count number of images in the batch + * @return the created value + */ + MMDEPLOY_API int mmdeploy_text_detector_create_input(const mmdeploy_mat_t* mats, int mat_count, mmdeploy_value_t* input); + + /** + * @brief Same as \ref mmdeploy_text_detector_apply, but input and output are packed in \ref + * mmdeploy_value_t. + */ + MMDEPLOY_API int mmdeploy_text_detector_apply_v2(mmdeploy_text_detector_t detector, + mmdeploy_value_t input, + mmdeploy_value_t* output); + + /** + * @brief Apply text-detector asynchronously + * @param[in] detector handle to the detector + * @param[in] input input sender that will be consumed by the operation + * @return output sender + */ + MMDEPLOY_API int mmdeploy_text_detector_apply_async(mmdeploy_text_detector_t detector, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + /** + * @brief Unpack detector output from a mmdeploy_value_t + * @param[in] output output sender returned by applying a detector + * @param[out] results a linear buffer to save detection results of each image. It must be + * released by \ref mmdeploy_text_detector_release_result + * @param[out] result_count a linear buffer with length number of input images to save the + * number of detection results of each image. Must be released by \ref + * mmdeploy_text_detector_release_result + * @return status of the operation + */ + MMDEPLOY_API + int mmdeploy_text_detector_get_result(mmdeploy_value_t output, mmdeploy_text_detection_t** results, int** result_count); + + typedef int (*mmdeploy_text_detector_continue_t)(mmdeploy_text_detection_t* results, + int* result_count, + void* context, + mmdeploy_sender_t* output); + + // MMDEPLOY_API int mmdeploy_text_detector_apply_async_v2(mm_handle_t handle, const mm_mat_t* imgs, + // int img_count, + // mmdeploy_text_detector_continuation_t + // cont, void* context, mmdeploy_sender_t* + // output); + + MMDEPLOY_API int mmdeploy_text_detector_apply_async_v3(mmdeploy_text_detector_t detector, + const mmdeploy_mat_t* imgs, + int img_count, + mmdeploy_sender_t* output); + + MMDEPLOY_API int mmdeploy_text_detector_continue_async(mmdeploy_sender_t input, + mmdeploy_text_detector_continue_t cont, + void* context, + mmdeploy_sender_t* output); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.cpp b/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.cpp index 3c8cfbb5c6..4c94666add 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.cpp @@ -19,10 +19,12 @@ using namespace mmdeploy; -namespace { +namespace +{ -Value config_template(const Model& model) { - // clang-format off + Value config_template(const Model& model) + { + // clang-format off return { {"type", "Pipeline"}, {"input", {"imgs", "bboxes"}}, @@ -44,194 +46,238 @@ Value config_template(const Model& model) { }, {"output", "texts"}, }; - // clang-format on -} + // clang-format on + } } // namespace -int mmdeploy_text_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_text_recognizer_t* recognizer) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_text_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_text_recognizer_t* recognizer) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_text_recognizer_create_v2(model, context, recognizer); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_text_recognizer_create_v2(model, context, recognizer); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_text_recognizer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_text_recognizer_t* recognizer) { - auto config = config_template(*Cast(model)); - return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)recognizer); +int mmdeploy_text_recognizer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_text_recognizer_t* recognizer) +{ + auto config = config_template(*Cast(model)); + return mmdeploy_pipeline_create_v3(Cast(&config), context, (mmdeploy_pipeline_t*)recognizer); } -int mmdeploy_text_recognizer_create_by_path(const char* model_path, const char* device_name, - int device_id, mmdeploy_text_recognizer_t* recognizer) { - mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { +int mmdeploy_text_recognizer_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_text_recognizer_t* recognizer) +{ + mmdeploy_model_t model{}; + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_text_recognizer_create(model, device_name, device_id, recognizer); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_text_recognizer_create(model, device_name, device_id, recognizer); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_text_recognizer_apply(mmdeploy_text_recognizer_t recognizer, - const mmdeploy_mat_t* images, int count, - mmdeploy_text_recognition_t** results) { - return mmdeploy_text_recognizer_apply_bbox(recognizer, images, count, nullptr, nullptr, results); +int mmdeploy_text_recognizer_apply(mmdeploy_text_recognizer_t recognizer, + const mmdeploy_mat_t* images, + int count, + mmdeploy_text_recognition_t** results) +{ + return mmdeploy_text_recognizer_apply_bbox(recognizer, images, count, nullptr, nullptr, results); } -int mmdeploy_text_recognizer_create_input(const mmdeploy_mat_t* images, int image_count, - const mmdeploy_text_detection_t* bboxes, - const int* bbox_count, mmdeploy_value_t* output) { - if (image_count && images == nullptr) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - Value::Array input_images; - Value::Array input_bboxes; - - auto add_bbox = [&](Mat img, const mmdeploy_text_detection_t* det) { - if (det) { - const auto& b = det->bbox; - Value::Array bbox{b[0].x, b[0].y, b[1].x, b[1].y, b[2].x, b[2].y, b[3].x, b[3].y}; - input_bboxes.push_back({{"bbox", std::move(bbox)}}); - } else { - input_bboxes.push_back(nullptr); - } - input_images.push_back({{"ori_img", img}}); - }; - - for (int i = 0; i < image_count; ++i) { - auto _mat = Cast(images[i]); - if (bboxes && bbox_count) { - for (int j = 0; j < bbox_count[i]; ++j) { - add_bbox(_mat, bboxes++); - } - } else { // inference with whole image - add_bbox(_mat, nullptr); - } +int mmdeploy_text_recognizer_create_input(const mmdeploy_mat_t* images, int image_count, const mmdeploy_text_detection_t* bboxes, const int* bbox_count, mmdeploy_value_t* output) +{ + if (image_count && images == nullptr) + { + return MMDEPLOY_E_INVALID_ARG; } + try + { + Value::Array input_images; + Value::Array input_bboxes; - *output = Take(Value{std::move(input_images), std::move(input_bboxes)}); - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("exception caught: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + auto add_bbox = [&](Mat img, const mmdeploy_text_detection_t* det) + { + if (det) + { + const auto& b = det->bbox; + Value::Array bbox{b[0].x, b[0].y, b[1].x, b[1].y, b[2].x, b[2].y, b[3].x, b[3].y}; + input_bboxes.push_back({{"bbox", std::move(bbox)}}); + } + else + { + input_bboxes.push_back(nullptr); + } + input_images.push_back({{"ori_img", img}}); + }; + + for (int i = 0; i < image_count; ++i) + { + auto _mat = Cast(images[i]); + if (bboxes && bbox_count) + { + for (int j = 0; j < bbox_count[i]; ++j) + { + add_bbox(_mat, bboxes++); + } + } + else + { // inference with whole image + add_bbox(_mat, nullptr); + } + } + + *output = Take(Value{std::move(input_images), std::move(input_bboxes)}); + return MMDEPLOY_SUCCESS; + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("exception caught: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } -int mmdeploy_text_recognizer_apply_bbox(mmdeploy_text_recognizer_t recognizer, - const mmdeploy_mat_t* images, int image_count, +int mmdeploy_text_recognizer_apply_bbox(mmdeploy_text_recognizer_t recognizer, + const mmdeploy_mat_t* images, + int image_count, const mmdeploy_text_detection_t* bboxes, - const int* bbox_count, - mmdeploy_text_recognition_t** results) { - wrapped input; - if (auto ec = mmdeploy_text_recognizer_create_input(images, image_count, bboxes, bbox_count, - input.ptr())) { - return ec; - } - wrapped output; - if (auto ec = mmdeploy_text_recognizer_apply_v2(recognizer, input, output.ptr())) { - return ec; - } - if (auto ec = mmdeploy_text_recognizer_get_result(output, results)) { - return ec; - } - return MMDEPLOY_SUCCESS; + const int* bbox_count, + mmdeploy_text_recognition_t** results) +{ + wrapped input; + if (auto ec = mmdeploy_text_recognizer_create_input(images, image_count, bboxes, bbox_count, input.ptr())) + { + return ec; + } + wrapped output; + if (auto ec = mmdeploy_text_recognizer_apply_v2(recognizer, input, output.ptr())) + { + return ec; + } + if (auto ec = mmdeploy_text_recognizer_get_result(output, results)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_text_recognizer_apply_v2(mmdeploy_text_recognizer_t recognizer, mmdeploy_value_t input, - mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)recognizer, input, output); +int mmdeploy_text_recognizer_apply_v2(mmdeploy_text_recognizer_t recognizer, mmdeploy_value_t input, mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)recognizer, input, output); } int mmdeploy_text_recognizer_apply_async(mmdeploy_text_recognizer_t recognizer, - mmdeploy_sender_t input, mmdeploy_sender_t* output) { - return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)recognizer, input, output); + mmdeploy_sender_t input, + mmdeploy_sender_t* output) +{ + return mmdeploy_pipeline_apply_async((mmdeploy_pipeline_t)recognizer, input, output); } -MMDEPLOY_API int mmdeploy_text_recognizer_get_result(mmdeploy_value_t output, - mmdeploy_text_recognition_t** results) { - if (!output || !results) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - std::vector recognitions; - from_value(Cast(output)->front(), recognitions); +MMDEPLOY_API int mmdeploy_text_recognizer_get_result(mmdeploy_value_t output, + mmdeploy_text_recognition_t** results) +{ + if (!output || !results) + { + return MMDEPLOY_E_INVALID_ARG; + } + try + { + std::vector recognitions; + from_value(Cast(output)->front(), recognitions); - size_t count = recognitions.size(); + size_t count = recognitions.size(); - auto deleter = [&](mmdeploy_text_recognition_t* p) { - mmdeploy_text_recognizer_release_result(p, static_cast(count)); - }; + auto deleter = [&](mmdeploy_text_recognition_t* p) + { + mmdeploy_text_recognizer_release_result(p, static_cast(count)); + }; - std::unique_ptr _results( - new mmdeploy_text_recognition_t[count]{}, deleter); + std::unique_ptr _results( + new mmdeploy_text_recognition_t[count]{}, + deleter); - size_t result_idx = 0; - for (const auto& bbox_result : recognitions) { - auto& res = _results[result_idx++]; + size_t result_idx = 0; + for (const auto& bbox_result : recognitions) + { + auto& res = _results[result_idx++]; - auto& score = bbox_result.score; - res.length = static_cast(score.size()); + auto& score = bbox_result.score; + res.length = static_cast(score.size()); - res.score = new float[score.size()]; - std::copy_n(score.data(), score.size(), res.score); + res.score = new float[score.size()]; + std::copy_n(score.data(), score.size(), res.score); - auto text = bbox_result.text; - res.text = new char[text.length() + 1]; - std::copy_n(text.data(), text.length() + 1, res.text); - } + auto text = bbox_result.text; + res.text = new char[text.length() + 1]; + std::copy_n(text.data(), text.length() + 1, res.text); + } - *results = _results.release(); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("exception caught: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_SUCCESS; + *results = _results.release(); + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("exception caught: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_SUCCESS; } -void mmdeploy_text_recognizer_release_result(mmdeploy_text_recognition_t* results, int count) { - for (int i = 0; i < count; ++i) { - delete[] results[i].score; - delete[] results[i].text; - } - delete[] results; +void mmdeploy_text_recognizer_release_result(mmdeploy_text_recognition_t* results, int count) +{ + for (int i = 0; i < count; ++i) + { + delete[] results[i].score; + delete[] results[i].text; + } + delete[] results; } -void mmdeploy_text_recognizer_destroy(mmdeploy_text_recognizer_t recognizer) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)recognizer); +void mmdeploy_text_recognizer_destroy(mmdeploy_text_recognizer_t recognizer) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)recognizer); } -int mmdeploy_text_recognizer_apply_async_v3(mmdeploy_text_recognizer_t recognizer, - const mmdeploy_mat_t* imgs, int img_count, +int mmdeploy_text_recognizer_apply_async_v3(mmdeploy_text_recognizer_t recognizer, + const mmdeploy_mat_t* imgs, + int img_count, const mmdeploy_text_detection_t* bboxes, - const int* bbox_count, mmdeploy_sender_t* output) { - wrapped input_val; - if (auto ec = mmdeploy_text_recognizer_create_input(imgs, img_count, bboxes, bbox_count, - input_val.ptr())) { - return ec; - } - mmdeploy_sender_t input_sndr = mmdeploy_executor_just(input_val); - if (auto ec = mmdeploy_text_recognizer_apply_async(recognizer, input_sndr, output)) { - return ec; - } - return MMDEPLOY_SUCCESS; + const int* bbox_count, + mmdeploy_sender_t* output) +{ + wrapped input_val; + if (auto ec = mmdeploy_text_recognizer_create_input(imgs, img_count, bboxes, bbox_count, input_val.ptr())) + { + return ec; + } + mmdeploy_sender_t input_sndr = mmdeploy_executor_just(input_val); + if (auto ec = mmdeploy_text_recognizer_apply_async(recognizer, input_sndr, output)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } -int mmdeploy_text_recognizer_continue_async(mmdeploy_sender_t input, - mmdeploy_text_recognizer_continue_t cont, void* context, - mmdeploy_sender_t* output) { - auto sender = Guard([&] { - return Take( - LetValue(Take(input), [fn = cont, context](Value& value) -> TypeErasedSender { +int mmdeploy_text_recognizer_continue_async(mmdeploy_sender_t input, + mmdeploy_text_recognizer_continue_t cont, + void* context, + mmdeploy_sender_t* output) +{ + auto sender = Guard([&] + { return Take( + LetValue(Take(input), [fn = cont, context](Value& value) -> TypeErasedSender + { mmdeploy_text_recognition_t* results{}; if (auto ec = mmdeploy_text_recognizer_get_result(Cast(&value), &results)) { return Just(Value()); @@ -241,12 +287,11 @@ int mmdeploy_text_recognizer_continue_async(mmdeploy_sender_t input, if (auto ec = fn(results, context, &output); ec || !output) { return Just(Value()); } - return Take(output); - })); - }); - if (sender) { - *output = sender; - return MMDEPLOY_SUCCESS; - } - return MMDEPLOY_E_FAIL; + return Take(output); })); }); + if (sender) + { + *output = sender; + return MMDEPLOY_SUCCESS; + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.h b/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.h index 6c18928242..f20c878028 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/text_recognizer.h @@ -13,149 +13,155 @@ #include "mmdeploy/text_detector.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_text_recognition_t { - char* text; - float* score; - int length; -} mmdeploy_text_recognition_t; - -typedef struct mmdeploy_text_recognizer* mmdeploy_text_recognizer_t; - -/** - * @brief Create a text recognizer instance - * @param[in] model an instance of mmocr text recognition model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] recognizer handle of the created text recognizer, which must be destroyed - * by \ref mmdeploy_text_recognizer_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_text_recognizer_create(mmdeploy_model_t model, const char* device_name, - int device_id, - mmdeploy_text_recognizer_t* recognizer); - -/** - * @brief Create a text recognizer instance - * @param[in] model_path path to text recognition model - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] recognizer handle of the created text recognizer, which must be destroyed - * by \ref mmdeploy_text_recognizer_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_text_recognizer_create_by_path(const char* model_path, - const char* device_name, int device_id, - mmdeploy_text_recognizer_t* recognizer); - -/** - * @brief Apply text recognizer to a batch of text images - * @param[in] recognizer text recognizer's handle created by \ref - * mmdeploy_text_recognizer_create_by_path - * @param[in] images a batch of text images - * @param[in] count number of images in the batch - * @param[out] results a linear buffer contains the recognized text, must be release - * by \ref mmdeploy_text_recognizer_release_result - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_text_recognizer_apply(mmdeploy_text_recognizer_t recognizer, - const mmdeploy_mat_t* images, int count, - mmdeploy_text_recognition_t** results); - -/** - * @brief Apply text recognizer to a batch of images supplied with text bboxes - * @param[in] recognizer text recognizer's handle created by \ref - * mmdeploy_text_recognizer_create_by_path - * @param[in] images a batch of text images - * @param[in] image_count number of images in the batch - * @param[in] bboxes bounding boxes detected by text detector - * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images - * @param[out] results a linear buffer contains the recognized text, which has the same length as \p - * bboxes, must be release by \ref mmdeploy_text_recognizer_release_result - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_text_recognizer_apply_bbox(mmdeploy_text_recognizer_t recognizer, - const mmdeploy_mat_t* images, int image_count, - const mmdeploy_text_detection_t* bboxes, - const int* bbox_count, - mmdeploy_text_recognition_t** results); - -/** @brief Release result buffer returned by \ref mmdeploy_text_recognizer_apply or \ref - * mmdeploy_text_recognizer_apply_bbox - * @param[in] results result buffer by text recognizer - * @param[in] count length of \p result - */ -MMDEPLOY_API void mmdeploy_text_recognizer_release_result(mmdeploy_text_recognition_t* results, - int count); - -/** - * @brief destroy text recognizer - * @param[in] recognizer handle of text recognizer created by \ref - * mmdeploy_text_recognizer_create_by_path or \ref mmdeploy_text_recognizer_create - */ -MMDEPLOY_API void mmdeploy_text_recognizer_destroy(mmdeploy_text_recognizer_t recognizer); - -/****************************************************************************** - * Experimental asynchronous APIs */ - -/** - * @brief Same as \ref mmdeploy_text_recognizer_create, but allows to control execution context of - * tasks via context - */ -MMDEPLOY_API int mmdeploy_text_recognizer_create_v2(mmdeploy_model_t model, - mmdeploy_context_t context, - mmdeploy_text_recognizer_t* recognizer); - -/** - * @brief Pack text-recognizer inputs into mmdeploy_value_t - * @param[in] images a batch of images - * @param[in] image_count number of images in the batch - * @param[in] bboxes bounding boxes detected by text detector - * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images - * @return value created - */ -MMDEPLOY_API int mmdeploy_text_recognizer_create_input(const mmdeploy_mat_t* images, - int image_count, - const mmdeploy_text_detection_t* bboxes, - const int* bbox_count, - mmdeploy_value_t* output); - -MMDEPLOY_API int mmdeploy_text_recognizer_apply_v2(mmdeploy_text_recognizer_t recognizer, - mmdeploy_value_t input, - mmdeploy_value_t* output); - -/** - * @brief Same as \ref mmdeploy_text_recognizer_apply_bbox, but input and output are packed in \ref - * mmdeploy_value_t. - */ -MMDEPLOY_API int mmdeploy_text_recognizer_apply_async(mmdeploy_text_recognizer_t recognizer, - mmdeploy_sender_t input, - mmdeploy_sender_t* output); - -typedef int (*mmdeploy_text_recognizer_continue_t)(mmdeploy_text_recognition_t* results, - void* context, mmdeploy_sender_t* output); - -MMDEPLOY_API int mmdeploy_text_recognizer_apply_async_v3(mmdeploy_text_recognizer_t recognizer, - const mmdeploy_mat_t* imgs, int img_count, - const mmdeploy_text_detection_t* bboxes, - const int* bbox_count, - mmdeploy_sender_t* output); - -MMDEPLOY_API int mmdeploy_text_recognizer_continue_async(mmdeploy_sender_t input, - mmdeploy_text_recognizer_continue_t cont, - void* context, mmdeploy_sender_t* output); - -/** - * @brief Unpack text-recognizer output from a mmdeploy_value_t - * @param[in] output - * @param[out] results - * @return status of the operation - */ -MMDEPLOY_API int mmdeploy_text_recognizer_get_result(mmdeploy_value_t output, - mmdeploy_text_recognition_t** results); + typedef struct mmdeploy_text_recognition_t + { + char* text; + float* score; + int length; + } mmdeploy_text_recognition_t; + + typedef struct mmdeploy_text_recognizer* mmdeploy_text_recognizer_t; + + /** + * @brief Create a text recognizer instance + * @param[in] model an instance of mmocr text recognition model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] recognizer handle of the created text recognizer, which must be destroyed + * by \ref mmdeploy_text_recognizer_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_text_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_text_recognizer_t* recognizer); + + /** + * @brief Create a text recognizer instance + * @param[in] model_path path to text recognition model + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] recognizer handle of the created text recognizer, which must be destroyed + * by \ref mmdeploy_text_recognizer_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_text_recognizer_create_by_path(const char* model_path, + const char* device_name, + int device_id, + mmdeploy_text_recognizer_t* recognizer); + + /** + * @brief Apply text recognizer to a batch of text images + * @param[in] recognizer text recognizer's handle created by \ref + * mmdeploy_text_recognizer_create_by_path + * @param[in] images a batch of text images + * @param[in] count number of images in the batch + * @param[out] results a linear buffer contains the recognized text, must be release + * by \ref mmdeploy_text_recognizer_release_result + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_text_recognizer_apply(mmdeploy_text_recognizer_t recognizer, + const mmdeploy_mat_t* images, + int count, + mmdeploy_text_recognition_t** results); + + /** + * @brief Apply text recognizer to a batch of images supplied with text bboxes + * @param[in] recognizer text recognizer's handle created by \ref + * mmdeploy_text_recognizer_create_by_path + * @param[in] images a batch of text images + * @param[in] image_count number of images in the batch + * @param[in] bboxes bounding boxes detected by text detector + * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images + * @param[out] results a linear buffer contains the recognized text, which has the same length as \p + * bboxes, must be release by \ref mmdeploy_text_recognizer_release_result + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_text_recognizer_apply_bbox(mmdeploy_text_recognizer_t recognizer, + const mmdeploy_mat_t* images, + int image_count, + const mmdeploy_text_detection_t* bboxes, + const int* bbox_count, + mmdeploy_text_recognition_t** results); + + /** @brief Release result buffer returned by \ref mmdeploy_text_recognizer_apply or \ref + * mmdeploy_text_recognizer_apply_bbox + * @param[in] results result buffer by text recognizer + * @param[in] count length of \p result + */ + MMDEPLOY_API void mmdeploy_text_recognizer_release_result(mmdeploy_text_recognition_t* results, + int count); + + /** + * @brief destroy text recognizer + * @param[in] recognizer handle of text recognizer created by \ref + * mmdeploy_text_recognizer_create_by_path or \ref mmdeploy_text_recognizer_create + */ + MMDEPLOY_API void mmdeploy_text_recognizer_destroy(mmdeploy_text_recognizer_t recognizer); + + /****************************************************************************** + * Experimental asynchronous APIs */ + + /** + * @brief Same as \ref mmdeploy_text_recognizer_create, but allows to control execution context of + * tasks via context + */ + MMDEPLOY_API int mmdeploy_text_recognizer_create_v2(mmdeploy_model_t model, + mmdeploy_context_t context, + mmdeploy_text_recognizer_t* recognizer); + + /** + * @brief Pack text-recognizer inputs into mmdeploy_value_t + * @param[in] images a batch of images + * @param[in] image_count number of images in the batch + * @param[in] bboxes bounding boxes detected by text detector + * @param[in] bbox_count number of bboxes of each \p images, must be same length as \p images + * @return value created + */ + MMDEPLOY_API int mmdeploy_text_recognizer_create_input(const mmdeploy_mat_t* images, + int image_count, + const mmdeploy_text_detection_t* bboxes, + const int* bbox_count, + mmdeploy_value_t* output); + + MMDEPLOY_API int mmdeploy_text_recognizer_apply_v2(mmdeploy_text_recognizer_t recognizer, + mmdeploy_value_t input, + mmdeploy_value_t* output); + + /** + * @brief Same as \ref mmdeploy_text_recognizer_apply_bbox, but input and output are packed in \ref + * mmdeploy_value_t. + */ + MMDEPLOY_API int mmdeploy_text_recognizer_apply_async(mmdeploy_text_recognizer_t recognizer, + mmdeploy_sender_t input, + mmdeploy_sender_t* output); + + typedef int (*mmdeploy_text_recognizer_continue_t)(mmdeploy_text_recognition_t* results, + void* context, + mmdeploy_sender_t* output); + + MMDEPLOY_API int mmdeploy_text_recognizer_apply_async_v3(mmdeploy_text_recognizer_t recognizer, + const mmdeploy_mat_t* imgs, + int img_count, + const mmdeploy_text_detection_t* bboxes, + const int* bbox_count, + mmdeploy_sender_t* output); + + MMDEPLOY_API int mmdeploy_text_recognizer_continue_async(mmdeploy_sender_t input, + mmdeploy_text_recognizer_continue_t cont, + void* context, + mmdeploy_sender_t* output); + + /** + * @brief Unpack text-recognizer output from a mmdeploy_value_t + * @param[in] output + * @param[out] results + * @return status of the operation + */ + MMDEPLOY_API int mmdeploy_text_recognizer_get_result(mmdeploy_value_t output, + mmdeploy_text_recognition_t** results); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.cpp b/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.cpp index de71e57842..3f0ab3c305 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.cpp +++ b/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.cpp @@ -20,146 +20,178 @@ using namespace mmdeploy; -int mmdeploy_video_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id, - mmdeploy_video_recognizer_t* recognizer) { - mmdeploy_context_t context{}; - auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); - if (ec != MMDEPLOY_SUCCESS) { +int mmdeploy_video_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_video_recognizer_t* recognizer) +{ + mmdeploy_context_t context{}; + auto ec = mmdeploy_context_create_by_device(device_name, device_id, &context); + if (ec != MMDEPLOY_SUCCESS) + { + return ec; + } + ec = mmdeploy_video_recognizer_create_v2(model, context, recognizer); + mmdeploy_context_destroy(context); return ec; - } - ec = mmdeploy_video_recognizer_create_v2(model, context, recognizer); - mmdeploy_context_destroy(context); - return ec; } -int mmdeploy_video_recognizer_create_by_path(const char* model_path, const char* device_name, - int device_id, - mmdeploy_video_recognizer_t* recognizer) { - mmdeploy_model_t model{}; +int mmdeploy_video_recognizer_create_by_path(const char* model_path, const char* device_name, int device_id, mmdeploy_video_recognizer_t* recognizer) +{ + mmdeploy_model_t model{}; - if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) { + if (auto ec = mmdeploy_model_create_by_path(model_path, &model)) + { + return ec; + } + auto ec = mmdeploy_video_recognizer_create(model, device_name, device_id, recognizer); + mmdeploy_model_destroy(model); return ec; - } - auto ec = mmdeploy_video_recognizer_create(model, device_name, device_id, recognizer); - mmdeploy_model_destroy(model); - return ec; } -int mmdeploy_video_recognizer_apply(mmdeploy_video_recognizer_t recognizer, - const mmdeploy_mat_t* images, - const mmdeploy_video_sample_info_t* video_info, int video_count, - mmdeploy_video_recognition_t** results, int** result_count) { - wrapped input; - if (auto ec = - mmdeploy_video_recognizer_create_input(images, video_info, video_count, input.ptr())) { - return ec; - } +int mmdeploy_video_recognizer_apply(mmdeploy_video_recognizer_t recognizer, + const mmdeploy_mat_t* images, + const mmdeploy_video_sample_info_t* video_info, + int video_count, + mmdeploy_video_recognition_t** results, + int** result_count) +{ + wrapped input; + if (auto ec = + mmdeploy_video_recognizer_create_input(images, video_info, video_count, input.ptr())) + { + return ec; + } - wrapped output; - if (auto ec = mmdeploy_video_recognizer_apply_v2(recognizer, input, output.ptr())) { - return ec; - } + wrapped output; + if (auto ec = mmdeploy_video_recognizer_apply_v2(recognizer, input, output.ptr())) + { + return ec; + } - if (auto ec = mmdeploy_video_recognizer_get_result(output, results, result_count)) { - return ec; - } - return MMDEPLOY_SUCCESS; + if (auto ec = mmdeploy_video_recognizer_get_result(output, results, result_count)) + { + return ec; + } + return MMDEPLOY_SUCCESS; } void mmdeploy_video_recognizer_release_result(mmdeploy_video_recognition_t* results, - int* result_count, int video_count) { - delete[] results; - delete[] result_count; + int* result_count, + int video_count) +{ + delete[] results; + delete[] result_count; } -void mmdeploy_video_recognizer_destroy(mmdeploy_video_recognizer_t recognizer) { - mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)recognizer); +void mmdeploy_video_recognizer_destroy(mmdeploy_video_recognizer_t recognizer) +{ + mmdeploy_pipeline_destroy((mmdeploy_pipeline_t)recognizer); } -int mmdeploy_video_recognizer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, - mmdeploy_video_recognizer_t* recognizer) { - return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)recognizer); +int mmdeploy_video_recognizer_create_v2(mmdeploy_model_t model, mmdeploy_context_t context, mmdeploy_video_recognizer_t* recognizer) +{ + return mmdeploy_pipeline_create_from_model(model, context, (mmdeploy_pipeline_t*)recognizer); } -int mmdeploy_video_recognizer_create_input(const mmdeploy_mat_t* images, +int mmdeploy_video_recognizer_create_input(const mmdeploy_mat_t* images, const mmdeploy_video_sample_info_t* video_info, - int video_count, mmdeploy_value_t* value) { - if (video_count && (images == nullptr || video_info == nullptr)) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - auto input = std::make_unique(Value{Value::kArray}); - auto sample = std::make_unique(Value::kArray); - for (int i = 0; i < video_count; ++i) { - int clip_len = video_info[i].clip_len; - int num_clips = video_info[i].num_clips; - int n_mat = clip_len * num_clips; - for (int j = 0; j < n_mat; j++) { - mmdeploy::Mat _mat{images[j].height, - images[j].width, - PixelFormat(images[j].format), - DataType(images[j].type), - images[j].data, - images[j].device ? *(const Device*)(images[j].device) : Device{0}}; - sample->push_back({{"ori_img", _mat}, {"clip_len", clip_len}, {"num_clips", num_clips}}); - } - input->front().push_back(std::move(*sample.release())); + int video_count, + mmdeploy_value_t* value) +{ + if (video_count && (images == nullptr || video_info == nullptr)) + { + return MMDEPLOY_E_INVALID_ARG; + } + try + { + auto input = std::make_unique(Value{Value::kArray}); + auto sample = std::make_unique(Value::kArray); + for (int i = 0; i < video_count; ++i) + { + int clip_len = video_info[i].clip_len; + int num_clips = video_info[i].num_clips; + int n_mat = clip_len * num_clips; + for (int j = 0; j < n_mat; j++) + { + mmdeploy::Mat _mat{images[j].height, + images[j].width, + PixelFormat(images[j].format), + DataType(images[j].type), + images[j].data, + images[j].device ? *(const Device*)(images[j].device) : Device{0}}; + sample->push_back({{"ori_img", _mat}, {"clip_len", clip_len}, {"num_clips", num_clips}}); + } + input->front().push_back(std::move(*sample.release())); + } + *value = Cast(input.release()); + } + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); } - *value = Cast(input.release()); - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_SUCCESS; + return MMDEPLOY_SUCCESS; } int mmdeploy_video_recognizer_apply_v2(mmdeploy_video_recognizer_t recognizer, - mmdeploy_value_t input, mmdeploy_value_t* output) { - return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)recognizer, input, output); + mmdeploy_value_t input, + mmdeploy_value_t* output) +{ + return mmdeploy_pipeline_apply((mmdeploy_pipeline_t)recognizer, input, output); } -int mmdeploy_video_recognizer_get_result(mmdeploy_value_t output, +int mmdeploy_video_recognizer_get_result(mmdeploy_value_t output, mmdeploy_video_recognition_t** results, - int** result_count) { - if (!output || !results || !result_count) { - return MMDEPLOY_E_INVALID_ARG; - } - try { - Value& value = Cast(output)->front(); - - auto classify_outputs = from_value>(value); - - std::vector _result_count; - _result_count.reserve(classify_outputs.size()); - - for (const auto& cls_output : classify_outputs) { - _result_count.push_back((int)cls_output.size()); + int** result_count) +{ + if (!output || !results || !result_count) + { + return MMDEPLOY_E_INVALID_ARG; } - - auto total = std::accumulate(begin(_result_count), end(_result_count), 0); - - std::unique_ptr result_count_data(new int[_result_count.size()]{}); - std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); - - std::unique_ptr result_data( - new mmdeploy_video_recognition_t[total]{}); - auto result_ptr = result_data.get(); - for (const auto& cls_output : classify_outputs) { - for (const auto& label : cls_output) { - result_ptr->label_id = label.label_id; - result_ptr->score = label.score; - ++result_ptr; - } + try + { + Value& value = Cast(output)->front(); + + auto classify_outputs = from_value>(value); + + std::vector _result_count; + _result_count.reserve(classify_outputs.size()); + + for (const auto& cls_output : classify_outputs) + { + _result_count.push_back((int)cls_output.size()); + } + + auto total = std::accumulate(begin(_result_count), end(_result_count), 0); + + std::unique_ptr result_count_data(new int[_result_count.size()]{}); + std::copy(_result_count.begin(), _result_count.end(), result_count_data.get()); + + std::unique_ptr result_data( + new mmdeploy_video_recognition_t[total]{}); + auto result_ptr = result_data.get(); + for (const auto& cls_output : classify_outputs) + { + for (const auto& label : cls_output) + { + result_ptr->label_id = label.label_id; + result_ptr->score = label.score; + ++result_ptr; + } + } + + *result_count = result_count_data.release(); + *results = result_data.release(); + + return MMDEPLOY_SUCCESS; } - - *result_count = result_count_data.release(); - *results = result_data.release(); - - return MMDEPLOY_SUCCESS; - } catch (const std::exception& e) { - MMDEPLOY_ERROR("unhandled exception: {}", e.what()); - } catch (...) { - MMDEPLOY_ERROR("unknown exception caught"); - } - return MMDEPLOY_E_FAIL; + catch (const std::exception& e) + { + MMDEPLOY_ERROR("unhandled exception: {}", e.what()); + } + catch (...) + { + MMDEPLOY_ERROR("unknown exception caught"); + } + return MMDEPLOY_E_FAIL; } diff --git a/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.h b/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.h index e98b2bd07e..6893170e7d 100644 --- a/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.h +++ b/csrc/mmdeploy/apis/c/mmdeploy/video_recognizer.h @@ -13,124 +13,129 @@ #include "mmdeploy/model.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -typedef struct mmdeploy_video_recognition_t { - int label_id; - float score; -} mmdeploy_video_recognition_t; - -typedef struct mmdeploy_video_sample_info_t { - int clip_len; - int num_clips; -} mmdeploy_video_sample_info_t; - -typedef struct mmdeploy_video_recognizer* mmdeploy_video_recognizer_t; - -/** - * @brief Create video recognizer's handle - * @param[in] model an instance of mmaction sdk model created by - * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] recognizer handle of the created video recognizer, which must be destroyed - * by \ref mmdeploy_video_recognizer_destroy - * @return status of creating video recognizer's handle - */ -MMDEPLOY_API int mmdeploy_video_recognizer_create(mmdeploy_model_t model, const char* device_name, - int device_id, - mmdeploy_video_recognizer_t* recognizer); - -/** - * @brief Create a video recognizer instance - * @param[in] model_path path to video recognition model - * @param[in] device_name name of device, such as "cpu", "cuda", etc. - * @param[in] device_id id of device. - * @param[out] recognizer handle of the created video recognizer, which must be destroyed - * by \ref mmdeploy_video_recognizer_destroy - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_video_recognizer_create_by_path(const char* model_path, - const char* device_name, int device_id, - mmdeploy_video_recognizer_t* recognizer); - -/** - * @brief Apply video recognizer to a batch of videos - * @param[in] recognizer video recognizer's handle created by \ref - * mmdeploy_video_recognizer_create_by_path - * @param[in] images a batch of videos - * @param[in] video_info video information of each video - * @param[in] video_count number of videos - * @param[out] results a linear buffer contains the recognized video, must be release - * by \ref mmdeploy_video_recognizer_release_result - * @param[out] result_count a linear buffer with length being \p video_count to save the number of - * recognition results of each video. It must be released by \ref - * mmdeploy_video_recognizer_release_result - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_video_recognizer_apply(mmdeploy_video_recognizer_t recognizer, - const mmdeploy_mat_t* images, - const mmdeploy_video_sample_info_t* video_info, - int video_count, - mmdeploy_video_recognition_t** results, - int** result_count); - -/** @brief Release result buffer returned by \ref mmdeploy_video_recognizer_apply - * @param[in] results result buffer by video recognizer - * @param[in] result_count \p results size buffer - * @param[in] video_count length of \p result_count - */ -MMDEPLOY_API void mmdeploy_video_recognizer_release_result(mmdeploy_video_recognition_t* results, - int* result_count, int video_count); - -/** - * @brief destroy video recognizer - * @param[in] recognizer handle of video recognizer created by \ref - * mmdeploy_video_recognizer_create_by_path or \ref mmdeploy_video_recognizer_create - */ -MMDEPLOY_API void mmdeploy_video_recognizer_destroy(mmdeploy_video_recognizer_t recognizer); - -/** - * @brief Same as \ref mmdeploy_video_recognizer_create, but allows to control execution context of - * tasks via context - */ -MMDEPLOY_API int mmdeploy_video_recognizer_create_v2(mmdeploy_model_t model, - mmdeploy_context_t context, - mmdeploy_video_recognizer_t* recognizer); - -/** - * @brief Pack video recognizer inputs into mmdeploy_value_t - * @param[in] images a batch of videos - * @param[in] video_info video information of each video - * @param[in] video_count number of videos in the batch - * @param[out] value created value - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_video_recognizer_create_input( - const mmdeploy_mat_t* images, const mmdeploy_video_sample_info_t* video_info, int video_count, - mmdeploy_value_t* value); - -/** - * @brief Apply video recognizer to a batch of videos - * @param[in] input packed input - * @param[out] output inference output - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_video_recognizer_apply_v2(mmdeploy_video_recognizer_t recognizer, - mmdeploy_value_t input, - mmdeploy_value_t* output); - -/** - * @brief Apply video recognizer to a batch of videos - * @param[in] output inference output - * @param[out] results structured output - * @param[out] result_count number of each videos - * @return status code of the operation - */ -MMDEPLOY_API int mmdeploy_video_recognizer_get_result(mmdeploy_value_t output, - mmdeploy_video_recognition_t** results, - int** result_count); + typedef struct mmdeploy_video_recognition_t + { + int label_id; + float score; + } mmdeploy_video_recognition_t; + + typedef struct mmdeploy_video_sample_info_t + { + int clip_len; + int num_clips; + } mmdeploy_video_sample_info_t; + + typedef struct mmdeploy_video_recognizer* mmdeploy_video_recognizer_t; + + /** + * @brief Create video recognizer's handle + * @param[in] model an instance of mmaction sdk model created by + * \ref mmdeploy_model_create_by_path or \ref mmdeploy_model_create in \ref model.h + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] recognizer handle of the created video recognizer, which must be destroyed + * by \ref mmdeploy_video_recognizer_destroy + * @return status of creating video recognizer's handle + */ + MMDEPLOY_API int mmdeploy_video_recognizer_create(mmdeploy_model_t model, const char* device_name, int device_id, mmdeploy_video_recognizer_t* recognizer); + + /** + * @brief Create a video recognizer instance + * @param[in] model_path path to video recognition model + * @param[in] device_name name of device, such as "cpu", "cuda", etc. + * @param[in] device_id id of device. + * @param[out] recognizer handle of the created video recognizer, which must be destroyed + * by \ref mmdeploy_video_recognizer_destroy + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_video_recognizer_create_by_path(const char* model_path, + const char* device_name, + int device_id, + mmdeploy_video_recognizer_t* recognizer); + + /** + * @brief Apply video recognizer to a batch of videos + * @param[in] recognizer video recognizer's handle created by \ref + * mmdeploy_video_recognizer_create_by_path + * @param[in] images a batch of videos + * @param[in] video_info video information of each video + * @param[in] video_count number of videos + * @param[out] results a linear buffer contains the recognized video, must be release + * by \ref mmdeploy_video_recognizer_release_result + * @param[out] result_count a linear buffer with length being \p video_count to save the number of + * recognition results of each video. It must be released by \ref + * mmdeploy_video_recognizer_release_result + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_video_recognizer_apply(mmdeploy_video_recognizer_t recognizer, + const mmdeploy_mat_t* images, + const mmdeploy_video_sample_info_t* video_info, + int video_count, + mmdeploy_video_recognition_t** results, + int** result_count); + + /** @brief Release result buffer returned by \ref mmdeploy_video_recognizer_apply + * @param[in] results result buffer by video recognizer + * @param[in] result_count \p results size buffer + * @param[in] video_count length of \p result_count + */ + MMDEPLOY_API void mmdeploy_video_recognizer_release_result(mmdeploy_video_recognition_t* results, + int* result_count, + int video_count); + + /** + * @brief destroy video recognizer + * @param[in] recognizer handle of video recognizer created by \ref + * mmdeploy_video_recognizer_create_by_path or \ref mmdeploy_video_recognizer_create + */ + MMDEPLOY_API void mmdeploy_video_recognizer_destroy(mmdeploy_video_recognizer_t recognizer); + + /** + * @brief Same as \ref mmdeploy_video_recognizer_create, but allows to control execution context of + * tasks via context + */ + MMDEPLOY_API int mmdeploy_video_recognizer_create_v2(mmdeploy_model_t model, + mmdeploy_context_t context, + mmdeploy_video_recognizer_t* recognizer); + + /** + * @brief Pack video recognizer inputs into mmdeploy_value_t + * @param[in] images a batch of videos + * @param[in] video_info video information of each video + * @param[in] video_count number of videos in the batch + * @param[out] value created value + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_video_recognizer_create_input( + const mmdeploy_mat_t* images, + const mmdeploy_video_sample_info_t* video_info, + int video_count, + mmdeploy_value_t* value); + + /** + * @brief Apply video recognizer to a batch of videos + * @param[in] input packed input + * @param[out] output inference output + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_video_recognizer_apply_v2(mmdeploy_video_recognizer_t recognizer, + mmdeploy_value_t input, + mmdeploy_value_t* output); + + /** + * @brief Apply video recognizer to a batch of videos + * @param[in] output inference output + * @param[out] results structured output + * @param[out] result_count number of each videos + * @return status code of the operation + */ + MMDEPLOY_API int mmdeploy_video_recognizer_get_result(mmdeploy_value_t output, + mmdeploy_video_recognition_t** results, + int** result_count); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/classifier.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/classifier.hpp index 1d9880fb7d..bf4772bcfb 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/classifier.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/classifier.hpp @@ -6,68 +6,80 @@ #include "mmdeploy/classifier.h" #include "mmdeploy/common.hpp" -namespace mmdeploy { - -namespace cxx { - -using Classification = mmdeploy_classification_t; - -class Classifier : public NonMovable { - public: - Classifier(const Model& model, const Context& context) { - auto ec = mmdeploy_classifier_create_v2(model, context, &classifier_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~Classifier() { - if (classifier_) { - mmdeploy_classifier_destroy(classifier_); - classifier_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images) { - if (images.empty()) { - return {}; - } - - Classification* results{}; - int* result_count{}; - auto ec = mmdeploy_classifier_apply(classifier_, reinterpret(images.data()), - static_cast(images.size()), &results, &result_count); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::vector rets; - rets.reserve(images.size()); - - std::shared_ptr data(results, [result_count, count = images.size()](auto p) { - mmdeploy_classifier_release_result(p, result_count, count); - }); - - size_t offset = 0; - for (size_t i = 0; i < images.size(); ++i) { - offset += rets.emplace_back(offset, result_count[i], data).size(); - } - - return rets; - } - - Result Apply(const Mat& img) { return Apply(Span{img})[0]; } - - private: - mmdeploy_classifier_t classifier_{}; -}; - -} // namespace cxx - -using cxx::Classification; -using cxx::Classifier; +namespace mmdeploy +{ + + namespace cxx + { + + using Classification = mmdeploy_classification_t; + + class Classifier : public NonMovable + { + public: + Classifier(const Model& model, const Context& context) + { + auto ec = mmdeploy_classifier_create_v2(model, context, &classifier_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~Classifier() + { + if (classifier_) + { + mmdeploy_classifier_destroy(classifier_); + classifier_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images) + { + if (images.empty()) + { + return {}; + } + + Classification* results{}; + int* result_count{}; + auto ec = mmdeploy_classifier_apply(classifier_, reinterpret(images.data()), static_cast(images.size()), &results, &result_count); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::vector rets; + rets.reserve(images.size()); + + std::shared_ptr data(results, [result_count, count = images.size()](auto p) + { mmdeploy_classifier_release_result(p, result_count, count); }); + + size_t offset = 0; + for (size_t i = 0; i < images.size(); ++i) + { + offset += rets.emplace_back(offset, result_count[i], data).size(); + } + + return rets; + } + + Result Apply(const Mat& img) + { + return Apply(Span{img})[0]; + } + + private: + mmdeploy_classifier_t classifier_{}; + }; + + } // namespace cxx + + using cxx::Classification; + using cxx::Classifier; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/common.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/common.hpp index 610c3a8b9e..a7547aa7c7 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/common.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/common.hpp @@ -16,253 +16,378 @@ #include "mmdeploy/model.h" #ifndef MMDEPLOY_CXX_USE_OPENCV -#define MMDEPLOY_CXX_USE_OPENCV 1 + #define MMDEPLOY_CXX_USE_OPENCV 1 #endif #if MMDEPLOY_CXX_USE_OPENCV -#include "opencv2/core/core.hpp" + #include "opencv2/core/core.hpp" #endif -namespace mmdeploy { - -namespace cxx { - -using Rect = mmdeploy_rect_t; - -template -class UniqueHandle : public NonCopyable { - public: - UniqueHandle() = default; - explicit UniqueHandle(T handle) : handle_(handle) {} - - // derived class must destroy the object and reset `handle_` - ~UniqueHandle() { assert(handle_ == nullptr); } - - UniqueHandle(UniqueHandle&& o) noexcept : handle_(std::exchange(o.handle_, nullptr)) {} - UniqueHandle& operator=(UniqueHandle&& o) noexcept { - if (this != &o) { - handle_ = std::exchange(o.handle_, nullptr); - } - return *this; - } - - explicit operator T() const noexcept { return handle_; } - T operator->() const noexcept { return handle_; } - - protected: - T handle_{}; -}; - -class Model { - public: - explicit Model(const char* path) { - mmdeploy_model_t model{}; - auto ec = mmdeploy_model_create_by_path(path, &model); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - model_.reset(model, [](auto p) { mmdeploy_model_destroy(p); }); - } - - explicit Model(const std::string& path) : Model(path.c_str()) {} - - Model(const void* buffer, size_t size) { - mmdeploy_model_t model{}; - auto ec = mmdeploy_model_create(buffer, static_cast(size), &model); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - model_.reset(model, [](auto p) { mmdeploy_model_destroy(p); }); - } - - operator mmdeploy_model_t() const noexcept { return model_.get(); } - - private: - std::shared_ptr model_{}; -}; - -class Device { - public: - explicit Device(std::string name, int index = 0) : name_(std::move(name)), index_(index) { - mmdeploy_device_t device{}; - auto ec = mmdeploy_device_create(name_.c_str(), index, &device); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - device_.reset(device, [](auto p) { mmdeploy_device_destroy(p); }); - } - - const char* name() const noexcept { return name_.c_str(); } - int index() const noexcept { return index_; } - - operator mmdeploy_device_t() const noexcept { return device_.get(); } - - private: - std::string name_; - int index_; - std::shared_ptr device_; -}; - -class Profiler { - public: - explicit Profiler(std::string_view path) : path_(path) { - mmdeploy_profiler_t profiler{}; - auto ec = mmdeploy_profiler_create(path_.c_str(), &profiler); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - profiler_.reset(profiler, [](auto p) { mmdeploy_profiler_destroy(p); }); - }; - - operator mmdeploy_profiler_t() const noexcept { return profiler_.get(); } - - private: - std::string path_; - std::shared_ptr profiler_; -}; - -class Mat { - public: - Mat() : desc_{} {} - - Mat(int height, int width, int channels, mmdeploy_pixel_format_t format, - mmdeploy_data_type_t type, uint8_t* data, mmdeploy_device_t device = nullptr) - : desc_{data, height, width, channels, format, type, device} {} - - Mat(const mmdeploy_mat_t& desc) : desc_(desc) {} // NOLINT - - const mmdeploy_mat_t& desc() const noexcept { return desc_; } +namespace mmdeploy +{ + + namespace cxx + { + + using Rect = mmdeploy_rect_t; + + template + class UniqueHandle : public NonCopyable + { + public: + UniqueHandle() = default; + explicit UniqueHandle(T handle) + : handle_(handle) + { + } + + // derived class must destroy the object and reset `handle_` + ~UniqueHandle() + { + assert(handle_ == nullptr); + } + + UniqueHandle(UniqueHandle&& o) noexcept + : handle_(std::exchange(o.handle_, nullptr)) + { + } + UniqueHandle& operator=(UniqueHandle&& o) noexcept + { + if (this != &o) + { + handle_ = std::exchange(o.handle_, nullptr); + } + return *this; + } + + explicit operator T() const noexcept + { + return handle_; + } + T operator->() const noexcept + { + return handle_; + } + + protected: + T handle_{}; + }; + + class Model + { + public: + explicit Model(const char* path) + { + mmdeploy_model_t model{}; + auto ec = mmdeploy_model_create_by_path(path, &model); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + model_.reset(model, [](auto p) + { mmdeploy_model_destroy(p); }); + } + + explicit Model(const std::string& path) + : Model(path.c_str()) + { + } + + Model(const void* buffer, size_t size) + { + mmdeploy_model_t model{}; + auto ec = mmdeploy_model_create(buffer, static_cast(size), &model); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + model_.reset(model, [](auto p) + { mmdeploy_model_destroy(p); }); + } + + operator mmdeploy_model_t() const noexcept + { + return model_.get(); + } + + private: + std::shared_ptr model_{}; + }; + + class Device + { + public: + explicit Device(std::string name, int index = 0) + : name_(std::move(name)) + , index_(index) + { + mmdeploy_device_t device{}; + auto ec = mmdeploy_device_create(name_.c_str(), index, &device); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + device_.reset(device, [](auto p) + { mmdeploy_device_destroy(p); }); + } + + const char* name() const noexcept + { + return name_.c_str(); + } + int index() const noexcept + { + return index_; + } + + operator mmdeploy_device_t() const noexcept + { + return device_.get(); + } + + private: + std::string name_; + int index_; + std::shared_ptr device_; + }; + + class Profiler + { + public: + explicit Profiler(std::string_view path) + : path_(path) + { + mmdeploy_profiler_t profiler{}; + auto ec = mmdeploy_profiler_create(path_.c_str(), &profiler); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + profiler_.reset(profiler, [](auto p) + { mmdeploy_profiler_destroy(p); }); + }; + + operator mmdeploy_profiler_t() const noexcept + { + return profiler_.get(); + } + + private: + std::string path_; + std::shared_ptr profiler_; + }; + + class Mat + { + public: + Mat() + : desc_{} + { + } + + Mat(int height, int width, int channels, mmdeploy_pixel_format_t format, mmdeploy_data_type_t type, uint8_t* data, mmdeploy_device_t device = nullptr) + : desc_{data, height, width, channels, format, type, device} + { + } + + Mat(const mmdeploy_mat_t& desc) + : desc_(desc) + { + } // NOLINT + + const mmdeploy_mat_t& desc() const noexcept + { + return desc_; + } #if MMDEPLOY_CXX_USE_OPENCV - Mat(const cv::Mat& mat, mmdeploy_pixel_format_t pixel_format) - : desc_{mat.data, mat.rows, mat.cols, mat.channels(), pixel_format, GetCvType(mat.depth())} { - if (pixel_format == MMDEPLOY_PIXEL_FORMAT_COUNT) { - throw_exception(eNotSupported); - } - if (desc_.type == MMDEPLOY_DATA_TYPE_COUNT) { - throw_exception(eNotSupported); - } - } - Mat(const cv::Mat& mat) : Mat(mat, GetCvFormat(mat.channels())) {} - - static mmdeploy_data_type_t GetCvType(int depth) { - switch (depth) { - case CV_8U: - return MMDEPLOY_DATA_TYPE_UINT8; - case CV_32F: - return MMDEPLOY_DATA_TYPE_FLOAT; - default: - return MMDEPLOY_DATA_TYPE_COUNT; - } - } - static mmdeploy_pixel_format_t GetCvFormat(int channels) { - switch (channels) { - case 1: - return MMDEPLOY_PIXEL_FORMAT_GRAYSCALE; - case 3: - return MMDEPLOY_PIXEL_FORMAT_BGR; - case 4: - return MMDEPLOY_PIXEL_FORMAT_BGRA; - default: - return MMDEPLOY_PIXEL_FORMAT_COUNT; - } - } + Mat(const cv::Mat& mat, mmdeploy_pixel_format_t pixel_format) + : desc_{mat.data, mat.rows, mat.cols, mat.channels(), pixel_format, GetCvType(mat.depth())} + { + if (pixel_format == MMDEPLOY_PIXEL_FORMAT_COUNT) + { + throw_exception(eNotSupported); + } + if (desc_.type == MMDEPLOY_DATA_TYPE_COUNT) + { + throw_exception(eNotSupported); + } + } + Mat(const cv::Mat& mat) + : Mat(mat, GetCvFormat(mat.channels())) + { + } + + static mmdeploy_data_type_t GetCvType(int depth) + { + switch (depth) + { + case CV_8U: + return MMDEPLOY_DATA_TYPE_UINT8; + case CV_32F: + return MMDEPLOY_DATA_TYPE_FLOAT; + default: + return MMDEPLOY_DATA_TYPE_COUNT; + } + } + static mmdeploy_pixel_format_t GetCvFormat(int channels) + { + switch (channels) + { + case 1: + return MMDEPLOY_PIXEL_FORMAT_GRAYSCALE; + case 3: + return MMDEPLOY_PIXEL_FORMAT_BGR; + case 4: + return MMDEPLOY_PIXEL_FORMAT_BGRA; + default: + return MMDEPLOY_PIXEL_FORMAT_COUNT; + } + } #endif - private: - mmdeploy_mat_t desc_; -}; - -template -class Result_ { - public: - using value_type = T; - using size_type = size_t; - using difference_type = ptrdiff_t; - using reference = T&; - using const_reference = const T&; - using pointer = T*; - using const_pointer = const T*; - using iterator = T*; - using const_iterator = T*; - - Result_(size_t offset, size_t size, std::shared_ptr data) - : offset_(offset), size_(size), data_(std::move(data)) {} - - T& operator[](size_t index) const noexcept { return *(data_.get() + offset_ + index); } - size_t size() const noexcept { return size_; } - T* begin() const noexcept { return data_.get() + offset_; } - T* end() const noexcept { return begin() + size_; } - - T* operator->() const noexcept { return data_.get(); } - T& operator*() const noexcept { return *data_; } - - private: - size_t offset_; - size_t size_; - std::shared_ptr data_; -}; - -inline const mmdeploy_mat_t* reinterpret(const Mat* p) { - return reinterpret_cast(p); -} - -class Scheduler { - public: - explicit Scheduler(mmdeploy_scheduler_t scheduler) { - scheduler_.reset(scheduler, [](auto p) { mmdeploy_scheduler_destroy(p); }); - } - - static Scheduler ThreadPool(int num_threads) { - return Scheduler(mmdeploy_executor_create_thread_pool(num_threads)); - } - static Scheduler Thread() { return Scheduler(mmdeploy_executor_create_thread()); } - - operator mmdeploy_scheduler_t() const noexcept { return scheduler_.get(); } - - private: - std::shared_ptr scheduler_; -}; - -class Context { - public: - Context() { - mmdeploy_context_t context{}; - mmdeploy_context_create(&context); - context_.reset(context, [](auto p) { mmdeploy_context_destroy(p); }); - } - /* implicit */ Context(const Device& device) : Context() { Add(device); } - - void Add(const std::string& name, const Scheduler& scheduler) { - mmdeploy_context_add(*this, MMDEPLOY_TYPE_SCHEDULER, name.c_str(), scheduler); - } - - void Add(const std::string& name, const Model& model) { - mmdeploy_context_add(*this, MMDEPLOY_TYPE_MODEL, name.c_str(), model); - } - - void Add(const Device& device) { - mmdeploy_context_add(*this, MMDEPLOY_TYPE_DEVICE, nullptr, device); - } - - void Add(const Profiler& profiler) { - mmdeploy_context_add(*this, MMDEPLOY_TYPE_PROFILER, nullptr, profiler); - } - - operator mmdeploy_context_t() const noexcept { return context_.get(); } - - private: - std::shared_ptr context_; -}; - -} // namespace cxx - -using cxx::Context; -using cxx::Device; -using cxx::Mat; -using cxx::Model; -using cxx::Profiler; -using cxx::Rect; -using cxx::Scheduler; + private: + mmdeploy_mat_t desc_; + }; + + template + class Result_ + { + public: + using value_type = T; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = T&; + using const_reference = const T&; + using pointer = T*; + using const_pointer = const T*; + using iterator = T*; + using const_iterator = T*; + + Result_(size_t offset, size_t size, std::shared_ptr data) + : offset_(offset) + , size_(size) + , data_(std::move(data)) + { + } + + T& operator[](size_t index) const noexcept + { + return *(data_.get() + offset_ + index); + } + size_t size() const noexcept + { + return size_; + } + T* begin() const noexcept + { + return data_.get() + offset_; + } + T* end() const noexcept + { + return begin() + size_; + } + + T* operator->() const noexcept + { + return data_.get(); + } + T& operator*() const noexcept + { + return *data_; + } + + private: + size_t offset_; + size_t size_; + std::shared_ptr data_; + }; + + inline const mmdeploy_mat_t* reinterpret(const Mat* p) + { + return reinterpret_cast(p); + } + + class Scheduler + { + public: + explicit Scheduler(mmdeploy_scheduler_t scheduler) + { + scheduler_.reset(scheduler, [](auto p) + { mmdeploy_scheduler_destroy(p); }); + } + + static Scheduler ThreadPool(int num_threads) + { + return Scheduler(mmdeploy_executor_create_thread_pool(num_threads)); + } + static Scheduler Thread() + { + return Scheduler(mmdeploy_executor_create_thread()); + } + + operator mmdeploy_scheduler_t() const noexcept + { + return scheduler_.get(); + } + + private: + std::shared_ptr scheduler_; + }; + + class Context + { + public: + Context() + { + mmdeploy_context_t context{}; + mmdeploy_context_create(&context); + context_.reset(context, [](auto p) + { mmdeploy_context_destroy(p); }); + } + /* implicit */ Context(const Device& device) + : Context() + { + Add(device); + } + + void Add(const std::string& name, const Scheduler& scheduler) + { + mmdeploy_context_add(*this, MMDEPLOY_TYPE_SCHEDULER, name.c_str(), scheduler); + } + + void Add(const std::string& name, const Model& model) + { + mmdeploy_context_add(*this, MMDEPLOY_TYPE_MODEL, name.c_str(), model); + } + + void Add(const Device& device) + { + mmdeploy_context_add(*this, MMDEPLOY_TYPE_DEVICE, nullptr, device); + } + + void Add(const Profiler& profiler) + { + mmdeploy_context_add(*this, MMDEPLOY_TYPE_PROFILER, nullptr, profiler); + } + + operator mmdeploy_context_t() const noexcept + { + return context_.get(); + } + + private: + std::shared_ptr context_; + }; + + } // namespace cxx + + using cxx::Context; + using cxx::Device; + using cxx::Mat; + using cxx::Model; + using cxx::Profiler; + using cxx::Rect; + using cxx::Scheduler; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/detector.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/detector.hpp index 847505bbe7..6f38a20d90 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/detector.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/detector.hpp @@ -6,68 +6,80 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/detector.h" -namespace mmdeploy { - -namespace cxx { - -using Detection = mmdeploy_detection_t; - -class Detector : public NonMovable { - public: - Detector(const Model& model, const Context& context) { - auto ec = mmdeploy_detector_create_v2(model, context, &detector_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~Detector() { - if (detector_) { - mmdeploy_detector_destroy(detector_); - detector_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images) { - if (images.empty()) { - return {}; - } - - Detection* results{}; - int* result_count{}; - auto ec = mmdeploy_detector_apply(detector_, reinterpret(images.data()), - static_cast(images.size()), &results, &result_count); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::shared_ptr data(results, [result_count, count = images.size()](auto p) { - mmdeploy_detector_release_result(p, result_count, count); - }); - - std::vector rets; - rets.reserve(images.size()); - - size_t offset = 0; - for (size_t i = 0; i < images.size(); ++i) { - offset += rets.emplace_back(offset, result_count[i], data).size(); - } - - return rets; - } - - Result Apply(const Mat& image) { return Apply(Span{image})[0]; } - - private: - mmdeploy_detector_t detector_{}; -}; - -} // namespace cxx - -using cxx::Detection; -using cxx::Detector; +namespace mmdeploy +{ + + namespace cxx + { + + using Detection = mmdeploy_detection_t; + + class Detector : public NonMovable + { + public: + Detector(const Model& model, const Context& context) + { + auto ec = mmdeploy_detector_create_v2(model, context, &detector_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~Detector() + { + if (detector_) + { + mmdeploy_detector_destroy(detector_); + detector_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images) + { + if (images.empty()) + { + return {}; + } + + Detection* results{}; + int* result_count{}; + auto ec = mmdeploy_detector_apply(detector_, reinterpret(images.data()), static_cast(images.size()), &results, &result_count); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::shared_ptr data(results, [result_count, count = images.size()](auto p) + { mmdeploy_detector_release_result(p, result_count, count); }); + + std::vector rets; + rets.reserve(images.size()); + + size_t offset = 0; + for (size_t i = 0; i < images.size(); ++i) + { + offset += rets.emplace_back(offset, result_count[i], data).size(); + } + + return rets; + } + + Result Apply(const Mat& image) + { + return Apply(Span{image})[0]; + } + + private: + mmdeploy_detector_t detector_{}; + }; + + } // namespace cxx + + using cxx::Detection; + using cxx::Detector; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp index e20ec6a224..c5f07f56af 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/pipeline.hpp @@ -7,72 +7,87 @@ #include "mmdeploy/core/value.h" #include "mmdeploy/pipeline.h" -namespace mmdeploy { +namespace mmdeploy +{ -namespace cxx { + namespace cxx + { -class Pipeline : public NonMovable { - public: - Pipeline(const Value& config, const Context& context) { - mmdeploy_pipeline_t pipeline{}; - auto ec = mmdeploy_pipeline_create_v3((mmdeploy_value_t)&config, context, &pipeline); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - pipeline_ = pipeline; - } + class Pipeline : public NonMovable + { + public: + Pipeline(const Value& config, const Context& context) + { + mmdeploy_pipeline_t pipeline{}; + auto ec = mmdeploy_pipeline_create_v3((mmdeploy_value_t)&config, context, &pipeline); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + pipeline_ = pipeline; + } - ~Pipeline() { - if (pipeline_) { - mmdeploy_pipeline_destroy(pipeline_); - pipeline_ = nullptr; - } - } + ~Pipeline() + { + if (pipeline_) + { + mmdeploy_pipeline_destroy(pipeline_); + pipeline_ = nullptr; + } + } - Value Apply(const Value& inputs) { - mmdeploy_value_t tmp{}; - auto ec = mmdeploy_pipeline_apply(pipeline_, (mmdeploy_value_t)&inputs, &tmp); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - Value output = std::move(*(Value*)tmp); - mmdeploy_value_destroy(tmp); - return output; - } + Value Apply(const Value& inputs) + { + mmdeploy_value_t tmp{}; + auto ec = mmdeploy_pipeline_apply(pipeline_, (mmdeploy_value_t)&inputs, &tmp); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + Value output = std::move(*(Value*)tmp); + mmdeploy_value_destroy(tmp); + return output; + } - Value Apply(Span images) { - if (images.empty()) { - return {}; - } - mmdeploy_value_t inputs{}; - auto ec = mmdeploy_common_create_input(reinterpret(images.data()), - static_cast(images.size()), &inputs); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - auto outputs = Apply(*reinterpret_cast(inputs)); - mmdeploy_value_destroy(inputs); + Value Apply(Span images) + { + if (images.empty()) + { + return {}; + } + mmdeploy_value_t inputs{}; + auto ec = mmdeploy_common_create_input(reinterpret(images.data()), + static_cast(images.size()), + &inputs); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + auto outputs = Apply(*reinterpret_cast(inputs)); + mmdeploy_value_destroy(inputs); - return outputs; - } + return outputs; + } - Value Apply(const Mat& image) { - auto outputs = Apply(Span{image}); - Value::Array rets; - rets.reserve(outputs.size()); - for (auto& output : outputs) { - rets.push_back(std::move(output[0])); - } - return rets; - } + Value Apply(const Mat& image) + { + auto outputs = Apply(Span{image}); + Value::Array rets; + rets.reserve(outputs.size()); + for (auto& output : outputs) + { + rets.push_back(std::move(output[0])); + } + return rets; + } - private: - mmdeploy_pipeline_t pipeline_{}; -}; + private: + mmdeploy_pipeline_t pipeline_{}; + }; -} // namespace cxx + } // namespace cxx -using cxx::Pipeline; + using cxx::Pipeline; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/pose_detector.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/pose_detector.hpp index 7432a417fc..6a157f5228 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/pose_detector.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/pose_detector.hpp @@ -6,79 +6,88 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/pose_detector.h" -namespace mmdeploy { - -namespace cxx { - -using PoseDetection = mmdeploy_pose_detection_t; - -class PoseDetector : public NonMovable { - public: - PoseDetector(const Model& model, const Context& context) { - auto ec = mmdeploy_pose_detector_create_v2(model, context, &detector_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~PoseDetector() { - if (detector_) { - mmdeploy_pose_detector_destroy(detector_); - detector_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images, Span bboxes, - Span bbox_count) { - if (images.empty()) { - return {}; - } - - const mmdeploy_rect_t* p_bboxes{}; - const int* p_bbox_count{}; - - if (!bboxes.empty()) { - p_bboxes = bboxes.data(); - p_bbox_count = bbox_count.data(); - } - - PoseDetection* results{}; - auto ec = mmdeploy_pose_detector_apply_bbox(detector_, reinterpret(images.data()), - static_cast(images.size()), p_bboxes, - p_bbox_count, &results); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::shared_ptr data(results, [count = images.size()](auto p) { - mmdeploy_pose_detector_release_result(p, count); - }); - - std::vector rets; - rets.reserve(images.size()); - - size_t offset = 0; - for (size_t i = 0; i < images.size(); ++i) { - offset += rets.emplace_back(offset, bboxes.empty() ? 1 : bbox_count[i], data).size(); - } - - return rets; - } - - Result Apply(const Mat& image, Span bboxes = {}) { - return Apply(Span{image}, bboxes, {static_cast(bboxes.size())})[0]; - } - - private: - mmdeploy_pose_detector_t detector_{}; -}; - -} // namespace cxx - -using cxx::PoseDetection; -using cxx::PoseDetector; +namespace mmdeploy +{ + + namespace cxx + { + + using PoseDetection = mmdeploy_pose_detection_t; + + class PoseDetector : public NonMovable + { + public: + PoseDetector(const Model& model, const Context& context) + { + auto ec = mmdeploy_pose_detector_create_v2(model, context, &detector_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~PoseDetector() + { + if (detector_) + { + mmdeploy_pose_detector_destroy(detector_); + detector_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images, Span bboxes, Span bbox_count) + { + if (images.empty()) + { + return {}; + } + + const mmdeploy_rect_t* p_bboxes{}; + const int* p_bbox_count{}; + + if (!bboxes.empty()) + { + p_bboxes = bboxes.data(); + p_bbox_count = bbox_count.data(); + } + + PoseDetection* results{}; + auto ec = mmdeploy_pose_detector_apply_bbox(detector_, reinterpret(images.data()), static_cast(images.size()), p_bboxes, p_bbox_count, &results); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::shared_ptr data(results, [count = images.size()](auto p) + { mmdeploy_pose_detector_release_result(p, count); }); + + std::vector rets; + rets.reserve(images.size()); + + size_t offset = 0; + for (size_t i = 0; i < images.size(); ++i) + { + offset += rets.emplace_back(offset, bboxes.empty() ? 1 : bbox_count[i], data).size(); + } + + return rets; + } + + Result Apply(const Mat& image, Span bboxes = {}) + { + return Apply(Span{image}, bboxes, {static_cast(bboxes.size())})[0]; + } + + private: + mmdeploy_pose_detector_t detector_{}; + }; + + } // namespace cxx + + using cxx::PoseDetection; + using cxx::PoseDetector; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/pose_tracker.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/pose_tracker.hpp index 077ec75700..e1e330ce05 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/pose_tracker.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/pose_tracker.hpp @@ -6,145 +6,171 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/pose_tracker.h" -namespace mmdeploy { - -namespace cxx { - -class PoseTracker : public UniqueHandle { - public: - using Result = Result_; - class State; - class Params; - - public: - /** - * @brief Create pose tracker pipeline - * @param detect object detection model - * @param pose pose estimation model - * @param context execution context - */ - PoseTracker(const Model& detect, const Model& pose, const Context& context) { - auto ec = mmdeploy_pose_tracker_create(detect, pose, context, &handle_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - ~PoseTracker() { - if (handle_) { - mmdeploy_pose_tracker_destroy(handle_); - handle_ = {}; - } - } - PoseTracker(PoseTracker&&) noexcept = default; - - /** - * @brief Create a tracker state corresponds to a video stream - * @param params params for creating the tracker state - * @return created tracker state - */ - State CreateState(const Params& params); - - /** - * @brief Apply pose tracker pipeline - * @param state tracker state - * @param frame input video frame - * @param detect control the use of detector - * -1: use params.det_interval, 0: don't use detector, 1: force use detector - * @return - */ - Result Apply(State& state, const Mat& frame, int detect = -1); - - /** - * @brief batched version of Apply - * @param states - * @param frames - * @param detects - * @return - */ - std::vector Apply(const Span& states, const Span& frames, - const Span& detects = {}); - - public: - /** - * see \ref mmdeploy/pose_tracker.h for detail - */ - class Params : public UniqueHandle { - public: - explicit Params() { - handle_ = new mmdeploy_pose_tracker_param_t{}; - mmdeploy_pose_tracker_default_params(handle_); - } - ~Params() { - if (handle_) { - delete handle_; - handle_ = {}; - } - } - }; - - class State : public UniqueHandle { - public: - explicit State(mmdeploy_pose_tracker_t pipeline, const mmdeploy_pose_tracker_param_t* params) { - auto ec = mmdeploy_pose_tracker_create_state(pipeline, params, &handle_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - ~State() { - if (handle_) { - mmdeploy_pose_tracker_destroy_state(handle_); - handle_ = {}; - } - } - State(State&&) noexcept = default; - }; -}; - -inline PoseTracker::State PoseTracker::CreateState(const PoseTracker::Params& params) { - return State(handle_, static_cast(params)); -} - -inline std::vector PoseTracker::Apply(const Span& states, - const Span& frames, - const Span& detects) { - if (frames.empty()) { - return {}; - } - mmdeploy_pose_tracker_target_t* results{}; - int32_t* result_count{}; - - auto ec = mmdeploy_pose_tracker_apply( - handle_, reinterpret_cast(states.data()), - reinterpret(frames.data()), detects.data(), static_cast(frames.size()), &results, - &result_count); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::shared_ptr data( - results, [result_count, count = frames.size()](auto p) { - mmdeploy_pose_tracker_release_result(p, result_count, count); - }); - - std::vector rets; - rets.reserve(frames.size()); - - size_t offset = 0; - for (size_t i = 0; i < frames.size(); ++i) { - offset += rets.emplace_back(offset, result_count[i], data).size(); - } - - return rets; -} - -inline PoseTracker::Result PoseTracker::Apply(PoseTracker::State& state, const Mat& frame, - int32_t detect) { - return Apply(Span(&state, 1), Span{frame}, Span{detect})[0]; -} - -} // namespace cxx - -using cxx::PoseTracker; +namespace mmdeploy +{ + + namespace cxx + { + + class PoseTracker : public UniqueHandle + { + public: + using Result = Result_; + class State; + class Params; + + public: + /** + * @brief Create pose tracker pipeline + * @param detect object detection model + * @param pose pose estimation model + * @param context execution context + */ + PoseTracker(const Model& detect, const Model& pose, const Context& context) + { + auto ec = mmdeploy_pose_tracker_create(detect, pose, context, &handle_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + ~PoseTracker() + { + if (handle_) + { + mmdeploy_pose_tracker_destroy(handle_); + handle_ = {}; + } + } + PoseTracker(PoseTracker&&) noexcept = default; + + /** + * @brief Create a tracker state corresponds to a video stream + * @param params params for creating the tracker state + * @return created tracker state + */ + State CreateState(const Params& params); + + /** + * @brief Apply pose tracker pipeline + * @param state tracker state + * @param frame input video frame + * @param detect control the use of detector + * -1: use params.det_interval, 0: don't use detector, 1: force use detector + * @return + */ + Result Apply(State& state, const Mat& frame, int detect = -1); + + /** + * @brief batched version of Apply + * @param states + * @param frames + * @param detects + * @return + */ + std::vector Apply(const Span& states, const Span& frames, const Span& detects = {}); + + public: + /** + * see \ref mmdeploy/pose_tracker.h for detail + */ + class Params : public UniqueHandle + { + public: + explicit Params() + { + handle_ = new mmdeploy_pose_tracker_param_t{}; + mmdeploy_pose_tracker_default_params(handle_); + } + ~Params() + { + if (handle_) + { + delete handle_; + handle_ = {}; + } + } + }; + + class State : public UniqueHandle + { + public: + explicit State(mmdeploy_pose_tracker_t pipeline, const mmdeploy_pose_tracker_param_t* params) + { + auto ec = mmdeploy_pose_tracker_create_state(pipeline, params, &handle_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + ~State() + { + if (handle_) + { + mmdeploy_pose_tracker_destroy_state(handle_); + handle_ = {}; + } + } + State(State&&) noexcept = default; + }; + }; + + inline PoseTracker::State PoseTracker::CreateState(const PoseTracker::Params& params) + { + return State(handle_, static_cast(params)); + } + + inline std::vector PoseTracker::Apply(const Span& states, + const Span& frames, + const Span& detects) + { + if (frames.empty()) + { + return {}; + } + mmdeploy_pose_tracker_target_t* results{}; + int32_t* result_count{}; + + auto ec = mmdeploy_pose_tracker_apply( + handle_, + reinterpret_cast(states.data()), + reinterpret(frames.data()), + detects.data(), + static_cast(frames.size()), + &results, + &result_count); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::shared_ptr data( + results, + [result_count, count = frames.size()](auto p) + { + mmdeploy_pose_tracker_release_result(p, result_count, count); + }); + + std::vector rets; + rets.reserve(frames.size()); + + size_t offset = 0; + for (size_t i = 0; i < frames.size(); ++i) + { + offset += rets.emplace_back(offset, result_count[i], data).size(); + } + + return rets; + } + + inline PoseTracker::Result PoseTracker::Apply(PoseTracker::State& state, const Mat& frame, int32_t detect) + { + return Apply(Span(&state, 1), Span{frame}, Span{detect})[0]; + } + + } // namespace cxx + + using cxx::PoseTracker; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/restorer.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/restorer.hpp index 671c5c2d0c..dcf9ab75af 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/restorer.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/restorer.hpp @@ -6,62 +6,77 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/restorer.h" -namespace mmdeploy { - -namespace cxx { - -class Restorer : public NonMovable { - public: - Restorer(const Model& model, const Context& context) { - auto ec = mmdeploy_restorer_create_v2(model, context, &restorer_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~Restorer() { - if (restorer_) { - mmdeploy_restorer_destroy(restorer_); - restorer_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images) { - if (images.empty()) { - return {}; - } - - mmdeploy_mat_t* results{}; - auto ec = mmdeploy_restorer_apply(restorer_, reinterpret(images.data()), - static_cast(images.size()), &results); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::vector rets; - rets.reserve(images.size()); - - std::shared_ptr data( - results, [count = images.size()](auto p) { mmdeploy_restorer_release_result(p, count); }); - - for (size_t i = 0; i < images.size(); ++i) { - rets.emplace_back(i, 1, data); - } - - return rets; - } - - Result Apply(const Mat& image) { return Apply(Span{image})[0]; } - - private: - mmdeploy_restorer_t restorer_{}; -}; - -} // namespace cxx - -using cxx::Restorer; +namespace mmdeploy +{ + + namespace cxx + { + + class Restorer : public NonMovable + { + public: + Restorer(const Model& model, const Context& context) + { + auto ec = mmdeploy_restorer_create_v2(model, context, &restorer_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~Restorer() + { + if (restorer_) + { + mmdeploy_restorer_destroy(restorer_); + restorer_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images) + { + if (images.empty()) + { + return {}; + } + + mmdeploy_mat_t* results{}; + auto ec = mmdeploy_restorer_apply(restorer_, reinterpret(images.data()), static_cast(images.size()), &results); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::vector rets; + rets.reserve(images.size()); + + std::shared_ptr data( + results, + [count = images.size()](auto p) + { mmdeploy_restorer_release_result(p, count); }); + + for (size_t i = 0; i < images.size(); ++i) + { + rets.emplace_back(i, 1, data); + } + + return rets; + } + + Result Apply(const Mat& image) + { + return Apply(Span{image})[0]; + } + + private: + mmdeploy_restorer_t restorer_{}; + }; + + } // namespace cxx + + using cxx::Restorer; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/rotated_detector.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/rotated_detector.hpp index fa065b0f0c..5a224f6fa5 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/rotated_detector.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/rotated_detector.hpp @@ -6,69 +6,81 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/rotated_detector.h" -namespace mmdeploy { - -namespace cxx { - -using RotatedDetection = mmdeploy_rotated_detection_t; - -class RotatedDetector : public NonMovable { - public: - RotatedDetector(const Model& model, const Context& context) { - auto ec = mmdeploy_rotated_detector_create_v2(model, context, &detector_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~RotatedDetector() { - if (detector_) { - mmdeploy_rotated_detector_destroy(detector_); - detector_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images) { - if (images.empty()) { - return {}; - } - - RotatedDetection* results{}; - int* result_count{}; - auto ec = - mmdeploy_rotated_detector_apply(detector_, reinterpret(images.data()), - static_cast(images.size()), &results, &result_count); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::shared_ptr data(results, [result_count](auto p) { - mmdeploy_rotated_detector_release_result(p, result_count); - }); - - std::vector rets; - rets.reserve(images.size()); - - size_t offset = 0; - for (size_t i = 0; i < images.size(); ++i) { - offset += rets.emplace_back(offset, result_count[i], data).size(); - } - - return rets; - } - - Result Apply(const Mat& image) { return Apply(Span{image})[0]; } - - private: - mmdeploy_rotated_detector_t detector_{}; -}; - -} // namespace cxx - -using cxx::RotatedDetection; -using cxx::RotatedDetector; +namespace mmdeploy +{ + + namespace cxx + { + + using RotatedDetection = mmdeploy_rotated_detection_t; + + class RotatedDetector : public NonMovable + { + public: + RotatedDetector(const Model& model, const Context& context) + { + auto ec = mmdeploy_rotated_detector_create_v2(model, context, &detector_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~RotatedDetector() + { + if (detector_) + { + mmdeploy_rotated_detector_destroy(detector_); + detector_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images) + { + if (images.empty()) + { + return {}; + } + + RotatedDetection* results{}; + int* result_count{}; + auto ec = + mmdeploy_rotated_detector_apply(detector_, reinterpret(images.data()), static_cast(images.size()), &results, &result_count); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::shared_ptr data(results, [result_count](auto p) + { mmdeploy_rotated_detector_release_result(p, result_count); }); + + std::vector rets; + rets.reserve(images.size()); + + size_t offset = 0; + for (size_t i = 0; i < images.size(); ++i) + { + offset += rets.emplace_back(offset, result_count[i], data).size(); + } + + return rets; + } + + Result Apply(const Mat& image) + { + return Apply(Span{image})[0]; + } + + private: + mmdeploy_rotated_detector_t detector_{}; + }; + + } // namespace cxx + + using cxx::RotatedDetection; + using cxx::RotatedDetector; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/segmentor.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/segmentor.hpp index fe53023d1c..7ad98a91bb 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/segmentor.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/segmentor.hpp @@ -6,65 +6,80 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/segmentor.h" -namespace mmdeploy { - -namespace cxx { - -using Segmentation = mmdeploy_segmentation_t; - -class Segmentor : public NonMovable { - public: - Segmentor(const Model& model, const Context& context) { - auto ec = mmdeploy_segmentor_create_v2(model, context, &segmentor_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~Segmentor() { - if (segmentor_) { - mmdeploy_segmentor_destroy(segmentor_); - segmentor_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images) { - if (images.empty()) { - return {}; - } - - Segmentation* results{}; - auto ec = mmdeploy_segmentor_apply(segmentor_, reinterpret(images.data()), - static_cast(images.size()), &results); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::vector rets; - rets.reserve(images.size()); - - std::shared_ptr data( - results, [count = images.size()](auto p) { mmdeploy_segmentor_release_result(p, count); }); - - for (size_t i = 0; i < images.size(); ++i) { - rets.emplace_back(i, 1, data); - } - - return rets; - } - - Result Apply(const Mat& image) { return Apply(Span{image})[0]; } - - private: - mmdeploy_segmentor_t segmentor_{}; -}; - -} // namespace cxx - -using cxx::Segmentation; -using cxx::Segmentor; +namespace mmdeploy +{ + + namespace cxx + { + + using Segmentation = mmdeploy_segmentation_t; + + class Segmentor : public NonMovable + { + public: + Segmentor(const Model& model, const Context& context) + { + auto ec = mmdeploy_segmentor_create_v2(model, context, &segmentor_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~Segmentor() + { + if (segmentor_) + { + mmdeploy_segmentor_destroy(segmentor_); + segmentor_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images) + { + if (images.empty()) + { + return {}; + } + + Segmentation* results{}; + auto ec = mmdeploy_segmentor_apply(segmentor_, reinterpret(images.data()), static_cast(images.size()), &results); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::vector rets; + rets.reserve(images.size()); + + std::shared_ptr data( + results, + [count = images.size()](auto p) + { mmdeploy_segmentor_release_result(p, count); }); + + for (size_t i = 0; i < images.size(); ++i) + { + rets.emplace_back(i, 1, data); + } + + return rets; + } + + Result Apply(const Mat& image) + { + return Apply(Span{image})[0]; + } + + private: + mmdeploy_segmentor_t segmentor_{}; + }; + + } // namespace cxx + + using cxx::Segmentation; + using cxx::Segmentor; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/text_detector.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/text_detector.hpp index d848715405..56f2f02f18 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/text_detector.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/text_detector.hpp @@ -6,69 +6,81 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/text_detector.h" -namespace mmdeploy { - -namespace cxx { - -using TextDetection = mmdeploy_text_detection_t; - -class TextDetector : public NonMovable { - public: - TextDetector(const Model& model, const Context& context) { - auto ec = mmdeploy_text_detector_create_v2(model, context, &detector_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~TextDetector() { - if (detector_) { - mmdeploy_text_detector_destroy(detector_); - detector_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images) { - if (images.empty()) { - return {}; - } - - TextDetection* results{}; - int* result_count{}; - auto ec = - mmdeploy_text_detector_apply(detector_, reinterpret(images.data()), - static_cast(images.size()), &results, &result_count); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::shared_ptr data(results, [result_count, count = images.size()](auto p) { - mmdeploy_text_detector_release_result(p, result_count, count); - }); - - std::vector rets; - rets.reserve(images.size()); - - size_t offset = 0; - for (size_t i = 0; i < images.size(); ++i) { - offset += rets.emplace_back(offset, result_count[i], data).size(); - } - - return rets; - } - - Result Apply(const Mat& image) { return Apply(Span{image})[0]; } - - private: - mmdeploy_text_detector_t detector_{}; -}; - -} // namespace cxx - -using cxx::TextDetection; -using cxx::TextDetector; +namespace mmdeploy +{ + + namespace cxx + { + + using TextDetection = mmdeploy_text_detection_t; + + class TextDetector : public NonMovable + { + public: + TextDetector(const Model& model, const Context& context) + { + auto ec = mmdeploy_text_detector_create_v2(model, context, &detector_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~TextDetector() + { + if (detector_) + { + mmdeploy_text_detector_destroy(detector_); + detector_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images) + { + if (images.empty()) + { + return {}; + } + + TextDetection* results{}; + int* result_count{}; + auto ec = + mmdeploy_text_detector_apply(detector_, reinterpret(images.data()), static_cast(images.size()), &results, &result_count); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::shared_ptr data(results, [result_count, count = images.size()](auto p) + { mmdeploy_text_detector_release_result(p, result_count, count); }); + + std::vector rets; + rets.reserve(images.size()); + + size_t offset = 0; + for (size_t i = 0; i < images.size(); ++i) + { + offset += rets.emplace_back(offset, result_count[i], data).size(); + } + + return rets; + } + + Result Apply(const Mat& image) + { + return Apply(Span{image})[0]; + } + + private: + mmdeploy_text_detector_t detector_{}; + }; + + } // namespace cxx + + using cxx::TextDetection; + using cxx::TextDetector; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/text_recognizer.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/text_recognizer.hpp index eba8ea3902..31c741e2ee 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/text_recognizer.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/text_recognizer.hpp @@ -9,82 +9,91 @@ #include "mmdeploy/text_detector.hpp" #include "mmdeploy/text_recognizer.h" -namespace mmdeploy { - -namespace cxx { - -using TextRecognition = mmdeploy_text_recognition_t; - -class TextRecognizer : public NonMovable { - public: - TextRecognizer(const Model& model, const Context& context) { - auto ec = mmdeploy_text_recognizer_create_v2(model, context, &recognizer_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~TextRecognizer() { - if (recognizer_) { - mmdeploy_text_recognizer_destroy(recognizer_); - recognizer_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span images, Span bboxes, - Span bbox_count) { - if (images.empty()) { - return {}; - } - - const TextDetection* p_bboxes{}; - const int* p_bbox_count{}; - - auto n_total_bboxes = static_cast(images.size()); - - if (!bboxes.empty()) { - p_bboxes = bboxes.data(); - p_bbox_count = bbox_count.data(); - n_total_bboxes = std::accumulate(bbox_count.begin(), bbox_count.end(), 0); - } - - TextRecognition* results{}; - auto ec = mmdeploy_text_recognizer_apply_bbox(recognizer_, reinterpret(images.data()), - static_cast(images.size()), p_bboxes, - p_bbox_count, &results); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::shared_ptr data(results, [count = n_total_bboxes](auto p) { - mmdeploy_text_recognizer_release_result(p, count); - }); - - std::vector rets; - rets.reserve(images.size()); - - size_t offset = 0; - for (size_t i = 0; i < images.size(); ++i) { - offset += rets.emplace_back(offset, bboxes.empty() ? 1 : bbox_count[i], data).size(); - } - - return rets; - } - - Result Apply(const Mat& image, Span bboxes = {}) { - return Apply(Span{image}, bboxes, {static_cast(bboxes.size())})[0]; - } - - private: - mmdeploy_text_recognizer_t recognizer_{}; -}; - -} // namespace cxx - -using cxx::TextRecognition; -using cxx::TextRecognizer; +namespace mmdeploy +{ + + namespace cxx + { + + using TextRecognition = mmdeploy_text_recognition_t; + + class TextRecognizer : public NonMovable + { + public: + TextRecognizer(const Model& model, const Context& context) + { + auto ec = mmdeploy_text_recognizer_create_v2(model, context, &recognizer_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~TextRecognizer() + { + if (recognizer_) + { + mmdeploy_text_recognizer_destroy(recognizer_); + recognizer_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span images, Span bboxes, Span bbox_count) + { + if (images.empty()) + { + return {}; + } + + const TextDetection* p_bboxes{}; + const int* p_bbox_count{}; + + auto n_total_bboxes = static_cast(images.size()); + + if (!bboxes.empty()) + { + p_bboxes = bboxes.data(); + p_bbox_count = bbox_count.data(); + n_total_bboxes = std::accumulate(bbox_count.begin(), bbox_count.end(), 0); + } + + TextRecognition* results{}; + auto ec = mmdeploy_text_recognizer_apply_bbox(recognizer_, reinterpret(images.data()), static_cast(images.size()), p_bboxes, p_bbox_count, &results); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::shared_ptr data(results, [count = n_total_bboxes](auto p) + { mmdeploy_text_recognizer_release_result(p, count); }); + + std::vector rets; + rets.reserve(images.size()); + + size_t offset = 0; + for (size_t i = 0; i < images.size(); ++i) + { + offset += rets.emplace_back(offset, bboxes.empty() ? 1 : bbox_count[i], data).size(); + } + + return rets; + } + + Result Apply(const Mat& image, Span bboxes = {}) + { + return Apply(Span{image}, bboxes, {static_cast(bboxes.size())})[0]; + } + + private: + mmdeploy_text_recognizer_t recognizer_{}; + }; + + } // namespace cxx + + using cxx::TextRecognition; + using cxx::TextRecognizer; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/cxx/mmdeploy/video_recognizer.hpp b/csrc/mmdeploy/apis/cxx/mmdeploy/video_recognizer.hpp index 583b28dd59..ed3569e242 100644 --- a/csrc/mmdeploy/apis/cxx/mmdeploy/video_recognizer.hpp +++ b/csrc/mmdeploy/apis/cxx/mmdeploy/video_recognizer.hpp @@ -6,85 +6,97 @@ #include "mmdeploy/common.hpp" #include "mmdeploy/video_recognizer.h" -namespace mmdeploy { - -namespace cxx { - -using VideoRecognition = mmdeploy_video_recognition_t; -using VideoSampleInfo = mmdeploy_video_sample_info_t; - -class VideoRecognizer : public NonMovable { - public: - VideoRecognizer(const Model& model, const Context& context) { - auto ec = mmdeploy_video_recognizer_create_v2(model, context, &recognizer_); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - } - - ~VideoRecognizer() { - if (recognizer_) { - mmdeploy_video_recognizer_destroy(recognizer_); - recognizer_ = {}; - } - } - - using Result = Result_; - - std::vector Apply(Span> videos, - Span infos) { - if (videos.empty()) { - return {}; - } - - int video_count = videos.size(); - - VideoRecognition* results{}; - int* result_count{}; - std::vector images; - std::vector video_info; - for (int i = 0; i < videos.size(); i++) { - for (auto& mat : videos[i]) { - images.push_back(mat); - } - video_info.push_back(infos[i]); - } - - auto ec = - mmdeploy_video_recognizer_apply(recognizer_, reinterpret(images.data()), video_info.data(), - video_count, &results, &result_count); - if (ec != MMDEPLOY_SUCCESS) { - throw_exception(static_cast(ec)); - } - - std::vector rets; - rets.reserve(video_count); - - std::shared_ptr data(results, [result_count, count = video_count](auto p) { - mmdeploy_video_recognizer_release_result(p, result_count, count); - }); - - size_t offset = 0; - for (size_t i = 0; i < video_count; ++i) { - offset += rets.emplace_back(offset, result_count[i], data).size(); - } - - return rets; - } - - Result Apply(const std::vector& video, const VideoSampleInfo info) { - return Apply(Span{video}, Span{info})[0]; - } - - private: - mmdeploy_video_recognizer_t recognizer_{}; -}; - -} // namespace cxx - -using cxx::VideoRecognition; -using cxx::VideoRecognizer; -using cxx::VideoSampleInfo; +namespace mmdeploy +{ + + namespace cxx + { + + using VideoRecognition = mmdeploy_video_recognition_t; + using VideoSampleInfo = mmdeploy_video_sample_info_t; + + class VideoRecognizer : public NonMovable + { + public: + VideoRecognizer(const Model& model, const Context& context) + { + auto ec = mmdeploy_video_recognizer_create_v2(model, context, &recognizer_); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + } + + ~VideoRecognizer() + { + if (recognizer_) + { + mmdeploy_video_recognizer_destroy(recognizer_); + recognizer_ = {}; + } + } + + using Result = Result_; + + std::vector Apply(Span> videos, + Span infos) + { + if (videos.empty()) + { + return {}; + } + + int video_count = videos.size(); + + VideoRecognition* results{}; + int* result_count{}; + std::vector images; + std::vector video_info; + for (int i = 0; i < videos.size(); i++) + { + for (auto& mat : videos[i]) + { + images.push_back(mat); + } + video_info.push_back(infos[i]); + } + + auto ec = + mmdeploy_video_recognizer_apply(recognizer_, reinterpret(images.data()), video_info.data(), video_count, &results, &result_count); + if (ec != MMDEPLOY_SUCCESS) + { + throw_exception(static_cast(ec)); + } + + std::vector rets; + rets.reserve(video_count); + + std::shared_ptr data(results, [result_count, count = video_count](auto p) + { mmdeploy_video_recognizer_release_result(p, result_count, count); }); + + size_t offset = 0; + for (size_t i = 0; i < video_count; ++i) + { + offset += rets.emplace_back(offset, result_count[i], data).size(); + } + + return rets; + } + + Result Apply(const std::vector& video, const VideoSampleInfo info) + { + return Apply(Span{video}, Span{info})[0]; + } + + private: + mmdeploy_video_recognizer_t recognizer_{}; + }; + + } // namespace cxx + + using cxx::VideoRecognition; + using cxx::VideoRecognizer; + using cxx::VideoSampleInfo; } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/java/native/common.h b/csrc/mmdeploy/apis/java/native/common.h index ba2601e5f1..045dc02a35 100644 --- a/csrc/mmdeploy/apis/java/native/common.h +++ b/csrc/mmdeploy/apis/java/native/common.h @@ -10,45 +10,48 @@ #include "mmdeploy/core/logger.h" #include "mmdeploy/core/utils/formatter.h" -template -static auto With(JNIEnv *env, jobjectArray imgs, F f) noexcept { - auto mat_clazz = env->FindClass("mmdeploy/Mat"); - auto shape_field = env->GetFieldID(mat_clazz, "shape", "[I"); - auto format_field = env->GetFieldID(mat_clazz, "format", "I"); - auto type_field = env->GetFieldID(mat_clazz, "type", "I"); - auto data_field = env->GetFieldID(mat_clazz, "data", "[B"); - auto num = env->GetArrayLength(imgs); - std::vector mats; - std::vector datum; - - mats.reserve(num); - datum.reserve(num); - - for (int i = 0; i < num; ++i) { - auto obj = env->GetObjectArrayElement(imgs, i); - auto shape_obj = env->GetObjectField(obj, shape_field); - auto shape = env->GetIntArrayElements((jintArray)shape_obj, nullptr); - auto format = env->GetIntField(obj, format_field); - auto type = env->GetIntField(obj, type_field); - auto &mat = mats.emplace_back(); - mat.height = shape[0]; - mat.width = shape[1]; - mat.channel = shape[2]; - env->ReleaseIntArrayElements((jintArray)shape_obj, shape, JNI_ABORT); - mat.format = (mmdeploy_pixel_format_t)format; - mat.type = (mmdeploy_data_type_t)type; - auto data_obj = env->GetObjectField(obj, data_field); - mat.data = (uint8_t *)env->GetByteArrayElements((jbyteArray)data_obj, nullptr); - datum.push_back((jbyteArray)data_obj); - } - - auto ret = f(mats.data(), mats.size()); // ! f must not throw - - for (int i = 0; i < num; ++i) { - env->ReleaseByteArrayElements(datum[i], (jbyte *)mats[i].data, JNI_ABORT); - } - - return ret; +template +static auto With(JNIEnv* env, jobjectArray imgs, F f) noexcept +{ + auto mat_clazz = env->FindClass("mmdeploy/Mat"); + auto shape_field = env->GetFieldID(mat_clazz, "shape", "[I"); + auto format_field = env->GetFieldID(mat_clazz, "format", "I"); + auto type_field = env->GetFieldID(mat_clazz, "type", "I"); + auto data_field = env->GetFieldID(mat_clazz, "data", "[B"); + auto num = env->GetArrayLength(imgs); + std::vector mats; + std::vector datum; + + mats.reserve(num); + datum.reserve(num); + + for (int i = 0; i < num; ++i) + { + auto obj = env->GetObjectArrayElement(imgs, i); + auto shape_obj = env->GetObjectField(obj, shape_field); + auto shape = env->GetIntArrayElements((jintArray)shape_obj, nullptr); + auto format = env->GetIntField(obj, format_field); + auto type = env->GetIntField(obj, type_field); + auto& mat = mats.emplace_back(); + mat.height = shape[0]; + mat.width = shape[1]; + mat.channel = shape[2]; + env->ReleaseIntArrayElements((jintArray)shape_obj, shape, JNI_ABORT); + mat.format = (mmdeploy_pixel_format_t)format; + mat.type = (mmdeploy_data_type_t)type; + auto data_obj = env->GetObjectField(obj, data_field); + mat.data = (uint8_t*)env->GetByteArrayElements((jbyteArray)data_obj, nullptr); + datum.push_back((jbyteArray)data_obj); + } + + auto ret = f(mats.data(), mats.size()); // ! f must not throw + + for (int i = 0; i < num; ++i) + { + env->ReleaseByteArrayElements(datum[i], (jbyte*)mats[i].data, JNI_ABORT); + } + + return ret; } #endif // MMDEPLOY_CSRC_APIS_JAVA_NATIVE_COMMON_H_ diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.cpp index 2a3309361e..6664a65289 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.cpp @@ -6,30 +6,33 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Classifier_create(JNIEnv *env, jobject, jstring modelPath, jstring deviceName, - jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_classifier_t classifier{}; - auto ec = - mmdeploy_classifier_create_by_path(model_path, device_name, (int)device_id, &classifier); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create classifier, code = {}", ec); - return -1; - } - return (jlong)classifier; +jlong Java_mmdeploy_Classifier_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_classifier_t classifier{}; + auto ec = + mmdeploy_classifier_create_by_path(model_path, device_name, (int)device_id, &classifier); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create classifier, code = {}", ec); + return -1; + } + return (jlong)classifier; } -void Java_mmdeploy_Classifier_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_Classifier_destroy"); - mmdeploy_classifier_destroy((mmdeploy_classifier_t)handle); +void Java_mmdeploy_Classifier_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Classifier_destroy"); + mmdeploy_classifier_destroy((mmdeploy_classifier_t)handle); } -jobjectArray Java_mmdeploy_Classifier_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images, jintArray counts) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_Classifier_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images, jintArray counts) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_classification_t *results{}; int *result_count{}; auto ec = mmdeploy_classifier_apply((mmdeploy_classifier_t)handle, imgs, size, &results, @@ -55,6 +58,5 @@ jobjectArray Java_mmdeploy_Classifier_apply(JNIEnv *env, jobject thiz, jlong han } env->ReleaseIntArrayElements(counts, counts_array, 0); mmdeploy_classifier_release_result(results, result_count, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.h index 16a06b5fba..84adf58aa3 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Classifier.h @@ -3,33 +3,33 @@ /* Header for class mmdeploy_Classifier */ #ifndef _Included_mmdeploy_Classifier -#define _Included_mmdeploy_Classifier -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Classifier - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Classifier_create(JNIEnv *, jobject, jstring, jstring, jint); + #define _Included_mmdeploy_Classifier + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Classifier + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Classifier_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_Classifier - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Classifier_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Classifier + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Classifier_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_Classifier - * Method: apply - * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/Classifier/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Classifier_apply(JNIEnv *, jobject, jlong, - jobjectArray, jintArray); + /* + * Class: mmdeploy_Classifier + * Method: apply + * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/Classifier/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Classifier_apply(JNIEnv*, jobject, jlong, jobjectArray, jintArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Context.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Context.cpp index dbd401724e..e875a66ead 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Context.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Context.cpp @@ -8,36 +8,43 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Context_create(JNIEnv *env, jobject) { - mmdeploy_context_t context{}; - mmdeploy_context_create(&context); - return (jlong)context; +jlong Java_mmdeploy_Context_create(JNIEnv* env, jobject) +{ + mmdeploy_context_t context{}; + mmdeploy_context_create(&context); + return (jlong)context; } -jint Java_mmdeploy_Context_add(JNIEnv *env, jobject, jlong context_, jint contextType, jstring name, - jlong handle) { - auto object_name = env->GetStringUTFChars(name, nullptr); - if ((int)contextType == MMDEPLOY_TYPE_SCHEDULER) { - mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, - object_name, (mmdeploy_scheduler_t)handle); - } else if ((int)contextType == MMDEPLOY_TYPE_MODEL) { - mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, - object_name, (mmdeploy_model_t)handle); - } else if ((int)contextType == MMDEPLOY_TYPE_DEVICE) { - mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, - nullptr, (mmdeploy_device_t)handle); - } else if ((int)contextType == MMDEPLOY_TYPE_PROFILER) { - mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, - nullptr, (mmdeploy_profiler_t)handle); - } else { - MMDEPLOY_ERROR("wrong context type, got {}", (int)contextType); - return MMDEPLOY_E_NOT_SUPPORTED; - } - env->ReleaseStringUTFChars(name, object_name); - return 0; +jint Java_mmdeploy_Context_add(JNIEnv* env, jobject, jlong context_, jint contextType, jstring name, jlong handle) +{ + auto object_name = env->GetStringUTFChars(name, nullptr); + if ((int)contextType == MMDEPLOY_TYPE_SCHEDULER) + { + mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, object_name, (mmdeploy_scheduler_t)handle); + } + else if ((int)contextType == MMDEPLOY_TYPE_MODEL) + { + mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, object_name, (mmdeploy_model_t)handle); + } + else if ((int)contextType == MMDEPLOY_TYPE_DEVICE) + { + mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, nullptr, (mmdeploy_device_t)handle); + } + else if ((int)contextType == MMDEPLOY_TYPE_PROFILER) + { + mmdeploy_context_add((mmdeploy_context_t)context_, (mmdeploy_context_type_t)contextType, nullptr, (mmdeploy_profiler_t)handle); + } + else + { + MMDEPLOY_ERROR("wrong context type, got {}", (int)contextType); + return MMDEPLOY_E_NOT_SUPPORTED; + } + env->ReleaseStringUTFChars(name, object_name); + return 0; } -void Java_mmdeploy_Context_destroy(JNIEnv *, jobject, jlong context_) { - MMDEPLOY_DEBUG("Java_mmdeploy_Context_destroy"); - mmdeploy_context_destroy((mmdeploy_context_t)context_); +void Java_mmdeploy_Context_destroy(JNIEnv*, jobject, jlong context_) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Context_destroy"); + mmdeploy_context_destroy((mmdeploy_context_t)context_); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Context.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Context.h index 42df819580..00e24065c6 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Context.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Context.h @@ -3,32 +3,33 @@ /* Header for class mmdeploy_Context */ #ifndef _Included_mmdeploy_Context -#define _Included_mmdeploy_Context -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Context - * Method: create - * Signature: ()J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Context_create(JNIEnv *, jobject); + #define _Included_mmdeploy_Context + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Context + * Method: create + * Signature: ()J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Context_create(JNIEnv*, jobject); -/* - * Class: mmdeploy_Context - * Method: add - * Signature: (JILjava/lang/String;J)I - */ -JNIEXPORT jint JNICALL Java_mmdeploy_Context_add(JNIEnv *, jobject, jlong, jint, jstring, jlong); + /* + * Class: mmdeploy_Context + * Method: add + * Signature: (JILjava/lang/String;J)I + */ + JNIEXPORT jint JNICALL Java_mmdeploy_Context_add(JNIEnv*, jobject, jlong, jint, jstring, jlong); -/* - * Class: mmdeploy_Context - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Context_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Context + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Context_destroy(JNIEnv*, jobject, jlong); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.cpp index c03ff1a1ff..6e8a32dac7 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.cpp @@ -6,29 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Detector_create(JNIEnv *env, jobject, jstring modelPath, jstring deviceName, - jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_detector_t detector{}; - auto ec = mmdeploy_detector_create_by_path(model_path, device_name, (int)device_id, &detector); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create detector, code = {}", ec); - return -1; - } - return (jlong)detector; +jlong Java_mmdeploy_Detector_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_detector_t detector{}; + auto ec = mmdeploy_detector_create_by_path(model_path, device_name, (int)device_id, &detector); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create detector, code = {}", ec); + return -1; + } + return (jlong)detector; } -void Java_mmdeploy_Detector_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_Detector_destroy"); // maybe use info? - mmdeploy_detector_destroy((mmdeploy_detector_t)handle); +void Java_mmdeploy_Detector_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Detector_destroy"); // maybe use info? + mmdeploy_detector_destroy((mmdeploy_detector_t)handle); } -jobjectArray Java_mmdeploy_Detector_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images, jintArray counts) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_Detector_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images, jintArray counts) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_detection_t *results{}; int *result_count{}; auto ec = @@ -79,6 +82,5 @@ jobjectArray Java_mmdeploy_Detector_apply(JNIEnv *env, jobject thiz, jlong handl } env->ReleaseIntArrayElements(counts, counts_array, 0); mmdeploy_detector_release_result(results, result_count, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.h index 41e711d15a..578643efc8 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Detector.h @@ -3,33 +3,33 @@ /* Header for class mmdeploy_Detector */ #ifndef _Included_mmdeploy_Detector -#define _Included_mmdeploy_Detector -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Detector - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Detector_create(JNIEnv *, jobject, jstring, jstring, jint); + #define _Included_mmdeploy_Detector + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Detector + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Detector_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_Detector - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Detector_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Detector + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Detector_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_Detector - * Method: apply - * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/Detector/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Detector_apply(JNIEnv *, jobject, jlong, jobjectArray, - jintArray); + /* + * Class: mmdeploy_Detector + * Method: apply + * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/Detector/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Detector_apply(JNIEnv*, jobject, jlong, jobjectArray, jintArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Device.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Device.cpp index 8dbec9285b..8160210ed5 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Device.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Device.cpp @@ -6,19 +6,22 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Device_create(JNIEnv *env, jobject, jstring name, jint index) { - auto device_name = env->GetStringUTFChars(name, nullptr); - mmdeploy_device_t device{}; - auto ec = mmdeploy_device_create(device_name, (int)index, &device); - env->ReleaseStringUTFChars(name, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create device, code = {}", ec); - return -1; - } - return (jlong)device; +jlong Java_mmdeploy_Device_create(JNIEnv* env, jobject, jstring name, jint index) +{ + auto device_name = env->GetStringUTFChars(name, nullptr); + mmdeploy_device_t device{}; + auto ec = mmdeploy_device_create(device_name, (int)index, &device); + env->ReleaseStringUTFChars(name, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create device, code = {}", ec); + return -1; + } + return (jlong)device; } -void Java_mmdeploy_Device_destroy(JNIEnv *, jobject, jlong device_) { - MMDEPLOY_DEBUG("Java_mmdeploy_Device_destroy"); - mmdeploy_device_destroy((mmdeploy_device_t)device_); +void Java_mmdeploy_Device_destroy(JNIEnv*, jobject, jlong device_) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Device_destroy"); + mmdeploy_device_destroy((mmdeploy_device_t)device_); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Device.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Device.h index 7d7ee9dee7..e751d0f781 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Device.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Device.h @@ -3,25 +3,26 @@ /* Header for class mmdeploy_Device */ #ifndef _Included_mmdeploy_Device -#define _Included_mmdeploy_Device -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Device - * Method: create - * Signature: (Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Device_create(JNIEnv *, jobject, jstring, jint); + #define _Included_mmdeploy_Device + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Device + * Method: create + * Signature: (Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Device_create(JNIEnv*, jobject, jstring, jint); -/* - * Class: mmdeploy_Device - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Device_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Device + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Device_destroy(JNIEnv*, jobject, jlong); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Model.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Model.cpp index 2bbc9a6920..821b1e988e 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Model.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Model.cpp @@ -6,19 +6,22 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Model_create(JNIEnv *env, jobject, jstring path) { - auto model_path = env->GetStringUTFChars(path, nullptr); - mmdeploy_model_t model{}; - auto ec = mmdeploy_model_create_by_path(model_path, &model); - env->ReleaseStringUTFChars(path, model_path); - if (ec) { - MMDEPLOY_ERROR("failed to create model, code = {}", ec); - return -1; - } - return (jlong)model; +jlong Java_mmdeploy_Model_create(JNIEnv* env, jobject, jstring path) +{ + auto model_path = env->GetStringUTFChars(path, nullptr); + mmdeploy_model_t model{}; + auto ec = mmdeploy_model_create_by_path(model_path, &model); + env->ReleaseStringUTFChars(path, model_path); + if (ec) + { + MMDEPLOY_ERROR("failed to create model, code = {}", ec); + return -1; + } + return (jlong)model; } -void Java_mmdeploy_Model_destroy(JNIEnv *, jobject, jlong model_) { - MMDEPLOY_DEBUG("Java_mmdeploy_Model_destroy"); - mmdeploy_model_destroy((mmdeploy_model_t)model_); +void Java_mmdeploy_Model_destroy(JNIEnv*, jobject, jlong model_) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Model_destroy"); + mmdeploy_model_destroy((mmdeploy_model_t)model_); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Model.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Model.h index 11e23a1a81..9fc714c259 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Model.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Model.h @@ -3,25 +3,26 @@ /* Header for class mmdeploy_Model */ #ifndef _Included_mmdeploy_Model -#define _Included_mmdeploy_Model -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Model - * Method: create - * Signature: (Ljava/lang/String;)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Model_create(JNIEnv *, jobject, jstring); + #define _Included_mmdeploy_Model + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Model + * Method: create + * Signature: (Ljava/lang/String;)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Model_create(JNIEnv*, jobject, jstring); -/* - * Class: mmdeploy_Model - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Model_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Model + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Model_destroy(JNIEnv*, jobject, jlong); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.cpp index 4956555a6e..aac54574a0 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.cpp @@ -6,30 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_PoseDetector_create(JNIEnv *env, jobject, jstring modelPath, jstring deviceName, - jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_pose_detector_t pose_estimator{}; - auto ec = mmdeploy_pose_detector_create_by_path(model_path, device_name, (int)device_id, - &pose_estimator); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create pose estimator, code = {}", ec); - return -1; - } - return (jlong)pose_estimator; +jlong Java_mmdeploy_PoseDetector_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_pose_detector_t pose_estimator{}; + auto ec = mmdeploy_pose_detector_create_by_path(model_path, device_name, (int)device_id, &pose_estimator); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create pose estimator, code = {}", ec); + return -1; + } + return (jlong)pose_estimator; } -void Java_mmdeploy_PoseDetector_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_PoseDetector_destroy"); - mmdeploy_pose_detector_destroy((mmdeploy_pose_detector_t)handle); +void Java_mmdeploy_PoseDetector_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_PoseDetector_destroy"); + mmdeploy_pose_detector_destroy((mmdeploy_pose_detector_t)handle); } -jobjectArray Java_mmdeploy_PoseDetector_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_PoseDetector_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_pose_detection_t *results{}; auto ec = mmdeploy_pose_detector_apply((mmdeploy_pose_detector_t)handle, imgs, size, &results); if (ec) { @@ -55,6 +57,5 @@ jobjectArray Java_mmdeploy_PoseDetector_apply(JNIEnv *env, jobject thiz, jlong h env->SetObjectArrayElement(array, i, res); } mmdeploy_pose_detector_release_result(results, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.h b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.h index a50b7fd821..87c70ac0a6 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseDetector.h @@ -3,34 +3,33 @@ /* Header for class mmdeploy_PoseDetector */ #ifndef _Included_mmdeploy_PoseDetector -#define _Included_mmdeploy_PoseDetector -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_PoseDetector - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_PoseDetector_create(JNIEnv *, jobject, jstring, jstring, - jint); + #define _Included_mmdeploy_PoseDetector + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_PoseDetector + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_PoseDetector_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_PoseDetector - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_PoseDetector_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_PoseDetector + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_PoseDetector_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_PoseDetector - * Method: apply - * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/PoseDetector/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_PoseDetector_apply(JNIEnv *, jobject, jlong, - jobjectArray); + /* + * Class: mmdeploy_PoseDetector + * Method: apply + * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/PoseDetector/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_PoseDetector_apply(JNIEnv*, jobject, jlong, jobjectArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.cpp index c0d1685729..61fd42eb07 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.cpp @@ -6,143 +6,161 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_PoseTracker_create(JNIEnv *env, jobject, jlong detModel, jlong poseModel, - jlong context) { - mmdeploy_pose_tracker_t pose_tracker{}; - auto ec = mmdeploy_pose_tracker_create((mmdeploy_model_t)detModel, (mmdeploy_model_t)poseModel, - (mmdeploy_context_t)context, &pose_tracker); - if (ec) { - MMDEPLOY_ERROR("failed to create pose tracker, code = {}", ec); - return -1; - } - return (jlong)pose_tracker; +jlong Java_mmdeploy_PoseTracker_create(JNIEnv* env, jobject, jlong detModel, jlong poseModel, jlong context) +{ + mmdeploy_pose_tracker_t pose_tracker{}; + auto ec = mmdeploy_pose_tracker_create((mmdeploy_model_t)detModel, (mmdeploy_model_t)poseModel, (mmdeploy_context_t)context, &pose_tracker); + if (ec) + { + MMDEPLOY_ERROR("failed to create pose tracker, code = {}", ec); + return -1; + } + return (jlong)pose_tracker; } -void Java_mmdeploy_PoseTracker_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_PoseTracker_destroy"); - mmdeploy_pose_tracker_destroy((mmdeploy_pose_tracker_t)handle); +void Java_mmdeploy_PoseTracker_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_PoseTracker_destroy"); + mmdeploy_pose_tracker_destroy((mmdeploy_pose_tracker_t)handle); } -jobject param_cpp_to_java(JNIEnv *env, mmdeploy_pose_tracker_param_t *params) { - auto param_cls = env->FindClass("mmdeploy/PoseTracker$Params"); - auto param_ctor = env->GetMethodID(param_cls, "", "(IIFFFIFIFFF[FIFIIFF[F)V"); +jobject param_cpp_to_java(JNIEnv* env, mmdeploy_pose_tracker_param_t* params) +{ + auto param_cls = env->FindClass("mmdeploy/PoseTracker$Params"); + auto param_ctor = env->GetMethodID(param_cls, "", "(IIFFFIFIFFF[FIFIIFF[F)V"); - jfloatArray keypointSigmas = env->NewFloatArray(params->keypoint_sigmas_size); - env->SetFloatArrayRegion(keypointSigmas, 0, params->keypoint_sigmas_size, - (jfloat *)params->keypoint_sigmas); - jfloatArray smoothParams = env->NewFloatArray(3); - env->SetFloatArrayRegion(smoothParams, 0, 3, (jfloat *)params->smooth_params); + jfloatArray keypointSigmas = env->NewFloatArray(params->keypoint_sigmas_size); + env->SetFloatArrayRegion(keypointSigmas, 0, params->keypoint_sigmas_size, (jfloat*)params->keypoint_sigmas); + jfloatArray smoothParams = env->NewFloatArray(3); + env->SetFloatArrayRegion(smoothParams, 0, 3, (jfloat*)params->smooth_params); - auto param = env->NewObject( - param_cls, param_ctor, (jint)params->det_interval, (jint)params->det_label, - (jfloat)params->det_thr, (jfloat)params->det_min_bbox_size, (jfloat)params->det_nms_thr, - (jint)params->pose_max_num_bboxes, (jfloat)params->pose_kpt_thr, - (jint)params->pose_min_keypoints, (jfloat)params->pose_bbox_scale, - (jfloat)params->pose_min_bbox_size, (jfloat)params->pose_nms_thr, keypointSigmas, - (jint)params->keypoint_sigmas_size, (jfloat)params->track_iou_thr, - (jint)params->track_max_missing, (jint)params->track_history_size, - (jfloat)params->std_weight_position, (jfloat)params->std_weight_velocity, smoothParams); - return param; + auto param = env->NewObject( + param_cls, + param_ctor, + (jint)params->det_interval, + (jint)params->det_label, + (jfloat)params->det_thr, + (jfloat)params->det_min_bbox_size, + (jfloat)params->det_nms_thr, + (jint)params->pose_max_num_bboxes, + (jfloat)params->pose_kpt_thr, + (jint)params->pose_min_keypoints, + (jfloat)params->pose_bbox_scale, + (jfloat)params->pose_min_bbox_size, + (jfloat)params->pose_nms_thr, + keypointSigmas, + (jint)params->keypoint_sigmas_size, + (jfloat)params->track_iou_thr, + (jint)params->track_max_missing, + (jint)params->track_history_size, + (jfloat)params->std_weight_position, + (jfloat)params->std_weight_velocity, + smoothParams); + return param; } -void param_java_to_cpp(JNIEnv *env, mmdeploy_pose_tracker_param_t *params, jobject customParam) { - auto param_cls = env->FindClass("mmdeploy/PoseTracker$Params"); - auto param_ctor = env->GetMethodID(param_cls, "", "(IIFFFIFIFFF[FIFIIFF[F)V"); +void param_java_to_cpp(JNIEnv* env, mmdeploy_pose_tracker_param_t* params, jobject customParam) +{ + auto param_cls = env->FindClass("mmdeploy/PoseTracker$Params"); + auto param_ctor = env->GetMethodID(param_cls, "", "(IIFFFIFIFFF[FIFIIFF[F)V"); - jfieldID fieldID_detInterval = env->GetFieldID(param_cls, "detInterval", "I"); - jint detInterval = env->GetIntField(customParam, fieldID_detInterval); - params->det_interval = (int)detInterval; - jfieldID fieldID_detLabel = env->GetFieldID(param_cls, "detLabel", "I"); - jint detLabel = env->GetIntField(customParam, fieldID_detLabel); - params->det_label = (int)detLabel; - jfieldID fieldID_detThr = env->GetFieldID(param_cls, "detThr", "F"); - jfloat detThr = env->GetFloatField(customParam, fieldID_detThr); - params->det_thr = (float)detThr; - jfieldID fieldID_detMinBboxSize = env->GetFieldID(param_cls, "detMinBboxSize", "F"); - jfloat detMinBboxSize = env->GetFloatField(customParam, fieldID_detMinBboxSize); - params->det_min_bbox_size = (float)detMinBboxSize; - jfieldID fieldID_detNmsThr = env->GetFieldID(param_cls, "detNmsThr", "F"); - jfloat detNmsThr = env->GetFloatField(customParam, fieldID_detNmsThr); - params->det_nms_thr = (float)detNmsThr; - jfieldID fieldID_poseMaxNumBboxes = env->GetFieldID(param_cls, "poseMaxNumBboxes", "I"); - jint poseMaxNumBboxes = env->GetIntField(customParam, fieldID_poseMaxNumBboxes); - params->pose_max_num_bboxes = (int)poseMaxNumBboxes; - jfieldID fieldID_poseKptThr = env->GetFieldID(param_cls, "poseKptThr", "F"); - jfloat poseKptThr = env->GetFloatField(customParam, fieldID_poseKptThr); - params->pose_kpt_thr = (float)poseKptThr; - jfieldID fieldID_poseMinKeypoints = env->GetFieldID(param_cls, "poseMinKeypoints", "I"); - jint poseMinKeypoints = env->GetIntField(customParam, fieldID_poseMinKeypoints); - params->pose_min_keypoints = (int)poseMinKeypoints; - jfieldID fieldID_poseBboxScale = env->GetFieldID(param_cls, "poseBboxScale", "F"); - jfloat poseBboxScale = env->GetFloatField(customParam, fieldID_poseBboxScale); - params->pose_bbox_scale = (float)poseBboxScale; - jfieldID fieldID_poseMinBboxSize = env->GetFieldID(param_cls, "poseMinBboxSize", "F"); - jfloat poseMinBboxSize = env->GetFloatField(customParam, fieldID_poseMinBboxSize); - params->pose_min_bbox_size = (float)poseMinBboxSize; - jfieldID fieldID_poseNmsThr = env->GetFieldID(param_cls, "poseNmsThr", "F"); - jfloat poseNmsThr = env->GetFloatField(customParam, fieldID_poseNmsThr); - params->pose_nms_thr = (float)poseNmsThr; - jfieldID fieldID_keypointSigmas = env->GetFieldID(param_cls, "keypointSigmas", "[F"); - auto keypointSigmasObj = env->GetObjectField(customParam, fieldID_keypointSigmas); - float *keypointSigmas = - (float *)env->GetFloatArrayElements((jfloatArray)keypointSigmasObj, nullptr); - params->keypoint_sigmas = keypointSigmas; - env->ReleaseFloatArrayElements((jfloatArray)keypointSigmasObj, keypointSigmas, JNI_ABORT); - jfieldID fieldID_keypointSigmasSize = env->GetFieldID(param_cls, "keypointSigmasSize", "I"); - jint keypointSigmasSize = env->GetIntField(customParam, fieldID_keypointSigmasSize); - params->keypoint_sigmas_size = keypointSigmasSize; - jfieldID fieldID_trackIouThr = env->GetFieldID(param_cls, "trackIouThr", "F"); - jfloat trackIouThr = env->GetFloatField(customParam, fieldID_trackIouThr); - params->track_iou_thr = trackIouThr; - jfieldID fieldID_trackMaxMissing = env->GetFieldID(param_cls, "trackMaxMissing", "I"); - jint trackMaxMissing = env->GetIntField(customParam, fieldID_trackMaxMissing); - params->track_max_missing = trackMaxMissing; - jfieldID fieldID_trackHistorySize = env->GetFieldID(param_cls, "trackHistorySize", "I"); - jint trackHistorySize = env->GetIntField(customParam, fieldID_trackHistorySize); - params->track_history_size = trackHistorySize; - jfieldID fieldID_stdWeightPosition = env->GetFieldID(param_cls, "stdWeightPosition", "F"); - jfloat stdWeightPosition = env->GetFloatField(customParam, fieldID_stdWeightPosition); - params->std_weight_position = stdWeightPosition; - jfieldID fieldID_stdWeightVelocity = env->GetFieldID(param_cls, "stdWeightVelocity", "F"); - jfloat stdWeightVelocity = env->GetFloatField(customParam, fieldID_stdWeightVelocity); - params->std_weight_velocity = stdWeightVelocity; - jfieldID fieldID_smoothParams = env->GetFieldID(param_cls, "smoothParams", "[F"); - auto smoothParamsObj = env->GetObjectField(customParam, fieldID_smoothParams); - float *smoothParams = (float *)env->GetFloatArrayElements((jfloatArray)smoothParamsObj, nullptr); - params->smooth_params[0] = smoothParams[0]; - params->smooth_params[1] = smoothParams[1]; - params->smooth_params[2] = smoothParams[2]; - env->ReleaseFloatArrayElements((jfloatArray)smoothParamsObj, smoothParams, JNI_ABORT); + jfieldID fieldID_detInterval = env->GetFieldID(param_cls, "detInterval", "I"); + jint detInterval = env->GetIntField(customParam, fieldID_detInterval); + params->det_interval = (int)detInterval; + jfieldID fieldID_detLabel = env->GetFieldID(param_cls, "detLabel", "I"); + jint detLabel = env->GetIntField(customParam, fieldID_detLabel); + params->det_label = (int)detLabel; + jfieldID fieldID_detThr = env->GetFieldID(param_cls, "detThr", "F"); + jfloat detThr = env->GetFloatField(customParam, fieldID_detThr); + params->det_thr = (float)detThr; + jfieldID fieldID_detMinBboxSize = env->GetFieldID(param_cls, "detMinBboxSize", "F"); + jfloat detMinBboxSize = env->GetFloatField(customParam, fieldID_detMinBboxSize); + params->det_min_bbox_size = (float)detMinBboxSize; + jfieldID fieldID_detNmsThr = env->GetFieldID(param_cls, "detNmsThr", "F"); + jfloat detNmsThr = env->GetFloatField(customParam, fieldID_detNmsThr); + params->det_nms_thr = (float)detNmsThr; + jfieldID fieldID_poseMaxNumBboxes = env->GetFieldID(param_cls, "poseMaxNumBboxes", "I"); + jint poseMaxNumBboxes = env->GetIntField(customParam, fieldID_poseMaxNumBboxes); + params->pose_max_num_bboxes = (int)poseMaxNumBboxes; + jfieldID fieldID_poseKptThr = env->GetFieldID(param_cls, "poseKptThr", "F"); + jfloat poseKptThr = env->GetFloatField(customParam, fieldID_poseKptThr); + params->pose_kpt_thr = (float)poseKptThr; + jfieldID fieldID_poseMinKeypoints = env->GetFieldID(param_cls, "poseMinKeypoints", "I"); + jint poseMinKeypoints = env->GetIntField(customParam, fieldID_poseMinKeypoints); + params->pose_min_keypoints = (int)poseMinKeypoints; + jfieldID fieldID_poseBboxScale = env->GetFieldID(param_cls, "poseBboxScale", "F"); + jfloat poseBboxScale = env->GetFloatField(customParam, fieldID_poseBboxScale); + params->pose_bbox_scale = (float)poseBboxScale; + jfieldID fieldID_poseMinBboxSize = env->GetFieldID(param_cls, "poseMinBboxSize", "F"); + jfloat poseMinBboxSize = env->GetFloatField(customParam, fieldID_poseMinBboxSize); + params->pose_min_bbox_size = (float)poseMinBboxSize; + jfieldID fieldID_poseNmsThr = env->GetFieldID(param_cls, "poseNmsThr", "F"); + jfloat poseNmsThr = env->GetFloatField(customParam, fieldID_poseNmsThr); + params->pose_nms_thr = (float)poseNmsThr; + jfieldID fieldID_keypointSigmas = env->GetFieldID(param_cls, "keypointSigmas", "[F"); + auto keypointSigmasObj = env->GetObjectField(customParam, fieldID_keypointSigmas); + float* keypointSigmas = + (float*)env->GetFloatArrayElements((jfloatArray)keypointSigmasObj, nullptr); + params->keypoint_sigmas = keypointSigmas; + env->ReleaseFloatArrayElements((jfloatArray)keypointSigmasObj, keypointSigmas, JNI_ABORT); + jfieldID fieldID_keypointSigmasSize = env->GetFieldID(param_cls, "keypointSigmasSize", "I"); + jint keypointSigmasSize = env->GetIntField(customParam, fieldID_keypointSigmasSize); + params->keypoint_sigmas_size = keypointSigmasSize; + jfieldID fieldID_trackIouThr = env->GetFieldID(param_cls, "trackIouThr", "F"); + jfloat trackIouThr = env->GetFloatField(customParam, fieldID_trackIouThr); + params->track_iou_thr = trackIouThr; + jfieldID fieldID_trackMaxMissing = env->GetFieldID(param_cls, "trackMaxMissing", "I"); + jint trackMaxMissing = env->GetIntField(customParam, fieldID_trackMaxMissing); + params->track_max_missing = trackMaxMissing; + jfieldID fieldID_trackHistorySize = env->GetFieldID(param_cls, "trackHistorySize", "I"); + jint trackHistorySize = env->GetIntField(customParam, fieldID_trackHistorySize); + params->track_history_size = trackHistorySize; + jfieldID fieldID_stdWeightPosition = env->GetFieldID(param_cls, "stdWeightPosition", "F"); + jfloat stdWeightPosition = env->GetFloatField(customParam, fieldID_stdWeightPosition); + params->std_weight_position = stdWeightPosition; + jfieldID fieldID_stdWeightVelocity = env->GetFieldID(param_cls, "stdWeightVelocity", "F"); + jfloat stdWeightVelocity = env->GetFloatField(customParam, fieldID_stdWeightVelocity); + params->std_weight_velocity = stdWeightVelocity; + jfieldID fieldID_smoothParams = env->GetFieldID(param_cls, "smoothParams", "[F"); + auto smoothParamsObj = env->GetObjectField(customParam, fieldID_smoothParams); + float* smoothParams = (float*)env->GetFloatArrayElements((jfloatArray)smoothParamsObj, nullptr); + params->smooth_params[0] = smoothParams[0]; + params->smooth_params[1] = smoothParams[1]; + params->smooth_params[2] = smoothParams[2]; + env->ReleaseFloatArrayElements((jfloatArray)smoothParamsObj, smoothParams, JNI_ABORT); } -jobject Java_mmdeploy_PoseTracker_setDefaultParams(JNIEnv *env, jobject) { - mmdeploy_pose_tracker_param_t params{}; - mmdeploy_pose_tracker_default_params(¶ms); - return param_cpp_to_java(env, ¶ms); +jobject Java_mmdeploy_PoseTracker_setDefaultParams(JNIEnv* env, jobject) +{ + mmdeploy_pose_tracker_param_t params{}; + mmdeploy_pose_tracker_default_params(¶ms); + return param_cpp_to_java(env, ¶ms); } -jlong Java_mmdeploy_PoseTracker_createState(JNIEnv *env, jobject, jlong pipeline, - jobject paramsObject) { - mmdeploy_pose_tracker_state_t state{}; - mmdeploy_pose_tracker_param_t params{}; - param_java_to_cpp(env, ¶ms, paramsObject); - auto ec = mmdeploy_pose_tracker_create_state((mmdeploy_pose_tracker_t)pipeline, ¶ms, &state); - if (ec) { - MMDEPLOY_ERROR("failed to create pose tracker state, code = {}", ec); - return -1; - } - return (jlong)state; +jlong Java_mmdeploy_PoseTracker_createState(JNIEnv* env, jobject, jlong pipeline, jobject paramsObject) +{ + mmdeploy_pose_tracker_state_t state{}; + mmdeploy_pose_tracker_param_t params{}; + param_java_to_cpp(env, ¶ms, paramsObject); + auto ec = mmdeploy_pose_tracker_create_state((mmdeploy_pose_tracker_t)pipeline, ¶ms, &state); + if (ec) + { + MMDEPLOY_ERROR("failed to create pose tracker state, code = {}", ec); + return -1; + } + return (jlong)state; } -void Java_mmdeploy_PoseTracker_destroyState(JNIEnv *, jobject, jlong state) { - MMDEPLOY_DEBUG("Java_mmdeploy_PoseTracker_destroy"); - mmdeploy_pose_tracker_destroy_state((mmdeploy_pose_tracker_state_t)state); +void Java_mmdeploy_PoseTracker_destroyState(JNIEnv*, jobject, jlong state) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_PoseTracker_destroy"); + mmdeploy_pose_tracker_destroy_state((mmdeploy_pose_tracker_state_t)state); } -jobjectArray Java_mmdeploy_PoseTracker_apply(JNIEnv *env, jobject thiz, jlong handle, - jlongArray states, jobjectArray frames, - jintArray detects, jintArray counts) { - return With(env, frames, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_PoseTracker_apply(JNIEnv* env, jobject thiz, jlong handle, jlongArray states, jobjectArray frames, jintArray detects, jintArray counts) +{ + return With(env, frames, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_pose_tracker_target_t *results{}; int *result_count{}; auto states_array = env->GetLongArrayElements(states, nullptr); @@ -189,6 +207,5 @@ jobjectArray Java_mmdeploy_PoseTracker_apply(JNIEnv *env, jobject thiz, jlong ha env->ReleaseLongArrayElements(states, states_array, 0); env->ReleaseIntArrayElements(detects, detects_array, 0); mmdeploy_pose_tracker_release_result(results, result_count, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.h b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.h index 8e8d3905c8..1de79b1eaa 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_PoseTracker.h @@ -3,54 +3,54 @@ /* Header for class mmdeploy_PoseTracker */ #ifndef _Included_mmdeploy_PoseTracker -#define _Included_mmdeploy_PoseTracker -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_PoseTracker - * Method: create - * Signature: (JJJ)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_PoseTracker_create(JNIEnv *, jobject, jlong, jlong, jlong); + #define _Included_mmdeploy_PoseTracker + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_PoseTracker + * Method: create + * Signature: (JJJ)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_PoseTracker_create(JNIEnv*, jobject, jlong, jlong, jlong); -/* - * Class: mmdeploy_PoseTracker - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_PoseTracker_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_PoseTracker + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_PoseTracker_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_PoseTracker - * Method: createState - * Signature: (JLmmdeploy/PoseTracker/Params;)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_PoseTracker_createState(JNIEnv *, jobject, jlong, jobject); + /* + * Class: mmdeploy_PoseTracker + * Method: createState + * Signature: (JLmmdeploy/PoseTracker/Params;)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_PoseTracker_createState(JNIEnv*, jobject, jlong, jobject); -/* - * Class: mmdeploy_PoseTracker - * Method: destroyState - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_PoseTracker_destroyState(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_PoseTracker + * Method: destroyState + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_PoseTracker_destroyState(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_PoseTracker - * Method: setDefaultParams - * Signature: ()Lmmdeploy/PoseTracker/Params; - */ -JNIEXPORT jobject JNICALL Java_mmdeploy_PoseTracker_setDefaultParams(JNIEnv *, jobject); + /* + * Class: mmdeploy_PoseTracker + * Method: setDefaultParams + * Signature: ()Lmmdeploy/PoseTracker/Params; + */ + JNIEXPORT jobject JNICALL Java_mmdeploy_PoseTracker_setDefaultParams(JNIEnv*, jobject); -/* - * Class: mmdeploy_PoseTracker - * Method: apply - * Signature: (J[J[Lmmdeploy/Mat;[I[I)[Lmmdeploy/PoseTracker/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_PoseTracker_apply(JNIEnv *, jobject, jlong, jlongArray, - jobjectArray, jintArray, jintArray); + /* + * Class: mmdeploy_PoseTracker + * Method: apply + * Signature: (J[J[Lmmdeploy/Mat;[I[I)[Lmmdeploy/PoseTracker/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_PoseTracker_apply(JNIEnv*, jobject, jlong, jlongArray, jobjectArray, jintArray, jintArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.cpp index 2c63233c5c..2ff419ec7a 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.cpp @@ -6,19 +6,22 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Profiler_create(JNIEnv *env, jobject, jstring path) { - auto profiler_path = env->GetStringUTFChars(path, nullptr); - mmdeploy_profiler_t profiler{}; - auto ec = mmdeploy_profiler_create(profiler_path, &profiler); - env->ReleaseStringUTFChars(path, profiler_path); - if (ec) { - MMDEPLOY_ERROR("failed to create profiler, code = {}", ec); - return -1; - } - return (jlong)profiler; +jlong Java_mmdeploy_Profiler_create(JNIEnv* env, jobject, jstring path) +{ + auto profiler_path = env->GetStringUTFChars(path, nullptr); + mmdeploy_profiler_t profiler{}; + auto ec = mmdeploy_profiler_create(profiler_path, &profiler); + env->ReleaseStringUTFChars(path, profiler_path); + if (ec) + { + MMDEPLOY_ERROR("failed to create profiler, code = {}", ec); + return -1; + } + return (jlong)profiler; } -void Java_mmdeploy_Profiler_destroy(JNIEnv *, jobject, jlong profiler_) { - MMDEPLOY_DEBUG("Java_mmdeploy_Profiler_destroy"); - mmdeploy_profiler_destroy((mmdeploy_profiler_t)profiler_); +void Java_mmdeploy_Profiler_destroy(JNIEnv*, jobject, jlong profiler_) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Profiler_destroy"); + mmdeploy_profiler_destroy((mmdeploy_profiler_t)profiler_); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.h index 2bcdbc42cc..9e829ad38c 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Profiler.h @@ -3,25 +3,26 @@ /* Header for class mmdeploy_Profiler */ #ifndef _Included_mmdeploy_Profiler -#define _Included_mmdeploy_Profiler -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Profiler - * Method: create - * Signature: (Ljava/lang/String;)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Profiler_create(JNIEnv *, jobject, jstring); + #define _Included_mmdeploy_Profiler + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Profiler + * Method: create + * Signature: (Ljava/lang/String;)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Profiler_create(JNIEnv*, jobject, jstring); -/* - * Class: mmdeploy_Profiler - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Profiler_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Profiler + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Profiler_destroy(JNIEnv*, jobject, jlong); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.cpp index f124d5edae..abc630afa6 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.cpp @@ -6,29 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Restorer_create(JNIEnv *env, jobject, jstring modelPath, jstring deviceName, - jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_restorer_t restorer{}; - auto ec = mmdeploy_restorer_create_by_path(model_path, device_name, (int)device_id, &restorer); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create restorer, code = {}", ec); - return -1; - } - return (jlong)restorer; +jlong Java_mmdeploy_Restorer_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_restorer_t restorer{}; + auto ec = mmdeploy_restorer_create_by_path(model_path, device_name, (int)device_id, &restorer); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create restorer, code = {}", ec); + return -1; + } + return (jlong)restorer; } -void Java_mmdeploy_Restorer_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_Restorer_destroy"); - mmdeploy_restorer_destroy((mmdeploy_restorer_t)handle); +void Java_mmdeploy_Restorer_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Restorer_destroy"); + mmdeploy_restorer_destroy((mmdeploy_restorer_t)handle); } -jobjectArray Java_mmdeploy_Restorer_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_Restorer_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_mat_t *results{}; auto ec = mmdeploy_restorer_apply((mmdeploy_restorer_t)handle, imgs, size, &results); if (ec) { @@ -68,6 +71,5 @@ jobjectArray Java_mmdeploy_Restorer_apply(JNIEnv *env, jobject thiz, jlong handl current_result++; } mmdeploy_restorer_release_result(results, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.h index 78b09787fe..7a4aec079b 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Restorer.h @@ -3,32 +3,33 @@ /* Header for class mmdeploy_Restorer */ #ifndef _Included_mmdeploy_Restorer -#define _Included_mmdeploy_Restorer -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Restorer - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Restorer_create(JNIEnv *, jobject, jstring, jstring, jint); + #define _Included_mmdeploy_Restorer + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Restorer + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Restorer_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_Restorer - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Restorer_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Restorer + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Restorer_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_Restorer - * Method: apply - * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/Restorer/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Restorer_apply(JNIEnv *, jobject, jlong, jobjectArray); + /* + * Class: mmdeploy_Restorer + * Method: apply + * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/Restorer/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Restorer_apply(JNIEnv*, jobject, jlong, jobjectArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.cpp index 3872e7e158..9b34659aa5 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.cpp @@ -6,30 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_RotatedDetector_create(JNIEnv *env, jobject, jstring modelPath, - jstring deviceName, jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_rotated_detector_t rotated_detector{}; - auto ec = mmdeploy_rotated_detector_create_by_path(model_path, device_name, (int)device_id, - &rotated_detector); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create rotated detector, code = {}", ec); - return -1; - } - return (jlong)rotated_detector; +jlong Java_mmdeploy_RotatedDetector_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_rotated_detector_t rotated_detector{}; + auto ec = mmdeploy_rotated_detector_create_by_path(model_path, device_name, (int)device_id, &rotated_detector); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create rotated detector, code = {}", ec); + return -1; + } + return (jlong)rotated_detector; } -void Java_mmdeploy_RotatedDetector_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_RotatedDetector_destroy"); - mmdeploy_rotated_detector_destroy((mmdeploy_rotated_detector_t)handle); +void Java_mmdeploy_RotatedDetector_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_RotatedDetector_destroy"); + mmdeploy_rotated_detector_destroy((mmdeploy_rotated_detector_t)handle); } -jobjectArray Java_mmdeploy_RotatedDetector_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images, jintArray counts) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_RotatedDetector_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images, jintArray counts) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_rotated_detection_t *results{}; int *result_count{}; auto ec = mmdeploy_rotated_detector_apply((mmdeploy_rotated_detector_t)handle, imgs, size, @@ -56,6 +58,5 @@ jobjectArray Java_mmdeploy_RotatedDetector_apply(JNIEnv *env, jobject thiz, jlon } env->ReleaseIntArrayElements(counts, counts_array, 0); mmdeploy_rotated_detector_release_result(results, result_count); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.h b/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.h index 6de527ec40..7327b791ea 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_RotatedDetector.h @@ -3,34 +3,33 @@ /* Header for class mmdeploy_RotatedDetector */ #ifndef _Included_mmdeploy_RotatedDetector -#define _Included_mmdeploy_RotatedDetector -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_RotatedDetector - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_RotatedDetector_create(JNIEnv *, jobject, jstring, jstring, - jint); + #define _Included_mmdeploy_RotatedDetector + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_RotatedDetector + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_RotatedDetector_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_RotatedDetector - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_RotatedDetector_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_RotatedDetector + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_RotatedDetector_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_RotatedDetector - * Method: apply - * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/RotatedDetector/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_RotatedDetector_apply(JNIEnv *, jobject, jlong, - jobjectArray, jintArray); + /* + * Class: mmdeploy_RotatedDetector + * Method: apply + * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/RotatedDetector/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_RotatedDetector_apply(JNIEnv*, jobject, jlong, jobjectArray, jintArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.cpp index 2c1f1c42c0..3ab391c44d 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.cpp @@ -7,17 +7,20 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Scheduler_createThreadPool(JNIEnv *env, jobject, jint numThreads) { - mmdeploy_scheduler_t scheduler = mmdeploy_executor_create_thread_pool((int)numThreads); - return (jlong)scheduler; +jlong Java_mmdeploy_Scheduler_createThreadPool(JNIEnv* env, jobject, jint numThreads) +{ + mmdeploy_scheduler_t scheduler = mmdeploy_executor_create_thread_pool((int)numThreads); + return (jlong)scheduler; } -jlong Java_mmdeploy_Scheduler_createThread(JNIEnv *env, jobject) { - mmdeploy_scheduler_t scheduler = mmdeploy_executor_create_thread(); - return (jlong)scheduler; +jlong Java_mmdeploy_Scheduler_createThread(JNIEnv* env, jobject) +{ + mmdeploy_scheduler_t scheduler = mmdeploy_executor_create_thread(); + return (jlong)scheduler; } -void Java_mmdeploy_Scheduler_destroy(JNIEnv *, jobject, jlong scheduler_) { - MMDEPLOY_DEBUG("Java_mmdeploy_Scheduler_destroy"); - mmdeploy_scheduler_destroy((mmdeploy_scheduler_t)scheduler_); +void Java_mmdeploy_Scheduler_destroy(JNIEnv*, jobject, jlong scheduler_) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Scheduler_destroy"); + mmdeploy_scheduler_destroy((mmdeploy_scheduler_t)scheduler_); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.h index 363015cf95..8774db0fc7 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Scheduler.h @@ -3,32 +3,33 @@ /* Header for class mmdeploy_Scheduler */ #ifndef _Included_mmdeploy_Scheduler -#define _Included_mmdeploy_Scheduler -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Scheduler - * Method: createThreadPool - * Signature: (I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Scheduler_createThreadPool(JNIEnv *, jclass, jint); + #define _Included_mmdeploy_Scheduler + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Scheduler + * Method: createThreadPool + * Signature: (I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Scheduler_createThreadPool(JNIEnv*, jclass, jint); -/* - * Class: mmdeploy_Scheduler - * Method: createThread - * Signature: ()J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Scheduler_createThread(JNIEnv *, jclass); + /* + * Class: mmdeploy_Scheduler + * Method: createThread + * Signature: ()J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Scheduler_createThread(JNIEnv*, jclass); -/* - * Class: mmdeploy_Scheduler - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Scheduler_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Scheduler + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Scheduler_destroy(JNIEnv*, jobject, jlong); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.cpp index 12df31a49e..8942041c8c 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.cpp @@ -6,29 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_Segmentor_create(JNIEnv *env, jobject, jstring modelPath, jstring deviceName, - jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_segmentor_t segmentor{}; - auto ec = mmdeploy_segmentor_create_by_path(model_path, device_name, (int)device_id, &segmentor); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create segmentor, code = {}", ec); - return -1; - } - return (jlong)segmentor; +jlong Java_mmdeploy_Segmentor_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_segmentor_t segmentor{}; + auto ec = mmdeploy_segmentor_create_by_path(model_path, device_name, (int)device_id, &segmentor); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create segmentor, code = {}", ec); + return -1; + } + return (jlong)segmentor; } -void Java_mmdeploy_Segmentor_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_Segmentor_destroy"); - mmdeploy_segmentor_destroy((mmdeploy_segmentor_t)handle); +void Java_mmdeploy_Segmentor_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_Segmentor_destroy"); + mmdeploy_segmentor_destroy((mmdeploy_segmentor_t)handle); } -jobjectArray Java_mmdeploy_Segmentor_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_Segmentor_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_segmentation_t *results{}; auto ec = mmdeploy_segmentor_apply((mmdeploy_segmentor_t)handle, imgs, size, &results); if (ec) { @@ -65,6 +68,5 @@ jobjectArray Java_mmdeploy_Segmentor_apply(JNIEnv *env, jobject thiz, jlong hand env->SetObjectArrayElement(array, i, res); } mmdeploy_segmentor_release_result(results, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.h b/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.h index afdf157bec..ec42c52dd5 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_Segmentor.h @@ -3,33 +3,33 @@ /* Header for class mmdeploy_Segmentor */ #ifndef _Included_mmdeploy_Segmentor -#define _Included_mmdeploy_Segmentor -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_Segmentor - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_Segmentor_create(JNIEnv *, jobject, jstring, jstring, jint); + #define _Included_mmdeploy_Segmentor + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_Segmentor + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_Segmentor_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_Segmentor - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_Segmentor_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_Segmentor + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_Segmentor_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_Segmentor - * Method: apply - * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/Segmentor/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Segmentor_apply(JNIEnv *, jobject, jlong, - jobjectArray); + /* + * Class: mmdeploy_Segmentor + * Method: apply + * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/Segmentor/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_Segmentor_apply(JNIEnv*, jobject, jlong, jobjectArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.cpp index 943d1e625b..adc1abe5cd 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.cpp @@ -6,30 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_TextDetector_create(JNIEnv *env, jobject, jstring modelPath, jstring deviceName, - jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_text_detector_t text_detector{}; - auto ec = mmdeploy_text_detector_create_by_path(model_path, device_name, (int)device_id, - &text_detector); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create text_detector, code = {}", ec); - return -1; - } - return (jlong)text_detector; +jlong Java_mmdeploy_TextDetector_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_text_detector_t text_detector{}; + auto ec = mmdeploy_text_detector_create_by_path(model_path, device_name, (int)device_id, &text_detector); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create text_detector, code = {}", ec); + return -1; + } + return (jlong)text_detector; } -void Java_mmdeploy_TextDetector_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_TextDetector_destroy"); - mmdeploy_text_detector_destroy((mmdeploy_text_detector_t)handle); +void Java_mmdeploy_TextDetector_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_TextDetector_destroy"); + mmdeploy_text_detector_destroy((mmdeploy_text_detector_t)handle); } -jobjectArray Java_mmdeploy_TextDetector_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images, jintArray counts) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_TextDetector_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images, jintArray counts) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_text_detection_t *results{}; int *result_count{}; auto ec = mmdeploy_text_detector_apply((mmdeploy_text_detector_t)handle, imgs, size, &results, @@ -61,6 +63,5 @@ jobjectArray Java_mmdeploy_TextDetector_apply(JNIEnv *env, jobject thiz, jlong h } env->ReleaseIntArrayElements(counts, counts_array, 0); mmdeploy_text_detector_release_result(results, result_count, size); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.h b/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.h index dc5574f77b..6a5df47924 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_TextDetector.h @@ -3,34 +3,33 @@ /* Header for class mmdeploy_TextDetector */ #ifndef _Included_mmdeploy_TextDetector -#define _Included_mmdeploy_TextDetector -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_TextDetector - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_TextDetector_create(JNIEnv *, jobject, jstring, jstring, - jint); + #define _Included_mmdeploy_TextDetector + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_TextDetector + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_TextDetector_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_TextDetector - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_TextDetector_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_TextDetector + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_TextDetector_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_TextDetector - * Method: apply - * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/TextDetector/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_TextDetector_apply(JNIEnv *, jobject, jlong, - jobjectArray, jintArray); + /* + * Class: mmdeploy_TextDetector + * Method: apply + * Signature: (J[Lmmdeploy/Mat;[I)[Lmmdeploy/TextDetector/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_TextDetector_apply(JNIEnv*, jobject, jlong, jobjectArray, jintArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.cpp b/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.cpp index 06987fb623..607b7c2ee8 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.cpp +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.cpp @@ -6,30 +6,32 @@ #include "mmdeploy/apis/java/native/common.h" #include "mmdeploy/core/logger.h" -jlong Java_mmdeploy_TextRecognizer_create(JNIEnv *env, jobject, jstring modelPath, - jstring deviceName, jint device_id) { - auto model_path = env->GetStringUTFChars(modelPath, nullptr); - auto device_name = env->GetStringUTFChars(deviceName, nullptr); - mmdeploy_text_recognizer_t text_recognizer{}; - auto ec = mmdeploy_text_recognizer_create_by_path(model_path, device_name, (int)device_id, - &text_recognizer); - env->ReleaseStringUTFChars(modelPath, model_path); - env->ReleaseStringUTFChars(deviceName, device_name); - if (ec) { - MMDEPLOY_ERROR("failed to create text recognizer, code = {}", ec); - return -1; - } - return (jlong)text_recognizer; +jlong Java_mmdeploy_TextRecognizer_create(JNIEnv* env, jobject, jstring modelPath, jstring deviceName, jint device_id) +{ + auto model_path = env->GetStringUTFChars(modelPath, nullptr); + auto device_name = env->GetStringUTFChars(deviceName, nullptr); + mmdeploy_text_recognizer_t text_recognizer{}; + auto ec = mmdeploy_text_recognizer_create_by_path(model_path, device_name, (int)device_id, &text_recognizer); + env->ReleaseStringUTFChars(modelPath, model_path); + env->ReleaseStringUTFChars(deviceName, device_name); + if (ec) + { + MMDEPLOY_ERROR("failed to create text recognizer, code = {}", ec); + return -1; + } + return (jlong)text_recognizer; } -void Java_mmdeploy_TextRecognizer_destroy(JNIEnv *, jobject, jlong handle) { - MMDEPLOY_DEBUG("Java_mmdeploy_TextRecognizer_destroy"); // maybe use info? - mmdeploy_text_recognizer_destroy((mmdeploy_text_recognizer_t)handle); +void Java_mmdeploy_TextRecognizer_destroy(JNIEnv*, jobject, jlong handle) +{ + MMDEPLOY_DEBUG("Java_mmdeploy_TextRecognizer_destroy"); // maybe use info? + mmdeploy_text_recognizer_destroy((mmdeploy_text_recognizer_t)handle); } -jobjectArray Java_mmdeploy_TextRecognizer_apply(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray { +jobjectArray Java_mmdeploy_TextRecognizer_apply(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) -> jobjectArray + { mmdeploy_text_recognition_t *results{}; auto ec = mmdeploy_text_recognizer_apply((mmdeploy_text_recognizer_t)handle, imgs, size, &results); @@ -51,13 +53,12 @@ jobjectArray Java_mmdeploy_TextRecognizer_apply(JNIEnv *env, jobject thiz, jlong env->SetObjectArrayElement(array, i, res); } mmdeploy_text_recognizer_release_result(results, size); - return array; - }); + return array; }); } -jobjectArray Java_mmdeploy_TextRecognizer_applyBbox(JNIEnv *env, jobject thiz, jlong handle, - jobjectArray images, jobjectArray bboxes, - jintArray bbox_count) { - return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) { +jobjectArray Java_mmdeploy_TextRecognizer_applyBbox(JNIEnv* env, jobject thiz, jlong handle, jobjectArray images, jobjectArray bboxes, jintArray bbox_count) +{ + return With(env, images, [&](const mmdeploy_mat_t imgs[], int size) + { mmdeploy_text_recognition_t *recog_results{}; auto *det_results = new mmdeploy_text_detection_t[env->GetArrayLength(bboxes)]; int *det_result_count = new int[env->GetArrayLength(bbox_count)]; @@ -100,6 +101,5 @@ jobjectArray Java_mmdeploy_TextRecognizer_applyBbox(JNIEnv *env, jobject thiz, j } mmdeploy_text_recognizer_release_result(recog_results, size); mmdeploy_text_detector_release_result(det_results, det_result_count, 1); - return array; - }); + return array; }); } diff --git a/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.h b/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.h index 721c17f2b6..13ed048b7e 100644 --- a/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.h +++ b/csrc/mmdeploy/apis/java/native/mmdeploy_TextRecognizer.h @@ -3,43 +3,40 @@ /* Header for class mmdeploy_TextRecognizer */ #ifndef _Included_mmdeploy_TextRecognizer -#define _Included_mmdeploy_TextRecognizer -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: mmdeploy_TextRecognizer - * Method: create - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_mmdeploy_TextRecognizer_create(JNIEnv *, jobject, jstring, jstring, - jint); + #define _Included_mmdeploy_TextRecognizer + #ifdef __cplusplus +extern "C" +{ + #endif + /* + * Class: mmdeploy_TextRecognizer + * Method: create + * Signature: (Ljava/lang/String;Ljava/lang/String;I)J + */ + JNIEXPORT jlong JNICALL Java_mmdeploy_TextRecognizer_create(JNIEnv*, jobject, jstring, jstring, jint); -/* - * Class: mmdeploy_TextRecognizer - * Method: destroy - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_mmdeploy_TextRecognizer_destroy(JNIEnv *, jobject, jlong); + /* + * Class: mmdeploy_TextRecognizer + * Method: destroy + * Signature: (J)V + */ + JNIEXPORT void JNICALL Java_mmdeploy_TextRecognizer_destroy(JNIEnv*, jobject, jlong); -/* - * Class: mmdeploy_TextRecognizer - * Method: apply - * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/TextRecognizer/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_TextRecognizer_apply(JNIEnv *, jobject, jlong, - jobjectArray); + /* + * Class: mmdeploy_TextRecognizer + * Method: apply + * Signature: (J[Lmmdeploy/Mat;)[Lmmdeploy/TextRecognizer/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_TextRecognizer_apply(JNIEnv*, jobject, jlong, jobjectArray); -/* - * Class: mmdeploy_TextRecognizer - * Method: applyBbox - * Signature: (J[Lmmdeploy/Mat;[Lmmdeploy/TextDetector/Result;[I)[Lmmdeploy/TextRecognizer/Result; - */ -JNIEXPORT jobjectArray JNICALL Java_mmdeploy_TextRecognizer_applyBbox(JNIEnv *, jobject, jlong, - jobjectArray, jobjectArray, - jintArray); + /* + * Class: mmdeploy_TextRecognizer + * Method: applyBbox + * Signature: (J[Lmmdeploy/Mat;[Lmmdeploy/TextDetector/Result;[I)[Lmmdeploy/TextRecognizer/Result; + */ + JNIEXPORT jobjectArray JNICALL Java_mmdeploy_TextRecognizer_applyBbox(JNIEnv*, jobject, jlong, jobjectArray, jobjectArray, jintArray); -#ifdef __cplusplus + #ifdef __cplusplus } -#endif + #endif #endif diff --git a/csrc/mmdeploy/apis/python/classifier.cpp b/csrc/mmdeploy/apis/python/classifier.cpp index 9916909c86..983b3357b5 100644 --- a/csrc/mmdeploy/apis/python/classifier.cpp +++ b/csrc/mmdeploy/apis/python/classifier.cpp @@ -4,64 +4,76 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyClassifier { - public: - PyClassifier(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_classifier_create_by_path(model_path, device_name, device_id, &classifier_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create classifier"); - } - } - ~PyClassifier() { - mmdeploy_classifier_destroy(classifier_); - classifier_ = {}; - } + class PyClassifier + { + public: + PyClassifier(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_classifier_create_by_path(model_path, device_name, device_id, &classifier_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create classifier"); + } + } + ~PyClassifier() + { + mmdeploy_classifier_destroy(classifier_); + classifier_ = {}; + } - std::vector>> Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } - mmdeploy_classification_t* results{}; - int* result_count{}; - auto status = mmdeploy_classifier_apply(classifier_, mats.data(), (int)mats.size(), &results, - &result_count); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply classifier, code: " + std::to_string(status)); - } - auto output = std::vector>>{}; - output.reserve(mats.size()); - auto result_ptr = results; - for (int i = 0; i < mats.size(); ++i) { - std::vector> label_score; - for (int j = 0; j < result_count[i]; ++j) { - label_score.emplace_back(result_ptr[j].label_id, result_ptr[j].score); - } - output.push_back(std::move(label_score)); - result_ptr += result_count[i]; - } - mmdeploy_classifier_release_result(results, result_count, (int)mats.size()); - return output; - } + std::vector>> Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } + mmdeploy_classification_t* results{}; + int* result_count{}; + auto status = mmdeploy_classifier_apply(classifier_, mats.data(), (int)mats.size(), &results, &result_count); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply classifier, code: " + std::to_string(status)); + } + auto output = std::vector>>{}; + output.reserve(mats.size()); + auto result_ptr = results; + for (int i = 0; i < mats.size(); ++i) + { + std::vector> label_score; + for (int j = 0; j < result_count[i]; ++j) + { + label_score.emplace_back(result_ptr[j].label_id, result_ptr[j].score); + } + output.push_back(std::move(label_score)); + result_ptr += result_count[i]; + } + mmdeploy_classifier_release_result(results, result_count, (int)mats.size()); + return output; + } - private: - mmdeploy_classifier_t classifier_{}; -}; + private: + mmdeploy_classifier_t classifier_{}; + }; -static PythonBindingRegisterer register_classifier{[](py::module& m) { - py::class_(m, "Classifier") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyClassifier* self, const PyImage& img) { return self->Apply(std::vector{img})[0]; }) - .def("batch", &PyClassifier::Apply); -}}; + static PythonBindingRegisterer register_classifier{[](py::module& m) + { + py::class_(m, "Classifier") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyClassifier* self, const PyImage& img) + { return self->Apply(std::vector{img})[0]; }) + .def("batch", &PyClassifier::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/common.cpp b/csrc/mmdeploy/apis/python/common.cpp index de4e1adf0a..72ed22089a 100644 --- a/csrc/mmdeploy/apis/python/common.cpp +++ b/csrc/mmdeploy/apis/python/common.cpp @@ -7,166 +7,214 @@ #include "mmdeploy/core/utils/formatter.h" #include "pybind11/numpy.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -std::vector& gPythonBindings() { - static std::vector v; - return v; -} - -mmdeploy_mat_t GetMat(const PyImage& img) { - auto info = img.request(); - if (info.ndim != 3) { - fprintf(stderr, "info.ndim = %d\n", (int)info.ndim); - throw std::runtime_error("continuous uint8 HWC array expected"); - } - auto channels = (int)info.shape[2]; - mmdeploy_mat_t mat{}; - if (channels == 1) { - mat.format = MMDEPLOY_PIXEL_FORMAT_GRAYSCALE; - } else if (channels == 3) { - mat.format = MMDEPLOY_PIXEL_FORMAT_BGR; - } else { - throw std::runtime_error("images of 1 or 3 channels are supported"); - } - mat.height = (int)info.shape[0]; - mat.width = (int)info.shape[1]; - mat.channel = channels; - mat.type = MMDEPLOY_DATA_TYPE_UINT8; - mat.data = (uint8_t*)info.ptr; - return mat; -} + std::vector& gPythonBindings() + { + static std::vector v; + return v; + } -py::object ToPyObject(const Value& value) { - switch (value.type()) { - case ValueType::kNull: - return py::none(); - case ValueType::kBool: - return py::bool_(value.get()); - case ValueType::kInt: - return py::int_(value.get()); - case ValueType::kUInt: - return py::int_(value.get()); - case ValueType::kFloat: - return py::float_(value.get()); - case ValueType::kString: - return py::str(value.get()); - case ValueType::kArray: { - py::list list; - for (const auto& x : value) { - list.append(ToPyObject(x)); - } - return list; + mmdeploy_mat_t GetMat(const PyImage& img) + { + auto info = img.request(); + if (info.ndim != 3) + { + fprintf(stderr, "info.ndim = %d\n", (int)info.ndim); + throw std::runtime_error("continuous uint8 HWC array expected"); + } + auto channels = (int)info.shape[2]; + mmdeploy_mat_t mat{}; + if (channels == 1) + { + mat.format = MMDEPLOY_PIXEL_FORMAT_GRAYSCALE; + } + else if (channels == 3) + { + mat.format = MMDEPLOY_PIXEL_FORMAT_BGR; + } + else + { + throw std::runtime_error("images of 1 or 3 channels are supported"); + } + mat.height = (int)info.shape[0]; + mat.width = (int)info.shape[1]; + mat.channel = channels; + mat.type = MMDEPLOY_DATA_TYPE_UINT8; + mat.data = (uint8_t*)info.ptr; + return mat; } - case ValueType::kObject: { - py::dict dict; - for (auto it = value.begin(); it != value.end(); ++it) { - dict[it.key().c_str()] = ToPyObject(*it); - } - return dict; + + py::object ToPyObject(const Value& value) + { + switch (value.type()) + { + case ValueType::kNull: + return py::none(); + case ValueType::kBool: + return py::bool_(value.get()); + case ValueType::kInt: + return py::int_(value.get()); + case ValueType::kUInt: + return py::int_(value.get()); + case ValueType::kFloat: + return py::float_(value.get()); + case ValueType::kString: + return py::str(value.get()); + case ValueType::kArray: + { + py::list list; + for (const auto& x : value) + { + list.append(ToPyObject(x)); + } + return list; + } + case ValueType::kObject: + { + py::dict dict; + for (auto it = value.begin(); it != value.end(); ++it) + { + dict[it.key().c_str()] = ToPyObject(*it); + } + return dict; + } + case ValueType::kAny: + return py::str(""); + default: + return py::str(""); + } } - case ValueType::kAny: - return py::str(""); - default: - return py::str(""); - } -} -std::optional _to_value_internal(const void* object, mmdeploy_context_type_t type); + std::optional _to_value_internal(const void* object, mmdeploy_context_type_t type); -Value FromPyObject(const py::object& obj) { - if (py::isinstance(obj)) { - return nullptr; - } else if (py::isinstance(obj)) { - return obj.cast(); - } else if (py::isinstance(obj)) { - return obj.cast(); - } else if (py::isinstance(obj)) { - return obj.cast(); - } else if (py::isinstance(obj)) { - return obj.cast(); - } else if (py::isinstance(obj) || py::isinstance(obj)) { - py::list src(obj); - Value::Array dst; - dst.reserve(src.size()); - for (const auto& item : src) { - dst.push_back(FromPyObject(py::reinterpret_borrow(item))); + Value FromPyObject(const py::object& obj) + { + if (py::isinstance(obj)) + { + return nullptr; + } + else if (py::isinstance(obj)) + { + return obj.cast(); + } + else if (py::isinstance(obj)) + { + return obj.cast(); + } + else if (py::isinstance(obj)) + { + return obj.cast(); + } + else if (py::isinstance(obj)) + { + return obj.cast(); + } + else if (py::isinstance(obj) || py::isinstance(obj)) + { + py::list src(obj); + Value::Array dst; + dst.reserve(src.size()); + for (const auto& item : src) + { + dst.push_back(FromPyObject(py::reinterpret_borrow(item))); + } + return dst; + } + else if (py::isinstance(obj)) + { + py::dict src(obj); + Value::Object dst; + for (const auto& item : src) + { + dst.emplace(item.first.cast(), + FromPyObject(py::reinterpret_borrow(item.second))); + } + return dst; + } + else if (py::isinstance(obj)) + { + const auto& array = obj.cast(); + return *_to_value_internal(&array, MMDEPLOY_TYPE_MAT); + } + else if (py::isinstance(obj)) + { + const auto& model = + *reinterpret_cast(static_cast(obj.cast())); + return model; + } + else + { + std::stringstream ss; + ss << obj.get_type(); + MMDEPLOY_ERROR("unsupported Python object type: {}", ss.str()); + return nullptr; + } + return nullptr; } - return dst; - } else if (py::isinstance(obj)) { - py::dict src(obj); - Value::Object dst; - for (const auto& item : src) { - dst.emplace(item.first.cast(), - FromPyObject(py::reinterpret_borrow(item.second))); - } - return dst; - } else if (py::isinstance(obj)) { - const auto& array = obj.cast(); - return *_to_value_internal(&array, MMDEPLOY_TYPE_MAT); - } else if (py::isinstance(obj)) { - const auto& model = - *reinterpret_cast(static_cast(obj.cast())); - return model; - } else { - std::stringstream ss; - ss << obj.get_type(); - MMDEPLOY_ERROR("unsupported Python object type: {}", ss.str()); - return nullptr; - } - return nullptr; -} -std::pair parse_device(const std::string& device) { - auto pos = device.find(':'); - if (pos == std::string::npos) { - return {device, 0}; // logic for index -1 is not ready on some devices - } - auto name = device.substr(0, pos); - auto index = std::stoi(device.substr(pos + 1)); - return {name, index}; -} + std::pair parse_device(const std::string& device) + { + auto pos = device.find(':'); + if (pos == std::string::npos) + { + return {device, 0}; // logic for index -1 is not ready on some devices + } + auto name = device.substr(0, pos); + auto index = std::stoi(device.substr(pos + 1)); + return {name, index}; + } -static PythonBindingRegisterer register_model{[](py::module& m) { - py::class_(m, "Model") - .def(py::init([](const py::str& path) { + static PythonBindingRegisterer register_model{[](py::module& m) + { + py::class_(m, "Model") + .def(py::init([](const py::str& path) + { MMDEPLOY_DEBUG("py::init([](const py::str& path)"); - return Model(path.cast().c_str()); - })) - .def(py::init([](const py::bytes& buffer) { + return Model(path.cast().c_str()); })) + .def(py::init([](const py::bytes& buffer) + { MMDEPLOY_DEBUG("py::init([](const py::bytes& buffer)"); py::buffer_info info(py::buffer(buffer).request()); - return Model(info.ptr, info.size); - })); -}}; + return Model(info.ptr, info.size); })); + }}; -static PythonBindingRegisterer register_device{[](py::module& m) { - py::class_(m, "Device") - .def(py::init([](const std::string& device) { + static PythonBindingRegisterer register_device{[](py::module& m) + { + py::class_(m, "Device") + .def(py::init([](const std::string& device) + { auto [name, index] = parse_device(device); - return Device(name, index); - })) - .def(py::init([](const std::string& name, int index) { return Device(name, index); })); -}}; + return Device(name, index); })) + .def(py::init([](const std::string& name, int index) + { return Device(name, index); })); + }}; -static PythonBindingRegisterer register_context{[](py::module& m) { - py::class_(m, "Context") - .def(py::init([](const Device& device) { return Context(device); })) - .def("add", [](Context* self, const std::string& name, const Scheduler& sched) { - self->Add(name, sched); - }); -}}; + static PythonBindingRegisterer register_context{[](py::module& m) + { + py::class_(m, "Context") + .def(py::init([](const Device& device) + { return Context(device); })) + .def("add", [](Context* self, const std::string& name, const Scheduler& sched) + { self->Add(name, sched); }); + }}; -static PythonBindingRegisterer register_scheduler{[](py::module& m) { - py::class_(m, "Scheduler") - .def_static("thread_pool", [](int n_workers) { return Scheduler::ThreadPool(n_workers); }) - .def_static("thread", [] { return Scheduler::Thread(); }); -}}; + static PythonBindingRegisterer register_scheduler{[](py::module& m) + { + py::class_(m, "Scheduler") + .def_static("thread_pool", [](int n_workers) + { return Scheduler::ThreadPool(n_workers); }) + .def_static("thread", [] + { return Scheduler::Thread(); }); + }}; } // namespace mmdeploy::python -PYBIND11_MODULE(mmdeploy_runtime, m) { - for (const auto& f : mmdeploy::python::gPythonBindings()) { - f(m); - } +PYBIND11_MODULE(mmdeploy_runtime, m) +{ + for (const auto& f : mmdeploy::python::gPythonBindings()) + { + f(m); + } } diff --git a/csrc/mmdeploy/apis/python/common.h b/csrc/mmdeploy/apis/python/common.h index 5b1ca96b74..e50ed76007 100644 --- a/csrc/mmdeploy/apis/python/common.h +++ b/csrc/mmdeploy/apis/python/common.h @@ -13,24 +13,27 @@ namespace py = pybind11; -namespace mmdeploy::python { +namespace mmdeploy::python +{ -using PyImage = py::array_t; + using PyImage = py::array_t; -std::vector& gPythonBindings(); + std::vector& gPythonBindings(); -mmdeploy_mat_t GetMat(const PyImage& img); + mmdeploy_mat_t GetMat(const PyImage& img); -py::object ToPyObject(const Value& value); + py::object ToPyObject(const Value& value); -Value FromPyObject(const py::object& obj); + Value FromPyObject(const py::object& obj); -class PythonBindingRegisterer { - public: - explicit PythonBindingRegisterer(void (*register_fn)(py::module& m)) { - gPythonBindings().push_back(register_fn); - } -}; + class PythonBindingRegisterer + { + public: + explicit PythonBindingRegisterer(void (*register_fn)(py::module& m)) + { + gPythonBindings().push_back(register_fn); + } + }; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/detector.cpp b/csrc/mmdeploy/apis/python/detector.cpp index 057a92ab00..137998f6b7 100644 --- a/csrc/mmdeploy/apis/python/detector.cpp +++ b/csrc/mmdeploy/apis/python/detector.cpp @@ -4,82 +4,97 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyDetector { - public: - PyDetector(const char* model_path, const char* device_name, int device_id) { - auto status = mmdeploy_detector_create_by_path(model_path, device_name, device_id, &detector_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create detector"); - } - } - py::list Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } - mmdeploy_detection_t* detection{}; - int* result_count{}; - auto status = mmdeploy_detector_apply(detector_, mats.data(), (int)mats.size(), &detection, - &result_count); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply detector, code: " + std::to_string(status)); - } - using Sptr = std::shared_ptr; - Sptr holder(detection, [result_count, n = mats.size()](auto p) { - mmdeploy_detector_release_result(p, result_count, n); - }); - auto output = py::list{}; - auto result = detection; - for (int i = 0; i < mats.size(); ++i) { - auto bboxes = py::array_t({result_count[i], 5}); - auto labels = py::array_t(result_count[i]); - auto masks = std::vector(); - masks.reserve(result_count[i]); - for (int j = 0; j < result_count[i]; ++j, ++result) { - auto bbox = bboxes.mutable_data(j); - bbox[0] = result->bbox.left; - bbox[1] = result->bbox.top; - bbox[2] = result->bbox.right; - bbox[3] = result->bbox.bottom; - bbox[4] = result->score; - labels.mutable_at(j) = result->label_id; - if (result->mask) { - masks.emplace_back(std::array{result->mask->height, result->mask->width}, // shape - reinterpret_cast(result->mask->data), // data - py::capsule(new Sptr(holder), // handle - [](void* p) { delete reinterpret_cast(p); })); - } else { - masks.emplace_back(); + class PyDetector + { + public: + PyDetector(const char* model_path, const char* device_name, int device_id) + { + auto status = mmdeploy_detector_create_by_path(model_path, device_name, device_id, &detector_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create detector"); + } + } + py::list Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } + mmdeploy_detection_t* detection{}; + int* result_count{}; + auto status = mmdeploy_detector_apply(detector_, mats.data(), (int)mats.size(), &detection, &result_count); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply detector, code: " + std::to_string(status)); + } + using Sptr = std::shared_ptr; + Sptr holder(detection, [result_count, n = mats.size()](auto p) + { mmdeploy_detector_release_result(p, result_count, n); }); + auto output = py::list{}; + auto result = detection; + for (int i = 0; i < mats.size(); ++i) + { + auto bboxes = py::array_t({result_count[i], 5}); + auto labels = py::array_t(result_count[i]); + auto masks = std::vector(); + masks.reserve(result_count[i]); + for (int j = 0; j < result_count[i]; ++j, ++result) + { + auto bbox = bboxes.mutable_data(j); + bbox[0] = result->bbox.left; + bbox[1] = result->bbox.top; + bbox[2] = result->bbox.right; + bbox[3] = result->bbox.bottom; + bbox[4] = result->score; + labels.mutable_at(j) = result->label_id; + if (result->mask) + { + masks.emplace_back(std::array{result->mask->height, result->mask->width}, // shape + reinterpret_cast(result->mask->data), // data + py::capsule(new Sptr(holder), // handle + [](void* p) + { delete reinterpret_cast(p); })); + } + else + { + masks.emplace_back(); + } + } + output.append(py::make_tuple(std::move(bboxes), std::move(labels), std::move(masks))); + } + return output; + } + ~PyDetector() + { + mmdeploy_detector_destroy(detector_); + detector_ = {}; } - } - output.append(py::make_tuple(std::move(bboxes), std::move(labels), std::move(masks))); - } - return output; - } - ~PyDetector() { - mmdeploy_detector_destroy(detector_); - detector_ = {}; - } - private: - mmdeploy_detector_t detector_{}; -}; + private: + mmdeploy_detector_t detector_{}; + }; -static PythonBindingRegisterer register_detector{[](py::module& m) { - py::class_(m, "Detector") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyDetector* self, const PyImage& img) -> py::tuple { - return self->Apply(std::vector{img})[0]; - }) - .def("batch", &PyDetector::Apply); -}}; + static PythonBindingRegisterer register_detector{[](py::module& m) + { + py::class_(m, "Detector") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyDetector* self, const PyImage& img) -> py::tuple + { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PyDetector::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/executor.cpp b/csrc/mmdeploy/apis/python/executor.cpp index eaa5c1144b..489985f232 100644 --- a/csrc/mmdeploy/apis/python/executor.cpp +++ b/csrc/mmdeploy/apis/python/executor.cpp @@ -8,39 +8,48 @@ #include "mmdeploy/execution/schedulers/single_thread_context.h" #include "mmdeploy/execution/schedulers/static_thread_pool.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -struct PySender { - TypeErasedSender sender_; - - explicit PySender(TypeErasedSender sender) : sender_(std::move(sender)) {} - - struct gil_guarded_deleter { - void operator()(py::object* p) const { - py::gil_scoped_acquire _; - delete p; - } - }; - using object_ptr = std::unique_ptr; - - py::object __await__() { - auto future = py::module::import("concurrent.futures").attr("Future")(); + struct PySender { - py::gil_scoped_release _; - StartDetached(std::move(sender_) | - Then([future = object_ptr{new py::object(future)}](const Value& value) mutable { + TypeErasedSender sender_; + + explicit PySender(TypeErasedSender sender) + : sender_(std::move(sender)) + { + } + + struct gil_guarded_deleter + { + void operator()(py::object* p) const + { + py::gil_scoped_acquire _; + delete p; + } + }; + using object_ptr = std::unique_ptr; + + py::object __await__() + { + auto future = py::module::import("concurrent.futures").attr("Future")(); + { + py::gil_scoped_release _; + StartDetached(std::move(sender_) | + Then([future = object_ptr{new py::object(future)}](const Value& value) mutable + { py::gil_scoped_acquire _; future->attr("set_result")(ToPyObject(value)); - delete future.release(); - })); - } - return py::module::import("asyncio").attr("wrap_future")(future).attr("__await__")(); - } -}; - -static PythonBindingRegisterer register_sender{[](py::module& m) { - py::class_>(m, "PySender") - .def("__await__", &PySender::__await__); -}}; + delete future.release(); })); + } + return py::module::import("asyncio").attr("wrap_future")(future).attr("__await__")(); + } + }; + + static PythonBindingRegisterer register_sender{[](py::module& m) + { + py::class_>(m, "PySender") + .def("__await__", &PySender::__await__); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/internal.cpp b/csrc/mmdeploy/apis/python/internal.cpp index 7373c1f184..8c38f5a7ce 100644 --- a/csrc/mmdeploy/apis/python/internal.cpp +++ b/csrc/mmdeploy/apis/python/internal.cpp @@ -9,49 +9,60 @@ #include "mmdeploy/core/model.h" #include "mmdeploy/core/value.h" -namespace mmdeploy { - -namespace python { - -framework::Mat _get_mat(const PyImage& img) { - auto info = img.request(); - if (info.ndim != 3) { - fprintf(stderr, "info.ndim = %d\n", (int)info.ndim); - throw std::runtime_error("continuous uint8 HWC array expected"); - } - auto channels = (int)info.shape[2]; - PixelFormat format; - if (channels == 1) { - format = PixelFormat::kGRAYSCALE; - } else if (channels == 3) { - format = PixelFormat::kBGR; - } else { - throw std::runtime_error("images of 1 or 3 channels are supported"); - } - - return { - (int)info.shape[0], // height - (int)info.shape[1], // width - format, // format - DataType::kINT8, // type - std::shared_ptr(info.ptr, [](void*) {}), // data - framework::Device(0), // device - }; -} - -std::optional _to_value_internal(const void* object, mmdeploy_context_type_t type) { - switch (type) { - case MMDEPLOY_TYPE_MODEL: - return Value(*(const framework::Model*)object); - case MMDEPLOY_TYPE_DEVICE: - return Value(*(const framework::Device*)object); - case MMDEPLOY_TYPE_MAT: - return _get_mat(*(const py::array*)object); - default: - return std::nullopt; - } -} - -} // namespace python +namespace mmdeploy +{ + + namespace python + { + + framework::Mat _get_mat(const PyImage& img) + { + auto info = img.request(); + if (info.ndim != 3) + { + fprintf(stderr, "info.ndim = %d\n", (int)info.ndim); + throw std::runtime_error("continuous uint8 HWC array expected"); + } + auto channels = (int)info.shape[2]; + PixelFormat format; + if (channels == 1) + { + format = PixelFormat::kGRAYSCALE; + } + else if (channels == 3) + { + format = PixelFormat::kBGR; + } + else + { + throw std::runtime_error("images of 1 or 3 channels are supported"); + } + + return { + (int)info.shape[0], // height + (int)info.shape[1], // width + format, // format + DataType::kINT8, // type + std::shared_ptr(info.ptr, [](void*) {}), // data + framework::Device(0), // device + }; + } + + std::optional _to_value_internal(const void* object, mmdeploy_context_type_t type) + { + switch (type) + { + case MMDEPLOY_TYPE_MODEL: + return Value(*(const framework::Model*)object); + case MMDEPLOY_TYPE_DEVICE: + return Value(*(const framework::Device*)object); + case MMDEPLOY_TYPE_MAT: + return _get_mat(*(const py::array*)object); + default: + return std::nullopt; + } + } + + } // namespace python } // namespace mmdeploy diff --git a/csrc/mmdeploy/apis/python/pipeline.cpp b/csrc/mmdeploy/apis/python/pipeline.cpp index e3e6237e44..114bce2095 100644 --- a/csrc/mmdeploy/apis/python/pipeline.cpp +++ b/csrc/mmdeploy/apis/python/pipeline.cpp @@ -7,41 +7,47 @@ #include "mmdeploy/core/logger.h" #include "mmdeploy/core/utils/formatter.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -using namespace std::literals; + using namespace std::literals; -static PythonBindingRegisterer register_pipeline{[](py::module& m) { - py::class_(m, "Pipeline") - .def(py::init([](const py::object& config, const Context& context) { + static PythonBindingRegisterer register_pipeline{[](py::module& m) + { + py::class_(m, "Pipeline") + .def(py::init([](const py::object& config, const Context& context) + { auto _config = FromPyObject(config); - return std::make_unique(_config, context); - })) - .def("__call__", - [](Pipeline* pipeline, const py::args& args) { - auto inputs = FromPyObject(args); - for (auto& input : inputs) { - input = Value::Array{std::move(input)}; - } - auto outputs = pipeline->Apply(inputs); - for (auto& output : outputs) { - output = std::move(output[0]); - } - py::tuple rets(outputs.size()); - for (int i = 0; i < outputs.size(); ++i) { - rets[i] = ToPyObject(outputs[i]); - } - return rets; - }) - .def("batch", [](Pipeline* pipeline, const py::args& args) { + return std::make_unique(_config, context); })) + .def("__call__", + [](Pipeline* pipeline, const py::args& args) + { + auto inputs = FromPyObject(args); + for (auto& input : inputs) + { + input = Value::Array{std::move(input)}; + } + auto outputs = pipeline->Apply(inputs); + for (auto& output : outputs) + { + output = std::move(output[0]); + } + py::tuple rets(outputs.size()); + for (int i = 0; i < outputs.size(); ++i) + { + rets[i] = ToPyObject(outputs[i]); + } + return rets; + }) + .def("batch", [](Pipeline* pipeline, const py::args& args) + { auto inputs = FromPyObject(args); auto outputs = pipeline->Apply(inputs); py::tuple rets(outputs.size()); for (int i = 0; i < outputs.size(); ++i) { rets[i] = ToPyObject(outputs[i]); } - return rets; - }); -}}; + return rets; }); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/pose_detector.cpp b/csrc/mmdeploy/apis/python/pose_detector.cpp index f9d99eaf14..b6dc96560a 100644 --- a/csrc/mmdeploy/apis/python/pose_detector.cpp +++ b/csrc/mmdeploy/apis/python/pose_detector.cpp @@ -7,122 +7,143 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -using Rect = std::array; + using Rect = std::array; -class PyPoseDetector { - public: - PyPoseDetector(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_pose_detector_create_by_path(model_path, device_name, device_id, &detector_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create pose_detector"); - } - } - py::list Apply(const std::vector& imgs, const std::vector>& bboxes) { - if (imgs.size() == 0 && bboxes.size() == 0) { - return py::list{}; - } - if (bboxes.size() != 0 && bboxes.size() != imgs.size()) { - std::ostringstream os; - os << "imgs length not equal with vboxes [" << imgs.size() << " vs " << bboxes.size() << "]"; - throw std::invalid_argument(os.str()); - } + class PyPoseDetector + { + public: + PyPoseDetector(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_pose_detector_create_by_path(model_path, device_name, device_id, &detector_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create pose_detector"); + } + } + py::list Apply(const std::vector& imgs, const std::vector>& bboxes) + { + if (imgs.size() == 0 && bboxes.size() == 0) + { + return py::list{}; + } + if (bboxes.size() != 0 && bboxes.size() != imgs.size()) + { + std::ostringstream os; + os << "imgs length not equal with vboxes [" << imgs.size() << " vs " << bboxes.size() << "]"; + throw std::invalid_argument(os.str()); + } - std::vector mats; - std::vector boxes; - std::vector bbox_count; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } + std::vector mats; + std::vector boxes; + std::vector bbox_count; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } - for (auto _boxes : bboxes) { - for (auto _box : _boxes) { - mmdeploy_rect_t box = {_box[0], _box[1], _box[2], _box[3]}; - boxes.push_back(box); - } - bbox_count.push_back(_boxes.size()); - } + for (auto _boxes : bboxes) + { + for (auto _box : _boxes) + { + mmdeploy_rect_t box = {_box[0], _box[1], _box[2], _box[3]}; + boxes.push_back(box); + } + bbox_count.push_back(_boxes.size()); + } - // full image - if (bboxes.size() == 0) { - for (int i = 0; i < mats.size(); i++) { - mmdeploy_rect_t box = {0.f, 0.f, mats[i].width - 1.f, mats[i].height - 1.f}; - boxes.push_back(box); - bbox_count.push_back(1); - } - } + // full image + if (bboxes.size() == 0) + { + for (int i = 0; i < mats.size(); i++) + { + mmdeploy_rect_t box = {0.f, 0.f, mats[i].width - 1.f, mats[i].height - 1.f}; + boxes.push_back(box); + bbox_count.push_back(1); + } + } - mmdeploy_pose_detection_t* detection{}; - auto status = mmdeploy_pose_detector_apply_bbox(detector_, mats.data(), (int)mats.size(), - boxes.data(), bbox_count.data(), &detection); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply pose_detector, code: " + std::to_string(status)); - } + mmdeploy_pose_detection_t* detection{}; + auto status = mmdeploy_pose_detector_apply_bbox(detector_, mats.data(), (int)mats.size(), boxes.data(), bbox_count.data(), &detection); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply pose_detector, code: " + std::to_string(status)); + } - auto output = py::list{}; - auto result = detection; - for (int i = 0; i < mats.size(); i++) { - int n_point = result->length; - auto pred = py::array_t({bbox_count[i], n_point, 3}); - auto dst = pred.mutable_data(); - for (int j = 0; j < bbox_count[i]; j++) { - for (int k = 0; k < n_point; k++) { - dst[0] = result->point[k].x; - dst[1] = result->point[k].y; - dst[2] = result->score[k]; - dst += 3; - } - result++; - } - output.append(std::move(pred)); - } + auto output = py::list{}; + auto result = detection; + for (int i = 0; i < mats.size(); i++) + { + int n_point = result->length; + auto pred = py::array_t({bbox_count[i], n_point, 3}); + auto dst = pred.mutable_data(); + for (int j = 0; j < bbox_count[i]; j++) + { + for (int k = 0; k < n_point; k++) + { + dst[0] = result->point[k].x; + dst[1] = result->point[k].y; + dst[2] = result->score[k]; + dst += 3; + } + result++; + } + output.append(std::move(pred)); + } - int total = std::accumulate(bbox_count.begin(), bbox_count.end(), 0); - mmdeploy_pose_detector_release_result(detection, total); - return output; - } - ~PyPoseDetector() { - mmdeploy_pose_detector_destroy(detector_); - detector_ = {}; - } + int total = std::accumulate(bbox_count.begin(), bbox_count.end(), 0); + mmdeploy_pose_detector_release_result(detection, total); + return output; + } + ~PyPoseDetector() + { + mmdeploy_pose_detector_destroy(detector_); + detector_ = {}; + } - private: - mmdeploy_pose_detector_t detector_{}; -}; + private: + mmdeploy_pose_detector_t detector_{}; + }; -static PythonBindingRegisterer register_pose_detector{[](py::module& m) { - py::class_(m, "PoseDetector") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyPoseDetector* self, const PyImage& img) -> py::array { - return self->Apply({img}, {})[0]; - }) - .def( - "__call__", - [](PyPoseDetector* self, const PyImage& img, const Rect& box) -> py::array { - std::vector> bboxes; - bboxes.push_back({box}); - return self->Apply({img}, bboxes)[0]; - }, - py::arg("img"), py::arg("box")) - .def( - "__call__", - [](PyPoseDetector* self, const PyImage& img, - const std::vector& bboxes) -> py::array { - std::vector> _bboxes; - _bboxes.push_back(bboxes); - return self->Apply({img}, _bboxes)[0]; - }, - py::arg("img"), py::arg("bboxes")) - .def("batch", &PyPoseDetector::Apply, py::arg("imgs"), - py::arg("bboxes") = std::vector>()); -}}; + static PythonBindingRegisterer register_pose_detector{[](py::module& m) + { + py::class_(m, "PoseDetector") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyPoseDetector* self, const PyImage& img) -> py::array + { + return self->Apply({img}, {})[0]; + }) + .def( + "__call__", + [](PyPoseDetector* self, const PyImage& img, const Rect& box) -> py::array + { + std::vector> bboxes; + bboxes.push_back({box}); + return self->Apply({img}, bboxes)[0]; + }, + py::arg("img"), + py::arg("box")) + .def( + "__call__", + [](PyPoseDetector* self, const PyImage& img, const std::vector& bboxes) -> py::array + { + std::vector> _bboxes; + _bboxes.push_back(bboxes); + return self->Apply({img}, _bboxes)[0]; + }, + py::arg("img"), + py::arg("bboxes")) + .def("batch", &PyPoseDetector::Apply, py::arg("imgs"), py::arg("bboxes") = std::vector>()); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/pose_tracker.cpp b/csrc/mmdeploy/apis/python/pose_tracker.cpp index 035ce3cdd1..c14f2450e8 100644 --- a/csrc/mmdeploy/apis/python/pose_tracker.cpp +++ b/csrc/mmdeploy/apis/python/pose_tracker.cpp @@ -5,146 +5,200 @@ #include "common.h" #include "mmdeploy/common.hpp" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -namespace { + namespace + { -std::vector Apply(mmdeploy::PoseTracker* self, - const std::vector& _states, - const std::vector& _frames, std::vector detect) { - std::vector tmp; - for (const auto& s : _states) { - tmp.push_back(static_cast(*s)); - } - mmdeploy::Span states(reinterpret_cast(tmp.data()), tmp.size()); - std::vector frames; - for (const auto& f : _frames) { - frames.emplace_back(GetMat(f)); - } - if (detect.empty()) { - detect.resize(frames.size(), -1); - } - assert(states.size() == frames.size()); - assert(states.size() == detect.size()); - auto results = self->Apply(states, frames, detect); - std::vector batch_ret; - batch_ret.reserve(frames.size()); - for (const auto& rs : results) { - py::array_t keypoints( - {static_cast(rs.size()), rs.size() > 0 ? rs[0].keypoint_count : 0, 3}); - py::array_t bboxes({static_cast(rs.size()), 4}); - py::array_t track_ids(static_cast(rs.size())); - auto kpts_ptr = keypoints.mutable_data(); - auto bbox_ptr = bboxes.mutable_data(); - auto track_id_ptr = track_ids.mutable_data(); - for (const auto& r : rs) { - for (int i = 0; i < r.keypoint_count; ++i) { - kpts_ptr[0] = r.keypoints[i].x; - kpts_ptr[1] = r.keypoints[i].y; - kpts_ptr[2] = r.scores[i]; - kpts_ptr += 3; - } - { - auto tmp_bbox = (std::array&)r.bbox; - bbox_ptr[0] = tmp_bbox[0]; - bbox_ptr[1] = tmp_bbox[1]; - bbox_ptr[2] = tmp_bbox[2]; - bbox_ptr[3] = tmp_bbox[3]; - bbox_ptr += 4; - } - *track_id_ptr++ = r.target_id; - } - batch_ret.push_back( - py::make_tuple(std::move(keypoints), std::move(bboxes), std::move(track_ids))); - } - return batch_ret; -} + std::vector Apply(mmdeploy::PoseTracker* self, + const std::vector& _states, + const std::vector& _frames, + std::vector detect) + { + std::vector tmp; + for (const auto& s : _states) + { + tmp.push_back(static_cast(*s)); + } + mmdeploy::Span states(reinterpret_cast(tmp.data()), tmp.size()); + std::vector frames; + for (const auto& f : _frames) + { + frames.emplace_back(GetMat(f)); + } + if (detect.empty()) + { + detect.resize(frames.size(), -1); + } + assert(states.size() == frames.size()); + assert(states.size() == detect.size()); + auto results = self->Apply(states, frames, detect); + std::vector batch_ret; + batch_ret.reserve(frames.size()); + for (const auto& rs : results) + { + py::array_t keypoints( + {static_cast(rs.size()), rs.size() > 0 ? rs[0].keypoint_count : 0, 3}); + py::array_t bboxes({static_cast(rs.size()), 4}); + py::array_t track_ids(static_cast(rs.size())); + auto kpts_ptr = keypoints.mutable_data(); + auto bbox_ptr = bboxes.mutable_data(); + auto track_id_ptr = track_ids.mutable_data(); + for (const auto& r : rs) + { + for (int i = 0; i < r.keypoint_count; ++i) + { + kpts_ptr[0] = r.keypoints[i].x; + kpts_ptr[1] = r.keypoints[i].y; + kpts_ptr[2] = r.scores[i]; + kpts_ptr += 3; + } + { + auto tmp_bbox = (std::array&)r.bbox; + bbox_ptr[0] = tmp_bbox[0]; + bbox_ptr[1] = tmp_bbox[1]; + bbox_ptr[2] = tmp_bbox[2]; + bbox_ptr[3] = tmp_bbox[3]; + bbox_ptr += 4; + } + *track_id_ptr++ = r.target_id; + } + batch_ret.push_back( + py::make_tuple(std::move(keypoints), std::move(bboxes), std::move(track_ids))); + } + return batch_ret; + } -template -void Copy(const py::handle& h, T (&a)[N]) { - auto array = h.cast>(); - assert(array.size() == N); - auto data = array.data(); - for (int i = 0; i < N; ++i) { - a[i] = data[i]; - } -} + template + void Copy(const py::handle& h, T (&a)[N]) + { + auto array = h.cast>(); + assert(array.size() == N); + auto data = array.data(); + for (int i = 0; i < N; ++i) + { + a[i] = data[i]; + } + } -void Parse(const py::dict& dict, PoseTracker::Params& params, py::array_t& sigmas) { - for (const auto& [_name, value] : dict) { - auto name = _name.cast(); - if (name == "det_interval") { - params->det_interval = value.cast(); - } else if (name == "det_label") { - params->det_label = value.cast(); - } else if (name == "det_thr") { - params->det_thr = value.cast(); - } else if (name == "det_min_bbox_size") { - params->det_min_bbox_size = value.cast(); - } else if (name == "det_nms_thr") { - params->det_nms_thr = value.cast(); - } else if (name == "pose_max_num_bboxes") { - params->pose_max_num_bboxes = value.cast(); - } else if (name == "pose_min_keypoints") { - params->pose_min_keypoints = value.cast(); - } else if (name == "pose_min_bbox_size") { - params->pose_min_bbox_size = value.cast(); - } else if (name == "pose_nms_thr") { - params->pose_nms_thr = value.cast(); - } else if (name == "track_kpt_thr") { - params->pose_kpt_thr = value.cast(); - } else if (name == "track_iou_thr") { - params->track_iou_thr = value.cast(); - } else if (name == "pose_bbox_scale") { - params->pose_bbox_scale = value.cast(); - } else if (name == "track_max_missing") { - params->track_max_missing = value.cast(); - } else if (name == "track_history_size") { - params->track_history_size = value.cast(); - } else if (name == "keypoint_sigmas") { - sigmas = value.cast>(); - params->keypoint_sigmas = const_cast(sigmas.data()); - params->keypoint_sigmas_size = sigmas.size(); - } else if (name == "std_weight_position") { - params->std_weight_position = value.cast(); - } else if (name == "std_weight_velocity") { - params->std_weight_velocity = value.cast(); - } else if (name == "smooth_params") { - Copy(value, params->smooth_params); - } else { - MMDEPLOY_ERROR("unused argument: {}", name); - } - } -} + void Parse(const py::dict& dict, PoseTracker::Params& params, py::array_t& sigmas) + { + for (const auto& [_name, value] : dict) + { + auto name = _name.cast(); + if (name == "det_interval") + { + params->det_interval = value.cast(); + } + else if (name == "det_label") + { + params->det_label = value.cast(); + } + else if (name == "det_thr") + { + params->det_thr = value.cast(); + } + else if (name == "det_min_bbox_size") + { + params->det_min_bbox_size = value.cast(); + } + else if (name == "det_nms_thr") + { + params->det_nms_thr = value.cast(); + } + else if (name == "pose_max_num_bboxes") + { + params->pose_max_num_bboxes = value.cast(); + } + else if (name == "pose_min_keypoints") + { + params->pose_min_keypoints = value.cast(); + } + else if (name == "pose_min_bbox_size") + { + params->pose_min_bbox_size = value.cast(); + } + else if (name == "pose_nms_thr") + { + params->pose_nms_thr = value.cast(); + } + else if (name == "track_kpt_thr") + { + params->pose_kpt_thr = value.cast(); + } + else if (name == "track_iou_thr") + { + params->track_iou_thr = value.cast(); + } + else if (name == "pose_bbox_scale") + { + params->pose_bbox_scale = value.cast(); + } + else if (name == "track_max_missing") + { + params->track_max_missing = value.cast(); + } + else if (name == "track_history_size") + { + params->track_history_size = value.cast(); + } + else if (name == "keypoint_sigmas") + { + sigmas = value.cast>(); + params->keypoint_sigmas = const_cast(sigmas.data()); + params->keypoint_sigmas_size = sigmas.size(); + } + else if (name == "std_weight_position") + { + params->std_weight_position = value.cast(); + } + else if (name == "std_weight_velocity") + { + params->std_weight_velocity = value.cast(); + } + else if (name == "smooth_params") + { + Copy(value, params->smooth_params); + } + else + { + MMDEPLOY_ERROR("unused argument: {}", name); + } + } + } -} // namespace + } // namespace -static PythonBindingRegisterer register_pose_tracker{[](py::module& m) { - py::class_(m, "PoseTracker.State"); - py::class_(m, "PoseTracker") - .def(py::init([](const char* det_model_path, const char* pose_model_path, - const char* device_name, int device_id) { - return mmdeploy::PoseTracker( - mmdeploy::Model(det_model_path), mmdeploy::Model(pose_model_path), - mmdeploy::Context(mmdeploy::Device(device_name, device_id))); - }), - py::arg("det_model"), py::arg("pose_model"), py::arg("device_name"), - py::arg("device_id") = 0) - .def( - "__call__", - [](mmdeploy::PoseTracker* self, mmdeploy::PoseTracker::State* state, const PyImage& img, - int detect) { return Apply(self, {state}, {img}, {detect})[0]; }, - py::arg("state"), py::arg("frame"), py::arg("detect") = -1) - .def("batch", &Apply, py::arg("states"), py::arg("frames"), - py::arg("detects") = std::vector{}) - .def("create_state", [](mmdeploy::PoseTracker* self, const py::kwargs& kwargs) { + static PythonBindingRegisterer register_pose_tracker{[](py::module& m) + { + py::class_(m, "PoseTracker.State"); + py::class_(m, "PoseTracker") + .def(py::init([](const char* det_model_path, const char* pose_model_path, const char* device_name, int device_id) + { return mmdeploy::PoseTracker( + mmdeploy::Model(det_model_path), + mmdeploy::Model(pose_model_path), + mmdeploy::Context(mmdeploy::Device(device_name, device_id))); }), + py::arg("det_model"), + py::arg("pose_model"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def( + "__call__", + [](mmdeploy::PoseTracker* self, mmdeploy::PoseTracker::State* state, const PyImage& img, int detect) + { return Apply(self, {state}, {img}, {detect})[0]; }, + py::arg("state"), + py::arg("frame"), + py::arg("detect") = -1) + .def("batch", &Apply, py::arg("states"), py::arg("frames"), py::arg("detects") = std::vector{}) + .def("create_state", [](mmdeploy::PoseTracker* self, const py::kwargs& kwargs) + { PoseTracker::Params params; py::array_t sigmas; if (kwargs) { Parse(kwargs, params, sigmas); } - return self->CreateState(params); - }); -}}; + return self->CreateState(params); }); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/restorer.cpp b/csrc/mmdeploy/apis/python/restorer.cpp index 771af2a6c4..ddd4c0a8ff 100644 --- a/csrc/mmdeploy/apis/python/restorer.cpp +++ b/csrc/mmdeploy/apis/python/restorer.cpp @@ -4,63 +4,77 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyRestorer { - public: - PyRestorer(const char* model_path, const char* device_name, int device_id) { - auto status = mmdeploy_restorer_create_by_path(model_path, device_name, device_id, &restorer_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create restorer"); - } - } - ~PyRestorer() { - mmdeploy_restorer_destroy(restorer_); - restorer_ = {}; - } + class PyRestorer + { + public: + PyRestorer(const char* model_path, const char* device_name, int device_id) + { + auto status = mmdeploy_restorer_create_by_path(model_path, device_name, device_id, &restorer_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create restorer"); + } + } + ~PyRestorer() + { + mmdeploy_restorer_destroy(restorer_); + restorer_ = {}; + } - std::vector Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } - mmdeploy_mat_t* results{}; - auto status = mmdeploy_restorer_apply(restorer_, mats.data(), (int)mats.size(), &results); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply restorer, code: " + std::to_string(status)); - } - using Sptr = std::shared_ptr; - Sptr holder(results, [n = mats.size()](auto p) { mmdeploy_restorer_release_result(p, n); }); + std::vector Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } + mmdeploy_mat_t* results{}; + auto status = mmdeploy_restorer_apply(restorer_, mats.data(), (int)mats.size(), &results); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply restorer, code: " + std::to_string(status)); + } + using Sptr = std::shared_ptr; + Sptr holder(results, [n = mats.size()](auto p) + { mmdeploy_restorer_release_result(p, n); }); - std::vector rets(mats.size()); - for (int i = 0; i < mats.size(); ++i) { - rets[i] = { - {results[i].height, results[i].width, results[i].channel}, // shape - results[i].data, // data - py::capsule(new Sptr(holder), // handle - [](void* p) { delete reinterpret_cast(p); }) // - }; - } - return rets; - } + std::vector rets(mats.size()); + for (int i = 0; i < mats.size(); ++i) + { + rets[i] = { + {results[i].height, results[i].width, results[i].channel}, // shape + results[i].data, // data + py::capsule(new Sptr(holder), // handle + [](void* p) + { delete reinterpret_cast(p); }) // + }; + } + return rets; + } - private: - mmdeploy_restorer_t restorer_{}; -}; + private: + mmdeploy_restorer_t restorer_{}; + }; -static PythonBindingRegisterer register_restorer{[](py::module& m) { - py::class_(m, "Restorer") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyRestorer* self, const PyImage& img) -> py::array { - return self->Apply(std::vector{img})[0]; - }) - .def("batch", &PyRestorer::Apply); -}}; + static PythonBindingRegisterer register_restorer{[](py::module& m) + { + py::class_(m, "Restorer") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyRestorer* self, const PyImage& img) -> py::array + { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PyRestorer::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/rotated_detector.cpp b/csrc/mmdeploy/apis/python/rotated_detector.cpp index bc760b04e4..148b31fa6e 100644 --- a/csrc/mmdeploy/apis/python/rotated_detector.cpp +++ b/csrc/mmdeploy/apis/python/rotated_detector.cpp @@ -4,74 +4,87 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyRotatedDetector { - public: - PyRotatedDetector(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_rotated_detector_create_by_path(model_path, device_name, device_id, &detector_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create rotated detector"); - } - } - py::list Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } + class PyRotatedDetector + { + public: + PyRotatedDetector(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_rotated_detector_create_by_path(model_path, device_name, device_id, &detector_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create rotated detector"); + } + } + py::list Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } - mmdeploy_rotated_detection_t* rbboxes{}; - int* res_count{}; - auto status = mmdeploy_rotated_detector_apply(detector_, mats.data(), (int)mats.size(), - &rbboxes, &res_count); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply rotated detector, code: " + std::to_string(status)); - } - auto output = py::list{}; - auto result = rbboxes; - auto counts = res_count; - for (int i = 0; i < mats.size(); i++) { - auto _dets = py::array_t({*counts, 6}); - auto _labels = py::array_t({*counts}); - auto dets = _dets.mutable_data(); - auto labels = _labels.mutable_data(); - for (int j = 0; j < *counts; j++) { - for (int k = 0; k < 5; k++) { - *dets++ = result->rbbox[k]; + mmdeploy_rotated_detection_t* rbboxes{}; + int* res_count{}; + auto status = mmdeploy_rotated_detector_apply(detector_, mats.data(), (int)mats.size(), &rbboxes, &res_count); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply rotated detector, code: " + std::to_string(status)); + } + auto output = py::list{}; + auto result = rbboxes; + auto counts = res_count; + for (int i = 0; i < mats.size(); i++) + { + auto _dets = py::array_t({*counts, 6}); + auto _labels = py::array_t({*counts}); + auto dets = _dets.mutable_data(); + auto labels = _labels.mutable_data(); + for (int j = 0; j < *counts; j++) + { + for (int k = 0; k < 5; k++) + { + *dets++ = result->rbbox[k]; + } + *dets++ = result->score; + *labels++ = result->label_id; + result++; + } + counts++; + output.append(py::make_tuple(std::move(_dets), std::move(_labels))); + } + mmdeploy_rotated_detector_release_result(rbboxes, res_count); + return output; + } + ~PyRotatedDetector() + { + mmdeploy_rotated_detector_destroy(detector_); + detector_ = {}; } - *dets++ = result->score; - *labels++ = result->label_id; - result++; - } - counts++; - output.append(py::make_tuple(std::move(_dets), std::move(_labels))); - } - mmdeploy_rotated_detector_release_result(rbboxes, res_count); - return output; - } - ~PyRotatedDetector() { - mmdeploy_rotated_detector_destroy(detector_); - detector_ = {}; - } - private: - mmdeploy_rotated_detector_t detector_{}; -}; + private: + mmdeploy_rotated_detector_t detector_{}; + }; -static PythonBindingRegisterer register_rotated_detector{[](py::module& m) { - py::class_(m, "RotatedDetector") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyRotatedDetector* self, const PyImage& img) -> py::tuple { - return self->Apply(std::vector{img})[0]; - }) - .def("batch", &PyRotatedDetector::Apply); -}}; + static PythonBindingRegisterer register_rotated_detector{[](py::module& m) + { + py::class_(m, "RotatedDetector") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyRotatedDetector* self, const PyImage& img) -> py::tuple + { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PyRotatedDetector::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/segmentor.cpp b/csrc/mmdeploy/apis/python/segmentor.cpp index 940972ab61..9e1db508c7 100644 --- a/csrc/mmdeploy/apis/python/segmentor.cpp +++ b/csrc/mmdeploy/apis/python/segmentor.cpp @@ -4,74 +4,91 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PySegmentor { - public: - PySegmentor(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_segmentor_create_by_path(model_path, device_name, device_id, &segmentor_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create segmentor"); - } - } - ~PySegmentor() { - mmdeploy_segmentor_destroy(segmentor_); - segmentor_ = {}; - } + class PySegmentor + { + public: + PySegmentor(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_segmentor_create_by_path(model_path, device_name, device_id, &segmentor_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create segmentor"); + } + } + ~PySegmentor() + { + mmdeploy_segmentor_destroy(segmentor_); + segmentor_ = {}; + } - std::vector Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } - mmdeploy_segmentation_t* segm{}; - auto status = mmdeploy_segmentor_apply(segmentor_, mats.data(), (int)mats.size(), &segm); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply segmentor, code: " + std::to_string(status)); - } - using Sptr = std::shared_ptr; - Sptr holder(segm, [n = mats.size()](auto p) { mmdeploy_segmentor_release_result(p, n); }); + std::vector Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } + mmdeploy_segmentation_t* segm{}; + auto status = mmdeploy_segmentor_apply(segmentor_, mats.data(), (int)mats.size(), &segm); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply segmentor, code: " + std::to_string(status)); + } + using Sptr = std::shared_ptr; + Sptr holder(segm, [n = mats.size()](auto p) + { mmdeploy_segmentor_release_result(p, n); }); - std::vector rets(mats.size()); - for (size_t i = 0; i < mats.size(); ++i) { - if (segm[i].mask != nullptr) { - rets[i] = { - {segm[i].height, segm[i].width}, // shape - segm[i].mask, // mask - py::capsule(new Sptr(holder), // handle - [](void* p) { delete reinterpret_cast(p); }) // - }; - } - if (segm[i].score != nullptr) { - rets[i] = { - {segm[i].classes, segm[i].height, segm[i].width}, // shape - segm[i].score, // score - py::capsule(new Sptr(holder), // handle - [](void* p) { delete reinterpret_cast(p); }) // - }; - } - } - return rets; - } + std::vector rets(mats.size()); + for (size_t i = 0; i < mats.size(); ++i) + { + if (segm[i].mask != nullptr) + { + rets[i] = { + {segm[i].height, segm[i].width}, // shape + segm[i].mask, // mask + py::capsule(new Sptr(holder), // handle + [](void* p) + { delete reinterpret_cast(p); }) // + }; + } + if (segm[i].score != nullptr) + { + rets[i] = { + {segm[i].classes, segm[i].height, segm[i].width}, // shape + segm[i].score, // score + py::capsule(new Sptr(holder), // handle + [](void* p) + { delete reinterpret_cast(p); }) // + }; + } + } + return rets; + } - private: - mmdeploy_segmentor_t segmentor_{}; -}; + private: + mmdeploy_segmentor_t segmentor_{}; + }; -static PythonBindingRegisterer register_segmentor{[](py::module& m) { - py::class_(m, "Segmentor") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PySegmentor* self, const PyImage& img) -> py::array { - return self->Apply(std::vector{img})[0]; - }) - .def("batch", &PySegmentor::Apply); -}}; + static PythonBindingRegisterer register_segmentor{[](py::module& m) + { + py::class_(m, "Segmentor") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PySegmentor* self, const PyImage& img) -> py::array + { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PySegmentor::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/text_detector.cpp b/csrc/mmdeploy/apis/python/text_detector.cpp index 19762d08ec..1326588a1f 100644 --- a/csrc/mmdeploy/apis/python/text_detector.cpp +++ b/csrc/mmdeploy/apis/python/text_detector.cpp @@ -4,68 +4,81 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyTextDetector { - public: - PyTextDetector(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_text_detector_create_by_path(model_path, device_name, device_id, &detector_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create text_detector"); - } - } - std::vector> Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } - mmdeploy_text_detection_t* detection{}; - int* result_count{}; - auto status = mmdeploy_text_detector_apply(detector_, mats.data(), (int)mats.size(), &detection, - &result_count); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply text_detector, code: " + std::to_string(status)); - } - auto output = std::vector>{}; - auto result = detection; - for (int i = 0; i < mats.size(); ++i) { - auto bboxes = py::array_t({result_count[i], 9}); - for (int j = 0; j < result_count[i]; ++j, ++result) { - auto data = bboxes.mutable_data(j); - for (const auto& p : result->bbox) { - *data++ = p.x; - *data++ = p.y; + class PyTextDetector + { + public: + PyTextDetector(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_text_detector_create_by_path(model_path, device_name, device_id, &detector_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create text_detector"); + } + } + std::vector> Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } + mmdeploy_text_detection_t* detection{}; + int* result_count{}; + auto status = mmdeploy_text_detector_apply(detector_, mats.data(), (int)mats.size(), &detection, &result_count); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply text_detector, code: " + std::to_string(status)); + } + auto output = std::vector>{}; + auto result = detection; + for (int i = 0; i < mats.size(); ++i) + { + auto bboxes = py::array_t({result_count[i], 9}); + for (int j = 0; j < result_count[i]; ++j, ++result) + { + auto data = bboxes.mutable_data(j); + for (const auto& p : result->bbox) + { + *data++ = p.x; + *data++ = p.y; + } + *data++ = result->score; + } + output.push_back(std::move(bboxes)); + } + mmdeploy_text_detector_release_result(detection, result_count, (int)mats.size()); + return output; + } + ~PyTextDetector() + { + mmdeploy_text_detector_destroy(detector_); + detector_ = {}; } - *data++ = result->score; - } - output.push_back(std::move(bboxes)); - } - mmdeploy_text_detector_release_result(detection, result_count, (int)mats.size()); - return output; - } - ~PyTextDetector() { - mmdeploy_text_detector_destroy(detector_); - detector_ = {}; - } - private: - mmdeploy_text_detector_t detector_{}; -}; + private: + mmdeploy_text_detector_t detector_{}; + }; -static PythonBindingRegisterer register_text_detector{[](py::module& m) { - py::class_(m, "TextDetector") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyTextDetector* self, const PyImage& img) -> py::array { - return self->Apply(std::vector{img})[0]; - }) - .def("batch", &PyTextDetector::Apply); -}}; + static PythonBindingRegisterer register_text_detector{[](py::module& m) + { + py::class_(m, "TextDetector") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyTextDetector* self, const PyImage& img) -> py::array + { + return self->Apply(std::vector{img})[0]; + }) + .def("batch", &PyTextDetector::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/text_recognizer.cpp b/csrc/mmdeploy/apis/python/text_recognizer.cpp index 317f55103a..1b3bc92af8 100644 --- a/csrc/mmdeploy/apis/python/text_recognizer.cpp +++ b/csrc/mmdeploy/apis/python/text_recognizer.cpp @@ -4,79 +4,99 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyTextRecognizer { - public: - PyTextRecognizer(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_text_recognizer_create_by_path(model_path, device_name, device_id, &recognizer_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create text_recognizer"); - } - } - std::vector>> Apply(const std::vector& imgs) { - std::vector mats; - mats.reserve(imgs.size()); - for (const auto& img : imgs) { - auto mat = GetMat(img); - mats.push_back(mat); - } - mmdeploy_text_recognition_t* results{}; - auto status = - mmdeploy_text_recognizer_apply(recognizer_, mats.data(), (int)mats.size(), &results); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply text_recognizer, code: " + std::to_string(status)); - } - auto output = std::vector>>{}; - for (int i = 0; i < mats.size(); ++i) { - std::vector score(results[i].score, results[i].score + results[i].length); - output.emplace_back(results[i].text, std::move(score)); - } - mmdeploy_text_recognizer_release_result(results, (int)mats.size()); - return output; - } - std::vector>> Apply(const PyImage& img, - const std::vector& bboxes) { - if (bboxes.size() * sizeof(float) % sizeof(mmdeploy_text_detection_t)) { - throw std::invalid_argument("bboxes is not a list of 'mmdeploy_text_detection_t'"); - } - auto mat = GetMat(img); - int bbox_count = bboxes.size() * sizeof(float) / sizeof(mmdeploy_text_detection_t); - mmdeploy_text_recognition_t* results{}; - auto status = mmdeploy_text_recognizer_apply_bbox( - recognizer_, &mat, 1, (mmdeploy_text_detection_t*)bboxes.data(), &bbox_count, &results); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply text_recognizer, code: " + std::to_string(status)); - } - auto output = std::vector>>{}; - for (int i = 0; i < bbox_count; ++i) { - std::vector score(results[i].score, results[i].score + results[i].length); - output.emplace_back(results[i].text, std::move(score)); - } - mmdeploy_text_recognizer_release_result(results, bbox_count); - return output; - } - ~PyTextRecognizer() { - mmdeploy_text_recognizer_destroy(recognizer_); - recognizer_ = {}; - } + class PyTextRecognizer + { + public: + PyTextRecognizer(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_text_recognizer_create_by_path(model_path, device_name, device_id, &recognizer_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create text_recognizer"); + } + } + std::vector>> Apply(const std::vector& imgs) + { + std::vector mats; + mats.reserve(imgs.size()); + for (const auto& img : imgs) + { + auto mat = GetMat(img); + mats.push_back(mat); + } + mmdeploy_text_recognition_t* results{}; + auto status = + mmdeploy_text_recognizer_apply(recognizer_, mats.data(), (int)mats.size(), &results); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply text_recognizer, code: " + std::to_string(status)); + } + auto output = std::vector>>{}; + for (int i = 0; i < mats.size(); ++i) + { + std::vector score(results[i].score, results[i].score + results[i].length); + output.emplace_back(results[i].text, std::move(score)); + } + mmdeploy_text_recognizer_release_result(results, (int)mats.size()); + return output; + } + std::vector>> Apply(const PyImage& img, + const std::vector& bboxes) + { + if (bboxes.size() * sizeof(float) % sizeof(mmdeploy_text_detection_t)) + { + throw std::invalid_argument("bboxes is not a list of 'mmdeploy_text_detection_t'"); + } + auto mat = GetMat(img); + int bbox_count = bboxes.size() * sizeof(float) / sizeof(mmdeploy_text_detection_t); + mmdeploy_text_recognition_t* results{}; + auto status = mmdeploy_text_recognizer_apply_bbox( + recognizer_, + &mat, + 1, + (mmdeploy_text_detection_t*)bboxes.data(), + &bbox_count, + &results); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply text_recognizer, code: " + std::to_string(status)); + } + auto output = std::vector>>{}; + for (int i = 0; i < bbox_count; ++i) + { + std::vector score(results[i].score, results[i].score + results[i].length); + output.emplace_back(results[i].text, std::move(score)); + } + mmdeploy_text_recognizer_release_result(results, bbox_count); + return output; + } + ~PyTextRecognizer() + { + mmdeploy_text_recognizer_destroy(recognizer_); + recognizer_ = {}; + } - private: - mmdeploy_text_recognizer_t recognizer_{}; -}; + private: + mmdeploy_text_recognizer_t recognizer_{}; + }; -static PythonBindingRegisterer register_text_recognizer{[](py::module& m) { - py::class_(m, "TextRecognizer") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", [](PyTextRecognizer* self, - const PyImage& img) { return self->Apply(std::vector{img})[0]; }) - .def("__call__", [](PyTextRecognizer* self, const PyImage& img, - const std::vector& bboxes) { return self->Apply(img, bboxes); }) - .def("batch", py::overload_cast&>(&PyTextRecognizer::Apply)); -}}; + static PythonBindingRegisterer register_text_recognizer{[](py::module& m) + { + py::class_(m, "TextRecognizer") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", [](PyTextRecognizer* self, const PyImage& img) + { return self->Apply(std::vector{img})[0]; }) + .def("__call__", [](PyTextRecognizer* self, const PyImage& img, const std::vector& bboxes) + { return self->Apply(img, bboxes); }) + .def("batch", py::overload_cast&>(&PyTextRecognizer::Apply)); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/apis/python/video_recognizer.cpp b/csrc/mmdeploy/apis/python/video_recognizer.cpp index 7c70337e51..ac2e691be3 100644 --- a/csrc/mmdeploy/apis/python/video_recognizer.cpp +++ b/csrc/mmdeploy/apis/python/video_recognizer.cpp @@ -4,85 +4,102 @@ #include "common.h" -namespace mmdeploy::python { +namespace mmdeploy::python +{ -class PyVideoRecognizer { - public: - PyVideoRecognizer(const char* model_path, const char* device_name, int device_id) { - auto status = - mmdeploy_video_recognizer_create_by_path(model_path, device_name, device_id, &recognizer_); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to create video_recognizer"); - } - } - std::vector>> Apply( - const std::vector>& imgs, const std::vector>& info) { - if (info.size() != imgs.size()) { - throw std::invalid_argument("the length of info is not equal with imgs"); - } - for (int i = 0; i < info.size(); i++) { - if (imgs[i].size() != info[i].first * info[i].second) { - throw std::invalid_argument("invalid info"); - } - } - int total = 0; - for (int i = 0; i < imgs.size(); i++) { - total += imgs[i].size(); - } - std::vector clips; - std::vector clip_info; - clips.reserve(total); - clip_info.reserve(total); - for (int i = 0; i < imgs.size(); i++) { - for (const auto& img : imgs[i]) { - auto mat = GetMat(img); - clips.push_back(mat); - } - clip_info.push_back({info[i].first, info[i].second}); - } + class PyVideoRecognizer + { + public: + PyVideoRecognizer(const char* model_path, const char* device_name, int device_id) + { + auto status = + mmdeploy_video_recognizer_create_by_path(model_path, device_name, device_id, &recognizer_); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to create video_recognizer"); + } + } + std::vector>> Apply( + const std::vector>& imgs, + const std::vector>& info) + { + if (info.size() != imgs.size()) + { + throw std::invalid_argument("the length of info is not equal with imgs"); + } + for (int i = 0; i < info.size(); i++) + { + if (imgs[i].size() != info[i].first * info[i].second) + { + throw std::invalid_argument("invalid info"); + } + } + int total = 0; + for (int i = 0; i < imgs.size(); i++) + { + total += imgs[i].size(); + } + std::vector clips; + std::vector clip_info; + clips.reserve(total); + clip_info.reserve(total); + for (int i = 0; i < imgs.size(); i++) + { + for (const auto& img : imgs[i]) + { + auto mat = GetMat(img); + clips.push_back(mat); + } + clip_info.push_back({info[i].first, info[i].second}); + } - mmdeploy_video_recognition_t* results{}; - int* result_count{}; - auto status = mmdeploy_video_recognizer_apply(recognizer_, clips.data(), clip_info.data(), 1, - &results, &result_count); - if (status != MMDEPLOY_SUCCESS) { - throw std::runtime_error("failed to apply video_recognizer, code: " + std::to_string(status)); - } + mmdeploy_video_recognition_t* results{}; + int* result_count{}; + auto status = mmdeploy_video_recognizer_apply(recognizer_, clips.data(), clip_info.data(), 1, &results, &result_count); + if (status != MMDEPLOY_SUCCESS) + { + throw std::runtime_error("failed to apply video_recognizer, code: " + std::to_string(status)); + } - auto output = std::vector>>{}; - output.reserve(imgs.size()); - auto result_ptr = results; - for (int i = 0; i < imgs.size(); ++i) { - std::vector> label_score; - for (int j = 0; j < result_count[i]; ++j) { - label_score.emplace_back(result_ptr[j].label_id, result_ptr[j].score); - } - output.push_back(std::move(label_score)); - result_ptr += result_count[i]; - } - mmdeploy_video_recognizer_release_result(results, result_count, (int)imgs.size()); - return output; - } + auto output = std::vector>>{}; + output.reserve(imgs.size()); + auto result_ptr = results; + for (int i = 0; i < imgs.size(); ++i) + { + std::vector> label_score; + for (int j = 0; j < result_count[i]; ++j) + { + label_score.emplace_back(result_ptr[j].label_id, result_ptr[j].score); + } + output.push_back(std::move(label_score)); + result_ptr += result_count[i]; + } + mmdeploy_video_recognizer_release_result(results, result_count, (int)imgs.size()); + return output; + } - ~PyVideoRecognizer() { - mmdeploy_video_recognizer_destroy(recognizer_); - recognizer_ = {}; - } + ~PyVideoRecognizer() + { + mmdeploy_video_recognizer_destroy(recognizer_); + recognizer_ = {}; + } - private: - mmdeploy_video_recognizer_t recognizer_{}; -}; + private: + mmdeploy_video_recognizer_t recognizer_{}; + }; -static PythonBindingRegisterer register_video_recognizer{[](py::module& m) { - py::class_(m, "VideoRecognizer") - .def(py::init([](const char* model_path, const char* device_name, int device_id) { - return std::make_unique(model_path, device_name, device_id); - }), - py::arg("model_path"), py::arg("device_name"), py::arg("device_id") = 0) - .def("__call__", - [](PyVideoRecognizer* self, const std::vector& imgs, - const std::pair& info) { return self->Apply({imgs}, {info})[0]; }) - .def("batch", &PyVideoRecognizer::Apply); -}}; + static PythonBindingRegisterer register_video_recognizer{[](py::module& m) + { + py::class_(m, "VideoRecognizer") + .def(py::init([](const char* model_path, const char* device_name, int device_id) + { return std::make_unique(model_path, device_name, device_id); }), + py::arg("model_path"), + py::arg("device_name"), + py::arg("device_id") = 0) + .def("__call__", + [](PyVideoRecognizer* self, const std::vector& imgs, const std::pair& info) + { return self->Apply({imgs}, {info})[0]; }) + .def("batch", &PyVideoRecognizer::Apply); + }}; } // namespace mmdeploy::python diff --git a/csrc/mmdeploy/archive/json_archive.h b/csrc/mmdeploy/archive/json_archive.h index 2803ee22b2..cf03005856 100644 --- a/csrc/mmdeploy/archive/json_archive.h +++ b/csrc/mmdeploy/archive/json_archive.h @@ -7,207 +7,247 @@ #include "mmdeploy/core/archive.h" #include "mmdeploy/core/value.h" -namespace mmdeploy { - -namespace detail { - -template -nlohmann::json to_json_impl(T&& val); - -inline nlohmann::json value_to_json(const Value& value) { - switch (value.type()) { - case ValueType::kNull: - return {}; - case ValueType::kBool: - return value.get(); - case ValueType::kInt: - return value.get(); - case ValueType::kUInt: - return value.get(); - case ValueType::kFloat: - return value.get(); - case ValueType::kString: - return value.get(); - case ValueType::kArray: { - nlohmann::json json = nlohmann::json::value_t::array; - for (const auto& x : value) { - json.push_back(value_to_json(x)); - } - return json; +namespace mmdeploy +{ + + namespace detail + { + + template + nlohmann::json to_json_impl(T&& val); + + inline nlohmann::json value_to_json(const Value& value) + { + switch (value.type()) + { + case ValueType::kNull: + return {}; + case ValueType::kBool: + return value.get(); + case ValueType::kInt: + return value.get(); + case ValueType::kUInt: + return value.get(); + case ValueType::kFloat: + return value.get(); + case ValueType::kString: + return value.get(); + case ValueType::kArray: + { + nlohmann::json json = nlohmann::json::value_t::array; + for (const auto& x : value) + { + json.push_back(value_to_json(x)); + } + return json; + } + case ValueType::kObject: + { + nlohmann::json json = nlohmann::json::value_t::object; + for (auto it = value.begin(); it != value.end(); ++it) + { + auto key = it.key(); + json[key] = value_to_json(*it); + } + return json; + } + case ValueType::kAny: + return ""; + default: + return ""; + } + } + + } // namespace detail + + template>, int> = 0> + nlohmann::json to_json(T&& val) + { + return detail::to_json_impl(std::forward(val)); } - case ValueType::kObject: { - nlohmann::json json = nlohmann::json::value_t::object; - for (auto it = value.begin(); it != value.end(); ++it) { - auto key = it.key(); - json[key] = value_to_json(*it); - } - return json; + + inline nlohmann::json to_json(const Value& value) + { + return detail::value_to_json(value); + } + + // save to JSON + class JsonOutputArchive : public OutputArchive + { + public: + explicit JsonOutputArchive(nlohmann::json& data) + : data_(data) + { + } + + void init(...) {} + + template + void named_value(const std::string& name, T&& val) + { + data_[name] = to_json(std::forward(val)); + } + + template + void item(T&& val) + { + data_.push_back(to_json(std::forward(val))); + } + + template, std::enable_if_t, std::is_same, std::is_same, std::is_same>, int> = 0> + void native(T&& val) + { + data_ = std::forward(val); + } + + private: + nlohmann::json& data_; + }; + + namespace detail + { + + template + inline nlohmann::json to_json_impl(T&& val) + { + nlohmann::json json; + JsonOutputArchive archive(json); + archive(std::forward(val)); + return json; + } + + } // namespace detail + + namespace detail + { + + inline Value json_to_value(const nlohmann::json& json) + { + using value_t = nlohmann::json::value_t; + switch (json.type()) + { + case value_t::null: + return {}; + case value_t::boolean: + return json.get(); + case value_t::number_integer: + return json.get(); + case value_t::number_unsigned: + return json.get(); + case value_t::number_float: + return json.get(); + case value_t::string: + return json.get(); + case value_t::array: + { + Value value = ValueType::kArray; + for (const auto& x : json) + { + value.push_back(json_to_value(x)); + } + return value; + } + case value_t::object: + { + Value value = ValueType::kObject; + for (const auto& proxy : json.items()) + { + value[proxy.key()] = json_to_value(proxy.value()); + } + return value; + } + default: + MMDEPLOY_ERROR("unsupported json type: {}", json.type_name()); + return {}; + } + } + + template + void from_json_impl(const nlohmann::json& json, T&& val); + + } // namespace detail + + template>, int> = 0> + void from_json(const nlohmann::json& json, T&& val) + { + detail::from_json_impl(json, std::forward(val)); } - case ValueType::kAny: - return ""; - default: - return ""; - } -} - -} // namespace detail - -template >, int> = 0> -nlohmann::json to_json(T&& val) { - return detail::to_json_impl(std::forward(val)); -} - -inline nlohmann::json to_json(const Value& value) { return detail::value_to_json(value); } - -// save to JSON -class JsonOutputArchive : public OutputArchive { - public: - explicit JsonOutputArchive(nlohmann::json& data) : data_(data) {} - - void init(...) {} - - template - void named_value(const std::string& name, T&& val) { - data_[name] = to_json(std::forward(val)); - } - - template - void item(T&& val) { - data_.push_back(to_json(std::forward(val))); - } - - template , - std::enable_if_t< - std::disjunction_v, std::is_same, - std::is_same, std::is_same>, - int> = 0> - void native(T&& val) { - data_ = std::forward(val); - } - - private: - nlohmann::json& data_; -}; - -namespace detail { - -template -inline nlohmann::json to_json_impl(T&& val) { - nlohmann::json json; - JsonOutputArchive archive(json); - archive(std::forward(val)); - return json; -} - -} // namespace detail - -namespace detail { - -inline Value json_to_value(const nlohmann::json& json) { - using value_t = nlohmann::json::value_t; - switch (json.type()) { - case value_t::null: - return {}; - case value_t::boolean: - return json.get(); - case value_t::number_integer: - return json.get(); - case value_t::number_unsigned: - return json.get(); - case value_t::number_float: - return json.get(); - case value_t::string: - return json.get(); - case value_t::array: { - Value value = ValueType::kArray; - for (const auto& x : json) { - value.push_back(json_to_value(x)); - } - return value; + + inline void from_json(const nlohmann::json& json, Value& val) + { + val = detail::json_to_value(json); } - case value_t::object: { - Value value = ValueType::kObject; - for (const auto& proxy : json.items()) { - value[proxy.key()] = json_to_value(proxy.value()); - } - return value; + + template + T from_json(const nlohmann::json& json); + + // load from JSON + class JsonInputArchive : public InputArchive + { + public: + explicit JsonInputArchive(const nlohmann::json& data) + : data_(data) + { + } + + template + void init(SizeType& size) + { + size = static_cast(data_.size()); + iter_ = data_.begin(); + } + + template + void named_value(std::string& name, T& val) + { + name = iter_.key(); + from_json(*iter_++, std::forward(val)); + } + + template + void named_value(const std::string& name, T&& val) + { + from_json(data_[name], std::forward(val)); + } + + template + void item(T&& val) + { + from_json(*iter_++, std::forward(val)); + } + + template + void native(T&& val) + { + data_.get_to(val); + } + + private: + const nlohmann::json& data_; + nlohmann::json::const_iterator iter_; + }; + + namespace detail + { + + template + inline void from_json_impl(const nlohmann::json& json, T&& val) + { + JsonInputArchive archive(json); + archive(std::forward(val)); + } + + } // namespace detail + + template + inline T from_json(const nlohmann::json& json) + { + T val{}; + from_json(json, val); + return val; } - default: - MMDEPLOY_ERROR("unsupported json type: {}", json.type_name()); - return {}; - } -} - -template -void from_json_impl(const nlohmann::json& json, T&& val); - -} // namespace detail - -template >, int> = 0> -void from_json(const nlohmann::json& json, T&& val) { - detail::from_json_impl(json, std::forward(val)); -} - -inline void from_json(const nlohmann::json& json, Value& val) { val = detail::json_to_value(json); } - -template -T from_json(const nlohmann::json& json); - -// load from JSON -class JsonInputArchive : public InputArchive { - public: - explicit JsonInputArchive(const nlohmann::json& data) : data_(data) {} - - template - void init(SizeType& size) { - size = static_cast(data_.size()); - iter_ = data_.begin(); - } - - template - void named_value(std::string& name, T& val) { - name = iter_.key(); - from_json(*iter_++, std::forward(val)); - } - - template - void named_value(const std::string& name, T&& val) { - from_json(data_[name], std::forward(val)); - } - - template - void item(T&& val) { - from_json(*iter_++, std::forward(val)); - } - - template - void native(T&& val) { - data_.get_to(val); - } - - private: - const nlohmann::json& data_; - nlohmann::json::const_iterator iter_; -}; - -namespace detail { - -template -inline void from_json_impl(const nlohmann::json& json, T&& val) { - JsonInputArchive archive(json); - archive(std::forward(val)); -} - -} // namespace detail - -template -inline T from_json(const nlohmann::json& json) { - T val{}; - from_json(json, val); - return val; -} - -void from_json(const nlohmann::json& json, Value& val); + + void from_json(const nlohmann::json& json, Value& val); } // namespace mmdeploy diff --git a/csrc/mmdeploy/archive/value_archive.h b/csrc/mmdeploy/archive/value_archive.h index 2f559c1a10..f3245f0dfc 100644 --- a/csrc/mmdeploy/archive/value_archive.h +++ b/csrc/mmdeploy/archive/value_archive.h @@ -6,131 +6,169 @@ #include "mmdeploy/core/archive.h" #include "mmdeploy/core/value.h" -namespace mmdeploy { - -template -Value to_value(T&& val); - -// save to Value -class ValueOutputArchive : public OutputArchive { - public: - explicit ValueOutputArchive(Value& data) : data_(data) {} - - template - void init(array_tag) { - data_ = ValueType::kArray; - } - - template - void init(object_tag) { - data_ = ValueType::kObject; - } - - template - void named_value(const std::string& name, T&& val) { - data_[name] = to_value(std::forward(val)); - } - - template - void item(T&& val) { - data_.push_back(to_value(std::forward(val))); - } - - template , int> = 0> - void native(T&& val) { - data_ = std::forward(val); - }; - - private: - Value& data_; -}; - -template -inline Value to_value(T&& val) { - Value value; - ValueOutputArchive archive(value); - archive(std::forward(val)); - return value; -} - -// fast path -inline Value to_value(const Value& v) { return v; } -inline Value to_value(Value&& v) { return std::move(v); } - -template -void from_value(const Value& value, T&& x); - -template -T from_value(const Value& value); - -// load from Value -class ValueInputArchive : public InputArchive { - public: - explicit ValueInputArchive(const Value& data) : data_(data) {} - - template - void init(SizeType& size) { - size = static_cast(data_.size()); - iter_ = data_.begin(); - } - - template - void named_value(std::string& name, T& val) { - name = iter_.key(); - from_value(*iter_, std::forward(val)); - ++iter_; - } - - template - void named_value(const std::string& name, T&& val) { - from_value(data_[name], std::forward(val)); - } - - template - void item(T&& val) { - from_value(*iter_, std::forward(val)); - ++iter_; - } - - template - void native(T&& val) { - data_.get_to(val); - } - - template - void value(T&& value) {} - - private: - const Value& data_; - Value::const_iterator iter_; -}; - -template -void from_value(const Value& value, T&& x) { - ValueInputArchive archive(value); - archive(std::forward(x)); -} - -// Required to avoid Value::Pointer being unwrapped by Value::get_to() -inline void from_value(const Value& value, Value& x) { x = value; } - -template -inline T from_value(const Value& value) { - T x{}; - from_value(value, x); - return x; -} - -namespace detail { - -inline void load(ValueInputArchive& archive, Value& v) { archive.native(v); } - -template , Value>::value, bool> = true> -inline void save(ValueOutputArchive& archive, T&& v) { - archive.native(std::forward(v)); -} - -} // namespace detail +namespace mmdeploy +{ + + template + Value to_value(T&& val); + + // save to Value + class ValueOutputArchive : public OutputArchive + { + public: + explicit ValueOutputArchive(Value& data) + : data_(data) + { + } + + template + void init(array_tag) + { + data_ = ValueType::kArray; + } + + template + void init(object_tag) + { + data_ = ValueType::kObject; + } + + template + void named_value(const std::string& name, T&& val) + { + data_[name] = to_value(std::forward(val)); + } + + template + void item(T&& val) + { + data_.push_back(to_value(std::forward(val))); + } + + template, int> = 0> + void native(T&& val) + { + data_ = std::forward(val); + }; + + private: + Value& data_; + }; + + template + inline Value to_value(T&& val) + { + Value value; + ValueOutputArchive archive(value); + archive(std::forward(val)); + return value; + } + + // fast path + inline Value to_value(const Value& v) + { + return v; + } + inline Value to_value(Value&& v) + { + return std::move(v); + } + + template + void from_value(const Value& value, T&& x); + + template + T from_value(const Value& value); + + // load from Value + class ValueInputArchive : public InputArchive + { + public: + explicit ValueInputArchive(const Value& data) + : data_(data) + { + } + + template + void init(SizeType& size) + { + size = static_cast(data_.size()); + iter_ = data_.begin(); + } + + template + void named_value(std::string& name, T& val) + { + name = iter_.key(); + from_value(*iter_, std::forward(val)); + ++iter_; + } + + template + void named_value(const std::string& name, T&& val) + { + from_value(data_[name], std::forward(val)); + } + + template + void item(T&& val) + { + from_value(*iter_, std::forward(val)); + ++iter_; + } + + template + void native(T&& val) + { + data_.get_to(val); + } + + template + void value(T&& value) + { + } + + private: + const Value& data_; + Value::const_iterator iter_; + }; + + template + void from_value(const Value& value, T&& x) + { + ValueInputArchive archive(value); + archive(std::forward(x)); + } + + // Required to avoid Value::Pointer being unwrapped by Value::get_to() + inline void from_value(const Value& value, Value& x) + { + x = value; + } + + template + inline T from_value(const Value& value) + { + T x{}; + from_value(value, x); + return x; + } + + namespace detail + { + + inline void load(ValueInputArchive& archive, Value& v) + { + archive.native(v); + } + + template, Value>::value, bool> = true> + inline void save(ValueOutputArchive& archive, T&& v) + { + archive.native(std::forward(v)); + } + + } // namespace detail } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/common_cuda_helper.cuh b/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/common_cuda_helper.cuh index 02c57c62e6..d1b3195669 100644 --- a/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/common_cuda_helper.cuh +++ b/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/common_cuda_helper.cuh @@ -8,25 +8,27 @@ #include #define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) #define THREADS_PER_BLOCK 512 #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) -inline int GET_BLOCKS(const int N) { - int optimal_block_num = DIVUP(N, THREADS_PER_BLOCK); - int max_block_num = 4096; - return std::min(optimal_block_num, max_block_num); +inline int GET_BLOCKS(const int N) +{ + int optimal_block_num = DIVUP(N, THREADS_PER_BLOCK); + int max_block_num = 4096; + return std::min(optimal_block_num, max_block_num); } -#define cudaCheckError() \ - { \ - cudaError_t e = cudaGetLastError(); \ - if (e != cudaSuccess) { \ - printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(0); \ - } \ - } +#define cudaCheckError() \ + { \ + cudaError_t e = cudaGetLastError(); \ + if (e != cudaSuccess) \ + { \ + printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(0); \ + } \ + } /** * Returns a view of the original tensor with its dimensions permuted. @@ -38,57 +40,59 @@ inline int GET_BLOCKS(const int N) { * @param[in] src_dim dim of src tensor * @param[in] stream cuda stream handle */ -template -void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size, int* permute, int src_dim, - cudaStream_t stream = 0); - -template -cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, const scalar_t* alpha, - const scalar_t* A, int lda, const scalar_t* B, int ldb, - const scalar_t* beta, scalar_t* C, int ldc); - -template -__device__ scalar_t bilinear_interpolate(const scalar_t* input, const int height, const int width, - scalar_t y, scalar_t x) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; - - if (y <= 0) y = 0; - if (x <= 0) x = 0; - - int y_low = (int)y; - int x_low = (int)x; - int y_high; - int x_high; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (scalar_t)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (scalar_t)x_low; - } else { - x_high = x_low + 1; - } - - scalar_t ly = y - y_low; - scalar_t lx = x - x_low; - scalar_t hy = 1. - ly, hx = 1. - lx; - // do bilinear interpolation - scalar_t v1 = input[y_low * width + x_low]; - scalar_t v2 = input[y_low * width + x_high]; - scalar_t v3 = input[y_high * width + x_low]; - scalar_t v4 = input[y_high * width + x_high]; - scalar_t w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - - return val; +template +void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size, int* permute, int src_dim, cudaStream_t stream = 0); + +template +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const scalar_t* alpha, const scalar_t* A, int lda, const scalar_t* B, int ldb, const scalar_t* beta, scalar_t* C, int ldc); + +template +__device__ scalar_t bilinear_interpolate(const scalar_t* input, const int height, const int width, scalar_t y, scalar_t x) +{ + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) + { + y_high = y_low = height - 1; + y = (scalar_t)y_low; + } + else + { + y_high = y_low + 1; + } + + if (x_low >= width - 1) + { + x_high = x_low = width - 1; + x = (scalar_t)x_low; + } + else + { + x_high = x_low + 1; + } + + scalar_t ly = y - y_low; + scalar_t lx = x - x_low; + scalar_t hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation + scalar_t v1 = input[y_low * width + x_low]; + scalar_t v2 = input[y_low * width + x_high]; + scalar_t v3 = input[y_high * width + x_low]; + scalar_t v4 = input[y_high * width + x_high]; + scalar_t w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; } #endif // COMMON_CUDA_HELPER diff --git a/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cpu.h b/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cpu.h index a37e243109..4bd17cd0d3 100644 --- a/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cpu.h +++ b/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cpu.h @@ -1,82 +1,83 @@ #include #include -template -T bilinear_interpolate_2d(const T *src, const int64_t src_h, const int64_t src_w, const T h, - const T w) { - if (h <= -1 || src_h <= h || w <= -1 || src_w <= w) { - return 0; - } +template +T bilinear_interpolate_2d(const T* src, const int64_t src_h, const int64_t src_w, const T h, const T w) +{ + if (h <= -1 || src_h <= h || w <= -1 || src_w <= w) + { + return 0; + } - int64_t h_low = floor(h); - int64_t w_low = floor(w); - int64_t h_high = h_low + 1; - int64_t w_high = w_low + 1; + int64_t h_low = floor(h); + int64_t w_low = floor(w); + int64_t h_high = h_low + 1; + int64_t w_high = w_low + 1; - T lh = h - h_low; - T lw = w - w_low; - T hh = 1 - lh; - T hw = 1 - lw; + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh; + T hw = 1 - lw; - T v1 = 0; - if (h_low >= 0 && w_low >= 0) v1 = src[h_low * src_w + w_low]; - T v2 = 0; - if (h_low >= 0 && w_high <= src_w - 1) v2 = src[h_low * src_w + w_high]; - T v3 = 0; - if (h_high <= src_h - 1 && w_low >= 0) v3 = src[h_high * src_w + w_low]; - T v4 = 0; - if (h_high <= src_h - 1 && w_high <= src_w - 1) v4 = src[h_high * src_w + w_high]; + T v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = src[h_low * src_w + w_low]; + T v2 = 0; + if (h_low >= 0 && w_high <= src_w - 1) v2 = src[h_low * src_w + w_high]; + T v3 = 0; + if (h_high <= src_h - 1 && w_low >= 0) v3 = src[h_high * src_w + w_low]; + T v4 = 0; + if (h_high <= src_h - 1 && w_high <= src_w - 1) v4 = src[h_high * src_w + w_high]; - T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; } // output: (channels * kernel_h * kernel_w, dst_h * dst_w) -template -void deformable_im2col_2d(const T *input, const T *offset, const T *mask, const int64_t src_h, - const int64_t src_w, const int64_t kernel_h, const int64_t kernel_w, - const int64_t pad_h, const int64_t pad_w, const int64_t stride_h, - const int64_t stride_w, const int64_t dilation_h, - const int64_t dilation_w, const int64_t channels, - const int64_t offset_groups, const int64_t dst_h, const int64_t dst_w, - const bool use_mask, T *columns) { - const int64_t workload = channels * dst_h * dst_w; - for (int64_t index = 0; index != workload; ++index) { - const int64_t ow = index % dst_w; - const int64_t oh = (index / dst_w) % dst_h; - const int64_t ic = index / (dst_w * dst_h); - const int64_t oc = ic * kernel_h * kernel_w; +template +void deformable_im2col_2d(const T* input, const T* offset, const T* mask, const int64_t src_h, const int64_t src_w, const int64_t kernel_h, const int64_t kernel_w, const int64_t pad_h, const int64_t pad_w, const int64_t stride_h, const int64_t stride_w, const int64_t dilation_h, const int64_t dilation_w, const int64_t channels, const int64_t offset_groups, const int64_t dst_h, const int64_t dst_w, const bool use_mask, T* columns) +{ + const int64_t workload = channels * dst_h * dst_w; + for (int64_t index = 0; index != workload; ++index) + { + const int64_t ow = index % dst_w; + const int64_t oh = (index / dst_w) % dst_h; + const int64_t ic = index / (dst_w * dst_h); + const int64_t oc = ic * kernel_h * kernel_w; - int64_t c_per_offset_grp = channels / offset_groups; - const int64_t grp_idx = ic / c_per_offset_grp; + int64_t c_per_offset_grp = channels / offset_groups; + const int64_t grp_idx = ic / c_per_offset_grp; - auto columns_ptr = columns + (oc * (dst_h * dst_w) + oh * dst_w + ow); - auto input_ptr = input + ic * (src_h * src_w); - auto offset_ptr = offset + grp_idx * 2 * kernel_h * kernel_w * dst_h * dst_w; - auto mask_ptr = mask; - if (use_mask) { - mask_ptr += grp_idx * kernel_h * kernel_w * dst_h * dst_w; - } + auto columns_ptr = columns + (oc * (dst_h * dst_w) + oh * dst_w + ow); + auto input_ptr = input + ic * (src_h * src_w); + auto offset_ptr = offset + grp_idx * 2 * kernel_h * kernel_w * dst_h * dst_w; + auto mask_ptr = mask; + if (use_mask) + { + mask_ptr += grp_idx * kernel_h * kernel_w * dst_h * dst_w; + } - for (int64_t kh = 0; kh < kernel_h; ++kh) { - for (int64_t kw = 0; kw < kernel_w; ++kw) { - const int64_t mask_idx = kh * kernel_w + kw; - const int64_t offset_idx = 2 * mask_idx; + for (int64_t kh = 0; kh < kernel_h; ++kh) + { + for (int64_t kw = 0; kw < kernel_w; ++kw) + { + const int64_t mask_idx = kh * kernel_w + kw; + const int64_t offset_idx = 2 * mask_idx; - T mask_value = 1; - if (use_mask) { - mask_value = mask_ptr[mask_idx * (dst_h * dst_w) + oh * dst_w + ow]; - } + T mask_value = 1; + if (use_mask) + { + mask_value = mask_ptr[mask_idx * (dst_h * dst_w) + oh * dst_w + ow]; + } - const T offset_h = offset_ptr[offset_idx * (dst_h * dst_w) + oh * dst_w + ow]; - const T offset_w = offset_ptr[(offset_idx + 1) * (dst_h * dst_w) + oh * dst_w + ow]; - const T ih = (oh * stride_h - pad_h) + kh * dilation_h + offset_h; - const T iw = (ow * stride_w - pad_w) + kw * dilation_w + offset_w; - *columns_ptr = mask_value * bilinear_interpolate_2d(input_ptr, src_h, src_w, ih, iw); - columns_ptr += dst_h * dst_w; - } + const T offset_h = offset_ptr[offset_idx * (dst_h * dst_w) + oh * dst_w + ow]; + const T offset_w = offset_ptr[(offset_idx + 1) * (dst_h * dst_w) + oh * dst_w + ow]; + const T ih = (oh * stride_h - pad_h) + kh * dilation_h + offset_h; + const T iw = (ow * stride_w - pad_w) + kw * dilation_w + offset_w; + *columns_ptr = mask_value * bilinear_interpolate_2d(input_ptr, src_h, src_w, ih, iw); + columns_ptr += dst_h * dst_w; + } + } } - } } diff --git a/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cuda.cuh b/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cuda.cuh index 43166e7d6b..6051c4762b 100644 --- a/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cuda.cuh +++ b/csrc/mmdeploy/backend_ops/common/modulated_deform_conv/modulated_deform_conv_cuda.cuh @@ -71,110 +71,130 @@ #include "common_cuda_helper.cuh" -template -__device__ float mdcn_im2col_bilinear(const T *input, const int data_width, const int height, - const int width, float h, float w) { - int h_low = floorf(h); - int w_low = floorf(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - T lh = h - h_low; - T lw = w - w_low; - T hh = 1 - lh, hw = 1 - lw; - - T v1 = 0; - if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low]; - T v2 = 0; - if (h_low >= 0 && w_high <= width - 1) v2 = input[h_low * data_width + w_high]; - T v3 = 0; - if (h_high <= height - 1 && w_low >= 0) v3 = input[h_high * data_width + w_low]; - T v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) v4 = input[h_high * data_width + w_high]; - - T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return float(val); +template +__device__ float mdcn_im2col_bilinear(const T* input, const int data_width, const int height, const int width, float h, float w) +{ + int h_low = floorf(h); + int w_low = floorf(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + T lh = h - h_low; + T lw = w - w_low; + T hh = 1 - lh, hw = 1 - lw; + + T v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = input[h_low * data_width + w_low]; + T v2 = 0; + if (h_low >= 0 && w_high <= width - 1) v2 = input[h_low * data_width + w_high]; + T v3 = 0; + if (h_high <= height - 1 && w_low >= 0) v3 = input[h_high * data_width + w_low]; + T v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) v4 = input[h_high * data_width + w_high]; + + T w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return float(val); } -template <> -__device__ float mdcn_im2col_bilinear<__half>(const __half *input, const int data_width, - const int height, const int width, float h, float w) { - int h_low = floorf(h); - int w_low = floorf(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - float lh = h - h_low; - float lw = w - w_low; - float hh = 1 - lh, hw = 1 - lw; - - float v1 = 0; - if (h_low >= 0 && w_low >= 0) v1 = __half2float(input[h_low * data_width + w_low]); - float v2 = 0; - if (h_low >= 0 && w_high <= width - 1) v2 = __half2float(input[h_low * data_width + w_high]); - float v3 = 0; - if (h_high <= height - 1 && w_low >= 0) v3 = __half2float(input[h_high * data_width + w_low]); - float v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - v4 = __half2float(input[h_high * data_width + w_high]); - - float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; +template<> +__device__ float mdcn_im2col_bilinear<__half>(const __half* input, const int data_width, const int height, const int width, float h, float w) +{ + int h_low = floorf(h); + int w_low = floorf(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h - h_low; + float lw = w - w_low; + float hh = 1 - lh, hw = 1 - lw; + + float v1 = 0; + if (h_low >= 0 && w_low >= 0) v1 = __half2float(input[h_low * data_width + w_low]); + float v2 = 0; + if (h_low >= 0 && w_high <= width - 1) v2 = __half2float(input[h_low * data_width + w_high]); + float v3 = 0; + if (h_high <= height - 1 && w_low >= 0) v3 = __half2float(input[h_high * data_width + w_low]); + float v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = __half2float(input[h_high * data_width + w_high]); + + float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; } -template +template __global__ void modulated_deformable_im2col_gpu_kernel( - const int n, const T *data_im, const T *data_offset, const T *data_mask, const int height, - const int width, const int kernel_h, const int kernel_w, const int pad_h, const int pad_w, - const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, const int batch_size, const int num_channels, - const int deformable_group, const int height_col, const int width_col, T *data_col) { - CUDA_1D_KERNEL_LOOP(index, n) { - // index index of output matrix - const int w_col = index % width_col; - const int h_col = (index / width_col) % height_col; - const int b_col = (index / width_col / height_col) % batch_size; - const int c_im = (index / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; - - // compute deformable group index - const int deformable_group_index = c_im / channel_per_deformable_group; - - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - - T *data_col_ptr = - data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; - const T *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; - const T *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * - 2 * kernel_h * kernel_w * height_col * width_col; - - const T *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * - kernel_h * kernel_w * height_col * width_col; - - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = - ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = - ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; - const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; - const T offset_h = data_offset_ptr[data_offset_h_ptr]; - const T offset_w = data_offset_ptr[data_offset_w_ptr]; - const T mask = data_mask_ptr[data_mask_hw_ptr]; - float val = 0.0f; - const float h_im = h_in + i * dilation_h + (float)offset_h; - const float w_im = w_in + j * dilation_w + (float)offset_w; - if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) - val = mdcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); - *data_col_ptr = (T)(val * (float)mask); - data_col_ptr += batch_size * height_col * width_col; - } + const int n, + const T* data_im, + const T* data_offset, + const T* data_mask, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int num_channels, + const int deformable_group, + const int height_col, + const int width_col, + T* data_col) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + T* data_col_ptr = + data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + const T* data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const T* data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * + 2 * kernel_h * kernel_w * height_col * width_col; + + const T* data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * + kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const T offset_h = data_offset_ptr[data_offset_h_ptr]; + const T offset_w = data_offset_ptr[data_offset_w_ptr]; + const T mask = data_mask_ptr[data_mask_hw_ptr]; + float val = 0.0f; + const float h_im = h_in + i * dilation_h + (float)offset_h; + const float w_im = w_in + j * dilation_w + (float)offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + val = mdcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + *data_col_ptr = (T)(val * (float)mask); + data_col_ptr += batch_size * height_col * width_col; + } + } } - } } #endif // TRT_MODULATED_DEFORM_CONV_KERNEL_CUH diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp index 4d620e4c82..274ba76bca 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.cpp @@ -1,355 +1,402 @@ // Copyright (c) OpenMMLab. All rights reserved. #include "fuse_pass.h" -void fuse_identity(onnx::GraphProto* mutable_graph, +void fuse_identity(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - // fuse - // identity --> op - // to - // noop_reducencnn --> op - const int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; ++i) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - for (int j = 0; j < node->input_size(); ++j) { - std::string output_name = node->input(j); - onnx::NodeProto* last_node = find_node_by_output_name(mutable_graph, output_name); - if (last_node && last_node->op_type() == "Identity") { - node->set_input(j, last_node->input(0)); - node_reference[last_node->output(0)] -= 1; - node_reference[last_node->input(0)] += 1; - if (node_reference[last_node->output(0)] == 0) { - last_node->set_op_type("noop_reducedncnn"); - node_reference[last_node->input(0)] -= 1; - reduced_node_count += 1; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + // fuse + // identity --> op + // to + // noop_reducencnn --> op + const int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; ++i) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + for (int j = 0; j < node->input_size(); ++j) + { + std::string output_name = node->input(j); + onnx::NodeProto* last_node = find_node_by_output_name(mutable_graph, output_name); + if (last_node && last_node->op_type() == "Identity") + { + node->set_input(j, last_node->input(0)); + node_reference[last_node->output(0)] -= 1; + node_reference[last_node->input(0)] += 1; + if (node_reference[last_node->output(0)] == 0) + { + last_node->set_op_type("noop_reducedncnn"); + node_reference[last_node->input(0)] -= 1; + reduced_node_count += 1; + } + } } - } } - } } -void fuse_rewrite_gather(onnx::GraphProto* mutable_graph, +void fuse_rewrite_gather(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - const int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; ++i) { - onnx::NodeProto* gather = mutable_graph->mutable_node(i); - if (gather->op_type() != "Gather") { - continue; - } - if (weights.find(std::string(gather->input(1))) == weights.end()) { - continue; - } - auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]); - if (indices.size() != 1) { - continue; - } - + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + const int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; ++i) { - // reconstruct node connections - node_reference[gather->input(1)] -= 1; - std::string origin_inp = gather->input(0); - gather->clear_input(); - gather->add_input(origin_inp); - } + onnx::NodeProto* gather = mutable_graph->mutable_node(i); + if (gather->op_type() != "Gather") + { + continue; + } + if (weights.find(std::string(gather->input(1))) == weights.end()) + { + continue; + } + auto indices = get_node_attr_from_input_ai(weights[gather->input(1)]); + if (indices.size() != 1) + { + continue; + } - { - // update axis, starts and ends - int axis = get_node_attr_i(*gather, "axis", 1) - 1; + { + // reconstruct node connections + node_reference[gather->input(1)] -= 1; + std::string origin_inp = gather->input(0); + gather->clear_input(); + gather->add_input(origin_inp); + } + + { + // update axis, starts and ends + int axis = get_node_attr_i(*gather, "axis", 1) - 1; - gather->set_op_type("Crop"); - gather->clear_attribute(); + gather->set_op_type("Crop"); + gather->clear_attribute(); - int indice = indices[0]; - set_node_attr_ai(*gather, "starts", std::vector{indice}); - set_node_attr_ai(*gather, "ends", std::vector{indice + 1}); - set_node_attr_ai(*gather, "axis", std::vector{axis}); + int indice = indices[0]; + set_node_attr_ai(*gather, "starts", std::vector{indice}); + set_node_attr_ai(*gather, "ends", std::vector{indice + 1}); + set_node_attr_ai(*gather, "axis", std::vector{axis}); + } } - } } -void fuse_weight_reshape(onnx::GraphProto* mutable_graph, +void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // weight <= Reshape(weight) - if (node->op_type() == "Reshape") { - // check weight - if (weights.find(node->input(0)) == weights.end()) continue; - - weights[node->output(0)] = weights[node->input(0)]; - - // set weight shape directly - std::vector shape; - if (node->input_size() == 1) { - shape = get_node_attr_ai(*node, "shape"); - } else if (node->input_size() == 2) { - // opset 5 - shape = get_node_attr_from_input_ai(weights[node->input(1)]); - } - - weights[node->output(0)].clear_dims(); - for (int j = 0; j < shape.size(); j++) { - weights[node->output(0)].add_dims(shape[j]); - } - - // reduce - node->set_op_type("noop_reducedncnn"); - - node_reference[node->input(0)] -= 1; - if (node->input_size() == 2) { - node_reference[node->input(1)] -= 1; - } - - reduced_node_count += 1; - i += 1; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // weight <= Reshape(weight) + if (node->op_type() == "Reshape") + { + // check weight + if (weights.find(node->input(0)) == weights.end()) continue; + + weights[node->output(0)] = weights[node->input(0)]; + + // set weight shape directly + std::vector shape; + if (node->input_size() == 1) + { + shape = get_node_attr_ai(*node, "shape"); + } + else if (node->input_size() == 2) + { + // opset 5 + shape = get_node_attr_from_input_ai(weights[node->input(1)]); + } + + weights[node->output(0)].clear_dims(); + for (int j = 0; j < shape.size(); j++) + { + weights[node->output(0)].add_dims(shape[j]); + } + + // reduce + node->set_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= 1; + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + + reduced_node_count += 1; + i += 1; + } } - } } -void fuse_weight_transpose(onnx::GraphProto* mutable_graph, +void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // weight <= Transpose(weight) - if (node->op_type() == "Transpose") { - // check weight - if (weights.find(node->input(0)) == weights.end()) continue; - - if (weights[node->input(0)].dims_size() != 2) continue; - - // perm = (1, 0) - std::vector perm = get_node_attr_ai(*node, "perm"); - if (perm.size() != 2) continue; - if (perm[0] != 1 || perm[1] != 0) continue; - - weights[node->output(0)] = weights[node->input(0)]; - - // permute weight - { - onnx::TensorProto& B = weights[node->output(0)]; - - const int h = B.dims(0); - const int w = B.dims(1); - - std::vector permuted_data; - permuted_data.reserve((size_t)h * w); - const float* bptr = - B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); - - for (int j = 0; j < w; j++) { - for (int k = 0; k < h; k++) { - float vb = bptr[k * w + j]; - permuted_data.push_back(vb); - } - } - - B.set_dims(0, w); - B.set_dims(1, h); - - if (B.has_raw_data()) { - B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float)); - } else { - for (int j = 0; j < (int)permuted_data.size(); j++) B.set_float_data(j, permuted_data[j]); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // weight <= Transpose(weight) + if (node->op_type() == "Transpose") + { + // check weight + if (weights.find(node->input(0)) == weights.end()) continue; + + if (weights[node->input(0)].dims_size() != 2) continue; + + // perm = (1, 0) + std::vector perm = get_node_attr_ai(*node, "perm"); + if (perm.size() != 2) continue; + if (perm[0] != 1 || perm[1] != 0) continue; + + weights[node->output(0)] = weights[node->input(0)]; + + // permute weight + { + onnx::TensorProto& B = weights[node->output(0)]; + + const int h = B.dims(0); + const int w = B.dims(1); + + std::vector permuted_data; + permuted_data.reserve((size_t)h * w); + const float* bptr = + B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); + + for (int j = 0; j < w; j++) + { + for (int k = 0; k < h; k++) + { + float vb = bptr[k * w + j]; + permuted_data.push_back(vb); + } + } + + B.set_dims(0, w); + B.set_dims(1, h); + + if (B.has_raw_data()) + { + B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float)); + } + else + { + for (int j = 0; j < (int)permuted_data.size(); j++) B.set_float_data(j, permuted_data[j]); + } + } + + // reduce + node->set_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= 1; + + reduced_node_count += 1; + i += 1; } - } - - // reduce - node->set_op_type("noop_reducedncnn"); - - node_reference[node->input(0)] -= 1; - - reduced_node_count += 1; - i += 1; } - } } -void fuse_shufflechannel(onnx::GraphProto* mutable_graph, +void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // ShuffleChannel <= Reshape - Transpose - Reshape - // ShuffleChannel <= Reshape - Transpose - Constant - Reshape - if (node->op_type() == "Reshape") { - if (node_reference[node->output(0)] != 1) continue; - - std::vector shape; - if (node->input_size() == 1) { - shape = get_node_attr_ai(*node, "shape"); - } else { - // skip weight reshape - if (weights.find(node->input(1)) == weights.end()) continue; - - shape = get_node_attr_from_input_ai(weights[node->input(1)]); - } - - // 1 groups channels_per_group, height, width - // reverse style = channels_per_group, groups, height * width - if (shape.size() != 5 && shape.size() != 3) continue; - - if (shape.size() == 5 && shape[0] != 1) continue; - - if (i + 2 >= node_count) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - - if (node3->op_type() == "Constant") { - if (i + 3 >= node_count) continue; - - node3 = mutable_graph->mutable_node(i + 3); - } - - if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; - - if (node_reference[node2->output(0)] != 1) continue; - - // 0 2 1 3 4 - // reverse style = 1 0 2 - std::vector perm = get_node_attr_ai(*node2, "perm"); - if (perm.size() != 5 && perm.size() != 3) continue; - - if (perm.size() == 5 && - (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4)) - continue; - - if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)) continue; - - std::vector shape3; - if (node3->input_size() == 1) { - shape3 = get_node_attr_ai(*node3, "shape"); - } else { - // skip weight reshape - if (weights.find(node3->input(1)) == weights.end()) continue; - - shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); - } - - // 1, -1, height, width - // reverse style = group, -1, channels_per_group, height, width - if (shape3.size() != 4 && shape3.size() != 5) continue; - - if (shape3.size() == 4 && - (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2]))) - continue; - - if (shape3.size() == 5 && - (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2])) - continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - - if (node->input_size() == 2) { - node_reference[node->input(1)] -= 1; - } - node_reference[node->output(0)] -= 1; - node_reference[node2->output(0)] -= 1; - if (node3->input_size() == 2) { - node_reference[node3->input(1)] -= 1; - } - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - - node3->set_op_type("ShuffleChannel"); - node3->set_input(0, node->input(0)); - - onnx::AttributeProto* attr_group = node3->add_attribute(); - attr_group->set_name("group"); - attr_group->set_i(shape[1]); - - onnx::AttributeProto* attr_reverse = node3->add_attribute(); - attr_reverse->set_name("reverse"); - attr_reverse->set_i(shape.size() == 3); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // ShuffleChannel <= Reshape - Transpose - Reshape + // ShuffleChannel <= Reshape - Transpose - Constant - Reshape + if (node->op_type() == "Reshape") + { + if (node_reference[node->output(0)] != 1) continue; + + std::vector shape; + if (node->input_size() == 1) + { + shape = get_node_attr_ai(*node, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node->input(1)) == weights.end()) continue; + + shape = get_node_attr_from_input_ai(weights[node->input(1)]); + } + + // 1 groups channels_per_group, height, width + // reverse style = channels_per_group, groups, height * width + if (shape.size() != 5 && shape.size() != 3) continue; + + if (shape.size() == 5 && shape[0] != 1) continue; + + if (i + 2 >= node_count) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + + if (node3->op_type() == "Constant") + { + if (i + 3 >= node_count) continue; + + node3 = mutable_graph->mutable_node(i + 3); + } + + if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; + + if (node_reference[node2->output(0)] != 1) continue; + + // 0 2 1 3 4 + // reverse style = 1 0 2 + std::vector perm = get_node_attr_ai(*node2, "perm"); + if (perm.size() != 5 && perm.size() != 3) continue; + + if (perm.size() == 5 && + (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4)) + continue; + + if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)) continue; + + std::vector shape3; + if (node3->input_size() == 1) + { + shape3 = get_node_attr_ai(*node3, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node3->input(1)) == weights.end()) continue; + + shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); + } + + // 1, -1, height, width + // reverse style = group, -1, channels_per_group, height, width + if (shape3.size() != 4 && shape3.size() != 5) continue; + + if (shape3.size() == 4 && + (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2]))) + continue; + + if (shape3.size() == 5 && + (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2])) + continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + + node3->set_op_type("ShuffleChannel"); + node3->set_input(0, node->input(0)); + + onnx::AttributeProto* attr_group = node3->add_attribute(); + attr_group->set_name("group"); + attr_group->set_i(shape[1]); + + onnx::AttributeProto* attr_reverse = node3->add_attribute(); + attr_reverse->set_name("reverse"); + attr_reverse->set_i(shape.size() == 3); - reduced_node_count += 2; - i += 2; + reduced_node_count += 2; + i += 2; + } } - } } -void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, +void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1) - if (node->op_type() == "ShuffleChannel") { - // reverse = 1 - int reverse = get_node_attr_i(*node, "reverse"); - if (reverse != 1) continue; + // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1) + if (node->op_type() == "ShuffleChannel") + { + // reverse = 1 + int reverse = get_node_attr_i(*node, "reverse"); + if (reverse != 1) continue; - if (i + 2 >= node_count) continue; + if (i + 2 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - if (node2->op_type() != "Gather" || node3->op_type() != "Gather") continue; + if (node2->op_type() != "Gather" || node3->op_type() != "Gather") continue; - if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0)) continue; + if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0)) continue; - // axis = 0 - int gather2_axis = get_node_attr_i(*node2, "axis"); - if (gather2_axis != 0) continue; + // axis = 0 + int gather2_axis = get_node_attr_i(*node2, "axis"); + if (gather2_axis != 0) continue; - // indices = 0 - if (weights.find(node2->input(1)) == weights.end()) continue; + // indices = 0 + if (weights.find(node2->input(1)) == weights.end()) continue; - std::vector gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); - if (gather2_indices.size() != 1 || gather2_indices[0] != 0) continue; + std::vector gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); + if (gather2_indices.size() != 1 || gather2_indices[0] != 0) continue; - // axis = 0 - int gather3_axis = get_node_attr_i(*node3, "axis"); - if (gather3_axis != 0) continue; + // axis = 0 + int gather3_axis = get_node_attr_i(*node3, "axis"); + if (gather3_axis != 0) continue; - // indices = 1 - if (weights.find(node3->input(1)) == weights.end()) continue; + // indices = 1 + if (weights.find(node3->input(1)) == weights.end()) continue; - std::vector gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]); - if (gather3_indices.size() != 1 || gather3_indices[0] != 1) continue; + std::vector gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]); + if (gather3_indices.size() != 1 || gather3_indices[0] != 1) continue; - // reduce - node2->set_op_type("noop_reducedncnn"); + // reduce + node2->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 2; - node_reference[node2->input(1)] -= 1; - node_reference[node3->input(1)] -= 1; + node_reference[node->output(0)] -= 2; + node_reference[node2->input(1)] -= 1; + node_reference[node3->input(1)] -= 1; - node3->set_op_type("Split"); - node3->clear_input(); - node3->add_input(node->output(0)); - node3->add_output(node3->output(0)); - node3->set_output(0, node2->output(0)); + node3->set_op_type("Split"); + node3->clear_input(); + node3->add_input(node->output(0)); + node3->add_output(node3->output(0)); + node3->set_output(0, node2->output(0)); - node3->clear_attribute(); - onnx::AttributeProto* attr_axis = node3->add_attribute(); - attr_axis->set_name("axis"); - attr_axis->set_i(1); + node3->clear_attribute(); + onnx::AttributeProto* attr_axis = node3->add_attribute(); + attr_axis->set_name("axis"); + attr_axis->set_i(1); - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; + } } - } } /** @@ -369,2034 +416,2209 @@ void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, * @param blob_names * @param reduced_node_count */ -void fuse_conv_reshape(onnx::GraphProto* mutable_graph, +void fuse_conv_reshape(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - std::map> shape_context; - const int node_count = mutable_graph->node_size(); - - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* conv = mutable_graph->mutable_node(i); - - if (conv->op_type() != "Conv") { - continue; - } - - if (i + 4 >= node_count) { - continue; - } - - onnx::NodeProto *shape = nullptr, *slice = nullptr, *concat = nullptr, *reshape = nullptr; - - // match [Shape ... Slice, Concat ... Reshape] from near sequence, skip useless Constant - std::vector> candidates = { - {"Shape", &shape}, {"Slice", &slice}, {"Concat", &concat}, {"Reshape", &reshape}}; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + std::map> shape_context; + const int node_count = mutable_graph->node_size(); + + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* conv = mutable_graph->mutable_node(i); - int MAX = std::min(10, node_count - i - 1); - int pos_candidate = 0; + if (conv->op_type() != "Conv") + { + continue; + } - for (int j = 0; j < MAX; ++j) { - auto node_ptr = mutable_graph->mutable_node(j + i + 1); - if (node_ptr->op_type() == "Constant") { - continue; - } - if (node_ptr->op_type() == std::get<0>(candidates[pos_candidate])) { - *(std::get<1>(candidates[pos_candidate])) = node_ptr; - pos_candidate++; - } - } + if (i + 4 >= node_count) + { + continue; + } - if (pos_candidate != candidates.size()) { - // not match the sequence - continue; - } + onnx::NodeProto * shape = nullptr, *slice = nullptr, *concat = nullptr, *reshape = nullptr; + + // match [Shape ... Slice, Concat ... Reshape] from near sequence, skip useless Constant + std::vector> candidates = { + {"Shape", &shape}, + {"Slice", &slice}, + {"Concat", &concat}, + {"Reshape", &reshape}}; + + int MAX = std::min(10, node_count - i - 1); + int pos_candidate = 0; + + for (int j = 0; j < MAX; ++j) + { + auto node_ptr = mutable_graph->mutable_node(j + i + 1); + if (node_ptr->op_type() == "Constant") + { + continue; + } + if (node_ptr->op_type() == std::get<0>(candidates[pos_candidate])) + { + *(std::get<1>(candidates[pos_candidate])) = node_ptr; + pos_candidate++; + } + } - if (node_reference[conv->output(0)] != 2 || node_reference[shape->output(0)] != 1 || - node_reference[slice->output(0)] != 1 || node_reference[concat->output(0)] != 1 || - node_reference[reshape->output(0)] != 1) { - continue; - } + if (pos_candidate != candidates.size()) + { + // not match the sequence + continue; + } - // check the connections - if (shape->input(0) != conv->output(0) || reshape->input(0) != conv->output(0)) { - continue; - } - if (slice->input(0) != shape->output(0)) { - continue; - } - if (concat->input(0) != slice->output(0)) { - continue; - } - if (reshape->input(0) != conv->output(0) || reshape->input(1) != concat->output(0)) { - continue; - } + if (node_reference[conv->output(0)] != 2 || node_reference[shape->output(0)] != 1 || + node_reference[slice->output(0)] != 1 || node_reference[concat->output(0)] != 1 || + node_reference[reshape->output(0)] != 1) + { + continue; + } - // add reshape attr - auto result = query_shape(mutable_graph, concat, weights, shape_context); - if (!std::get<0>(result)) { - continue; - } - set_node_attr_ai(*reshape, "shape", std::get<1>(result)); + // check the connections + if (shape->input(0) != conv->output(0) || reshape->input(0) != conv->output(0)) + { + continue; + } + if (slice->input(0) != shape->output(0)) + { + continue; + } + if (concat->input(0) != slice->output(0)) + { + continue; + } + if (reshape->input(0) != conv->output(0) || reshape->input(1) != concat->output(0)) + { + continue; + } - // reconstruct graph - { - // remove reference - node_reference[reshape->input(1)] -= 1; - node_reference[concat->input(0)] -= 1; - node_reference[slice->input(0)] -= 1; - node_reference[shape->input(0)] -= 1; - - // remove tensor/blob on edge - blob_names.erase(slice->input(0)); - blob_names.erase(slice->input(1)); - blob_names.erase(slice->input(2)); - blob_names.erase(slice->input(3)); - weights.erase(slice->input(1)); - weights.erase(slice->input(2)); - weights.erase(slice->input(3)); - - blob_names.erase(concat->input(0)); - blob_names.erase(concat->input(1)); - weights.erase(concat->input(1)); - - blob_names.erase(reshape->input(0)); - - // update edge - shape->clear_input(); - reshape->clear_input(); - reshape->add_input(conv->output(0)); - - shape->set_op_type("noop_reducedncnn"); - slice->set_op_type("noop_reducedncnn"); - concat->set_op_type("noop_reducedncnn"); - - reduced_node_count += 3; + // add reshape attr + auto result = query_shape(mutable_graph, concat, weights, shape_context); + if (!std::get<0>(result)) + { + continue; + } + set_node_attr_ai(*reshape, "shape", std::get<1>(result)); + + // reconstruct graph + { + // remove reference + node_reference[reshape->input(1)] -= 1; + node_reference[concat->input(0)] -= 1; + node_reference[slice->input(0)] -= 1; + node_reference[shape->input(0)] -= 1; + + // remove tensor/blob on edge + blob_names.erase(slice->input(0)); + blob_names.erase(slice->input(1)); + blob_names.erase(slice->input(2)); + blob_names.erase(slice->input(3)); + weights.erase(slice->input(1)); + weights.erase(slice->input(2)); + weights.erase(slice->input(3)); + + blob_names.erase(concat->input(0)); + blob_names.erase(concat->input(1)); + weights.erase(concat->input(1)); + + blob_names.erase(reshape->input(0)); + + // update edge + shape->clear_input(); + reshape->clear_input(); + reshape->add_input(conv->output(0)); + + shape->set_op_type("noop_reducedncnn"); + slice->set_op_type("noop_reducedncnn"); + concat->set_op_type("noop_reducedncnn"); + + reduced_node_count += 3; + } + i += 3; } - i += 3; - } } -void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, +void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // Add/Sub/Mul/Div/Min/Max/Pow - if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || - node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || - node->op_type() == "Pow") { - if (weights.find(node->input(1)) == weights.end()) continue; + // Add/Sub/Mul/Div/Min/Max/Pow + if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || + node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || + node->op_type() == "Pow") + { + if (weights.find(node->input(1)) == weights.end()) continue; - const onnx::TensorProto& scalar_b = weights[node->input(1)]; - if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1) continue; + const onnx::TensorProto& scalar_b = weights[node->input(1)]; + if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1) continue; - float b = get_node_attr_from_input(scalar_b); + float b = get_node_attr_from_input(scalar_b); - node_reference[node->input(1)] -= 1; + node_reference[node->input(1)] -= 1; - std::string input = node->input(0); + std::string input = node->input(0); - node->clear_input(); - node->add_input(input); + node->clear_input(); + node->add_input(input); - onnx::AttributeProto* attr_with_scalar = node->add_attribute(); - attr_with_scalar->set_name("with_scalar"); - attr_with_scalar->set_i(1); + onnx::AttributeProto* attr_with_scalar = node->add_attribute(); + attr_with_scalar->set_name("with_scalar"); + attr_with_scalar->set_i(1); - onnx::AttributeProto* attr_b = node->add_attribute(); - attr_b->set_name("b"); - attr_b->set_f(b); + onnx::AttributeProto* attr_b = node->add_attribute(); + attr_b->set_name("b"); + attr_b->set_f(b); + } } - } } -void fuse_hardswish(onnx::GraphProto* mutable_graph, +void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6) - // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6)) - // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6) - // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6)) - // out = x * F.relu6(x + 3, inplace=True) / 6 - if (node->op_type() == "Add") { - if (node_reference[node->output(0)] != 1) continue; - - if (i + 3 >= node_count) continue; - - if (weights.find(node->input(1)) == weights.end()) continue; - - const onnx::TensorProto& add_three = weights[node->input(1)]; - if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; - - float constant_add_three = get_node_attr_from_input(add_three); - if (constant_add_three != 3.f) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - - if (node4->op_type() == "Constant") { - if (i + 4 >= node_count) continue; - - node4 = mutable_graph->mutable_node(i + 4); - } - - if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || - (node4->op_type() != "Div" && node4->op_type() != "Mul")) - continue; - - if (node_reference[node2->output(0)] != 1) continue; - - float relu6_min; - float relu6_max; - if (node2->input_size() == 1) { - relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX); - relu6_max = get_node_attr_f(*node2, "max", FLT_MAX); - } else { - const onnx::TensorProto& min_tp = weights[node2->input(1)]; - const onnx::TensorProto& max_tp = weights[node2->input(2)]; - - relu6_min = get_node_attr_from_input(min_tp); - relu6_max = get_node_attr_from_input(max_tp); - } - if (relu6_min != 0.f || relu6_max != 6.f) continue; - - if (node_reference[node3->output(0)] != 1) continue; - - if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0)) continue; - - if (weights.find(node4->input(1)) == weights.end()) continue; - - const onnx::TensorProto& div_six = weights[node4->input(1)]; - if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; - - float constant_div_six = get_node_attr_from_input(div_six); - if (node4->op_type() == "Div" && constant_div_six != 6.f) continue; - if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); - - node_reference[node->input(0)] -= 1; - node_reference[node->input(1)] -= 1; - node_reference[node->output(0)] -= 1; - if (node2->input_size() == 3) { - node_reference[node2->input(1)] -= 1; - node_reference[node2->input(2)] -= 1; - } - node_reference[node2->output(0)] -= 1; - node_reference[node3->output(0)] -= 1; - node_reference[node4->input(1)] -= 1; - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - blob_names.erase(node3->output(0)); - - node4->set_op_type("HardSwish"); - node4->clear_input(); - node4->add_input(node->input(0)); - - onnx::AttributeProto* attr_alpha = node4->add_attribute(); - attr_alpha->set_name("alpha"); - attr_alpha->set_f(1.f / 6.f); - - onnx::AttributeProto* attr_beta = node4->add_attribute(); - attr_beta->set_name("beta"); - attr_beta->set_f(3.f / 6.f); - - reduced_node_count += 3; - i += 3; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6) + // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6)) + // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6) + // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6)) + // out = x * F.relu6(x + 3, inplace=True) / 6 + if (node->op_type() == "Add") + { + if (node_reference[node->output(0)] != 1) continue; + + if (i + 3 >= node_count) continue; + + if (weights.find(node->input(1)) == weights.end()) continue; + + const onnx::TensorProto& add_three = weights[node->input(1)]; + if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; + + float constant_add_three = get_node_attr_from_input(add_three); + if (constant_add_three != 3.f) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + + if (node4->op_type() == "Constant") + { + if (i + 4 >= node_count) continue; + + node4 = mutable_graph->mutable_node(i + 4); + } + + if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || + (node4->op_type() != "Div" && node4->op_type() != "Mul")) + continue; + + if (node_reference[node2->output(0)] != 1) continue; + + float relu6_min; + float relu6_max; + if (node2->input_size() == 1) + { + relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX); + relu6_max = get_node_attr_f(*node2, "max", FLT_MAX); + } + else + { + const onnx::TensorProto& min_tp = weights[node2->input(1)]; + const onnx::TensorProto& max_tp = weights[node2->input(2)]; + + relu6_min = get_node_attr_from_input(min_tp); + relu6_max = get_node_attr_from_input(max_tp); + } + if (relu6_min != 0.f || relu6_max != 6.f) continue; + + if (node_reference[node3->output(0)] != 1) continue; + + if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0)) continue; + + if (weights.find(node4->input(1)) == weights.end()) continue; + + const onnx::TensorProto& div_six = weights[node4->input(1)]; + if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; + + float constant_div_six = get_node_attr_from_input(div_six); + if (node4->op_type() == "Div" && constant_div_six != 6.f) continue; + if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= 1; + node_reference[node->input(1)] -= 1; + node_reference[node->output(0)] -= 1; + if (node2->input_size() == 3) + { + node_reference[node2->input(1)] -= 1; + node_reference[node2->input(2)] -= 1; + } + node_reference[node2->output(0)] -= 1; + node_reference[node3->output(0)] -= 1; + node_reference[node4->input(1)] -= 1; + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + blob_names.erase(node3->output(0)); + + node4->set_op_type("HardSwish"); + node4->clear_input(); + node4->add_input(node->input(0)); + + onnx::AttributeProto* attr_alpha = node4->add_attribute(); + attr_alpha->set_name("alpha"); + attr_alpha->set_f(1.f / 6.f); + + onnx::AttributeProto* attr_beta = node4->add_attribute(); + attr_beta->set_name("beta"); + attr_beta->set_f(3.f / 6.f); + + reduced_node_count += 3; + i += 3; + } } - } - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // HardSwish <= HardSigmoid - Mul - // out = x * hsigmoid(x) - if (node->op_type() == "HardSigmoid") { - if (node_reference[node->output(0)] != 1) continue; + // HardSwish <= HardSigmoid - Mul + // out = x * hsigmoid(x) + if (node->op_type() == "HardSigmoid") + { + if (node_reference[node->output(0)] != 1) continue; - float alpha = get_node_attr_f(*node, "alpha", 0.2f); - float beta = get_node_attr_f(*node, "beta", 0.5f); + float alpha = get_node_attr_f(*node, "alpha", 0.2f); + float beta = get_node_attr_f(*node, "beta", 0.5f); - if (i + 1 >= node_count) continue; + if (i + 1 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - if (node2->op_type() != "Mul") continue; + if (node2->op_type() != "Mul") continue; - if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)) continue; + if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)) continue; - // reduce - node->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); - node_reference[node->input(0)] -= 1; - node_reference[node->output(0)] -= 1; + node_reference[node->input(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node2->set_op_type("HardSwish"); - node2->clear_input(); - node2->add_input(node->input(0)); + node2->set_op_type("HardSwish"); + node2->clear_input(); + node2->add_input(node->input(0)); - onnx::AttributeProto* attr_alpha = node2->add_attribute(); - attr_alpha->set_name("alpha"); - attr_alpha->set_f(alpha); + onnx::AttributeProto* attr_alpha = node2->add_attribute(); + attr_alpha->set_name("alpha"); + attr_alpha->set_f(alpha); - onnx::AttributeProto* attr_beta = node2->add_attribute(); - attr_beta->set_name("beta"); - attr_beta->set_f(beta); + onnx::AttributeProto* attr_beta = node2->add_attribute(); + attr_beta->set_name("beta"); + attr_beta->set_f(beta); - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; + } } - } } -void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, +void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6) - // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6)) - // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6) - // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6)) - // out = F.relu6(x + 3, inplace=True) / 6 - if (node->op_type() == "Add") { - if (node_reference[node->output(0)] != 1) continue; - - if (i + 2 >= node_count) continue; - - if (weights.find(node->input(1)) == weights.end()) continue; - - const onnx::TensorProto& add_three = weights[node->input(1)]; - if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; - - float constant_add_three = get_node_attr_from_input(add_three); - if (constant_add_three != 3.f) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - - if (node3->op_type() == "Constant") { - if (i + 3 >= node_count) continue; - - node3 = mutable_graph->mutable_node(i + 3); - } - - if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul")) - continue; - - if (node_reference[node2->output(0)] != 1) continue; - - float relu6_min; - float relu6_max; - if (node2->input_size() == 1) { - relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX); - relu6_max = get_node_attr_f(*node2, "max", FLT_MAX); - } else { - const onnx::TensorProto& min_tp = weights[node2->input(1)]; - const onnx::TensorProto& max_tp = weights[node2->input(2)]; - - relu6_min = get_node_attr_from_input(min_tp); - relu6_max = get_node_attr_from_input(max_tp); - } - if (relu6_min != 0.f || relu6_max != 6.f) continue; - - if (weights.find(node3->input(1)) == weights.end()) continue; - - const onnx::TensorProto& div_six = weights[node3->input(1)]; - if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; - - float constant_div_six = get_node_attr_from_input(div_six); - if (node3->op_type() == "Div" && constant_div_six != 6.f) continue; - if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - - node_reference[node->input(1)] -= 1; - node_reference[node->output(0)] -= 1; - if (node2->input_size() == 3) { - node_reference[node2->input(1)] -= 1; - node_reference[node2->input(2)] -= 1; - } - node_reference[node2->output(0)] -= 1; - node_reference[node3->input(1)] -= 1; - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - - node3->set_op_type("HardSigmoid"); - node3->clear_input(); - node3->add_input(node->input(0)); - - onnx::AttributeProto* attr_alpha = node3->add_attribute(); - attr_alpha->set_name("alpha"); - attr_alpha->set_f(1.f / 6.f); - - onnx::AttributeProto* attr_beta = node3->add_attribute(); - attr_beta->set_name("beta"); - attr_beta->set_f(3.f / 6.f); - - reduced_node_count += 2; - i += 2; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6) + // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6)) + // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6) + // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6)) + // out = F.relu6(x + 3, inplace=True) / 6 + if (node->op_type() == "Add") + { + if (node_reference[node->output(0)] != 1) continue; + + if (i + 2 >= node_count) continue; + + if (weights.find(node->input(1)) == weights.end()) continue; + + const onnx::TensorProto& add_three = weights[node->input(1)]; + if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1) continue; + + float constant_add_three = get_node_attr_from_input(add_three); + if (constant_add_three != 3.f) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + + if (node3->op_type() == "Constant") + { + if (i + 3 >= node_count) continue; + + node3 = mutable_graph->mutable_node(i + 3); + } + + if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul")) + continue; + + if (node_reference[node2->output(0)] != 1) continue; + + float relu6_min; + float relu6_max; + if (node2->input_size() == 1) + { + relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX); + relu6_max = get_node_attr_f(*node2, "max", FLT_MAX); + } + else + { + const onnx::TensorProto& min_tp = weights[node2->input(1)]; + const onnx::TensorProto& max_tp = weights[node2->input(2)]; + + relu6_min = get_node_attr_from_input(min_tp); + relu6_max = get_node_attr_from_input(max_tp); + } + if (relu6_min != 0.f || relu6_max != 6.f) continue; + + if (weights.find(node3->input(1)) == weights.end()) continue; + + const onnx::TensorProto& div_six = weights[node3->input(1)]; + if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1) continue; + + float constant_div_six = get_node_attr_from_input(div_six); + if (node3->op_type() == "Div" && constant_div_six != 6.f) continue; + if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f) continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + + node_reference[node->input(1)] -= 1; + node_reference[node->output(0)] -= 1; + if (node2->input_size() == 3) + { + node_reference[node2->input(1)] -= 1; + node_reference[node2->input(2)] -= 1; + } + node_reference[node2->output(0)] -= 1; + node_reference[node3->input(1)] -= 1; + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + + node3->set_op_type("HardSigmoid"); + node3->clear_input(); + node3->add_input(node->input(0)); + + onnx::AttributeProto* attr_alpha = node3->add_attribute(); + attr_alpha->set_name("alpha"); + attr_alpha->set_f(1.f / 6.f); + + onnx::AttributeProto* attr_beta = node3->add_attribute(); + attr_beta->set_name("beta"); + attr_beta->set_f(3.f / 6.f); + + reduced_node_count += 2; + i += 2; + } } - } } -void fuse_swish(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); +void fuse_swish(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // Swish <= Sigmoid - Mul - // x * torch.sigmoid(x) - if (node->op_type() == "Sigmoid") { - if (node_reference[node->output(0)] != 1) continue; + // Swish <= Sigmoid - Mul + // x * torch.sigmoid(x) + if (node->op_type() == "Sigmoid") + { + if (node_reference[node->output(0)] != 1) continue; - if (i + 1 >= node_count) continue; + if (i + 1 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - if (node2->op_type() != "Mul") continue; + if (node2->op_type() != "Mul") continue; - if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)) continue; + if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)) continue; - // reduce - node->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); - node_reference[node->input(0)] -= 1; - node_reference[node->output(0)] -= 1; + node_reference[node->input(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node2->set_op_type("Swish"); - node2->clear_input(); - node2->add_input(node->input(0)); + node2->set_op_type("Swish"); + node2->clear_input(); + node2->add_input(node->input(0)); - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; + } } - } } -void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, +void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze - if (node->op_type() == "Unsqueeze") { - if (node_reference[node->output(0)] != 1) continue; + // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze + if (node->op_type() == "Unsqueeze") + { + if (node_reference[node->output(0)] != 1) continue; - if (i + 2 >= node_count) continue; + if (i + 2 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze") continue; + if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze") continue; - if (node_reference[node2->output(0)] != 1) continue; + if (node_reference[node2->output(0)] != 1) continue; - if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) continue; + if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) continue; - // reduce - node->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; - node_reference[node2->output(0)] -= 1; + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); - node2->set_input(0, node->input(0)); - node2->set_output(0, node3->output(0)); + node2->set_input(0, node->input(0)); + node2->set_output(0, node3->output(0)); - reduced_node_count += 2; - i += 2; + reduced_node_count += 2; + i += 2; + } } - } } -void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, +void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // PReLU <= Unsqueeze - PReLU - if (node->op_type() == "Unsqueeze") { - // check weight - if (weights.find(node->input(0)) == weights.end()) continue; + // PReLU <= Unsqueeze - PReLU + if (node->op_type() == "Unsqueeze") + { + // check weight + if (weights.find(node->input(0)) == weights.end()) continue; - onnx::TensorProto& B = weights[node->input(0)]; - if (B.dims_size() != 1) continue; + onnx::TensorProto& B = weights[node->input(0)]; + if (B.dims_size() != 1) continue; - if (node_reference[node->output(0)] != 1) continue; + if (node_reference[node->output(0)] != 1) continue; - // axes = (1, 2) - std::vector axes = get_node_attr_ai(*node, "axes"); - if (axes.size() != 2) continue; - if (axes[0] != 1 || axes[1] != 2) continue; + // axes = (1, 2) + std::vector axes = get_node_attr_ai(*node, "axes"); + if (axes.size() != 2) continue; + if (axes[0] != 1 || axes[1] != 2) continue; - if (i + 1 >= node_count) continue; + if (i + 1 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - if (node2->op_type() != "PRelu") continue; + if (node2->op_type() != "PRelu") continue; - if (node2->input(1) != node->output(0)) continue; + if (node2->input(1) != node->output(0)) continue; - // reduce - node->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node2->set_input(1, node->input(0)); + node2->set_input(1, node->input(0)); - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; + } } - } } -void fuse_normalize(onnx::GraphProto* mutable_graph, +void fuse_normalize(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // Normalize <= X - ReduceL2 - Clip - Expand - Div - // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div - if (node->op_type() == "ReduceL2") { - if (node_reference[node->output(0)] != 1) continue; - - // axes = (1) - std::vector axes = get_node_attr_ai(*node, "axes"); - if (axes.size() != 1) continue; - if (axes[0] != 1) continue; - - if (i + 3 >= node_count) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - - bool has_shape_node = node3->op_type() == "Shape"; - onnx::NodeProto* node_shape = 0; - if (has_shape_node) { - if (i + 4 >= node_count) continue; - - node_shape = node3; - node3 = mutable_graph->mutable_node(i + 3); - node4 = mutable_graph->mutable_node(i + 4); - } - - if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div") - continue; - - if (node_reference[node2->output(0)] != 1) continue; - - if (node_reference[node3->output(0)] != 1) continue; - - if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || - node4->input(0) != node->input(0) || node4->input(1) != node3->output(0)) - continue; - - if (has_shape_node) { - if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0)) - continue; - } - - // +eps - float clip_min; - if (node2->input_size() == 1) { - clip_min = get_node_attr_f(*node2, "min", -FLT_MAX); - } else { - const onnx::TensorProto& min_tp = weights[node2->input(1)]; - - clip_min = get_node_attr_from_input(min_tp); - } - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - if (has_shape_node) { - node_shape->set_op_type("noop_reducedncnn"); - } - node3->set_op_type("noop_reducedncnn"); - - node_reference[node->input(0)] -= has_shape_node ? 2 : 1; - node_reference[node->output(0)] -= 1; - node_reference[node2->output(0)] -= 1; - if (has_shape_node) { - node_reference[node_shape->output(0)] -= 1; - } - node_reference[node3->output(0)] -= 1; - if (node3->input_size() == 2) { - node_reference[node3->input(1)] -= 1; - } - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - if (has_shape_node) { - blob_names.erase(node_shape->output(0)); - } - blob_names.erase(node3->output(0)); - - node4->set_op_type("Normalize"); - node4->clear_input(); - node4->add_input(node->input(0)); - - onnx::AttributeProto* attr_alpha = node4->add_attribute(); - attr_alpha->set_name("eps"); - attr_alpha->set_f(clip_min); - - reduced_node_count += has_shape_node ? 4 : 3; - i += has_shape_node ? 4 : 3; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // Normalize <= X - ReduceL2 - Clip - Expand - Div + // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div + if (node->op_type() == "ReduceL2") + { + if (node_reference[node->output(0)] != 1) continue; + + // axes = (1) + std::vector axes = get_node_attr_ai(*node, "axes"); + if (axes.size() != 1) continue; + if (axes[0] != 1) continue; + + if (i + 3 >= node_count) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + + bool has_shape_node = node3->op_type() == "Shape"; + onnx::NodeProto* node_shape = 0; + if (has_shape_node) + { + if (i + 4 >= node_count) continue; + + node_shape = node3; + node3 = mutable_graph->mutable_node(i + 3); + node4 = mutable_graph->mutable_node(i + 4); + } + + if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div") + continue; + + if (node_reference[node2->output(0)] != 1) continue; + + if (node_reference[node3->output(0)] != 1) continue; + + if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || + node4->input(0) != node->input(0) || node4->input(1) != node3->output(0)) + continue; + + if (has_shape_node) + { + if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0)) + continue; + } + + // +eps + float clip_min; + if (node2->input_size() == 1) + { + clip_min = get_node_attr_f(*node2, "min", -FLT_MAX); + } + else + { + const onnx::TensorProto& min_tp = weights[node2->input(1)]; + + clip_min = get_node_attr_from_input(min_tp); + } + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + if (has_shape_node) + { + node_shape->set_op_type("noop_reducedncnn"); + } + node3->set_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= has_shape_node ? 2 : 1; + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (has_shape_node) + { + node_reference[node_shape->output(0)] -= 1; + } + node_reference[node3->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + if (has_shape_node) + { + blob_names.erase(node_shape->output(0)); + } + blob_names.erase(node3->output(0)); + + node4->set_op_type("Normalize"); + node4->clear_input(); + node4->add_input(node->input(0)); + + onnx::AttributeProto* attr_alpha = node4->add_attribute(); + attr_alpha->set_name("eps"); + attr_alpha->set_f(clip_min); + + reduced_node_count += has_shape_node ? 4 : 3; + i += has_shape_node ? 4 : 3; + } } - } } -void fuse_groupnorm(onnx::GraphProto* mutable_graph, +void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add - if (node->op_type() == "Reshape") { - if (node_reference[node->output(0)] != 1) continue; - - std::vector shape; - if (node->input_size() == 1) { - shape = get_node_attr_ai(*node, "shape"); - } else { - // skip weight reshape - if (weights.find(node->input(1)) == weights.end()) continue; - - shape = get_node_attr_from_input_ai(weights[node->input(1)]); - } - - // 0, group, -1 - if (shape.size() != 3) continue; - - if (shape[0] != 0 || shape[2] != -1) continue; - - int groups = shape[1]; - - if (i + 4 >= node_count) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); - - if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || - node4->op_type() != "Mul" || node5->op_type() != "Add") - continue; - - if (node_reference[node2->output(0)] != 1) continue; - - if (node_reference[node3->output(0)] != 1) continue; - - if (node_reference[node4->output(0)] != 1) continue; - - if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || - node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0)) - continue; - - // +eps - float eps = get_node_attr_f(*node2, "epsilon", 1e-05f); - - // InstanceNormalization S=1 B=0 - std::vector S = get_node_attr_from_input_af(weights[node2->input(1)]); - std::vector B = get_node_attr_from_input_af(weights[node2->input(2)]); - if ((int)S.size() != groups || (int)B.size() != groups) continue; - - bool instancenorm_affine = false; - for (int j = 0; j < groups; j++) { - if (S[j] != 1.f || B[j] != 0.f) { - instancenorm_affine = true; - break; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add + if (node->op_type() == "Reshape") + { + if (node_reference[node->output(0)] != 1) continue; + + std::vector shape; + if (node->input_size() == 1) + { + shape = get_node_attr_ai(*node, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node->input(1)) == weights.end()) continue; + + shape = get_node_attr_from_input_ai(weights[node->input(1)]); + } + + // 0, group, -1 + if (shape.size() != 3) continue; + + if (shape[0] != 0 || shape[2] != -1) continue; + + int groups = shape[1]; + + if (i + 4 >= node_count) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); + + if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || + node4->op_type() != "Mul" || node5->op_type() != "Add") + continue; + + if (node_reference[node2->output(0)] != 1) continue; + + if (node_reference[node3->output(0)] != 1) continue; + + if (node_reference[node4->output(0)] != 1) continue; + + if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || + node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0)) + continue; + + // +eps + float eps = get_node_attr_f(*node2, "epsilon", 1e-05f); + + // InstanceNormalization S=1 B=0 + std::vector S = get_node_attr_from_input_af(weights[node2->input(1)]); + std::vector B = get_node_attr_from_input_af(weights[node2->input(2)]); + if ((int)S.size() != groups || (int)B.size() != groups) continue; + + bool instancenorm_affine = false; + for (int j = 0; j < groups; j++) + { + if (S[j] != 1.f || B[j] != 0.f) + { + instancenorm_affine = true; + break; + } + } + + if (instancenorm_affine) continue; + + std::vector shape2; + if (node3->input_size() == 1) + { + shape2 = get_node_attr_ai(*node3, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node3->input(1)) == weights.end()) continue; + + shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]); + } + + // 1, channels, w, h + if (shape2.size() != 4) continue; + + if (shape2[0] != 1) continue; + + int channels = shape2[1]; + + // affine + std::vector affine_S = get_node_attr_from_input_af(weights[node4->input(1)]); + std::vector affine_B = get_node_attr_from_input_af(weights[node5->input(1)]); + if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && + affine_B[0] == 0.f) + { + // no affine + } + else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels) + { + // we only allow per-channel affine + continue; + } + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); + node4->set_op_type("noop_reducedncnn"); + + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + node_reference[node->output(0)] -= 1; + node_reference[node2->input(1)] -= 1; + node_reference[node2->input(2)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } + node_reference[node3->output(0)] -= 1; + node_reference[node4->output(0)] -= 1; + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + blob_names.erase(node3->output(0)); + blob_names.erase(node4->output(0)); + + std::string affine_scale = node4->input(1); + std::string affine_bias = node5->input(1); + + node5->set_op_type("GroupNorm"); + node5->clear_input(); + node5->add_input(node->input(0)); + node5->add_input(affine_scale); + node5->add_input(affine_bias); + + onnx::AttributeProto* attr_groups = node5->add_attribute(); + attr_groups->set_name("groups"); + attr_groups->set_i(groups); + + onnx::AttributeProto* attr_channels = node5->add_attribute(); + attr_channels->set_name("channels"); + attr_channels->set_i(channels); + + onnx::AttributeProto* attr_eps = node5->add_attribute(); + attr_eps->set_name("epsilon"); + attr_eps->set_f(eps); + + onnx::AttributeProto* attr_affine = node5->add_attribute(); + attr_affine->set_name("affine"); + attr_affine->set_i(1); + + reduced_node_count += 4; + i += 4; } - } - - if (instancenorm_affine) continue; - - std::vector shape2; - if (node3->input_size() == 1) { - shape2 = get_node_attr_ai(*node3, "shape"); - } else { - // skip weight reshape - if (weights.find(node3->input(1)) == weights.end()) continue; - - shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]); - } - - // 1, channels, w, h - if (shape2.size() != 4) continue; - - if (shape2[0] != 1) continue; - - int channels = shape2[1]; - - // affine - std::vector affine_S = get_node_attr_from_input_af(weights[node4->input(1)]); - std::vector affine_B = get_node_attr_from_input_af(weights[node5->input(1)]); - if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && - affine_B[0] == 0.f) { - // no affine - } else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels) { - // we only allow per-channel affine - continue; - } - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); - node4->set_op_type("noop_reducedncnn"); - - if (node->input_size() == 2) { - node_reference[node->input(1)] -= 1; - } - node_reference[node->output(0)] -= 1; - node_reference[node2->input(1)] -= 1; - node_reference[node2->input(2)] -= 1; - node_reference[node2->output(0)] -= 1; - if (node3->input_size() == 2) { - node_reference[node3->input(1)] -= 1; - } - node_reference[node3->output(0)] -= 1; - node_reference[node4->output(0)] -= 1; - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - blob_names.erase(node3->output(0)); - blob_names.erase(node4->output(0)); - - std::string affine_scale = node4->input(1); - std::string affine_bias = node5->input(1); - - node5->set_op_type("GroupNorm"); - node5->clear_input(); - node5->add_input(node->input(0)); - node5->add_input(affine_scale); - node5->add_input(affine_bias); - - onnx::AttributeProto* attr_groups = node5->add_attribute(); - attr_groups->set_name("groups"); - attr_groups->set_i(groups); - - onnx::AttributeProto* attr_channels = node5->add_attribute(); - attr_channels->set_name("channels"); - attr_channels->set_i(channels); - - onnx::AttributeProto* attr_eps = node5->add_attribute(); - attr_eps->set_name("epsilon"); - attr_eps->set_f(eps); - - onnx::AttributeProto* attr_affine = node5->add_attribute(); - attr_affine->set_name("affine"); - attr_affine->set_i(1); - - reduced_node_count += 4; - i += 4; } - } } -void fuse_layernorm(onnx::GraphProto* mutable_graph, +void fuse_layernorm(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - - // Mul - Add - if (node->op_type() == "ReduceMean") { - if (node_reference[node->output(0)] != 1) continue; - - std::vector axes = get_node_attr_ai(*node, "axes"); - - // -1 - // -2 -1 - if (axes.size() != 1 && axes.size() != 2) continue; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - int normed_axes = (int)axes.size(); - if (normed_axes == 1 && axes[0] != -1) continue; - if (normed_axes == 2 && (axes[0] != -2 || axes[1] != -1)) continue; + // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div + // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - + // Mul - Add + if (node->op_type() == "ReduceMean") + { + if (node_reference[node->output(0)] != 1) continue; - if (i + 6 >= node_count) continue; + std::vector axes = get_node_attr_ai(*node, "axes"); - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); - onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); - onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); + // -1 + // -2 -1 + if (axes.size() != 1 && axes.size() != 2) continue; - if (node2->op_type() != "Sub" || node3->op_type() != "Pow" || - node4->op_type() != "ReduceMean" || node5->op_type() != "Add" || - node6->op_type() != "Sqrt" || node7->op_type() != "Div") - continue; + int normed_axes = (int)axes.size(); + if (normed_axes == 1 && axes[0] != -1) continue; + if (normed_axes == 2 && (axes[0] != -2 || axes[1] != -1)) continue; - if (node_reference[node2->output(0)] != 2) continue; + if (i + 6 >= node_count) continue; - if (node_reference[node3->output(0)] != 1) continue; + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); + onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); + onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); - if (node_reference[node4->output(0)] != 1) continue; + if (node2->op_type() != "Sub" || node3->op_type() != "Pow" || + node4->op_type() != "ReduceMean" || node5->op_type() != "Add" || + node6->op_type() != "Sqrt" || node7->op_type() != "Div") + continue; - if (node_reference[node5->output(0)] != 1) continue; + if (node_reference[node2->output(0)] != 2) continue; - if (node_reference[node6->output(0)] != 1) continue; + if (node_reference[node3->output(0)] != 1) continue; - if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0) || - node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || - node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) || - node7->input(0) != node2->output(0) || node7->input(1) != node6->output(0)) - continue; + if (node_reference[node4->output(0)] != 1) continue; - if (weights.find(node3->input(1)) == weights.end()) continue; + if (node_reference[node5->output(0)] != 1) continue; - const onnx::TensorProto& pow_two = weights[node3->input(1)]; - if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1) continue; + if (node_reference[node6->output(0)] != 1) continue; - float constant_pow_two = get_node_attr_from_input(pow_two); - if (constant_pow_two != 2.f) continue; + if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0) || + node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || + node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) || + node7->input(0) != node2->output(0) || node7->input(1) != node6->output(0)) + continue; - std::vector axes4 = get_node_attr_ai(*node4, "axes"); + if (weights.find(node3->input(1)) == weights.end()) continue; - // -1 - // -2 -1 - if ((int)axes4.size() != normed_axes) continue; + const onnx::TensorProto& pow_two = weights[node3->input(1)]; + if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1) continue; - if (normed_axes == 1 && axes4[0] != -1) continue; - if (normed_axes == 2 && (axes4[0] != -2 || axes4[1] != -1)) continue; + float constant_pow_two = get_node_attr_from_input(pow_two); + if (constant_pow_two != 2.f) continue; - if (weights.find(node5->input(1)) == weights.end()) continue; + std::vector axes4 = get_node_attr_ai(*node4, "axes"); - const onnx::TensorProto& add_eps = weights[node5->input(1)]; - if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1) continue; + // -1 + // -2 -1 + if ((int)axes4.size() != normed_axes) continue; - float eps = get_node_attr_from_input(add_eps); + if (normed_axes == 1 && axes4[0] != -1) continue; + if (normed_axes == 2 && (axes4[0] != -2 || axes4[1] != -1)) continue; - int affine = 0; - while (i + 8 < node_count) { - onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); - onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); + if (weights.find(node5->input(1)) == weights.end()) continue; - if (node8->op_type() != "Mul" || node9->op_type() != "Add") break; + const onnx::TensorProto& add_eps = weights[node5->input(1)]; + if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1) continue; - if (node_reference[node7->output(0)] != 1) break; + float eps = get_node_attr_from_input(add_eps); - if (node_reference[node8->output(0)] != 1) break; + int affine = 0; + while (i + 8 < node_count) + { + onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); + onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); - if (node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0)) break; + if (node8->op_type() != "Mul" || node9->op_type() != "Add") break; - // affine - std::vector affine_S = get_node_attr_from_input_af(weights[node8->input(1)]); - std::vector affine_B = get_node_attr_from_input_af(weights[node9->input(1)]); - if (affine_S.size() != affine_B.size()) break; + if (node_reference[node7->output(0)] != 1) break; - affine = 1; - break; - } + if (node_reference[node8->output(0)] != 1) break; - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); - node4->set_op_type("noop_reducedncnn"); - node5->set_op_type("noop_reducedncnn"); - node6->set_op_type("noop_reducedncnn"); + if (node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0)) break; - node_reference[node->input(0)] -= 1; - node_reference[node2->input(0)] -= 1; - node_reference[node2->input(1)] -= 1; - node_reference[node3->input(0)] -= 1; - node_reference[node3->input(1)] -= 1; - node_reference[node4->input(0)] -= 1; - node_reference[node5->input(0)] -= 1; - node_reference[node5->input(1)] -= 1; - node_reference[node6->input(0)] -= 1; - node_reference[node7->input(0)] -= 1; - node_reference[node7->input(1)] -= 1; + // affine + std::vector affine_S = get_node_attr_from_input_af(weights[node8->input(1)]); + std::vector affine_B = get_node_attr_from_input_af(weights[node9->input(1)]); + if (affine_S.size() != affine_B.size()) break; - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - blob_names.erase(node3->output(0)); - blob_names.erase(node4->output(0)); - blob_names.erase(node5->output(0)); - blob_names.erase(node6->output(0)); + affine = 1; + break; + } - node_reference[node->input(0)] += 1; + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); + node4->set_op_type("noop_reducedncnn"); + node5->set_op_type("noop_reducedncnn"); + node6->set_op_type("noop_reducedncnn"); - if (affine == 0) { - node7->set_op_type("LayerNorm"); - node7->clear_input(); - node7->add_input(node->input(0)); + node_reference[node->input(0)] -= 1; + node_reference[node2->input(0)] -= 1; + node_reference[node2->input(1)] -= 1; + node_reference[node3->input(0)] -= 1; + node_reference[node3->input(1)] -= 1; + node_reference[node4->input(0)] -= 1; + node_reference[node5->input(0)] -= 1; + node_reference[node5->input(1)] -= 1; + node_reference[node6->input(0)] -= 1; + node_reference[node7->input(0)] -= 1; + node_reference[node7->input(1)] -= 1; - onnx::AttributeProto* attr_eps = node7->add_attribute(); - attr_eps->set_name("epsilon"); - attr_eps->set_f(eps); + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + blob_names.erase(node3->output(0)); + blob_names.erase(node4->output(0)); + blob_names.erase(node5->output(0)); + blob_names.erase(node6->output(0)); - onnx::AttributeProto* attr_affine = node7->add_attribute(); - attr_affine->set_name("affine"); - attr_affine->set_i(affine); + node_reference[node->input(0)] += 1; - reduced_node_count += 6; - i += 6; - } else // if (affine == 1) - { - onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); - onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); + if (affine == 0) + { + node7->set_op_type("LayerNorm"); + node7->clear_input(); + node7->add_input(node->input(0)); - node7->set_op_type("noop_reducedncnn"); - node8->set_op_type("noop_reducedncnn"); + onnx::AttributeProto* attr_eps = node7->add_attribute(); + attr_eps->set_name("epsilon"); + attr_eps->set_f(eps); - node_reference[node8->input(0)] -= 1; - node_reference[node9->input(0)] -= 1; + onnx::AttributeProto* attr_affine = node7->add_attribute(); + attr_affine->set_name("affine"); + attr_affine->set_i(affine); - blob_names.erase(node7->output(0)); - blob_names.erase(node8->output(0)); + reduced_node_count += 6; + i += 6; + } + else // if (affine == 1) + { + onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); + onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); - std::string affine_scale = node8->input(1); - std::string affine_bias = node9->input(1); + node7->set_op_type("noop_reducedncnn"); + node8->set_op_type("noop_reducedncnn"); - node9->set_op_type("LayerNorm"); - node9->clear_input(); - node9->add_input(node->input(0)); - node9->add_input(affine_scale); - node9->add_input(affine_bias); - - onnx::AttributeProto* attr_eps = node9->add_attribute(); - attr_eps->set_name("epsilon"); - attr_eps->set_f(eps); - - onnx::AttributeProto* attr_affine = node9->add_attribute(); - attr_affine->set_name("affine"); - attr_affine->set_i(affine); - - reduced_node_count += 8; - i += 8; - } + node_reference[node8->input(0)] -= 1; + node_reference[node9->input(0)] -= 1; + + blob_names.erase(node7->output(0)); + blob_names.erase(node8->output(0)); + + std::string affine_scale = node8->input(1); + std::string affine_bias = node9->input(1); + + node9->set_op_type("LayerNorm"); + node9->clear_input(); + node9->add_input(node->input(0)); + node9->add_input(affine_scale); + node9->add_input(affine_bias); + + onnx::AttributeProto* attr_eps = node9->add_attribute(); + attr_eps->set_name("epsilon"); + attr_eps->set_f(eps); + + onnx::AttributeProto* attr_affine = node9->add_attribute(); + attr_affine->set_name("affine"); + attr_affine->set_i(affine); + + reduced_node_count += 8; + i += 8; + } + } } - } } -void fuse_flatten(onnx::GraphProto* mutable_graph, +void fuse_flatten(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - // - Reshape - if (node->op_type() == "Shape") { - if (node_reference[node->output(0)] != 1) continue; - - if (i + 6 >= node_count) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); - onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); - onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); - - if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || - node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze" || - node6->op_type() != "Concat" || node7->op_type() != "Reshape") - continue; - - if (node_reference[node2->output(0)] != 1) continue; - - // if (node_reference[node3->output(0)] != 1) - // continue; - - if (node_reference[node4->output(0)] != 1) continue; - - if (node_reference[node5->output(0)] != 1) continue; - - if (node_reference[node6->output(0)] != 1) continue; - - if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || - node5->input(0) != node3->output(0) || node6->input(0) != node4->output(0) || - node6->input(1) != node5->output(0) || node7->input(0) != node->input(0) || - node7->input(1) != node6->output(0)) - continue; - - // axis = 0 - int gather_axis = get_node_attr_i(*node2, "axis"); - if (gather_axis != 0) continue; - - // indices = 0 - if (weights.find(node2->input(1)) == weights.end()) continue; - - std::vector gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); - if (gather_indices.size() != 1 || gather_indices[0] != 0) continue; - - // axes = (0) - std::vector unsqueeze_axes = get_node_attr_ai(*node4, "axes"); - if (unsqueeze_axes.size() != 1) continue; - if (unsqueeze_axes[0] != 0) continue; - - // axes = (0) - std::vector unsqueeze2_axes = get_node_attr_ai(*node5, "axes"); - if (unsqueeze2_axes.size() != 1) continue; - if (unsqueeze2_axes[0] != 0) continue; - - // data = -1 - if (weights.find(node5->input(0)) == weights.end()) continue; - - std::vector unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]); - if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1) continue; - - // axis = 0 - int concat_axis = get_node_attr_i(*node6, "axis"); - if (concat_axis != 0) continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - // node3->set_op_type("noop_reducedncnn"); - node4->set_op_type("noop_reducedncnn"); - node5->set_op_type("noop_reducedncnn"); - node6->set_op_type("noop_reducedncnn"); - - node_reference[node->input(0)] -= 1; - node_reference[node->output(0)] -= 1; - node_reference[node2->input(1)] -= 1; - node_reference[node2->output(0)] -= 1; - // node_reference[node3->output(0)] -= 1; - node_reference[node4->output(0)] -= 1; - node_reference[node5->input(0)] -= 1; - node_reference[node5->output(0)] -= 1; - node_reference[node6->output(0)] -= 1; - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - // blob_names.erase(node3->output(0)); - blob_names.erase(node4->output(0)); - blob_names.erase(node5->output(0)); - blob_names.erase(node6->output(0)); - - node7->set_op_type("Flatten"); - node7->clear_input(); - node7->add_input(node->input(0)); - - reduced_node_count += 5; - i += 5; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat + // - Reshape + if (node->op_type() == "Shape") + { + if (node_reference[node->output(0)] != 1) continue; + + if (i + 6 >= node_count) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); + onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); + onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); + + if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || + node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze" || + node6->op_type() != "Concat" || node7->op_type() != "Reshape") + continue; + + if (node_reference[node2->output(0)] != 1) continue; + + // if (node_reference[node3->output(0)] != 1) + // continue; + + if (node_reference[node4->output(0)] != 1) continue; + + if (node_reference[node5->output(0)] != 1) continue; + + if (node_reference[node6->output(0)] != 1) continue; + + if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || + node5->input(0) != node3->output(0) || node6->input(0) != node4->output(0) || + node6->input(1) != node5->output(0) || node7->input(0) != node->input(0) || + node7->input(1) != node6->output(0)) + continue; + + // axis = 0 + int gather_axis = get_node_attr_i(*node2, "axis"); + if (gather_axis != 0) continue; + + // indices = 0 + if (weights.find(node2->input(1)) == weights.end()) continue; + + std::vector gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]); + if (gather_indices.size() != 1 || gather_indices[0] != 0) continue; + + // axes = (0) + std::vector unsqueeze_axes = get_node_attr_ai(*node4, "axes"); + if (unsqueeze_axes.size() != 1) continue; + if (unsqueeze_axes[0] != 0) continue; + + // axes = (0) + std::vector unsqueeze2_axes = get_node_attr_ai(*node5, "axes"); + if (unsqueeze2_axes.size() != 1) continue; + if (unsqueeze2_axes[0] != 0) continue; + + // data = -1 + if (weights.find(node5->input(0)) == weights.end()) continue; + + std::vector unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]); + if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1) continue; + + // axis = 0 + int concat_axis = get_node_attr_i(*node6, "axis"); + if (concat_axis != 0) continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + // node3->set_op_type("noop_reducedncnn"); + node4->set_op_type("noop_reducedncnn"); + node5->set_op_type("noop_reducedncnn"); + node6->set_op_type("noop_reducedncnn"); + + node_reference[node->input(0)] -= 1; + node_reference[node->output(0)] -= 1; + node_reference[node2->input(1)] -= 1; + node_reference[node2->output(0)] -= 1; + // node_reference[node3->output(0)] -= 1; + node_reference[node4->output(0)] -= 1; + node_reference[node5->input(0)] -= 1; + node_reference[node5->output(0)] -= 1; + node_reference[node6->output(0)] -= 1; + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + // blob_names.erase(node3->output(0)); + blob_names.erase(node4->output(0)); + blob_names.erase(node5->output(0)); + blob_names.erase(node6->output(0)); + + node7->set_op_type("Flatten"); + node7->clear_input(); + node7->add_input(node->input(0)); + + reduced_node_count += 5; + i += 5; + } } - } } -void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, +void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // PixelShuffle <= Reshape - Transpose - Reshape - // PixelShuffle <= Reshape - Transpose - Constant - Reshape - if (node->op_type() == "Reshape") { - if (node_reference[node->output(0)] != 1) continue; + // PixelShuffle <= Reshape - Transpose - Reshape + // PixelShuffle <= Reshape - Transpose - Constant - Reshape + if (node->op_type() == "Reshape") + { + if (node_reference[node->output(0)] != 1) continue; - std::vector shape; - if (node->input_size() == 1) { - shape = get_node_attr_ai(*node, "shape"); - } else { - // skip weight reshape - if (weights.find(node->input(1)) == weights.end()) continue; + std::vector shape; + if (node->input_size() == 1) + { + shape = get_node_attr_ai(*node, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node->input(1)) == weights.end()) continue; - shape = get_node_attr_from_input_ai(weights[node->input(1)]); - } + shape = get_node_attr_from_input_ai(weights[node->input(1)]); + } - // -1, 3, upscale_factor, upscale_factor, height, width - if (shape.size() != 6) continue; + // -1, 3, upscale_factor, upscale_factor, height, width + if (shape.size() != 6) continue; - if (shape[0] != 1 && shape[0] != -1) continue; + if (shape[0] != 1 && shape[0] != -1) continue; - if (shape[2] != shape[3]) continue; + if (shape[2] != shape[3]) continue; - if (i + 2 >= node_count) continue; + if (i + 2 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - if (node3->op_type() == "Constant") { - if (i + 3 >= node_count) continue; + if (node3->op_type() == "Constant") + { + if (i + 3 >= node_count) continue; - node3 = mutable_graph->mutable_node(i + 3); - } + node3 = mutable_graph->mutable_node(i + 3); + } - if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; + if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; - if (node_reference[node2->output(0)] != 1) continue; + if (node_reference[node2->output(0)] != 1) continue; - // 0 1 4 2 5 3 - std::vector perm = get_node_attr_ai(*node2, "perm"); - if (perm.size() != 6) continue; + // 0 1 4 2 5 3 + std::vector perm = get_node_attr_ai(*node2, "perm"); + if (perm.size() != 6) continue; - if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || - perm[5] != 3) - continue; + if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || + perm[5] != 3) + continue; - std::vector shape3; - if (node3->input_size() == 1) { - shape3 = get_node_attr_ai(*node3, "shape"); - } else { - // skip weight reshape - if (weights.find(node3->input(1)) == weights.end()) continue; + std::vector shape3; + if (node3->input_size() == 1) + { + shape3 = get_node_attr_ai(*node3, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node3->input(1)) == weights.end()) continue; - shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); - } + shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); + } - // -1, 3, height, width - if (shape3.size() != 4) continue; + // -1, 3, height, width + if (shape3.size() != 4) continue; - if (shape3[0] != 1 && shape3[0] != -1) continue; + if (shape3[0] != 1 && shape3[0] != -1) continue; - if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || - shape3[3] != shape[3] * shape[5]) - continue; + if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || + shape3[3] != shape[3] * shape[5]) + continue; - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); - if (node->input_size() == 2) { - node_reference[node->input(1)] -= 1; - } - node_reference[node->output(0)] -= 1; - node_reference[node2->output(0)] -= 1; - if (node3->input_size() == 2) { - node_reference[node3->input(1)] -= 1; - } + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); - node3->set_op_type("PixelShuffle"); - node3->set_input(0, node->input(0)); + node3->set_op_type("PixelShuffle"); + node3->set_input(0, node->input(0)); - onnx::AttributeProto* attr_group = node3->add_attribute(); - attr_group->set_name("scale_factor"); - attr_group->set_i(shape[2]); + onnx::AttributeProto* attr_group = node3->add_attribute(); + attr_group->set_name("scale_factor"); + attr_group->set_i(shape[2]); - reduced_node_count += 2; - i += 2; + reduced_node_count += 2; + i += 2; + } } - } } -void fuse_reorg(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); +void fuse_reorg(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // PixelShuffle <= Reshape - Transpose - Reshape - // PixelShuffle <= Reshape - Transpose - Constant - Reshape - if (node->op_type() == "Reshape") { - if (node_reference[node->output(0)] != 1) continue; + // PixelShuffle <= Reshape - Transpose - Reshape + // PixelShuffle <= Reshape - Transpose - Constant - Reshape + if (node->op_type() == "Reshape") + { + if (node_reference[node->output(0)] != 1) continue; - std::vector shape; - if (node->input_size() == 1) { - shape = get_node_attr_ai(*node, "shape"); - } else { - // skip weight reshape - if (weights.find(node->input(1)) == weights.end()) continue; + std::vector shape; + if (node->input_size() == 1) + { + shape = get_node_attr_ai(*node, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node->input(1)) == weights.end()) continue; - shape = get_node_attr_from_input_ai(weights[node->input(1)]); - } + shape = get_node_attr_from_input_ai(weights[node->input(1)]); + } - // -1, 3, out_height, block_size, out_width, block_size - if (shape.size() != 6) continue; + // -1, 3, out_height, block_size, out_width, block_size + if (shape.size() != 6) continue; - if (shape[0] != 1 && shape[0] != -1) continue; + if (shape[0] != 1 && shape[0] != -1) continue; - if (shape[3] != shape[5]) continue; + if (shape[3] != shape[5]) continue; - if (i + 2 >= node_count) continue; + if (i + 2 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - if (node3->op_type() == "Constant") { - if (i + 3 >= node_count) continue; + if (node3->op_type() == "Constant") + { + if (i + 3 >= node_count) continue; - node3 = mutable_graph->mutable_node(i + 3); - } + node3 = mutable_graph->mutable_node(i + 3); + } - if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; + if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; - if (node_reference[node2->output(0)] != 1) continue; + if (node_reference[node2->output(0)] != 1) continue; - // 0 1 3 5 2 4 - std::vector perm = get_node_attr_ai(*node2, "perm"); - if (perm.size() != 6) continue; + // 0 1 3 5 2 4 + std::vector perm = get_node_attr_ai(*node2, "perm"); + if (perm.size() != 6) continue; - if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || - perm[5] != 4) - continue; + if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || + perm[5] != 4) + continue; - std::vector shape3; - if (node3->input_size() == 1) { - shape3 = get_node_attr_ai(*node3, "shape"); - } else { - // skip weight reshape - if (weights.find(node3->input(1)) == weights.end()) continue; + std::vector shape3; + if (node3->input_size() == 1) + { + shape3 = get_node_attr_ai(*node3, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node3->input(1)) == weights.end()) continue; - shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); - } + shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]); + } - // -1, out_channels, out_height, out_width - if (shape3.size() != 4) continue; + // -1, out_channels, out_height, out_width + if (shape3.size() != 4) continue; - if (shape3[0] != 1 && shape3[0] != -1) continue; + if (shape3[0] != 1 && shape3[0] != -1) continue; - if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || - shape3[3] != shape[4]) - continue; + if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || + shape3[3] != shape[4]) + continue; - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); - if (node->input_size() == 2) { - node_reference[node->input(1)] -= 1; - } - node_reference[node->output(0)] -= 1; - node_reference[node2->output(0)] -= 1; - if (node3->input_size() == 2) { - node_reference[node3->input(1)] -= 1; - } + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); - node3->set_op_type("Reorg"); - node3->set_input(0, node->input(0)); + node3->set_op_type("Reorg"); + node3->set_input(0, node->input(0)); - onnx::AttributeProto* attr_group = node3->add_attribute(); - attr_group->set_name("stride"); - attr_group->set_i(shape[3]); + onnx::AttributeProto* attr_group = node3->add_attribute(); + attr_group->set_name("stride"); + attr_group->set_i(shape[3]); - reduced_node_count += 2; - i += 2; + reduced_node_count += 2; + i += 2; + } } - } } -void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, +void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max - if (node->op_type() == "Expand") { - if (node_reference[node->output(0)] != 1) continue; + // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max + if (node->op_type() == "Expand") + { + if (node_reference[node->output(0)] != 1) continue; - if (i + 1 >= node_count) continue; + if (i + 1 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && - node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max") - continue; + if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && + node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max") + continue; - if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0)) continue; + if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0)) continue; - // reduce - node->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; - if (node->input_size() == 2) { - node_reference[node->input(1)] -= 1; - } + node_reference[node->output(0)] -= 1; + if (node->input_size() == 2) + { + node_reference[node->input(1)] -= 1; + } - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - if (node2->input(0) == node->output(0)) { - node2->set_input(0, node->input(0)); - } else { - node2->set_input(1, node->input(0)); - } + if (node2->input(0) == node->output(0)) + { + node2->set_input(0, node->input(0)); + } + else + { + node2->set_input(1, node->input(0)); + } - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; + } } - } } -void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, +void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose - // or LSTM(bi) <= LSTM(bi) - Transpose Constant - Reshape - Transpose - if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") { - if (node_reference[node->output(0)] != 1) continue; + // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose + // or LSTM(bi) <= LSTM(bi) - Transpose Constant - Reshape - Transpose + if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") + { + if (node_reference[node->output(0)] != 1) continue; - if (i + 2 >= node_count) continue; + if (i + 2 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - // skip if second ops is constant - if (node3->op_type() == "Constant") { - if (i + 3 >= node_count) continue; - node3 = mutable_graph->mutable_node(i + 3); - i += 1; - } + // skip if second ops is constant + if (node3->op_type() == "Constant") + { + if (i + 3 >= node_count) continue; + node3 = mutable_graph->mutable_node(i + 3); + i += 1; + } - if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; + if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape") continue; - if (node_reference[node2->output(0)] != 1) continue; + if (node_reference[node2->output(0)] != 1) continue; - if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) continue; + if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)) continue; - std::string direction = get_node_attr_s(*node, "direction"); - if (direction != "bidirectional") continue; + std::string direction = get_node_attr_s(*node, "direction"); + if (direction != "bidirectional") continue; - // 0 2 1 3 - std::vector perm = get_node_attr_ai(*node2, "perm"); - if (perm.size() != 4) continue; + // 0 2 1 3 + std::vector perm = get_node_attr_ai(*node2, "perm"); + if (perm.size() != 4) continue; - if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3) continue; + if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3) continue; - std::vector shape; - if (node3->input_size() == 1) { - shape = get_node_attr_ai(*node3, "shape"); - } else { - // skip weight reshape - if (weights.find(node3->input(1)) == weights.end()) continue; + std::vector shape; + if (node3->input_size() == 1) + { + shape = get_node_attr_ai(*node3, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node3->input(1)) == weights.end()) continue; - shape = get_node_attr_from_input_ai(weights[node3->input(1)]); - } + shape = get_node_attr_from_input_ai(weights[node3->input(1)]); + } - // 0 0 -1 - if (shape.size() != 3) continue; + // 0 0 -1 + if (shape.size() != 3) continue; - if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1) continue; + if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1) continue; - // reduce - node2->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); + // reduce + node2->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; - node_reference[node2->output(0)] -= 1; - if (node3->input_size() == 2) { - node_reference[node3->input(1)] -= 1; - } + node_reference[node->output(0)] -= 1; + node_reference[node2->output(0)] -= 1; + if (node3->input_size() == 2) + { + node_reference[node3->input(1)] -= 1; + } - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); - node->set_output(0, node3->output(0)); + node->set_output(0, node3->output(0)); - reduced_node_count += 2; - i += 2; + reduced_node_count += 2; + i += 2; - if (i + 1 < node_count) { - if (node_reference[node3->output(0)] != 1) continue; + if (i + 1 < node_count) + { + if (node_reference[node3->output(0)] != 1) continue; - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1); - if (node4->op_type() != "Transpose") continue; + if (node4->op_type() != "Transpose") continue; - if (node4->input(0) != node->output(0)) continue; + if (node4->input(0) != node->output(0)) continue; - // 1 0 2 - std::vector perm4 = get_node_attr_ai(*node4, "perm"); - if (perm4.size() != 3) continue; + // 1 0 2 + std::vector perm4 = get_node_attr_ai(*node4, "perm"); + if (perm4.size() != 3) continue; - if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) continue; + if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) continue; - // reduce - node4->set_op_type("noop_reducedncnn"); + // reduce + node4->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node->set_output(0, node4->output(0)); + node->set_output(0, node4->output(0)); - reduced_node_count += 1; - i += 1; - } + reduced_node_count += 1; + i += 1; + } + } } - } - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose - if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") { - if (node_reference[node->output(0)] != 1) continue; + // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose + if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN") + { + if (node_reference[node->output(0)] != 1) continue; - if (i + 1 >= node_count) continue; + if (i + 1 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - if (node2->op_type() != "Squeeze") continue; + if (node2->op_type() != "Squeeze") continue; - if (node2->input(0) != node->output(0)) continue; + if (node2->input(0) != node->output(0)) continue; - std::string direction = get_node_attr_s(*node, "direction"); - if (direction == "bidirectional") continue; + std::string direction = get_node_attr_s(*node, "direction"); + if (direction == "bidirectional") continue; - // 1 - std::vector axes = get_node_attr_ai(*node2, "axes"); - if (axes.size() != 1) continue; + // 1 + std::vector axes = get_node_attr_ai(*node2, "axes"); + if (axes.size() != 1) continue; - if (axes[0] != 1) continue; + if (axes[0] != 1) continue; - // reduce - node2->set_op_type("noop_reducedncnn"); + // reduce + node2->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node->set_output(0, node2->output(0)); + node->set_output(0, node2->output(0)); - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; - if (i + 1 < node_count) { - if (node_reference[node2->output(0)] != 1) continue; + if (i + 1 < node_count) + { + if (node_reference[node2->output(0)] != 1) continue; - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1); - if (node3->op_type() != "Transpose") continue; + if (node3->op_type() != "Transpose") continue; - if (node3->input(0) != node->output(0)) continue; + if (node3->input(0) != node->output(0)) continue; - // 1 0 2 - std::vector perm4 = get_node_attr_ai(*node3, "perm"); - if (perm4.size() != 3) continue; + // 1 0 2 + std::vector perm4 = get_node_attr_ai(*node3, "perm"); + if (perm4.size() != 3) continue; - if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) continue; + if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2) continue; - // reduce - node3->set_op_type("noop_reducedncnn"); + // reduce + node3->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node->set_output(0, node3->output(0)); + node->set_output(0, node3->output(0)); - reduced_node_count += 1; - i += 1; - } + reduced_node_count += 1; + i += 1; + } + } } - } - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); - // LSTM <= Transpose - LSTM - if (node->op_type() == "Transpose") { - if (node_reference[node->output(0)] != 1) continue; + // LSTM <= Transpose - LSTM + if (node->op_type() == "Transpose") + { + if (node_reference[node->output(0)] != 1) continue; - // 1 0 2 - std::vector perm = get_node_attr_ai(*node, "perm"); - if (perm.size() != 3) continue; + // 1 0 2 + std::vector perm = get_node_attr_ai(*node, "perm"); + if (perm.size() != 3) continue; - if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2) continue; + if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2) continue; - if (i + 1 >= node_count) continue; + if (i + 1 >= node_count) continue; - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN") - continue; + if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN") + continue; - if (node2->input(0) != node->output(0)) continue; + if (node2->input(0) != node->output(0)) continue; - // reduce - node->set_op_type("noop_reducedncnn"); + // reduce + node->set_op_type("noop_reducedncnn"); - node_reference[node->output(0)] -= 1; + node_reference[node->output(0)] -= 1; - blob_names.erase(node->output(0)); + blob_names.erase(node->output(0)); - node2->set_input(0, node->input(0)); + node2->set_input(0, node->input(0)); - reduced_node_count += 1; - i += 1; + reduced_node_count += 1; + i += 1; + } } - } } -void fuse_multiheadattention(onnx::GraphProto* mutable_graph, +void fuse_multiheadattention(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count) { - int node_count = mutable_graph->node_size(); - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // MultiHeadAttention <= MatMul(q) - Add - // - MatMul(k) - Add - // - MatMul(v) - Add - // - Mul - // - Reshape - Transpose - // - Reshape - Reshape - Transpose - Transpose - // - Gemm - Softmax - Gemm - Transpose - Reshape - - // MatMul - Add - if (node->op_type() == "MatMul") { - if (i + 19 >= node_count) continue; - - if (node_reference[node->output(0)] != 1) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); - onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); - onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); - onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); - onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); - onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9); - onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10); - onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11); - onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12); - onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13); - onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14); - onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15); - onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16); - onnx::NodeProto* node18 = mutable_graph->mutable_node(i + 17); - onnx::NodeProto* node19 = mutable_graph->mutable_node(i + 18); - onnx::NodeProto* node20 = mutable_graph->mutable_node(i + 19); - - if (node2->op_type() != "Add" || node3->op_type() != "MatMul" || node4->op_type() != "Add" || - node5->op_type() != "MatMul" || node6->op_type() != "Add" || node7->op_type() != "Mul" || - node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || - node10->op_type() != "Reshape" || node11->op_type() != "Reshape" || - node12->op_type() != "Transpose" || node13->op_type() != "Transpose" || - node14->op_type() != "MatMul" || node15->op_type() != "Softmax" || - node16->op_type() != "MatMul" || node17->op_type() != "Transpose" || - node18->op_type() != "Reshape" || node19->op_type() != "MatMul" || - node20->op_type() != "Add") - continue; - - if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || - node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || - node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || - node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || - node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || - node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || - node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || - node_reference[node16->output(0)] != 1 || node_reference[node17->output(0)] != 1 || - node_reference[node18->output(0)] != 1 || node_reference[node19->output(0)] != 1) - continue; - - if (node2->input(0) != node->output(0) || node4->input(0) != node3->output(0) || - node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) || - node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0) || - node10->input(0) != node4->output(0) || node11->input(0) != node6->output(0) || - node12->input(0) != node11->output(0) || node13->input(0) != node10->output(0) || - node14->input(0) != node9->output(0) || node14->input(1) != node13->output(0) || - node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || - node16->input(1) != node12->output(0) || node17->input(0) != node16->output(0) || - node18->input(0) != node17->output(0) || node19->input(0) != node18->output(0) || - node20->input(0) != node19->output(0)) - continue; - - std::vector q_B = get_node_attr_from_input_af(weights[node2->input(1)]); - std::vector k_B = get_node_attr_from_input_af(weights[node4->input(1)]); - std::vector v_B = get_node_attr_from_input_af(weights[node6->input(1)]); - std::vector o_B = get_node_attr_from_input_af(weights[node20->input(1)]); - - if (q_B.size() != k_B.size() || q_B.size() != v_B.size() || q_B.size() != o_B.size()) - continue; - - int embed_dim = q_B.size(); - - // 1 0 2 - std::vector perm9 = get_node_attr_ai(*node9, "perm"); - std::vector perm12 = get_node_attr_ai(*node12, "perm"); - if (perm9.size() != 3 || perm12.size() != 3) continue; - - if (perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2 || perm12[0] != 1 || perm12[1] != 0 || - perm12[2] != 2) - continue; - - // 1 2 0 - std::vector perm13 = get_node_attr_ai(*node13, "perm"); - if (perm13.size() != 3) continue; - - if (perm13[0] != 1 || perm13[1] != 2 || perm13[2] != 0) continue; - - // 1 0 2 - std::vector perm17 = get_node_attr_ai(*node17, "perm"); - if (perm17.size() != 3) continue; - - if (perm17[0] != 1 || perm17[1] != 0 || perm17[2] != 2) continue; - - int softmax_axis = get_node_attr_i(*node15, "axis"); - if (softmax_axis != 2) continue; - - // 1/-1, seqlen * num_heads, embed_dim / num_heads - std::vector shape8; - std::vector shape10; - std::vector shape11; - if (node8->input_size() == 1) { - shape8 = get_node_attr_ai(*node8, "shape"); - } else { - // skip weight reshape - if (weights.find(node8->input(1)) == weights.end()) continue; - - shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]); - } - if (node10->input_size() == 1) { - shape10 = get_node_attr_ai(*node10, "shape"); - } else { - // skip weight reshape - if (weights.find(node10->input(1)) == weights.end()) continue; - - shape10 = get_node_attr_from_input_ai(weights[node10->input(1)]); - } - if (node11->input_size() == 1) { - shape11 = get_node_attr_ai(*node11, "shape"); - } else { - // skip weight reshape - if (weights.find(node11->input(1)) == weights.end()) continue; - - shape11 = get_node_attr_from_input_ai(weights[node11->input(1)]); - } - - if (shape8.size() != 3 || shape10.size() != 3 || shape11.size() != 3) continue; - - if (shape8[1] != shape10[1] || shape8[1] != shape11[1] || shape8[2] != shape10[2] || - shape8[2] != shape11[2]) - continue; - - int num_heads = embed_dim / shape8[2]; - - // 1, seqlen, embed_dim - std::vector shape18; - if (node18->input_size() == 1) { - shape18 = get_node_attr_ai(*node18, "shape"); - } else { - // skip weight reshape - if (weights.find(node18->input(1)) == weights.end()) continue; - - shape18 = get_node_attr_from_input_ai(weights[node18->input(1)]); - } - - if (shape18.size() != 3) continue; - - if (shape18[2] != embed_dim || shape18[1] * num_heads != shape8[1]) continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); - node4->set_op_type("noop_reducedncnn"); - node5->set_op_type("noop_reducedncnn"); - node6->set_op_type("noop_reducedncnn"); - node7->set_op_type("noop_reducedncnn"); - node8->set_op_type("noop_reducedncnn"); - node9->set_op_type("noop_reducedncnn"); - node10->set_op_type("noop_reducedncnn"); - node11->set_op_type("noop_reducedncnn"); - node12->set_op_type("noop_reducedncnn"); - node13->set_op_type("noop_reducedncnn"); - node14->set_op_type("noop_reducedncnn"); - node15->set_op_type("noop_reducedncnn"); - node16->set_op_type("noop_reducedncnn"); - node17->set_op_type("noop_reducedncnn"); - node18->set_op_type("noop_reducedncnn"); - node19->set_op_type("noop_reducedncnn"); - - node_reference[node2->input(0)] -= 1; - node_reference[node4->input(0)] -= 1; - node_reference[node6->input(0)] -= 1; - node_reference[node7->input(0)] -= 1; - node_reference[node7->input(1)] -= 1; - node_reference[node8->input(0)] -= 1; - if (node8->input_size() == 2) { - node_reference[node8->input(1)] -= 1; - } - node_reference[node9->input(0)] -= 1; - node_reference[node10->input(0)] -= 1; - if (node10->input_size() == 2) { - node_reference[node10->input(1)] -= 1; - } - node_reference[node11->input(0)] -= 1; - if (node11->input_size() == 2) { - node_reference[node11->input(1)] -= 1; - } - node_reference[node12->input(0)] -= 1; - node_reference[node13->input(0)] -= 1; - node_reference[node14->input(0)] -= 1; - node_reference[node14->input(1)] -= 1; - node_reference[node15->input(0)] -= 1; - node_reference[node16->input(0)] -= 1; - node_reference[node16->input(1)] -= 1; - node_reference[node17->input(0)] -= 1; - node_reference[node18->input(0)] -= 1; - if (node18->input_size() == 2) { - node_reference[node18->input(1)] -= 1; - } - node_reference[node19->input(0)] -= 1; - node_reference[node20->input(0)] -= 1; - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - blob_names.erase(node3->output(0)); - blob_names.erase(node4->output(0)); - blob_names.erase(node5->output(0)); - blob_names.erase(node6->output(0)); - blob_names.erase(node7->output(0)); - blob_names.erase(node8->output(0)); - blob_names.erase(node9->output(0)); - blob_names.erase(node10->output(0)); - blob_names.erase(node11->output(0)); - blob_names.erase(node12->output(0)); - blob_names.erase(node13->output(0)); - blob_names.erase(node14->output(0)); - blob_names.erase(node15->output(0)); - blob_names.erase(node16->output(0)); - blob_names.erase(node17->output(0)); - blob_names.erase(node18->output(0)); - blob_names.erase(node19->output(0)); - - std::string qw = node->input(1); - std::string qb = node2->input(1); - std::string kw = node3->input(1); - std::string kb = node4->input(1); - std::string vw = node5->input(1); - std::string vb = node6->input(1); - std::string ow = node19->input(1); - std::string ob = node20->input(1); - - node20->set_op_type("MultiHeadAttention"); - node20->clear_input(); - node20->add_input(node->input(0)); - node20->add_input(node3->input(0)); - node20->add_input(node5->input(0)); - // q - node20->add_input(qw); - node20->add_input(qb); - // k - node20->add_input(kw); - node20->add_input(kb); - // v - node20->add_input(vw); - node20->add_input(vb); - // out linear - node20->add_input(ow); - node20->add_input(ob); - - onnx::AttributeProto* attr_embed_dim = node20->add_attribute(); - attr_embed_dim->set_name("embed_dim"); - attr_embed_dim->set_i(embed_dim); - - onnx::AttributeProto* attr_num_heads = node20->add_attribute(); - attr_num_heads->set_name("num_heads"); - attr_num_heads->set_i(num_heads); - - reduced_node_count += 19; - i += 19; + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count) +{ + int node_count = mutable_graph->node_size(); + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // MultiHeadAttention <= MatMul(q) - Add + // - MatMul(k) - Add + // - MatMul(v) - Add + // - Mul + // - Reshape - Transpose + // - Reshape - Reshape - Transpose - Transpose + // - Gemm - Softmax - Gemm - Transpose - Reshape - + // MatMul - Add + if (node->op_type() == "MatMul") + { + if (i + 19 >= node_count) continue; + + if (node_reference[node->output(0)] != 1) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); + onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); + onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); + onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); + onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); + onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9); + onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10); + onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11); + onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12); + onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13); + onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14); + onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15); + onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16); + onnx::NodeProto* node18 = mutable_graph->mutable_node(i + 17); + onnx::NodeProto* node19 = mutable_graph->mutable_node(i + 18); + onnx::NodeProto* node20 = mutable_graph->mutable_node(i + 19); + + if (node2->op_type() != "Add" || node3->op_type() != "MatMul" || node4->op_type() != "Add" || + node5->op_type() != "MatMul" || node6->op_type() != "Add" || node7->op_type() != "Mul" || + node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || + node10->op_type() != "Reshape" || node11->op_type() != "Reshape" || + node12->op_type() != "Transpose" || node13->op_type() != "Transpose" || + node14->op_type() != "MatMul" || node15->op_type() != "Softmax" || + node16->op_type() != "MatMul" || node17->op_type() != "Transpose" || + node18->op_type() != "Reshape" || node19->op_type() != "MatMul" || + node20->op_type() != "Add") + continue; + + if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || + node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || + node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || + node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || + node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || + node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || + node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || + node_reference[node16->output(0)] != 1 || node_reference[node17->output(0)] != 1 || + node_reference[node18->output(0)] != 1 || node_reference[node19->output(0)] != 1) + continue; + + if (node2->input(0) != node->output(0) || node4->input(0) != node3->output(0) || + node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) || + node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0) || + node10->input(0) != node4->output(0) || node11->input(0) != node6->output(0) || + node12->input(0) != node11->output(0) || node13->input(0) != node10->output(0) || + node14->input(0) != node9->output(0) || node14->input(1) != node13->output(0) || + node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || + node16->input(1) != node12->output(0) || node17->input(0) != node16->output(0) || + node18->input(0) != node17->output(0) || node19->input(0) != node18->output(0) || + node20->input(0) != node19->output(0)) + continue; + + std::vector q_B = get_node_attr_from_input_af(weights[node2->input(1)]); + std::vector k_B = get_node_attr_from_input_af(weights[node4->input(1)]); + std::vector v_B = get_node_attr_from_input_af(weights[node6->input(1)]); + std::vector o_B = get_node_attr_from_input_af(weights[node20->input(1)]); + + if (q_B.size() != k_B.size() || q_B.size() != v_B.size() || q_B.size() != o_B.size()) + continue; + + int embed_dim = q_B.size(); + + // 1 0 2 + std::vector perm9 = get_node_attr_ai(*node9, "perm"); + std::vector perm12 = get_node_attr_ai(*node12, "perm"); + if (perm9.size() != 3 || perm12.size() != 3) continue; + + if (perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2 || perm12[0] != 1 || perm12[1] != 0 || + perm12[2] != 2) + continue; + + // 1 2 0 + std::vector perm13 = get_node_attr_ai(*node13, "perm"); + if (perm13.size() != 3) continue; + + if (perm13[0] != 1 || perm13[1] != 2 || perm13[2] != 0) continue; + + // 1 0 2 + std::vector perm17 = get_node_attr_ai(*node17, "perm"); + if (perm17.size() != 3) continue; + + if (perm17[0] != 1 || perm17[1] != 0 || perm17[2] != 2) continue; + + int softmax_axis = get_node_attr_i(*node15, "axis"); + if (softmax_axis != 2) continue; + + // 1/-1, seqlen * num_heads, embed_dim / num_heads + std::vector shape8; + std::vector shape10; + std::vector shape11; + if (node8->input_size() == 1) + { + shape8 = get_node_attr_ai(*node8, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node8->input(1)) == weights.end()) continue; + + shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]); + } + if (node10->input_size() == 1) + { + shape10 = get_node_attr_ai(*node10, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node10->input(1)) == weights.end()) continue; + + shape10 = get_node_attr_from_input_ai(weights[node10->input(1)]); + } + if (node11->input_size() == 1) + { + shape11 = get_node_attr_ai(*node11, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node11->input(1)) == weights.end()) continue; + + shape11 = get_node_attr_from_input_ai(weights[node11->input(1)]); + } + + if (shape8.size() != 3 || shape10.size() != 3 || shape11.size() != 3) continue; + + if (shape8[1] != shape10[1] || shape8[1] != shape11[1] || shape8[2] != shape10[2] || + shape8[2] != shape11[2]) + continue; + + int num_heads = embed_dim / shape8[2]; + + // 1, seqlen, embed_dim + std::vector shape18; + if (node18->input_size() == 1) + { + shape18 = get_node_attr_ai(*node18, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node18->input(1)) == weights.end()) continue; + + shape18 = get_node_attr_from_input_ai(weights[node18->input(1)]); + } + + if (shape18.size() != 3) continue; + + if (shape18[2] != embed_dim || shape18[1] * num_heads != shape8[1]) continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); + node4->set_op_type("noop_reducedncnn"); + node5->set_op_type("noop_reducedncnn"); + node6->set_op_type("noop_reducedncnn"); + node7->set_op_type("noop_reducedncnn"); + node8->set_op_type("noop_reducedncnn"); + node9->set_op_type("noop_reducedncnn"); + node10->set_op_type("noop_reducedncnn"); + node11->set_op_type("noop_reducedncnn"); + node12->set_op_type("noop_reducedncnn"); + node13->set_op_type("noop_reducedncnn"); + node14->set_op_type("noop_reducedncnn"); + node15->set_op_type("noop_reducedncnn"); + node16->set_op_type("noop_reducedncnn"); + node17->set_op_type("noop_reducedncnn"); + node18->set_op_type("noop_reducedncnn"); + node19->set_op_type("noop_reducedncnn"); + + node_reference[node2->input(0)] -= 1; + node_reference[node4->input(0)] -= 1; + node_reference[node6->input(0)] -= 1; + node_reference[node7->input(0)] -= 1; + node_reference[node7->input(1)] -= 1; + node_reference[node8->input(0)] -= 1; + if (node8->input_size() == 2) + { + node_reference[node8->input(1)] -= 1; + } + node_reference[node9->input(0)] -= 1; + node_reference[node10->input(0)] -= 1; + if (node10->input_size() == 2) + { + node_reference[node10->input(1)] -= 1; + } + node_reference[node11->input(0)] -= 1; + if (node11->input_size() == 2) + { + node_reference[node11->input(1)] -= 1; + } + node_reference[node12->input(0)] -= 1; + node_reference[node13->input(0)] -= 1; + node_reference[node14->input(0)] -= 1; + node_reference[node14->input(1)] -= 1; + node_reference[node15->input(0)] -= 1; + node_reference[node16->input(0)] -= 1; + node_reference[node16->input(1)] -= 1; + node_reference[node17->input(0)] -= 1; + node_reference[node18->input(0)] -= 1; + if (node18->input_size() == 2) + { + node_reference[node18->input(1)] -= 1; + } + node_reference[node19->input(0)] -= 1; + node_reference[node20->input(0)] -= 1; + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + blob_names.erase(node3->output(0)); + blob_names.erase(node4->output(0)); + blob_names.erase(node5->output(0)); + blob_names.erase(node6->output(0)); + blob_names.erase(node7->output(0)); + blob_names.erase(node8->output(0)); + blob_names.erase(node9->output(0)); + blob_names.erase(node10->output(0)); + blob_names.erase(node11->output(0)); + blob_names.erase(node12->output(0)); + blob_names.erase(node13->output(0)); + blob_names.erase(node14->output(0)); + blob_names.erase(node15->output(0)); + blob_names.erase(node16->output(0)); + blob_names.erase(node17->output(0)); + blob_names.erase(node18->output(0)); + blob_names.erase(node19->output(0)); + + std::string qw = node->input(1); + std::string qb = node2->input(1); + std::string kw = node3->input(1); + std::string kb = node4->input(1); + std::string vw = node5->input(1); + std::string vb = node6->input(1); + std::string ow = node19->input(1); + std::string ob = node20->input(1); + + node20->set_op_type("MultiHeadAttention"); + node20->clear_input(); + node20->add_input(node->input(0)); + node20->add_input(node3->input(0)); + node20->add_input(node5->input(0)); + // q + node20->add_input(qw); + node20->add_input(qb); + // k + node20->add_input(kw); + node20->add_input(kb); + // v + node20->add_input(vw); + node20->add_input(vb); + // out linear + node20->add_input(ow); + node20->add_input(ob); + + onnx::AttributeProto* attr_embed_dim = node20->add_attribute(); + attr_embed_dim->set_name("embed_dim"); + attr_embed_dim->set_i(embed_dim); + + onnx::AttributeProto* attr_num_heads = node20->add_attribute(); + attr_num_heads->set_name("num_heads"); + attr_num_heads->set_i(num_heads); + + reduced_node_count += 19; + i += 19; + } } - } - - for (int i = 0; i < node_count; i++) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - // MultiHeadAttention <= MatMul(qkv) - Add - Split - // - Mul - // - Reshape - Transpose - // - Reshape - Reshape - Transpose - Transpose - // - Gemm - Softmax - Gemm - Transpose - Reshape - - // MatMul - Add - if (node->op_type() == "MatMul") { - if (i + 16 >= node_count) continue; - - if (node_reference[node->output(0)] != 1) continue; - - onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); - onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); - onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); - onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); - onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); - onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); - onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); - onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); - onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9); - onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10); - onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11); - onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12); - onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13); - onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14); - onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15); - onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16); - - if (node2->op_type() != "Add" || node3->op_type() != "Split" || node4->op_type() != "Mul" || - node5->op_type() != "Reshape" || node6->op_type() != "Transpose" || - node7->op_type() != "Reshape" || node8->op_type() != "Reshape" || - node9->op_type() != "Transpose" || node10->op_type() != "Transpose" || - node11->op_type() != "MatMul" || node12->op_type() != "Softmax" || - node13->op_type() != "MatMul" || node14->op_type() != "Transpose" || - node15->op_type() != "Reshape" || node16->op_type() != "MatMul" || - node17->op_type() != "Add") - continue; - - if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || - node_reference[node3->output(1)] != 1 || node_reference[node3->output(2)] != 1 || - node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || - node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || - node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || - node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || - node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || - node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || - node_reference[node16->output(0)] != 1) - continue; - - if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || - node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) || - node6->input(0) != node5->output(0) || node7->input(0) != node3->output(1) || - node8->input(0) != node3->output(2) || node9->input(0) != node8->output(0) || - node10->input(0) != node7->output(0) || node11->input(0) != node6->output(0) || - node11->input(1) != node10->output(0) || node12->input(0) != node11->output(0) || - node13->input(0) != node12->output(0) || node13->input(1) != node9->output(0) || - node14->input(0) != node13->output(0) || node15->input(0) != node14->output(0) || - node16->input(0) != node15->output(0) || node17->input(0) != node16->output(0)) - continue; - - std::vector qkv_B = get_node_attr_from_input_af(weights[node2->input(1)]); - std::vector o_B = get_node_attr_from_input_af(weights[node17->input(1)]); - - if (qkv_B.size() != o_B.size() * 3) continue; - - int embed_dim = o_B.size(); - - // 1 0 2 - std::vector perm6 = get_node_attr_ai(*node6, "perm"); - std::vector perm9 = get_node_attr_ai(*node9, "perm"); - if (perm6.size() != 3 || perm9.size() != 3) continue; - - if (perm6[0] != 1 || perm6[1] != 0 || perm6[2] != 2 || perm9[0] != 1 || perm9[1] != 0 || - perm9[2] != 2) - continue; - - // 1 2 0 - std::vector perm10 = get_node_attr_ai(*node10, "perm"); - if (perm10.size() != 3) continue; - - if (perm10[0] != 1 || perm10[1] != 2 || perm10[2] != 0) continue; - - // 1 0 2 - std::vector perm14 = get_node_attr_ai(*node14, "perm"); - if (perm14.size() != 3) continue; - - if (perm14[0] != 1 || perm14[1] != 0 || perm14[2] != 2) continue; - - int softmax_axis = get_node_attr_i(*node12, "axis"); - if (softmax_axis != 2) continue; - - // 1/-1, seqlen * num_heads, embed_dim / num_heads - std::vector shape5; - std::vector shape7; - std::vector shape8; - if (node5->input_size() == 1) { - shape5 = get_node_attr_ai(*node5, "shape"); - } else { - // skip weight reshape - if (weights.find(node5->input(1)) == weights.end()) continue; - - shape5 = get_node_attr_from_input_ai(weights[node5->input(1)]); - } - if (node7->input_size() == 1) { - shape7 = get_node_attr_ai(*node7, "shape"); - } else { - // skip weight reshape - if (weights.find(node7->input(1)) == weights.end()) continue; - - shape7 = get_node_attr_from_input_ai(weights[node7->input(1)]); - } - if (node8->input_size() == 1) { - shape8 = get_node_attr_ai(*node8, "shape"); - } else { - // skip weight reshape - if (weights.find(node8->input(1)) == weights.end()) continue; - - shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]); - } - - if (shape5.size() != 3 || shape7.size() != 3 || shape8.size() != 3) continue; - - if (shape5[1] != shape7[1] || shape5[1] != shape8[1] || shape5[2] != shape7[2] || - shape5[2] != shape8[2]) - continue; - - int num_heads = embed_dim / shape5[2]; - - // 1, seqlen, embed_dim - std::vector shape15; - if (node15->input_size() == 1) { - shape15 = get_node_attr_ai(*node15, "shape"); - } else { - // skip weight reshape - if (weights.find(node15->input(1)) == weights.end()) continue; - - shape15 = get_node_attr_from_input_ai(weights[node15->input(1)]); - } - - if (shape15.size() != 3) continue; - - if (shape15[2] != embed_dim || shape15[1] * num_heads != shape8[1]) continue; - - // reduce - node->set_op_type("noop_reducedncnn"); - node2->set_op_type("noop_reducedncnn"); - node3->set_op_type("noop_reducedncnn"); - node4->set_op_type("noop_reducedncnn"); - node5->set_op_type("noop_reducedncnn"); - node6->set_op_type("noop_reducedncnn"); - node7->set_op_type("noop_reducedncnn"); - node8->set_op_type("noop_reducedncnn"); - node9->set_op_type("noop_reducedncnn"); - node10->set_op_type("noop_reducedncnn"); - node11->set_op_type("noop_reducedncnn"); - node12->set_op_type("noop_reducedncnn"); - node13->set_op_type("noop_reducedncnn"); - node14->set_op_type("noop_reducedncnn"); - node15->set_op_type("noop_reducedncnn"); - node16->set_op_type("noop_reducedncnn"); - - node_reference[node2->input(0)] -= 1; - node_reference[node3->input(0)] -= 1; - node_reference[node4->input(0)] -= 1; - node_reference[node4->input(1)] -= 1; - node_reference[node5->input(0)] -= 1; - if (node5->input_size() == 2) { - node_reference[node5->input(1)] -= 1; - } - node_reference[node6->input(0)] -= 1; - node_reference[node7->input(0)] -= 1; - if (node7->input_size() == 2) { - node_reference[node7->input(1)] -= 1; - } - node_reference[node8->input(0)] -= 1; - if (node8->input_size() == 2) { - node_reference[node8->input(1)] -= 1; - } - node_reference[node9->input(0)] -= 1; - node_reference[node10->input(0)] -= 1; - node_reference[node11->input(0)] -= 1; - node_reference[node11->input(1)] -= 1; - node_reference[node12->input(0)] -= 1; - node_reference[node13->input(0)] -= 1; - node_reference[node13->input(1)] -= 1; - node_reference[node14->input(0)] -= 1; - node_reference[node15->input(0)] -= 1; - if (node15->input_size() == 2) { - node_reference[node15->input(1)] -= 1; - } - node_reference[node16->input(0)] -= 1; - node_reference[node17->input(0)] -= 1; - - blob_names.erase(node->output(0)); - blob_names.erase(node2->output(0)); - blob_names.erase(node3->output(0)); - blob_names.erase(node3->output(1)); - blob_names.erase(node3->output(2)); - blob_names.erase(node4->output(0)); - blob_names.erase(node5->output(0)); - blob_names.erase(node6->output(0)); - blob_names.erase(node7->output(0)); - blob_names.erase(node8->output(0)); - blob_names.erase(node9->output(0)); - blob_names.erase(node10->output(0)); - blob_names.erase(node11->output(0)); - blob_names.erase(node12->output(0)); - blob_names.erase(node13->output(0)); - blob_names.erase(node14->output(0)); - blob_names.erase(node15->output(0)); - blob_names.erase(node16->output(0)); - - std::string qkvw = node->input(1); - std::string qkvb = node2->input(1); - std::string ow = node16->input(1); - std::string ob = node17->input(1); - - node17->set_op_type("MultiHeadAttention"); - node17->clear_input(); - node17->add_input(node->input(0)); - // qkv - node17->add_input(qkvw); - node17->add_input(qkvb); - // out linear - node17->add_input(ow); - node17->add_input(ob); - - onnx::AttributeProto* attr_embed_dim = node17->add_attribute(); - attr_embed_dim->set_name("embed_dim"); - attr_embed_dim->set_i(embed_dim); - - onnx::AttributeProto* attr_num_heads = node17->add_attribute(); - attr_num_heads->set_name("num_heads"); - attr_num_heads->set_i(num_heads); - - reduced_node_count += 16; - i += 16; + + for (int i = 0; i < node_count; i++) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + // MultiHeadAttention <= MatMul(qkv) - Add - Split + // - Mul + // - Reshape - Transpose + // - Reshape - Reshape - Transpose - Transpose + // - Gemm - Softmax - Gemm - Transpose - Reshape - + // MatMul - Add + if (node->op_type() == "MatMul") + { + if (i + 16 >= node_count) continue; + + if (node_reference[node->output(0)] != 1) continue; + + onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1); + onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2); + onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3); + onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4); + onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5); + onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6); + onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7); + onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8); + onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9); + onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10); + onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11); + onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12); + onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13); + onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14); + onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15); + onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16); + + if (node2->op_type() != "Add" || node3->op_type() != "Split" || node4->op_type() != "Mul" || + node5->op_type() != "Reshape" || node6->op_type() != "Transpose" || + node7->op_type() != "Reshape" || node8->op_type() != "Reshape" || + node9->op_type() != "Transpose" || node10->op_type() != "Transpose" || + node11->op_type() != "MatMul" || node12->op_type() != "Softmax" || + node13->op_type() != "MatMul" || node14->op_type() != "Transpose" || + node15->op_type() != "Reshape" || node16->op_type() != "MatMul" || + node17->op_type() != "Add") + continue; + + if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || + node_reference[node3->output(1)] != 1 || node_reference[node3->output(2)] != 1 || + node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || + node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || + node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || + node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || + node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || + node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || + node_reference[node16->output(0)] != 1) + continue; + + if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || + node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) || + node6->input(0) != node5->output(0) || node7->input(0) != node3->output(1) || + node8->input(0) != node3->output(2) || node9->input(0) != node8->output(0) || + node10->input(0) != node7->output(0) || node11->input(0) != node6->output(0) || + node11->input(1) != node10->output(0) || node12->input(0) != node11->output(0) || + node13->input(0) != node12->output(0) || node13->input(1) != node9->output(0) || + node14->input(0) != node13->output(0) || node15->input(0) != node14->output(0) || + node16->input(0) != node15->output(0) || node17->input(0) != node16->output(0)) + continue; + + std::vector qkv_B = get_node_attr_from_input_af(weights[node2->input(1)]); + std::vector o_B = get_node_attr_from_input_af(weights[node17->input(1)]); + + if (qkv_B.size() != o_B.size() * 3) continue; + + int embed_dim = o_B.size(); + + // 1 0 2 + std::vector perm6 = get_node_attr_ai(*node6, "perm"); + std::vector perm9 = get_node_attr_ai(*node9, "perm"); + if (perm6.size() != 3 || perm9.size() != 3) continue; + + if (perm6[0] != 1 || perm6[1] != 0 || perm6[2] != 2 || perm9[0] != 1 || perm9[1] != 0 || + perm9[2] != 2) + continue; + + // 1 2 0 + std::vector perm10 = get_node_attr_ai(*node10, "perm"); + if (perm10.size() != 3) continue; + + if (perm10[0] != 1 || perm10[1] != 2 || perm10[2] != 0) continue; + + // 1 0 2 + std::vector perm14 = get_node_attr_ai(*node14, "perm"); + if (perm14.size() != 3) continue; + + if (perm14[0] != 1 || perm14[1] != 0 || perm14[2] != 2) continue; + + int softmax_axis = get_node_attr_i(*node12, "axis"); + if (softmax_axis != 2) continue; + + // 1/-1, seqlen * num_heads, embed_dim / num_heads + std::vector shape5; + std::vector shape7; + std::vector shape8; + if (node5->input_size() == 1) + { + shape5 = get_node_attr_ai(*node5, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node5->input(1)) == weights.end()) continue; + + shape5 = get_node_attr_from_input_ai(weights[node5->input(1)]); + } + if (node7->input_size() == 1) + { + shape7 = get_node_attr_ai(*node7, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node7->input(1)) == weights.end()) continue; + + shape7 = get_node_attr_from_input_ai(weights[node7->input(1)]); + } + if (node8->input_size() == 1) + { + shape8 = get_node_attr_ai(*node8, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node8->input(1)) == weights.end()) continue; + + shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]); + } + + if (shape5.size() != 3 || shape7.size() != 3 || shape8.size() != 3) continue; + + if (shape5[1] != shape7[1] || shape5[1] != shape8[1] || shape5[2] != shape7[2] || + shape5[2] != shape8[2]) + continue; + + int num_heads = embed_dim / shape5[2]; + + // 1, seqlen, embed_dim + std::vector shape15; + if (node15->input_size() == 1) + { + shape15 = get_node_attr_ai(*node15, "shape"); + } + else + { + // skip weight reshape + if (weights.find(node15->input(1)) == weights.end()) continue; + + shape15 = get_node_attr_from_input_ai(weights[node15->input(1)]); + } + + if (shape15.size() != 3) continue; + + if (shape15[2] != embed_dim || shape15[1] * num_heads != shape8[1]) continue; + + // reduce + node->set_op_type("noop_reducedncnn"); + node2->set_op_type("noop_reducedncnn"); + node3->set_op_type("noop_reducedncnn"); + node4->set_op_type("noop_reducedncnn"); + node5->set_op_type("noop_reducedncnn"); + node6->set_op_type("noop_reducedncnn"); + node7->set_op_type("noop_reducedncnn"); + node8->set_op_type("noop_reducedncnn"); + node9->set_op_type("noop_reducedncnn"); + node10->set_op_type("noop_reducedncnn"); + node11->set_op_type("noop_reducedncnn"); + node12->set_op_type("noop_reducedncnn"); + node13->set_op_type("noop_reducedncnn"); + node14->set_op_type("noop_reducedncnn"); + node15->set_op_type("noop_reducedncnn"); + node16->set_op_type("noop_reducedncnn"); + + node_reference[node2->input(0)] -= 1; + node_reference[node3->input(0)] -= 1; + node_reference[node4->input(0)] -= 1; + node_reference[node4->input(1)] -= 1; + node_reference[node5->input(0)] -= 1; + if (node5->input_size() == 2) + { + node_reference[node5->input(1)] -= 1; + } + node_reference[node6->input(0)] -= 1; + node_reference[node7->input(0)] -= 1; + if (node7->input_size() == 2) + { + node_reference[node7->input(1)] -= 1; + } + node_reference[node8->input(0)] -= 1; + if (node8->input_size() == 2) + { + node_reference[node8->input(1)] -= 1; + } + node_reference[node9->input(0)] -= 1; + node_reference[node10->input(0)] -= 1; + node_reference[node11->input(0)] -= 1; + node_reference[node11->input(1)] -= 1; + node_reference[node12->input(0)] -= 1; + node_reference[node13->input(0)] -= 1; + node_reference[node13->input(1)] -= 1; + node_reference[node14->input(0)] -= 1; + node_reference[node15->input(0)] -= 1; + if (node15->input_size() == 2) + { + node_reference[node15->input(1)] -= 1; + } + node_reference[node16->input(0)] -= 1; + node_reference[node17->input(0)] -= 1; + + blob_names.erase(node->output(0)); + blob_names.erase(node2->output(0)); + blob_names.erase(node3->output(0)); + blob_names.erase(node3->output(1)); + blob_names.erase(node3->output(2)); + blob_names.erase(node4->output(0)); + blob_names.erase(node5->output(0)); + blob_names.erase(node6->output(0)); + blob_names.erase(node7->output(0)); + blob_names.erase(node8->output(0)); + blob_names.erase(node9->output(0)); + blob_names.erase(node10->output(0)); + blob_names.erase(node11->output(0)); + blob_names.erase(node12->output(0)); + blob_names.erase(node13->output(0)); + blob_names.erase(node14->output(0)); + blob_names.erase(node15->output(0)); + blob_names.erase(node16->output(0)); + + std::string qkvw = node->input(1); + std::string qkvb = node2->input(1); + std::string ow = node16->input(1); + std::string ob = node17->input(1); + + node17->set_op_type("MultiHeadAttention"); + node17->clear_input(); + node17->add_input(node->input(0)); + // qkv + node17->add_input(qkvw); + node17->add_input(qkvb); + // out linear + node17->add_input(ow); + node17->add_input(ob); + + onnx::AttributeProto* attr_embed_dim = node17->add_attribute(); + attr_embed_dim->set_name("embed_dim"); + attr_embed_dim->set_i(embed_dim); + + onnx::AttributeProto* attr_num_heads = node17->add_attribute(); + attr_num_heads->set_name("num_heads"); + attr_num_heads->set_i(num_heads); + + reduced_node_count += 16; + i += 16; + } } - } } diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.h b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.h index 31dc6f5b93..73390cc24d 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.h +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/fuse_pass.h @@ -4,30 +4,35 @@ #include "shape_inference.h" #include "utils.h" -void fuse_identity(onnx::GraphProto* mutable_graph, +void fuse_identity(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_rewrite_gather(onnx::GraphProto* mutable_graph, +void fuse_rewrite_gather(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_weight_reshape(onnx::GraphProto* mutable_graph, +void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_shufflechannel(onnx::GraphProto* mutable_graph, +void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, +void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); /** * @brief fuse subgraph @@ -46,85 +51,96 @@ void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, * @param blob_names * @param reduced_node_count */ -void fuse_conv_reshape(onnx::GraphProto* mutable_graph, +void fuse_conv_reshape(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, +void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_hardswish(onnx::GraphProto* mutable_graph, +void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, +void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, +void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, +void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_normalize(onnx::GraphProto* mutable_graph, +void fuse_normalize(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_groupnorm(onnx::GraphProto* mutable_graph, +void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_layernorm(onnx::GraphProto* mutable_graph, +void fuse_layernorm(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_flatten(onnx::GraphProto* mutable_graph, +void fuse_flatten(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, +void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_reorg(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); +void fuse_reorg(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count); -void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, +void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, +void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_multiheadattention(onnx::GraphProto* mutable_graph, +void fuse_multiheadattention(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_weight_transpose(onnx::GraphProto* mutable_graph, +void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, - std::set& blob_names, int& reduced_node_count); + std::map& node_reference, + std::set& blob_names, + int& reduced_node_count); -void fuse_swish(onnx::GraphProto* mutable_graph, std::map& weights, - std::map& node_reference, std::set& blob_names, - int& reduced_node_count); +void fuse_swish(onnx::GraphProto* mutable_graph, std::map& weights, std::map& node_reference, std::set& blob_names, int& reduced_node_count); diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp index ca8cd628ad..bc38599b63 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/onnx2ncnn.cpp @@ -26,2719 +26,3551 @@ #include "shape_inference.h" #include "utils.h" -int main(int argc, char** argv) { - if (!(argc == 2 || argc == 4)) { - fprintf(stderr, "Usage: %s [onnxpb] [ncnnparam] [ncnnbin]\n", argv[0]); - return -1; - } - - const char* onnxpb = argv[1]; - const char* ncnn_prototxt = argc == 4 ? argv[2] : "ncnn.param"; - const char* ncnn_modelbin = argc == 4 ? argv[3] : "ncnn.bin"; - - onnx::ModelProto model; - - // load - bool s1 = read_proto_from_binary(onnxpb, &model); - if (!s1) { - fprintf(stderr, "read_proto_from_binary failed\n"); - return -1; - } - FILE* pp = fopen(ncnn_prototxt, "wb"); - FILE* bp = fopen(ncnn_modelbin, "wb"); - // magic - fprintf(pp, "7767517\n"); - onnx::GraphProto* mutable_graph = model.mutable_graph(); - int node_count = mutable_graph->node_size(); - - // node reference - std::map node_reference; - - // weight node and weight reshape node - std::map weights; - for (int j = 0; j < mutable_graph->initializer_size(); j++) { - const onnx::TensorProto& initializer = mutable_graph->initializer(j); - - // fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), - // initializer.data_type()); - - weights[initializer.name()] = initializer; - } - // topological sort - { - // name -> producer node index - std::set producers; - for (int j = 0; j < mutable_graph->input_size(); j++) { - const std::string& input_name = mutable_graph->input(j).name(); - producers.insert(input_name); +int main(int argc, char** argv) +{ + if (!(argc == 2 || argc == 4)) + { + fprintf(stderr, "Usage: %s [onnxpb] [ncnnparam] [ncnnbin]\n", argv[0]); + return -1; } - for (int i = 0; i < node_count;) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); + const char* onnxpb = argv[1]; + const char* ncnn_prototxt = argc == 4 ? argv[2] : "ncnn.param"; + const char* ncnn_modelbin = argc == 4 ? argv[3] : "ncnn.bin"; - bool swapnode = false; - std::string missing_input_name; - for (int j = 0; j < (int)node->input_size(); j++) { - const std::string& input_name = node->input(j); - if (input_name.empty()) continue; + onnx::ModelProto model; - if (producers.find(input_name) == producers.end() && - weights.find(input_name) == weights.end()) { - swapnode = true; - missing_input_name = input_name; - break; - } - } + // load + bool s1 = read_proto_from_binary(onnxpb, &model); + if (!s1) + { + fprintf(stderr, "read_proto_from_binary failed\n"); + return -1; + } + FILE* pp = fopen(ncnn_prototxt, "wb"); + FILE* bp = fopen(ncnn_modelbin, "wb"); + // magic + fprintf(pp, "7767517\n"); + onnx::GraphProto* mutable_graph = model.mutable_graph(); + int node_count = mutable_graph->node_size(); + + // node reference + std::map node_reference; + + // weight node and weight reshape node + std::map weights; + for (int j = 0; j < mutable_graph->initializer_size(); j++) + { + const onnx::TensorProto& initializer = mutable_graph->initializer(j); - if (!swapnode) { - for (int j = 0; j < (int)node->output_size(); j++) { - const std::string& output_name = node->output(j); - if (output_name.empty()) continue; + // fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), + // initializer.data_type()); - producers.insert(output_name); + weights[initializer.name()] = initializer; + } + // topological sort + { + // name -> producer node index + std::set producers; + for (int j = 0; j < mutable_graph->input_size(); j++) + { + const std::string& input_name = mutable_graph->input(j).name(); + producers.insert(input_name); } - i++; - continue; - } - - // find node that produce missing_input_name - int q = i + 1; - for (; q < node_count; q++) { - onnx::NodeProto* nodeq = mutable_graph->mutable_node(q); - bool found = false; - for (int j = 0; j < (int)nodeq->output_size(); j++) { - const std::string& output_name = nodeq->output(j); - if (output_name == missing_input_name) { - found = true; - break; - } - } + for (int i = 0; i < node_count;) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + bool swapnode = false; + std::string missing_input_name; + for (int j = 0; j < (int)node->input_size(); j++) + { + const std::string& input_name = node->input(j); + if (input_name.empty()) continue; + + if (producers.find(input_name) == producers.end() && + weights.find(input_name) == weights.end()) + { + swapnode = true; + missing_input_name = input_name; + break; + } + } - if (found) break; - } + if (!swapnode) + { + for (int j = 0; j < (int)node->output_size(); j++) + { + const std::string& output_name = node->output(j); + if (output_name.empty()) continue; - if (q == node_count) { - fprintf(stderr, "cannot find node produces %s but node %d requires it\n", - missing_input_name.c_str(), i); - return -1; - } - - // fprintf(stderr, "swap %d %d\n", i, q); - // swap this node with q - onnx::NodeProto* nodeq = mutable_graph->mutable_node(q); - onnx::NodeProto tmp = *node; - *node = *nodeq; - *nodeq = tmp; - } - } - // global definition line - // [layer count] [blob count] - std::set blob_names; - for (int i = 0; i < node_count; i++) { - const onnx::NodeProto& node = mutable_graph->node(i); - - const std::string& op = node.op_type(); - - std::string name = node.name(); - if (name.empty()) { - name = node.output(0); - } + producers.insert(output_name); + } - if (op == "Constant") { - onnx::TensorProto tensor = get_node_attr_tensor(node, "value"); - weights[node.output(0)] = tensor; - } + i++; + continue; + } - for (int j = 0; j < (int)node.input_size(); j++) { - const std::string& input_name = node.input(j); + // find node that produce missing_input_name + int q = i + 1; + for (; q < node_count; q++) + { + onnx::NodeProto* nodeq = mutable_graph->mutable_node(q); + bool found = false; + for (int j = 0; j < (int)nodeq->output_size(); j++) + { + const std::string& output_name = nodeq->output(j); + if (output_name == missing_input_name) + { + found = true; + break; + } + } + + if (found) break; + } - blob_names.insert(input_name); + if (q == node_count) + { + fprintf(stderr, "cannot find node produces %s but node %d requires it\n", missing_input_name.c_str(), i); + return -1; + } - if (node_reference.find(input_name) == node_reference.end()) { - node_reference[input_name] = 1; - } else { - node_reference[input_name] = node_reference[input_name] + 1; - } + // fprintf(stderr, "swap %d %d\n", i, q); + // swap this node with q + onnx::NodeProto* nodeq = mutable_graph->mutable_node(q); + onnx::NodeProto tmp = *node; + *node = *nodeq; + *nodeq = tmp; + } } + // global definition line + // [layer count] [blob count] + std::set blob_names; + for (int i = 0; i < node_count; i++) + { + const onnx::NodeProto& node = mutable_graph->node(i); - if (op == "Dropout") { - const std::string& output_name = node.output(0); - blob_names.insert(output_name); - node_reference[output_name] = 0; - continue; - } + const std::string& op = node.op_type(); - for (int j = 0; j < (int)node.output_size(); j++) { - const std::string& output_name = node.output(j); + std::string name = node.name(); + if (name.empty()) + { + name = node.output(0); + } - blob_names.insert(output_name); + if (op == "Constant") + { + onnx::TensorProto tensor = get_node_attr_tensor(node, "value"); + weights[node.output(0)] = tensor; + } - node_reference[output_name] = 0; - } - } - // include Input node - int input_node_count = 0; - for (int j = 0; j < mutable_graph->input_size(); j++) { - const std::string& input_name = mutable_graph->input(j).name(); - - // check weight - if (weights.find(input_name) != weights.end()) continue; - - blob_names.insert(input_name); - - input_node_count++; - } - - // for (auto a: node_reference) - // { - // fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second); - // } - - // op chain fusion - int reduced_node_count = 0; - { - fuse_identity(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_conv_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, - reduced_node_count); - fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, - reduced_node_count); - fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_layernorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names, - reduced_node_count); - fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count); - } - // reduce common const weight node_reference - for (int i = 0; i < node_count; i++) { - const onnx::NodeProto& node = mutable_graph->node(i); - - const std::string& op = node.op_type(); - - if (op == "BatchNormalization") { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - node_reference[node.input(3)] -= 1; - node_reference[node.input(4)] -= 1; - } else if (op == "BiasGelu") { - node_reference[node.input(1)] -= 1; - } else if (op == "Clip") { - if (node.input_size() == 3) { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - } - } else if (op == "Conv") { - node_reference[node.input(1)] -= 1; - if (node.input_size() == 3) { - node_reference[node.input(2)] -= 1; - } - } else if (op == "ConvTranspose") { - node_reference[node.input(1)] -= 1; - if (node.input_size() == 3) { - node_reference[node.input(2)] -= 1; - } - } else if (op == "EmbedLayerNormalization") { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - node_reference[node.input(3)] -= 1; - node_reference[node.input(4)] -= 1; - node_reference[node.input(5)] -= 1; - node_reference[node.input(6)] -= 1; - } else if (op == "Gemm") { - float alpha = get_node_attr_f(node, "alpha", 1.f); - float beta = get_node_attr_f(node, "beta", 1.f); - int transA = get_node_attr_i(node, "transA", 0); - int transB = get_node_attr_i(node, "transB", 0); - - if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) { - // InnerProduct-like A * B + C, C is optional. - node_reference[node.input(1)] -= 1; - if (node.input_size() == 3) { - node_reference[node.input(2)] -= 1; - } - } - } else if (op == "GroupNorm") { - int affine = get_node_attr_i(node, "affine", 1); - if (affine) { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - } - } else if (op == "GRU") { - for (int j = 1; j < node.input_size(); j++) { - node_reference[node.input(j)] -= 1; - } - } else if (op == "InstanceNormalization") { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - } else if (op == "LayerNorm") { - int affine = get_node_attr_i(node, "affine", 1); - if (affine) { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - } - } else if (op == "LSTM") { - for (int j = 1; j < node.input_size(); j++) { - node_reference[node.input(j)] -= 1; - } - } else if (op == "MatMul") { - if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) { - // InnerProduct - node_reference[node.input(1)] -= 1; - } - } else if (op == "MultiHeadAttention") { - if (node.input_size() == 5) { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - node_reference[node.input(3)] -= 1; - node_reference[node.input(4)] -= 1; - } else { - node_reference[node.input(3)] -= 1; - node_reference[node.input(4)] -= 1; - node_reference[node.input(5)] -= 1; - node_reference[node.input(6)] -= 1; - node_reference[node.input(7)] -= 1; - node_reference[node.input(8)] -= 1; - node_reference[node.input(9)] -= 1; - node_reference[node.input(10)] -= 1; - } - } else if (op == "NonMaxSuppression") { - if (node.input_size() >= 3) { - node_reference[node.input(2)] -= 1; - } - if (node.input_size() >= 4) { - node_reference[node.input(3)] -= 1; - } - if (node.input_size() >= 5) { - node_reference[node.input(4)] -= 1; - } - } else if (op == "Pad") { - if (node.input_size() >= 2) { - node_reference[node.input(1)] -= 1; - } - } else if (op == "PRelu") { - node_reference[node.input(1)] -= 1; - } else if (op == "Reshape") { - if (node.input_size() == 2) { - if (weights[node.input(1)].data_type() != 0) { - node_reference[node.input(1)] -= 1; - } - } - } else if (op == "Resize") { - if (node.input_size() == 2) { - // opset 10 - node_reference[node.input(1)] -= 1; - } else { - // opset 11+ - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - if (node.input_size() >= 4) { - node_reference[node.input(3)] -= 1; - } - } - } else if (op == "RNN") { - for (int j = 1; j < node.input_size(); j++) { - node_reference[node.input(j)] -= 1; - } - } else if (op == "SkipLayerNormalization") { - node_reference[node.input(2)] -= 1; - node_reference[node.input(3)] -= 1; - node_reference[node.input(4)] -= 1; - } else if (op == "Slice") { - if (node.input_size() >= 2) { - node_reference[node.input(1)] -= 1; - node_reference[node.input(2)] -= 1; - if (node.input_size() >= 4) node_reference[node.input(3)] -= 1; - if (node.input_size() >= 5) node_reference[node.input(4)] -= 1; - } - } else if (op == "Upsample") { - if (node.input_size() >= 2) { - node_reference[node.input(1)] -= 1; - } - } else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" || - op == "adaptive_max_pool2d") { - if (node.input_size() >= 2) { - node_reference[node.input(1)] -= 1; - } - } - } + for (int j = 0; j < (int)node.input_size(); j++) + { + const std::string& input_name = node.input(j); - // for (auto a: node_reference) - // { - // fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second); - // } + blob_names.insert(input_name); - // count all weight node with zero reference - int zero_reference_weight_node_count = 0; - for (std::map::iterator it = weights.begin(); it != weights.end(); - it++) { - const std::string& input_name = it->first; + if (node_reference.find(input_name) == node_reference.end()) + { + node_reference[input_name] = 1; + } + else + { + node_reference[input_name] = node_reference[input_name] + 1; + } + } - int refcount = node_reference[input_name]; - if (refcount == 0) zero_reference_weight_node_count++; - } + if (op == "Dropout") + { + const std::string& output_name = node.output(0); + blob_names.insert(output_name); + node_reference[output_name] = 0; + continue; + } - // we always treat constant node as weight or binaryop_weights - // do not count it twice for layer_count - int constant_node_count_moved_to_weight = 0; - for (int i = 0; i < node_count; i++) { - const onnx::NodeProto& node = mutable_graph->node(i); + for (int j = 0; j < (int)node.output_size(); j++) + { + const std::string& output_name = node.output(j); - const std::string& op = node.op_type(); + blob_names.insert(output_name); - if (op == "Constant") { - constant_node_count_moved_to_weight++; - } - } - - // some op may have anonymous input - // LSTM sequence_lens - blob_names.erase(""); - node_reference.erase(""); - - // remove node_reference entry with reference equals to one - int split_layer_count = 0; - int splitncnn_blob_count = 0; - // split node reference - std::map split_node_reference; - for (std::map::iterator it = node_reference.begin(); it != node_reference.end(); - it++) { - if (it->second > 1) { - split_layer_count++; - splitncnn_blob_count += it->second; - - split_node_reference[it->first] = it->second; + node_reference[output_name] = 0; + } } - } - - fprintf(pp, "%zu %zu\n", - node_count - constant_node_count_moved_to_weight + weights.size() - - zero_reference_weight_node_count - reduced_node_count + input_node_count + - split_layer_count, - blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count); - - int internal_split = 0; - - // place Input at the beginning - for (int j = 0; j < mutable_graph->input_size(); j++) { - const std::string& input_name = mutable_graph->input(j).name(); + // include Input node + int input_node_count = 0; + for (int j = 0; j < mutable_graph->input_size(); j++) + { + const std::string& input_name = mutable_graph->input(j).name(); - // check weight - if (weights.find(input_name) != weights.end()) continue; + // check weight + if (weights.find(input_name) != weights.end()) continue; - fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str()); + blob_names.insert(input_name); - int refcount = node_reference[input_name]; - if (refcount <= 1) { - continue; + input_node_count++; } - char splitname[256]; - sprintf(splitname, "splitncnn_input%d", j); - fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount); - fprintf(pp, " %s", input_name.c_str()); + // for (auto a: node_reference) + // { + // fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second); + // } - for (int k = 0; k < refcount; k++) { - fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k); + // op chain fusion + int reduced_node_count = 0; + { + fuse_identity(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_conv_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_layernorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names, reduced_node_count); + fuse_rewrite_gather(mutable_graph, weights, node_reference, blob_names, reduced_node_count); } - fprintf(pp, "\n"); - } + // reduce common const weight node_reference + for (int i = 0; i < node_count; i++) + { + const onnx::NodeProto& node = mutable_graph->node(i); - // place MemoryData next - for (std::map::iterator weight_it = weights.begin(); - weight_it != weights.end(); weight_it++) { - const std::string& input_name = weight_it->first; + const std::string& op = node.op_type(); - int refcount = node_reference[input_name]; - if (refcount == 0) { - continue; + if (op == "BatchNormalization") + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + node_reference[node.input(3)] -= 1; + node_reference[node.input(4)] -= 1; + } + else if (op == "BiasGelu") + { + node_reference[node.input(1)] -= 1; + } + else if (op == "Clip") + { + if (node.input_size() == 3) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + } + } + else if (op == "Conv") + { + node_reference[node.input(1)] -= 1; + if (node.input_size() == 3) + { + node_reference[node.input(2)] -= 1; + } + } + else if (op == "ConvTranspose") + { + node_reference[node.input(1)] -= 1; + if (node.input_size() == 3) + { + node_reference[node.input(2)] -= 1; + } + } + else if (op == "EmbedLayerNormalization") + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + node_reference[node.input(3)] -= 1; + node_reference[node.input(4)] -= 1; + node_reference[node.input(5)] -= 1; + node_reference[node.input(6)] -= 1; + } + else if (op == "Gemm") + { + float alpha = get_node_attr_f(node, "alpha", 1.f); + float beta = get_node_attr_f(node, "beta", 1.f); + int transA = get_node_attr_i(node, "transA", 0); + int transB = get_node_attr_i(node, "transB", 0); + + if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) + { + // InnerProduct-like A * B + C, C is optional. + node_reference[node.input(1)] -= 1; + if (node.input_size() == 3) + { + node_reference[node.input(2)] -= 1; + } + } + } + else if (op == "GroupNorm") + { + int affine = get_node_attr_i(node, "affine", 1); + if (affine) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + } + } + else if (op == "GRU") + { + for (int j = 1; j < node.input_size(); j++) + { + node_reference[node.input(j)] -= 1; + } + } + else if (op == "InstanceNormalization") + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + } + else if (op == "LayerNorm") + { + int affine = get_node_attr_i(node, "affine", 1); + if (affine) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + } + } + else if (op == "LSTM") + { + for (int j = 1; j < node.input_size(); j++) + { + node_reference[node.input(j)] -= 1; + } + } + else if (op == "MatMul") + { + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) + { + // InnerProduct + node_reference[node.input(1)] -= 1; + } + } + else if (op == "MultiHeadAttention") + { + if (node.input_size() == 5) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + node_reference[node.input(3)] -= 1; + node_reference[node.input(4)] -= 1; + } + else + { + node_reference[node.input(3)] -= 1; + node_reference[node.input(4)] -= 1; + node_reference[node.input(5)] -= 1; + node_reference[node.input(6)] -= 1; + node_reference[node.input(7)] -= 1; + node_reference[node.input(8)] -= 1; + node_reference[node.input(9)] -= 1; + node_reference[node.input(10)] -= 1; + } + } + else if (op == "NonMaxSuppression") + { + if (node.input_size() >= 3) + { + node_reference[node.input(2)] -= 1; + } + if (node.input_size() >= 4) + { + node_reference[node.input(3)] -= 1; + } + if (node.input_size() >= 5) + { + node_reference[node.input(4)] -= 1; + } + } + else if (op == "Pad") + { + if (node.input_size() >= 2) + { + node_reference[node.input(1)] -= 1; + } + } + else if (op == "PRelu") + { + node_reference[node.input(1)] -= 1; + } + else if (op == "Reshape") + { + if (node.input_size() == 2) + { + if (weights[node.input(1)].data_type() != 0) + { + node_reference[node.input(1)] -= 1; + } + } + } + else if (op == "Resize") + { + if (node.input_size() == 2) + { + // opset 10 + node_reference[node.input(1)] -= 1; + } + else + { + // opset 11+ + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + if (node.input_size() >= 4) + { + node_reference[node.input(3)] -= 1; + } + } + } + else if (op == "RNN") + { + for (int j = 1; j < node.input_size(); j++) + { + node_reference[node.input(j)] -= 1; + } + } + else if (op == "SkipLayerNormalization") + { + node_reference[node.input(2)] -= 1; + node_reference[node.input(3)] -= 1; + node_reference[node.input(4)] -= 1; + } + else if (op == "Slice") + { + if (node.input_size() >= 2) + { + node_reference[node.input(1)] -= 1; + node_reference[node.input(2)] -= 1; + if (node.input_size() >= 4) node_reference[node.input(3)] -= 1; + if (node.input_size() >= 5) node_reference[node.input(4)] -= 1; + } + } + else if (op == "Upsample") + { + if (node.input_size() >= 2) + { + node_reference[node.input(1)] -= 1; + } + } + else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" || + op == "adaptive_max_pool2d") + { + if (node.input_size() >= 2) + { + node_reference[node.input(1)] -= 1; + } + } } - fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str()); - - const onnx::TensorProto& M = weights[input_name]; - - if (M.dims_size() == 0) { - fprintf(pp, " 0=%d", get_tensor_proto_data_size(M)); - } else if (M.dims_size() == 1) { - fprintf(pp, " 0=%d", (int)M.dims(0)); - } else if (M.dims_size() == 2) { - fprintf(pp, " 0=%d", (int)M.dims(1)); - if (M.dims(0) != 1) { - fprintf(pp, " 1=%d", (int)M.dims(0)); - } - } else if (M.dims_size() == 3) { - fprintf(pp, " 0=%d", (int)M.dims(2)); - fprintf(pp, " 1=%d", (int)M.dims(1)); - if (M.dims(0) != 1) { - fprintf(pp, " 2=%d", (int)M.dims(0)); - } - } else if (M.dims_size() == 4) { - fprintf(pp, " 0=%d", (int)M.dims(3)); - fprintf(pp, " 1=%d", (int)M.dims(2)); - fprintf(pp, " 2=%d", (int)M.dims(1)); - } + // for (auto a: node_reference) + // { + // fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second); + // } - fprintf(pp, "\n"); - if (M.data_type() == 1) { - fwrite_tensor_proto_data(M, bp); - } else if (M.data_type() == 7 || M.data_type() == 6 || M.data_type() == 9 || - M.data_type() == 11) { - fwrite_tensor_proto_data_to_float(M, bp); - } else { - fwrite_tensor_proto_data(M, bp); - } + // count all weight node with zero reference + int zero_reference_weight_node_count = 0; + for (std::map::iterator it = weights.begin(); it != weights.end(); + it++) + { + const std::string& input_name = it->first; - if (refcount <= 1) { - continue; + int refcount = node_reference[input_name]; + if (refcount == 0) zero_reference_weight_node_count++; } - char splitname[256]; - sprintf(splitname, "splitncnn_%d", internal_split); - fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount); + // we always treat constant node as weight or binaryop_weights + // do not count it twice for layer_count + int constant_node_count_moved_to_weight = 0; + for (int i = 0; i < node_count; i++) + { + const onnx::NodeProto& node = mutable_graph->node(i); - fprintf(pp, " %s", input_name.c_str()); + const std::string& op = node.op_type(); - for (int k = 0; k < refcount; k++) { - fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k); + if (op == "Constant") + { + constant_node_count_moved_to_weight++; + } } - fprintf(pp, "\n"); - internal_split++; - } + // some op may have anonymous input + // LSTM sequence_lens + blob_names.erase(""); + node_reference.erase(""); + + // remove node_reference entry with reference equals to one + int split_layer_count = 0; + int splitncnn_blob_count = 0; + // split node reference + std::map split_node_reference; + for (std::map::iterator it = node_reference.begin(); it != node_reference.end(); + it++) + { + if (it->second > 1) + { + split_layer_count++; + splitncnn_blob_count += it->second; - for (int i = 0; i < node_count; i++) { - const onnx::NodeProto& node = mutable_graph->node(i); - const std::string& op = node.op_type(); + split_node_reference[it->first] = it->second; + } + } - // fprintf(stderr, "op = %s\n", op.c_str()); + fprintf(pp, "%zu %zu\n", node_count - constant_node_count_moved_to_weight + weights.size() - zero_reference_weight_node_count - reduced_node_count + input_node_count + split_layer_count, blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count); - if (op == "noop_reducedncnn") { - continue; - } + int internal_split = 0; - std::string name = node.name(); - if (name.empty()) { - name = node.output(0); - } + // place Input at the beginning + for (int j = 0; j < mutable_graph->input_size(); j++) + { + const std::string& input_name = mutable_graph->input(j).name(); - int input_size = node.input_size(); - int output_size = node.output_size(); + // check weight + if (weights.find(input_name) != weights.end()) continue; - for (int j = 0; j < (int)node.input_size(); j++) { - const std::string& input_name = node.input(j); + fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str()); - // check weight - if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0) { - input_size--; - } + int refcount = node_reference[input_name]; + if (refcount <= 1) + { + continue; + } - if (input_name.empty()) { - input_size--; - } + char splitname[256]; + sprintf(splitname, "splitncnn_input%d", j); + fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount); + fprintf(pp, " %s", input_name.c_str()); - // fprintf(stderr, " input = %s\n", input_name.c_str()); + for (int k = 0; k < refcount; k++) + { + fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k); + } + fprintf(pp, "\n"); } - /* - for (int j=0; j<(int)node.output_size(); j++) + + // place MemoryData next + for (std::map::iterator weight_it = weights.begin(); + weight_it != weights.end(); + weight_it++) { - const std::string& output_name = node.output(j); - fprintf(stderr, " output = %s\n", output_name.c_str()); - } - */ - - if (op == "Abs") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Acos") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Add") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "ArgMax") { - fprintf(pp, "%-16s", "TopK"); - } else if (op == "Asin") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Atan") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "AveragePool" || op == "MaxPool") { - std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); - if (kernel_shape.size() == 1) { - fprintf(pp, "%-16s", "Pooling1D"); - } else { - fprintf(pp, "%-16s", "Pooling"); - } - } else if (op == "BatchNormalization") { - fprintf(pp, "%-16s", "BatchNorm"); - } else if (op == "BiasGelu") { - fprintf(pp, "%-16s", "BiasGelu"); - } else if (op == "Cast") { - fprintf(pp, "%-16s", "Noop"); - } else if (op == "Ceil") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Clip") { - fprintf(pp, "%-16s", "Clip"); - } else if (op == "Concat") { - fprintf(pp, "%-16s", "Concat"); - } else if (op == "Constant") { - continue; - } else if (op == "ConstantOfShape") { - fprintf(pp, "%-16s", "ConstantOfShape"); - } else if (op == "Conv") { - std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); - if (kernel_shape.size() == 1) { - fprintf(pp, "%-16s", "Convolution1D"); - } else { - int group = get_node_attr_i(node, "group", 1); - if (group > 1) { - fprintf(pp, "%-16s", "ConvolutionDepthWise"); - } else { - fprintf(pp, "%-16s", "Convolution"); - } - } - } else if (op == "ConvTranspose") { - int group = get_node_attr_i(node, "group", 1); - if (group > 1) { - fprintf(pp, "%-16s", "DeconvolutionDepthWise"); - } else { - fprintf(pp, "%-16s", "Deconvolution"); - } - } else if (op == "Cos") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Crop") { - fprintf(pp, "%-16s", "Crop"); - } else if (op == "DepthToSpace") { - fprintf(pp, "%-16s", "PixelShuffle"); - } else if (op == "DetectionOutput") { - fprintf(pp, "%-16s", "DetectionOutput"); - } else if (op == "Div") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "Dropout") { - fprintf(pp, "%-16s", "Dropout"); - output_size = 1; - } else if (op == "Elu") { - fprintf(pp, "%-16s", "ELU"); - } else if (op == "EmbedLayerNormalization") { - fprintf(pp, "%-16s", "EmbedLayerNormalization"); - } else if (op == "Equal") { - fprintf(pp, "%-16s", "Compare"); - } else if (op == "Exp") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Expand") { - fprintf(pp, "%-16s", "Expand"); - } else if (op == "Flatten") { - fprintf(pp, "%-16s", "Flatten"); - } else if (op == "Floor") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Gather") { - fprintf(pp, "%-16s", "Gather"); - } else if (op == "Gelu") { - fprintf(pp, "%-16s", "GELU"); - } else if (op == "Gemm") { - float alpha = get_node_attr_f(node, "alpha", 1.f); - float beta = get_node_attr_f(node, "beta", 1.f); - int transA = get_node_attr_i(node, "transA", 0); - int transB = get_node_attr_i(node, "transB", 0); - - if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) { - // InnerProduct-like A * B + C - fprintf(pp, "%-16s", "InnerProduct"); - } else { - fprintf(pp, "%-16s", "Gemm"); - } - } else if (op == "GlobalAveragePool") { - fprintf(pp, "%-16s", "Pooling"); - } else if (op == "GlobalMaxPool") { - fprintf(pp, "%-16s", "Pooling"); - } else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" || - op == "adaptive_max_pool2d") { - fprintf(pp, "%-16s", "Pooling"); - } else if (op == "GroupNorm") { - fprintf(pp, "%-16s", "GroupNorm"); - } else if (op == "GRU") { - fprintf(pp, "%-16s", "GRU"); - } else if (op == "HardSigmoid") { - fprintf(pp, "%-16s", "HardSigmoid"); - } else if (op == "HardSwish") { - fprintf(pp, "%-16s", "HardSwish"); - } else if (op == "ImageScaler") { - fprintf(pp, "%-16s", "Scale"); - } else if (op == "InstanceNormalization") { - fprintf(pp, "%-16s", "InstanceNorm"); - } else if (op == "LayerNorm") { - fprintf(pp, "%-16s", "LayerNorm"); - } else if (op == "LeakyRelu") { - fprintf(pp, "%-16s", "ReLU"); - } else if (op == "Threshold") { - fprintf(pp, "%-16s", "Threshold"); - } else if (op == "Log") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "LRN") { - fprintf(pp, "%-16s", "LRN"); - } else if (op == "LSTM") { - fprintf(pp, "%-16s", "LSTM"); - } else if (op == "MatMul") { - if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) { - fprintf(pp, "%-16s", "InnerProduct"); - } else { - fprintf(pp, "%-16s", "Gemm"); - } - } else if (op == "Max") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "Min") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "Mul") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "MultiHeadAttention") { - fprintf(pp, "%-16s", "MultiHeadAttention"); - } else if (op == "Neg") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "NonMaxSuppression") { - fprintf(pp, "%-16s", "NonMaxSuppression"); - } else if (op == "Normalize") { - fprintf(pp, "%-16s", "Normalize"); - } else if (op == "Pad") { - fprintf(pp, "%-16s", "Padding"); - } else if (op == "PixelShuffle") { - fprintf(pp, "%-16s", "PixelShuffle"); - } else if (op == "Pow") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "PriorBox") { - fprintf(pp, "%-16s", "PriorBox"); - } else if (op == "PRelu") { - fprintf(pp, "%-16s", "PReLU"); - } else if (op == "Range") { - fprintf(pp, "%-16s", "Range"); - } else if (op == "Reciprocal") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || - op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || - op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp") { - fprintf(pp, "%-16s", "Reduction"); - } else if (op == "Relu") { - fprintf(pp, "%-16s", "ReLU"); - } else if (op == "Reorg") { - fprintf(pp, "%-16s", "Reorg"); - } else if (op == "Reshape") { - fprintf(pp, "%-16s", "Reshape"); - } else if (op == "RNN") { - fprintf(pp, "%-16s", "RNN"); - } else if (op == "RDiv") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "RSub") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "RoiAlign") { - fprintf(pp, "%-16s", "ROIAlign"); - } else if (op == "ScatterND") { - fprintf(pp, "%-16s", "ScatterND"); - } else if (op == "Shape") { - fprintf(pp, "%-16s", "Shape"); - } else if (op == "ShuffleChannel") { - fprintf(pp, "%-16s", "ShuffleChannel"); - } else if (op == "Sigmoid") { - fprintf(pp, "%-16s", "Sigmoid"); - } else if (op == "Sin") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "SkipLayerNormalization") { - fprintf(pp, "%-16s", "SkipLayerNormalization"); - } else if (op == "Slice") { - std::vector ends; - std::vector steps; - bool use_crop = true; - - if (node.input_size() == 1) { - ends = get_node_attr_ai(node, "ends"); - steps = get_node_attr_ai(node, "steps"); // TODO - } else { - ends = get_node_attr_from_input_ai(weights[node.input(2)]); - if (node.input_size() >= 5) steps = get_node_attr_from_input_ai(weights[node.input(4)]); - } - - // assert step == 1 - for (int i = 0; i < (int)steps.size(); i++) { - if (steps[i] != 1 && steps[i] < ends[i]) { - use_crop = false; - break; - } - } - - if (use_crop) { - fprintf(pp, "%-16s", "Crop"); - } else { - fprintf(pp, "%-16s", "TensorSlice"); - } - } else if (op == "Softmax") { - fprintf(pp, "%-16s", "Softmax"); - } else if (op == "Softplus") { - fprintf(pp, "%-16s", "Softplus"); - } else if (op == "Split") { - fprintf(pp, "%-16s", "Slice"); - } else if (op == "Sqrt") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Squeeze") { - std::vector axes = get_node_attr_ai(node, "axes"); - // fprintf(stderr, "axes[0]: %d\n",axes[0]); - if (axes[0] == 0) { - fprintf(pp, "%-16s", "Noop"); - } else { - fprintf(pp, "%-16s", "Squeeze"); - } - } else if (op == "Sub") { - fprintf(pp, "%-16s", "BinaryOp"); - } else if (op == "Sum") { - fprintf(pp, "%-16s", "Eltwise"); - } else if (op == "Swish") { - fprintf(pp, "%-16s", "Swish"); - } else if (op == "Tan") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Tanh") { - fprintf(pp, "%-16s", "UnaryOp"); - } else if (op == "Tile") { - fprintf(pp, "%-16s", "TileOnnx"); - } else if (op == "TopK") { - fprintf(pp, "%-16s", "TopK"); - } else if (op == "Transpose") { - fprintf(pp, "%-16s", "Permute"); - } else if (op == "Upsample" || op == "Resize") { - fprintf(pp, "%-16s", "Interp"); - } else if (op == "Unsqueeze") { - std::vector axes = get_node_attr_ai(node, "axes"); - // fprintf(stderr, "axes[0]: %d\n",axes[0]); - if (axes[0] == 0) { - fprintf(pp, "%-16s", "Noop"); - } else { - fprintf(pp, "%-16s", "ExpandDims"); - } - } else if (op == "Where") { - fprintf(pp, "%-16s", "Where"); - } else if (op == "Yolov3DetectionOutput") { - fprintf(pp, "%-16s", "Yolov3DetectionOutput"); - } else { - // TODO - fprintf(stderr, "%s not supported yet!\n", op.c_str()); - fprintf(pp, "%-16s", op.c_str()); - } + const std::string& input_name = weight_it->first; - fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size); + int refcount = node_reference[input_name]; + if (refcount == 0) + { + continue; + } - for (int j = 0; j < (int)node.input_size(); j++) { - std::string input_name = node.input(j); + fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str()); - // check weight - if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0) { - continue; - } + const onnx::TensorProto& M = weights[input_name]; - if (input_name.empty()) { - continue; - } + if (M.dims_size() == 0) + { + fprintf(pp, " 0=%d", get_tensor_proto_data_size(M)); + } + else if (M.dims_size() == 1) + { + fprintf(pp, " 0=%d", (int)M.dims(0)); + } + else if (M.dims_size() == 2) + { + fprintf(pp, " 0=%d", (int)M.dims(1)); + if (M.dims(0) != 1) + { + fprintf(pp, " 1=%d", (int)M.dims(0)); + } + } + else if (M.dims_size() == 3) + { + fprintf(pp, " 0=%d", (int)M.dims(2)); + fprintf(pp, " 1=%d", (int)M.dims(1)); + if (M.dims(0) != 1) + { + fprintf(pp, " 2=%d", (int)M.dims(0)); + } + } + else if (M.dims_size() == 4) + { + fprintf(pp, " 0=%d", (int)M.dims(3)); + fprintf(pp, " 1=%d", (int)M.dims(2)); + fprintf(pp, " 2=%d", (int)M.dims(1)); + } - if (split_node_reference.find(input_name) != split_node_reference.end()) { - int refidx = split_node_reference[input_name] - 1; - split_node_reference[input_name] = refidx; + fprintf(pp, "\n"); + if (M.data_type() == 1) + { + fwrite_tensor_proto_data(M, bp); + } + else if (M.data_type() == 7 || M.data_type() == 6 || M.data_type() == 9 || + M.data_type() == 11) + { + fwrite_tensor_proto_data_to_float(M, bp); + } + else + { + fwrite_tensor_proto_data(M, bp); + } - char splitsuffix[256]; - sprintf(splitsuffix, "_splitncnn_%d", refidx); - input_name = input_name + splitsuffix; - } + if (refcount <= 1) + { + continue; + } - fprintf(pp, " %s", input_name.c_str()); - } + char splitname[256]; + sprintf(splitname, "splitncnn_%d", internal_split); + fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount); - for (int j = 0; j < output_size; j++) { - const std::string& output_name = node.output(j); + fprintf(pp, " %s", input_name.c_str()); - fprintf(pp, " %s", output_name.c_str()); - } + for (int k = 0; k < refcount; k++) + { + fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k); + } + fprintf(pp, "\n"); - if (op == "Abs") { - int op_type = 0; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Acos") { - int op_type = 13; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Add") { - int op_type = 0; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "ArgMax") { - int axis = get_node_attr_i(node, "axis"); - int keepdims = get_node_attr_i(node, "keepdims"); - fprintf(pp, " 0=%d", axis - 1); - fprintf(pp, " 3=%d", keepdims); - } else if (op == "Asin") { - int op_type = 12; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Atan") { - int op_type = 14; - fprintf(pp, " 0=%d", op_type); - } else if (op == "AveragePool" || op == "MaxPool") { - std::string auto_pad = get_node_attr_s(node, "auto_pad"); - int ceil_mode = get_node_attr_i(node, "ceil_mode", 0); - std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); - std::vector strides = get_node_attr_ai(node, "strides"); - std::vector pads = get_node_attr_ai(node, "pads"); - - int pool = op == "AveragePool" ? 1 : 0; - int pad_mode = 1; - - if (auto_pad == "SAME_UPPER") { - pad_mode = 2; - } else if (auto_pad == "SAME_LOWER") { - pad_mode = 3; - } - - if (ceil_mode == 1) { - pad_mode = 0; - } - - fprintf(pp, " 0=%d", pool); - - if (kernel_shape.size() == 1) { - fprintf(pp, " 1=%d", kernel_shape[0]); - } else if (kernel_shape.size() == 2) { - fprintf(pp, " 1=%d", kernel_shape[1]); - fprintf(pp, " 11=%d", kernel_shape[0]); - } - - if (strides.size() == 1) { - fprintf(pp, " 2=%d", strides[0]); - } else if (strides.size() == 2) { - fprintf(pp, " 2=%d", strides[1]); - fprintf(pp, " 12=%d", strides[0]); - } - - if (pads.size() == 1) { - fprintf(pp, " 3=%d", pads[0]); - } else if (pads.size() == 2) { - fprintf(pp, " 3=%d", pads[1]); - fprintf(pp, " 13=%d", pads[0]); - } else if (pads.size() == 4) { - fprintf(pp, " 3=%d", pads[1]); - fprintf(pp, " 13=%d", pads[0]); - fprintf(pp, " 14=%d", pads[3]); - fprintf(pp, " 15=%d", pads[2]); - } - - fprintf(pp, " 5=%d", pad_mode); - - if (op == "AveragePool") { - int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0); - fprintf(pp, " 6=%d", avgpool_count_include_pad); - } - } else if (op == "BatchNormalization") { - float epsilon = get_node_attr_f(node, "epsilon", 1e-5f); - - const onnx::TensorProto& scale = weights[node.input(1)]; - const onnx::TensorProto& B = weights[node.input(2)]; - const onnx::TensorProto& mean = weights[node.input(3)]; - const onnx::TensorProto& var = weights[node.input(4)]; - - int channels = get_tensor_proto_data_size(scale); - - fprintf(pp, " 0=%d", channels); - - fwrite_tensor_proto_data(scale, bp); - fwrite_tensor_proto_data(mean, bp); - // apply epsilon to var - { - const float* v = - var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data(); - - for (int j = 0; j < channels; j++) { - float ve = v[j] + epsilon; - fwrite(&ve, sizeof(float), 1, bp); - } - } - fwrite_tensor_proto_data(B, bp); - } else if (op == "BiasGelu") { - const onnx::TensorProto& B = weights[node.input(1)]; - - fprintf(pp, " 0=%d", get_tensor_proto_data_size(B)); - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(B, bp); - } else if (op == "Ceil") { - int op_type = 3; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Clip") { - float min; - float max; - if (node.input_size() == 1) { - min = get_node_attr_f(node, "min", -FLT_MAX); - max = get_node_attr_f(node, "max", FLT_MAX); - } else { - min = weights.find(node.input(1)) != weights.end() - ? get_node_attr_from_input(weights[node.input(1)]) - : -FLT_MAX; - max = weights.find(node.input(2)) != weights.end() - ? get_node_attr_from_input(weights[node.input(2)]) - : FLT_MAX; - } - - fprintf(pp, " 0=%e", min); - fprintf(pp, " 1=%e", max); - } else if (op == "Concat") { - int axis = get_node_attr_i(node, "axis", 1); - fprintf(pp, " 0=%d", axis - 1); - } else if (op == "Constant") { - // never reach here - } else if (op == "ConstantOfShape") { - float value = 0.f; - value = get_node_attr_f(node, "value", 0.f); - fprintf(pp, " 0=%f", value); - - } else if (op == "Conv") { - const onnx::TensorProto& W = weights[node.input(1)]; - - int num_filter = W.dims(0); - int has_bias = node.input_size() == 3 ? 1 : 0; - - std::string auto_pad = get_node_attr_s(node, "auto_pad"); - std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); - std::vector dilations = get_node_attr_ai(node, "dilations"); - std::vector strides = get_node_attr_ai(node, "strides"); - std::vector pads = get_node_attr_ai(node, "pads"); - int group = get_node_attr_i(node, "group", 1); - - fprintf(pp, " 0=%d", num_filter); - - if (kernel_shape.size() == 1) { - fprintf(pp, " 1=%d", kernel_shape[0]); - } else if (kernel_shape.size() == 2) { - fprintf(pp, " 1=%d", kernel_shape[1]); - fprintf(pp, " 11=%d", kernel_shape[0]); - } - - if (dilations.size() == 1) { - fprintf(pp, " 2=%d", dilations[0]); - } else if (dilations.size() == 2) { - fprintf(pp, " 2=%d", dilations[1]); - fprintf(pp, " 12=%d", dilations[0]); - } - - if (strides.size() == 1) { - fprintf(pp, " 3=%d", strides[0]); - } else if (strides.size() == 2) { - fprintf(pp, " 3=%d", strides[1]); - fprintf(pp, " 13=%d", strides[0]); - } - - if (auto_pad == "SAME_UPPER") { - fprintf(pp, " 4=-233"); - } else if (auto_pad == "SAME_LOWER") { - fprintf(pp, " 4=-234"); - } else { - if (pads.size() == 1) { - fprintf(pp, " 4=%d", pads[0]); - } else if (pads.size() == 2) { - fprintf(pp, " 4=%d", pads[1]); - fprintf(pp, " 14=%d", pads[0]); - } else if (pads.size() == 4) { - fprintf(pp, " 4=%d", pads[1]); - fprintf(pp, " 14=%d", pads[0]); - fprintf(pp, " 15=%d", pads[3]); - fprintf(pp, " 16=%d", pads[2]); - } - } - - fprintf(pp, " 5=%d", has_bias); - - fprintf(pp, " 6=%d", get_tensor_proto_data_size(W)); - - if (group > 1) { - fprintf(pp, " 7=%d", group); - } - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(W, bp); - - if (has_bias) { - const onnx::TensorProto& B = weights[node.input(2)]; - fwrite_tensor_proto_data(B, bp); - } - } else if (op == "ConvTranspose") { - const onnx::TensorProto& W = weights[node.input(1)]; - - int has_bias = node.input_size() == 3 ? 1 : 0; - - std::string auto_pad = get_node_attr_s(node, "auto_pad"); - std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); - std::vector dilations = get_node_attr_ai(node, "dilations"); - std::vector strides = get_node_attr_ai(node, "strides"); - std::vector output_padding = get_node_attr_ai(node, "output_padding"); - std::vector output_shape = get_node_attr_ai(node, "output_shape"); - std::vector pads = get_node_attr_ai(node, "pads"); - int group = get_node_attr_i(node, "group", 1); - int num_filter = W.dims(1) * group; - - fprintf(pp, " 0=%d", num_filter); - - if (kernel_shape.size() == 1) { - fprintf(pp, " 1=%d", kernel_shape[0]); - } else if (kernel_shape.size() == 2) { - fprintf(pp, " 1=%d", kernel_shape[1]); - fprintf(pp, " 11=%d", kernel_shape[0]); - } - - if (dilations.size() == 1) { - fprintf(pp, " 2=%d", dilations[0]); - } else if (dilations.size() == 2) { - fprintf(pp, " 2=%d", dilations[1]); - fprintf(pp, " 12=%d", dilations[0]); - } - - if (strides.size() == 1) { - fprintf(pp, " 3=%d", strides[0]); - } else if (strides.size() == 2) { - fprintf(pp, " 3=%d", strides[1]); - fprintf(pp, " 13=%d", strides[0]); - } - - if (auto_pad == "SAME_UPPER") { - fprintf(pp, " 4=-233"); - } else if (auto_pad == "SAME_LOWER") { - fprintf(pp, " 4=-234"); - } else { - if (pads.size() == 1) { - fprintf(pp, " 4=%d", pads[0]); - } else if (pads.size() == 2) { - fprintf(pp, " 4=%d", pads[1]); - fprintf(pp, " 14=%d", pads[0]); - } else if (pads.size() == 4) { - fprintf(pp, " 4=%d", pads[1]); - fprintf(pp, " 14=%d", pads[0]); - fprintf(pp, " 15=%d", pads[3]); - fprintf(pp, " 16=%d", pads[2]); - } - } - - if (output_padding.size() == 1) { - fprintf(pp, " 18=%d", output_padding[0]); - } else if (output_padding.size() == 2) { - fprintf(pp, " 18=%d", output_padding[1]); - fprintf(pp, " 19=%d", output_padding[0]); - } - - if (output_shape.size() == 1) { - fprintf(pp, " 20=%d", output_shape[0]); - } else if (output_shape.size() == 2) { - fprintf(pp, " 20=%d", output_shape[1]); - fprintf(pp, " 21=%d", output_shape[0]); - } - - fprintf(pp, " 5=%d", has_bias); - - fprintf(pp, " 6=%d", get_tensor_proto_data_size(W)); - - if (group > 1) { - fprintf(pp, " 7=%d", group); - } - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int maxk = 0; - if (kernel_shape.size() == 2) { - maxk = kernel_shape[1] * kernel_shape[0]; - } else { - maxk = kernel_shape[0] * kernel_shape[0]; - } - int weight_data_size = get_tensor_proto_data_size(W); - const float* weight_data = 0; - if (W.has_raw_data()) { - weight_data = (const float*)W.raw_data().data(); - } else if (W.data_type() == 1) { - weight_data = W.float_data().data(); - } - for (int g = 0; g < group; g++) { - // reorder weight from inch-outch to outch-inch - int num_filter_g = num_filter / group; - int num_input = weight_data_size / maxk / num_filter_g / group; - const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input; - for (int k = 0; k < num_filter_g; k++) { - for (int j = 0; j < num_input; j++) { - fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp); - } - } - } - - if (has_bias) { - const onnx::TensorProto& B = weights[node.input(2)]; - fwrite_tensor_proto_data(B, bp); - } - } else if (op == "Cos") { - int op_type = 10; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Crop") { - auto starts = get_node_attr_ai(node, "starts"); - fprintf(pp, " -23309=%zu", starts.size()); - for (size_t j = 0; j < starts.size(); ++j) { - fprintf(pp, ",%i", starts[j]); - } - auto ends = get_node_attr_ai(node, "ends"); - fprintf(pp, " -23310=%zu", ends.size()); - for (size_t j = 0; j < ends.size(); ++j) { - fprintf(pp, ",%i", ends[j]); - } - auto axis = get_node_attr_ai(node, "axis"); - fprintf(pp, " -23311=%zu", axis.size()); - for (size_t j = 0; j < axis.size(); ++j) { - fprintf(pp, ",%i", axis[j]); - } - } else if (op == "DepthToSpace") { - // pixelshuffle - int scale_factor = get_node_attr_i(node, "blocksize", 1); - std::string mode = get_node_attr_s(node, "mode"); - fprintf(pp, " 0=%d", scale_factor); - if (mode == "CRD") { - fprintf(pp, " 1=0"); - } else if (mode == "DCR") { - fprintf(pp, " 1=1"); - } - } else if (op == "DetectionOutput") { - float score_threshold = get_node_attr_f(node, "score_threshold"); - float nms_threshold = get_node_attr_f(node, "nms_threshold"); - int nms_top_k = get_node_attr_i(node, "nms_top_k"); - int keep_top_k = get_node_attr_i(node, "keep_top_k"); - int num_class = get_node_attr_i(node, "num_class"); - std::vector vars = get_node_attr_af(node, "vars"); - fprintf(pp, " 0=%d", num_class); - fprintf(pp, " 1=%f", nms_threshold); - fprintf(pp, " 2=%d", nms_top_k); - fprintf(pp, " 3=%d", keep_top_k); - fprintf(pp, " 4=%f", score_threshold); - fprintf(pp, " 5=%f", vars[0]); - fprintf(pp, " 6=%f", vars[1]); - fprintf(pp, " 7=%f", vars[2]); - fprintf(pp, " 8=%f", vars[3]); - } else if (op == "Div") { - int op_type = 3; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "Dropout") { - // no-op - } else if (op == "Elu") { - float alpha = get_node_attr_f(node, "alpha", 1.f); - fprintf(pp, " 0=%e", alpha); - } else if (op == "EmbedLayerNormalization") { - const onnx::TensorProto& words = weights[node.input(2)]; - const onnx::TensorProto& positions = weights[node.input(3)]; - const onnx::TensorProto& W = weights[node.input(5)]; - const onnx::TensorProto& B = weights[node.input(6)]; - - fprintf(pp, " 0=%d", get_tensor_proto_data_size(B)); - fprintf(pp, " 1=%d", get_tensor_proto_data_size(words)); - fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions)); - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(words, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(positions, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(W, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(B, bp); - } else if (op == "Equal") { - int op_type = 0; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Exp") { - int op_type = 7; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Flatten") { - int axis = get_node_attr_i(node, "axis", 1); - if (axis != 1) { - fprintf(stderr, "Unsupported Flatten axis %d!\n", axis); - } - } else if (op == "Floor") { - int op_type = 2; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Gather") { - if (weights[node.input(1)].dims_size() > 1) { - fprintf(stderr, "Unsupported indice dims > 1"); - } - int axis = get_node_attr_i(node, "axis", 1) - 1; - if (axis < 0) { - fprintf(stderr, "Unsupported Gather axis: %d\n", axis + 1); - } - fprintf(pp, " 0=%d", axis); - } else if (op == "Gelu") { - fprintf(pp, " 0=1"); - } else if (op == "Gemm") { - float alpha = get_node_attr_f(node, "alpha", 1.f); - float beta = get_node_attr_f(node, "beta", 1.f); - int transA = get_node_attr_i(node, "transA", 0); - int transB = get_node_attr_i(node, "transB", 0); - - if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) { - // InnerProduct-like A * B + C - const onnx::TensorProto& B = weights[node.input(1)]; - // B has transposed. - int num_output = B.dims(0); - fprintf(pp, " 0=%d", num_output); - if (node.input_size() == 3) { - fprintf(pp, " 1=1"); - } else { - fprintf(pp, " 1=0"); - } - fprintf(pp, " 2=%d", get_tensor_proto_data_size(B)); - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - fwrite_tensor_proto_data(B, bp); - if (node.input_size() == 3) { - const onnx::TensorProto& C = weights[node.input(2)]; - fwrite_tensor_proto_data(C, bp); - } - } else { - // gemm - fprintf(pp, " 0=%e", alpha); - fprintf(pp, " 1=%e", beta); - fprintf(pp, " 2=%d", transA); - fprintf(pp, " 3=%d", transB); - } - } else if (op == "GlobalAveragePool") { - int pool = 1; - int global_pool = 1; - - fprintf(pp, " 0=%d", pool); - fprintf(pp, " 4=%d", global_pool); - } else if (op == "GlobalMaxPool") { - int pool = 0; - int global_pool = 1; - - fprintf(pp, " 0=%d", pool); - fprintf(pp, " 4=%d", global_pool); - } else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" || - op == "adaptive_max_pool2d") { - int pool = 0; - if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d") { - pool = 1; - } - int adaptive_pooling = 1; - const onnx::TensorProto& out_shape_tp = weights[node.input(1)]; - std::vector out_shape = get_node_attr_from_input_ai(out_shape_tp); - - fprintf(pp, " 0=%d", pool); - fprintf(pp, " 7=%d", adaptive_pooling); - if (out_shape.size() == 1) { - fprintf(pp, " 8=%d", out_shape[0]); - } else if (out_shape.size() == 2) { - // out_w - fprintf(pp, " 8=%d", out_shape[1]); - // out_h - fprintf(pp, " 18=%d", out_shape[0]); - } - } else if (op == "GroupNorm") { - int groups = get_node_attr_i(node, "groups", 1); - int channels = get_node_attr_i(node, "channels", 1); - float eps = get_node_attr_f(node, "epsilon", 1e-5f); - int affine = get_node_attr_i(node, "affine", 1); - - if (affine) { - // discard affine-less S=1 B=0 - std::vector affine_S = get_node_attr_from_input_af(weights[node.input(1)]); - std::vector affine_B = get_node_attr_from_input_af(weights[node.input(2)]); - if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && - affine_B[0] == 0.f) { - affine = 0; - } else { - affine = 0; - { - for (int j = 0; j < channels; j++) { - if (affine_S[j] != 1.f || affine_B[j] != 0.f) { - affine = 1; - break; - } - } - } - } - } - - fprintf(pp, " 0=%d", groups); - fprintf(pp, " 1=%d", channels); - fprintf(pp, " 2=%e", eps); - fprintf(pp, " 3=%d", affine); - if (affine) { - const onnx::TensorProto& scale = weights[node.input(1)]; - const onnx::TensorProto& B = weights[node.input(2)]; - - fwrite_tensor_proto_data(scale, bp); - fwrite_tensor_proto_data(B, bp); - } - } else if (op == "GRU") { - const onnx::TensorProto& W = weights[node.input(1)]; - const onnx::TensorProto& R = weights[node.input(2)]; - const onnx::TensorProto& B = weights[node.input(3)]; - - int hidden_size = get_node_attr_i(node, "hidden_size", 0); - std::string direction = get_node_attr_s(node, "direction"); - - int direction_type = 0; - if (direction == "forward") { - direction_type = 0; - } else if (direction == "reverse") { - direction_type = 1; - } else if (direction == "bidirectional") { - direction_type = 2; - } - - int weight_data_size = get_tensor_proto_data_size(W); - - fprintf(pp, " 0=%d", hidden_size); - fprintf(pp, " 1=%d", weight_data_size); - fprintf(pp, " 2=%d", direction_type); - - int num_directions = direction_type == 2 ? 2 : 1; - - int quantize_tag = 0; - - // reorder num_directions-URN-hidden-size to - // num_directions-RUN-hidden-size - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions; - const float* wptr = - W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data(); - - const float* uptr = wptr; - const float* rptr = wptr + weight_data_size_g; - const float* nptr = wptr + weight_data_size_g * 2; - fwrite(rptr, sizeof(float), weight_data_size_g, bp); - fwrite(uptr, sizeof(float), weight_data_size_g, bp); - fwrite(nptr, sizeof(float), weight_data_size_g, bp); - - if (direction_type == 2) { - uptr += weight_data_size_g * 3; - rptr += weight_data_size_g * 3; - nptr += weight_data_size_g * 3; - fwrite(rptr, sizeof(float), weight_data_size_g, bp); - fwrite(uptr, sizeof(float), weight_data_size_g, bp); - fwrite(nptr, sizeof(float), weight_data_size_g, bp); - } - } - - // reduce U and R bias except N - // reorder num_directions-URN-hidden to num_directions-RUN-hidden - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions; - const float* bptr = - B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); - const float* wuptr = bptr; - const float* wrptr = bptr + bias_data_size_g; - const float* wnptr = bptr + bias_data_size_g * 2; - const float* buptr = bptr + bias_data_size_g * 3; - const float* brptr = bptr + bias_data_size_g * 4; - const float* bnptr = bptr + bias_data_size_g * 5; - - for (int j = 0; j < bias_data_size_g; j++) { - float vb = wrptr[j] + brptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = wuptr[j] + buptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - fwrite(wnptr, sizeof(float), bias_data_size_g, bp); - fwrite(bnptr, sizeof(float), bias_data_size_g, bp); - - if (direction_type == 2) { - wuptr += bias_data_size_g * 6; - wrptr += bias_data_size_g * 6; - wnptr += bias_data_size_g * 6; - buptr += bias_data_size_g * 6; - brptr += bias_data_size_g * 6; - bnptr += bias_data_size_g * 6; - - for (int j = 0; j < bias_data_size_g; j++) { - float vb = wrptr[j] + brptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = wuptr[j] + buptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - fwrite(wnptr, sizeof(float), bias_data_size_g, bp); - fwrite(bnptr, sizeof(float), bias_data_size_g, bp); - } - } - - // reorder num_directions-URN-hidden-hidden to - // num_directions-RUN-hidden-hidden - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions; - const float* Rptr = - R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data(); - - const float* uptr = Rptr; - const float* rptr = Rptr + weight_data_size_g; - const float* nptr = Rptr + weight_data_size_g * 2; - fwrite(rptr, sizeof(float), weight_data_size_g, bp); - fwrite(uptr, sizeof(float), weight_data_size_g, bp); - fwrite(nptr, sizeof(float), weight_data_size_g, bp); - - if (direction_type == 2) { - uptr += weight_data_size_g * 3; - rptr += weight_data_size_g * 3; - nptr += weight_data_size_g * 3; - fwrite(rptr, sizeof(float), weight_data_size_g, bp); - fwrite(uptr, sizeof(float), weight_data_size_g, bp); - fwrite(nptr, sizeof(float), weight_data_size_g, bp); - } - } - } else if (op == "HardSigmoid") { - float alpha = get_node_attr_f(node, "alpha", 0.2f); - float beta = get_node_attr_f(node, "beta", 0.5f); - - fprintf(pp, " 0=%e", alpha); - fprintf(pp, " 1=%e", beta); - } else if (op == "HardSwish") { - float alpha = get_node_attr_f(node, "alpha", 0.2f); - float beta = get_node_attr_f(node, "beta", 0.5f); - - fprintf(pp, " 0=%e", alpha); - fprintf(pp, " 1=%e", beta); - } else if (op == "ImageScaler") { - std::vector bias = get_node_attr_af(node, "bias"); - float scale = get_node_attr_f(node, "scale", 1.f); - - int channels = (int)bias.size(); - - fprintf(pp, " 0=%d", channels); - fprintf(pp, " 1=1"); - - for (int j = 0; j < channels; j++) { - fwrite(&scale, sizeof(float), 1, bp); - } - fwrite(&bias[0], sizeof(float), channels, bp); - } else if (op == "InstanceNormalization") { - float eps = get_node_attr_f(node, "epsilon", 1e-5f); - - // discard affine-less S=1 B=0 - std::vector affine_S = get_node_attr_from_input_af(weights[node.input(1)]); - std::vector affine_B = get_node_attr_from_input_af(weights[node.input(2)]); - int channels = (int)affine_S.size(); - int affine = 0; - { - for (int j = 0; j < channels; j++) { - if (affine_S[j] != 1.f || affine_B[j] != 0.f) { - affine = 1; - break; - } - } - } - - fprintf(pp, " 0=%d", channels); - fprintf(pp, " 1=%e", eps); - fprintf(pp, " 2=%d", affine); - if (affine) { - const onnx::TensorProto& scale = weights[node.input(1)]; - const onnx::TensorProto& B = weights[node.input(2)]; - - fwrite_tensor_proto_data(scale, bp); - fwrite_tensor_proto_data(B, bp); - } - } else if (op == "LayerNorm") { - float eps = get_node_attr_f(node, "epsilon", 1e-5f); - int affine = get_node_attr_i(node, "affine", 1); - - if (affine) { - // discard affine-less S=1 B=0 - std::vector affine_S = get_node_attr_from_input_af(weights[node.input(1)]); - std::vector affine_B = get_node_attr_from_input_af(weights[node.input(2)]); - int affine_size = (int)affine_S.size(); - affine = 0; - { - for (int j = 0; j < affine_size; j++) { - if (affine_S[j] != 1.f || affine_B[j] != 0.f) { - affine = 1; - break; - } - } - } - - if (affine) { - fprintf(pp, " 0=%d", affine_size); - } - } - - fprintf(pp, " 1=%e", eps); - fprintf(pp, " 2=%d", affine); - - if (affine) { - const onnx::TensorProto& scale = weights[node.input(1)]; - const onnx::TensorProto& B = weights[node.input(2)]; - - fwrite_tensor_proto_data(scale, bp); - fwrite_tensor_proto_data(B, bp); - } - } else if (op == "LeakyRelu") { - float alpha = get_node_attr_f(node, "alpha", 0.01f); - fprintf(pp, " 0=%e", alpha); - } else if (op == "Threshold") { - float threshold = get_node_attr_f(node, "threshold", 0.f); - fprintf(pp, " 0=%e", threshold); - } else if (op == "Log") { - int op_type = 8; - fprintf(pp, " 0=%d", op_type); - } else if (op == "LRN") { - float alpha = get_node_attr_f(node, "alpha", 1.f); - float beta = get_node_attr_f(node, "beta", 0.5f); - float bias = get_node_attr_f(node, "bias", 1.f); - int size = get_node_attr_i(node, "size", 1); - - int norm_region = 0; - - fprintf(pp, " 0=%d", norm_region); - fprintf(pp, " 1=%d", size); - fprintf(pp, " 2=%e", alpha); - fprintf(pp, " 3=%e", beta); - fprintf(pp, " 4=%e", bias); - } else if (op == "LSTM") { - const onnx::TensorProto& W = weights[node.input(1)]; - const onnx::TensorProto& R = weights[node.input(2)]; - const onnx::TensorProto& B = weights[node.input(3)]; - - int hidden_size = get_node_attr_i(node, "hidden_size", 0); - std::string direction = get_node_attr_s(node, "direction"); - - int direction_type = 0; - if (direction == "forward") { - direction_type = 0; - } else if (direction == "reverse") { - direction_type = 1; - } else if (direction == "bidirectional") { - direction_type = 2; - } - - int weight_data_size = get_tensor_proto_data_size(W); - - fprintf(pp, " 0=%d", hidden_size); - fprintf(pp, " 1=%d", weight_data_size); - fprintf(pp, " 2=%d", direction_type); - - int num_directions = direction_type == 2 ? 2 : 1; - - int quantize_tag = 0; - - // reorder num_directions-IOFG-hidden-size to - // num_directions-IFOG-hidden-size - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions; - const float* wptr = - W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data(); - - const float* iptr = wptr; - const float* optr = wptr + weight_data_size_g; - const float* fptr = wptr + weight_data_size_g * 2; - const float* gptr = wptr + weight_data_size_g * 3; - fwrite(iptr, sizeof(float), weight_data_size_g, bp); - fwrite(fptr, sizeof(float), weight_data_size_g, bp); - fwrite(optr, sizeof(float), weight_data_size_g, bp); - fwrite(gptr, sizeof(float), weight_data_size_g, bp); - - if (direction_type == 2) { - iptr += weight_data_size_g * 4; - optr += weight_data_size_g * 4; - fptr += weight_data_size_g * 4; - gptr += weight_data_size_g * 4; - fwrite(iptr, sizeof(float), weight_data_size_g, bp); - fwrite(fptr, sizeof(float), weight_data_size_g, bp); - fwrite(optr, sizeof(float), weight_data_size_g, bp); - fwrite(gptr, sizeof(float), weight_data_size_g, bp); - } - } - - // reduce xc and hc bias - // reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions; - const float* xcbptr = - B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); - const float* xiptr = xcbptr; - const float* xoptr = xcbptr + bias_data_size_g; - const float* xfptr = xcbptr + bias_data_size_g * 2; - const float* xgptr = xcbptr + bias_data_size_g * 3; - const float* hiptr = xcbptr + bias_data_size_g * 4; - const float* hoptr = xcbptr + bias_data_size_g * 5; - const float* hfptr = xcbptr + bias_data_size_g * 6; - const float* hgptr = xcbptr + bias_data_size_g * 7; - - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xiptr[j] + hiptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xfptr[j] + hfptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xoptr[j] + hoptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xgptr[j] + hgptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - - if (direction_type == 2) { - xiptr += bias_data_size_g * 8; - xoptr += bias_data_size_g * 8; - xfptr += bias_data_size_g * 8; - xgptr += bias_data_size_g * 8; - hiptr += bias_data_size_g * 8; - hoptr += bias_data_size_g * 8; - hfptr += bias_data_size_g * 8; - hgptr += bias_data_size_g * 8; - - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xiptr[j] + hiptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xfptr[j] + hfptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xoptr[j] + hoptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xgptr[j] + hgptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - } - } - - // reorder num_directions-IOFG-hidden-hidden to - // num_directions-IFOG-hidden-hidden - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions; - const float* rptr = - R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data(); - - const float* iptr = rptr; - const float* optr = rptr + weight_data_size_g; - const float* fptr = rptr + weight_data_size_g * 2; - const float* gptr = rptr + weight_data_size_g * 3; - fwrite(iptr, sizeof(float), weight_data_size_g, bp); - fwrite(fptr, sizeof(float), weight_data_size_g, bp); - fwrite(optr, sizeof(float), weight_data_size_g, bp); - fwrite(gptr, sizeof(float), weight_data_size_g, bp); - - if (direction_type == 2) { - iptr += weight_data_size_g * 4; - optr += weight_data_size_g * 4; - fptr += weight_data_size_g * 4; - gptr += weight_data_size_g * 4; - fwrite(iptr, sizeof(float), weight_data_size_g, bp); - fwrite(fptr, sizeof(float), weight_data_size_g, bp); - fwrite(optr, sizeof(float), weight_data_size_g, bp); - fwrite(gptr, sizeof(float), weight_data_size_g, bp); - } - } - } else if (op == "MatMul") { - if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) { - // InnerProduct - const onnx::TensorProto& B = weights[node.input(1)]; - - int weight_data_size = get_tensor_proto_data_size(B); - - int num_output = B.dims(B.dims_size() - 1); - int num_input = weight_data_size / num_output; - - fprintf(pp, " 0=%d", num_output); - fprintf(pp, " 1=0"); - fprintf(pp, " 2=%d", weight_data_size); - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - - // reorder num_input-num_output to num_output-num_input - { - const float* bptr = - B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); - - for (int j = 0; j < num_output; j++) { - for (int k = 0; k < num_input; k++) { - float vb = bptr[k * num_output + j]; - fwrite(&vb, sizeof(float), 1, bp); - } - } - } - - // fwrite_tensor_proto_data(B, bp) - } else { - // default matrix multiplication - } - } else if (op == "Max") { - int op_type = 4; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "Min") { - int op_type = 5; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "Mul") { - int op_type = 2; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "MultiHeadAttention") { - int embed_dim = get_node_attr_i(node, "embed_dim", 0); - int num_heads = get_node_attr_i(node, "num_heads", 0); + internal_split++; + } - fprintf(pp, " 0=%d", embed_dim); - fprintf(pp, " 1=%d", num_heads); + for (int i = 0; i < node_count; i++) + { + const onnx::NodeProto& node = mutable_graph->node(i); + const std::string& op = node.op_type(); - if (node.input_size() == 5) { - const onnx::TensorProto& qkvw = weights[node.input(1)]; - const onnx::TensorProto& qkvb = weights[node.input(2)]; - const onnx::TensorProto& ow = weights[node.input(3)]; - const onnx::TensorProto& ob = weights[node.input(4)]; + // fprintf(stderr, "op = %s\n", op.c_str()); - int weight_data_size = get_tensor_proto_data_size(ow); + if (op == "noop_reducedncnn") + { + continue; + } - fprintf(pp, " 2=%d", weight_data_size); + std::string name = node.name(); + if (name.empty()) + { + name = node.output(0); + } - int quantize_tag = 0; + int input_size = node.input_size(); + int output_size = node.output_size(); - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose qw + for (int j = 0; j < (int)node.input_size(); j++) { - const float* wptr = - qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data(); - const float* bptr = - qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data(); + const std::string& input_name = node.input(j); - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim * 3 + k]; - fwrite(&vb, sizeof(float), 1, bp); + // check weight + if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0) + { + input_size--; } - } - fwrite(bptr, sizeof(float), embed_dim, bp); - } + if (input_name.empty()) + { + input_size--; + } - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose kw + // fprintf(stderr, " input = %s\n", input_name.c_str()); + } + /* + for (int j=0; j<(int)node.output_size(); j++) { - const float* wptr = - qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data(); - const float* bptr = - qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data(); - bptr += embed_dim; + const std::string& output_name = node.output(j); + fprintf(stderr, " output = %s\n", output_name.c_str()); + } + */ - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim * 3 + k + embed_dim]; - fwrite(&vb, sizeof(float), 1, bp); + if (op == "Abs") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Acos") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Add") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "ArgMax") + { + fprintf(pp, "%-16s", "TopK"); + } + else if (op == "Asin") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Atan") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "AveragePool" || op == "MaxPool") + { + std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); + if (kernel_shape.size() == 1) + { + fprintf(pp, "%-16s", "Pooling1D"); + } + else + { + fprintf(pp, "%-16s", "Pooling"); } - } - - fwrite(bptr, sizeof(float), embed_dim, bp); } - - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose vw + else if (op == "BatchNormalization") { - const float* wptr = - qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data(); - const float* bptr = - qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data(); - bptr += embed_dim * 2; - - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim * 3 + k + embed_dim * 2]; - fwrite(&vb, sizeof(float), 1, bp); + fprintf(pp, "%-16s", "BatchNorm"); + } + else if (op == "BiasGelu") + { + fprintf(pp, "%-16s", "BiasGelu"); + } + else if (op == "Cast") + { + fprintf(pp, "%-16s", "Noop"); + } + else if (op == "Ceil") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Clip") + { + fprintf(pp, "%-16s", "Clip"); + } + else if (op == "Concat") + { + fprintf(pp, "%-16s", "Concat"); + } + else if (op == "Constant") + { + continue; + } + else if (op == "ConstantOfShape") + { + fprintf(pp, "%-16s", "ConstantOfShape"); + } + else if (op == "Conv") + { + std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); + if (kernel_shape.size() == 1) + { + fprintf(pp, "%-16s", "Convolution1D"); + } + else + { + int group = get_node_attr_i(node, "group", 1); + if (group > 1) + { + fprintf(pp, "%-16s", "ConvolutionDepthWise"); + } + else + { + fprintf(pp, "%-16s", "Convolution"); + } } - } - - fwrite(bptr, sizeof(float), embed_dim, bp); } - - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose ow + else if (op == "ConvTranspose") { - const float* wptr = - ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data(); - - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim + k]; - fwrite(&vb, sizeof(float), 1, bp); + int group = get_node_attr_i(node, "group", 1); + if (group > 1) + { + fprintf(pp, "%-16s", "DeconvolutionDepthWise"); + } + else + { + fprintf(pp, "%-16s", "Deconvolution"); } - } } - fwrite_tensor_proto_data(ob, bp); - } else { - const onnx::TensorProto& qw = weights[node.input(3)]; - const onnx::TensorProto& qb = weights[node.input(4)]; - const onnx::TensorProto& kw = weights[node.input(5)]; - const onnx::TensorProto& kb = weights[node.input(6)]; - const onnx::TensorProto& vw = weights[node.input(7)]; - const onnx::TensorProto& vb = weights[node.input(8)]; - const onnx::TensorProto& ow = weights[node.input(9)]; - const onnx::TensorProto& ob = weights[node.input(10)]; - - int weight_data_size = get_tensor_proto_data_size(qw); - - fprintf(pp, " 2=%d", weight_data_size); - - int quantize_tag = 0; - - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose qw + else if (op == "Cos") { - const float* wptr = - qw.has_raw_data() ? (const float*)qw.raw_data().data() : qw.float_data().data(); - - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim + k]; - fwrite(&vb, sizeof(float), 1, bp); + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Crop") + { + fprintf(pp, "%-16s", "Crop"); + } + else if (op == "DepthToSpace") + { + fprintf(pp, "%-16s", "PixelShuffle"); + } + else if (op == "DetectionOutput") + { + fprintf(pp, "%-16s", "DetectionOutput"); + } + else if (op == "Div") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "Dropout") + { + fprintf(pp, "%-16s", "Dropout"); + output_size = 1; + } + else if (op == "Elu") + { + fprintf(pp, "%-16s", "ELU"); + } + else if (op == "EmbedLayerNormalization") + { + fprintf(pp, "%-16s", "EmbedLayerNormalization"); + } + else if (op == "Equal") + { + fprintf(pp, "%-16s", "Compare"); + } + else if (op == "Exp") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Expand") + { + fprintf(pp, "%-16s", "Expand"); + } + else if (op == "Flatten") + { + fprintf(pp, "%-16s", "Flatten"); + } + else if (op == "Floor") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Gather") + { + fprintf(pp, "%-16s", "Gather"); + } + else if (op == "Gelu") + { + fprintf(pp, "%-16s", "GELU"); + } + else if (op == "Gemm") + { + float alpha = get_node_attr_f(node, "alpha", 1.f); + float beta = get_node_attr_f(node, "beta", 1.f); + int transA = get_node_attr_i(node, "transA", 0); + int transB = get_node_attr_i(node, "transB", 0); + + if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) + { + // InnerProduct-like A * B + C + fprintf(pp, "%-16s", "InnerProduct"); + } + else + { + fprintf(pp, "%-16s", "Gemm"); + } + } + else if (op == "GlobalAveragePool") + { + fprintf(pp, "%-16s", "Pooling"); + } + else if (op == "GlobalMaxPool") + { + fprintf(pp, "%-16s", "Pooling"); + } + else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" || + op == "adaptive_max_pool2d") + { + fprintf(pp, "%-16s", "Pooling"); + } + else if (op == "GroupNorm") + { + fprintf(pp, "%-16s", "GroupNorm"); + } + else if (op == "GRU") + { + fprintf(pp, "%-16s", "GRU"); + } + else if (op == "HardSigmoid") + { + fprintf(pp, "%-16s", "HardSigmoid"); + } + else if (op == "HardSwish") + { + fprintf(pp, "%-16s", "HardSwish"); + } + else if (op == "ImageScaler") + { + fprintf(pp, "%-16s", "Scale"); + } + else if (op == "InstanceNormalization") + { + fprintf(pp, "%-16s", "InstanceNorm"); + } + else if (op == "LayerNorm") + { + fprintf(pp, "%-16s", "LayerNorm"); + } + else if (op == "LeakyRelu") + { + fprintf(pp, "%-16s", "ReLU"); + } + else if (op == "Threshold") + { + fprintf(pp, "%-16s", "Threshold"); + } + else if (op == "Log") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "LRN") + { + fprintf(pp, "%-16s", "LRN"); + } + else if (op == "LSTM") + { + fprintf(pp, "%-16s", "LSTM"); + } + else if (op == "MatMul") + { + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) + { + fprintf(pp, "%-16s", "InnerProduct"); + } + else + { + fprintf(pp, "%-16s", "Gemm"); + } + } + else if (op == "Max") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "Min") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "Mul") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "MultiHeadAttention") + { + fprintf(pp, "%-16s", "MultiHeadAttention"); + } + else if (op == "Neg") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "NonMaxSuppression") + { + fprintf(pp, "%-16s", "NonMaxSuppression"); + } + else if (op == "Normalize") + { + fprintf(pp, "%-16s", "Normalize"); + } + else if (op == "Pad") + { + fprintf(pp, "%-16s", "Padding"); + } + else if (op == "PixelShuffle") + { + fprintf(pp, "%-16s", "PixelShuffle"); + } + else if (op == "Pow") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "PriorBox") + { + fprintf(pp, "%-16s", "PriorBox"); + } + else if (op == "PRelu") + { + fprintf(pp, "%-16s", "PReLU"); + } + else if (op == "Range") + { + fprintf(pp, "%-16s", "Range"); + } + else if (op == "Reciprocal") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || + op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || + op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp") + { + fprintf(pp, "%-16s", "Reduction"); + } + else if (op == "Relu") + { + fprintf(pp, "%-16s", "ReLU"); + } + else if (op == "Reorg") + { + fprintf(pp, "%-16s", "Reorg"); + } + else if (op == "Reshape") + { + fprintf(pp, "%-16s", "Reshape"); + } + else if (op == "RNN") + { + fprintf(pp, "%-16s", "RNN"); + } + else if (op == "RDiv") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "RSub") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "RoiAlign") + { + fprintf(pp, "%-16s", "ROIAlign"); + } + else if (op == "ScatterND") + { + fprintf(pp, "%-16s", "ScatterND"); + } + else if (op == "Shape") + { + fprintf(pp, "%-16s", "Shape"); + } + else if (op == "ShuffleChannel") + { + fprintf(pp, "%-16s", "ShuffleChannel"); + } + else if (op == "Sigmoid") + { + fprintf(pp, "%-16s", "Sigmoid"); + } + else if (op == "Sin") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "SkipLayerNormalization") + { + fprintf(pp, "%-16s", "SkipLayerNormalization"); + } + else if (op == "Slice") + { + std::vector ends; + std::vector steps; + bool use_crop = true; + + if (node.input_size() == 1) + { + ends = get_node_attr_ai(node, "ends"); + steps = get_node_attr_ai(node, "steps"); // TODO + } + else + { + ends = get_node_attr_from_input_ai(weights[node.input(2)]); + if (node.input_size() >= 5) steps = get_node_attr_from_input_ai(weights[node.input(4)]); + } + + // assert step == 1 + for (int i = 0; i < (int)steps.size(); i++) + { + if (steps[i] != 1 && steps[i] < ends[i]) + { + use_crop = false; + break; + } + } + + if (use_crop) + { + fprintf(pp, "%-16s", "Crop"); + } + else + { + fprintf(pp, "%-16s", "TensorSlice"); + } + } + else if (op == "Softmax") + { + fprintf(pp, "%-16s", "Softmax"); + } + else if (op == "Softplus") + { + fprintf(pp, "%-16s", "Softplus"); + } + else if (op == "Split") + { + fprintf(pp, "%-16s", "Slice"); + } + else if (op == "Sqrt") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Squeeze") + { + std::vector axes = get_node_attr_ai(node, "axes"); + // fprintf(stderr, "axes[0]: %d\n",axes[0]); + if (axes[0] == 0) + { + fprintf(pp, "%-16s", "Noop"); + } + else + { + fprintf(pp, "%-16s", "Squeeze"); + } + } + else if (op == "Sub") + { + fprintf(pp, "%-16s", "BinaryOp"); + } + else if (op == "Sum") + { + fprintf(pp, "%-16s", "Eltwise"); + } + else if (op == "Swish") + { + fprintf(pp, "%-16s", "Swish"); + } + else if (op == "Tan") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Tanh") + { + fprintf(pp, "%-16s", "UnaryOp"); + } + else if (op == "Tile") + { + fprintf(pp, "%-16s", "TileOnnx"); + } + else if (op == "TopK") + { + fprintf(pp, "%-16s", "TopK"); + } + else if (op == "Transpose") + { + fprintf(pp, "%-16s", "Permute"); + } + else if (op == "Upsample" || op == "Resize") + { + fprintf(pp, "%-16s", "Interp"); + } + else if (op == "Unsqueeze") + { + std::vector axes = get_node_attr_ai(node, "axes"); + // fprintf(stderr, "axes[0]: %d\n",axes[0]); + if (axes[0] == 0) + { + fprintf(pp, "%-16s", "Noop"); + } + else + { + fprintf(pp, "%-16s", "ExpandDims"); + } + } + else if (op == "Where") + { + fprintf(pp, "%-16s", "Where"); + } + else if (op == "Yolov3DetectionOutput") + { + fprintf(pp, "%-16s", "Yolov3DetectionOutput"); + } + else + { + // TODO + fprintf(stderr, "%s not supported yet!\n", op.c_str()); + fprintf(pp, "%-16s", op.c_str()); + } + + fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size); + + for (int j = 0; j < (int)node.input_size(); j++) + { + std::string input_name = node.input(j); + + // check weight + if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0) + { + continue; + } + + if (input_name.empty()) + { + continue; + } + + if (split_node_reference.find(input_name) != split_node_reference.end()) + { + int refidx = split_node_reference[input_name] - 1; + split_node_reference[input_name] = refidx; + + char splitsuffix[256]; + sprintf(splitsuffix, "_splitncnn_%d", refidx); + input_name = input_name + splitsuffix; + } + + fprintf(pp, " %s", input_name.c_str()); + } + + for (int j = 0; j < output_size; j++) + { + const std::string& output_name = node.output(j); + + fprintf(pp, " %s", output_name.c_str()); + } + + if (op == "Abs") + { + int op_type = 0; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Acos") + { + int op_type = 13; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Add") + { + int op_type = 0; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "ArgMax") + { + int axis = get_node_attr_i(node, "axis"); + int keepdims = get_node_attr_i(node, "keepdims"); + fprintf(pp, " 0=%d", axis - 1); + fprintf(pp, " 3=%d", keepdims); + } + else if (op == "Asin") + { + int op_type = 12; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Atan") + { + int op_type = 14; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "AveragePool" || op == "MaxPool") + { + std::string auto_pad = get_node_attr_s(node, "auto_pad"); + int ceil_mode = get_node_attr_i(node, "ceil_mode", 0); + std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); + std::vector strides = get_node_attr_ai(node, "strides"); + std::vector pads = get_node_attr_ai(node, "pads"); + + int pool = op == "AveragePool" ? 1 : 0; + int pad_mode = 1; + + if (auto_pad == "SAME_UPPER") + { + pad_mode = 2; + } + else if (auto_pad == "SAME_LOWER") + { + pad_mode = 3; + } + + if (ceil_mode == 1) + { + pad_mode = 0; + } + + fprintf(pp, " 0=%d", pool); + + if (kernel_shape.size() == 1) + { + fprintf(pp, " 1=%d", kernel_shape[0]); + } + else if (kernel_shape.size() == 2) + { + fprintf(pp, " 1=%d", kernel_shape[1]); + fprintf(pp, " 11=%d", kernel_shape[0]); + } + + if (strides.size() == 1) + { + fprintf(pp, " 2=%d", strides[0]); + } + else if (strides.size() == 2) + { + fprintf(pp, " 2=%d", strides[1]); + fprintf(pp, " 12=%d", strides[0]); + } + + if (pads.size() == 1) + { + fprintf(pp, " 3=%d", pads[0]); + } + else if (pads.size() == 2) + { + fprintf(pp, " 3=%d", pads[1]); + fprintf(pp, " 13=%d", pads[0]); + } + else if (pads.size() == 4) + { + fprintf(pp, " 3=%d", pads[1]); + fprintf(pp, " 13=%d", pads[0]); + fprintf(pp, " 14=%d", pads[3]); + fprintf(pp, " 15=%d", pads[2]); + } + + fprintf(pp, " 5=%d", pad_mode); + + if (op == "AveragePool") + { + int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0); + fprintf(pp, " 6=%d", avgpool_count_include_pad); + } + } + else if (op == "BatchNormalization") + { + float epsilon = get_node_attr_f(node, "epsilon", 1e-5f); + + const onnx::TensorProto& scale = weights[node.input(1)]; + const onnx::TensorProto& B = weights[node.input(2)]; + const onnx::TensorProto& mean = weights[node.input(3)]; + const onnx::TensorProto& var = weights[node.input(4)]; + + int channels = get_tensor_proto_data_size(scale); + + fprintf(pp, " 0=%d", channels); + + fwrite_tensor_proto_data(scale, bp); + fwrite_tensor_proto_data(mean, bp); + // apply epsilon to var + { + const float* v = + var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data(); + + for (int j = 0; j < channels; j++) + { + float ve = v[j] + epsilon; + fwrite(&ve, sizeof(float), 1, bp); + } + } + fwrite_tensor_proto_data(B, bp); + } + else if (op == "BiasGelu") + { + const onnx::TensorProto& B = weights[node.input(1)]; + + fprintf(pp, " 0=%d", get_tensor_proto_data_size(B)); + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(B, bp); + } + else if (op == "Ceil") + { + int op_type = 3; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Clip") + { + float min; + float max; + if (node.input_size() == 1) + { + min = get_node_attr_f(node, "min", -FLT_MAX); + max = get_node_attr_f(node, "max", FLT_MAX); + } + else + { + min = weights.find(node.input(1)) != weights.end() ? get_node_attr_from_input(weights[node.input(1)]) : -FLT_MAX; + max = weights.find(node.input(2)) != weights.end() ? get_node_attr_from_input(weights[node.input(2)]) : FLT_MAX; + } + + fprintf(pp, " 0=%e", min); + fprintf(pp, " 1=%e", max); + } + else if (op == "Concat") + { + int axis = get_node_attr_i(node, "axis", 1); + fprintf(pp, " 0=%d", axis - 1); + } + else if (op == "Constant") + { + // never reach here + } + else if (op == "ConstantOfShape") + { + float value = 0.f; + value = get_node_attr_f(node, "value", 0.f); + fprintf(pp, " 0=%f", value); + } + else if (op == "Conv") + { + const onnx::TensorProto& W = weights[node.input(1)]; + + int num_filter = W.dims(0); + int has_bias = node.input_size() == 3 ? 1 : 0; + + std::string auto_pad = get_node_attr_s(node, "auto_pad"); + std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); + std::vector dilations = get_node_attr_ai(node, "dilations"); + std::vector strides = get_node_attr_ai(node, "strides"); + std::vector pads = get_node_attr_ai(node, "pads"); + int group = get_node_attr_i(node, "group", 1); + + fprintf(pp, " 0=%d", num_filter); + + if (kernel_shape.size() == 1) + { + fprintf(pp, " 1=%d", kernel_shape[0]); + } + else if (kernel_shape.size() == 2) + { + fprintf(pp, " 1=%d", kernel_shape[1]); + fprintf(pp, " 11=%d", kernel_shape[0]); + } + + if (dilations.size() == 1) + { + fprintf(pp, " 2=%d", dilations[0]); + } + else if (dilations.size() == 2) + { + fprintf(pp, " 2=%d", dilations[1]); + fprintf(pp, " 12=%d", dilations[0]); + } + + if (strides.size() == 1) + { + fprintf(pp, " 3=%d", strides[0]); + } + else if (strides.size() == 2) + { + fprintf(pp, " 3=%d", strides[1]); + fprintf(pp, " 13=%d", strides[0]); + } + + if (auto_pad == "SAME_UPPER") + { + fprintf(pp, " 4=-233"); + } + else if (auto_pad == "SAME_LOWER") + { + fprintf(pp, " 4=-234"); + } + else + { + if (pads.size() == 1) + { + fprintf(pp, " 4=%d", pads[0]); + } + else if (pads.size() == 2) + { + fprintf(pp, " 4=%d", pads[1]); + fprintf(pp, " 14=%d", pads[0]); + } + else if (pads.size() == 4) + { + fprintf(pp, " 4=%d", pads[1]); + fprintf(pp, " 14=%d", pads[0]); + fprintf(pp, " 15=%d", pads[3]); + fprintf(pp, " 16=%d", pads[2]); + } + } + + fprintf(pp, " 5=%d", has_bias); + + fprintf(pp, " 6=%d", get_tensor_proto_data_size(W)); + + if (group > 1) + { + fprintf(pp, " 7=%d", group); + } + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(W, bp); + + if (has_bias) + { + const onnx::TensorProto& B = weights[node.input(2)]; + fwrite_tensor_proto_data(B, bp); + } + } + else if (op == "ConvTranspose") + { + const onnx::TensorProto& W = weights[node.input(1)]; + + int has_bias = node.input_size() == 3 ? 1 : 0; + + std::string auto_pad = get_node_attr_s(node, "auto_pad"); + std::vector kernel_shape = get_node_attr_ai(node, "kernel_shape"); + std::vector dilations = get_node_attr_ai(node, "dilations"); + std::vector strides = get_node_attr_ai(node, "strides"); + std::vector output_padding = get_node_attr_ai(node, "output_padding"); + std::vector output_shape = get_node_attr_ai(node, "output_shape"); + std::vector pads = get_node_attr_ai(node, "pads"); + int group = get_node_attr_i(node, "group", 1); + int num_filter = W.dims(1) * group; + + fprintf(pp, " 0=%d", num_filter); + + if (kernel_shape.size() == 1) + { + fprintf(pp, " 1=%d", kernel_shape[0]); + } + else if (kernel_shape.size() == 2) + { + fprintf(pp, " 1=%d", kernel_shape[1]); + fprintf(pp, " 11=%d", kernel_shape[0]); + } + + if (dilations.size() == 1) + { + fprintf(pp, " 2=%d", dilations[0]); + } + else if (dilations.size() == 2) + { + fprintf(pp, " 2=%d", dilations[1]); + fprintf(pp, " 12=%d", dilations[0]); + } + + if (strides.size() == 1) + { + fprintf(pp, " 3=%d", strides[0]); + } + else if (strides.size() == 2) + { + fprintf(pp, " 3=%d", strides[1]); + fprintf(pp, " 13=%d", strides[0]); + } + + if (auto_pad == "SAME_UPPER") + { + fprintf(pp, " 4=-233"); + } + else if (auto_pad == "SAME_LOWER") + { + fprintf(pp, " 4=-234"); + } + else + { + if (pads.size() == 1) + { + fprintf(pp, " 4=%d", pads[0]); + } + else if (pads.size() == 2) + { + fprintf(pp, " 4=%d", pads[1]); + fprintf(pp, " 14=%d", pads[0]); + } + else if (pads.size() == 4) + { + fprintf(pp, " 4=%d", pads[1]); + fprintf(pp, " 14=%d", pads[0]); + fprintf(pp, " 15=%d", pads[3]); + fprintf(pp, " 16=%d", pads[2]); + } + } + + if (output_padding.size() == 1) + { + fprintf(pp, " 18=%d", output_padding[0]); + } + else if (output_padding.size() == 2) + { + fprintf(pp, " 18=%d", output_padding[1]); + fprintf(pp, " 19=%d", output_padding[0]); + } + + if (output_shape.size() == 1) + { + fprintf(pp, " 20=%d", output_shape[0]); + } + else if (output_shape.size() == 2) + { + fprintf(pp, " 20=%d", output_shape[1]); + fprintf(pp, " 21=%d", output_shape[0]); + } + + fprintf(pp, " 5=%d", has_bias); + + fprintf(pp, " 6=%d", get_tensor_proto_data_size(W)); + + if (group > 1) + { + fprintf(pp, " 7=%d", group); + } + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int maxk = 0; + if (kernel_shape.size() == 2) + { + maxk = kernel_shape[1] * kernel_shape[0]; + } + else + { + maxk = kernel_shape[0] * kernel_shape[0]; + } + int weight_data_size = get_tensor_proto_data_size(W); + const float* weight_data = 0; + if (W.has_raw_data()) + { + weight_data = (const float*)W.raw_data().data(); + } + else if (W.data_type() == 1) + { + weight_data = W.float_data().data(); + } + for (int g = 0; g < group; g++) + { + // reorder weight from inch-outch to outch-inch + int num_filter_g = num_filter / group; + int num_input = weight_data_size / maxk / num_filter_g / group; + const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input; + for (int k = 0; k < num_filter_g; k++) + { + for (int j = 0; j < num_input; j++) + { + fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp); + } + } + } + + if (has_bias) + { + const onnx::TensorProto& B = weights[node.input(2)]; + fwrite_tensor_proto_data(B, bp); + } + } + else if (op == "Cos") + { + int op_type = 10; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Crop") + { + auto starts = get_node_attr_ai(node, "starts"); + fprintf(pp, " -23309=%zu", starts.size()); + for (size_t j = 0; j < starts.size(); ++j) + { + fprintf(pp, ",%i", starts[j]); + } + auto ends = get_node_attr_ai(node, "ends"); + fprintf(pp, " -23310=%zu", ends.size()); + for (size_t j = 0; j < ends.size(); ++j) + { + fprintf(pp, ",%i", ends[j]); + } + auto axis = get_node_attr_ai(node, "axis"); + fprintf(pp, " -23311=%zu", axis.size()); + for (size_t j = 0; j < axis.size(); ++j) + { + fprintf(pp, ",%i", axis[j]); + } + } + else if (op == "DepthToSpace") + { + // pixelshuffle + int scale_factor = get_node_attr_i(node, "blocksize", 1); + std::string mode = get_node_attr_s(node, "mode"); + fprintf(pp, " 0=%d", scale_factor); + if (mode == "CRD") + { + fprintf(pp, " 1=0"); + } + else if (mode == "DCR") + { + fprintf(pp, " 1=1"); + } + } + else if (op == "DetectionOutput") + { + float score_threshold = get_node_attr_f(node, "score_threshold"); + float nms_threshold = get_node_attr_f(node, "nms_threshold"); + int nms_top_k = get_node_attr_i(node, "nms_top_k"); + int keep_top_k = get_node_attr_i(node, "keep_top_k"); + int num_class = get_node_attr_i(node, "num_class"); + std::vector vars = get_node_attr_af(node, "vars"); + fprintf(pp, " 0=%d", num_class); + fprintf(pp, " 1=%f", nms_threshold); + fprintf(pp, " 2=%d", nms_top_k); + fprintf(pp, " 3=%d", keep_top_k); + fprintf(pp, " 4=%f", score_threshold); + fprintf(pp, " 5=%f", vars[0]); + fprintf(pp, " 6=%f", vars[1]); + fprintf(pp, " 7=%f", vars[2]); + fprintf(pp, " 8=%f", vars[3]); + } + else if (op == "Div") + { + int op_type = 3; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "Dropout") + { + // no-op + } + else if (op == "Elu") + { + float alpha = get_node_attr_f(node, "alpha", 1.f); + fprintf(pp, " 0=%e", alpha); + } + else if (op == "EmbedLayerNormalization") + { + const onnx::TensorProto& words = weights[node.input(2)]; + const onnx::TensorProto& positions = weights[node.input(3)]; + const onnx::TensorProto& W = weights[node.input(5)]; + const onnx::TensorProto& B = weights[node.input(6)]; + + fprintf(pp, " 0=%d", get_tensor_proto_data_size(B)); + fprintf(pp, " 1=%d", get_tensor_proto_data_size(words)); + fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions)); + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(words, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(positions, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(W, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(B, bp); + } + else if (op == "Equal") + { + int op_type = 0; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Exp") + { + int op_type = 7; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Flatten") + { + int axis = get_node_attr_i(node, "axis", 1); + if (axis != 1) + { + fprintf(stderr, "Unsupported Flatten axis %d!\n", axis); + } + } + else if (op == "Floor") + { + int op_type = 2; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Gather") + { + if (weights[node.input(1)].dims_size() > 1) + { + fprintf(stderr, "Unsupported indice dims > 1"); + } + int axis = get_node_attr_i(node, "axis", 1) - 1; + if (axis < 0) + { + fprintf(stderr, "Unsupported Gather axis: %d\n", axis + 1); + } + fprintf(pp, " 0=%d", axis); + } + else if (op == "Gelu") + { + fprintf(pp, " 0=1"); + } + else if (op == "Gemm") + { + float alpha = get_node_attr_f(node, "alpha", 1.f); + float beta = get_node_attr_f(node, "beta", 1.f); + int transA = get_node_attr_i(node, "transA", 0); + int transB = get_node_attr_i(node, "transB", 0); + + if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1) + { + // InnerProduct-like A * B + C + const onnx::TensorProto& B = weights[node.input(1)]; + // B has transposed. + int num_output = B.dims(0); + fprintf(pp, " 0=%d", num_output); + if (node.input_size() == 3) + { + fprintf(pp, " 1=1"); + } + else + { + fprintf(pp, " 1=0"); + } + fprintf(pp, " 2=%d", get_tensor_proto_data_size(B)); + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + fwrite_tensor_proto_data(B, bp); + if (node.input_size() == 3) + { + const onnx::TensorProto& C = weights[node.input(2)]; + fwrite_tensor_proto_data(C, bp); + } + } + else + { + // gemm + fprintf(pp, " 0=%e", alpha); + fprintf(pp, " 1=%e", beta); + fprintf(pp, " 2=%d", transA); + fprintf(pp, " 3=%d", transB); + } + } + else if (op == "GlobalAveragePool") + { + int pool = 1; + int global_pool = 1; + + fprintf(pp, " 0=%d", pool); + fprintf(pp, " 4=%d", global_pool); + } + else if (op == "GlobalMaxPool") + { + int pool = 0; + int global_pool = 1; + + fprintf(pp, " 0=%d", pool); + fprintf(pp, " 4=%d", global_pool); + } + else if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d" || + op == "adaptive_max_pool2d") + { + int pool = 0; + if (op == "AdaptiveAvgPool2d" || op == "adaptive_avg_pool2d") + { + pool = 1; + } + int adaptive_pooling = 1; + const onnx::TensorProto& out_shape_tp = weights[node.input(1)]; + std::vector out_shape = get_node_attr_from_input_ai(out_shape_tp); + + fprintf(pp, " 0=%d", pool); + fprintf(pp, " 7=%d", adaptive_pooling); + if (out_shape.size() == 1) + { + fprintf(pp, " 8=%d", out_shape[0]); + } + else if (out_shape.size() == 2) + { + // out_w + fprintf(pp, " 8=%d", out_shape[1]); + // out_h + fprintf(pp, " 18=%d", out_shape[0]); + } + } + else if (op == "GroupNorm") + { + int groups = get_node_attr_i(node, "groups", 1); + int channels = get_node_attr_i(node, "channels", 1); + float eps = get_node_attr_f(node, "epsilon", 1e-5f); + int affine = get_node_attr_i(node, "affine", 1); + + if (affine) + { + // discard affine-less S=1 B=0 + std::vector affine_S = get_node_attr_from_input_af(weights[node.input(1)]); + std::vector affine_B = get_node_attr_from_input_af(weights[node.input(2)]); + if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && + affine_B[0] == 0.f) + { + affine = 0; + } + else + { + affine = 0; + { + for (int j = 0; j < channels; j++) + { + if (affine_S[j] != 1.f || affine_B[j] != 0.f) + { + affine = 1; + break; + } + } + } + } + } + + fprintf(pp, " 0=%d", groups); + fprintf(pp, " 1=%d", channels); + fprintf(pp, " 2=%e", eps); + fprintf(pp, " 3=%d", affine); + if (affine) + { + const onnx::TensorProto& scale = weights[node.input(1)]; + const onnx::TensorProto& B = weights[node.input(2)]; + + fwrite_tensor_proto_data(scale, bp); + fwrite_tensor_proto_data(B, bp); + } + } + else if (op == "GRU") + { + const onnx::TensorProto& W = weights[node.input(1)]; + const onnx::TensorProto& R = weights[node.input(2)]; + const onnx::TensorProto& B = weights[node.input(3)]; + + int hidden_size = get_node_attr_i(node, "hidden_size", 0); + std::string direction = get_node_attr_s(node, "direction"); + + int direction_type = 0; + if (direction == "forward") + { + direction_type = 0; + } + else if (direction == "reverse") + { + direction_type = 1; + } + else if (direction == "bidirectional") + { + direction_type = 2; + } + + int weight_data_size = get_tensor_proto_data_size(W); + + fprintf(pp, " 0=%d", hidden_size); + fprintf(pp, " 1=%d", weight_data_size); + fprintf(pp, " 2=%d", direction_type); + + int num_directions = direction_type == 2 ? 2 : 1; + + int quantize_tag = 0; + + // reorder num_directions-URN-hidden-size to + // num_directions-RUN-hidden-size + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions; + const float* wptr = + W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data(); + + const float* uptr = wptr; + const float* rptr = wptr + weight_data_size_g; + const float* nptr = wptr + weight_data_size_g * 2; + fwrite(rptr, sizeof(float), weight_data_size_g, bp); + fwrite(uptr, sizeof(float), weight_data_size_g, bp); + fwrite(nptr, sizeof(float), weight_data_size_g, bp); + + if (direction_type == 2) + { + uptr += weight_data_size_g * 3; + rptr += weight_data_size_g * 3; + nptr += weight_data_size_g * 3; + fwrite(rptr, sizeof(float), weight_data_size_g, bp); + fwrite(uptr, sizeof(float), weight_data_size_g, bp); + fwrite(nptr, sizeof(float), weight_data_size_g, bp); + } + } + + // reduce U and R bias except N + // reorder num_directions-URN-hidden to num_directions-RUN-hidden + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions; + const float* bptr = + B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); + const float* wuptr = bptr; + const float* wrptr = bptr + bias_data_size_g; + const float* wnptr = bptr + bias_data_size_g * 2; + const float* buptr = bptr + bias_data_size_g * 3; + const float* brptr = bptr + bias_data_size_g * 4; + const float* bnptr = bptr + bias_data_size_g * 5; + + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = wrptr[j] + brptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = wuptr[j] + buptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + fwrite(wnptr, sizeof(float), bias_data_size_g, bp); + fwrite(bnptr, sizeof(float), bias_data_size_g, bp); + + if (direction_type == 2) + { + wuptr += bias_data_size_g * 6; + wrptr += bias_data_size_g * 6; + wnptr += bias_data_size_g * 6; + buptr += bias_data_size_g * 6; + brptr += bias_data_size_g * 6; + bnptr += bias_data_size_g * 6; + + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = wrptr[j] + brptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = wuptr[j] + buptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + fwrite(wnptr, sizeof(float), bias_data_size_g, bp); + fwrite(bnptr, sizeof(float), bias_data_size_g, bp); + } + } + + // reorder num_directions-URN-hidden-hidden to + // num_directions-RUN-hidden-hidden + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions; + const float* Rptr = + R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data(); + + const float* uptr = Rptr; + const float* rptr = Rptr + weight_data_size_g; + const float* nptr = Rptr + weight_data_size_g * 2; + fwrite(rptr, sizeof(float), weight_data_size_g, bp); + fwrite(uptr, sizeof(float), weight_data_size_g, bp); + fwrite(nptr, sizeof(float), weight_data_size_g, bp); + + if (direction_type == 2) + { + uptr += weight_data_size_g * 3; + rptr += weight_data_size_g * 3; + nptr += weight_data_size_g * 3; + fwrite(rptr, sizeof(float), weight_data_size_g, bp); + fwrite(uptr, sizeof(float), weight_data_size_g, bp); + fwrite(nptr, sizeof(float), weight_data_size_g, bp); + } + } + } + else if (op == "HardSigmoid") + { + float alpha = get_node_attr_f(node, "alpha", 0.2f); + float beta = get_node_attr_f(node, "beta", 0.5f); + + fprintf(pp, " 0=%e", alpha); + fprintf(pp, " 1=%e", beta); + } + else if (op == "HardSwish") + { + float alpha = get_node_attr_f(node, "alpha", 0.2f); + float beta = get_node_attr_f(node, "beta", 0.5f); + + fprintf(pp, " 0=%e", alpha); + fprintf(pp, " 1=%e", beta); + } + else if (op == "ImageScaler") + { + std::vector bias = get_node_attr_af(node, "bias"); + float scale = get_node_attr_f(node, "scale", 1.f); + + int channels = (int)bias.size(); + + fprintf(pp, " 0=%d", channels); + fprintf(pp, " 1=1"); + + for (int j = 0; j < channels; j++) + { + fwrite(&scale, sizeof(float), 1, bp); + } + fwrite(&bias[0], sizeof(float), channels, bp); + } + else if (op == "InstanceNormalization") + { + float eps = get_node_attr_f(node, "epsilon", 1e-5f); + + // discard affine-less S=1 B=0 + std::vector affine_S = get_node_attr_from_input_af(weights[node.input(1)]); + std::vector affine_B = get_node_attr_from_input_af(weights[node.input(2)]); + int channels = (int)affine_S.size(); + int affine = 0; + { + for (int j = 0; j < channels; j++) + { + if (affine_S[j] != 1.f || affine_B[j] != 0.f) + { + affine = 1; + break; + } + } + } + + fprintf(pp, " 0=%d", channels); + fprintf(pp, " 1=%e", eps); + fprintf(pp, " 2=%d", affine); + if (affine) + { + const onnx::TensorProto& scale = weights[node.input(1)]; + const onnx::TensorProto& B = weights[node.input(2)]; + + fwrite_tensor_proto_data(scale, bp); + fwrite_tensor_proto_data(B, bp); + } + } + else if (op == "LayerNorm") + { + float eps = get_node_attr_f(node, "epsilon", 1e-5f); + int affine = get_node_attr_i(node, "affine", 1); + + if (affine) + { + // discard affine-less S=1 B=0 + std::vector affine_S = get_node_attr_from_input_af(weights[node.input(1)]); + std::vector affine_B = get_node_attr_from_input_af(weights[node.input(2)]); + int affine_size = (int)affine_S.size(); + affine = 0; + { + for (int j = 0; j < affine_size; j++) + { + if (affine_S[j] != 1.f || affine_B[j] != 0.f) + { + affine = 1; + break; + } + } + } + + if (affine) + { + fprintf(pp, " 0=%d", affine_size); + } + } + + fprintf(pp, " 1=%e", eps); + fprintf(pp, " 2=%d", affine); + + if (affine) + { + const onnx::TensorProto& scale = weights[node.input(1)]; + const onnx::TensorProto& B = weights[node.input(2)]; + + fwrite_tensor_proto_data(scale, bp); + fwrite_tensor_proto_data(B, bp); + } + } + else if (op == "LeakyRelu") + { + float alpha = get_node_attr_f(node, "alpha", 0.01f); + fprintf(pp, " 0=%e", alpha); + } + else if (op == "Threshold") + { + float threshold = get_node_attr_f(node, "threshold", 0.f); + fprintf(pp, " 0=%e", threshold); + } + else if (op == "Log") + { + int op_type = 8; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "LRN") + { + float alpha = get_node_attr_f(node, "alpha", 1.f); + float beta = get_node_attr_f(node, "beta", 0.5f); + float bias = get_node_attr_f(node, "bias", 1.f); + int size = get_node_attr_i(node, "size", 1); + + int norm_region = 0; + + fprintf(pp, " 0=%d", norm_region); + fprintf(pp, " 1=%d", size); + fprintf(pp, " 2=%e", alpha); + fprintf(pp, " 3=%e", beta); + fprintf(pp, " 4=%e", bias); + } + else if (op == "LSTM") + { + const onnx::TensorProto& W = weights[node.input(1)]; + const onnx::TensorProto& R = weights[node.input(2)]; + const onnx::TensorProto& B = weights[node.input(3)]; + + int hidden_size = get_node_attr_i(node, "hidden_size", 0); + std::string direction = get_node_attr_s(node, "direction"); + + int direction_type = 0; + if (direction == "forward") + { + direction_type = 0; + } + else if (direction == "reverse") + { + direction_type = 1; + } + else if (direction == "bidirectional") + { + direction_type = 2; + } + + int weight_data_size = get_tensor_proto_data_size(W); + + fprintf(pp, " 0=%d", hidden_size); + fprintf(pp, " 1=%d", weight_data_size); + fprintf(pp, " 2=%d", direction_type); + + int num_directions = direction_type == 2 ? 2 : 1; + + int quantize_tag = 0; + + // reorder num_directions-IOFG-hidden-size to + // num_directions-IFOG-hidden-size + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions; + const float* wptr = + W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data(); + + const float* iptr = wptr; + const float* optr = wptr + weight_data_size_g; + const float* fptr = wptr + weight_data_size_g * 2; + const float* gptr = wptr + weight_data_size_g * 3; + fwrite(iptr, sizeof(float), weight_data_size_g, bp); + fwrite(fptr, sizeof(float), weight_data_size_g, bp); + fwrite(optr, sizeof(float), weight_data_size_g, bp); + fwrite(gptr, sizeof(float), weight_data_size_g, bp); + + if (direction_type == 2) + { + iptr += weight_data_size_g * 4; + optr += weight_data_size_g * 4; + fptr += weight_data_size_g * 4; + gptr += weight_data_size_g * 4; + fwrite(iptr, sizeof(float), weight_data_size_g, bp); + fwrite(fptr, sizeof(float), weight_data_size_g, bp); + fwrite(optr, sizeof(float), weight_data_size_g, bp); + fwrite(gptr, sizeof(float), weight_data_size_g, bp); + } + } + + // reduce xc and hc bias + // reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions; + const float* xcbptr = + B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); + const float* xiptr = xcbptr; + const float* xoptr = xcbptr + bias_data_size_g; + const float* xfptr = xcbptr + bias_data_size_g * 2; + const float* xgptr = xcbptr + bias_data_size_g * 3; + const float* hiptr = xcbptr + bias_data_size_g * 4; + const float* hoptr = xcbptr + bias_data_size_g * 5; + const float* hfptr = xcbptr + bias_data_size_g * 6; + const float* hgptr = xcbptr + bias_data_size_g * 7; + + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xiptr[j] + hiptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xfptr[j] + hfptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xoptr[j] + hoptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xgptr[j] + hgptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + + if (direction_type == 2) + { + xiptr += bias_data_size_g * 8; + xoptr += bias_data_size_g * 8; + xfptr += bias_data_size_g * 8; + xgptr += bias_data_size_g * 8; + hiptr += bias_data_size_g * 8; + hoptr += bias_data_size_g * 8; + hfptr += bias_data_size_g * 8; + hgptr += bias_data_size_g * 8; + + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xiptr[j] + hiptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xfptr[j] + hfptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xoptr[j] + hoptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xgptr[j] + hgptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + + // reorder num_directions-IOFG-hidden-hidden to + // num_directions-IFOG-hidden-hidden + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions; + const float* rptr = + R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data(); + + const float* iptr = rptr; + const float* optr = rptr + weight_data_size_g; + const float* fptr = rptr + weight_data_size_g * 2; + const float* gptr = rptr + weight_data_size_g * 3; + fwrite(iptr, sizeof(float), weight_data_size_g, bp); + fwrite(fptr, sizeof(float), weight_data_size_g, bp); + fwrite(optr, sizeof(float), weight_data_size_g, bp); + fwrite(gptr, sizeof(float), weight_data_size_g, bp); + + if (direction_type == 2) + { + iptr += weight_data_size_g * 4; + optr += weight_data_size_g * 4; + fptr += weight_data_size_g * 4; + gptr += weight_data_size_g * 4; + fwrite(iptr, sizeof(float), weight_data_size_g, bp); + fwrite(fptr, sizeof(float), weight_data_size_g, bp); + fwrite(optr, sizeof(float), weight_data_size_g, bp); + fwrite(gptr, sizeof(float), weight_data_size_g, bp); + } + } + } + else if (op == "MatMul") + { + if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2) + { + // InnerProduct + const onnx::TensorProto& B = weights[node.input(1)]; + + int weight_data_size = get_tensor_proto_data_size(B); + + int num_output = B.dims(B.dims_size() - 1); + int num_input = weight_data_size / num_output; + + fprintf(pp, " 0=%d", num_output); + fprintf(pp, " 1=0"); + fprintf(pp, " 2=%d", weight_data_size); + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + + // reorder num_input-num_output to num_output-num_input + { + const float* bptr = + B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); + + for (int j = 0; j < num_output; j++) + { + for (int k = 0; k < num_input; k++) + { + float vb = bptr[k * num_output + j]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + + // fwrite_tensor_proto_data(B, bp) + } + else + { + // default matrix multiplication + } + } + else if (op == "Max") + { + int op_type = 4; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "Min") + { + int op_type = 5; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "Mul") + { + int op_type = 2; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "MultiHeadAttention") + { + int embed_dim = get_node_attr_i(node, "embed_dim", 0); + int num_heads = get_node_attr_i(node, "num_heads", 0); + + fprintf(pp, " 0=%d", embed_dim); + fprintf(pp, " 1=%d", num_heads); + + if (node.input_size() == 5) + { + const onnx::TensorProto& qkvw = weights[node.input(1)]; + const onnx::TensorProto& qkvb = weights[node.input(2)]; + const onnx::TensorProto& ow = weights[node.input(3)]; + const onnx::TensorProto& ob = weights[node.input(4)]; + + int weight_data_size = get_tensor_proto_data_size(ow); + + fprintf(pp, " 2=%d", weight_data_size); + + int quantize_tag = 0; + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose qw + { + const float* wptr = + qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data(); + const float* bptr = + qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data(); + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim * 3 + k]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + + fwrite(bptr, sizeof(float), embed_dim, bp); + } + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose kw + { + const float* wptr = + qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data(); + const float* bptr = + qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data(); + bptr += embed_dim; + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim * 3 + k + embed_dim]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + + fwrite(bptr, sizeof(float), embed_dim, bp); + } + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose vw + { + const float* wptr = + qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data(); + const float* bptr = + qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data(); + bptr += embed_dim * 2; + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim * 3 + k + embed_dim * 2]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + + fwrite(bptr, sizeof(float), embed_dim, bp); + } + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose ow + { + const float* wptr = + ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data(); + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim + k]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + fwrite_tensor_proto_data(ob, bp); + } + else + { + const onnx::TensorProto& qw = weights[node.input(3)]; + const onnx::TensorProto& qb = weights[node.input(4)]; + const onnx::TensorProto& kw = weights[node.input(5)]; + const onnx::TensorProto& kb = weights[node.input(6)]; + const onnx::TensorProto& vw = weights[node.input(7)]; + const onnx::TensorProto& vb = weights[node.input(8)]; + const onnx::TensorProto& ow = weights[node.input(9)]; + const onnx::TensorProto& ob = weights[node.input(10)]; + + int weight_data_size = get_tensor_proto_data_size(qw); + + fprintf(pp, " 2=%d", weight_data_size); + + int quantize_tag = 0; + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose qw + { + const float* wptr = + qw.has_raw_data() ? (const float*)qw.raw_data().data() : qw.float_data().data(); + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim + k]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + fwrite_tensor_proto_data(qb, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose kw + { + const float* wptr = + kw.has_raw_data() ? (const float*)kw.raw_data().data() : kw.float_data().data(); + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim + k]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + fwrite_tensor_proto_data(kb, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose vw + { + const float* wptr = + vw.has_raw_data() ? (const float*)vw.raw_data().data() : vw.float_data().data(); + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim + k]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + fwrite_tensor_proto_data(vb, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + // transpose ow + { + const float* wptr = + ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data(); + + for (int j = 0; j < embed_dim; j++) + { + for (int k = 0; k < embed_dim; k++) + { + float vb = wptr[j * embed_dim + k]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + fwrite_tensor_proto_data(ob, bp); + } + } + else if (op == "Neg") + { + int op_type = 1; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "NonMaxSuppression") + { + int max_dets = 0; + float iou_thre = 0.f; + float score_thre = 0.f; + // fprintf(stderr, "%s\n", node.name().c_str()); + // fprintf(stderr, "node.input_size(): %d\n", node.input_size()); + if (node.input_size() >= 3) + { + // fprintf(stderr, "ok12!\n"); + max_dets = (int)(get_node_attr_from_input(weights[node.input(2)]) + 0.5); + } + if (node.input_size() >= 4) + { + // fprintf(stderr, "iou_thre: %f\n", + // get_node_attr_from_input(weights[node.input(3)])); + iou_thre = get_node_attr_from_input(weights[node.input(3)]); + } + if (node.input_size() >= 5) + { + // fprintf(stderr, "score_thre: %f\n", + // get_node_attr_from_input(weights[node.input(4)])); + score_thre = get_node_attr_from_input(weights[node.input(4)]); + } + fprintf(pp, " 0=%d", max_dets); + fprintf(pp, " 1=%f", iou_thre); + fprintf(pp, " 2=%f", score_thre); + } + else if (op == "Normalize") + { + float eps = get_node_attr_f(node, "eps", 0.f); + int scale_data_size = 1; + + fprintf(pp, " 1=1"); // channel_shared + fprintf(pp, " 2=%e", eps); + fprintf(pp, " 3=%d", scale_data_size); + fprintf(pp, " 9=1"); // TODO hardcode pytorch style + + const float scale_data[1] = {1.f}; + fwrite(scale_data, sizeof(float), 1, bp); + } + else if (op == "Pad") + { + std::string mode = get_node_attr_s(node, "mode"); + float value = get_node_attr_f(node, "value", 0.f); + + std::vector pads; + if (node.input_size() == 1) + { + pads = get_node_attr_ai(node, "pads"); + } + else + { + pads = get_node_attr_from_input_ai(weights[node.input(1)]); + } + int type = 0; + if (mode == "constant") + { + type = 0; + } + else if (mode == "edge") + { + type = 1; + } + else if (mode == "reflect") + { + type = 2; + } + + int pad_size = (int)pads.size(); + int top = 0; + int bottom = 0; + int left = 0; + int right = 0; + int front = 0; + int behind = 0; + if (pad_size == 8) + { + // NCHW + top = pads[2]; + bottom = pads[6]; + left = pads[3]; + right = pads[7]; + front = pads[1]; + behind = pads[5]; + } + else if (pad_size == 6) + { + // NHW + top = pads[1]; + bottom = pads[4]; + left = pads[2]; + right = pads[5]; + } + else + { + // NW + left = pads[1]; + right = pads[3]; + } + + fprintf(pp, " 0=%d", top); + fprintf(pp, " 1=%d", bottom); + fprintf(pp, " 2=%d", left); + fprintf(pp, " 3=%d", right); + fprintf(pp, " 4=%d", type); + fprintf(pp, " 5=%e", value); + fprintf(pp, " 7=%d", front); + fprintf(pp, " 8=%d", behind); + } + else if (op == "Pow") + { + int op_type = 6; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "PriorBox") + { + std::vector min_sizes = get_node_attr_af(node, "min_sizes"); + std::vector max_sizes = get_node_attr_af(node, "max_sizes"); + std::vector aspect_ratios = get_node_attr_af(node, "aspect_ratios"); + fprintf(pp, " -23300=%zu", min_sizes.size()); + for (size_t j = 0; j < min_sizes.size(); ++j) + { + fprintf(pp, ",%f", min_sizes[j]); + } + fprintf(pp, " -23301=%zu", max_sizes.size()); + for (size_t j = 0; j < max_sizes.size(); ++j) + { + fprintf(pp, ",%f", max_sizes[j]); + } + fprintf(pp, " -23302=%zu", aspect_ratios.size()); + for (size_t j = 0; j < aspect_ratios.size(); ++j) + { + fprintf(pp, ",%f", aspect_ratios[j]); + } + int image_width = get_node_attr_i(node, "image_width"); + int image_height = get_node_attr_i(node, "image_height"); + float step_width = get_node_attr_f(node, "step_width"); + float step_height = get_node_attr_f(node, "step_height"); + float offset = get_node_attr_f(node, "offset"); + int step_mmdetection = get_node_attr_i(node, "step_mmdetection"); + fprintf(pp, " 9=%d", image_width); + fprintf(pp, " 10=%d", image_height); + fprintf(pp, " 11=%f", step_width); + fprintf(pp, " 12=%f", step_height); + fprintf(pp, " 13=%f", offset); + fprintf(pp, " 14=%d", step_mmdetection); + } + else if (op == "PixelShuffle") + { + int scale_factor = get_node_attr_i(node, "scale_factor", 1); + fprintf(pp, " 0=%d", scale_factor); + } + else if (op == "PRelu") + { + const onnx::TensorProto& slope = weights[node.input(1)]; + + int num_slope = get_tensor_proto_data_size(slope); + + fprintf(pp, " 0=%d", num_slope); + + fwrite_tensor_proto_data(slope, bp); + } + else if (op == "Reciprocal") + { + int op_type = 15; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || + op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || + op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp") + { + int op_type = -233; + if (op == "ReduceSum") + op_type = 0; + else if (op == "ReduceSumSquare") + op_type = 2; + else if (op == "ReduceMean") + op_type = 3; + else if (op == "ReduceMax") + op_type = 4; + else if (op == "ReduceMin") + op_type = 5; + else if (op == "ReduceProd") + op_type = 6; + else if (op == "ReduceL1") + op_type = 7; + else if (op == "ReduceL2") + op_type = 8; + else if (op == "ReduceLogSum") + op_type = 9; + else if (op == "ReduceLogSumExp") + op_type = 10; + fprintf(pp, " 0=%d", op_type); + + std::vector axes = get_node_attr_ai(node, "axes"); + int keepdims = get_node_attr_i(node, "keepdims", 1); + + if (axes.size() > 0) + { + // if axes set, reduce according to axes + fprintf(pp, " 1=%d", 0); + fprintf(pp, " -23303=%zu", axes.size()); + for (size_t j = 0; j < axes.size(); j++) + { + if (axes[j] == 0 || axes[j] > 4 || axes[j] < -3) + fprintf(stderr, "Unsupported reduction axes !\n"); + fprintf(pp, ",%d", axes[j] > 0 ? axes[j] - 1 : axes[j]); + } + } + else + { + // if axes not set, reduce all axes by default + fprintf(pp, " 1=%d", 1); + } + fprintf(pp, " 4=%d", keepdims); + fprintf(pp, " 5=1"); + } + else if (op == "Reorg") + { + int stride = get_node_attr_i(node, "stride", 1); + fprintf(pp, " 0=%d", stride); + } + else if (op == "Reshape") + { + std::vector shape; + + if (node.input_size() == 1) + { + shape = get_node_attr_ai(node, "shape"); + } + else if (weights.find(node.input(1)) != weights.end()) + { + shape = get_node_attr_from_input_ai(weights[node.input(1)]); + } + else + { + fprintf(stderr, "Unsupported reshape weight ! \n"); + } + + if (shape.size() == 1) + { + fprintf(pp, " 0=%d", shape[0]); // should never reach here + } + else if (shape.size() == 2) + { + fprintf(pp, " 0=%d", shape[1]); + } + else if (shape.size() == 3) + { + fprintf(pp, " 0=%d", shape[2]); + fprintf(pp, " 1=%d", shape[1]); + } + else if (shape.size() == 4) + { + fprintf(pp, " 0=%d", shape[3]); + fprintf(pp, " 1=%d", shape[2]); + fprintf(pp, " 2=%d", shape[1]); + } + else if (shape.size() == 5) + { + fprintf(pp, " 0=%d", shape[4] * shape[3]); + fprintf(pp, " 1=%d", shape[2]); + fprintf(pp, " 2=%d", shape[1]); + } + } + else if (op == "Resize") + { + std::string mode = get_node_attr_s(node, "mode"); + std::string align = get_node_attr_s(node, "coordinate_transformation_mode"); + + std::vector scales; + std::vector sizes; + if (node.input_size() == 2) + { + // opset 10 + scales = get_node_attr_from_input_af(weights[node.input(1)]); + } + else + { + // opset 11+ + scales = get_node_attr_from_input_af(weights[node.input(2)]); + if (node.input_size() >= 4) + { + sizes = get_node_attr_from_input_ai(weights[node.input(3)]); + } + } + + int resize_type = 1; + if (mode == "nearest") + { + resize_type = 1; + } + else if (mode == "linear") + { + resize_type = 2; + } + else if (mode == "cubic") + { + resize_type = 3; + } + + if (scales.empty() && sizes.empty()) + { + fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n"); + } + + float h_scale = 1.f; + float w_scale = 1.f; + if (scales.size() == 2) + { + w_scale = scales[1]; + } + else if (scales.size() == 3) + { + h_scale = scales[1]; + w_scale = scales[2]; + } + else if (scales.size() == 4) + { + h_scale = scales[2]; + w_scale = scales[3]; + + if (scales[1] != 1.f) fprintf(stderr, "Unsupported Resize scales !\n"); + } + + int output_height = 0; + int output_width = 0; + if (sizes.size() == 2) + { + output_width = sizes[1]; + } + else if (sizes.size() == 3) + { + output_height = sizes[1]; + output_width = sizes[2]; + } + else if (sizes.size() == 4) + { + output_height = sizes[2]; + output_width = sizes[3]; + } + + int align_corner = 0; + if (align == "align_corners") + { + align_corner = 1; + } + + fprintf(pp, " 0=%d", resize_type); + fprintf(pp, " 1=%e", h_scale); + fprintf(pp, " 2=%e", w_scale); + fprintf(pp, " 3=%d", output_height); + fprintf(pp, " 4=%d", output_width); + fprintf(pp, " 6=%d", align_corner); + } + else if (op == "RNN") + { + const onnx::TensorProto& W = weights[node.input(1)]; + const onnx::TensorProto& R = weights[node.input(2)]; + const onnx::TensorProto& B = weights[node.input(3)]; + + int hidden_size = get_node_attr_i(node, "hidden_size", 0); + std::string direction = get_node_attr_s(node, "direction"); + + int direction_type = 0; + if (direction == "forward") + { + direction_type = 0; + } + else if (direction == "reverse") + { + direction_type = 1; + } + else if (direction == "bidirectional") + { + direction_type = 2; + } + + int weight_data_size = get_tensor_proto_data_size(W); + + fprintf(pp, " 0=%d", hidden_size); + fprintf(pp, " 1=%d", weight_data_size); + fprintf(pp, " 2=%d", direction_type); + + int num_directions = direction_type == 2 ? 2 : 1; + + int quantize_tag = 0; + + fwrite(&quantize_tag, sizeof(int), 1, bp); + fwrite_tensor_proto_data(W, bp); + + // reduce xc and hc bias + { + fwrite(&quantize_tag, sizeof(int), 1, bp); + + int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions; + const float* bptr = + B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); + const float* xiptr = bptr; + const float* hiptr = bptr + bias_data_size_g; + + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xiptr[j] + hiptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + + if (direction_type == 2) + { + xiptr += bias_data_size_g * 2; + hiptr += bias_data_size_g * 2; + + for (int j = 0; j < bias_data_size_g; j++) + { + float vb = xiptr[j] + hiptr[j]; + fwrite(&vb, sizeof(float), 1, bp); + } + } + } + + fwrite(&quantize_tag, sizeof(int), 1, bp); + fwrite_tensor_proto_data(R, bp); + } + else if (op == "RDiv") + { + int op_type = 8; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "RSub") + { + int op_type = 7; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "RoiAlign") + { + int pooled_width = get_node_attr_i(node, "output_width", 1); + int pooled_height = get_node_attr_i(node, "output_height", 1); + float spatial_scale = get_node_attr_f(node, "spatial_scale", 1.f); + int sampling_ratio = get_node_attr_i(node, "sampling_ratio", 0); + fprintf(pp, " 0=%d", pooled_width); + fprintf(pp, " 1=%d", pooled_height); + fprintf(pp, " 2=%f", spatial_scale); + fprintf(pp, " 3=%d", sampling_ratio); + } + else if (op == "ShuffleChannel") + { + int group = get_node_attr_i(node, "group", 1); + int reverse = get_node_attr_i(node, "reverse", 0); + fprintf(pp, " 0=%d", group); + fprintf(pp, " 1=%d", reverse); + } + else if (op == "Sigmoid") + { + // no param + } + else if (op == "Sin") + { + int op_type = 9; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "SkipLayerNormalization") + { + const onnx::TensorProto& W = weights[node.input(2)]; + const onnx::TensorProto& B = weights[node.input(3)]; + const onnx::TensorProto& B2 = weights[node.input(4)]; + + fprintf(pp, " 0=%d", get_tensor_proto_data_size(B)); + + int quantize_tag = 0; + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(W, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(B, bp); + + fwrite(&quantize_tag, sizeof(int), 1, bp); + + fwrite_tensor_proto_data(B2, bp); + } + else if (op == "Slice") + { + bool use_crop = true; + + std::vector starts; + std::vector ends; + std::vector axes; + std::vector steps; + if (node.input_size() == 1) + { + starts = get_node_attr_ai(node, "starts"); + ends = get_node_attr_ai(node, "ends"); + axes = get_node_attr_ai(node, "axes"); + steps = get_node_attr_ai(node, "steps"); // TODO + } + else + { + starts = get_node_attr_from_input_ai(weights[node.input(1)]); + ends = get_node_attr_from_input_ai(weights[node.input(2)]); + if (node.input_size() >= 4) axes = get_node_attr_from_input_ai(weights[node.input(3)]); + if (node.input_size() >= 5) steps = get_node_attr_from_input_ai(weights[node.input(4)]); + } + + // assert step == 1 or step >= ends + for (int i = 0; i < (int)steps.size(); i++) + { + if (steps[i] != 1 && steps[i] < ends[i]) + { + use_crop = false; + fprintf(stderr, "Unsupported slice step ! Use custom TensorSlice\n"); + } + } + + if (use_crop) + { + // filter out N-dim axis + if (!axes.empty()) + { + for (int i = 0; i < (int)axes.size(); i++) + { + int axis = axes[i]; + if (axis == 0) + { + starts.erase(starts.begin() + i); + ends.erase(ends.begin() + i); + axes.erase(axes.begin() + i); + break; + } + } + } + + fprintf(pp, " -23309=%d", (int)starts.size()); + for (int i = 0; i < (int)starts.size(); i++) + { + fprintf(pp, ",%d", starts[i]); + } + fprintf(pp, " -23310=%d", (int)ends.size()); + for (int i = 0; i < (int)ends.size(); i++) + { + fprintf(pp, ",%d", ends[i]); + } + if (!axes.empty()) + { + fprintf(pp, " -23311=%d", (int)axes.size()); + for (int i = 0; i < (int)axes.size(); i++) + { + int axis = axes[i]; + if (axis == 0 || axis > 3 || axis < -3) fprintf(stderr, "Unsupported slice axes !\n"); + + if (axis > 0) axis = axis - 1; // -1 for skip N-dim + + fprintf(pp, ",%d", axis); + } + } + } + else + { + fprintf(pp, " -23300=%d", (int)starts.size()); + for (int i = 0; i < (int)starts.size(); i++) + { + fprintf(pp, ",%d", starts[i]); + } + fprintf(pp, " -23301=%d", (int)ends.size()); + for (int i = 0; i < (int)ends.size(); i++) + { + fprintf(pp, ",%d", ends[i]); + } + if (!axes.empty()) + { + fprintf(pp, " -23302=%d", (int)axes.size()); + for (int i = 0; i < (int)axes.size(); i++) + { + int axis = axes[i]; + if (axis > 3 || axis < -3) fprintf(stderr, "Unsupported slice axes !\n"); + fprintf(pp, ",%d", axis); + } + } + if (!steps.empty()) + { + fprintf(pp, " -23303=%d", (int)steps.size()); + for (int i = 0; i < (int)steps.size(); i++) + { + int step = steps[i]; + if (step == 0) fprintf(stderr, "Unsupported slice step ! Unsupported slice step\n"); + fprintf(pp, ",%d", step); + } + } + } + } + else if (op == "Softmax") + { + int axis = get_node_attr_i(node, "axis", 1); + fprintf(pp, " 0=%d", axis - 1); + fprintf(pp, " 1=1"); + } + else if (op == "Split") + { + int axis = get_node_attr_i(node, "axis", 0); + std::vector split = get_node_attr_ai(node, "split"); + if (axis < 1) fprintf(stderr, "Unsupported split axis !\n"); + + fprintf(pp, " -23300=%d", output_size); + if (split.empty()) + { + for (int i = 0; i < output_size; i++) + { + fprintf(pp, ",-233"); + } + } + else + { + for (size_t i = 0; i < split.size() - 1; i++) + { + fprintf(pp, ",%d", split[i]); + } + fprintf(pp, ",-233"); + } + fprintf(pp, " 1=%d", axis - 1); + } + else if (op == "Sqrt") + { + int op_type = 5; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Squeeze") + { + std::vector axes = get_node_attr_ai(node, "axes"); + + if (axes.empty()) + { + fprintf(pp, " 0=1"); + fprintf(pp, " 1=1"); + fprintf(pp, " 2=1"); + } + else + { + bool flag = true; + for (int i = 0; i < (int)axes.size(); i++) + { + if (axes[i] == 0) + { + flag = false; + break; + } + } + if (flag == true) + { + fprintf(pp, " -23303=%zu", axes.size()); + for (int i = 0; i < (int)axes.size(); i++) + { + if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3) + fprintf(stderr, "Unsupported squeeze axes !: %d, %s\n", axes[i], node.name().c_str()); + fprintf(pp, ",%d", axes[i] - 1); + } + } + } + } + else if (op == "Sub") + { + int op_type = 1; + fprintf(pp, " 0=%d", op_type); + + int with_scalar = get_node_attr_i(node, "with_scalar", 0); + float b = get_node_attr_f(node, "b", 0.f); + if (with_scalar) + { + fprintf(pp, " 1=%d", with_scalar); + fprintf(pp, " 2=%e", b); + } + } + else if (op == "Sum") + { + int op_type = 1; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Swish") + { + // no param + } + else if (op == "Tan") + { + int op_type = 11; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "Tanh") + { + int op_type = 16; + fprintf(pp, " 0=%d", op_type); + } + else if (op == "TopK") + { + int axis = get_node_attr_i(node, "axis", -1); + axis = axis > 0 ? axis - 1 : axis; + int largest = get_node_attr_i(node, "largest", 1); + int sorted = get_node_attr_i(node, "sorted", 1); + fprintf(pp, " 0=%d", axis); + fprintf(pp, " 1=%d", largest); + fprintf(pp, " 2=%d", sorted); + } + else if (op == "Transpose") + { + std::vector perm = get_node_attr_ai(node, "perm"); + + if (perm.size() == 3) + { + if (perm[1] == 1 && perm[2] == 2) + fprintf(pp, " 0=0"); // w h + else if (perm[1] == 2 && perm[2] == 1) + fprintf(pp, " 0=1"); // h w + else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2) + fprintf(pp, " 0=0"); // w h + else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1) + fprintf(pp, " 0=1"); // h w + } + else if (perm.size() == 4) + { + if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3) + fprintf(pp, " 0=0"); // w h c + else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2) + fprintf(pp, " 0=1"); // h w c + else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3) + fprintf(pp, " 0=2"); // w c h + else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1) + fprintf(pp, " 0=3"); // c w h + else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2) + fprintf(pp, " 0=4"); // h c w + else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1) + fprintf(pp, " 0=5"); // c h w + } + else if (perm.size() == 5) + { + if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4) + fprintf(pp, " 0=0"); // wx h c + else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2) + fprintf(pp, " 0=1"); // h wx c + else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4) + fprintf(pp, " 0=2"); // wx c h + else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1) + fprintf(pp, " 0=3"); // c wx h + else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2) + fprintf(pp, " 0=4"); // h c wx + else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1) + fprintf(pp, " 0=5"); // c h wx + else + fprintf(stderr, "Unsupported transpose type !\n"); + } + } + else if (op == "Upsample") + { + std::string mode = get_node_attr_s(node, "mode"); + std::string align = get_node_attr_s(node, "coordinate_transformation_mode"); + + std::vector scales; + + if (node.input_size() == 1) + { + scales = get_node_attr_af(node, "scales"); + } + else + { + scales = get_node_attr_from_input_af(weights[node.input(1)]); + } + + int resize_type = 1; + if (mode == "nearest") + { + resize_type = 1; + } + else if (mode == "bilinear" || mode == "linear") + { + resize_type = 2; + } + else if (mode == "trilinear") + { + fprintf(stderr, "Unsupported Upsample mode !\n"); + } + + float h_scale = 1.f; + float w_scale = 1.f; + if (scales.size() == 2) + { + w_scale = scales[1]; + } + else if (scales.size() == 3) + { + h_scale = scales[1]; + w_scale = scales[2]; + } + else if (scales.size() == 4) + { + h_scale = scales[2]; + w_scale = scales[3]; + + if (scales[1] != 1.f) fprintf(stderr, "Unsupported Upsample scales !\n"); + } + else + { + fprintf(stderr, "Unsupported Upsample scales !\n"); + } + + int align_corner = 0; + if (align == "align_corners") + { + align_corner = 1; + } + + fprintf(pp, " 0=%d", resize_type); + fprintf(pp, " 1=%e", h_scale); + fprintf(pp, " 2=%e", w_scale); + fprintf(pp, " 6=%d", align_corner); + } + else if (op == "Unsqueeze") + { + std::vector axes = get_node_attr_ai(node, "axes"); + bool flag = true; + for (int i = 0; i < (int)axes.size(); i++) + { + if (axes[i] == 0) + { + flag = false; + break; + } + } + if (flag) + { + fprintf(pp, " -23303=%zu", axes.size()); + for (int i = 0; i < (int)axes.size(); i++) + { + if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4) + fprintf(stderr, "Unsupported unsqueeze axes !: %d, %s\n", axes[i], node.name().c_str()); + fprintf(pp, ",%d", axes[i] - 1); + } + } + } + else if (op == "Yolov3DetectionOutput") + { + int num_class = get_node_attr_i(node, "num_class"); + int num_box = get_node_attr_i(node, "num_box"); + float confidence_threshold = get_node_attr_f(node, "confidence_threshold"); + float nms_threshold = get_node_attr_f(node, "nms_threshold"); + fprintf(pp, " 0=%d", num_class); + fprintf(pp, " 1=%d", num_box); + fprintf(pp, " 2=%e", confidence_threshold); + fprintf(pp, " 3=%e", nms_threshold); + std::vector biases = get_node_attr_af(node, "biases"); + if (biases.size() > 0) + { + fprintf(pp, " -23304=%zu", biases.size()); + for (int i = 0; i < (int)biases.size(); i++) + { + fprintf(pp, ",%e", biases[i]); + } + } + std::vector mask = get_node_attr_af(node, "mask"); + if (mask.size() > 0) + { + fprintf(pp, " -23305=%zu", mask.size()); + for (int i = 0; i < (int)mask.size(); i++) + { + fprintf(pp, ",%e", mask[i]); + } + } + std::vector anchors_scale = get_node_attr_af(node, "anchors_scale"); + if (anchors_scale.size() > 0) + { + fprintf(pp, " -23306=%zu", anchors_scale.size()); + for (int i = 0; i < (int)anchors_scale.size(); i++) + { + fprintf(pp, ",%e", anchors_scale[i]); + } + } + } + else + { + // TODO op specific param + } + + fprintf(pp, "\n"); + for (int j = 0; j < output_size; j++) + { + const std::string& output_name = node.output(j); + if (node_reference.find(output_name) != node_reference.end()) + { + int refcount = node_reference[output_name]; + if (refcount > 1) + { + char splitname[256]; + sprintf(splitname, "splitncnn_%d", internal_split); + fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount); + + fprintf(pp, " %s", output_name.c_str()); + + for (int k = 0; k < refcount; k++) + { + fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k); + } + fprintf(pp, "\n"); + + internal_split++; + } } - } - } - fwrite_tensor_proto_data(qb, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose kw - { - const float* wptr = - kw.has_raw_data() ? (const float*)kw.raw_data().data() : kw.float_data().data(); - - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim + k]; - fwrite(&vb, sizeof(float), 1, bp); - } - } - } - fwrite_tensor_proto_data(kb, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose vw - { - const float* wptr = - vw.has_raw_data() ? (const float*)vw.raw_data().data() : vw.float_data().data(); - - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim + k]; - fwrite(&vb, sizeof(float), 1, bp); - } - } - } - fwrite_tensor_proto_data(vb, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - // transpose ow - { - const float* wptr = - ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data(); - - for (int j = 0; j < embed_dim; j++) { - for (int k = 0; k < embed_dim; k++) { - float vb = wptr[j * embed_dim + k]; - fwrite(&vb, sizeof(float), 1, bp); - } - } - } - fwrite_tensor_proto_data(ob, bp); - } - } else if (op == "Neg") { - int op_type = 1; - fprintf(pp, " 0=%d", op_type); - } else if (op == "NonMaxSuppression") { - int max_dets = 0; - float iou_thre = 0.f; - float score_thre = 0.f; - // fprintf(stderr, "%s\n", node.name().c_str()); - // fprintf(stderr, "node.input_size(): %d\n", node.input_size()); - if (node.input_size() >= 3) { - // fprintf(stderr, "ok12!\n"); - max_dets = (int)(get_node_attr_from_input(weights[node.input(2)]) + 0.5); - } - if (node.input_size() >= 4) { - // fprintf(stderr, "iou_thre: %f\n", - // get_node_attr_from_input(weights[node.input(3)])); - iou_thre = get_node_attr_from_input(weights[node.input(3)]); - } - if (node.input_size() >= 5) { - // fprintf(stderr, "score_thre: %f\n", - // get_node_attr_from_input(weights[node.input(4)])); - score_thre = get_node_attr_from_input(weights[node.input(4)]); - } - fprintf(pp, " 0=%d", max_dets); - fprintf(pp, " 1=%f", iou_thre); - fprintf(pp, " 2=%f", score_thre); - } else if (op == "Normalize") { - float eps = get_node_attr_f(node, "eps", 0.f); - int scale_data_size = 1; - - fprintf(pp, " 1=1"); // channel_shared - fprintf(pp, " 2=%e", eps); - fprintf(pp, " 3=%d", scale_data_size); - fprintf(pp, " 9=1"); // TODO hardcode pytorch style - - const float scale_data[1] = {1.f}; - fwrite(scale_data, sizeof(float), 1, bp); - } else if (op == "Pad") { - std::string mode = get_node_attr_s(node, "mode"); - float value = get_node_attr_f(node, "value", 0.f); - - std::vector pads; - if (node.input_size() == 1) { - pads = get_node_attr_ai(node, "pads"); - } else { - pads = get_node_attr_from_input_ai(weights[node.input(1)]); - } - int type = 0; - if (mode == "constant") { - type = 0; - } else if (mode == "edge") { - type = 1; - } else if (mode == "reflect") { - type = 2; - } - - int pad_size = (int)pads.size(); - int top = 0; - int bottom = 0; - int left = 0; - int right = 0; - int front = 0; - int behind = 0; - if (pad_size == 8) { - // NCHW - top = pads[2]; - bottom = pads[6]; - left = pads[3]; - right = pads[7]; - front = pads[1]; - behind = pads[5]; - } else if (pad_size == 6) { - // NHW - top = pads[1]; - bottom = pads[4]; - left = pads[2]; - right = pads[5]; - } else { - // NW - left = pads[1]; - right = pads[3]; - } - - fprintf(pp, " 0=%d", top); - fprintf(pp, " 1=%d", bottom); - fprintf(pp, " 2=%d", left); - fprintf(pp, " 3=%d", right); - fprintf(pp, " 4=%d", type); - fprintf(pp, " 5=%e", value); - fprintf(pp, " 7=%d", front); - fprintf(pp, " 8=%d", behind); - } else if (op == "Pow") { - int op_type = 6; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "PriorBox") { - std::vector min_sizes = get_node_attr_af(node, "min_sizes"); - std::vector max_sizes = get_node_attr_af(node, "max_sizes"); - std::vector aspect_ratios = get_node_attr_af(node, "aspect_ratios"); - fprintf(pp, " -23300=%zu", min_sizes.size()); - for (size_t j = 0; j < min_sizes.size(); ++j) { - fprintf(pp, ",%f", min_sizes[j]); - } - fprintf(pp, " -23301=%zu", max_sizes.size()); - for (size_t j = 0; j < max_sizes.size(); ++j) { - fprintf(pp, ",%f", max_sizes[j]); - } - fprintf(pp, " -23302=%zu", aspect_ratios.size()); - for (size_t j = 0; j < aspect_ratios.size(); ++j) { - fprintf(pp, ",%f", aspect_ratios[j]); - } - int image_width = get_node_attr_i(node, "image_width"); - int image_height = get_node_attr_i(node, "image_height"); - float step_width = get_node_attr_f(node, "step_width"); - float step_height = get_node_attr_f(node, "step_height"); - float offset = get_node_attr_f(node, "offset"); - int step_mmdetection = get_node_attr_i(node, "step_mmdetection"); - fprintf(pp, " 9=%d", image_width); - fprintf(pp, " 10=%d", image_height); - fprintf(pp, " 11=%f", step_width); - fprintf(pp, " 12=%f", step_height); - fprintf(pp, " 13=%f", offset); - fprintf(pp, " 14=%d", step_mmdetection); - } else if (op == "PixelShuffle") { - int scale_factor = get_node_attr_i(node, "scale_factor", 1); - fprintf(pp, " 0=%d", scale_factor); - } else if (op == "PRelu") { - const onnx::TensorProto& slope = weights[node.input(1)]; - - int num_slope = get_tensor_proto_data_size(slope); - - fprintf(pp, " 0=%d", num_slope); - - fwrite_tensor_proto_data(slope, bp); - } else if (op == "Reciprocal") { - int op_type = 15; - fprintf(pp, " 0=%d", op_type); - } else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || - op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || - op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp") { - int op_type = -233; - if (op == "ReduceSum") - op_type = 0; - else if (op == "ReduceSumSquare") - op_type = 2; - else if (op == "ReduceMean") - op_type = 3; - else if (op == "ReduceMax") - op_type = 4; - else if (op == "ReduceMin") - op_type = 5; - else if (op == "ReduceProd") - op_type = 6; - else if (op == "ReduceL1") - op_type = 7; - else if (op == "ReduceL2") - op_type = 8; - else if (op == "ReduceLogSum") - op_type = 9; - else if (op == "ReduceLogSumExp") - op_type = 10; - fprintf(pp, " 0=%d", op_type); - - std::vector axes = get_node_attr_ai(node, "axes"); - int keepdims = get_node_attr_i(node, "keepdims", 1); - - if (axes.size() > 0) { - // if axes set, reduce according to axes - fprintf(pp, " 1=%d", 0); - fprintf(pp, " -23303=%zu", axes.size()); - for (size_t j = 0; j < axes.size(); j++) { - if (axes[j] == 0 || axes[j] > 4 || axes[j] < -3) - fprintf(stderr, "Unsupported reduction axes !\n"); - fprintf(pp, ",%d", axes[j] > 0 ? axes[j] - 1 : axes[j]); - } - } else { - // if axes not set, reduce all axes by default - fprintf(pp, " 1=%d", 1); - } - fprintf(pp, " 4=%d", keepdims); - fprintf(pp, " 5=1"); - } else if (op == "Reorg") { - int stride = get_node_attr_i(node, "stride", 1); - fprintf(pp, " 0=%d", stride); - } else if (op == "Reshape") { - std::vector shape; - - if (node.input_size() == 1) { - shape = get_node_attr_ai(node, "shape"); - } else if (weights.find(node.input(1)) != weights.end()) { - shape = get_node_attr_from_input_ai(weights[node.input(1)]); - } else { - fprintf(stderr, "Unsupported reshape weight ! \n"); - } - - if (shape.size() == 1) { - fprintf(pp, " 0=%d", shape[0]); // should never reach here - } else if (shape.size() == 2) { - fprintf(pp, " 0=%d", shape[1]); - } else if (shape.size() == 3) { - fprintf(pp, " 0=%d", shape[2]); - fprintf(pp, " 1=%d", shape[1]); - } else if (shape.size() == 4) { - fprintf(pp, " 0=%d", shape[3]); - fprintf(pp, " 1=%d", shape[2]); - fprintf(pp, " 2=%d", shape[1]); - } else if (shape.size() == 5) { - fprintf(pp, " 0=%d", shape[4] * shape[3]); - fprintf(pp, " 1=%d", shape[2]); - fprintf(pp, " 2=%d", shape[1]); - } - } else if (op == "Resize") { - std::string mode = get_node_attr_s(node, "mode"); - std::string align = get_node_attr_s(node, "coordinate_transformation_mode"); - - std::vector scales; - std::vector sizes; - if (node.input_size() == 2) { - // opset 10 - scales = get_node_attr_from_input_af(weights[node.input(1)]); - } else { - // opset 11+ - scales = get_node_attr_from_input_af(weights[node.input(2)]); - if (node.input_size() >= 4) { - sizes = get_node_attr_from_input_ai(weights[node.input(3)]); - } - } - - int resize_type = 1; - if (mode == "nearest") { - resize_type = 1; - } else if (mode == "linear") { - resize_type = 2; - } else if (mode == "cubic") { - resize_type = 3; - } - - if (scales.empty() && sizes.empty()) { - fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n"); - } - - float h_scale = 1.f; - float w_scale = 1.f; - if (scales.size() == 2) { - w_scale = scales[1]; - } else if (scales.size() == 3) { - h_scale = scales[1]; - w_scale = scales[2]; - } else if (scales.size() == 4) { - h_scale = scales[2]; - w_scale = scales[3]; - - if (scales[1] != 1.f) fprintf(stderr, "Unsupported Resize scales !\n"); - } - - int output_height = 0; - int output_width = 0; - if (sizes.size() == 2) { - output_width = sizes[1]; - } else if (sizes.size() == 3) { - output_height = sizes[1]; - output_width = sizes[2]; - } else if (sizes.size() == 4) { - output_height = sizes[2]; - output_width = sizes[3]; - } - - int align_corner = 0; - if (align == "align_corners") { - align_corner = 1; - } - - fprintf(pp, " 0=%d", resize_type); - fprintf(pp, " 1=%e", h_scale); - fprintf(pp, " 2=%e", w_scale); - fprintf(pp, " 3=%d", output_height); - fprintf(pp, " 4=%d", output_width); - fprintf(pp, " 6=%d", align_corner); - } else if (op == "RNN") { - const onnx::TensorProto& W = weights[node.input(1)]; - const onnx::TensorProto& R = weights[node.input(2)]; - const onnx::TensorProto& B = weights[node.input(3)]; - - int hidden_size = get_node_attr_i(node, "hidden_size", 0); - std::string direction = get_node_attr_s(node, "direction"); - - int direction_type = 0; - if (direction == "forward") { - direction_type = 0; - } else if (direction == "reverse") { - direction_type = 1; - } else if (direction == "bidirectional") { - direction_type = 2; - } - - int weight_data_size = get_tensor_proto_data_size(W); - - fprintf(pp, " 0=%d", hidden_size); - fprintf(pp, " 1=%d", weight_data_size); - fprintf(pp, " 2=%d", direction_type); - - int num_directions = direction_type == 2 ? 2 : 1; - - int quantize_tag = 0; - - fwrite(&quantize_tag, sizeof(int), 1, bp); - fwrite_tensor_proto_data(W, bp); - - // reduce xc and hc bias - { - fwrite(&quantize_tag, sizeof(int), 1, bp); - - int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions; - const float* bptr = - B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data(); - const float* xiptr = bptr; - const float* hiptr = bptr + bias_data_size_g; - - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xiptr[j] + hiptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - - if (direction_type == 2) { - xiptr += bias_data_size_g * 2; - hiptr += bias_data_size_g * 2; - - for (int j = 0; j < bias_data_size_g; j++) { - float vb = xiptr[j] + hiptr[j]; - fwrite(&vb, sizeof(float), 1, bp); - } - } - } - - fwrite(&quantize_tag, sizeof(int), 1, bp); - fwrite_tensor_proto_data(R, bp); - } else if (op == "RDiv") { - int op_type = 8; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "RSub") { - int op_type = 7; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "RoiAlign") { - int pooled_width = get_node_attr_i(node, "output_width", 1); - int pooled_height = get_node_attr_i(node, "output_height", 1); - float spatial_scale = get_node_attr_f(node, "spatial_scale", 1.f); - int sampling_ratio = get_node_attr_i(node, "sampling_ratio", 0); - fprintf(pp, " 0=%d", pooled_width); - fprintf(pp, " 1=%d", pooled_height); - fprintf(pp, " 2=%f", spatial_scale); - fprintf(pp, " 3=%d", sampling_ratio); - } else if (op == "ShuffleChannel") { - int group = get_node_attr_i(node, "group", 1); - int reverse = get_node_attr_i(node, "reverse", 0); - fprintf(pp, " 0=%d", group); - fprintf(pp, " 1=%d", reverse); - } else if (op == "Sigmoid") { - // no param - } else if (op == "Sin") { - int op_type = 9; - fprintf(pp, " 0=%d", op_type); - } else if (op == "SkipLayerNormalization") { - const onnx::TensorProto& W = weights[node.input(2)]; - const onnx::TensorProto& B = weights[node.input(3)]; - const onnx::TensorProto& B2 = weights[node.input(4)]; - - fprintf(pp, " 0=%d", get_tensor_proto_data_size(B)); - - int quantize_tag = 0; - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(W, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(B, bp); - - fwrite(&quantize_tag, sizeof(int), 1, bp); - - fwrite_tensor_proto_data(B2, bp); - } else if (op == "Slice") { - bool use_crop = true; - - std::vector starts; - std::vector ends; - std::vector axes; - std::vector steps; - if (node.input_size() == 1) { - starts = get_node_attr_ai(node, "starts"); - ends = get_node_attr_ai(node, "ends"); - axes = get_node_attr_ai(node, "axes"); - steps = get_node_attr_ai(node, "steps"); // TODO - } else { - starts = get_node_attr_from_input_ai(weights[node.input(1)]); - ends = get_node_attr_from_input_ai(weights[node.input(2)]); - if (node.input_size() >= 4) axes = get_node_attr_from_input_ai(weights[node.input(3)]); - if (node.input_size() >= 5) steps = get_node_attr_from_input_ai(weights[node.input(4)]); - } - - // assert step == 1 or step >= ends - for (int i = 0; i < (int)steps.size(); i++) { - if (steps[i] != 1 && steps[i] < ends[i]) { - use_crop = false; - fprintf(stderr, "Unsupported slice step ! Use custom TensorSlice\n"); - } - } - - if (use_crop) { - // filter out N-dim axis - if (!axes.empty()) { - for (int i = 0; i < (int)axes.size(); i++) { - int axis = axes[i]; - if (axis == 0) { - starts.erase(starts.begin() + i); - ends.erase(ends.begin() + i); - axes.erase(axes.begin() + i); - break; - } - } - } - - fprintf(pp, " -23309=%d", (int)starts.size()); - for (int i = 0; i < (int)starts.size(); i++) { - fprintf(pp, ",%d", starts[i]); - } - fprintf(pp, " -23310=%d", (int)ends.size()); - for (int i = 0; i < (int)ends.size(); i++) { - fprintf(pp, ",%d", ends[i]); - } - if (!axes.empty()) { - fprintf(pp, " -23311=%d", (int)axes.size()); - for (int i = 0; i < (int)axes.size(); i++) { - int axis = axes[i]; - if (axis == 0 || axis > 3 || axis < -3) fprintf(stderr, "Unsupported slice axes !\n"); - - if (axis > 0) axis = axis - 1; // -1 for skip N-dim - - fprintf(pp, ",%d", axis); - } - } - } else { - fprintf(pp, " -23300=%d", (int)starts.size()); - for (int i = 0; i < (int)starts.size(); i++) { - fprintf(pp, ",%d", starts[i]); - } - fprintf(pp, " -23301=%d", (int)ends.size()); - for (int i = 0; i < (int)ends.size(); i++) { - fprintf(pp, ",%d", ends[i]); - } - if (!axes.empty()) { - fprintf(pp, " -23302=%d", (int)axes.size()); - for (int i = 0; i < (int)axes.size(); i++) { - int axis = axes[i]; - if (axis > 3 || axis < -3) fprintf(stderr, "Unsupported slice axes !\n"); - fprintf(pp, ",%d", axis); - } - } - if (!steps.empty()) { - fprintf(pp, " -23303=%d", (int)steps.size()); - for (int i = 0; i < (int)steps.size(); i++) { - int step = steps[i]; - if (step == 0) fprintf(stderr, "Unsupported slice step ! Unsupported slice step\n"); - fprintf(pp, ",%d", step); - } - } - } - } else if (op == "Softmax") { - int axis = get_node_attr_i(node, "axis", 1); - fprintf(pp, " 0=%d", axis - 1); - fprintf(pp, " 1=1"); - } else if (op == "Split") { - int axis = get_node_attr_i(node, "axis", 0); - std::vector split = get_node_attr_ai(node, "split"); - if (axis < 1) fprintf(stderr, "Unsupported split axis !\n"); - - fprintf(pp, " -23300=%d", output_size); - if (split.empty()) { - for (int i = 0; i < output_size; i++) { - fprintf(pp, ",-233"); - } - } else { - for (size_t i = 0; i < split.size() - 1; i++) { - fprintf(pp, ",%d", split[i]); - } - fprintf(pp, ",-233"); - } - fprintf(pp, " 1=%d", axis - 1); - } else if (op == "Sqrt") { - int op_type = 5; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Squeeze") { - std::vector axes = get_node_attr_ai(node, "axes"); - - if (axes.empty()) { - fprintf(pp, " 0=1"); - fprintf(pp, " 1=1"); - fprintf(pp, " 2=1"); - } else { - bool flag = true; - for (int i = 0; i < (int)axes.size(); i++) { - if (axes[i] == 0) { - flag = false; - break; - } - } - if (flag == true) { - fprintf(pp, " -23303=%zu", axes.size()); - for (int i = 0; i < (int)axes.size(); i++) { - if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3) - fprintf(stderr, "Unsupported squeeze axes !: %d, %s\n", axes[i], node.name().c_str()); - fprintf(pp, ",%d", axes[i] - 1); - } - } - } - } else if (op == "Sub") { - int op_type = 1; - fprintf(pp, " 0=%d", op_type); - - int with_scalar = get_node_attr_i(node, "with_scalar", 0); - float b = get_node_attr_f(node, "b", 0.f); - if (with_scalar) { - fprintf(pp, " 1=%d", with_scalar); - fprintf(pp, " 2=%e", b); - } - } else if (op == "Sum") { - int op_type = 1; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Swish") { - // no param - } else if (op == "Tan") { - int op_type = 11; - fprintf(pp, " 0=%d", op_type); - } else if (op == "Tanh") { - int op_type = 16; - fprintf(pp, " 0=%d", op_type); - } else if (op == "TopK") { - int axis = get_node_attr_i(node, "axis", -1); - axis = axis > 0 ? axis - 1 : axis; - int largest = get_node_attr_i(node, "largest", 1); - int sorted = get_node_attr_i(node, "sorted", 1); - fprintf(pp, " 0=%d", axis); - fprintf(pp, " 1=%d", largest); - fprintf(pp, " 2=%d", sorted); - } else if (op == "Transpose") { - std::vector perm = get_node_attr_ai(node, "perm"); - - if (perm.size() == 3) { - if (perm[1] == 1 && perm[2] == 2) - fprintf(pp, " 0=0"); // w h - else if (perm[1] == 2 && perm[2] == 1) - fprintf(pp, " 0=1"); // h w - else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2) - fprintf(pp, " 0=0"); // w h - else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1) - fprintf(pp, " 0=1"); // h w - } else if (perm.size() == 4) { - if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3) - fprintf(pp, " 0=0"); // w h c - else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2) - fprintf(pp, " 0=1"); // h w c - else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3) - fprintf(pp, " 0=2"); // w c h - else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1) - fprintf(pp, " 0=3"); // c w h - else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2) - fprintf(pp, " 0=4"); // h c w - else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1) - fprintf(pp, " 0=5"); // c h w - } else if (perm.size() == 5) { - if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4) - fprintf(pp, " 0=0"); // wx h c - else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2) - fprintf(pp, " 0=1"); // h wx c - else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4) - fprintf(pp, " 0=2"); // wx c h - else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1) - fprintf(pp, " 0=3"); // c wx h - else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2) - fprintf(pp, " 0=4"); // h c wx - else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1) - fprintf(pp, " 0=5"); // c h wx - else - fprintf(stderr, "Unsupported transpose type !\n"); - } - } else if (op == "Upsample") { - std::string mode = get_node_attr_s(node, "mode"); - std::string align = get_node_attr_s(node, "coordinate_transformation_mode"); - - std::vector scales; - - if (node.input_size() == 1) { - scales = get_node_attr_af(node, "scales"); - } else { - scales = get_node_attr_from_input_af(weights[node.input(1)]); - } - - int resize_type = 1; - if (mode == "nearest") { - resize_type = 1; - } else if (mode == "bilinear" || mode == "linear") { - resize_type = 2; - } else if (mode == "trilinear") { - fprintf(stderr, "Unsupported Upsample mode !\n"); - } - - float h_scale = 1.f; - float w_scale = 1.f; - if (scales.size() == 2) { - w_scale = scales[1]; - } else if (scales.size() == 3) { - h_scale = scales[1]; - w_scale = scales[2]; - } else if (scales.size() == 4) { - h_scale = scales[2]; - w_scale = scales[3]; - - if (scales[1] != 1.f) fprintf(stderr, "Unsupported Upsample scales !\n"); - } else { - fprintf(stderr, "Unsupported Upsample scales !\n"); - } - - int align_corner = 0; - if (align == "align_corners") { - align_corner = 1; - } - - fprintf(pp, " 0=%d", resize_type); - fprintf(pp, " 1=%e", h_scale); - fprintf(pp, " 2=%e", w_scale); - fprintf(pp, " 6=%d", align_corner); - } else if (op == "Unsqueeze") { - std::vector axes = get_node_attr_ai(node, "axes"); - bool flag = true; - for (int i = 0; i < (int)axes.size(); i++) { - if (axes[i] == 0) { - flag = false; - break; - } - } - if (flag) { - fprintf(pp, " -23303=%zu", axes.size()); - for (int i = 0; i < (int)axes.size(); i++) { - if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4) - fprintf(stderr, "Unsupported unsqueeze axes !: %d, %s\n", axes[i], node.name().c_str()); - fprintf(pp, ",%d", axes[i] - 1); - } - } - } else if (op == "Yolov3DetectionOutput") { - int num_class = get_node_attr_i(node, "num_class"); - int num_box = get_node_attr_i(node, "num_box"); - float confidence_threshold = get_node_attr_f(node, "confidence_threshold"); - float nms_threshold = get_node_attr_f(node, "nms_threshold"); - fprintf(pp, " 0=%d", num_class); - fprintf(pp, " 1=%d", num_box); - fprintf(pp, " 2=%e", confidence_threshold); - fprintf(pp, " 3=%e", nms_threshold); - std::vector biases = get_node_attr_af(node, "biases"); - if (biases.size() > 0) { - fprintf(pp, " -23304=%zu", biases.size()); - for (int i = 0; i < (int)biases.size(); i++) { - fprintf(pp, ",%e", biases[i]); - } - } - std::vector mask = get_node_attr_af(node, "mask"); - if (mask.size() > 0) { - fprintf(pp, " -23305=%zu", mask.size()); - for (int i = 0; i < (int)mask.size(); i++) { - fprintf(pp, ",%e", mask[i]); - } - } - std::vector anchors_scale = get_node_attr_af(node, "anchors_scale"); - if (anchors_scale.size() > 0) { - fprintf(pp, " -23306=%zu", anchors_scale.size()); - for (int i = 0; i < (int)anchors_scale.size(); i++) { - fprintf(pp, ",%e", anchors_scale[i]); - } - } - } else { - // TODO op specific param - } - - fprintf(pp, "\n"); - for (int j = 0; j < output_size; j++) { - const std::string& output_name = node.output(j); - if (node_reference.find(output_name) != node_reference.end()) { - int refcount = node_reference[output_name]; - if (refcount > 1) { - char splitname[256]; - sprintf(splitname, "splitncnn_%d", internal_split); - fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount); - - fprintf(pp, " %s", output_name.c_str()); - - for (int k = 0; k < refcount; k++) { - fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k); - } - fprintf(pp, "\n"); - - internal_split++; } - } } - } - fclose(pp); - fclose(bp); - fprintf(stderr, "onnx2ncnn finish\n"); - return 0; + fclose(pp); + fclose(bp); + fprintf(stderr, "onnx2ncnn finish\n"); + return 0; } diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp index dd1fe2c4f6..42482ee8b8 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.cpp @@ -14,157 +14,179 @@ * @return std::tuple> */ std::tuple> query_shape( - onnx::GraphProto* mutable_graph, onnx::NodeProto* target, + onnx::GraphProto* mutable_graph, + onnx::NodeProto* target, const std::map& weights, - std::map>& context) { - // emplace all input nodes - const int input_count = mutable_graph->input_size(); - for (int i = 0; i < input_count; i++) { - auto inp = mutable_graph->input(i); - onnx::TypeProto inp_type = inp.type(); - onnx::TensorShapeProto shape_proto = inp_type.tensor_type().shape(); - - auto dim_size = shape_proto.dim_size(); - std::vector shape(dim_size); - for (int index = 0; index < dim_size; ++index) { - shape[index] = shape_proto.dim(index).dim_value(); - } - - context.emplace(inp.name(), shape); - } - - // BFS the tree, `target` as root, onnx::graph inputs and weights as leaf nodes - std::vector serial = {target}; - { - std::set mark_as_appended = {}; - while (true) { - int start = 0, end = serial.size(); - for (int i = start; i < end; ++i) { - auto node_ptr = serial[i]; - auto len = node_ptr->input_size(); - - for (int j = 0; j < len; ++j) { - std::string name = node_ptr->input(j); - if (context.find(name) != context.end()) { - // if input founded, skip - continue; - } - - if (weights.find(name) != weights.end()) { - // if founded in weights, extract shape to context - auto weight = weights.at(name); - std::vector shape; - for (auto index = 0; index < weight.dims_size(); ++index) { - shape.emplace_back(weight.dims(index)); - } - context.emplace(name, shape); - continue; - } - - if (mark_as_appended.find(name) != mark_as_appended.end()) { - // if mark as appended, skip - continue; - } - // else append it to serialization list - auto depend_ptr = find_node_by_output_name(mutable_graph, name); - if (depend_ptr == nullptr) { - fprintf(stderr, "cannot find %s from graph !\n", name.c_str()); - return std::make_tuple(false, std::vector{}); - } - mark_as_appended.insert(name); - serial.emplace_back(depend_ptr); + std::map>& context) +{ + // emplace all input nodes + const int input_count = mutable_graph->input_size(); + for (int i = 0; i < input_count; i++) + { + auto inp = mutable_graph->input(i); + onnx::TypeProto inp_type = inp.type(); + onnx::TensorShapeProto shape_proto = inp_type.tensor_type().shape(); + + auto dim_size = shape_proto.dim_size(); + std::vector shape(dim_size); + for (int index = 0; index < dim_size; ++index) + { + shape[index] = shape_proto.dim(index).dim_value(); } - } - if (serial.size() <= end) { - // if not new node added, quit - break; - } - - // update start and end position, continue BFS the tree - start = end; - end = serial.size(); + context.emplace(inp.name(), shape); } - } - - // for each node in serialization list, calculate the output shape - { - std::reverse(serial.begin(), serial.end()); - for (auto node : serial) { - if (node->op_type() == "Conv") { - auto inp = context[node->input(0)]; - auto weight = context[node->input(1)]; - assert(inp.size() == 4 and weight.size() == 4); - - int group = get_node_attr_i(*node, "group", 1); - assert(group == 1); - - // treat multiple spatial attr as single one -#define EXTRACT_REPEATED_PARAM(NAME, ATTR, DEFAULT) \ - int ATTR = DEFAULT; \ - { \ - std::vector _vec = get_node_attr_ai(*node, NAME); \ - if (not _vec.empty()) { \ - ATTR = _vec[0]; \ - } \ - } - - EXTRACT_REPEATED_PARAM("dilations", dilation, 1); - EXTRACT_REPEATED_PARAM("pads", pad, 0); - EXTRACT_REPEATED_PARAM("strides", stride, 1); - -#undef EXTRACT_REPEATED_PARAM - int on = inp[0]; - int oc = weight[0]; - int oh = (inp[2] + 2 * pad - weight[2]) / stride + 1; - int ow = (inp[3] + 2 * pad - weight[3]) / stride + 1; - context.emplace(node->output(0), std::vector{on, oc, oh, ow}); - - } else if (node->op_type() == "Shape") { - auto inp = context[node->input(0)]; - context.emplace(node->output(0), std::vector{1, inp[1], inp[2], inp[3]}); - - } else if (node->op_type() == "Slice") { - assert(node->input_size() >= 4); + // BFS the tree, `target` as root, onnx::graph inputs and weights as leaf nodes + std::vector serial = {target}; + { + std::set mark_as_appended = {}; + while (true) + { + int start = 0, end = serial.size(); + for (int i = start; i < end; ++i) + { + auto node_ptr = serial[i]; + auto len = node_ptr->input_size(); + + for (int j = 0; j < len; ++j) + { + std::string name = node_ptr->input(j); + if (context.find(name) != context.end()) + { + // if input founded, skip + continue; + } + + if (weights.find(name) != weights.end()) + { + // if founded in weights, extract shape to context + auto weight = weights.at(name); + std::vector shape; + for (auto index = 0; index < weight.dims_size(); ++index) + { + shape.emplace_back(weight.dims(index)); + } + context.emplace(name, shape); + continue; + } + + if (mark_as_appended.find(name) != mark_as_appended.end()) + { + // if mark as appended, skip + continue; + } + // else append it to serialization list + auto depend_ptr = find_node_by_output_name(mutable_graph, name); + if (depend_ptr == nullptr) + { + fprintf(stderr, "cannot find %s from graph !\n", name.c_str()); + return std::make_tuple(false, std::vector{}); + } + mark_as_appended.insert(name); + serial.emplace_back(depend_ptr); + } + } - auto inp = context[node->input(0)]; - int start = get_node_attr_from_input(weights.at(node->input(1))); - int end = get_node_attr_from_input(weights.at(node->input(2))); - int axes = get_node_attr_from_input(weights.at(node->input(3))); + if (serial.size() <= end) + { + // if not new node added, quit + break; + } - if (axes != 0) { - fprintf(stderr, "Not support axes=%d !\n", axes); - return std::make_tuple(false, std::vector{}); + // update start and end position, continue BFS the tree + start = end; + end = serial.size(); } + } - assert(inp.size() >= end - start); - context.emplace(node->output(0), std::vector{inp.begin() + start, inp.begin() + end}); - - } else if (node->op_type() == "Concat") { - assert(node->input_size() >= 2); - - auto axis = get_node_attr_i(*node, "axis", 0); - if (axis != 0) { - fprintf(stderr, "Not support axes=%d !\n", axis); - return std::make_tuple(false, std::vector{}); - } + // for each node in serialization list, calculate the output shape + { + std::reverse(serial.begin(), serial.end()); + for (auto node : serial) + { + if (node->op_type() == "Conv") + { + auto inp = context[node->input(0)]; + auto weight = context[node->input(1)]; + assert(inp.size() == 4 and weight.size() == 4); + + int group = get_node_attr_i(*node, "group", 1); + assert(group == 1); + + // treat multiple spatial attr as single one +#define EXTRACT_REPEATED_PARAM(NAME, ATTR, DEFAULT) \ + int ATTR = DEFAULT; \ + { \ + std::vector _vec = get_node_attr_ai(*node, NAME); \ + if (not _vec.empty()) \ + { \ + ATTR = _vec[0]; \ + } \ + } - std::vector inp = context[node->input(0)]; - std::vector w_data = get_node_attr_from_input_ai(weights.at(node->input(1))); + EXTRACT_REPEATED_PARAM("dilations", dilation, 1); + EXTRACT_REPEATED_PARAM("pads", pad, 0); + EXTRACT_REPEATED_PARAM("strides", stride, 1); - // concat data on axis 0 - inp.insert(inp.end(), w_data.begin(), w_data.end()); - context.emplace(node->output(0), inp); +#undef EXTRACT_REPEATED_PARAM - } else { - fprintf(stderr, "Unsupported type %s in query_shape !\n", node->op_type().c_str()); - return std::make_tuple(false, std::vector{}); - } + int on = inp[0]; + int oc = weight[0]; + int oh = (inp[2] + 2 * pad - weight[2]) / stride + 1; + int ow = (inp[3] + 2 * pad - weight[3]) / stride + 1; + context.emplace(node->output(0), std::vector{on, oc, oh, ow}); + } + else if (node->op_type() == "Shape") + { + auto inp = context[node->input(0)]; + context.emplace(node->output(0), std::vector{1, inp[1], inp[2], inp[3]}); + } + else if (node->op_type() == "Slice") + { + assert(node->input_size() >= 4); + + auto inp = context[node->input(0)]; + int start = get_node_attr_from_input(weights.at(node->input(1))); + int end = get_node_attr_from_input(weights.at(node->input(2))); + int axes = get_node_attr_from_input(weights.at(node->input(3))); + + if (axes != 0) + { + fprintf(stderr, "Not support axes=%d !\n", axes); + return std::make_tuple(false, std::vector{}); + } + + assert(inp.size() >= end - start); + context.emplace(node->output(0), std::vector{inp.begin() + start, inp.begin() + end}); + } + else if (node->op_type() == "Concat") + { + assert(node->input_size() >= 2); + + auto axis = get_node_attr_i(*node, "axis", 0); + if (axis != 0) + { + fprintf(stderr, "Not support axes=%d !\n", axis); + return std::make_tuple(false, std::vector{}); + } + + std::vector inp = context[node->input(0)]; + std::vector w_data = get_node_attr_from_input_ai(weights.at(node->input(1))); + + // concat data on axis 0 + inp.insert(inp.end(), w_data.begin(), w_data.end()); + context.emplace(node->output(0), inp); + } + else + { + fprintf(stderr, "Unsupported type %s in query_shape !\n", node->op_type().c_str()); + return std::make_tuple(false, std::vector{}); + } + } } - } - assert(context.find(target->output(0)) != context.end()); - auto target_shape = context[target->output(0)]; - return std::make_tuple(true, target_shape); + assert(context.find(target->output(0)) != context.end()); + auto target_shape = context[target->output(0)]; + return std::make_tuple(true, target_shape); } diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.h b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.h index fa62ffe9de..e7a29a2cef 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.h +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/shape_inference.h @@ -14,6 +14,7 @@ * @return std::tuple> */ std::tuple> query_shape( - onnx::GraphProto* mutable_graph, onnx::NodeProto* target, + onnx::GraphProto* mutable_graph, + onnx::NodeProto* target, const std::map& weights, - std::map>& context); + std::map>& context); diff --git a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/utils.h b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/utils.h index 792db0ed34..ab991a52f9 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/utils.h +++ b/csrc/mmdeploy/backend_ops/ncnn/onnx2ncnn/utils.h @@ -21,381 +21,496 @@ * @param name * @return onnx::NodeProto* */ -static onnx::NodeProto* find_node_by_output_name(onnx::GraphProto* mutable_graph, - const std::string& name) { - const int input_count = mutable_graph->node_size(); - for (int i = 0; i < input_count; ++i) { - onnx::NodeProto* node = mutable_graph->mutable_node(i); - - for (int j = 0; j < node->output_size(); ++j) { - auto output = node->output(j); - if (output == name) { - return node; - } +static onnx::NodeProto* find_node_by_output_name(onnx::GraphProto* mutable_graph, + const std::string& name) +{ + const int input_count = mutable_graph->node_size(); + for (int i = 0; i < input_count; ++i) + { + onnx::NodeProto* node = mutable_graph->mutable_node(i); + + for (int j = 0; j < node->output_size(); ++j) + { + auto output = node->output(j); + if (output == name) + { + return node; + } + } } - } - return nullptr; + return nullptr; } -static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message) { - std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary); - if (!fs.is_open()) { - fprintf(stderr, "open failed %s\n", filepath); - return false; - } +static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message) +{ + std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary); + if (!fs.is_open()) + { + fprintf(stderr, "open failed %s\n", filepath); + return false; + } - google::protobuf::io::IstreamInputStream input(&fs); - google::protobuf::io::CodedInputStream codedstr(&input); + google::protobuf::io::IstreamInputStream input(&fs); + google::protobuf::io::CodedInputStream codedstr(&input); #if GOOGLE_PROTOBUF_VERSION >= 3011000 - codedstr.SetTotalBytesLimit(INT_MAX); + codedstr.SetTotalBytesLimit(INT_MAX); #else - codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2); + codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2); #endif - bool success = message->ParseFromCodedStream(&codedstr); + bool success = message->ParseFromCodedStream(&codedstr); - fs.close(); + fs.close(); - return success; + return success; } -static std::vector get_node_attr_ai(const onnx::NodeProto& node, const char* key) { - std::vector v; +static std::vector get_node_attr_ai(const onnx::NodeProto& node, const char* key) +{ + std::vector v; + + for (int i = 0; i < node.attribute_size(); i++) + { + const onnx::AttributeProto& attr = node.attribute(i); + if (attr.name() == key) + { + v.resize(attr.ints_size()); + for (int j = 0; j < attr.ints_size(); j++) + { + v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX), + (::google::protobuf::int64)INT_MIN); + } + + break; + } + } - for (int i = 0; i < node.attribute_size(); i++) { - const onnx::AttributeProto& attr = node.attribute(i); - if (attr.name() == key) { - v.resize(attr.ints_size()); - for (int j = 0; j < attr.ints_size(); j++) { - v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX), - (::google::protobuf::int64)INT_MIN); - } + return v; +} - break; +static void set_node_attr_ai(onnx::NodeProto& node, const char* key, const std::vector& value) +{ + onnx::AttributeProto* attr_group = node.add_attribute(); + attr_group->set_name(key); + for (auto v : value) + { + attr_group->add_ints(v); } - } - return v; + return; } -static void set_node_attr_ai(onnx::NodeProto& node, const char* key, - const std::vector& value) { - onnx::AttributeProto* attr_group = node.add_attribute(); - attr_group->set_name(key); - for (auto v : value) { - attr_group->add_ints(v); - } +static std::vector get_node_attr_af(const onnx::NodeProto& node, const char* key) +{ + std::vector v; + + for (int i = 0; i < node.attribute_size(); i++) + { + const onnx::AttributeProto& attr = node.attribute(i); + if (attr.name() == key) + { + v.resize(attr.floats_size()); + for (int j = 0; j < attr.floats_size(); j++) + { + v[j] = attr.floats(j); + } + + break; + } + } - return; + return v; } -static std::vector get_node_attr_af(const onnx::NodeProto& node, const char* key) { - std::vector v; +static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0) +{ + for (int i = 0; i < node.attribute_size(); i++) + { + const onnx::AttributeProto& attr = node.attribute(i); + if (attr.name() == key) + { + return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX), + (::google::protobuf::int64)INT_MIN); + } + } - for (int i = 0; i < node.attribute_size(); i++) { - const onnx::AttributeProto& attr = node.attribute(i); - if (attr.name() == key) { - v.resize(attr.floats_size()); - for (int j = 0; j < attr.floats_size(); j++) { - v[j] = attr.floats(j); - } + return def; +} - break; +static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f) +{ + for (int i = 0; i < node.attribute_size(); i++) + { + const onnx::AttributeProto& attr = node.attribute(i); + if (attr.name() == key) + { + return attr.f(); + } } - } - return v; + return def; } -static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0) { - for (int i = 0; i < node.attribute_size(); i++) { - const onnx::AttributeProto& attr = node.attribute(i); - if (attr.name() == key) { - return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX), - (::google::protobuf::int64)INT_MIN); +static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, const std::string& def = std::string()) +{ + for (int i = 0; i < node.attribute_size(); i++) + { + const onnx::AttributeProto& attr = node.attribute(i); + if (attr.name() == key) + { + return attr.s(); + } } - } - return def; + return def; } -static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f) { - for (int i = 0; i < node.attribute_size(); i++) { - const onnx::AttributeProto& attr = node.attribute(i); - if (attr.name() == key) { - return attr.f(); +static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key) +{ + for (int i = 0; i < node.attribute_size(); i++) + { + const onnx::AttributeProto& attr = node.attribute(i); + if (attr.name() == key) + { + return attr.t(); + } } - } - return def; + return onnx::TensorProto(); } -static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, - const std::string& def = std::string()) { - for (int i = 0; i < node.attribute_size(); i++) { - const onnx::AttributeProto& attr = node.attribute(i); - if (attr.name() == key) { - return attr.s(); +template +static T get_node_attr_from_input(const onnx::TensorProto& tp) +{ + T v = 0.f; + + // float + if (tp.data_type() == 1) + { + const float* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const float*)tp.raw_data().data(); + } + else + { + shape_data = tp.float_data().data(); + } + v = shape_data[0]; + } + // double + else if (tp.data_type() == 11) + { + const double* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const double*)tp.raw_data().data(); + } + else + { + shape_data = tp.double_data().data(); + } + v = shape_data[0]; + } + // int64 + else if (tp.data_type() == 7) + { + const int64_t* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const int64_t*)tp.raw_data().data(); + } + else + { + shape_data = tp.int64_data().data(); + } + v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), + (::google::protobuf::int64)INT_MIN); + } + // int32 + else if (tp.data_type() == 6) + { + const int32_t* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const int32_t*)tp.raw_data().data(); + } + else + { + shape_data = tp.int32_data().data(); + } + v = shape_data[0]; + } + else + { + // fprintf(stderr, "tp.name: %s\n", tp.name().c_str()); + fprintf(stderr, "Unknown data type %d\n", tp.data_type()); + fprintf(stderr, "get_node_attr_from_input\n"); + abort(); } - } - return def; + return v; } -static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key) { - for (int i = 0; i < node.attribute_size(); i++) { - const onnx::AttributeProto& attr = node.attribute(i); - if (attr.name() == key) { - return attr.t(); +static std::vector get_node_attr_from_input_ai(const onnx::TensorProto& tp) +{ + int size = 0; + + std::vector v; + + // int64 + if (tp.data_type() == 7) + { + const int64_t* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const int64_t*)tp.raw_data().data(); + size = (int)(tp.raw_data().size() / 8); + } + else + { + shape_data = tp.int64_data().data(); + size = tp.int64_data_size(); + } + for (int j = 0; j < size; j++) + { + int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX), + (::google::protobuf::int64)INT_MIN); + v.push_back(vi); + } + } + // int32 + else if (tp.data_type() == 6) + { + const int32_t* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const int32_t*)tp.raw_data().data(); + size = (int)(tp.raw_data().size() / 4); + } + else + { + shape_data = tp.int32_data().data(); + size = tp.int32_data_size(); + } + for (int j = 0; j < size; j++) + { + v.push_back(shape_data[j]); + } + } + else + { + fprintf(stderr, "Unknown data type %d\n", tp.data_type()); } - } - return onnx::TensorProto(); + return v; } -template -static T get_node_attr_from_input(const onnx::TensorProto& tp) { - T v = 0.f; - - // float - if (tp.data_type() == 1) { - const float* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const float*)tp.raw_data().data(); - } else { - shape_data = tp.float_data().data(); - } - v = shape_data[0]; - } - // double - else if (tp.data_type() == 11) { - const double* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const double*)tp.raw_data().data(); - } else { - shape_data = tp.double_data().data(); - } - v = shape_data[0]; - } - // int64 - else if (tp.data_type() == 7) { - const int64_t* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const int64_t*)tp.raw_data().data(); - } else { - shape_data = tp.int64_data().data(); - } - v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), - (::google::protobuf::int64)INT_MIN); - } - // int32 - else if (tp.data_type() == 6) { - const int32_t* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const int32_t*)tp.raw_data().data(); - } else { - shape_data = tp.int32_data().data(); - } - v = shape_data[0]; - } else { - // fprintf(stderr, "tp.name: %s\n", tp.name().c_str()); - fprintf(stderr, "Unknown data type %d\n", tp.data_type()); - fprintf(stderr, "get_node_attr_from_input\n"); - abort(); - } - - return v; -} +static std::vector get_node_attr_from_input_af(const onnx::TensorProto& tp) +{ + int size = 0; + + std::vector v; + + // float + if (tp.data_type() == 1) + { + const float* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const float*)tp.raw_data().data(); + size = (int)(tp.raw_data().size() / 4); + } + else + { + shape_data = tp.float_data().data(); + size = tp.float_data_size(); + } + for (int j = 0; j < size; j++) + { + v.push_back(shape_data[j]); + } + } + // double + else if (tp.data_type() == 11) + { + const double* shape_data = 0; + if (tp.has_raw_data()) + { + shape_data = (const double*)tp.raw_data().data(); + size = (int)(tp.raw_data().size() / 8); + } + else + { + shape_data = tp.double_data().data(); + size = tp.double_data_size(); + } + for (int j = 0; j < size; j++) + { + v.push_back((float)shape_data[j]); + } + } + else + { + fprintf(stderr, "Unknown data type %d\n", tp.data_type()); + } -static std::vector get_node_attr_from_input_ai(const onnx::TensorProto& tp) { - int size = 0; - - std::vector v; - - // int64 - if (tp.data_type() == 7) { - const int64_t* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const int64_t*)tp.raw_data().data(); - size = (int)(tp.raw_data().size() / 8); - } else { - shape_data = tp.int64_data().data(); - size = tp.int64_data_size(); - } - for (int j = 0; j < size; j++) { - int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX), - (::google::protobuf::int64)INT_MIN); - v.push_back(vi); - } - } - // int32 - else if (tp.data_type() == 6) { - const int32_t* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const int32_t*)tp.raw_data().data(); - size = (int)(tp.raw_data().size() / 4); - } else { - shape_data = tp.int32_data().data(); - size = tp.int32_data_size(); - } - for (int j = 0; j < size; j++) { - v.push_back(shape_data[j]); - } - } else { - fprintf(stderr, "Unknown data type %d\n", tp.data_type()); - } - - return v; + return v; } -static std::vector get_node_attr_from_input_af(const onnx::TensorProto& tp) { - int size = 0; - - std::vector v; - - // float - if (tp.data_type() == 1) { - const float* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const float*)tp.raw_data().data(); - size = (int)(tp.raw_data().size() / 4); - } else { - shape_data = tp.float_data().data(); - size = tp.float_data_size(); - } - for (int j = 0; j < size; j++) { - v.push_back(shape_data[j]); - } - } - // double - else if (tp.data_type() == 11) { - const double* shape_data = 0; - if (tp.has_raw_data()) { - shape_data = (const double*)tp.raw_data().data(); - size = (int)(tp.raw_data().size() / 8); - } else { - shape_data = tp.double_data().data(); - size = tp.double_data_size(); - } - for (int j = 0; j < size; j++) { - v.push_back((float)shape_data[j]); - } - } else { - fprintf(stderr, "Unknown data type %d\n", tp.data_type()); - } - - return v; -} +static int get_tensor_proto_data_size(const onnx::TensorProto& tp) +{ + if (tp.has_raw_data()) + { + if (tp.data_type() == 1 || tp.data_type() == 6) + { + const std::string& raw_data = tp.raw_data(); + int size = (int)raw_data.size() / 4; + return size; + } + else if (tp.data_type() == 7 || tp.data_type() == 11) + { + const std::string& raw_data = tp.raw_data(); + int size = (int)raw_data.size() / 8; + return size; + } + else if (tp.data_type() == 9) + { + const std::string& raw_data = tp.raw_data(); + return 0; + } + } + else if (tp.data_type() == 1) + { + return tp.float_data_size(); + } + else if (tp.data_type() == 7) + { + return tp.int64_data_size(); + } + else if (tp.data_type() == 6) + { + return tp.int32_data_size(); + } + else if (tp.data_type() == 11) + { + return tp.double_data_size(); + } -static int get_tensor_proto_data_size(const onnx::TensorProto& tp) { - if (tp.has_raw_data()) { - if (tp.data_type() == 1 || tp.data_type() == 6) { - const std::string& raw_data = tp.raw_data(); - int size = (int)raw_data.size() / 4; - return size; - } else if (tp.data_type() == 7 || tp.data_type() == 11) { - const std::string& raw_data = tp.raw_data(); - int size = (int)raw_data.size() / 8; - return size; - } else if (tp.data_type() == 9) { - const std::string& raw_data = tp.raw_data(); - return 0; - } - } else if (tp.data_type() == 1) { - return tp.float_data_size(); - } else if (tp.data_type() == 7) { - return tp.int64_data_size(); - } else if (tp.data_type() == 6) { - return tp.int32_data_size(); - } else if (tp.data_type() == 11) { - return tp.double_data_size(); - } - - return 0; + return 0; } -static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp) { - int size = get_tensor_proto_data_size(tp); +static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp) +{ + int size = get_tensor_proto_data_size(tp); - if (tp.has_raw_data()) { - const std::string& raw_data = tp.raw_data(); - fwrite(raw_data.data(), sizeof(float), size, bp); - } else if (tp.data_type() == 1) { - fwrite(tp.float_data().data(), sizeof(float), size, bp); - } + if (tp.has_raw_data()) + { + const std::string& raw_data = tp.raw_data(); + fwrite(raw_data.data(), sizeof(float), size, bp); + } + else if (tp.data_type() == 1) + { + fwrite(tp.float_data().data(), sizeof(float), size, bp); + } } -static void fwrite_tensor_proto_data_to_float(const onnx::TensorProto& tp, FILE* bp) { - int size = get_tensor_proto_data_size(tp); - size_t written_size; - if (tp.has_raw_data()) { - const std::string& raw_data = tp.raw_data(); - if (tp.data_type() == 6) { - int* intdataptr = (int*)raw_data.data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)intdataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } else if (tp.data_type() == 7) { - int64_t* intdataptr = (int64_t*)raw_data.data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)intdataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } else if (tp.data_type() == 9) { - bool* intdataptr = (bool*)raw_data.data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)intdataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } else if (tp.data_type() == 11) { - double* doubledataptr = (double*)raw_data.data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)doubledataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } - } else if (tp.data_type() == 6) { - int* intdataptr = (int*)tp.int32_data().data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)intdataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } else if (tp.data_type() == 7) { - int64_t* intdataptr = (int64_t*)tp.int64_data().data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)intdataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } else if (tp.data_type() == 9) { - int* intdataptr = (int*)tp.int64_data().data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)intdataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } else if (tp.data_type() == 11) { - double* doubledataptr = (double*)tp.double_data().data(); - float* floatdataptr = (float*)std::malloc(sizeof(float) * size); - for (int i = 0; i < size; i++) { - floatdataptr[i] = (float)doubledataptr[i]; - } - written_size = fwrite(floatdataptr, sizeof(float), size, bp); - std::free(floatdataptr); - } +static void fwrite_tensor_proto_data_to_float(const onnx::TensorProto& tp, FILE* bp) +{ + int size = get_tensor_proto_data_size(tp); + size_t written_size; + if (tp.has_raw_data()) + { + const std::string& raw_data = tp.raw_data(); + if (tp.data_type() == 6) + { + int* intdataptr = (int*)raw_data.data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)intdataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + else if (tp.data_type() == 7) + { + int64_t* intdataptr = (int64_t*)raw_data.data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)intdataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + else if (tp.data_type() == 9) + { + bool* intdataptr = (bool*)raw_data.data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)intdataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + else if (tp.data_type() == 11) + { + double* doubledataptr = (double*)raw_data.data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)doubledataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + } + else if (tp.data_type() == 6) + { + int* intdataptr = (int*)tp.int32_data().data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)intdataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + else if (tp.data_type() == 7) + { + int64_t* intdataptr = (int64_t*)tp.int64_data().data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)intdataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + else if (tp.data_type() == 9) + { + int* intdataptr = (int*)tp.int64_data().data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)intdataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } + else if (tp.data_type() == 11) + { + double* doubledataptr = (double*)tp.double_data().data(); + float* floatdataptr = (float*)std::malloc(sizeof(float) * size); + for (int i = 0; i < size; i++) + { + floatdataptr[i] = (float)doubledataptr[i]; + } + written_size = fwrite(floatdataptr, sizeof(float), size, bp); + std::free(floatdataptr); + } } diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.cpp old mode 100755 new mode 100644 index b865db7b25..c347cb97a9 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.cpp @@ -3,51 +3,60 @@ #include "../ncnn_ops_definer.h" -namespace mmdeploy { -using namespace ncnn; -DEFINE_LAYER_CREATOR(ConstantOfShape) -DEFINE_NCNN_OPS(ConstantOfShape, ConstantOfShape) -ConstantOfShape::ConstantOfShape() { - one_blob_only = true; - support_inplace = false; -} +namespace mmdeploy +{ + using namespace ncnn; + DEFINE_LAYER_CREATOR(ConstantOfShape) + DEFINE_NCNN_OPS(ConstantOfShape, ConstantOfShape) + ConstantOfShape::ConstantOfShape() + { + one_blob_only = true; + support_inplace = false; + } -int ConstantOfShape::load_param(const ParamDict& pd) { - val = pd.get(0, 0.f); - return 0; -} + int ConstantOfShape::load_param(const ParamDict& pd) + { + val = pd.get(0, 0.f); + return 0; + } -int ConstantOfShape::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { - int dims = bottom_blob.w - 1; - const float* bottom_ptr = bottom_blob; - const float* shape_ptr = bottom_ptr + 1; + int ConstantOfShape::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const + { + int dims = bottom_blob.w - 1; + const float* bottom_ptr = bottom_blob; + const float* shape_ptr = bottom_ptr + 1; - if (dims == 1) { - int w = (int)(shape_ptr[0] + 0.5); - size_t elemsize = sizeof(val); - top_blob.create(w, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - top_blob.fill(val); - return 0; - } else if (dims == 2) { - int h = (int)(shape_ptr[0] + 0.5); - int w = (int)(shape_ptr[1] + 0.5); - size_t elemsize = sizeof(val); - top_blob.create(w, h, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - top_blob.fill(val); - return 0; - } else if (dims == 3) { - int channels = (int)(shape_ptr[0] + 0.5); - int h = (int)(shape_ptr[1] + 0.5); - int w = (int)(shape_ptr[2] + 0.5); - size_t elemsize = sizeof(val); - top_blob.create(w, h, channels, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - top_blob.fill(val); - return 0; - } - return -1; -} + if (dims == 1) + { + int w = (int)(shape_ptr[0] + 0.5); + size_t elemsize = sizeof(val); + top_blob.create(w, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + top_blob.fill(val); + return 0; + } + else if (dims == 2) + { + int h = (int)(shape_ptr[0] + 0.5); + int w = (int)(shape_ptr[1] + 0.5); + size_t elemsize = sizeof(val); + top_blob.create(w, h, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + top_blob.fill(val); + return 0; + } + else if (dims == 3) + { + int channels = (int)(shape_ptr[0] + 0.5); + int h = (int)(shape_ptr[1] + 0.5); + int w = (int)(shape_ptr[2] + 0.5); + size_t elemsize = sizeof(val); + top_blob.create(w, h, channels, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + top_blob.fill(val); + return 0; + } + return -1; + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.h b/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.h old mode 100755 new mode 100644 index b61fb62c09..d068fd3196 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/constantofshape/constantofshape.h @@ -4,20 +4,21 @@ #include "layer.h" -namespace mmdeploy { +namespace mmdeploy +{ -class ConstantOfShape : public ncnn::Layer { - public: - ConstantOfShape(); + class ConstantOfShape : public ncnn::Layer + { + public: + ConstantOfShape(); - virtual int load_param(const ncnn::ParamDict& pd); + virtual int load_param(const ncnn::ParamDict& pd); - virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, - const ncnn::Option& opt) const; + virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, const ncnn::Option& opt) const; - public: - float val; -}; + public: + float val; + }; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.cpp old mode 100755 new mode 100644 index be3d75a248..c742b91df7 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.cpp @@ -4,330 +4,452 @@ #include "expand.h" #include "../ncnn_ops_definer.h" -namespace mmdeploy { -using namespace ncnn; -DEFINE_LAYER_CREATOR(Expand) -DEFINE_NCNN_OPS(Expand, Expand) -Expand::Expand() { - one_blob_only = false; - support_inplace = false; -} - -int Expand::forward(const std::vector& bottom_blobs, std::vector& top_blobs, - const Option& opt) const { - const Mat& bottom_blob = bottom_blobs[0]; - size_t elemsize = bottom_blob.elemsize; - const Mat& old_shape_blob = bottom_blobs[1]; - const int shape_width = old_shape_blob.w - 1; - Mat shape_blob(shape_width, elemsize, opt.workspace_allocator); - memcpy(shape_blob.row(0), old_shape_blob.row(0) + 1, shape_width * elemsize); - Mat& top_blob = top_blobs[0]; - - if (bottom_blob.dims == 1 && shape_blob.w == 1) { - int shape_0 = (int)(shape_blob[0] + 0.5); - if (bottom_blob.w != shape_0 && bottom_blob.w != 1 && shape_0 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d) vs (%d)\n", bottom_blob.w, shape_0); - } else if (bottom_blob.w == shape_0 || shape_0 == 1) { - top_blob.create(bottom_blob.w, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - - for (int i = 0; i < bottom_blob.w; i++) { - top_blob[i] = bottom_blob[i]; - } - } else if (bottom_blob.w == 1) { - top_blob.create(shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - - for (int i = 0; i < shape_0; i++) { - top_blob[i] = bottom_blob[0]; - } - } else { - fprintf(stderr, "error case\n"); - return -100; +namespace mmdeploy +{ + using namespace ncnn; + DEFINE_LAYER_CREATOR(Expand) + DEFINE_NCNN_OPS(Expand, Expand) + Expand::Expand() + { + one_blob_only = false; + support_inplace = false; } - return 0; - } else if (bottom_blob.dims == 1 && shape_blob.w == 2) { - int shape_0 = (int)(shape_blob[0] + 0.5); - int shape_1 = (int)(shape_blob[1] + 0.5); - if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (1, %d) vs (%d, %d)\n", bottom_blob.w, shape_0, - shape_1); - } else if (bottom_blob.w == shape_1 || shape_1 == 1) { - top_blob.create(bottom_blob.w, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int j = 0; j < shape_0; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.row(j)[i] = bottom_blob[i]; - } - } + int Expand::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const + { + const Mat& bottom_blob = bottom_blobs[0]; + size_t elemsize = bottom_blob.elemsize; + const Mat& old_shape_blob = bottom_blobs[1]; + const int shape_width = old_shape_blob.w - 1; + Mat shape_blob(shape_width, elemsize, opt.workspace_allocator); + memcpy(shape_blob.row(0), old_shape_blob.row(0) + 1, shape_width * elemsize); + Mat& top_blob = top_blobs[0]; - } else if (bottom_blob.w == 1) { - top_blob.create(shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; + if (bottom_blob.dims == 1 && shape_blob.w == 1) + { + int shape_0 = (int)(shape_blob[0] + 0.5); + if (bottom_blob.w != shape_0 && bottom_blob.w != 1 && shape_0 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d) vs (%d)\n", bottom_blob.w, shape_0); + } + else if (bottom_blob.w == shape_0 || shape_0 == 1) + { + top_blob.create(bottom_blob.w, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; - for (int j = 0; j < shape_0; j++) { - for (int i = 0; i < shape_1; i++) { - top_blob.row(j)[i] = bottom_blob[0]; - } - } + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob[i] = bottom_blob[i]; + } + } + else if (bottom_blob.w == 1) + { + top_blob.create(shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; - } else { - fprintf(stderr, "error case\n"); - return -100; - } - return 0; - } else if (bottom_blob.dims == 1 && shape_blob.w == 3) { - int shape_0 = (int)(shape_blob[0] + 0.5); - int shape_1 = (int)(shape_blob[1] + 0.5); - int shape_2 = (int)(shape_blob[2] + 0.5); - - if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (1, 1, %d) vs (%d, %d, %d)\n", bottom_blob.w, - shape_0, shape_1, shape_2); - } else if (bottom_blob.w == shape_2 || shape_2 == 1) { - top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob[i]; - } - } - } - } else if (bottom_blob.w == 1) { - top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob[0]; - } - } - } - } else { - fprintf(stderr, "error case\n"); - return -100; - } - return 0; - } else if (bottom_blob.dims == 2 && shape_blob.w == 2) { - int shape_0 = (int)(shape_blob[0] + 0.5); - int shape_1 = (int)(shape_blob[1] + 0.5); - if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d)\n", bottom_blob.h, - bottom_blob.w, shape_0, shape_1); - } else if (bottom_blob.h != shape_0 && bottom_blob.h != 1 && shape_0 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d)\n", bottom_blob.h, - bottom_blob.w, shape_0, shape_1); - } else if ((bottom_blob.w == shape_1 || shape_1 == 1) && - (bottom_blob.h == shape_0 || shape_0 == 1)) { - top_blob.create(bottom_blob.w, bottom_blob.h, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.row(j)[i] = bottom_blob.row(j)[i]; - } - } - } else if ((bottom_blob.w == shape_1 || shape_1 == 1) && (bottom_blob.h == 1)) { - top_blob.create(bottom_blob.w, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int j = 0; j < shape_0; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.row(j)[i] = bottom_blob.row(0)[i]; - } - } - } else if ((bottom_blob.w == 1) && (bottom_blob.h == shape_0 || shape_0 == 1)) { - top_blob.create(shape_1, bottom_blob.h, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < shape_1; i++) { - top_blob.row(j)[i] = bottom_blob.row(j)[0]; + for (int i = 0; i < shape_0; i++) + { + top_blob[i] = bottom_blob[0]; + } + } + else + { + fprintf(stderr, "error case\n"); + return -100; + } + return 0; } - } - } else if (bottom_blob.h == 1 && bottom_blob.w == 1) { - top_blob.create(shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int j = 0; j < shape_0; j++) { - for (int i = 0; i < shape_1; i++) { - top_blob.row(j)[i] = bottom_blob.row(0)[0]; - } - } - } else { - fprintf(stderr, "error case\n"); - return -100; - } - return 0; - } else if (bottom_blob.dims == 2 && shape_blob.w == 3) { - int shape_0 = (int)(shape_blob[0] + 0.5); - int shape_1 = (int)(shape_blob[1] + 0.5); - int shape_2 = (int)(shape_blob[2] + 0.5); - if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d, %d)\n", bottom_blob.h, - bottom_blob.w, shape_0, shape_1, shape_2); - } else if (bottom_blob.h != shape_1 && bottom_blob.h != 1 && shape_1 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d, %d)\n", bottom_blob.h, - bottom_blob.w, shape_0, shape_1, shape_2); - } else if ((bottom_blob.w == shape_2 || shape_2 == 1) && - (bottom_blob.h == shape_1 || shape_1 == 1)) { - top_blob.create(bottom_blob.w, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.row(j)[i]; - } - } - } - } else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1)) { - top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.row(0)[i]; - } - } - } - - } else if ((bottom_blob.w == 1) && (bottom_blob.h == shape_1 || shape_1 == 1)) { - top_blob.create(shape_2, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.row(j)[0]; - } - } - } + else if (bottom_blob.dims == 1 && shape_blob.w == 2) + { + int shape_0 = (int)(shape_blob[0] + 0.5); + int shape_1 = (int)(shape_blob[1] + 0.5); + if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (1, %d) vs (%d, %d)\n", bottom_blob.w, shape_0, shape_1); + } + else if (bottom_blob.w == shape_1 || shape_1 == 1) + { + top_blob.create(bottom_blob.w, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; - } else if (bottom_blob.h == 1 && bottom_blob.w == 1) { - top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.row(0)[0]; - } - } - } - } else { - fprintf(stderr, "error case\n"); - return -100; - } - return 0; - } else if (bottom_blob.dims == 3 && shape_blob.w == 3) { - int shape_0 = (int)(shape_blob[0] + 0.5); - int shape_1 = (int)(shape_blob[1] + 0.5); - int shape_2 = (int)(shape_blob[2] + 0.5); - if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c, - bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); - } else if (bottom_blob.h != shape_1 && bottom_blob.h != 1 && shape_1 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c, - bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); - } else if (bottom_blob.c != shape_0 && bottom_blob.c != 1 && shape_0 != 1) { - fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c, - bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); - } else if ((bottom_blob.w == shape_2 || shape_2 == 1) && - (bottom_blob.h == shape_1 || shape_1 == 1) && - (bottom_blob.c == shape_0 || shape_0 == 1)) { - top_blob.create(bottom_blob.w, bottom_blob.h, bottom_blob.c, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[i]; - } - } - } - } else if ((bottom_blob.w == shape_2 || shape_2 == 1) && - (bottom_blob.h == shape_1 || shape_1 == 1) && (bottom_blob.c == 1)) { - top_blob.create(bottom_blob.w, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(j)[i]; - } - } - } - - } else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1) && - (bottom_blob.c == shape_0 || shape_0 == 1)) { - top_blob.create(bottom_blob.w, shape_1, bottom_blob.c, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[i]; - } - } - } + for (int j = 0; j < shape_0; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.row(j)[i] = bottom_blob[i]; + } + } + } + else if (bottom_blob.w == 1) + { + top_blob.create(shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; - } else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1) && - (bottom_blob.c == 1)) { - top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < bottom_blob.w; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(0)[i]; - } + for (int j = 0; j < shape_0; j++) + { + for (int i = 0; i < shape_1; i++) + { + top_blob.row(j)[i] = bottom_blob[0]; + } + } + } + else + { + fprintf(stderr, "error case\n"); + return -100; + } + return 0; } - } + else if (bottom_blob.dims == 1 && shape_blob.w == 3) + { + int shape_0 = (int)(shape_blob[0] + 0.5); + int shape_1 = (int)(shape_blob[1] + 0.5); + int shape_2 = (int)(shape_blob[2] + 0.5); - } else if (bottom_blob.w == 1 && (bottom_blob.h == shape_1 || shape_1 == 1) && - (bottom_blob.c == shape_0 || shape_0 == 1)) { - top_blob.create(shape_2, bottom_blob.h, bottom_blob.c, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[0]; - } + if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (1, 1, %d) vs (%d, %d, %d)\n", bottom_blob.w, shape_0, shape_1, shape_2); + } + else if (bottom_blob.w == shape_2 || shape_2 == 1) + { + top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob[i]; + } + } + } + } + else if (bottom_blob.w == 1) + { + top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob[0]; + } + } + } + } + else + { + fprintf(stderr, "error case\n"); + return -100; + } + return 0; } - } - } else if (bottom_blob.w == 1 && (bottom_blob.h == shape_1 || shape_1 == 1) && - (bottom_blob.c == 1)) { - top_blob.create(shape_2, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < bottom_blob.h; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(j)[0]; - } + else if (bottom_blob.dims == 2 && shape_blob.w == 2) + { + int shape_0 = (int)(shape_blob[0] + 0.5); + int shape_1 = (int)(shape_blob[1] + 0.5); + if (bottom_blob.w != shape_1 && bottom_blob.w != 1 && shape_1 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d)\n", bottom_blob.h, bottom_blob.w, shape_0, shape_1); + } + else if (bottom_blob.h != shape_0 && bottom_blob.h != 1 && shape_0 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d)\n", bottom_blob.h, bottom_blob.w, shape_0, shape_1); + } + else if ((bottom_blob.w == shape_1 || shape_1 == 1) && + (bottom_blob.h == shape_0 || shape_0 == 1)) + { + top_blob.create(bottom_blob.w, bottom_blob.h, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.row(j)[i] = bottom_blob.row(j)[i]; + } + } + } + else if ((bottom_blob.w == shape_1 || shape_1 == 1) && (bottom_blob.h == 1)) + { + top_blob.create(bottom_blob.w, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int j = 0; j < shape_0; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.row(j)[i] = bottom_blob.row(0)[i]; + } + } + } + else if ((bottom_blob.w == 1) && (bottom_blob.h == shape_0 || shape_0 == 1)) + { + top_blob.create(shape_1, bottom_blob.h, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < shape_1; i++) + { + top_blob.row(j)[i] = bottom_blob.row(j)[0]; + } + } + } + else if (bottom_blob.h == 1 && bottom_blob.w == 1) + { + top_blob.create(shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int j = 0; j < shape_0; j++) + { + for (int i = 0; i < shape_1; i++) + { + top_blob.row(j)[i] = bottom_blob.row(0)[0]; + } + } + } + else + { + fprintf(stderr, "error case\n"); + return -100; + } + return 0; } - } - } else if (bottom_blob.w == 1 && bottom_blob.h == 1 && - (bottom_blob.c == shape_0 || shape_0 == 1)) { - top_blob.create(shape_2, shape_1, bottom_blob.c, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < bottom_blob.c; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[0]; - } + else if (bottom_blob.dims == 2 && shape_blob.w == 3) + { + int shape_0 = (int)(shape_blob[0] + 0.5); + int shape_1 = (int)(shape_blob[1] + 0.5); + int shape_2 = (int)(shape_blob[2] + 0.5); + if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d, %d)\n", bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); + } + else if (bottom_blob.h != shape_1 && bottom_blob.h != 1 && shape_1 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d) vs (%d, %d, %d)\n", bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); + } + else if ((bottom_blob.w == shape_2 || shape_2 == 1) && + (bottom_blob.h == shape_1 || shape_1 == 1)) + { + top_blob.create(bottom_blob.w, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.row(j)[i]; + } + } + } + } + else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1)) + { + top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.row(0)[i]; + } + } + } + } + else if ((bottom_blob.w == 1) && (bottom_blob.h == shape_1 || shape_1 == 1)) + { + top_blob.create(shape_2, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.row(j)[0]; + } + } + } + } + else if (bottom_blob.h == 1 && bottom_blob.w == 1) + { + top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.row(0)[0]; + } + } + } + } + else + { + fprintf(stderr, "error case\n"); + return -100; + } + return 0; } - } - } else if (bottom_blob.w == 1 && bottom_blob.h == 1 && bottom_blob.c == 1) { - top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator); - if (top_blob.empty()) return -100; - for (int k = 0; k < shape_0; k++) { - for (int j = 0; j < shape_1; j++) { - for (int i = 0; i < shape_2; i++) { - top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(0)[0]; - } + else if (bottom_blob.dims == 3 && shape_blob.w == 3) + { + int shape_0 = (int)(shape_blob[0] + 0.5); + int shape_1 = (int)(shape_blob[1] + 0.5); + int shape_2 = (int)(shape_blob[2] + 0.5); + if (bottom_blob.w != shape_2 && bottom_blob.w != 1 && shape_2 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); + } + else if (bottom_blob.h != shape_1 && bottom_blob.h != 1 && shape_1 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); + } + else if (bottom_blob.c != shape_0 && bottom_blob.c != 1 && shape_0 != 1) + { + fprintf(stderr, "The broadcast rule is wrong, (%d, %d, %d) vs (%d, %d, %d)\n", bottom_blob.c, bottom_blob.h, bottom_blob.w, shape_0, shape_1, shape_2); + } + else if ((bottom_blob.w == shape_2 || shape_2 == 1) && + (bottom_blob.h == shape_1 || shape_1 == 1) && + (bottom_blob.c == shape_0 || shape_0 == 1)) + { + top_blob.create(bottom_blob.w, bottom_blob.h, bottom_blob.c, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < bottom_blob.c; k++) + { + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[i]; + } + } + } + } + else if ((bottom_blob.w == shape_2 || shape_2 == 1) && + (bottom_blob.h == shape_1 || shape_1 == 1) && (bottom_blob.c == 1)) + { + top_blob.create(bottom_blob.w, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(j)[i]; + } + } + } + } + else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1) && + (bottom_blob.c == shape_0 || shape_0 == 1)) + { + top_blob.create(bottom_blob.w, shape_1, bottom_blob.c, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < bottom_blob.c; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[i]; + } + } + } + } + else if ((bottom_blob.w == shape_2 || shape_2 == 1) && (bottom_blob.h == 1) && + (bottom_blob.c == 1)) + { + top_blob.create(bottom_blob.w, shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < bottom_blob.w; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(0)[i]; + } + } + } + } + else if (bottom_blob.w == 1 && (bottom_blob.h == shape_1 || shape_1 == 1) && + (bottom_blob.c == shape_0 || shape_0 == 1)) + { + top_blob.create(shape_2, bottom_blob.h, bottom_blob.c, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < bottom_blob.c; k++) + { + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(j)[0]; + } + } + } + } + else if (bottom_blob.w == 1 && (bottom_blob.h == shape_1 || shape_1 == 1) && + (bottom_blob.c == 1)) + { + top_blob.create(shape_2, bottom_blob.h, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < bottom_blob.h; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(j)[0]; + } + } + } + } + else if (bottom_blob.w == 1 && bottom_blob.h == 1 && + (bottom_blob.c == shape_0 || shape_0 == 1)) + { + top_blob.create(shape_2, shape_1, bottom_blob.c, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < bottom_blob.c; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(k).row(0)[0]; + } + } + } + } + else if (bottom_blob.w == 1 && bottom_blob.h == 1 && bottom_blob.c == 1) + { + top_blob.create(shape_2, shape_1, shape_0, elemsize, opt.blob_allocator); + if (top_blob.empty()) return -100; + for (int k = 0; k < shape_0; k++) + { + for (int j = 0; j < shape_1; j++) + { + for (int i = 0; i < shape_2; i++) + { + top_blob.channel(k).row(j)[i] = bottom_blob.channel(0).row(0)[0]; + } + } + } + } + else + { + fprintf(stderr, "error case\n"); + return -100; + } + return 0; } - } - } else { - fprintf(stderr, "error case\n"); - return -100; + fprintf(stderr, "Layer: Expand, bottom_blob.dims: %d, shape_blob.w: %d\n", bottom_blob.dims, shape_blob.w); + return -1; } - return 0; - } - fprintf(stderr, "Layer: Expand, bottom_blob.dims: %d, shape_blob.w: %d\n", bottom_blob.dims, - shape_blob.w); - return -1; -} } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.h b/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.h old mode 100755 new mode 100644 index 3dca54fb0f..a378965d03 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/expand/expand.h @@ -4,15 +4,16 @@ #include "layer.h" -namespace mmdeploy { +namespace mmdeploy +{ -class Expand : public ncnn::Layer { - public: - Expand(); + class Expand : public ncnn::Layer + { + public: + Expand(); - virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, - const ncnn::Option& opt) const; -}; + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const ncnn::Option& opt) const; + }; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.cpp index 4b6bd34630..24ea7f7181 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.cpp @@ -4,157 +4,180 @@ #include "../ncnn_ops_definer.h" #include "assert.h" -namespace mmdeploy { -using namespace ncnn; -DEFINE_LAYER_CREATOR(Gather) -DEFINE_NCNN_OPS(Gather, Gather) -Gather::Gather() { - one_blob_only = false; - support_inplace = false; -} - -int Gather::load_param(const ParamDict &pd) { - axis = pd.get(0, 0); - - return 0; -} - -// Gather only support 1-dim of indices, because the data and indices all has -// implicit batch in ncnn, this will lead to wrong shape to match onnx result. -// When indices dim equals to 1, after eliminating implicit batch, the indices -// dim still be 1. So there is only 1 implicit batch in data, this will make -// the shape match onnx result. -int Gather::forward(const std::vector &bottom_blobs, std::vector &top_blobs, - const Option &opt) const { - const Mat &bottom_blob = bottom_blobs[0]; - const Mat &indices = bottom_blobs[1]; - int dims = bottom_blob.dims; - int indices_dims = indices.dims; - size_t elemsize = bottom_blob.elemsize; - int positive_axis = axis < 0 ? dims + axis : axis; - Mat &top_blob = top_blobs[0]; - assert(indices.dims == 1); - const float *indices_ptr = indices; - - if (dims == 1 && indices_dims == 1) // positive_axis == 0 - { - int w = indices.w; - top_blob.create(w, elemsize, opt.blob_allocator); - if (top_blob.empty()) { - return -100; - } - const float *ptr = bottom_blob; - float *outptr = top_blob; - for (int i = 0; i < w; i++) { - float indice = indices_ptr[i]; - outptr[i] = ptr[(int)(indice + 0.5)]; +namespace mmdeploy +{ + using namespace ncnn; + DEFINE_LAYER_CREATOR(Gather) + DEFINE_NCNN_OPS(Gather, Gather) + Gather::Gather() + { + one_blob_only = false; + support_inplace = false; } - return 0; - } - - if (dims == 2 && positive_axis == 0 && indices_dims == 1) { - int w = bottom_blob.w; - int h = bottom_blob.h; - top_blob.create(w, indices.w, elemsize, opt.blob_allocator); - // w -> w - // h -> indices.w - // h * w -> indices.w * w - if (top_blob.empty()) { - return -100; - } - const float *ptr = bottom_blob; - float *outptr = top_blob; - for (int i = 0; i < indices.w; i++) { - const int selected = (int)(indices_ptr[i] + 0.5); - memcpy(top_blob.row(i), bottom_blob.row(selected), w * elemsize); - } + int Gather::load_param(const ParamDict& pd) + { + axis = pd.get(0, 0); - return 0; - } - - if (dims == 2 && positive_axis == 1 && indices_dims == 1) { - int w = bottom_blob.w; - int h = bottom_blob.h; - top_blob.create(indices.w, h, elemsize, opt.blob_allocator); - // w -> h - // h -> indices.w - // h * w -> indices.w * h - if (top_blob.empty()) { - return -100; - } - const float *ptr = bottom_blob; - float *outptr = top_blob; - for (int j = 0; j < h; j++) { - for (int i = 0; i < indices.w; i++) { - int selected = (int)(indices_ptr[i] + 0.5); - outptr[j * indices.w + i] = ptr[j * w + selected]; - } + return 0; } - return 0; - } - if (dims == 3 && positive_axis == 0 && indices_dims == 1) { - int w = bottom_blob.w; - int h = bottom_blob.h; - int channels = bottom_blob.c; - top_blob.create(w, h, indices.w, elemsize, opt.blob_allocator); + // Gather only support 1-dim of indices, because the data and indices all has + // implicit batch in ncnn, this will lead to wrong shape to match onnx result. + // When indices dim equals to 1, after eliminating implicit batch, the indices + // dim still be 1. So there is only 1 implicit batch in data, this will make + // the shape match onnx result. + int Gather::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const + { + const Mat& bottom_blob = bottom_blobs[0]; + const Mat& indices = bottom_blobs[1]; + int dims = bottom_blob.dims; + int indices_dims = indices.dims; + size_t elemsize = bottom_blob.elemsize; + int positive_axis = axis < 0 ? dims + axis : axis; + Mat& top_blob = top_blobs[0]; + assert(indices.dims == 1); + const float* indices_ptr = indices; + + if (dims == 1 && indices_dims == 1) // positive_axis == 0 + { + int w = indices.w; + top_blob.create(w, elemsize, opt.blob_allocator); + if (top_blob.empty()) + { + return -100; + } + const float* ptr = bottom_blob; + float* outptr = top_blob; + for (int i = 0; i < w; i++) + { + float indice = indices_ptr[i]; + outptr[i] = ptr[(int)(indice + 0.5)]; + } + + return 0; + } - if (top_blob.empty()) { - return -100; - } - for (int i = 0; i < indices.w; i++) { - int selected = (int)(indices_ptr[i] + 0.5); - const unsigned char *ptr = bottom_blob.channel(selected); - unsigned char *outptr = top_blob.channel(i); + if (dims == 2 && positive_axis == 0 && indices_dims == 1) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + top_blob.create(w, indices.w, elemsize, opt.blob_allocator); + // w -> w + // h -> indices.w + // h * w -> indices.w * w + if (top_blob.empty()) + { + return -100; + } + const float* ptr = bottom_blob; + float* outptr = top_blob; + for (int i = 0; i < indices.w; i++) + { + const int selected = (int)(indices_ptr[i] + 0.5); + memcpy(top_blob.row(i), bottom_blob.row(selected), w * elemsize); + } + + return 0; + } - memcpy(outptr, ptr, w * h * elemsize); - } - return 0; - } - - if (dims == 3 && positive_axis == 1 && indices_dims == 1) { - int w = bottom_blob.w; - int h = bottom_blob.h; - int channels = bottom_blob.c; - top_blob.create(w, indices.w, channels, elemsize, opt.blob_allocator); -#pragma omp parallel for num_threads(opt.num_threads) - // use parallel programming - for (int i = 0; i < channels; i++) { - float *outptr = top_blob.channel(i); - const float *ptr = bottom_blob.channel(i); - for (int j = 0; j < indices.w; j++) { - int selected = (int)(indices_ptr[j] + 0.5); - for (int k = 0; k < w; k++) { - outptr[j * w + k] = ptr[selected * w + k]; + if (dims == 2 && positive_axis == 1 && indices_dims == 1) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + top_blob.create(indices.w, h, elemsize, opt.blob_allocator); + // w -> h + // h -> indices.w + // h * w -> indices.w * h + if (top_blob.empty()) + { + return -100; + } + const float* ptr = bottom_blob; + float* outptr = top_blob; + for (int j = 0; j < h; j++) + { + for (int i = 0; i < indices.w; i++) + { + int selected = (int)(indices_ptr[i] + 0.5); + outptr[j * indices.w + i] = ptr[j * w + selected]; + } + } + return 0; } - } - } - return 0; - } + if (dims == 3 && positive_axis == 0 && indices_dims == 1) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + int channels = bottom_blob.c; + top_blob.create(w, h, indices.w, elemsize, opt.blob_allocator); + + if (top_blob.empty()) + { + return -100; + } + for (int i = 0; i < indices.w; i++) + { + int selected = (int)(indices_ptr[i] + 0.5); + const unsigned char* ptr = bottom_blob.channel(selected); + unsigned char* outptr = top_blob.channel(i); + + memcpy(outptr, ptr, w * h * elemsize); + } + return 0; + } - if (dims == 3 && positive_axis == 2 && indices_dims == 1) { - int w = bottom_blob.w; - int h = bottom_blob.h; - int channels = bottom_blob.c; - top_blob.create(indices.w, h, channels, elemsize, opt.blob_allocator); + if (dims == 3 && positive_axis == 1 && indices_dims == 1) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + int channels = bottom_blob.c; + top_blob.create(w, indices.w, channels, elemsize, opt.blob_allocator); #pragma omp parallel for num_threads(opt.num_threads) - // use parallel programming - for (int i = 0; i < channels; i++) { - float *outptr = top_blob.channel(i); - const float *ptr = bottom_blob.channel(i); - for (int j = 0; j < h; j++) { - for (int k = 0; k < indices.w; k++) { - int selected = (int)(indices_ptr[k] + 0.5); - outptr[j * indices.w + k] = ptr[j * w + selected]; + // use parallel programming + for (int i = 0; i < channels; i++) + { + float* outptr = top_blob.channel(i); + const float* ptr = bottom_blob.channel(i); + for (int j = 0; j < indices.w; j++) + { + int selected = (int)(indices_ptr[j] + 0.5); + for (int k = 0; k < w; k++) + { + outptr[j * w + k] = ptr[selected * w + k]; + } + } + } + + return 0; } - } - } - return 0; - } - return 0; -} + if (dims == 3 && positive_axis == 2 && indices_dims == 1) + { + int w = bottom_blob.w; + int h = bottom_blob.h; + int channels = bottom_blob.c; + top_blob.create(indices.w, h, channels, elemsize, opt.blob_allocator); +#pragma omp parallel for num_threads(opt.num_threads) + // use parallel programming + for (int i = 0; i < channels; i++) + { + float* outptr = top_blob.channel(i); + const float* ptr = bottom_blob.channel(i); + for (int j = 0; j < h; j++) + { + for (int k = 0; k < indices.w; k++) + { + int selected = (int)(indices_ptr[k] + 0.5); + outptr[j * indices.w + k] = ptr[j * w + selected]; + } + } + } + return 0; + } + + return 0; + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.h b/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.h old mode 100755 new mode 100644 index af6eb6365e..13d38e4bd0 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/gather/gather.h @@ -4,20 +4,21 @@ #include "layer.h" -namespace mmdeploy { +namespace mmdeploy +{ -class Gather : public ncnn::Layer { - public: - Gather(); + class Gather : public ncnn::Layer + { + public: + Gather(); - virtual int load_param(const ncnn::ParamDict& pd); + virtual int load_param(const ncnn::ParamDict& pd); - virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, - const ncnn::Option& opt) const; + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const ncnn::Option& opt) const; - public: - int axis; -}; + public: + int axis; + }; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_definer.h b/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_definer.h old mode 100755 new mode 100644 index 509c8c0ce0..bd5d9ca23e --- a/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_definer.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_definer.h @@ -7,22 +7,24 @@ #include "layer.h" #include "ncnn_ops_register.h" -namespace mmdeploy { - -class NCNNOpsDefiner { - public: - NCNNOpsDefiner(const std::string& ops_name, const ncnn::layer_creator_func& creator_func = 0, - const ncnn::layer_destroyer_func& destroyer_func = 0) - : _ops_name(ops_name) { - get_mmdeploy_layer_creator()[_ops_name.c_str()] = creator_func; - } - - private: - const std::string _ops_name; -}; +namespace mmdeploy +{ + + class NCNNOpsDefiner + { + public: + NCNNOpsDefiner(const std::string& ops_name, const ncnn::layer_creator_func& creator_func = 0, const ncnn::layer_destroyer_func& destroyer_func = 0) + : _ops_name(ops_name) + { + get_mmdeploy_layer_creator()[_ops_name.c_str()] = creator_func; + } + + private: + const std::string _ops_name; + }; #define DEFINE_NCNN_OPS(ops_name, OpsLayer) \ - static mmdeploy::NCNNOpsDefiner NCNNOpsDefiner##ops_name{#ops_name, OpsLayer##_layer_creator}; + static mmdeploy::NCNNOpsDefiner NCNNOpsDefiner##ops_name{#ops_name, OpsLayer##_layer_creator}; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.cpp old mode 100755 new mode 100644 index 42bc050a1c..85d4f66d04 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.cpp @@ -3,32 +3,38 @@ #include -std::map &get_mmdeploy_layer_creator() { - static std::map _layer_creator_map; - return _layer_creator_map; +std::map& get_mmdeploy_layer_creator() +{ + static std::map _layer_creator_map; + return _layer_creator_map; } -std::map &get_mmdeploy_layer_destroyer() { - static std::map _layer_destroyer_map; - return _layer_destroyer_map; +std::map& get_mmdeploy_layer_destroyer() +{ + static std::map _layer_destroyer_map; + return _layer_destroyer_map; } -int register_mmdeploy_custom_layers(ncnn::Net &net) { - auto &layer_creator_map = get_mmdeploy_layer_creator(); - auto &layer_destroyer_map = get_mmdeploy_layer_destroyer(); +int register_mmdeploy_custom_layers(ncnn::Net& net) +{ + auto& layer_creator_map = get_mmdeploy_layer_creator(); + auto& layer_destroyer_map = get_mmdeploy_layer_destroyer(); - for (auto const &creator_pair : layer_creator_map) { - auto creator_name = creator_pair.first; - auto creator_func = creator_pair.second; + for (auto const& creator_pair : layer_creator_map) + { + auto creator_name = creator_pair.first; + auto creator_func = creator_pair.second; - ncnn::layer_destroyer_func destroyer_func = 0; - if (layer_destroyer_map.find(creator_name) != layer_destroyer_map.end()) { - destroyer_func = layer_destroyer_map[creator_name]; + ncnn::layer_destroyer_func destroyer_func = 0; + if (layer_destroyer_map.find(creator_name) != layer_destroyer_map.end()) + { + destroyer_func = layer_destroyer_map[creator_name]; + } + int ret = net.register_custom_layer(creator_name, creator_func, destroyer_func); + if (0 != ret) + { + return ret; + } } - int ret = net.register_custom_layer(creator_name, creator_func, destroyer_func); - if (0 != ret) { - return ret; - } - } - return 0; + return 0; } diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.h b/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.h old mode 100755 new mode 100644 index 0d9974f783..b0de664040 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/ncnn_ops_register.h @@ -11,6 +11,6 @@ MMDEPLOY_API std::map& get_mmdeploy_layer_creator(); MMDEPLOY_API std::map& get_mmdeploy_layer_destroyer(); -MMDEPLOY_API int register_mmdeploy_custom_layers(ncnn::Net& net); +MMDEPLOY_API int register_mmdeploy_custom_layers(ncnn::Net& net); #endif diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.cpp old mode 100755 new mode 100644 index f538eabbac..17ae195659 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.cpp @@ -3,45 +3,56 @@ #include "../ncnn_ops_definer.h" -namespace mmdeploy { -using namespace ncnn; -DEFINE_LAYER_CREATOR(Shape) -DEFINE_NCNN_OPS(Shape, Shape) -Shape::Shape() { - one_blob_only = true; - support_inplace = false; -} +namespace mmdeploy +{ + using namespace ncnn; + DEFINE_LAYER_CREATOR(Shape) + DEFINE_NCNN_OPS(Shape, Shape) + Shape::Shape() + { + one_blob_only = true; + support_inplace = false; + } -int Shape::forward(const Mat &bottom_blob, Mat &top_blob, const Option &opt) const { - int dims = bottom_blob.dims; - int w = bottom_blob.w; - size_t elemsize = sizeof(float); - top_blob.create(dims + 1, elemsize, opt.blob_allocator); - if (top_blob.empty()) { - return -100; - } - float *outptr = top_blob; + int Shape::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const + { + int dims = bottom_blob.dims; + int w = bottom_blob.w; + size_t elemsize = sizeof(float); + top_blob.create(dims + 1, elemsize, opt.blob_allocator); + if (top_blob.empty()) + { + return -100; + } + float* outptr = top_blob; - if (dims == 1) { - outptr[0] = 1.0f; - outptr[1] = w; - } else if (dims == 2) { - int h = bottom_blob.h; - outptr[0] = 1.0f; - outptr[1] = h; - outptr[2] = w; - } else if (dims == 3) { - int h = bottom_blob.h; - int channels = bottom_blob.c; - outptr[0] = 1.0f; - outptr[1] = channels; - outptr[2] = h; - outptr[3] = w; - } else { - fprintf(stdout, "Unsupported dims=%d\n", dims); - } + if (dims == 1) + { + outptr[0] = 1.0f; + outptr[1] = w; + } + else if (dims == 2) + { + int h = bottom_blob.h; + outptr[0] = 1.0f; + outptr[1] = h; + outptr[2] = w; + } + else if (dims == 3) + { + int h = bottom_blob.h; + int channels = bottom_blob.c; + outptr[0] = 1.0f; + outptr[1] = channels; + outptr[2] = h; + outptr[3] = w; + } + else + { + fprintf(stdout, "Unsupported dims=%d\n", dims); + } - return 0; -} + return 0; + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.h b/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.h old mode 100755 new mode 100644 index 863dc77c1d..2330f57ba4 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/shape/shape.h @@ -4,15 +4,16 @@ #include "layer.h" -namespace mmdeploy { +namespace mmdeploy +{ -class Shape : public ncnn::Layer { - public: - Shape(); + class Shape : public ncnn::Layer + { + public: + Shape(); - virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, - const ncnn::Option& opt) const; -}; + virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, const ncnn::Option& opt) const; + }; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.cpp index 9f2ced1992..b77c9ce56f 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.cpp @@ -5,202 +5,250 @@ #include "../ncnn_ops_definer.h" -namespace mmdeploy { -using namespace ncnn; -DEFINE_LAYER_CREATOR(TensorSlice) -DEFINE_NCNN_OPS(TensorSlice, TensorSlice) -TensorSlice::TensorSlice() { - one_blob_only = true; - support_inplace = false; -} - -int TensorSlice::load_param(const ParamDict& pd) { - starts = pd.get(0, Mat()); - ends = pd.get(1, Mat()); - axes = pd.get(2, Mat()); - steps = pd.get(3, Mat()); - if (axes.w == 0) { - axes.create(starts.w); - int* axes_ptr = axes; - for (int i = 0; i < starts.w; i++) { - axes_ptr[i] = i; +namespace mmdeploy +{ + using namespace ncnn; + DEFINE_LAYER_CREATOR(TensorSlice) + DEFINE_NCNN_OPS(TensorSlice, TensorSlice) + TensorSlice::TensorSlice() + { + one_blob_only = true; + support_inplace = false; } - } - if (steps.w == 0) { - steps.create(axes.w); - steps.fill(1); - } - return 0; -} - -static inline int get_shape_by_axes(const Mat& blob, int axes, int dims) { - switch (dims - axes) { - case 0: - return blob.w; - case 1: - return blob.h; - case 2: - return blob.c; - default: - fprintf(stderr, "wrong axes %d!\n", axes); - return -1; - } - return 0; -} -int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { - int dims = bottom_blob.dims; - size_t elemsize = bottom_blob.elemsize; - const int* start_ptr = starts; - const int* end_ptr = ends; - const int* axes_ptr = axes; - const int* step_ptr = steps; - if (starts.w > dims || ends.w > dims) { - fprintf(stderr, "start/end attributes shape error!\n"); - return -100; - } - if (axes.w != 1) { - fprintf(stderr, - "axes.w must be 1 because any of multiaxes slice is regarded as " - "multi-staged onnx slice in pytorch2onnx."); - } - if (dims == 1) { - for (int i = 0; i < axes.w; i++) { - int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i]; - int step = step_ptr[i]; - std::vector temp_val; - int start = start_ptr[i]; - int end = end_ptr[i]; - int cur = start; - if (step > 0) { - while (cur < end && cur < bottom_blob.w) { - temp_val.push_back(bottom_blob[cur]); - cur += step; + int TensorSlice::load_param(const ParamDict& pd) + { + starts = pd.get(0, Mat()); + ends = pd.get(1, Mat()); + axes = pd.get(2, Mat()); + steps = pd.get(3, Mat()); + if (axes.w == 0) + { + axes.create(starts.w); + int* axes_ptr = axes; + for (int i = 0; i < starts.w; i++) + { + axes_ptr[i] = i; + } } - } else if (step < 0) { - while (cur > end && cur > 0) { - temp_val.push_back(bottom_blob[cur]); - cur += step; + if (steps.w == 0) + { + steps.create(axes.w); + steps.fill(1); } - } else { - fprintf(stderr, "step should not be 0!\n"); - return -100; - } - top_blob.create(temp_val.size(), elemsize, opt.blob_allocator); - for (int i = 0; i < temp_val.size(); i++) { - top_blob[i] = temp_val[i]; - } - } - return 0; - } - if (dims == 2) { - std::vector > active_indice; - std::vector indices; - for (int i = 0; i < bottom_blob.h; i++) { - indices.push_back(i); + return 0; } - active_indice.push_back(indices); - indices.clear(); - for (int i = 0; i < bottom_blob.w; i++) { - indices.push_back(i); - } - active_indice.push_back(indices); - for (int i = 0; i < axes.w; i++) { - int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i]; - int step = step_ptr[i]; - int start = start_ptr[i]; - int end = end_ptr[i]; - int dim_shape = get_shape_by_axes(bottom_blob, positive_axis, dims); - int dim_shape_test = get_shape_by_axes(bottom_blob, positive_axis, dims - 1); - if (dim_shape < 0) { - return -1; - } - end = end < dim_shape ? end : dim_shape; - int cur = start; - std::vector temp_indice; - if (step > 0) { - while (cur < end && cur < dim_shape) { - temp_indice.push_back(cur); - cur += step; - } - } else if (step < 0) { - while (cur > end && cur > 0) { - temp_indice.push_back(cur); - cur += step; - } - } else { - fprintf(stderr, "step should not be 0!\n"); - return -100; - } - active_indice[positive_axis - 1] = temp_indice; - active_indice[positive_axis - 1].resize(temp_indice.size()); - } - top_blob.create((int)active_indice[1].size(), (int)active_indice[0].size(), elemsize, - opt.blob_allocator); - for (int i = 0; i < active_indice[0].size(); i++) { - for (int j = 0; j < active_indice[1].size(); j++) { - top_blob.row(i)[j] = bottom_blob.row(active_indice[0][i])[active_indice[1][j]]; - } - } - return 0; - } - if (dims == 3) { - std::vector > active_indice; - std::vector indices; - for (int i = 0; i < bottom_blob.c; i++) { - indices.push_back(i); - } - active_indice.push_back(indices); - indices.clear(); - for (int i = 0; i < bottom_blob.h; i++) { - indices.push_back(i); - } - active_indice.push_back(indices); - indices.clear(); - for (int i = 0; i < bottom_blob.w; i++) { - indices.push_back(i); + static inline int get_shape_by_axes(const Mat& blob, int axes, int dims) + { + switch (dims - axes) + { + case 0: + return blob.w; + case 1: + return blob.h; + case 2: + return blob.c; + default: + fprintf(stderr, "wrong axes %d!\n", axes); + return -1; + } + return 0; } - active_indice.push_back(indices); - for (int i = 0; i < axes.w; i++) { - int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i]; - int step = step_ptr[i]; - int start = start_ptr[i]; - int end = end_ptr[i]; - int cur = start; - std::vector temp_indice; - if (step > 0) { - while (cur < end && cur < bottom_blob.w) { - temp_indice.push_back(cur); - cur += step; + int TensorSlice::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const + { + int dims = bottom_blob.dims; + size_t elemsize = bottom_blob.elemsize; + const int* start_ptr = starts; + const int* end_ptr = ends; + const int* axes_ptr = axes; + const int* step_ptr = steps; + if (starts.w > dims || ends.w > dims) + { + fprintf(stderr, "start/end attributes shape error!\n"); + return -100; } - } else if (step < 0) { - while (cur > end && cur > 0) { - temp_indice.push_back(cur); - cur += step; + if (axes.w != 1) + { + fprintf(stderr, + "axes.w must be 1 because any of multiaxes slice is regarded as " + "multi-staged onnx slice in pytorch2onnx."); } - } else { - fprintf(stderr, "step should not be 0!\n"); - return -100; - } - active_indice[positive_axis - 1] = temp_indice; - active_indice[positive_axis - 1].resize(temp_indice.size()); - } - top_blob.create((int)active_indice[2].size(), (int)active_indice[1].size(), - (int)active_indice[0].size(), elemsize, opt.blob_allocator); - for (int i = 0; i < active_indice[0].size(); i++) { - for (int j = 0; j < active_indice[1].size(); j++) { - for (int k = 0; k < active_indice[2].size(); k++) { - top_blob.channel(i).row(j)[k] = bottom_blob.channel(active_indice[0][i]) - .row(active_indice[1][j])[active_indice[2][k]]; + if (dims == 1) + { + for (int i = 0; i < axes.w; i++) + { + int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i]; + int step = step_ptr[i]; + std::vector temp_val; + int start = start_ptr[i]; + int end = end_ptr[i]; + int cur = start; + if (step > 0) + { + while (cur < end && cur < bottom_blob.w) + { + temp_val.push_back(bottom_blob[cur]); + cur += step; + } + } + else if (step < 0) + { + while (cur > end && cur > 0) + { + temp_val.push_back(bottom_blob[cur]); + cur += step; + } + } + else + { + fprintf(stderr, "step should not be 0!\n"); + return -100; + } + top_blob.create(temp_val.size(), elemsize, opt.blob_allocator); + for (int i = 0; i < temp_val.size(); i++) + { + top_blob[i] = temp_val[i]; + } + } + return 0; + } + if (dims == 2) + { + std::vector> active_indice; + std::vector indices; + for (int i = 0; i < bottom_blob.h; i++) + { + indices.push_back(i); + } + active_indice.push_back(indices); + indices.clear(); + for (int i = 0; i < bottom_blob.w; i++) + { + indices.push_back(i); + } + active_indice.push_back(indices); + for (int i = 0; i < axes.w; i++) + { + int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i]; + int step = step_ptr[i]; + int start = start_ptr[i]; + int end = end_ptr[i]; + int dim_shape = get_shape_by_axes(bottom_blob, positive_axis, dims); + int dim_shape_test = get_shape_by_axes(bottom_blob, positive_axis, dims - 1); + if (dim_shape < 0) + { + return -1; + } + end = end < dim_shape ? end : dim_shape; + int cur = start; + std::vector temp_indice; + if (step > 0) + { + while (cur < end && cur < dim_shape) + { + temp_indice.push_back(cur); + cur += step; + } + } + else if (step < 0) + { + while (cur > end && cur > 0) + { + temp_indice.push_back(cur); + cur += step; + } + } + else + { + fprintf(stderr, "step should not be 0!\n"); + return -100; + } + active_indice[positive_axis - 1] = temp_indice; + active_indice[positive_axis - 1].resize(temp_indice.size()); + } + top_blob.create((int)active_indice[1].size(), (int)active_indice[0].size(), elemsize, opt.blob_allocator); + for (int i = 0; i < active_indice[0].size(); i++) + { + for (int j = 0; j < active_indice[1].size(); j++) + { + top_blob.row(i)[j] = bottom_blob.row(active_indice[0][i])[active_indice[1][j]]; + } + } + return 0; } - } - } - return 0; - } - return 0; -} + if (dims == 3) + { + std::vector> active_indice; + std::vector indices; + for (int i = 0; i < bottom_blob.c; i++) + { + indices.push_back(i); + } + active_indice.push_back(indices); + indices.clear(); + for (int i = 0; i < bottom_blob.h; i++) + { + indices.push_back(i); + } + active_indice.push_back(indices); + indices.clear(); + for (int i = 0; i < bottom_blob.w; i++) + { + indices.push_back(i); + } + active_indice.push_back(indices); + for (int i = 0; i < axes.w; i++) + { + int positive_axis = axes_ptr[i] < 0 ? dims + axes_ptr[i] : axes_ptr[i]; + int step = step_ptr[i]; + + int start = start_ptr[i]; + int end = end_ptr[i]; + int cur = start; + std::vector temp_indice; + if (step > 0) + { + while (cur < end && cur < bottom_blob.w) + { + temp_indice.push_back(cur); + cur += step; + } + } + else if (step < 0) + { + while (cur > end && cur > 0) + { + temp_indice.push_back(cur); + cur += step; + } + } + else + { + fprintf(stderr, "step should not be 0!\n"); + return -100; + } + active_indice[positive_axis - 1] = temp_indice; + active_indice[positive_axis - 1].resize(temp_indice.size()); + } + top_blob.create((int)active_indice[2].size(), (int)active_indice[1].size(), (int)active_indice[0].size(), elemsize, opt.blob_allocator); + for (int i = 0; i < active_indice[0].size(); i++) + { + for (int j = 0; j < active_indice[1].size(); j++) + { + for (int k = 0; k < active_indice[2].size(); k++) + { + top_blob.channel(i).row(j)[k] = bottom_blob.channel(active_indice[0][i]) + .row(active_indice[1][j])[active_indice[2][k]]; + } + } + } + return 0; + } + + return 0; + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.h b/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.h old mode 100755 new mode 100644 index 9164d43335..14342c6f81 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/tensorslice/tensorslice.h @@ -4,23 +4,24 @@ #include "layer.h" -namespace mmdeploy { +namespace mmdeploy +{ -class TensorSlice : public ncnn::Layer { - public: - TensorSlice(); + class TensorSlice : public ncnn::Layer + { + public: + TensorSlice(); - virtual int load_param(const ncnn::ParamDict& pd); + virtual int load_param(const ncnn::ParamDict& pd); - virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, - const ncnn::Option& opt) const; + virtual int forward(const ncnn::Mat& bottom_blob, ncnn::Mat& top_blob, const ncnn::Option& opt) const; - public: - ncnn::Mat starts; - ncnn::Mat ends; - ncnn::Mat axes; - ncnn::Mat steps; -}; + public: + ncnn::Mat starts; + ncnn::Mat ends; + ncnn::Mat axes; + ncnn::Mat steps; + }; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.cpp b/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.cpp index f618831568..91235fa476 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.cpp @@ -6,872 +6,1118 @@ #include #include "../ncnn_ops_definer.h" -namespace mmdeploy { -using namespace ncnn; -DEFINE_LAYER_CREATOR(TopK) -DEFINE_NCNN_OPS(TopK, TopK) - -TopK::TopK() { - one_blob_only = false; - support_inplace = false; -} -int TopK::load_param(const ParamDict& pd) { - axis = pd.get(0, -1); - largest = pd.get(1, 1); - sorted = pd.get(2, 1); - keep_dims = pd.get(3, 1); - - return 0; -} -int TopK::forward(const std::vector& bottom_blobs, std::vector& top_blobs, - const Option& opt) const { - int dims = bottom_blobs[0].dims; - int positive_axis = axis < 0 ? dims + axis : axis; - int topk; - if (bottom_blobs.size() == 2) { - const Mat& topk_blob = bottom_blobs[1]; - topk = (int)(topk_blob[0] + 0.5); - } else if (bottom_blobs.size() == 1) { - topk = 1; - } else { - fprintf(stderr, "topk input blobs should be 1 or 2, but not %ld\n", bottom_blobs.size()); - return -103; - } - - // To do: Cut the top_val_blob after unit test. And we should change them in - // param files. - // Adaptive outputs. For onnx TopK, we output 2 blobs, for ArgMax, we output - // 1 blob. - Mat& top_val_blob = top_blobs[0]; - Mat& top_ind_blob = top_blobs.size() == 2 ? top_blobs[1] : top_val_blob; - - if (topk > 1) { - // real topk - if (keep_dims == 0) { - fprintf(stderr, "real topk should not reduce dims!\n"); - return -102; +namespace mmdeploy +{ + using namespace ncnn; + DEFINE_LAYER_CREATOR(TopK) + DEFINE_NCNN_OPS(TopK, TopK) + + TopK::TopK() + { + one_blob_only = false; + support_inplace = false; } - if (dims == 1 && positive_axis == 0) { - if (topk > bottom_blobs[0].w) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - const float* ptr = bottom_blobs[0]; - std::vector > vec; - vec.resize(bottom_blobs[0].w); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(ptr[i], -i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(ptr[i], i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - float* valptr = top_val_blob; - float* indptr = top_ind_blob; - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - valptr[i] = vec[i].first; - indptr[i] = abs(vec[i].second); - } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - - // pair comparison - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0][i] > valtarget) { - valptr[cur] = bottom_blobs[0][i]; - indptr[cur] = i; - cur++; - } else if (bottom_blobs[0][i] == valtarget && i <= indtarget) { - valptr[cur] = bottom_blobs[0][i]; - indptr[cur] = i; - cur++; - } - } - } else { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0][i] < valtarget) { - valptr[cur] = bottom_blobs[0][i]; - indptr[cur] = i; - cur++; - } else if (bottom_blobs[0][i] == valtarget && i <= indtarget) { - valptr[cur] = bottom_blobs[0][i]; - indptr[cur] = i; - cur++; - } - } - } - } + int TopK::load_param(const ParamDict& pd) + { + axis = pd.get(0, -1); + largest = pd.get(1, 1); + sorted = pd.get(2, 1); + keep_dims = pd.get(3, 1); + + return 0; } - if (dims == 2 && positive_axis == 0) { - if (topk > bottom_blobs[0].h) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - for (int col = 0; col < bottom_blobs[0].w; col++) { - std::vector > vec; - vec.resize(bottom_blobs[0].h); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = std::make_pair(bottom_blobs[0].row(i)[col], -i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = std::make_pair(bottom_blobs[0].row(i)[col], i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; + int TopK::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const + { + int dims = bottom_blobs[0].dims; + int positive_axis = axis < 0 ? dims + axis : axis; + int topk; + if (bottom_blobs.size() == 2) + { + const Mat& topk_blob = bottom_blobs[1]; + topk = (int)(topk_blob[0] + 0.5); } - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - top_val_blob.row(i)[col] = vec[i].first; - top_ind_blob.row(i)[col] = abs(vec[i].second); - } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].row(i)[col] > valtarget) { - top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; - top_ind_blob.row(cur)[col] = i; - cur++; - } else if (bottom_blobs[0].row(i)[col] == valtarget && i <= indtarget) { - top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; - top_ind_blob.row(cur)[col] = i; - cur++; - } - } - } else { - for (int i = 0; i < bottom_blobs[0].h; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].row(i)[col] < valtarget) { - top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; - top_ind_blob.row(cur)[col] = i; - cur++; - } else if (bottom_blobs[0].row(i)[col] == valtarget && i <= indtarget) { - top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; - top_ind_blob.row(cur)[col] = i; - cur++; - } - } - } - } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); - return -100; + else if (bottom_blobs.size() == 1) + { + topk = 1; } - } - } - if (dims == 2 && positive_axis == 1) { - if (topk > bottom_blobs[0].w) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - for (int r = 0; r < bottom_blobs[0].h; r++) { - std::vector > vec; - vec.resize(bottom_blobs[0].w); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(bottom_blobs[0].row(r)[i], -i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(bottom_blobs[0].row(r)[i], i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; + else + { + fprintf(stderr, "topk input blobs should be 1 or 2, but not %ld\n", bottom_blobs.size()); + return -103; } - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - top_val_blob.row(r)[i] = vec[i].first; - top_ind_blob.row(r)[i] = abs(vec[i].second); - } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].row(r)[i] > valtarget) { - top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; - top_ind_blob.row(r)[cur] = i; - cur++; - } else if (bottom_blobs[0].row(r)[i] == valtarget && i <= indtarget) { - top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; - top_ind_blob.row(r)[cur] = i; - cur++; - } - } - } else { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].row(r)[i] < valtarget) { - top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; - top_ind_blob.row(r)[cur] = i; - cur++; - } else if (bottom_blobs[0].row(r)[i] == valtarget && i <= indtarget) { - top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; - top_ind_blob.row(r)[cur] = i; - cur++; - } - } - } + // To do: Cut the top_val_blob after unit test. And we should change them in + // param files. + // Adaptive outputs. For onnx TopK, we output 2 blobs, for ArgMax, we output + // 1 blob. + Mat& top_val_blob = top_blobs[0]; + Mat& top_ind_blob = top_blobs.size() == 2 ? top_blobs[1] : top_val_blob; - } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); - return -100; - } - } - } - if (dims == 3 && positive_axis == 0) { - if (topk > bottom_blobs[0].c) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - for (int r = 0; r < bottom_blobs[0].h; r++) { - for (int col = 0; col < bottom_blobs[0].w; col++) { - std::vector > vec; - vec.resize(bottom_blobs[0].c); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].c; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(i).row(r)[col], -i); - } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].c; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(i).row(r)[col], i); + if (topk > 1) + { + // real topk + if (keep_dims == 0) + { + fprintf(stderr, "real topk should not reduce dims!\n"); + return -102; } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - top_val_blob.channel(i).row(r)[col] = vec[i].first; - top_ind_blob.channel(i).row(r)[col] = abs(vec[i].second); + if (dims == 1 && positive_axis == 0) + { + if (topk > bottom_blobs[0].w) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + top_ind_blob.create(topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + const float* ptr = bottom_blobs[0]; + std::vector> vec; + vec.resize(bottom_blobs[0].w); + + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = std::make_pair(ptr[i], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater>()); + } + else if (largest == 0) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = std::make_pair(ptr[i], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less>()); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + float* valptr = top_val_blob; + float* indptr = top_ind_blob; + if (sorted == 1) + { + for (int i = 0; i < topk; i++) + { + valptr[i] = vec[i].first; + indptr[i] = abs(vec[i].second); + } + } + else if (sorted == 0) + { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + + // pair comparison + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0][i] > valtarget) + { + valptr[cur] = bottom_blobs[0][i]; + indptr[cur] = i; + cur++; + } + else if (bottom_blobs[0][i] == valtarget && i <= indtarget) + { + valptr[cur] = bottom_blobs[0][i]; + indptr[cur] = i; + cur++; + } + } + } + else + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0][i] < valtarget) + { + valptr[cur] = bottom_blobs[0][i]; + indptr[cur] = i; + cur++; + } + else if (bottom_blobs[0][i] == valtarget && i <= indtarget) + { + valptr[cur] = bottom_blobs[0][i]; + indptr[cur] = i; + cur++; + } + } + } + } } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].c; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].channel(i).row(r)[col] > valtarget) { - top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; - top_ind_blob.channel(cur).row(r)[col] = i; - cur++; - } else if (bottom_blobs[0].channel(i).row(r)[col] == valtarget && i <= indtarget) { - top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; - top_ind_blob.channel(cur).row(r)[col] = i; - cur++; - } - } - } else { - for (int i = 0; i < bottom_blobs[0].c; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].channel(i).row(r)[col] < valtarget) { - top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; - top_ind_blob.channel(cur).row(r)[col] = i; - cur++; - } else if (bottom_blobs[0].channel(i).row(r)[col] == valtarget && i <= indtarget) { - top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; - top_ind_blob.channel(cur).row(r)[col] = i; - cur++; - } - } + if (dims == 2 && positive_axis == 0) + { + if (topk > bottom_blobs[0].h) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + top_ind_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + for (int col = 0; col < bottom_blobs[0].w; col++) + { + std::vector> vec; + vec.resize(bottom_blobs[0].h); + + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].row(i)[col], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater>()); + } + else if (largest == 0) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].row(i)[col], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less>()); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + if (sorted == 1) + { + for (int i = 0; i < topk; i++) + { + top_val_blob.row(i)[col] = vec[i].first; + top_ind_blob.row(i)[col] = abs(vec[i].second); + } + } + else if (sorted == 0) + { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].row(i)[col] > valtarget) + { + top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; + top_ind_blob.row(cur)[col] = i; + cur++; + } + else if (bottom_blobs[0].row(i)[col] == valtarget && i <= indtarget) + { + top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; + top_ind_blob.row(cur)[col] = i; + cur++; + } + } + } + else + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].row(i)[col] < valtarget) + { + top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; + top_ind_blob.row(cur)[col] = i; + cur++; + } + else if (bottom_blobs[0].row(i)[col] == valtarget && i <= indtarget) + { + top_val_blob.row(cur)[col] = bottom_blobs[0].row(i)[col]; + top_ind_blob.row(cur)[col] = i; + cur++; + } + } + } + } + else + { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); + return -100; + } + } } + if (dims == 2 && positive_axis == 1) + { + if (topk > bottom_blobs[0].w) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; - } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); - return -100; - } - } - } - } - if (dims == 3 && positive_axis == 1) { - if (topk > bottom_blobs[0].h) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - for (int page = 0; page < bottom_blobs[0].c; page++) { - for (int col = 0; col < bottom_blobs[0].w; col++) { - std::vector > vec; - vec.resize(bottom_blobs[0].h); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(i)[col], -i); + top_ind_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + for (int r = 0; r < bottom_blobs[0].h; r++) + { + std::vector> vec; + vec.resize(bottom_blobs[0].w); + + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].row(r)[i], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater>()); + } + else if (largest == 0) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].row(r)[i], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less>()); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + + if (sorted == 1) + { + for (int i = 0; i < topk; i++) + { + top_val_blob.row(r)[i] = vec[i].first; + top_ind_blob.row(r)[i] = abs(vec[i].second); + } + } + else if (sorted == 0) + { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].row(r)[i] > valtarget) + { + top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; + top_ind_blob.row(r)[cur] = i; + cur++; + } + else if (bottom_blobs[0].row(r)[i] == valtarget && i <= indtarget) + { + top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; + top_ind_blob.row(r)[cur] = i; + cur++; + } + } + } + else + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].row(r)[i] < valtarget) + { + top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; + top_ind_blob.row(r)[cur] = i; + cur++; + } + else if (bottom_blobs[0].row(r)[i] == valtarget && i <= indtarget) + { + top_val_blob.row(r)[cur] = bottom_blobs[0].row(r)[i]; + top_ind_blob.row(r)[cur] = i; + cur++; + } + } + } + } + else + { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); + return -100; + } + } } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(i)[col], i); + if (dims == 3 && positive_axis == 0) + { + if (topk > bottom_blobs[0].c) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + for (int r = 0; r < bottom_blobs[0].h; r++) + { + for (int col = 0; col < bottom_blobs[0].w; col++) + { + std::vector> vec; + vec.resize(bottom_blobs[0].c); + + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].c; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].channel(i).row(r)[col], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater>()); + } + else if (largest == 0) + { + for (int i = 0; i < bottom_blobs[0].c; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].channel(i).row(r)[col], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less>()); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + + if (sorted == 1) + { + for (int i = 0; i < topk; i++) + { + top_val_blob.channel(i).row(r)[col] = vec[i].first; + top_ind_blob.channel(i).row(r)[col] = abs(vec[i].second); + } + } + else if (sorted == 0) + { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].c; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].channel(i).row(r)[col] > valtarget) + { + top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; + top_ind_blob.channel(cur).row(r)[col] = i; + cur++; + } + else if (bottom_blobs[0].channel(i).row(r)[col] == valtarget && i <= indtarget) + { + top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; + top_ind_blob.channel(cur).row(r)[col] = i; + cur++; + } + } + } + else + { + for (int i = 0; i < bottom_blobs[0].c; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].channel(i).row(r)[col] < valtarget) + { + top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; + top_ind_blob.channel(cur).row(r)[col] = i; + cur++; + } + else if (bottom_blobs[0].channel(i).row(r)[col] == valtarget && i <= indtarget) + { + top_val_blob.channel(cur).row(r)[col] = bottom_blobs[0].channel(i).row(r)[col]; + top_ind_blob.channel(cur).row(r)[col] = i; + cur++; + } + } + } + } + else + { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); + return -100; + } + } + } } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - top_val_blob.channel(page).row(i)[col] = vec[i].first; - top_ind_blob.channel(page).row(i)[col] = abs(vec[i].second); + if (dims == 3 && positive_axis == 1) + { + if (topk > bottom_blobs[0].h) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + top_ind_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + for (int page = 0; page < bottom_blobs[0].c; page++) + { + for (int col = 0; col < bottom_blobs[0].w; col++) + { + std::vector> vec; + vec.resize(bottom_blobs[0].h); + + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(i)[col], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater>()); + } + else if (largest == 0) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(i)[col], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less>()); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + + if (sorted == 1) + { + for (int i = 0; i < topk; i++) + { + top_val_blob.channel(page).row(i)[col] = vec[i].first; + top_ind_blob.channel(page).row(i)[col] = abs(vec[i].second); + } + } + else if (sorted == 0) + { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + for (int i = 0; i < bottom_blobs[0].h; i++) + { + if (cur >= topk) break; + if (largest == 1) + { + if (bottom_blobs[0].channel(page).row(i)[col] > valtarget) + { + top_val_blob.channel(page).row(cur)[col] = + bottom_blobs[0].channel(page).row(i)[col]; + top_ind_blob.channel(page).row(cur)[col] = i; + cur++; + } + else if (bottom_blobs[0].channel(page).row(i)[col] == valtarget && + i <= indtarget) + { + top_val_blob.channel(page).row(cur)[col] = + bottom_blobs[0].channel(page).row(i)[col]; + top_ind_blob.channel(page).row(cur)[col] = i; + cur++; + } + } + else + { + if (bottom_blobs[0].channel(page).row(i)[col] < valtarget) + { + top_val_blob.channel(page).row(cur)[col] = + bottom_blobs[0].channel(page).row(i)[col]; + top_ind_blob.channel(page).row(cur)[col] = i; + cur++; + } + else if (bottom_blobs[0].channel(page).row(i)[col] == valtarget && + i <= indtarget) + { + top_val_blob.channel(page).row(cur)[col] = + bottom_blobs[0].channel(page).row(i)[col]; + top_ind_blob.channel(page).row(cur)[col] = i; + cur++; + } + } + } + } + else + { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); + return -100; + } + } + } } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - for (int i = 0; i < bottom_blobs[0].h; i++) { - if (cur >= topk) break; - if (largest == 1) { - if (bottom_blobs[0].channel(page).row(i)[col] > valtarget) { - top_val_blob.channel(page).row(cur)[col] = - bottom_blobs[0].channel(page).row(i)[col]; - top_ind_blob.channel(page).row(cur)[col] = i; - cur++; - } else if (bottom_blobs[0].channel(page).row(i)[col] == valtarget && - i <= indtarget) { - top_val_blob.channel(page).row(cur)[col] = - bottom_blobs[0].channel(page).row(i)[col]; - top_ind_blob.channel(page).row(cur)[col] = i; - cur++; - } - } else { - if (bottom_blobs[0].channel(page).row(i)[col] < valtarget) { - top_val_blob.channel(page).row(cur)[col] = - bottom_blobs[0].channel(page).row(i)[col]; - top_ind_blob.channel(page).row(cur)[col] = i; - cur++; - } else if (bottom_blobs[0].channel(page).row(i)[col] == valtarget && - i <= indtarget) { - top_val_blob.channel(page).row(cur)[col] = - bottom_blobs[0].channel(page).row(i)[col]; - top_ind_blob.channel(page).row(cur)[col] = i; - cur++; - } - } + if (dims == 3 && positive_axis == 2) + { + if (topk > bottom_blobs[0].w) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + top_ind_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + + for (int page = 0; page < bottom_blobs[0].c; page++) + { + for (int r = 0; r < bottom_blobs[0].h; r++) + { + std::vector> vec; + vec.resize(bottom_blobs[0].w); + + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(r)[i], -i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::greater>()); + } + else if (largest == 0) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(r)[i], i); + } + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), std::less>()); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + + if (sorted == 1) + { + for (int i = 0; i < topk; i++) + { + top_val_blob.channel(page).row(r)[i] = vec[i].first; + top_ind_blob.channel(page).row(r)[i] = abs(vec[i].second); + } + } + else if (sorted == 0) + { + int cur = 0; + float valtarget = vec[topk - 1].first; + int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); + if (largest == 1) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].channel(page).row(r)[i] > valtarget) + { + top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; + top_ind_blob.channel(page).row(r)[cur] = i; + cur++; + } + else if (bottom_blobs[0].channel(page).row(r)[i] == valtarget && i <= indtarget) + { + top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; + top_ind_blob.channel(page).row(r)[cur] = i; + cur++; + } + } + } + else + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + if (cur >= topk) break; + if (bottom_blobs[0].channel(page).row(r)[i] < valtarget) + { + top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; + top_ind_blob.channel(page).row(r)[cur] = i; + cur++; + } + else if (bottom_blobs[0].channel(page).row(r)[i] == valtarget && i <= indtarget) + { + top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; + top_ind_blob.channel(page).row(r)[cur] = i; + cur++; + } + } + } + } + else + { + fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); + return -100; + } + } + } } - } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); - return -100; - } } - } - } - if (dims == 3 && positive_axis == 2) { - if (topk > bottom_blobs[0].w) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - top_ind_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - - for (int page = 0; page < bottom_blobs[0].c; page++) { - for (int r = 0; r < bottom_blobs[0].h; r++) { - std::vector > vec; - vec.resize(bottom_blobs[0].w); - - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(r)[i], -i); + else + { + if (topk <= 0) + { + fprintf(stderr, "topk should not <= 0!\n"); + return -102; } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); - } else if (largest == 0) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = std::make_pair(bottom_blobs[0].channel(page).row(r)[i], i); + if (dims == 1 && positive_axis == 0) + { + if (topk > bottom_blobs[0].w) + { + fprintf(stderr, "topk should not greater than total items!\n"); + return -100; + } + top_val_blob.create(topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + if (top_blobs.size() == 2) + { + top_ind_blob.create(topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + + const float* ptr = bottom_blobs[0]; + std::vector vec; + vec.resize(bottom_blobs[0].w); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = ptr[i]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[0] = *index_iter; + if (top_blobs.size() == 2) + indptr[0] = std::distance(vec.begin(), index_iter); + else + valptr[0] = std::distance(vec.begin(), index_iter); // replace with index + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[0] = *index_iter; + if (top_blobs.size() == 2) + indptr[0] = std::distance(vec.begin(), index_iter); + else + valptr[0] = std::distance(vec.begin(), index_iter); // replace with index + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } } - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::less >()); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - - if (sorted == 1) { - for (int i = 0; i < topk; i++) { - top_val_blob.channel(page).row(r)[i] = vec[i].first; - top_ind_blob.channel(page).row(r)[i] = abs(vec[i].second); + if (dims == 2 && positive_axis == 0) + { + if (keep_dims == 1) + { + top_val_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + else + { + top_val_blob.create(bottom_blobs[0].w, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].w, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + const float* ptr = bottom_blobs[0]; + std::vector vec; + vec.resize(bottom_blobs[0].h); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + for (int col = 0; col < bottom_blobs[0].w; col++) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = ptr[i * bottom_blobs[0].w + col]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[col] = *index_iter; + if (top_blobs.size() == 2) + indptr[col] = std::distance(vec.begin(), index_iter); + else + valptr[col] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[col] = *index_iter; + if (top_blobs.size() == 2) + indptr[col] = std::distance(vec.begin(), index_iter); + else + valptr[col] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } } - } else if (sorted == 0) { - int cur = 0; - float valtarget = vec[topk - 1].first; - int indtarget = (int)(abs(vec[topk - 1].second) + 0.5); - if (largest == 1) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].channel(page).row(r)[i] > valtarget) { - top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; - top_ind_blob.channel(page).row(r)[cur] = i; - cur++; - } else if (bottom_blobs[0].channel(page).row(r)[i] == valtarget && i <= indtarget) { - top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; - top_ind_blob.channel(page).row(r)[cur] = i; - cur++; - } - } - } else { - for (int i = 0; i < bottom_blobs[0].w; i++) { - if (cur >= topk) break; - if (bottom_blobs[0].channel(page).row(r)[i] < valtarget) { - top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; - top_ind_blob.channel(page).row(r)[cur] = i; - cur++; - } else if (bottom_blobs[0].channel(page).row(r)[i] == valtarget && i <= indtarget) { - top_val_blob.channel(page).row(r)[cur] = bottom_blobs[0].channel(page).row(r)[i]; - top_ind_blob.channel(page).row(r)[cur] = i; - cur++; - } - } + if (dims == 2 && positive_axis == 1) + { + if (keep_dims == 1) + { + top_val_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + else + { + top_val_blob.create(bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + + const float* ptr = bottom_blobs[0]; + std::vector vec; + vec.resize(bottom_blobs[0].w); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + + for (int r = 0; r < bottom_blobs[0].h; r++) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = ptr[r * bottom_blobs[0].w + i]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[r] = *index_iter; + if (top_blobs.size() == 2) + indptr[r] = std::distance(vec.begin(), index_iter); + else + valptr[r] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[r] = *index_iter; + if (top_blobs.size() == 2) + indptr[r] = std::distance(vec.begin(), index_iter); + else + valptr[r] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } } + if (dims == 3 && positive_axis == 0) + { + if (keep_dims == 1) + { + top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + else + { + top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } + } + const float* ptr = bottom_blobs[0]; + std::vector vec; + vec.resize(bottom_blobs[0].c); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; - } else { - fprintf(stderr, "sorted attribute should be 0 or 1, but not %d\n", sorted); - return -100; - } - } - } - } - } else { - if (topk <= 0) { - fprintf(stderr, "topk should not <= 0!\n"); - return -102; - } - if (dims == 1 && positive_axis == 0) { - if (topk > bottom_blobs[0].w) { - fprintf(stderr, "topk should not greater than total items!\n"); - return -100; - } - top_val_blob.create(topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - - if (top_blobs.size() == 2) { - top_ind_blob.create(topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } - - const float* ptr = bottom_blobs[0]; - std::vector vec; - vec.resize(bottom_blobs[0].w); - float* valptr = top_val_blob; - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob; - - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = ptr[i]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[0] = *index_iter; - if (top_blobs.size() == 2) - indptr[0] = std::distance(vec.begin(), index_iter); - else - valptr[0] = std::distance(vec.begin(), index_iter); // replace with index - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[0] = *index_iter; - if (top_blobs.size() == 2) - indptr[0] = std::distance(vec.begin(), index_iter); - else - valptr[0] = std::distance(vec.begin(), index_iter); // replace with index - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - } - if (dims == 2 && positive_axis == 0) { - if (keep_dims == 1) { - top_val_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].w, topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + for (int r = 0; r < bottom_blobs[0].h; r++) + { + for (int col = 0; col < bottom_blobs[0].w; col++) + { + for (int i = 0; i < bottom_blobs[0].c; i++) + { + ptr = bottom_blobs[0].channel(i); + vec[i] = ptr[r * bottom_blobs[0].w + col]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[r * top_val_blob.w + col] = *index_iter; + if (top_blobs.size() == 2) + indptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + else + valptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[r * top_val_blob.w + col] = *index_iter; - } else { - top_val_blob.create(bottom_blobs[0].w, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + indptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + else + valptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } + } + } + if (dims == 3 && positive_axis == 1) + { + if (keep_dims == 1) + { + top_val_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].w, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } - } - const float* ptr = bottom_blobs[0]; - std::vector vec; - vec.resize(bottom_blobs[0].h); - float* valptr = top_val_blob; - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob; - for (int col = 0; col < bottom_blobs[0].w; col++) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = ptr[i * bottom_blobs[0].w + col]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[col] = *index_iter; - if (top_blobs.size() == 2) - indptr[col] = std::distance(vec.begin(), index_iter); - else - valptr[col] = std::distance(vec.begin(), index_iter); - - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[col] = *index_iter; - if (top_blobs.size() == 2) - indptr[col] = std::distance(vec.begin(), index_iter); - else - valptr[col] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - } - } - if (dims == 2 && positive_axis == 1) { - if (keep_dims == 1) { - top_val_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(topk, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + std::vector vec; + vec.resize(bottom_blobs[0].h); - } else { - top_val_blob.create(bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } - } - - const float* ptr = bottom_blobs[0]; - std::vector vec; - vec.resize(bottom_blobs[0].w); - float* valptr = top_val_blob; - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob; - - for (int r = 0; r < bottom_blobs[0].h; r++) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = ptr[r * bottom_blobs[0].w + i]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[r] = *index_iter; - if (top_blobs.size() == 2) - indptr[r] = std::distance(vec.begin(), index_iter); - else - valptr[r] = std::distance(vec.begin(), index_iter); - - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[r] = *index_iter; - if (top_blobs.size() == 2) - indptr[r] = std::distance(vec.begin(), index_iter); - else - valptr[r] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - } - } - if (dims == 3 && positive_axis == 0) { - if (keep_dims == 1) { - top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, topk, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + for (int page = 0; page < bottom_blobs[0].c; page++) + { + const float* ptr = bottom_blobs[0].channel(page); + float* valptr = top_val_blob.channel(page); + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob.channel(page); + for (int col = 0; col < bottom_blobs[0].w; col++) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = ptr[i * bottom_blobs[0].w + col]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[col] = *index_iter; + if (top_blobs.size() == 2) + indptr[col] = std::distance(vec.begin(), index_iter); + else + valptr[col] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[col] = *index_iter; + if (top_blobs.size() == 2) + indptr[col] = std::distance(vec.begin(), index_iter); + else + valptr[col] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } + } + } + else + { + top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } - } else { - top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].h, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } - } - const float* ptr = bottom_blobs[0]; - std::vector vec; - vec.resize(bottom_blobs[0].c); - float* valptr = top_val_blob; - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob; - - for (int r = 0; r < bottom_blobs[0].h; r++) { - for (int col = 0; col < bottom_blobs[0].w; col++) { - for (int i = 0; i < bottom_blobs[0].c; i++) { - ptr = bottom_blobs[0].channel(i); - vec[i] = ptr[r * bottom_blobs[0].w + col]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[r * top_val_blob.w + col] = *index_iter; - if (top_blobs.size() == 2) - indptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - else - valptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[r * top_val_blob.w + col] = *index_iter; - - if (top_blobs.size() == 2) - indptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - else - valptr[r * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - } - } - } - if (dims == 3 && positive_axis == 1) { - if (keep_dims == 1) { - top_val_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].w, topk, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + std::vector vec; + vec.resize(bottom_blobs[0].h); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; - std::vector vec; - vec.resize(bottom_blobs[0].h); - - for (int page = 0; page < bottom_blobs[0].c; page++) { - const float* ptr = bottom_blobs[0].channel(page); - float* valptr = top_val_blob.channel(page); - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob.channel(page); - for (int col = 0; col < bottom_blobs[0].w; col++) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = ptr[i * bottom_blobs[0].w + col]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[col] = *index_iter; - if (top_blobs.size() == 2) - indptr[col] = std::distance(vec.begin(), index_iter); - else - valptr[col] = std::distance(vec.begin(), index_iter); - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[col] = *index_iter; - if (top_blobs.size() == 2) - indptr[col] = std::distance(vec.begin(), index_iter); - else - valptr[col] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; + for (int page = 0; page < bottom_blobs[0].c; page++) + { + const float* ptr = bottom_blobs[0].channel(page); + for (int col = 0; col < bottom_blobs[0].w; col++) + { + for (int i = 0; i < bottom_blobs[0].h; i++) + { + vec[i] = ptr[i * bottom_blobs[0].w + col]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[page * top_val_blob.w + col] = *index_iter; + if (top_blobs.size() == 2) + indptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + else + valptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[page * top_val_blob.w + col] = *index_iter; + if (top_blobs.size() == 2) + indptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + else + valptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } + } + } } - } - } - } else { - top_val_blob.create(bottom_blobs[0].w, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].w, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + if (dims == 3 && positive_axis == 2) + { + if (keep_dims == 1) + { + top_val_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } - std::vector vec; - vec.resize(bottom_blobs[0].h); - float* valptr = top_val_blob; - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob; - - for (int page = 0; page < bottom_blobs[0].c; page++) { - const float* ptr = bottom_blobs[0].channel(page); - for (int col = 0; col < bottom_blobs[0].w; col++) { - for (int i = 0; i < bottom_blobs[0].h; i++) { - vec[i] = ptr[i * bottom_blobs[0].w + col]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[page * top_val_blob.w + col] = *index_iter; - if (top_blobs.size() == 2) - indptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - else - valptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[page * top_val_blob.w + col] = *index_iter; - if (top_blobs.size() == 2) - indptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - else - valptr[page * top_ind_blob.w + col] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - } - } - } - } - if (dims == 3 && positive_axis == 2) { - if (keep_dims == 1) { - top_val_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(topk, bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + std::vector vec; + vec.resize(bottom_blobs[0].w); - std::vector vec; - vec.resize(bottom_blobs[0].w); - - for (int page = 0; page < bottom_blobs[0].c; page++) { - const float* ptr = bottom_blobs[0].channel(page); - float* valptr = top_val_blob.channel(page); - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob.channel(page); - for (int r = 0; r < bottom_blobs[0].h; r++) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = ptr[r * bottom_blobs[0].w + i]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[r] = *index_iter; - if (top_blobs.size() == 2) - indptr[r] = std::distance(vec.begin(), index_iter); - else - valptr[r] = std::distance(vec.begin(), index_iter); - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[r] = *index_iter; - if (top_blobs.size() == 2) - indptr[r] = std::distance(vec.begin(), index_iter); - else - valptr[r] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; - } - } - } - } else { - top_val_blob.create(bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_val_blob.empty()) return -100; - if (top_blobs.size() == 2) { - top_ind_blob.create(bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); - if (top_ind_blob.empty()) return -100; - } + for (int page = 0; page < bottom_blobs[0].c; page++) + { + const float* ptr = bottom_blobs[0].channel(page); + float* valptr = top_val_blob.channel(page); + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob.channel(page); + for (int r = 0; r < bottom_blobs[0].h; r++) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = ptr[r * bottom_blobs[0].w + i]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[r] = *index_iter; + if (top_blobs.size() == 2) + indptr[r] = std::distance(vec.begin(), index_iter); + else + valptr[r] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[r] = *index_iter; + if (top_blobs.size() == 2) + indptr[r] = std::distance(vec.begin(), index_iter); + else + valptr[r] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } + } + } + else + { + top_val_blob.create(bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_val_blob.empty()) return -100; + if (top_blobs.size() == 2) + { + top_ind_blob.create(bottom_blobs[0].h, bottom_blobs[0].c, 4u, opt.blob_allocator); + if (top_ind_blob.empty()) return -100; + } - std::vector vec; - vec.resize(bottom_blobs[0].w); - float* valptr = top_val_blob; - float* indptr; - if (top_blobs.size() == 2) indptr = top_ind_blob; - - for (int page = 0; page < bottom_blobs[0].c; page++) { - const float* ptr = bottom_blobs[0].channel(page); - for (int r = 0; r < bottom_blobs[0].h; r++) { - for (int i = 0; i < bottom_blobs[0].w; i++) { - vec[i] = ptr[r * bottom_blobs[0].w + i]; - } - if (largest == 1) { - auto index_iter = std::max_element(vec.begin(), vec.end()); - valptr[page * top_val_blob.w + r] = *index_iter; - if (top_blobs.size() == 2) - indptr[page * top_ind_blob.w + r] = std::distance(vec.begin(), index_iter); - else - valptr[page * top_ind_blob.w + r] = std::distance(vec.begin(), index_iter); - } else if (largest == 0) { - auto index_iter = std::min_element(vec.begin(), vec.end()); - valptr[page * top_val_blob.w + r] = *index_iter; - if (top_blobs.size() == 2) - indptr[page * top_val_blob.w + r] = std::distance(vec.begin(), index_iter); - else - valptr[page * top_ind_blob.w + r] = std::distance(vec.begin(), index_iter); - } else { - fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); - return -100; + std::vector vec; + vec.resize(bottom_blobs[0].w); + float* valptr = top_val_blob; + float* indptr; + if (top_blobs.size() == 2) indptr = top_ind_blob; + + for (int page = 0; page < bottom_blobs[0].c; page++) + { + const float* ptr = bottom_blobs[0].channel(page); + for (int r = 0; r < bottom_blobs[0].h; r++) + { + for (int i = 0; i < bottom_blobs[0].w; i++) + { + vec[i] = ptr[r * bottom_blobs[0].w + i]; + } + if (largest == 1) + { + auto index_iter = std::max_element(vec.begin(), vec.end()); + valptr[page * top_val_blob.w + r] = *index_iter; + if (top_blobs.size() == 2) + indptr[page * top_ind_blob.w + r] = std::distance(vec.begin(), index_iter); + else + valptr[page * top_ind_blob.w + r] = std::distance(vec.begin(), index_iter); + } + else if (largest == 0) + { + auto index_iter = std::min_element(vec.begin(), vec.end()); + valptr[page * top_val_blob.w + r] = *index_iter; + if (top_blobs.size() == 2) + indptr[page * top_val_blob.w + r] = std::distance(vec.begin(), index_iter); + else + valptr[page * top_ind_blob.w + r] = std::distance(vec.begin(), index_iter); + } + else + { + fprintf(stderr, "largest attribute should be 0 or 1, but not %d\n", largest); + return -100; + } + } + } + } } - } } - } + return 0; } - } - return 0; -} } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.h b/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.h index d390fbafcd..e9bbde1297 100644 --- a/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.h +++ b/csrc/mmdeploy/backend_ops/ncnn/ops/topk/topk.h @@ -4,21 +4,22 @@ #include "layer.h" -namespace mmdeploy { +namespace mmdeploy +{ -class TopK : public ncnn::Layer { - public: - TopK(); - virtual int load_param(const ncnn::ParamDict& pd); - virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, - const ncnn::Option& opt) const; + class TopK : public ncnn::Layer + { + public: + TopK(); + virtual int load_param(const ncnn::ParamDict& pd); + virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const ncnn::Option& opt) const; - public: - int axis; - int largest; - int sorted; - int keep_dims; -}; + public: + int axis; + int largest; + int sorted; + int keep_dims; + }; } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/ncnn_ext.cpp b/csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/ncnn_ext.cpp old mode 100755 new mode 100644 index ac158b9edb..1c8ad70cc7 --- a/csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/ncnn_ext.cpp +++ b/csrc/mmdeploy/backend_ops/ncnn/pyncnn_ext/ncnn_ext.cpp @@ -4,9 +4,11 @@ #include "ncnn_ops_register.h" #include "net.h" -PYBIND11_MODULE(ncnn_ext, m) { - m.def( - "register_mmdeploy_custom_layers", - [](ncnn::Net &net) { return register_mmdeploy_custom_layers(net); }, - "register mmdeploy custom ncnn layers."); +PYBIND11_MODULE(ncnn_ext, m) +{ + m.def( + "register_mmdeploy_custom_layers", + [](ncnn::Net& net) + { return register_mmdeploy_custom_layers(net); }, + "register mmdeploy custom ncnn layers."); } diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/common/onnxruntime_register.h b/csrc/mmdeploy/backend_ops/onnxruntime/common/onnxruntime_register.h index 28d2a2b782..1095c28bae 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/common/onnxruntime_register.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/common/onnxruntime_register.h @@ -6,11 +6,12 @@ #include "mmdeploy/core/macro.h" #ifdef __cplusplus -extern "C" { +extern "C" +{ #endif -MMDEPLOY_API OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, - const OrtApiBase *api); + MMDEPLOY_API OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, + const OrtApiBase* api); #ifdef __cplusplus } diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.cpp index c604e4b650..da959ec37e 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.cpp @@ -1,10 +1,12 @@ // Copyright (c) OpenMMLab. All rights reserved. #include "ort_utils.h" -namespace mmdeploy { +namespace mmdeploy +{ -CustomOpsTable& get_mmdeploy_custom_ops() { - static CustomOpsTable _custom_ops; - return _custom_ops; -} + CustomOpsTable& get_mmdeploy_custom_ops() + { + static CustomOpsTable _custom_ops; + return _custom_ops; + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.h b/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.h index e19c984f86..14d2da3457 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/common/ort_utils.h @@ -6,32 +6,39 @@ #include #include -namespace mmdeploy { - -typedef std::unordered_map> CustomOpsTable; - -struct OrtTensorDimensions : std::vector { - OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) { - OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value); - std::vector::operator=(ort.GetTensorShape(info)); - ort.ReleaseTensorTypeAndShapeInfo(info); - } -}; - -CustomOpsTable& get_mmdeploy_custom_ops(); - -template -class OrtOpsRegistry { - public: - OrtOpsRegistry() { get_mmdeploy_custom_ops()[domain].push_back(&instance); } - - private: - T instance{}; -}; - -#define REGISTER_ONNXRUNTIME_OPS(domain, name) \ - static char __domain_##domain##name[] = #domain; \ - static OrtOpsRegistry<__domain_##domain##name, name> ort_ops_registry_##domain##name {} +namespace mmdeploy +{ + + typedef std::unordered_map> CustomOpsTable; + + struct OrtTensorDimensions : std::vector + { + OrtTensorDimensions(Ort::CustomOpApi ort, const OrtValue* value) + { + OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value); + std::vector::operator=(ort.GetTensorShape(info)); + ort.ReleaseTensorTypeAndShapeInfo(info); + } + }; + + CustomOpsTable& get_mmdeploy_custom_ops(); + + template + class OrtOpsRegistry + { + public: + OrtOpsRegistry() + { + get_mmdeploy_custom_ops()[domain].push_back(&instance); + } + + private: + T instance{}; + }; + +#define REGISTER_ONNXRUNTIME_OPS(domain, name) \ + static char __domain_##domain##name[] = #domain; \ + static OrtOpsRegistry<__domain_##domain##name, name> ort_ops_registry_##domain##name {} } // namespace mmdeploy #endif // ORT_MMCV_UTILS_H diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.cpp index c7fed37d23..27eb677394 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.cpp @@ -8,287 +8,335 @@ #include "ort_utils.h" -namespace mmdeploy { +namespace mmdeploy +{ #define MIN(a, b) (((a) < (b)) ? (a) : (b)) #define MAX(a, b) (((a) < (b)) ? (b) : (a)) #define CLIP_COORDINATES(in, out, clip_limit) out = MIN((clip_limit - 1), MAX(in, 0)) -GridSampleKernel::GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info) - : ort_(api), info_(info) { - align_corners_ = ort_.KernelInfoGetAttribute(info, "align_corners"); - interpolation_mode_ = ort_.KernelInfoGetAttribute(info, "interpolation_mode"); - padding_mode_ = ort_.KernelInfoGetAttribute(info, "padding_mode"); - - allocator_ = Ort::AllocatorWithDefaultOptions(); -} - -enum GridSamplerInterpolation { Bilinear = 0, Nearest = 1, Bicubic = 2 }; -enum GridSamplerPadding { Zeros = 0, Border = 1, Reflection = 2 }; - -template -static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size, bool align_corners) { - if (align_corners) { - return ((coord + 1) / 2) * (size - 1); - } else { - return ((coord + 1) * size - 1) / 2; - } -} - -// Clips coordinates to between 0 and clip_limit - 1 -template -static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) { - return std::min(static_cast(clip_limit - 1), std::max(in, static_cast(0))); -} - -// Reflects coordinates until they fall between low and high (inclusive). -// The bounds are passed as twice their value so that half-integer values -// can be represented as ints. -template -static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low, int64_t twice_high) { - if (twice_low == twice_high) { - return static_cast(0); - } - scalar_t min = static_cast(twice_low) / 2; - scalar_t span = static_cast(twice_high - twice_low) / 2; - in = std::fabs(in - min); - // `fmod` returns same sign as `in`, which is positive after the `fabs` above. - scalar_t extra = std::fmod(in, span); - int flips = static_cast(std::floor(in / span)); - if (flips % 2 == 0) { - return extra + min; - } else { - return span - extra + min; - } -} - -template -static inline scalar_t compute_coordinates(scalar_t coord, int64_t size, int64_t padding_mode, - bool align_corners) { - if (padding_mode == GridSamplerPadding::Border) { - coord = clip_coordinates(coord, size); - } else if (padding_mode == GridSamplerPadding::Reflection) { - if (align_corners) { - coord = reflect_coordinates(coord, 0, 2 * (size - 1)); - } else { - coord = reflect_coordinates(coord, -1, 2 * size - 1); + GridSampleKernel::GridSampleKernel(const OrtApi& api, const OrtKernelInfo* info) + : ort_(api) + , info_(info) + { + align_corners_ = ort_.KernelInfoGetAttribute(info, "align_corners"); + interpolation_mode_ = ort_.KernelInfoGetAttribute(info, "interpolation_mode"); + padding_mode_ = ort_.KernelInfoGetAttribute(info, "padding_mode"); + + allocator_ = Ort::AllocatorWithDefaultOptions(); } - coord = clip_coordinates(coord, size); - } - return coord; -} - -// Computes the pixel source index value for a grid coordinate -template -static inline scalar_t grid_sampler_compute_source_index(scalar_t coord, int64_t size, - int64_t padding_mode, bool align_corners) { - coord = grid_sampler_unnormalize(coord, size, align_corners); - coord = compute_coordinates(coord, size, padding_mode, align_corners); - return coord; -} - -static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) { - return h >= 0 && h < H && w >= 0 && w < W; -} - -template -static inline scalar_t get_value_bounded(const scalar_t *data, scalar_t x, scalar_t y, int64_t W, - int64_t H, int64_t sW, int64_t sH, int64_t padding_mode, - bool align_corners) { - x = compute_coordinates(x, W, padding_mode, align_corners); - y = compute_coordinates(y, H, padding_mode, align_corners); - - int64_t ix = static_cast(x); - int64_t iy = static_cast(y); - - if (within_bounds_2d(iy, ix, H, W)) { - return data[iy * sH + ix * sW]; - } - return static_cast(0); -} - -template -static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) { - return ((A + 2) * x - (A + 3)) * x * x + 1; -} - -template -static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) { - return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; -} - -template -static inline void get_cubic_upsample_coefficients(scalar_t coeffs[4], scalar_t t) { - scalar_t A = -0.75; - - scalar_t x1 = t; - coeffs[0] = cubic_convolution2(x1 + 1.0, A); - coeffs[1] = cubic_convolution1(x1, A); - - // opposite coefficients - scalar_t x2 = 1.0 - t; - coeffs[2] = cubic_convolution1(x2, A); - coeffs[3] = cubic_convolution2(x2 + 1.0, A); -} - -template -static inline scalar_t cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, - scalar_t t) { - scalar_t coeffs[4]; - get_cubic_upsample_coefficients(coeffs, t); - - return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; -} - -void GridSampleKernel::Compute(OrtKernelContext *context) { - const bool align_corners = align_corners_; - const int64_t padding_mode = padding_mode_; - const int64_t interpolation_mode = interpolation_mode_; - - const OrtValue *input = ort_.KernelContext_GetInput(context, 0); - const float *input_data = reinterpret_cast(ort_.GetTensorData(input)); - - const OrtValue *grid = ort_.KernelContext_GetInput(context, 1); - const float *grid_data = reinterpret_cast(ort_.GetTensorData(grid)); - - OrtTensorDimensions input_dims(ort_, input); - OrtTensorDimensions grid_dims(ort_, grid); - int64_t N = input_dims[0]; - int64_t C = input_dims[1]; - int64_t inp_H = input_dims[2]; - int64_t inp_W = input_dims[3]; - int64_t out_H = grid_dims[1]; - int64_t out_W = grid_dims[2]; - - std::vector output_dims = {N, C, out_H, out_W}; - OrtValue *output = - ort_.KernelContext_GetOutput(context, 0, output_dims.data(), output_dims.size()); - float *out_ptr = ort_.GetTensorMutableData(output); - - int64_t inp_sN = input_dims[1] * input_dims[2] * input_dims[3]; - int64_t inp_sC = input_dims[2] * input_dims[3]; - int64_t inp_sH = input_dims[3]; - int64_t inp_sW = 1; - int64_t grid_sN = grid_dims[1] * grid_dims[2] * grid_dims[3]; - int64_t grid_sH = grid_dims[2] * grid_dims[3]; - int64_t grid_sW = grid_dims[3]; - int64_t grid_sCoor = 1; - int64_t out_sN = output_dims[1] * output_dims[2] * output_dims[3]; - int64_t out_sC = output_dims[2] * output_dims[3]; - int64_t out_sH = output_dims[3]; - int64_t out_sW = 1; - - // loop over each output pixel - for (int64_t n = 0; n < N; ++n) { - const float *grid_ptr_N = grid_data + n * grid_sN; - const float *inp_ptr_N = input_data + n * inp_sN; - for (int64_t h = 0; h < out_H; ++h) { - for (int64_t w = 0; w < out_W; ++w) { - const float *grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; - float x = *grid_ptr_NHW; - float y = grid_ptr_NHW[grid_sCoor]; - - float ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); - float iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); - - if (interpolation_mode == GridSamplerInterpolation::Bilinear) { - // get corner pixel values from (x, y) - // for 4d, we use north-east-south-west - int64_t ix_nw = static_cast(std::floor(ix)); - int64_t iy_nw = static_cast(std::floor(iy)); - - int64_t ix_ne = ix_nw + 1; - int64_t iy_ne = iy_nw; - - int64_t ix_sw = ix_nw; - int64_t iy_sw = iy_nw + 1; - - int64_t ix_se = ix_nw + 1; - int64_t iy_se = iy_nw + 1; - - // get surfaces to each neighbor: - float nw = (ix_se - ix) * (iy_se - iy); - float ne = (ix - ix_sw) * (iy_sw - iy); - float sw = (ix_ne - ix) * (iy - iy_ne); - float se = (ix - ix_nw) * (iy - iy_nw); - - // calculate bilinear weighted pixel value and set output pixel - const float *inp_ptr_NC = inp_ptr_N; - float *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; - for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) { - auto res = static_cast(0); - if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { - res += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; - } - if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { - res += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; - } - if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { - res += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; - } - if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { - res += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; - } - *out_ptr_NCHW = res; - } - } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { - int64_t ix_nearest = static_cast(std::nearbyint(ix)); - int64_t iy_nearest = static_cast(std::nearbyint(iy)); - - // assign nearest neighbor pixel value to output pixel - float *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; - const float *inp_ptr_NC = inp_ptr_N; - for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) { - if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) { - *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; - } else { - *out_ptr_NCHW = static_cast(0); + + enum GridSamplerInterpolation + { + Bilinear = 0, + Nearest = 1, + Bicubic = 2 + }; + enum GridSamplerPadding + { + Zeros = 0, + Border = 1, + Reflection = 2 + }; + + template + static inline scalar_t grid_sampler_unnormalize(scalar_t coord, int64_t size, bool align_corners) + { + if (align_corners) + { + return ((coord + 1) / 2) * (size - 1); + } + else + { + return ((coord + 1) * size - 1) / 2; + } + } + + // Clips coordinates to between 0 and clip_limit - 1 + template + static inline scalar_t clip_coordinates(scalar_t in, int64_t clip_limit) + { + return std::min(static_cast(clip_limit - 1), std::max(in, static_cast(0))); + } + + // Reflects coordinates until they fall between low and high (inclusive). + // The bounds are passed as twice their value so that half-integer values + // can be represented as ints. + template + static inline scalar_t reflect_coordinates(scalar_t in, int64_t twice_low, int64_t twice_high) + { + if (twice_low == twice_high) + { + return static_cast(0); + } + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = std::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = std::fmod(in, span); + int flips = static_cast(std::floor(in / span)); + if (flips % 2 == 0) + { + return extra + min; + } + else + { + return span - extra + min; + } + } + + template + static inline scalar_t compute_coordinates(scalar_t coord, int64_t size, int64_t padding_mode, bool align_corners) + { + if (padding_mode == GridSamplerPadding::Border) + { + coord = clip_coordinates(coord, size); + } + else if (padding_mode == GridSamplerPadding::Reflection) + { + if (align_corners) + { + coord = reflect_coordinates(coord, 0, 2 * (size - 1)); } - } - } else if (interpolation_mode == GridSamplerInterpolation::Bicubic) { - // grid_sampler_compute_source_index will "clip the value" of idx - // depends on the padding, - // which would cause calculation to be wrong, - // for example x = -0.1 -> ix = 0 for zero padding, but in bicubic ix - // = floor(x) = -1 - // There would be more problem in reflection padding, since the -1 and - // +1 direction is not fixed in boundary condition - ix = grid_sampler_unnormalize(x, inp_W, align_corners); - iy = grid_sampler_unnormalize(y, inp_H, align_corners); - - float ix_nw = std::floor(ix); - float iy_nw = std::floor(iy); - - const float tx = ix - ix_nw; - const float ty = iy - iy_nw; - - const float *inp_ptr_NC = inp_ptr_N; - float *out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; - for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) { - float coefficients[4]; - - // Interpolate 4 values in the x direction - for (int64_t i = 0; i < 4; ++i) { - coefficients[i] = cubic_interp1d( - get_value_bounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, - inp_sW, inp_sH, padding_mode, align_corners), - get_value_bounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, - inp_sW, inp_sH, padding_mode, align_corners), - get_value_bounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, - inp_sW, inp_sH, padding_mode, align_corners), - get_value_bounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, - inp_sW, inp_sH, padding_mode, align_corners), - tx); + else + { + coord = reflect_coordinates(coord, -1, 2 * size - 1); } + coord = clip_coordinates(coord, size); + } + return coord; + } - // Interpolate in the y direction - *out_ptr_NCHW = cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], - coefficients[3], ty); - } + // Computes the pixel source index value for a grid coordinate + template + static inline scalar_t grid_sampler_compute_source_index(scalar_t coord, int64_t size, int64_t padding_mode, bool align_corners) + { + coord = grid_sampler_unnormalize(coord, size, align_corners); + coord = compute_coordinates(coord, size, padding_mode, align_corners); + return coord; + } + + static inline bool within_bounds_2d(int64_t h, int64_t w, int64_t H, int64_t W) + { + return h >= 0 && h < H && w >= 0 && w < W; + } + + template + static inline scalar_t get_value_bounded(const scalar_t* data, scalar_t x, scalar_t y, int64_t W, int64_t H, int64_t sW, int64_t sH, int64_t padding_mode, bool align_corners) + { + x = compute_coordinates(x, W, padding_mode, align_corners); + y = compute_coordinates(y, H, padding_mode, align_corners); + + int64_t ix = static_cast(x); + int64_t iy = static_cast(y); + + if (within_bounds_2d(iy, ix, H, W)) + { + return data[iy * sH + ix * sW]; + } + return static_cast(0); + } + + template + static inline scalar_t cubic_convolution1(scalar_t x, scalar_t A) + { + return ((A + 2) * x - (A + 3)) * x * x + 1; + } + + template + static inline scalar_t cubic_convolution2(scalar_t x, scalar_t A) + { + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; + } + + template + static inline void get_cubic_upsample_coefficients(scalar_t coeffs[4], scalar_t t) + { + scalar_t A = -0.75; + + scalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + // opposite coefficients + scalar_t x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); + } + + template + static inline scalar_t cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, scalar_t t) + { + scalar_t coeffs[4]; + get_cubic_upsample_coefficients(coeffs, t); + + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; + } + + void GridSampleKernel::Compute(OrtKernelContext* context) + { + const bool align_corners = align_corners_; + const int64_t padding_mode = padding_mode_; + const int64_t interpolation_mode = interpolation_mode_; + + const OrtValue* input = ort_.KernelContext_GetInput(context, 0); + const float* input_data = reinterpret_cast(ort_.GetTensorData(input)); + + const OrtValue* grid = ort_.KernelContext_GetInput(context, 1); + const float* grid_data = reinterpret_cast(ort_.GetTensorData(grid)); + + OrtTensorDimensions input_dims(ort_, input); + OrtTensorDimensions grid_dims(ort_, grid); + int64_t N = input_dims[0]; + int64_t C = input_dims[1]; + int64_t inp_H = input_dims[2]; + int64_t inp_W = input_dims[3]; + int64_t out_H = grid_dims[1]; + int64_t out_W = grid_dims[2]; + + std::vector output_dims = {N, C, out_H, out_W}; + OrtValue* output = + ort_.KernelContext_GetOutput(context, 0, output_dims.data(), output_dims.size()); + float* out_ptr = ort_.GetTensorMutableData(output); + + int64_t inp_sN = input_dims[1] * input_dims[2] * input_dims[3]; + int64_t inp_sC = input_dims[2] * input_dims[3]; + int64_t inp_sH = input_dims[3]; + int64_t inp_sW = 1; + int64_t grid_sN = grid_dims[1] * grid_dims[2] * grid_dims[3]; + int64_t grid_sH = grid_dims[2] * grid_dims[3]; + int64_t grid_sW = grid_dims[3]; + int64_t grid_sCoor = 1; + int64_t out_sN = output_dims[1] * output_dims[2] * output_dims[3]; + int64_t out_sC = output_dims[2] * output_dims[3]; + int64_t out_sH = output_dims[3]; + int64_t out_sW = 1; + + // loop over each output pixel + for (int64_t n = 0; n < N; ++n) + { + const float* grid_ptr_N = grid_data + n * grid_sN; + const float* inp_ptr_N = input_data + n * inp_sN; + for (int64_t h = 0; h < out_H; ++h) + { + for (int64_t w = 0; w < out_W; ++w) + { + const float* grid_ptr_NHW = grid_ptr_N + h * grid_sH + w * grid_sW; + float x = *grid_ptr_NHW; + float y = grid_ptr_NHW[grid_sCoor]; + + float ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners); + float iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) + { + // get corner pixel values from (x, y) + // for 4d, we use north-east-south-west + int64_t ix_nw = static_cast(std::floor(ix)); + int64_t iy_nw = static_cast(std::floor(iy)); + + int64_t ix_ne = ix_nw + 1; + int64_t iy_ne = iy_nw; + + int64_t ix_sw = ix_nw; + int64_t iy_sw = iy_nw + 1; + + int64_t ix_se = ix_nw + 1; + int64_t iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + float nw = (ix_se - ix) * (iy_se - iy); + float ne = (ix - ix_sw) * (iy_sw - iy); + float sw = (ix_ne - ix) * (iy - iy_ne); + float se = (ix - ix_nw) * (iy - iy_nw); + + // calculate bilinear weighted pixel value and set output pixel + const float* inp_ptr_NC = inp_ptr_N; + float* out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; + for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) + { + auto res = static_cast(0); + if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) + { + res += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) + { + res += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) + { + res += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) + { + res += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; + } + *out_ptr_NCHW = res; + } + } + else if (interpolation_mode == GridSamplerInterpolation::Nearest) + { + int64_t ix_nearest = static_cast(std::nearbyint(ix)); + int64_t iy_nearest = static_cast(std::nearbyint(iy)); + + // assign nearest neighbor pixel value to output pixel + float* out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; + const float* inp_ptr_NC = inp_ptr_N; + for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) + { + if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) + { + *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; + } + else + { + *out_ptr_NCHW = static_cast(0); + } + } + } + else if (interpolation_mode == GridSamplerInterpolation::Bicubic) + { + // grid_sampler_compute_source_index will "clip the value" of idx + // depends on the padding, + // which would cause calculation to be wrong, + // for example x = -0.1 -> ix = 0 for zero padding, but in bicubic ix + // = floor(x) = -1 + // There would be more problem in reflection padding, since the -1 and + // +1 direction is not fixed in boundary condition + ix = grid_sampler_unnormalize(x, inp_W, align_corners); + iy = grid_sampler_unnormalize(y, inp_H, align_corners); + + float ix_nw = std::floor(ix); + float iy_nw = std::floor(iy); + + const float tx = ix - ix_nw; + const float ty = iy - iy_nw; + + const float* inp_ptr_NC = inp_ptr_N; + float* out_ptr_NCHW = out_ptr + n * out_sN + h * out_sH + w * out_sW; + for (int64_t c = 0; c < C; ++c, out_ptr_NCHW += out_sC, inp_ptr_NC += inp_sC) + { + float coefficients[4]; + + // Interpolate 4 values in the x direction + for (int64_t i = 0; i < 4; ++i) + { + coefficients[i] = cubic_interp1d( + get_value_bounded(inp_ptr_NC, ix_nw - 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 0, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 1, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + get_value_bounded(inp_ptr_NC, ix_nw + 2, iy_nw - 1 + i, inp_W, inp_H, inp_sW, inp_sH, padding_mode, align_corners), + tx); + } + + // Interpolate in the y direction + *out_ptr_NCHW = cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], coefficients[3], ty); + } + } + } + } } - } } - } -} -REGISTER_ONNXRUNTIME_OPS(mmdeploy, GridSampleOp); + REGISTER_ONNXRUNTIME_OPS(mmdeploy, GridSampleOp); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.h b/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.h index 2581b7833e..e6c9fa280f 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/grid_sample/grid_sample.h @@ -4,41 +4,59 @@ #include -namespace mmdeploy { - -struct GridSampleKernel { - GridSampleKernel(const OrtApi &api, const OrtKernelInfo *info); - - void Compute(OrtKernelContext *context); - - protected: - Ort::CustomOpApi ort_; - const OrtKernelInfo *info_; - Ort::AllocatorWithDefaultOptions allocator_; - - int64_t align_corners_; - int64_t interpolation_mode_; - int64_t padding_mode_; -}; - -struct GridSampleOp : Ort::CustomOpBase { - void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const { - return new GridSampleKernel(api, info); - }; - - const char *GetName() const { return "grid_sampler"; }; - - size_t GetInputTypeCount() const { return 2; }; - ONNXTensorElementDataType GetInputType(size_t /*index*/) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - }; - - size_t GetOutputTypeCount() const { return 1; }; - ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - }; - - const char *GetExecutionProviderType() const { return "CPUExecutionProvider"; }; -}; +namespace mmdeploy +{ + + struct GridSampleKernel + { + GridSampleKernel(const OrtApi& api, const OrtKernelInfo* info); + + void Compute(OrtKernelContext* context); + + protected: + Ort::CustomOpApi ort_; + const OrtKernelInfo* info_; + Ort::AllocatorWithDefaultOptions allocator_; + + int64_t align_corners_; + int64_t interpolation_mode_; + int64_t padding_mode_; + }; + + struct GridSampleOp : Ort::CustomOpBase + { + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const + { + return new GridSampleKernel(api, info); + }; + + const char* GetName() const + { + return "grid_sampler"; + }; + + size_t GetInputTypeCount() const + { + return 2; + }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + size_t GetOutputTypeCount() const + { + return 1; + }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + const char* GetExecutionProviderType() const + { + return "CPUExecutionProvider"; + }; + }; } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp index 075c3277bc..320fa8dd45 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.cpp @@ -8,191 +8,218 @@ #include "modulated_deform_conv/modulated_deform_conv_cpu.h" #include "ort_utils.h" -namespace mmdeploy { - -void parallel_unroll_gemm(const float *A, const float *B, const float *V, const float *H, - const int32_t M, const int32_t N, const int32_t K, const float alpha, - const float beta, float *Y, const int32_t start_row, - const int32_t end_row) { - std::vector tmp(N); - for (int32_t m = start_row; m < end_row; ++m) { - for (int32_t n = 0; n < N; n++) { - tmp[n] = 0; - } +namespace mmdeploy +{ + + void parallel_unroll_gemm(const float* A, const float* B, const float* V, const float* H, const int32_t M, const int32_t N, const int32_t K, const float alpha, const float beta, float* Y, const int32_t start_row, const int32_t end_row) { - int32_t remainder = K % 8; // unroll - for (int32_t k = 0; k < K; k += 8) { - for (int32_t n = 0; n < N; n++) { - tmp[n] += A[m * K + k] * B[k * N + n]; - tmp[n] += A[m * K + k + 1] * B[k * N + N + n]; - tmp[n] += A[m * K + k + 2] * B[k * N + 2 * N + n]; - tmp[n] += A[m * K + k + 3] * B[k * N + 3 * N + n]; - tmp[n] += A[m * K + k + 4] * B[k * N + 4 * N + n]; - tmp[n] += A[m * K + k + 5] * B[k * N + 5 * N + n]; - tmp[n] += A[m * K + k + 6] * B[k * N + 6 * N + n]; - tmp[n] += A[m * K + k + 7] * B[k * N + 7 * N + n]; + std::vector tmp(N); + for (int32_t m = start_row; m < end_row; ++m) + { + for (int32_t n = 0; n < N; n++) + { + tmp[n] = 0; + } + { + int32_t remainder = K % 8; // unroll + for (int32_t k = 0; k < K; k += 8) + { + for (int32_t n = 0; n < N; n++) + { + tmp[n] += A[m * K + k] * B[k * N + n]; + tmp[n] += A[m * K + k + 1] * B[k * N + N + n]; + tmp[n] += A[m * K + k + 2] * B[k * N + 2 * N + n]; + tmp[n] += A[m * K + k + 3] * B[k * N + 3 * N + n]; + tmp[n] += A[m * K + k + 4] * B[k * N + 4 * N + n]; + tmp[n] += A[m * K + k + 5] * B[k * N + 5 * N + n]; + tmp[n] += A[m * K + k + 6] * B[k * N + 6 * N + n]; + tmp[n] += A[m * K + k + 7] * B[k * N + 7 * N + n]; + } + } + for (int32_t k = K - remainder; k < K; k++) + { + for (int32_t n = 0; n < N; n++) + { + tmp[n] += A[m * K + k] * B[k * N + n]; + } + } + } + for (int32_t n = 0; n < N; n++) + { + tmp[n] *= alpha; + if (V) tmp[n] += beta * V[n]; + if (H) tmp[n] += beta * H[m * N + n]; + Y[m * N + n] = tmp[n]; + } } - } - for (int32_t k = K - remainder; k < K; k++) { - for (int32_t n = 0; n < N; n++) { - tmp[n] += A[m * K + k] * B[k * N + n]; + } + + void deformable_conv2d_ref_fp32(const float* src, const float* offset, const float* mask, const float* filter, const float* bias, const int64_t batch, const int64_t src_c, const int64_t src_h, const int64_t src_w, const int64_t dst_c, const int64_t dst_h, const int64_t dst_w, const int64_t group, const int64_t offset_group, const int64_t channels, const int64_t num_output, const int64_t kernel_h, const int64_t kernel_w, const int64_t stride_h, const int64_t stride_w, const int64_t pad_h, const int64_t pad_w, const int64_t dilation_h, const int64_t dilation_w, float* columns, float* dst) + { + const int64_t ic_per_gp = channels / group; + const int64_t oc_per_gp = num_output / group; + // Set up for launching threads + std::size_t num_threads = std::thread::hardware_concurrency(); + std::vector threads; + threads.reserve(num_threads); + + for (int64_t b = 0; b < batch; ++b) + { + for (int64_t g = 0; g < group; ++g) + { + deformable_im2col_2d( + src + b * src_c * src_h * src_w + g * ic_per_gp * src_h * src_w, + offset + b * offset_group * 2 * kernel_h * kernel_w * dst_h * dst_w, + mask + b * offset_group * kernel_h * kernel_w * dst_h * dst_w, + src_h, + src_w, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + ic_per_gp, + offset_group, + dst_h, + dst_w, + mask != nullptr, + columns); + float* dst_ptr = dst + b * dst_c * dst_h * dst_w + g * oc_per_gp * dst_h * dst_w; + if (bias != nullptr) + { + const float* bias_ptr = bias + g * oc_per_gp; + for (int64_t oc = 0; oc < oc_per_gp; ++oc) + { + for (int64_t hw = 0; hw < dst_h * dst_w; ++hw) + { + dst_ptr[oc * dst_h * dst_w + hw] = bias_ptr[oc]; + } + } + } + else + { + memset(dst_ptr, 0.0f, sizeof(float) * oc_per_gp * dst_h * dst_w); + } + if (num_threads > 1) + { + // Calculate values to pass to threads + int32_t n_rows = (oc_per_gp + num_threads - 1) / num_threads; + int32_t end_row = 0; + for (int32_t i = 0; i < num_threads; i++) + { + auto start_row = i * n_rows; + end_row = start_row + n_rows; + if (end_row > oc_per_gp) end_row = oc_per_gp; + std::thread t(parallel_unroll_gemm, + filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, + columns, + nullptr, + dst_ptr, + oc_per_gp, + dst_h * dst_w, + ic_per_gp * kernel_h * kernel_w, + 1.0f, + 1.0f, + dst_ptr, + start_row, + end_row); + threads.emplace_back(std::move(t)); + } + // Wait for all threads to complete + for (auto& t : threads) t.join(); + threads.clear(); + } + else + { // parallel gemm degrade to serial gemm with start_row=0 and end_row= oc_per_gp + parallel_unroll_gemm(filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns, nullptr, dst_ptr, oc_per_gp, dst_h * dst_w, ic_per_gp * kernel_h * kernel_w, 1.0f, 1.0f, dst_ptr, 0, oc_per_gp); + } + } } - } } - for (int32_t n = 0; n < N; n++) { - tmp[n] *= alpha; - if (V) tmp[n] += beta * V[n]; - if (H) tmp[n] += beta * H[m * N + n]; - Y[m * N + n] = tmp[n]; + + MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(const OrtApi& api, + const OrtKernelInfo* info) + : ort_(api) + , info_(info) + { + std::vector stride = ort_.KernelInfoGetAttribute>(info, "stride"); + stride_height_ = stride[0]; + stride_width_ = stride[1]; + std::vector padding = ort_.KernelInfoGetAttribute>(info, "padding"); + padding_height_ = padding[0]; + padding_width_ = padding[1]; + std::vector dilation = + ort_.KernelInfoGetAttribute>(info, "dilation"); + dilation_height_ = dilation[0]; + dilation_width_ = dilation[1]; + deformable_group_ = ort_.KernelInfoGetAttribute(info, "deform_groups"); + group_ = ort_.KernelInfoGetAttribute(info, "groups"); + + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); } - } -} - -void deformable_conv2d_ref_fp32(const float *src, const float *offset, const float *mask, - const float *filter, const float *bias, const int64_t batch, - const int64_t src_c, const int64_t src_h, const int64_t src_w, - const int64_t dst_c, const int64_t dst_h, const int64_t dst_w, - const int64_t group, const int64_t offset_group, - const int64_t channels, const int64_t num_output, - const int64_t kernel_h, const int64_t kernel_w, - const int64_t stride_h, const int64_t stride_w, const int64_t pad_h, - const int64_t pad_w, const int64_t dilation_h, - const int64_t dilation_w, float *columns, float *dst) { - const int64_t ic_per_gp = channels / group; - const int64_t oc_per_gp = num_output / group; - // Set up for launching threads - std::size_t num_threads = std::thread::hardware_concurrency(); - std::vector threads; - threads.reserve(num_threads); - - for (int64_t b = 0; b < batch; ++b) { - for (int64_t g = 0; g < group; ++g) { - deformable_im2col_2d( - src + b * src_c * src_h * src_w + g * ic_per_gp * src_h * src_w, - offset + b * offset_group * 2 * kernel_h * kernel_w * dst_h * dst_w, - mask + b * offset_group * kernel_h * kernel_w * dst_h * dst_w, src_h, src_w, kernel_h, - kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, ic_per_gp, - offset_group, dst_h, dst_w, mask != nullptr, columns); - float *dst_ptr = dst + b * dst_c * dst_h * dst_w + g * oc_per_gp * dst_h * dst_w; - if (bias != nullptr) { - const float *bias_ptr = bias + g * oc_per_gp; - for (int64_t oc = 0; oc < oc_per_gp; ++oc) { - for (int64_t hw = 0; hw < dst_h * dst_w; ++hw) { - dst_ptr[oc * dst_h * dst_w + hw] = bias_ptr[oc]; - } - } - } else { - memset(dst_ptr, 0.0f, sizeof(float) * oc_per_gp * dst_h * dst_w); - } - if (num_threads > 1) { - // Calculate values to pass to threads - int32_t n_rows = (oc_per_gp + num_threads - 1) / num_threads; - int32_t end_row = 0; - for (int32_t i = 0; i < num_threads; i++) { - auto start_row = i * n_rows; - end_row = start_row + n_rows; - if (end_row > oc_per_gp) end_row = oc_per_gp; - std::thread t(parallel_unroll_gemm, - filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns, nullptr, - dst_ptr, oc_per_gp, dst_h * dst_w, ic_per_gp * kernel_h * kernel_w, 1.0f, - 1.0f, dst_ptr, start_row, end_row); - threads.emplace_back(std::move(t)); - } - // Wait for all threads to complete - for (auto &t : threads) t.join(); - threads.clear(); - } else { // parallel gemm degrade to serial gemm with start_row=0 and end_row= oc_per_gp - parallel_unroll_gemm(filter + g * oc_per_gp * ic_per_gp * kernel_h * kernel_w, columns, - nullptr, dst_ptr, oc_per_gp, dst_h * dst_w, - ic_per_gp * kernel_h * kernel_w, 1.0f, 1.0f, dst_ptr, 0, oc_per_gp); - } + + void MMCVModulatedDeformConvKernel::Compute(OrtKernelContext* context) + { + const int64_t stride_height = stride_height_; + const int64_t stride_width = stride_width_; + const int64_t padding_height = padding_height_; + const int64_t padding_width = padding_width_; + const int64_t dilation_height = dilation_height_; + const int64_t dilation_width = dilation_width_; + const int64_t deformable_group = deformable_group_; + const int64_t group = group_; + + const OrtValue* input = ort_.KernelContext_GetInput(context, 0); + const float* input_data = reinterpret_cast(ort_.GetTensorData(input)); + + const OrtValue* offset = ort_.KernelContext_GetInput(context, 1); + const float* offset_data = reinterpret_cast(ort_.GetTensorData(offset)); + + const OrtValue* mask = ort_.KernelContext_GetInput(context, 2); + const float* mask_data = reinterpret_cast(ort_.GetTensorData(mask)); + + const OrtValue* filter = ort_.KernelContext_GetInput(context, 3); + const float* filter_data = reinterpret_cast(ort_.GetTensorData(filter)); + + const OrtValue* bias = ort_.KernelContext_GetInput(context, 4); + const float* bias_data = (bias != nullptr) ? reinterpret_cast(ort_.GetTensorData(bias)) : nullptr; + // const float *bias_data = nullptr; + + OrtTensorDimensions input_dims(ort_, input); + OrtTensorDimensions filter_dims(ort_, filter); + + int64_t batch = input_dims[0]; + int64_t channels = input_dims[1]; + int64_t in_height = input_dims[2]; + int64_t in_width = input_dims[3]; + int64_t num_output = filter_dims[0]; + int64_t kernel_height = filter_dims[2]; + int64_t kernel_width = filter_dims[3]; + + // get output memory + int64_t out_height = floor( + (in_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) / stride_height + + 1); + int64_t out_width = floor( + (in_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) / stride_width + 1); + + std::vector output_dims = {batch, num_output, out_height, out_width}; + OrtValue* output = + ort_.KernelContext_GetOutput(context, 0, output_dims.data(), output_dims.size()); + float* out_ptr = ort_.GetTensorMutableData(output); + + // allocate tmp memory + int64_t column_len = (channels / group) * kernel_height * kernel_width * out_height * out_width; + float* columns = (float*)allocator_.Alloc(sizeof(float) * column_len); + + deformable_conv2d_ref_fp32(input_data, offset_data, mask_data, filter_data, bias_data, batch, channels, in_height, in_width, num_output, out_height, out_width, group, deformable_group, channels, num_output, kernel_height, kernel_width, stride_height, stride_width, padding_height, padding_width, dilation_height, dilation_width, columns, out_ptr); + + allocator_.Free(columns); } - } -} - -MMCVModulatedDeformConvKernel::MMCVModulatedDeformConvKernel(const OrtApi &api, - const OrtKernelInfo *info) - : ort_(api), info_(info) { - std::vector stride = ort_.KernelInfoGetAttribute>(info, "stride"); - stride_height_ = stride[0]; - stride_width_ = stride[1]; - std::vector padding = ort_.KernelInfoGetAttribute>(info, "padding"); - padding_height_ = padding[0]; - padding_width_ = padding[1]; - std::vector dilation = - ort_.KernelInfoGetAttribute>(info, "dilation"); - dilation_height_ = dilation[0]; - dilation_width_ = dilation[1]; - deformable_group_ = ort_.KernelInfoGetAttribute(info, "deform_groups"); - group_ = ort_.KernelInfoGetAttribute(info, "groups"); - - // create allocator - allocator_ = Ort::AllocatorWithDefaultOptions(); -} - -void MMCVModulatedDeformConvKernel::Compute(OrtKernelContext *context) { - const int64_t stride_height = stride_height_; - const int64_t stride_width = stride_width_; - const int64_t padding_height = padding_height_; - const int64_t padding_width = padding_width_; - const int64_t dilation_height = dilation_height_; - const int64_t dilation_width = dilation_width_; - const int64_t deformable_group = deformable_group_; - const int64_t group = group_; - - const OrtValue *input = ort_.KernelContext_GetInput(context, 0); - const float *input_data = reinterpret_cast(ort_.GetTensorData(input)); - - const OrtValue *offset = ort_.KernelContext_GetInput(context, 1); - const float *offset_data = reinterpret_cast(ort_.GetTensorData(offset)); - - const OrtValue *mask = ort_.KernelContext_GetInput(context, 2); - const float *mask_data = reinterpret_cast(ort_.GetTensorData(mask)); - - const OrtValue *filter = ort_.KernelContext_GetInput(context, 3); - const float *filter_data = reinterpret_cast(ort_.GetTensorData(filter)); - - const OrtValue *bias = ort_.KernelContext_GetInput(context, 4); - const float *bias_data = (bias != nullptr) - ? reinterpret_cast(ort_.GetTensorData(bias)) - : nullptr; - // const float *bias_data = nullptr; - - OrtTensorDimensions input_dims(ort_, input); - OrtTensorDimensions filter_dims(ort_, filter); - - int64_t batch = input_dims[0]; - int64_t channels = input_dims[1]; - int64_t in_height = input_dims[2]; - int64_t in_width = input_dims[3]; - int64_t num_output = filter_dims[0]; - int64_t kernel_height = filter_dims[2]; - int64_t kernel_width = filter_dims[3]; - - // get output memory - int64_t out_height = floor( - (in_height + 2 * padding_height - dilation_height * (kernel_height - 1) - 1) / stride_height + - 1); - int64_t out_width = floor( - (in_width + 2 * padding_width - dilation_width * (kernel_width - 1) - 1) / stride_width + 1); - - std::vector output_dims = {batch, num_output, out_height, out_width}; - OrtValue *output = - ort_.KernelContext_GetOutput(context, 0, output_dims.data(), output_dims.size()); - float *out_ptr = ort_.GetTensorMutableData(output); - - // allocate tmp memory - int64_t column_len = (channels / group) * kernel_height * kernel_width * out_height * out_width; - float *columns = (float *)allocator_.Alloc(sizeof(float) * column_len); - - deformable_conv2d_ref_fp32(input_data, offset_data, mask_data, filter_data, bias_data, batch, - channels, in_height, in_width, num_output, out_height, out_width, - group, deformable_group, channels, num_output, kernel_height, - kernel_width, stride_height, stride_width, padding_height, - padding_width, dilation_height, dilation_width, columns, out_ptr); - - allocator_.Free(columns); -} -REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVModulatedDeformConvOp); -REGISTER_ONNXRUNTIME_OPS(mmcv, MMCVModulatedDeformConvOp); + REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVModulatedDeformConvOp); + REGISTER_ONNXRUNTIME_OPS(mmcv, MMCVModulatedDeformConvOp); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.h b/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.h index 772a9c4a88..7ffeb702d3 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/modulated_deform_conv/modulated_deform_conv.h @@ -4,55 +4,74 @@ #include -namespace mmdeploy { - -struct MMCVModulatedDeformConvKernel { - MMCVModulatedDeformConvKernel(const OrtApi &api, const OrtKernelInfo *info); - - void Compute(OrtKernelContext *context); - - protected: - Ort::CustomOpApi ort_; - const OrtKernelInfo *info_; - Ort::AllocatorWithDefaultOptions allocator_; - - int64_t stride_height_; - int64_t stride_width_; - int64_t padding_height_; - int64_t padding_width_; - int64_t dilation_height_; - int64_t dilation_width_; - int64_t deformable_group_; - int64_t group_; -}; - -struct MMCVModulatedDeformConvOp - : Ort::CustomOpBase { - void *CreateKernel(const OrtApi &api, const OrtKernelInfo *info) const { - return new MMCVModulatedDeformConvKernel(api, info); - } - - const char *GetName() const { return "MMCVModulatedDeformConv2d"; }; - - size_t GetInputTypeCount() const { return 5; }; - ONNXTensorElementDataType GetInputType(size_t /*index*/) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - }; - - OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const { - // The last input (index == 4) is optional, which is bias - if (index == 4) return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL; - - return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; - } - - size_t GetOutputTypeCount() const { return 1; }; - ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - }; - - // force cpu - const char *GetExecutionProviderType() const { return "CPUExecutionProvider"; }; -}; +namespace mmdeploy +{ + + struct MMCVModulatedDeformConvKernel + { + MMCVModulatedDeformConvKernel(const OrtApi& api, const OrtKernelInfo* info); + + void Compute(OrtKernelContext* context); + + protected: + Ort::CustomOpApi ort_; + const OrtKernelInfo* info_; + Ort::AllocatorWithDefaultOptions allocator_; + + int64_t stride_height_; + int64_t stride_width_; + int64_t padding_height_; + int64_t padding_width_; + int64_t dilation_height_; + int64_t dilation_width_; + int64_t deformable_group_; + int64_t group_; + }; + + struct MMCVModulatedDeformConvOp + : Ort::CustomOpBase + { + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const + { + return new MMCVModulatedDeformConvKernel(api, info); + } + + const char* GetName() const + { + return "MMCVModulatedDeformConv2d"; + }; + + size_t GetInputTypeCount() const + { + return 5; + }; + ONNXTensorElementDataType GetInputType(size_t /*index*/) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t index) const + { + // The last input (index == 4) is optional, which is bias + if (index == 4) return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL; + + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED; + } + + size_t GetOutputTypeCount() const + { + return 1; + }; + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + }; + + // force cpu + const char* GetExecutionProviderType() const + { + return "CPUExecutionProvider"; + }; + }; } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.cpp index 784be2c987..397bcbf92c 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.cpp @@ -13,117 +13,132 @@ #include "ort_utils.h" -namespace mmdeploy { -struct Box { - float x1, y1, x2, y2; -}; - -float nms_match_iou(Box box1, Box box2) { - auto inter_x1 = std::max(box1.x1, box2.x1); - auto inter_y1 = std::max(box1.y1, box2.y1); - auto inter_x2 = std::min(box1.x2, box2.x2); - auto inter_y2 = std::min(box1.y2, box2.y2); - - auto eps = 1e-10; - - auto w = std::max(static_cast(0), inter_x2 - inter_x1); - auto h = std::max(static_cast(0), inter_y2 - inter_y1); - - auto area1 = (box1.x2 - box1.x1) * (box1.y2 - box1.y1); - auto area2 = (box2.x2 - box2.x1) * (box2.y2 - box2.y1); - auto inter = w * h; - auto ovr = inter / (area1 + area2 - inter + eps); - return ovr; -} -NMSMatchKernel::NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info) - : ort_(api), info_(info) { - // create allocator - allocator_ = Ort::AllocatorWithDefaultOptions(); -} - -void NMSMatchKernel::Compute(OrtKernelContext* context) { - const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0); - const float* boxes_data = reinterpret_cast(ort_.GetTensorData(boxes)); - const OrtValue* scores = ort_.KernelContext_GetInput(context, 1); - const float* scores_data = reinterpret_cast(ort_.GetTensorData(scores)); - const OrtValue* iou_threshold_ = ort_.KernelContext_GetInput(context, 2); - const float iou_threshold_data = ort_.GetTensorData(iou_threshold_)[0]; - const OrtValue* score_threshold_ = ort_.KernelContext_GetInput(context, 3); - const float score_threshold_data = ort_.GetTensorData(score_threshold_)[0]; - - OrtTensorDimensions boxes_dim(ort_, boxes); - OrtTensorDimensions scores_dim(ort_, scores); - // loop over batch - int64_t nbatch = boxes_dim[0]; - int64_t nboxes = boxes_dim[1]; - int64_t nclass = scores_dim[1]; - assert(boxes_dim[2] == 4); //(x1, x2, y1, y2) - // alloc some temp memory - bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes); - - std::vector res_order; - for (int64_t k = 0; k < nbatch; k++) { - for (int64_t g = 0; g < nclass; g++) { - for (int64_t i = 0; i < nboxes; i++) { - select[i] = true; - } - // scores - // k * nboxes * nclass means per batch - // g * nboxes means per class - // batch = 2 boxes = 3 classes = 4 - std::vector tmp_sc; - // get the class scores - for (int i = 0; i < nboxes; i++) { - tmp_sc.push_back(scores_data[k * nboxes * nclass + g * nboxes + i]); - } - - std::vector order(tmp_sc.size()); - std::iota(order.begin(), order.end(), 0); - std::sort(order.begin(), order.end(), - [&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; }); - for (int64_t _i = 0; _i < nboxes; _i++) { - auto i = order[_i]; - if (select[i] == false) continue; - std::vector v_i; - for (int64_t _j = _i + 1; _j < nboxes; _j++) { - auto j = order[_j]; - if (select[j] == false) continue; - Box vbox1, vbox2; - vbox1.x1 = boxes_data[k * nboxes * 4 + i * 4]; - vbox1.y1 = boxes_data[k * nboxes * 4 + i * 4 + 1]; - vbox1.x2 = boxes_data[k * nboxes * 4 + i * 4 + 2]; - vbox1.y2 = boxes_data[k * nboxes * 4 + i * 4 + 3]; - - vbox2.x1 = boxes_data[k * nboxes * 4 + j * 4]; - vbox2.y1 = boxes_data[k * nboxes * 4 + j * 4 + 1]; - vbox2.x2 = boxes_data[k * nboxes * 4 + j * 4 + 2]; - vbox2.y2 = boxes_data[k * nboxes * 4 + j * 4 + 3]; - - auto ovr = nms_match_iou(vbox1, vbox2); - if (ovr >= iou_threshold_data) { - select[j] = false; - v_i.push_back(j); - } - } - if (tmp_sc[i] > score_threshold_data && v_i.size() != 0) { - for (int v_i_idx = 0; v_i_idx < v_i.size(); v_i_idx++) { - res_order.push_back(k); - res_order.push_back(g); - res_order.push_back(i); - res_order.push_back(v_i[v_i_idx]); - } - } - } +namespace mmdeploy +{ + struct Box + { + float x1, y1, x2, y2; + }; + + float nms_match_iou(Box box1, Box box2) + { + auto inter_x1 = std::max(box1.x1, box2.x1); + auto inter_y1 = std::max(box1.y1, box2.y1); + auto inter_x2 = std::min(box1.x2, box2.x2); + auto inter_y2 = std::min(box1.y2, box2.y2); + + auto eps = 1e-10; + + auto w = std::max(static_cast(0), inter_x2 - inter_x1); + auto h = std::max(static_cast(0), inter_y2 - inter_y1); + + auto area1 = (box1.x2 - box1.x1) * (box1.y2 - box1.y1); + auto area2 = (box2.x2 - box2.x1) * (box2.y2 - box2.y1); + auto inter = w * h; + auto ovr = inter / (area1 + area2 - inter + eps); + return ovr; + } + NMSMatchKernel::NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info) + : ort_(api) + , info_(info) + { + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); } - } - std::vector inds_dims({(int64_t)res_order.size() / 4, 4}); - OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size()); - int64_t* res_data = ort_.GetTensorMutableData(res); + void NMSMatchKernel::Compute(OrtKernelContext* context) + { + const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0); + const float* boxes_data = reinterpret_cast(ort_.GetTensorData(boxes)); + const OrtValue* scores = ort_.KernelContext_GetInput(context, 1); + const float* scores_data = reinterpret_cast(ort_.GetTensorData(scores)); + const OrtValue* iou_threshold_ = ort_.KernelContext_GetInput(context, 2); + const float iou_threshold_data = ort_.GetTensorData(iou_threshold_)[0]; + const OrtValue* score_threshold_ = ort_.KernelContext_GetInput(context, 3); + const float score_threshold_data = ort_.GetTensorData(score_threshold_)[0]; + + OrtTensorDimensions boxes_dim(ort_, boxes); + OrtTensorDimensions scores_dim(ort_, scores); + // loop over batch + int64_t nbatch = boxes_dim[0]; + int64_t nboxes = boxes_dim[1]; + int64_t nclass = scores_dim[1]; + assert(boxes_dim[2] == 4); //(x1, x2, y1, y2) + // alloc some temp memory + bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes); + + std::vector res_order; + for (int64_t k = 0; k < nbatch; k++) + { + for (int64_t g = 0; g < nclass; g++) + { + for (int64_t i = 0; i < nboxes; i++) + { + select[i] = true; + } + // scores + // k * nboxes * nclass means per batch + // g * nboxes means per class + // batch = 2 boxes = 3 classes = 4 + std::vector tmp_sc; + // get the class scores + for (int i = 0; i < nboxes; i++) + { + tmp_sc.push_back(scores_data[k * nboxes * nclass + g * nboxes + i]); + } + + std::vector order(tmp_sc.size()); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&tmp_sc](int64_t id1, int64_t id2) + { return tmp_sc[id1] > tmp_sc[id2]; }); + for (int64_t _i = 0; _i < nboxes; _i++) + { + auto i = order[_i]; + if (select[i] == false) continue; + std::vector v_i; + for (int64_t _j = _i + 1; _j < nboxes; _j++) + { + auto j = order[_j]; + if (select[j] == false) continue; + Box vbox1, vbox2; + vbox1.x1 = boxes_data[k * nboxes * 4 + i * 4]; + vbox1.y1 = boxes_data[k * nboxes * 4 + i * 4 + 1]; + vbox1.x2 = boxes_data[k * nboxes * 4 + i * 4 + 2]; + vbox1.y2 = boxes_data[k * nboxes * 4 + i * 4 + 3]; + + vbox2.x1 = boxes_data[k * nboxes * 4 + j * 4]; + vbox2.y1 = boxes_data[k * nboxes * 4 + j * 4 + 1]; + vbox2.x2 = boxes_data[k * nboxes * 4 + j * 4 + 2]; + vbox2.y2 = boxes_data[k * nboxes * 4 + j * 4 + 3]; + + auto ovr = nms_match_iou(vbox1, vbox2); + if (ovr >= iou_threshold_data) + { + select[j] = false; + v_i.push_back(j); + } + } + if (tmp_sc[i] > score_threshold_data && v_i.size() != 0) + { + for (int v_i_idx = 0; v_i_idx < v_i.size(); v_i_idx++) + { + res_order.push_back(k); + res_order.push_back(g); + res_order.push_back(i); + res_order.push_back(v_i[v_i_idx]); + } + } + } + } + } + std::vector inds_dims({(int64_t)res_order.size() / 4, 4}); + + OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size()); + int64_t* res_data = ort_.GetTensorMutableData(res); - memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size()); + memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size()); - allocator_.Free(select); -} -REGISTER_ONNXRUNTIME_OPS(mmdeploy, NMSMatchOp); + allocator_.Free(select); + } + REGISTER_ONNXRUNTIME_OPS(mmdeploy, NMSMatchOp); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.h b/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.h index 57aa94d964..48e0d0dbb0 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/nms_match/nms_match.h @@ -10,37 +10,55 @@ #include #include -namespace mmdeploy { -struct NMSMatchKernel { - NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info); - - void Compute(OrtKernelContext* context); - - private: - Ort::CustomOpApi ort_; - const OrtKernelInfo* info_; - Ort::AllocatorWithDefaultOptions allocator_; -}; - -struct NMSMatchOp : Ort::CustomOpBase { - void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { - return new NMSMatchKernel(api, info); - } - const char* GetName() const { return "NMSMatch"; } - - size_t GetInputTypeCount() const { return 4; } - ONNXTensorElementDataType GetInputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } - - size_t GetOutputTypeCount() const { return 1; } - ONNXTensorElementDataType GetOutputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - } - - // force cpu - const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } -}; +namespace mmdeploy +{ + struct NMSMatchKernel + { + NMSMatchKernel(const OrtApi& api, const OrtKernelInfo* info); + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; + const OrtKernelInfo* info_; + Ort::AllocatorWithDefaultOptions allocator_; + }; + + struct NMSMatchOp : Ort::CustomOpBase + { + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const + { + return new NMSMatchKernel(api, info); + } + const char* GetName() const + { + return "NMSMatch"; + } + + size_t GetInputTypeCount() const + { + return 4; + } + ONNXTensorElementDataType GetInputType(size_t) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + size_t GetOutputTypeCount() const + { + return 1; + } + ONNXTensorElementDataType GetOutputType(size_t) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } + + // force cpu + const char* GetExecutionProviderType() const + { + return "CPUExecutionProvider"; + } + }; } // namespace mmdeploy #endif // ONNXRUNTIME_NMS_MATCH_H diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp index 9d8cc4597e..73c508ce47 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.cpp @@ -13,356 +13,418 @@ #include "ort_utils.h" -namespace mmdeploy { - -namespace { -struct RotatedBox { - float x_ctr, y_ctr, w, h, a; -}; -struct Point { - float x, y; - Point(const float& px = 0, const float& py = 0) : x(px), y(py) {} - Point operator+(const Point& p) const { return Point(x + p.x, y + p.y); } - Point& operator+=(const Point& p) { - x += p.x; - y += p.y; - return *this; - } - Point operator-(const Point& p) const { return Point(x - p.x, y - p.y); } - Point operator*(const float coeff) const { return Point(x * coeff, y * coeff); } -}; - -float dot_2d(const Point& A, const Point& B) { return A.x * B.x + A.y * B.y; } - -float cross_2d(const Point& A, const Point& B) { return A.x * B.y - B.x * A.y; } -} // namespace - -void get_rotated_vertices(const RotatedBox& box, Point (&pts)[4]) { - // M_PI / 180. == 0.01745329251 - // double theta = box.a * 0.01745329251; - // MODIFIED - double theta = box.a; - float cosTheta2 = (float)cos(theta) * 0.5f; - float sinTheta2 = (float)sin(theta) * 0.5f; - - // y: top --> down; x: left --> right - pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w; - pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; - pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w; - pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; - pts[2].x = 2 * box.x_ctr - pts[0].x; - pts[2].y = 2 * box.y_ctr - pts[0].y; - pts[3].x = 2 * box.x_ctr - pts[1].x; - pts[3].y = 2 * box.y_ctr - pts[1].y; -} - -int get_intersection_points(const Point (&pts1)[4], const Point (&pts2)[4], - Point (&intersections)[24]) { - // Line vector - // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] - Point vec1[4], vec2[4]; - for (int i = 0; i < 4; i++) { - vec1[i] = pts1[(i + 1) % 4] - pts1[i]; - vec2[i] = pts2[(i + 1) % 4] - pts2[i]; - } - - // Line test - test all line combos for intersection - int num = 0; // number of intersections - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - // Solve for 2x2 Ax=b - float det = cross_2d(vec2[j], vec1[i]); - - // This takes care of parallel lines - if (fabs(det) <= 1e-14) { - continue; - } - - auto vec12 = pts2[j] - pts1[i]; - - float t1 = cross_2d(vec2[j], vec12) / det; - float t2 = cross_2d(vec1[i], vec12) / det; - - if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) { - intersections[num++] = pts1[i] + vec1[i] * t1; - } - } - } - - // Check for vertices of rect1 inside rect2 - { - const auto& AB = vec2[0]; - const auto& DA = vec2[3]; - auto ABdotAB = dot_2d(AB, AB); - auto ADdotAD = dot_2d(DA, DA); - for (int i = 0; i < 4; i++) { - // assume ABCD is the rectangle, and P is the point to be judged - // P is inside ABCD iff. P's projection on AB lies within AB - // and P's projection on AD lies within AD - - auto AP = pts1[i] - pts2[0]; - - auto APdotAB = dot_2d(AP, AB); - auto APdotAD = -dot_2d(AP, DA); - - if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) { - intersections[num++] = pts1[i]; - } - } - } - - // Reverse the check - check for vertices of rect2 inside rect1 - { - const auto& AB = vec1[0]; - const auto& DA = vec1[3]; - auto ABdotAB = dot_2d(AB, AB); - auto ADdotAD = dot_2d(DA, DA); - for (int i = 0; i < 4; i++) { - auto AP = pts2[i] - pts1[0]; - - auto APdotAB = dot_2d(AP, AB); - auto APdotAD = -dot_2d(AP, DA); - - if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) { - intersections[num++] = pts2[i]; - } +namespace mmdeploy +{ + + namespace + { + struct RotatedBox + { + float x_ctr, y_ctr, w, h, a; + }; + struct Point + { + float x, y; + Point(const float& px = 0, const float& py = 0) + : x(px) + , y(py) + { + } + Point operator+(const Point& p) const + { + return Point(x + p.x, y + p.y); + } + Point& operator+=(const Point& p) + { + x += p.x; + y += p.y; + return *this; + } + Point operator-(const Point& p) const + { + return Point(x - p.x, y - p.y); + } + Point operator*(const float coeff) const + { + return Point(x * coeff, y * coeff); + } + }; + + float dot_2d(const Point& A, const Point& B) + { + return A.x * B.x + A.y * B.y; + } + + float cross_2d(const Point& A, const Point& B) + { + return A.x * B.y - B.x * A.y; + } + } // namespace + + void get_rotated_vertices(const RotatedBox& box, Point (&pts)[4]) + { + // M_PI / 180. == 0.01745329251 + // double theta = box.a * 0.01745329251; + // MODIFIED + double theta = box.a; + float cosTheta2 = (float)cos(theta) * 0.5f; + float sinTheta2 = (float)sin(theta) * 0.5f; + + // y: top --> down; x: left --> right + pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w; + pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; + pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w; + pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; + pts[2].x = 2 * box.x_ctr - pts[0].x; + pts[2].y = 2 * box.y_ctr - pts[0].y; + pts[3].x = 2 * box.x_ctr - pts[1].x; + pts[3].y = 2 * box.y_ctr - pts[1].y; } - } - - return num; -} - -int convex_hull_graham(const Point (&p)[24], const int& num_in, Point (&q)[24], - bool shift_to_zero = false) { - assert(num_in >= 2); - - // Step 1: - // Find point with minimum y - // if more than 1 points have the same minimum y, - // pick the one with the minimum x. - int t = 0; - for (int i = 1; i < num_in; i++) { - if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) { - t = i; + + int get_intersection_points(const Point (&pts1)[4], const Point (&pts2)[4], Point (&intersections)[24]) + { + // Line vector + // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] + Point vec1[4], vec2[4]; + for (int i = 0; i < 4; i++) + { + vec1[i] = pts1[(i + 1) % 4] - pts1[i]; + vec2[i] = pts2[(i + 1) % 4] - pts2[i]; + } + + // Line test - test all line combos for intersection + int num = 0; // number of intersections + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++) + { + // Solve for 2x2 Ax=b + float det = cross_2d(vec2[j], vec1[i]); + + // This takes care of parallel lines + if (fabs(det) <= 1e-14) + { + continue; + } + + auto vec12 = pts2[j] - pts1[i]; + + float t1 = cross_2d(vec2[j], vec12) / det; + float t2 = cross_2d(vec1[i], vec12) / det; + + if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) + { + intersections[num++] = pts1[i] + vec1[i] * t1; + } + } + } + + // Check for vertices of rect1 inside rect2 + { + const auto& AB = vec2[0]; + const auto& DA = vec2[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) + { + // assume ABCD is the rectangle, and P is the point to be judged + // P is inside ABCD iff. P's projection on AB lies within AB + // and P's projection on AD lies within AD + + auto AP = pts1[i] - pts2[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) + { + intersections[num++] = pts1[i]; + } + } + } + + // Reverse the check - check for vertices of rect2 inside rect1 + { + const auto& AB = vec1[0]; + const auto& DA = vec1[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) + { + auto AP = pts2[i] - pts1[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) + { + intersections[num++] = pts2[i]; + } + } + } + + return num; } - } - auto& start = p[t]; // starting point - - // Step 2: - // Subtract starting point from every points (for sorting in the next step) - for (int i = 0; i < num_in; i++) { - q[i] = p[i] - start; - } - - // Swap the starting point to position 0 - auto tmp = q[0]; - q[0] = q[t]; - q[t] = tmp; - - // Step 3: - // Sort point 1 ~ num_in according to their relative cross-product values - // (essentially sorting according to angles) - // If the angles are the same, sort according to their distance to origin - float dist[24]; - for (int i = 0; i < num_in; i++) { - dist[i] = dot_2d(q[i], q[i]); - } - - // CPU version - std::sort(q + 1, q + num_in, [](const Point& A, const Point& B) -> bool { + + int convex_hull_graham(const Point (&p)[24], const int& num_in, Point (&q)[24], bool shift_to_zero = false) + { + assert(num_in >= 2); + + // Step 1: + // Find point with minimum y + // if more than 1 points have the same minimum y, + // pick the one with the minimum x. + int t = 0; + for (int i = 1; i < num_in; i++) + { + if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) + { + t = i; + } + } + auto& start = p[t]; // starting point + + // Step 2: + // Subtract starting point from every points (for sorting in the next step) + for (int i = 0; i < num_in; i++) + { + q[i] = p[i] - start; + } + + // Swap the starting point to position 0 + auto tmp = q[0]; + q[0] = q[t]; + q[t] = tmp; + + // Step 3: + // Sort point 1 ~ num_in according to their relative cross-product values + // (essentially sorting according to angles) + // If the angles are the same, sort according to their distance to origin + float dist[24]; + for (int i = 0; i < num_in; i++) + { + dist[i] = dot_2d(q[i], q[i]); + } + + // CPU version + std::sort(q + 1, q + num_in, [](const Point& A, const Point& B) -> bool + { float temp = cross_2d(A, B); if (fabs(temp) < 1e-6) { return dot_2d(A, A) < dot_2d(B, B); } else { return temp > 0; + } }); + // compute distance to origin after sort, since the points are now different. + for (int i = 0; i < num_in; i++) + { + dist[i] = dot_2d(q[i], q[i]); + } + + // Step 4: + // Make sure there are at least 2 points (that don't overlap with each other) + // in the stack + int k; // index of the non-overlapped second point + for (k = 1; k < num_in; k++) + { + if (dist[k] > 1e-8) + { + break; + } + } + if (k == num_in) + { + // We reach the end, which means the convex hull is just one point + q[0] = p[t]; + return 1; + } + q[1] = q[k]; + int m = 2; // 2 points in the stack + // Step 5: + // Finally we can start the scanning process. + // When a non-convex relationship between the 3 points is found + // (either concave shape or duplicated points), + // we pop the previous point from the stack + // until the 3-point relationship is convex again, or + // until the stack only contains two points + for (int i = k + 1; i < num_in; i++) + { + while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) + { + m--; + } + q[m++] = q[i]; + } + + // Step 6 (Optional): + // In general sense we need the original coordinates, so we + // need to shift the points back (reverting Step 2) + // But if we're only interested in getting the area/perimeter of the shape + // We can simply return. + if (!shift_to_zero) + { + for (int i = 0; i < m; i++) + { + q[i] += start; + } + } + + return m; } - }); - // compute distance to origin after sort, since the points are now different. - for (int i = 0; i < num_in; i++) { - dist[i] = dot_2d(q[i], q[i]); - } - - // Step 4: - // Make sure there are at least 2 points (that don't overlap with each other) - // in the stack - int k; // index of the non-overlapped second point - for (k = 1; k < num_in; k++) { - if (dist[k] > 1e-8) { - break; - } - } - if (k == num_in) { - // We reach the end, which means the convex hull is just one point - q[0] = p[t]; - return 1; - } - q[1] = q[k]; - int m = 2; // 2 points in the stack - // Step 5: - // Finally we can start the scanning process. - // When a non-convex relationship between the 3 points is found - // (either concave shape or duplicated points), - // we pop the previous point from the stack - // until the 3-point relationship is convex again, or - // until the stack only contains two points - for (int i = k + 1; i < num_in; i++) { - while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) { - m--; - } - q[m++] = q[i]; - } - - // Step 6 (Optional): - // In general sense we need the original coordinates, so we - // need to shift the points back (reverting Step 2) - // But if we're only interested in getting the area/perimeter of the shape - // We can simply return. - if (!shift_to_zero) { - for (int i = 0; i < m; i++) { - q[i] += start; - } - } - - return m; -} - -float polygon_area(const Point (&q)[24], const int& m) { - if (m <= 2) { - return 0; - } - - float area = 0; - for (int i = 1; i < m - 1; i++) { - area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); - } - - return area / 2.0; -} - -float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2) { - // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned - // from rotated_rect_intersection_pts - Point intersectPts[24], orderedPts[24]; - - Point pts1[4]; - Point pts2[4]; - get_rotated_vertices(box1, pts1); - get_rotated_vertices(box2, pts2); - - int num = get_intersection_points(pts1, pts2, intersectPts); - - if (num <= 2) { - return 0.0; - } - - // Convex Hull to order the intersection points in clockwise order and find - // the contour area. - int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); - return polygon_area(orderedPts, num_convex); -} - -NMSRotatedKernel::NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info) - : ort_(api), info_(info) { - iou_threshold_ = ort_.KernelInfoGetAttribute(info, "iou_threshold"); - score_threshold_ = ort_.KernelInfoGetAttribute(info, "score_threshold"); - - // create allocator - allocator_ = Ort::AllocatorWithDefaultOptions(); -} - -void NMSRotatedKernel::Compute(OrtKernelContext* context) { - const float iou_threshold = iou_threshold_; - const float score_threshold = score_threshold_; - - const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0); - const float* boxes_data = reinterpret_cast(ort_.GetTensorData(boxes)); - const OrtValue* scores = ort_.KernelContext_GetInput(context, 1); - const float* scores_data = reinterpret_cast(ort_.GetTensorData(scores)); - - OrtTensorDimensions boxes_dim(ort_, boxes); - OrtTensorDimensions scores_dim(ort_, scores); - - // loop over batch - int64_t nbatch = boxes_dim[0]; - int64_t nboxes = boxes_dim[1]; - int64_t nclass = scores_dim[1]; - assert(boxes_dim[2] == 5); //(cx,cy,w,h,theta) - - // allocate tmp memory - float* tmp_boxes = (float*)allocator_.Alloc(sizeof(float) * nbatch * nboxes * 5); - float* sc = (float*)allocator_.Alloc(sizeof(float) * nbatch * nclass * nboxes); - bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes); - - memcpy(tmp_boxes, boxes_data, sizeof(float) * nbatch * nboxes * 5); - memcpy(sc, scores_data, sizeof(float) * nbatch * nclass * nboxes); - - // std::vector> res_order; - std::vector res_order; - for (int64_t k = 0; k < nbatch; k++) { - for (int64_t g = 0; g < nclass; g++) { - for (int64_t i = 0; i < nboxes; i++) { - select[i] = true; - } - // sort scores - std::vector tmp_sc; - for (int i = 0; i < nboxes; i++) { - tmp_sc.push_back(sc[k * nboxes * nclass + g * nboxes + i]); - } - std::vector order(tmp_sc.size()); - std::iota(order.begin(), order.end(), 0); - std::sort(order.begin(), order.end(), - [&tmp_sc](int64_t id1, int64_t id2) { return tmp_sc[id1] > tmp_sc[id2]; }); - for (int64_t _i = 0; _i < nboxes; _i++) { - if (select[_i] == false) continue; - auto i = order[_i]; - for (int64_t _j = _i + 1; _j < nboxes; _j++) { - if (select[_j] == false) continue; - auto j = order[_j]; - RotatedBox box1, box2; - auto center_shift_x = - (tmp_boxes[k * nboxes * 5 + i * 5] + tmp_boxes[k * nboxes * 5 + j * 5]) / 2.0; - auto center_shift_y = - (tmp_boxes[k * nboxes * 5 + i * 5 + 1] + tmp_boxes[k * nboxes * 5 + j * 5 + 1]) / 2.0; - box1.x_ctr = tmp_boxes[k * nboxes * 5 + i * 5] - center_shift_x; - box1.y_ctr = tmp_boxes[k * nboxes * 5 + i * 5 + 1] - center_shift_y; - box1.w = tmp_boxes[k * nboxes * 5 + i * 5 + 2]; - box1.h = tmp_boxes[k * nboxes * 5 + i * 5 + 3]; - box1.a = tmp_boxes[k * nboxes * 5 + i * 5 + 4]; - box2.x_ctr = tmp_boxes[k * nboxes * 5 + j * 5] - center_shift_x; - box2.y_ctr = tmp_boxes[k * nboxes * 5 + j * 5 + 1] - center_shift_y; - box2.w = tmp_boxes[k * nboxes * 5 + j * 5 + 2]; - box2.h = tmp_boxes[k * nboxes * 5 + j * 5 + 3]; - box2.a = tmp_boxes[k * nboxes * 5 + j * 5 + 4]; - auto area1 = box1.w * box1.h; - auto area2 = box2.w * box2.h; - auto intersection = rotated_boxes_intersection(box1, box2); - float baseS = 1.0; - baseS = (area1 + area2 - intersection); - auto ovr = intersection / baseS; - if (ovr > iou_threshold) select[_j] = false; + + float polygon_area(const Point (&q)[24], const int& m) + { + if (m <= 2) + { + return 0; } - } - for (int i = 0; i < nboxes; i++) { - if (select[i] & (tmp_sc[order[i]] > score_threshold)) { - res_order.push_back(k); - res_order.push_back(g); - res_order.push_back(order[i]); + + float area = 0; + for (int i = 1; i < m - 1; i++) + { + area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); } - } - } // class loop - } // batch loop - std::vector inds_dims({(int64_t)res_order.size() / 3, 3}); + return area / 2.0; + } + + float rotated_boxes_intersection(const RotatedBox& box1, const RotatedBox& box2) + { + // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned + // from rotated_rect_intersection_pts + Point intersectPts[24], orderedPts[24]; - OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size()); - int64_t* res_data = ort_.GetTensorMutableData(res); + Point pts1[4]; + Point pts2[4]; + get_rotated_vertices(box1, pts1); + get_rotated_vertices(box2, pts2); - memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size()); + int num = get_intersection_points(pts1, pts2, intersectPts); - allocator_.Free(tmp_boxes); - allocator_.Free(sc); - allocator_.Free(select); -} + if (num <= 2) + { + return 0.0; + } + + // Convex Hull to order the intersection points in clockwise order and find + // the contour area. + int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); + return polygon_area(orderedPts, num_convex); + } + + NMSRotatedKernel::NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info) + : ort_(api) + , info_(info) + { + iou_threshold_ = ort_.KernelInfoGetAttribute(info, "iou_threshold"); + score_threshold_ = ort_.KernelInfoGetAttribute(info, "score_threshold"); + + // create allocator + allocator_ = Ort::AllocatorWithDefaultOptions(); + } + + void NMSRotatedKernel::Compute(OrtKernelContext* context) + { + const float iou_threshold = iou_threshold_; + const float score_threshold = score_threshold_; + + const OrtValue* boxes = ort_.KernelContext_GetInput(context, 0); + const float* boxes_data = reinterpret_cast(ort_.GetTensorData(boxes)); + const OrtValue* scores = ort_.KernelContext_GetInput(context, 1); + const float* scores_data = reinterpret_cast(ort_.GetTensorData(scores)); + + OrtTensorDimensions boxes_dim(ort_, boxes); + OrtTensorDimensions scores_dim(ort_, scores); + + // loop over batch + int64_t nbatch = boxes_dim[0]; + int64_t nboxes = boxes_dim[1]; + int64_t nclass = scores_dim[1]; + assert(boxes_dim[2] == 5); //(cx,cy,w,h,theta) + + // allocate tmp memory + float* tmp_boxes = (float*)allocator_.Alloc(sizeof(float) * nbatch * nboxes * 5); + float* sc = (float*)allocator_.Alloc(sizeof(float) * nbatch * nclass * nboxes); + bool* select = (bool*)allocator_.Alloc(sizeof(bool) * nbatch * nboxes); + + memcpy(tmp_boxes, boxes_data, sizeof(float) * nbatch * nboxes * 5); + memcpy(sc, scores_data, sizeof(float) * nbatch * nclass * nboxes); + + // std::vector> res_order; + std::vector res_order; + for (int64_t k = 0; k < nbatch; k++) + { + for (int64_t g = 0; g < nclass; g++) + { + for (int64_t i = 0; i < nboxes; i++) + { + select[i] = true; + } + // sort scores + std::vector tmp_sc; + for (int i = 0; i < nboxes; i++) + { + tmp_sc.push_back(sc[k * nboxes * nclass + g * nboxes + i]); + } + std::vector order(tmp_sc.size()); + std::iota(order.begin(), order.end(), 0); + std::sort(order.begin(), order.end(), [&tmp_sc](int64_t id1, int64_t id2) + { return tmp_sc[id1] > tmp_sc[id2]; }); + for (int64_t _i = 0; _i < nboxes; _i++) + { + if (select[_i] == false) continue; + auto i = order[_i]; + for (int64_t _j = _i + 1; _j < nboxes; _j++) + { + if (select[_j] == false) continue; + auto j = order[_j]; + RotatedBox box1, box2; + auto center_shift_x = + (tmp_boxes[k * nboxes * 5 + i * 5] + tmp_boxes[k * nboxes * 5 + j * 5]) / 2.0; + auto center_shift_y = + (tmp_boxes[k * nboxes * 5 + i * 5 + 1] + tmp_boxes[k * nboxes * 5 + j * 5 + 1]) / 2.0; + box1.x_ctr = tmp_boxes[k * nboxes * 5 + i * 5] - center_shift_x; + box1.y_ctr = tmp_boxes[k * nboxes * 5 + i * 5 + 1] - center_shift_y; + box1.w = tmp_boxes[k * nboxes * 5 + i * 5 + 2]; + box1.h = tmp_boxes[k * nboxes * 5 + i * 5 + 3]; + box1.a = tmp_boxes[k * nboxes * 5 + i * 5 + 4]; + box2.x_ctr = tmp_boxes[k * nboxes * 5 + j * 5] - center_shift_x; + box2.y_ctr = tmp_boxes[k * nboxes * 5 + j * 5 + 1] - center_shift_y; + box2.w = tmp_boxes[k * nboxes * 5 + j * 5 + 2]; + box2.h = tmp_boxes[k * nboxes * 5 + j * 5 + 3]; + box2.a = tmp_boxes[k * nboxes * 5 + j * 5 + 4]; + auto area1 = box1.w * box1.h; + auto area2 = box2.w * box2.h; + auto intersection = rotated_boxes_intersection(box1, box2); + float baseS = 1.0; + baseS = (area1 + area2 - intersection); + auto ovr = intersection / baseS; + if (ovr > iou_threshold) select[_j] = false; + } + } + for (int i = 0; i < nboxes; i++) + { + if (select[i] & (tmp_sc[order[i]] > score_threshold)) + { + res_order.push_back(k); + res_order.push_back(g); + res_order.push_back(order[i]); + } + } + } // class loop + } // batch loop + + std::vector inds_dims({(int64_t)res_order.size() / 3, 3}); + + OrtValue* res = ort_.KernelContext_GetOutput(context, 0, inds_dims.data(), inds_dims.size()); + int64_t* res_data = ort_.GetTensorMutableData(res); + + memcpy(res_data, res_order.data(), sizeof(int64_t) * res_order.size()); + + allocator_.Free(tmp_boxes); + allocator_.Free(sc); + allocator_.Free(select); + } -REGISTER_ONNXRUNTIME_OPS(mmdeploy, NMSRotatedOp); + REGISTER_ONNXRUNTIME_OPS(mmdeploy, NMSRotatedOp); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.h b/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.h index 6ed44ce410..3b4aa856a5 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/nms_rotated/nms_rotated.h @@ -10,39 +10,57 @@ #include #include -namespace mmdeploy { -struct NMSRotatedKernel { - NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info); - - void Compute(OrtKernelContext* context); - - private: - Ort::CustomOpApi ort_; - const OrtKernelInfo* info_; - Ort::AllocatorWithDefaultOptions allocator_; - float iou_threshold_; - float score_threshold_; -}; - -struct NMSRotatedOp : Ort::CustomOpBase { - void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { - return new NMSRotatedKernel(api, info); - } - const char* GetName() const { return "NMSRotated"; } - - size_t GetInputTypeCount() const { return 2; } - ONNXTensorElementDataType GetInputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } - - size_t GetOutputTypeCount() const { return 1; } - ONNXTensorElementDataType GetOutputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; - } - - // force cpu - const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } -}; +namespace mmdeploy +{ + struct NMSRotatedKernel + { + NMSRotatedKernel(const OrtApi& api, const OrtKernelInfo* info); + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; + const OrtKernelInfo* info_; + Ort::AllocatorWithDefaultOptions allocator_; + float iou_threshold_; + float score_threshold_; + }; + + struct NMSRotatedOp : Ort::CustomOpBase + { + void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const + { + return new NMSRotatedKernel(api, info); + } + const char* GetName() const + { + return "NMSRotated"; + } + + size_t GetInputTypeCount() const + { + return 2; + } + ONNXTensorElementDataType GetInputType(size_t) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + size_t GetOutputTypeCount() const + { + return 1; + } + ONNXTensorElementDataType GetOutputType(size_t) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; + } + + // force cpu + const char* GetExecutionProviderType() const + { + return "CPUExecutionProvider"; + } + }; } // namespace mmdeploy #endif // ONNXRUNTIME_NMS_ROTATED_H diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/onnxruntime_register.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/onnxruntime_register.cpp index f7b9cedff8..1159496843 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/onnxruntime_register.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/onnxruntime_register.cpp @@ -3,25 +3,30 @@ #include "ort_utils.h" -const char *c_MMDeployOpDomain = "mmdeploy"; +const char* c_MMDeployOpDomain = "mmdeploy"; -OrtStatus *ORT_API_CALL RegisterCustomOps(OrtSessionOptions *options, const OrtApiBase *api) { - const OrtApi *kOrtApi = api->GetApi(ORT_API_VERSION); - OrtStatus *status = nullptr; - for (auto &_op_list_pair : mmdeploy::get_mmdeploy_custom_ops()) { - OrtCustomOpDomain *domain = nullptr; - if (auto status = kOrtApi->CreateCustomOpDomain(_op_list_pair.first.c_str(), &domain)) { - return status; +OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) +{ + const OrtApi* kOrtApi = api->GetApi(ORT_API_VERSION); + OrtStatus* status = nullptr; + for (auto& _op_list_pair : mmdeploy::get_mmdeploy_custom_ops()) + { + OrtCustomOpDomain* domain = nullptr; + if (auto status = kOrtApi->CreateCustomOpDomain(_op_list_pair.first.c_str(), &domain)) + { + return status; + } + auto& _op_list = _op_list_pair.second; + for (auto& _op : _op_list) + { + if (auto status = kOrtApi->CustomOpDomain_Add(domain, _op)) + { + return status; + } + } + // TODO: figure out what will return if failed. + status = kOrtApi->AddCustomOpDomain(options, domain); } - auto &_op_list = _op_list_pair.second; - for (auto &_op : _op_list) { - if (auto status = kOrtApi->CustomOpDomain_Add(domain, _op)) { - return status; - } - } - // TODO: figure out what will return if failed. - status = kOrtApi->AddCustomOpDomain(options, domain); - } - return status; + return status; } diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.cpp b/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.cpp index a8e7023fe1..4fbf6365d0 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.cpp +++ b/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.cpp @@ -5,233 +5,245 @@ #include "ort_utils.h" -namespace mmdeploy { -// implementation taken from Caffe2 -struct PreCalc { - int pos1; - int pos2; - int pos3; - int pos4; - float w1; - float w2; - float w3; - float w4; -}; - -void pre_calc_for_bilinear_interpolate(const int height, const int width, const int pooled_height, - const int pooled_width, const int iy_upper, - const int ix_upper, float roi_start_h, float roi_start_w, - float bin_size_h, float bin_size_w, int roi_bin_grid_h, - int roi_bin_grid_w, float roi_center_h, float roi_center_w, - float cos_theta, float sin_theta, - std::vector &pre_calc) { - int pre_calc_index = 0; - for (int ph = 0; ph < pooled_height; ph++) { - for (int pw = 0; pw < pooled_width; pw++) { - for (int iy = 0; iy < iy_upper; iy++) { - const float yy = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (int ix = 0; ix < ix_upper; ix++) { - const float xx = - roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); - - // Rotate by theta around the center and translate - // In image space, (y, x) is the order for Right Handed System, - // and this is essentially multiplying the point by a rotation matrix - // to rotate it counterclockwise through angle theta. - float y = yy * cos_theta - xx * sin_theta + roi_center_h; - float x = yy * sin_theta + xx * cos_theta + roi_center_w; - // deal with: inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) { - // empty - PreCalc pc; - pc.pos1 = 0; - pc.pos2 = 0; - pc.pos3 = 0; - pc.pos4 = 0; - pc.w1 = 0; - pc.w2 = 0; - pc.w3 = 0; - pc.w4 = 0; - pre_calc[pre_calc_index] = pc; - pre_calc_index += 1; - continue; - } - - if (y < 0) { - y = 0; - } - if (x < 0) { - x = 0; - } - - int y_low = (int)y; - int x_low = (int)x; - int y_high; - int x_high; - - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = (float)y_low; - } else { - y_high = y_low + 1; - } - - if (x_low >= width - 1) { - x_high = x_low = width - 1; - x = (float)x_low; - } else { - x_high = x_low + 1; - } - - float ly = y - y_low; - float lx = x - x_low; - float hy = 1. - ly, hx = 1. - lx; - float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; - - // save weights and indices - PreCalc pc; - pc.pos1 = y_low * width + x_low; - pc.pos2 = y_low * width + x_high; - pc.pos3 = y_high * width + x_low; - pc.pos4 = y_high * width + x_high; - pc.w1 = w1; - pc.w2 = w2; - pc.w3 = w3; - pc.w4 = w4; - pre_calc[pre_calc_index] = pc; - - pre_calc_index += 1; +namespace mmdeploy +{ + // implementation taken from Caffe2 + struct PreCalc + { + int pos1; + int pos2; + int pos3; + int pos4; + float w1; + float w2; + float w3; + float w4; + }; + + void pre_calc_for_bilinear_interpolate(const int height, const int width, const int pooled_height, const int pooled_width, const int iy_upper, const int ix_upper, float roi_start_h, float roi_start_w, float bin_size_h, float bin_size_w, int roi_bin_grid_h, int roi_bin_grid_w, float roi_center_h, float roi_center_w, float cos_theta, float sin_theta, std::vector& pre_calc) + { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) + { + for (int pw = 0; pw < pooled_width; pw++) + { + for (int iy = 0; iy < iy_upper; iy++) + { + const float yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < ix_upper; ix++) + { + const float xx = + roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + // In image space, (y, x) is the order for Right Handed System, + // and this is essentially multiplying the point by a rotation matrix + // to rotate it counterclockwise through angle theta. + float y = yy * cos_theta - xx * sin_theta + roi_center_h; + float x = yy * sin_theta + xx * cos_theta + roi_center_w; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) + { + // empty + PreCalc pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y < 0) + { + y = 0; + } + if (x < 0) + { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) + { + y_high = y_low = height - 1; + y = (float)y_low; + } + else + { + y_high = y_low + 1; + } + + if (x_low >= width - 1) + { + x_high = x_low = width - 1; + x = (float)x_low; + } + else + { + x_high = x_low + 1; + } + + float ly = y - y_low; + float lx = x - x_low; + float hy = 1. - ly, hx = 1. - lx; + float w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indices + PreCalc pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } } - } - } - } -} - -void ROIAlignRotatedForwardCPU(const int nthreads, const float *input, const float *rois, - float *output, const float &spatial_scale, const int aligned, - const int clockwise, const int channels, const int height, - const int width, const int pooled_height, const int pooled_width, - const int sampling_ratio) { - int n_rois = nthreads / channels / pooled_width / pooled_height; - // (n, c, ph, pw) is an element in the pooled output - // can be parallelized using omp - // #pragma omp parallel for num_threads(32) - for (int n = 0; n < n_rois; n++) { - int index_n = n * channels * pooled_width * pooled_height; - - const float *current_roi = rois + n * 6; - int roi_batch_ind = current_roi[0]; - - // Do not use rounding; this implementation detail is critical - float offset = aligned ? (float)0.5 : (float)0.0; - float roi_center_w = current_roi[1] * spatial_scale - offset; - float roi_center_h = current_roi[2] * spatial_scale - offset; - float roi_width = current_roi[3] * spatial_scale; - float roi_height = current_roi[4] * spatial_scale; - // float theta = current_roi[5] * M_PI / 180.0; - float theta = current_roi[5]; // Radian angle by default - if (clockwise) { - theta = -theta; } - float cos_theta = cos(theta); - float sin_theta = sin(theta); - if (!aligned) { // for backward-compatibility only - roi_width = std::max(roi_width, (float)1.); - roi_height = std::max(roi_height, (float)1.); - } - - float bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - float bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - - // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = - (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 - int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); - - // We do average (integral) pooling inside a bin - const float count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 - - // we want to precalculate indices and weights shared by all channels, - // this is the key point of optimization - std::vector pre_calc(roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); - - // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). - // Appropriate translation needs to be applied after. - float roi_start_h = -roi_height / 2.0; - float roi_start_w = -roi_width / 2.0; - pre_calc_for_bilinear_interpolate(height, width, pooled_height, pooled_width, roi_bin_grid_h, - roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, - bin_size_w, roi_bin_grid_h, roi_bin_grid_w, roi_center_h, - roi_center_w, cos_theta, sin_theta, pre_calc); - - for (int c = 0; c < channels; c++) { - int index_n_c = index_n + c * pooled_width * pooled_height; - const float *offset_input = input + (roi_batch_ind * channels + c) * height * width; - int pre_calc_index = 0; + void ROIAlignRotatedForwardCPU(const int nthreads, const float* input, const float* rois, float* output, const float& spatial_scale, const int aligned, const int clockwise, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int sampling_ratio) + { + int n_rois = nthreads / channels / pooled_width / pooled_height; + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp + // #pragma omp parallel for num_threads(32) + for (int n = 0; n < n_rois; n++) + { + int index_n = n * channels * pooled_width * pooled_height; + + const float* current_roi = rois + n * 6; + int roi_batch_ind = current_roi[0]; + + // Do not use rounding; this implementation detail is critical + float offset = aligned ? (float)0.5 : (float)0.0; + float roi_center_w = current_roi[1] * spatial_scale - offset; + float roi_center_h = current_roi[2] * spatial_scale - offset; + float roi_width = current_roi[3] * spatial_scale; + float roi_height = current_roi[4] * spatial_scale; + // float theta = current_roi[5] * M_PI / 180.0; + float theta = current_roi[5]; // Radian angle by default + if (clockwise) + { + theta = -theta; + } + float cos_theta = cos(theta); + float sin_theta = sin(theta); + if (!aligned) + { // for backward-compatibility only + roi_width = std::max(roi_width, (float)1.); + roi_height = std::max(roi_height, (float)1.); + } - for (int ph = 0; ph < pooled_height; ph++) { - for (int pw = 0; pw < pooled_width; pw++) { - int index = index_n_c + ph * pooled_width + pw; + float bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + float bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const float count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + // we want to precalculate indices and weights shared by all channels, + // this is the key point of optimization + std::vector pre_calc(roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + float roi_start_h = -roi_height / 2.0; + float roi_start_w = -roi_width / 2.0; + + pre_calc_for_bilinear_interpolate(height, width, pooled_height, pooled_width, roi_bin_grid_h, roi_bin_grid_w, roi_start_h, roi_start_w, bin_size_h, bin_size_w, roi_bin_grid_h, roi_bin_grid_w, roi_center_h, roi_center_w, cos_theta, sin_theta, pre_calc); + + for (int c = 0; c < channels; c++) + { + int index_n_c = index_n + c * pooled_width * pooled_height; + const float* offset_input = input + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) + { + for (int pw = 0; pw < pooled_width; pw++) + { + int index = index_n_c + ph * pooled_width + pw; + + float output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) + { + for (int ix = 0; ix < roi_bin_grid_w; ix++) + { + PreCalc pc = pre_calc[pre_calc_index]; + output_val += pc.w1 * offset_input[pc.pos1] + pc.w2 * offset_input[pc.pos2] + + pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4]; + + pre_calc_index += 1; + } + } + output_val /= count; + + output[index] = output_val; + } // for pw + } // for ph + } // for c + } // for n + } - float output_val = 0.; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - PreCalc pc = pre_calc[pre_calc_index]; - output_val += pc.w1 * offset_input[pc.pos1] + pc.w2 * offset_input[pc.pos2] + - pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4]; + void MMCVRoIAlignRotatedKernel::Compute(OrtKernelContext* context) + { + // Setup inputs + const OrtValue* input_X = ort_.KernelContext_GetInput(context, 0); + const float* X_data = reinterpret_cast(ort_.GetTensorData(input_X)); + const OrtValue* input_rois = ort_.KernelContext_GetInput(context, 1); + const float* rois = + reinterpret_cast(ort_.GetTensorData(input_rois)); + + // Setup output + OrtTensorDimensions out_dimensions(ort_, input_X); + OrtTensorDimensions roi_dimensions(ort_, input_rois); + + int batch_size = out_dimensions.data()[0]; + int input_channels = out_dimensions.data()[1]; + int input_height = out_dimensions.data()[2]; + int input_width = out_dimensions.data()[3]; + + out_dimensions.data()[0] = roi_dimensions.data()[0]; + out_dimensions.data()[2] = aligned_height_; + out_dimensions.data()[3] = aligned_width_; + + OrtValue* output = + ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size()); + float* out = ort_.GetTensorMutableData(output); + OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output); + ort_.ReleaseTensorTypeAndShapeInfo(output_info); + + // TODO: forward here + int output_size = out_dimensions.data()[0]; + for (auto i = 1; i < out_dimensions.size(); ++i) + { + output_size *= out_dimensions.data()[i]; + } + ROIAlignRotatedForwardCPU(output_size, X_data, rois, out, spatial_scale_, aligned_, clockwise_, input_channels, input_height, input_width, aligned_height_, aligned_width_, sampling_ratio_); + } - pre_calc_index += 1; - } - } - output_val /= count; - - output[index] = output_val; - } // for pw - } // for ph - } // for c - } // for n -} - -void MMCVRoIAlignRotatedKernel::Compute(OrtKernelContext *context) { - // Setup inputs - const OrtValue *input_X = ort_.KernelContext_GetInput(context, 0); - const float *X_data = reinterpret_cast(ort_.GetTensorData(input_X)); - const OrtValue *input_rois = ort_.KernelContext_GetInput(context, 1); - const float *rois = - reinterpret_cast(ort_.GetTensorData(input_rois)); - - // Setup output - OrtTensorDimensions out_dimensions(ort_, input_X); - OrtTensorDimensions roi_dimensions(ort_, input_rois); - - int batch_size = out_dimensions.data()[0]; - int input_channels = out_dimensions.data()[1]; - int input_height = out_dimensions.data()[2]; - int input_width = out_dimensions.data()[3]; - - out_dimensions.data()[0] = roi_dimensions.data()[0]; - out_dimensions.data()[2] = aligned_height_; - out_dimensions.data()[3] = aligned_width_; - - OrtValue *output = - ort_.KernelContext_GetOutput(context, 0, out_dimensions.data(), out_dimensions.size()); - float *out = ort_.GetTensorMutableData(output); - OrtTensorTypeAndShapeInfo *output_info = ort_.GetTensorTypeAndShape(output); - ort_.ReleaseTensorTypeAndShapeInfo(output_info); - - // TODO: forward here - int output_size = out_dimensions.data()[0]; - for (auto i = 1; i < out_dimensions.size(); ++i) { - output_size *= out_dimensions.data()[i]; - } - ROIAlignRotatedForwardCPU(output_size, X_data, rois, out, spatial_scale_, aligned_, clockwise_, - input_channels, input_height, input_width, aligned_height_, - aligned_width_, sampling_ratio_); -} - -REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVRoIAlignRotatedCustomOp); + REGISTER_ONNXRUNTIME_OPS(mmdeploy, MMCVRoIAlignRotatedCustomOp); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.h b/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.h index c0129d31f8..24a90e5321 100644 --- a/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.h +++ b/csrc/mmdeploy/backend_ops/onnxruntime/roi_align_rotated/roi_align_rotated.h @@ -10,50 +10,70 @@ #include #include -namespace mmdeploy { -struct MMCVRoIAlignRotatedKernel { - public: - MMCVRoIAlignRotatedKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info) : ort_(ort) { - aligned_height_ = ort_.KernelInfoGetAttribute(info, "output_height"); - aligned_width_ = ort_.KernelInfoGetAttribute(info, "output_width"); - sampling_ratio_ = ort_.KernelInfoGetAttribute(info, "sampling_ratio"); - spatial_scale_ = ort_.KernelInfoGetAttribute(info, "spatial_scale"); - aligned_ = ort_.KernelInfoGetAttribute(info, "aligned"); - clockwise_ = ort_.KernelInfoGetAttribute(info, "clockwise"); - } - - void Compute(OrtKernelContext* context); - - private: - Ort::CustomOpApi ort_; - int aligned_height_; - int aligned_width_; - float spatial_scale_; - int sampling_ratio_; - int aligned_; - int clockwise_; -}; - -struct MMCVRoIAlignRotatedCustomOp - : Ort::CustomOpBase { - void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const { - return new MMCVRoIAlignRotatedKernel(api, info); - } - const char* GetName() const { return "MMCVRoIAlignRotated"; } - - size_t GetInputTypeCount() const { return 2; } - ONNXTensorElementDataType GetInputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } - - size_t GetOutputTypeCount() const { return 1; } - ONNXTensorElementDataType GetOutputType(size_t) const { - return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; - } - - // force cpu - const char* GetExecutionProviderType() const { return "CPUExecutionProvider"; } -}; +namespace mmdeploy +{ + struct MMCVRoIAlignRotatedKernel + { + public: + MMCVRoIAlignRotatedKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info) + : ort_(ort) + { + aligned_height_ = ort_.KernelInfoGetAttribute(info, "output_height"); + aligned_width_ = ort_.KernelInfoGetAttribute(info, "output_width"); + sampling_ratio_ = ort_.KernelInfoGetAttribute(info, "sampling_ratio"); + spatial_scale_ = ort_.KernelInfoGetAttribute(info, "spatial_scale"); + aligned_ = ort_.KernelInfoGetAttribute(info, "aligned"); + clockwise_ = ort_.KernelInfoGetAttribute(info, "clockwise"); + } + + void Compute(OrtKernelContext* context); + + private: + Ort::CustomOpApi ort_; + int aligned_height_; + int aligned_width_; + float spatial_scale_; + int sampling_ratio_; + int aligned_; + int clockwise_; + }; + + struct MMCVRoIAlignRotatedCustomOp + : Ort::CustomOpBase + { + void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const + { + return new MMCVRoIAlignRotatedKernel(api, info); + } + const char* GetName() const + { + return "MMCVRoIAlignRotated"; + } + + size_t GetInputTypeCount() const + { + return 2; + } + ONNXTensorElementDataType GetInputType(size_t) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + size_t GetOutputTypeCount() const + { + return 1; + } + ONNXTensorElementDataType GetOutputType(size_t) const + { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT; + } + + // force cpu + const char* GetExecutionProviderType() const + { + return "CPUExecutionProvider"; + } + }; } // namespace mmdeploy #endif // ONNXRUNTIME_ROI_ALIGN_ROTATED_H diff --git a/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp b/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp index 431f2dd63b..8edec279c5 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.cpp @@ -9,225 +9,294 @@ #include "nms/kernel.h" #include "trt_serialize.hpp" -namespace mmdeploy { -using namespace nvinfer1; -using nvinfer1::plugin::NMSParameters; - -namespace { -static const char* NMS_PLUGIN_VERSION{"1"}; -static const char* NMS_PLUGIN_NAME{"TRTBatchedNMS"}; -} // namespace - -TRTBatchedNMS::TRTBatchedNMS(const std::string& name, NMSParameters params, bool returnIndex) - : TRTPluginBase(name), param(params), mReturnIndex(returnIndex) {} - -TRTBatchedNMS::TRTBatchedNMS(const std::string& name, const void* data, size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, ¶m); - deserialize_value(&data, &length, &mClipBoxes); - deserialize_value(&data, &length, &mReturnIndex); -} - -int TRTBatchedNMS::getNbOutputs() const TRT_NOEXCEPT { - int num = mReturnIndex ? 3 : 2; - return num; -} - -nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { - ASSERT(nbInputs == 2); - ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs()); - ASSERT(inputs[0].nbDims == 4); - ASSERT(inputs[1].nbDims == 3); - - nvinfer1::DimsExprs ret; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = exprBuilder.constant(param.keepTopK); - switch (outputIndex) { - case 0: - ret.nbDims = 3; - ret.d[2] = exprBuilder.constant(5); - break; - case 1: - ret.nbDims = 2; - break; - case 2: - ret.nbDims = 2; - default: - break; - } - - return ret; -} - -size_t TRTBatchedNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT { - size_t batch_size = inputs[0].dims.d[0]; - size_t boxes_size = inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3]; - size_t score_size = inputs[1].dims.d[1] * inputs[1].dims.d[2]; - size_t num_priors = inputs[0].dims.d[1]; - bool shareLocation = (inputs[0].dims.d[2] == 1); - int topk = param.topK > 0 && param.topK <= inputs[1].dims.d[1] ? param.topK : inputs[1].dims.d[1]; - return detectionInferenceWorkspaceSize(shareLocation, batch_size, boxes_size, score_size, - param.numClasses, num_priors, topk, DataType::kFLOAT, - DataType::kFLOAT); -} - -int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, - void* const* outputs, void* workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - const void* const locData = inputs[0]; - const void* const confData = inputs[1]; - - void* nmsedDets = outputs[0]; - void* nmsedLabels = outputs[1]; - void* nmsedIndex = mReturnIndex ? outputs[2] : nullptr; - - size_t batch_size = inputDesc[0].dims.d[0]; - size_t boxes_size = inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3]; - size_t score_size = inputDesc[1].dims.d[1] * inputDesc[1].dims.d[2]; - size_t num_priors = inputDesc[0].dims.d[1]; - bool shareLocation = (inputDesc[0].dims.d[2] == 1); - - int topk = - param.topK > 0 && param.topK <= inputDesc[1].dims.d[1] ? param.topK : inputDesc[1].dims.d[1]; - bool rotated = false; - pluginStatus_t status = nmsInference( - stream, batch_size, boxes_size, score_size, shareLocation, param.backgroundLabelId, - num_priors, param.numClasses, topk, param.keepTopK, param.scoreThreshold, param.iouThreshold, - DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, nmsedIndex, - workSpace, param.isNormalized, false, mClipBoxes, rotated); - ASSERT(status == STATUS_SUCCESS); - - return 0; -} - -size_t TRTBatchedNMS::getSerializationSize() const TRT_NOEXCEPT { - // NMSParameters - return sizeof(NMSParameters) + sizeof(mClipBoxes) + sizeof(mReturnIndex); -} - -void TRTBatchedNMS::serialize(void* buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, param); - serialize_value(&buffer, mClipBoxes); - serialize_value(&buffer, mReturnIndex); -} - -void TRTBatchedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* outputs, - int nbOutputs) TRT_NOEXCEPT { - // Validate input arguments -} - -bool TRTBatchedNMS::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 3 || pos == 4) { - return ioDesc[pos].type == nvinfer1::DataType::kINT32 && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; - } - return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; -} - -const char* TRTBatchedNMS::getPluginType() const TRT_NOEXCEPT { return NMS_PLUGIN_NAME; } - -const char* TRTBatchedNMS::getPluginVersion() const TRT_NOEXCEPT { return NMS_PLUGIN_VERSION; } - -IPluginV2DynamicExt* TRTBatchedNMS::clone() const TRT_NOEXCEPT { - auto* plugin = new TRTBatchedNMS(mLayerName, param, mReturnIndex); - plugin->setPluginNamespace(mNamespace.c_str()); - plugin->setClipParam(mClipBoxes); - return plugin; -} - -nvinfer1::DataType TRTBatchedNMS::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, - int nbInputs) const TRT_NOEXCEPT { - ASSERT(index >= 0 && index < this->getNbOutputs()); - if (index == 1 || index == 2) { - return nvinfer1::DataType::kINT32; - } - return inputTypes[0]; -} - -void TRTBatchedNMS::setClipParam(bool clip) { mClipBoxes = clip; } - -TRTBatchedNMSCreator::TRTBatchedNMSCreator() { - mPluginAttributes.emplace_back( - PluginField("background_label_id", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("num_classes", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("topk", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back( - PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); - mPluginAttributes.emplace_back( - PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); - mPluginAttributes.emplace_back(PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("return_index", nullptr, PluginFieldType::kINT32, 1)); - - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char* TRTBatchedNMSCreator::getPluginName() const TRT_NOEXCEPT { return NMS_PLUGIN_NAME; } - -const char* TRTBatchedNMSCreator::getPluginVersion() const TRT_NOEXCEPT { - return NMS_PLUGIN_VERSION; -} - -IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(const char* name, - const PluginFieldCollection* fc) TRT_NOEXCEPT { - const PluginField* fields = fc->fields; - bool clipBoxes = true; - bool returnIndex = false; - nvinfer1::plugin::NMSParameters params{}; - - for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; - if (!strcmp(attrName, "background_label_id")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.backgroundLabelId = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "num_classes")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.numClasses = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "topk")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.topK = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "keep_topk")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.keepTopK = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "score_threshold")) { - ASSERT(fields[i].type == PluginFieldType::kFLOAT32); - params.scoreThreshold = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "iou_threshold")) { - ASSERT(fields[i].type == PluginFieldType::kFLOAT32); - params.iouThreshold = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "is_normalized")) { - params.isNormalized = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "clip_boxes")) { - clipBoxes = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "return_index")) { - returnIndex = *(static_cast(fields[i].data)); - } - } - - TRTBatchedNMS* plugin = new TRTBatchedNMS(name, params, returnIndex); - plugin->setClipParam(clipBoxes); - plugin->setPluginNamespace(mNamespace.c_str()); - return plugin; -} - -IPluginV2Ext* TRTBatchedNMSCreator::deserializePlugin(const char* name, const void* serialData, - size_t serialLength) TRT_NOEXCEPT { - // This object will be deleted when the network is destroyed, which will - // call NMS::destroy() - TRTBatchedNMS* plugin = new TRTBatchedNMS(name, serialData, serialLength); - plugin->setPluginNamespace(mNamespace.c_str()); - return plugin; -} - -REGISTER_TENSORRT_PLUGIN(TRTBatchedNMSCreator); +namespace mmdeploy +{ + using namespace nvinfer1; + using nvinfer1::plugin::NMSParameters; + + namespace + { + static const char* NMS_PLUGIN_VERSION{"1"}; + static const char* NMS_PLUGIN_NAME{"TRTBatchedNMS"}; + } // namespace + + TRTBatchedNMS::TRTBatchedNMS(const std::string& name, NMSParameters params, bool returnIndex) + : TRTPluginBase(name) + , param(params) + , mReturnIndex(returnIndex) + { + } + + TRTBatchedNMS::TRTBatchedNMS(const std::string& name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, ¶m); + deserialize_value(&data, &length, &mClipBoxes); + deserialize_value(&data, &length, &mReturnIndex); + } + + int TRTBatchedNMS::getNbOutputs() const TRT_NOEXCEPT + { + int num = mReturnIndex ? 3 : 2; + return num; + } + + nvinfer1::DimsExprs TRTBatchedNMS::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + ASSERT(nbInputs == 2); + ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs()); + ASSERT(inputs[0].nbDims == 4); + ASSERT(inputs[1].nbDims == 3); + + nvinfer1::DimsExprs ret; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = exprBuilder.constant(param.keepTopK); + switch (outputIndex) + { + case 0: + ret.nbDims = 3; + ret.d[2] = exprBuilder.constant(5); + break; + case 1: + ret.nbDims = 2; + break; + case 2: + ret.nbDims = 2; + default: + break; + } + + return ret; + } + + size_t TRTBatchedNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT + { + size_t batch_size = inputs[0].dims.d[0]; + size_t boxes_size = inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3]; + size_t score_size = inputs[1].dims.d[1] * inputs[1].dims.d[2]; + size_t num_priors = inputs[0].dims.d[1]; + bool shareLocation = (inputs[0].dims.d[2] == 1); + int topk = param.topK > 0 && param.topK <= inputs[1].dims.d[1] ? param.topK : inputs[1].dims.d[1]; + return detectionInferenceWorkspaceSize(shareLocation, batch_size, boxes_size, score_size, param.numClasses, num_priors, topk, DataType::kFLOAT, DataType::kFLOAT); + } + + int TRTBatchedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + const void* const locData = inputs[0]; + const void* const confData = inputs[1]; + + void* nmsedDets = outputs[0]; + void* nmsedLabels = outputs[1]; + void* nmsedIndex = mReturnIndex ? outputs[2] : nullptr; + + size_t batch_size = inputDesc[0].dims.d[0]; + size_t boxes_size = inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3]; + size_t score_size = inputDesc[1].dims.d[1] * inputDesc[1].dims.d[2]; + size_t num_priors = inputDesc[0].dims.d[1]; + bool shareLocation = (inputDesc[0].dims.d[2] == 1); + + int topk = + param.topK > 0 && param.topK <= inputDesc[1].dims.d[1] ? param.topK : inputDesc[1].dims.d[1]; + bool rotated = false; + pluginStatus_t status = nmsInference( + stream, + batch_size, + boxes_size, + score_size, + shareLocation, + param.backgroundLabelId, + num_priors, + param.numClasses, + topk, + param.keepTopK, + param.scoreThreshold, + param.iouThreshold, + DataType::kFLOAT, + locData, + DataType::kFLOAT, + confData, + nmsedDets, + nmsedLabels, + nmsedIndex, + workSpace, + param.isNormalized, + false, + mClipBoxes, + rotated); + ASSERT(status == STATUS_SUCCESS); + + return 0; + } + + size_t TRTBatchedNMS::getSerializationSize() const TRT_NOEXCEPT + { + // NMSParameters + return sizeof(NMSParameters) + sizeof(mClipBoxes) + sizeof(mReturnIndex); + } + + void TRTBatchedNMS::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, param); + serialize_value(&buffer, mClipBoxes); + serialize_value(&buffer, mReturnIndex); + } + + void TRTBatchedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) TRT_NOEXCEPT + { + // Validate input arguments + } + + bool TRTBatchedNMS::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT + { + if (pos == 3 || pos == 4) + { + return ioDesc[pos].type == nvinfer1::DataType::kINT32 && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + + const char* TRTBatchedNMS::getPluginType() const TRT_NOEXCEPT + { + return NMS_PLUGIN_NAME; + } + + const char* TRTBatchedNMS::getPluginVersion() const TRT_NOEXCEPT + { + return NMS_PLUGIN_VERSION; + } + + IPluginV2DynamicExt* TRTBatchedNMS::clone() const TRT_NOEXCEPT + { + auto* plugin = new TRTBatchedNMS(mLayerName, param, mReturnIndex); + plugin->setPluginNamespace(mNamespace.c_str()); + plugin->setClipParam(mClipBoxes); + return plugin; + } + + nvinfer1::DataType TRTBatchedNMS::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT + { + ASSERT(index >= 0 && index < this->getNbOutputs()); + if (index == 1 || index == 2) + { + return nvinfer1::DataType::kINT32; + } + return inputTypes[0]; + } + + void TRTBatchedNMS::setClipParam(bool clip) + { + mClipBoxes = clip; + } + + TRTBatchedNMSCreator::TRTBatchedNMSCreator() + { + mPluginAttributes.emplace_back( + PluginField("background_label_id", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("num_classes", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("topk", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back( + PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back( + PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("return_index", nullptr, PluginFieldType::kINT32, 1)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTBatchedNMSCreator::getPluginName() const TRT_NOEXCEPT + { + return NMS_PLUGIN_NAME; + } + + const char* TRTBatchedNMSCreator::getPluginVersion() const TRT_NOEXCEPT + { + return NMS_PLUGIN_VERSION; + } + + IPluginV2Ext* TRTBatchedNMSCreator::createPlugin(const char* name, + const PluginFieldCollection* fc) TRT_NOEXCEPT + { + const PluginField* fields = fc->fields; + bool clipBoxes = true; + bool returnIndex = false; + nvinfer1::plugin::NMSParameters params{}; + + for (int i = 0; i < fc->nbFields; ++i) + { + const char* attrName = fields[i].name; + if (!strcmp(attrName, "background_label_id")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.backgroundLabelId = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "num_classes")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.numClasses = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "topk")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.topK = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "keep_topk")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.keepTopK = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "score_threshold")) + { + ASSERT(fields[i].type == PluginFieldType::kFLOAT32); + params.scoreThreshold = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "iou_threshold")) + { + ASSERT(fields[i].type == PluginFieldType::kFLOAT32); + params.iouThreshold = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "is_normalized")) + { + params.isNormalized = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "clip_boxes")) + { + clipBoxes = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "return_index")) + { + returnIndex = *(static_cast(fields[i].data)); + } + } + + TRTBatchedNMS* plugin = new TRTBatchedNMS(name, params, returnIndex); + plugin->setClipParam(clipBoxes); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + + IPluginV2Ext* TRTBatchedNMSCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT + { + // This object will be deleted when the network is destroyed, which will + // call NMS::destroy() + TRTBatchedNMS* plugin = new TRTBatchedNMS(name, serialData, serialLength); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + + REGISTER_TENSORRT_PLUGIN(TRTBatchedNMSCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp b/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp index d1e5d643db..2cd276a931 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/batched_nms/trt_batched_nms.hpp @@ -8,75 +8,77 @@ #include "NvInferPluginUtils.h" #include "trt_plugin_base.hpp" -namespace mmdeploy { +namespace mmdeploy +{ -enum NMSReturnType { RETURN_DETS = 1, RETURN_INDEX = 1 << 1 }; + enum NMSReturnType + { + RETURN_DETS = 1, + RETURN_INDEX = 1 << 1 + }; -class TRTBatchedNMS : public TRTPluginBase { - public: - TRTBatchedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param, bool returnIndex); + class TRTBatchedNMS : public TRTPluginBase + { + public: + TRTBatchedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param, bool returnIndex); - TRTBatchedNMS(const std::string& name, const void* data, size_t length); + TRTBatchedNMS(const std::string& name, const void* data, size_t length); - ~TRTBatchedNMS() TRT_NOEXCEPT override = default; + ~TRTBatchedNMS() TRT_NOEXCEPT override = default; - int getNbOutputs() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, - int nbInputs, nvinfer1::IExprBuilder& exprBuilder) - TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, - void* const* outputs, void* workSpace, cudaStream_t stream) TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void* buffer) const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* outputs, - int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; - const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginType() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputType, - int nbInputs) const TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputType, int nbInputs) const TRT_NOEXCEPT override; - void setClipParam(bool clip); + void setClipParam(bool clip); - private: - nvinfer1::plugin::NMSParameters param{}; - bool mClipBoxes{}; - bool mReturnIndex{}; -}; + private: + nvinfer1::plugin::NMSParameters param{}; + bool mClipBoxes{}; + bool mReturnIndex{}; + }; -class TRTBatchedNMSCreator : public TRTPluginCreatorBase { - public: - TRTBatchedNMSCreator(); + class TRTBatchedNMSCreator : public TRTPluginCreatorBase + { + public: + TRTBatchedNMSCreator(); - ~TRTBatchedNMSCreator() TRT_NOEXCEPT override = default; + ~TRTBatchedNMSCreator() TRT_NOEXCEPT override = default; - const char* getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2Ext* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) - TRT_NOEXCEPT override; + nvinfer1::IPluginV2Ext* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; - nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_BATCHED_NMS_PLUGIN_CUSTOM_H diff --git a/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.cpp b/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.cpp index 9d977bc937..80b5be6abc 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.cpp @@ -8,222 +8,295 @@ #include "nms/kernel.h" #include "trt_serialize.hpp" -namespace mmdeploy { -using namespace nvinfer1; -using nvinfer1::plugin::NMSParameters; - -namespace { -static const char* NMS_PLUGIN_VERSION{"1"}; -static const char* NMS_PLUGIN_NAME{"TRTBatchedRotatedNMS"}; -} // namespace - -TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, NMSParameters params) - : TRTPluginBase(name), param(params) {} - -TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, const void* data, size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, ¶m); - deserialize_value(&data, &length, &mClipBoxes); -} - -int TRTBatchedRotatedNMS::getNbOutputs() const TRT_NOEXCEPT { return 2; } - -nvinfer1::DimsExprs TRTBatchedRotatedNMS::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { - ASSERT(nbInputs == 2); - ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs()); - ASSERT(inputs[0].nbDims == 4); - ASSERT(inputs[1].nbDims == 3); - - nvinfer1::DimsExprs ret; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = exprBuilder.constant(param.keepTopK); - switch (outputIndex) { - case 0: - ret.nbDims = 3; - ret.d[2] = exprBuilder.constant(6); - break; - case 1: - ret.nbDims = 2; - break; - default: - break; - } - - return ret; -} - -size_t TRTBatchedRotatedNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT { - size_t batch_size = inputs[0].dims.d[0]; - size_t boxes_size = inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3]; - size_t score_size = inputs[1].dims.d[1] * inputs[1].dims.d[2]; - size_t num_priors = inputs[0].dims.d[1]; - bool shareLocation = (inputs[0].dims.d[2] == 1); - int topk = param.topK > 0 && param.topK <= inputs[1].dims.d[1] ? param.topK : inputs[1].dims.d[1]; - return detectionInferenceWorkspaceSize(shareLocation, batch_size, boxes_size, score_size, - param.numClasses, num_priors, topk, DataType::kFLOAT, - DataType::kFLOAT); -} - -int TRTBatchedRotatedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, void* workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - const void* const locData = inputs[0]; - const void* const confData = inputs[1]; - - void* nmsedDets = outputs[0]; - void* nmsedLabels = outputs[1]; - - size_t batch_size = inputDesc[0].dims.d[0]; - size_t boxes_size = inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3]; - size_t score_size = inputDesc[1].dims.d[1] * inputDesc[1].dims.d[2]; - size_t num_priors = inputDesc[0].dims.d[1]; - bool shareLocation = (inputDesc[0].dims.d[2] == 1); - - int topk = - param.topK > 0 && param.topK <= inputDesc[1].dims.d[1] ? param.topK : inputDesc[1].dims.d[1]; - bool rotated = true; - pluginStatus_t status = nmsInference( - stream, batch_size, boxes_size, score_size, shareLocation, param.backgroundLabelId, - num_priors, param.numClasses, topk, param.keepTopK, param.scoreThreshold, param.iouThreshold, - DataType::kFLOAT, locData, DataType::kFLOAT, confData, nmsedDets, nmsedLabels, nullptr, - workSpace, param.isNormalized, false, mClipBoxes, rotated); - ASSERT(status == STATUS_SUCCESS); - - return 0; -} - -size_t TRTBatchedRotatedNMS::getSerializationSize() const TRT_NOEXCEPT { - // NMSParameters, - return sizeof(NMSParameters) + sizeof(bool); -} - -void TRTBatchedRotatedNMS::serialize(void* buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, param); - serialize_value(&buffer, mClipBoxes); -} - -void TRTBatchedRotatedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* outputs, - int nbOutputs) TRT_NOEXCEPT { - // Validate input arguments -} - -bool TRTBatchedRotatedNMS::supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc* ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 3) { - return ioDesc[pos].type == nvinfer1::DataType::kINT32 && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; - } - return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; -} - -const char* TRTBatchedRotatedNMS::getPluginType() const TRT_NOEXCEPT { return NMS_PLUGIN_NAME; } - -const char* TRTBatchedRotatedNMS::getPluginVersion() const TRT_NOEXCEPT { - return NMS_PLUGIN_VERSION; -} - -IPluginV2DynamicExt* TRTBatchedRotatedNMS::clone() const TRT_NOEXCEPT { - auto* plugin = new TRTBatchedRotatedNMS(mLayerName, param); - plugin->setPluginNamespace(mNamespace.c_str()); - plugin->setClipParam(mClipBoxes); - return plugin; -} - -nvinfer1::DataType TRTBatchedRotatedNMS::getOutputDataType(int index, - const nvinfer1::DataType* inputTypes, - int nbInputs) const TRT_NOEXCEPT { - ASSERT(index >= 0 && index < this->getNbOutputs()); - if (index == 1) { - return nvinfer1::DataType::kINT32; - } - return inputTypes[0]; -} - -void TRTBatchedRotatedNMS::setClipParam(bool clip) { mClipBoxes = clip; } - -TRTBatchedRotatedNMSCreator::TRTBatchedRotatedNMSCreator() { - mPluginAttributes.emplace_back( - PluginField("background_label_id", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("num_classes", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("topk", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back( - PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); - mPluginAttributes.emplace_back( - PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); - mPluginAttributes.emplace_back(PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1)); - mPluginAttributes.emplace_back(PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1)); - - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char* TRTBatchedRotatedNMSCreator::getPluginName() const TRT_NOEXCEPT { - return NMS_PLUGIN_NAME; -} - -const char* TRTBatchedRotatedNMSCreator::getPluginVersion() const TRT_NOEXCEPT { - return NMS_PLUGIN_VERSION; -} - -IPluginV2Ext* TRTBatchedRotatedNMSCreator::createPlugin( - const char* name, const PluginFieldCollection* fc) TRT_NOEXCEPT { - const PluginField* fields = fc->fields; - bool clipBoxes = true; - nvinfer1::plugin::NMSParameters params{}; - - for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; - if (!strcmp(attrName, "background_label_id")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.backgroundLabelId = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "num_classes")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.numClasses = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "topk")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.topK = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "keep_topk")) { - ASSERT(fields[i].type == PluginFieldType::kINT32); - params.keepTopK = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "score_threshold")) { - ASSERT(fields[i].type == PluginFieldType::kFLOAT32); - params.scoreThreshold = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "iou_threshold")) { - ASSERT(fields[i].type == PluginFieldType::kFLOAT32); - params.iouThreshold = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "is_normalized")) { - params.isNormalized = *(static_cast(fields[i].data)); - } else if (!strcmp(attrName, "clip_boxes")) { - clipBoxes = *(static_cast(fields[i].data)); - } - } - - TRTBatchedRotatedNMS* plugin = new TRTBatchedRotatedNMS(name, params); - plugin->setClipParam(clipBoxes); - plugin->setPluginNamespace(mNamespace.c_str()); - return plugin; -} - -IPluginV2Ext* TRTBatchedRotatedNMSCreator::deserializePlugin(const char* name, - const void* serialData, - size_t serialLength) TRT_NOEXCEPT { - // This object will be deleted when the network is destroyed, which will - // call NMS::destroy() - TRTBatchedRotatedNMS* plugin = new TRTBatchedRotatedNMS(name, serialData, serialLength); - plugin->setPluginNamespace(mNamespace.c_str()); - return plugin; -} - -REGISTER_TENSORRT_PLUGIN(TRTBatchedRotatedNMSCreator); +namespace mmdeploy +{ + using namespace nvinfer1; + using nvinfer1::plugin::NMSParameters; + + namespace + { + static const char* NMS_PLUGIN_VERSION{"1"}; + static const char* NMS_PLUGIN_NAME{"TRTBatchedRotatedNMS"}; + } // namespace + + TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, NMSParameters params) + : TRTPluginBase(name) + , param(params) + { + } + + TRTBatchedRotatedNMS::TRTBatchedRotatedNMS(const std::string& name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, ¶m); + deserialize_value(&data, &length, &mClipBoxes); + } + + int TRTBatchedRotatedNMS::getNbOutputs() const TRT_NOEXCEPT + { + return 2; + } + + nvinfer1::DimsExprs TRTBatchedRotatedNMS::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + ASSERT(nbInputs == 2); + ASSERT(outputIndex >= 0 && outputIndex < this->getNbOutputs()); + ASSERT(inputs[0].nbDims == 4); + ASSERT(inputs[1].nbDims == 3); + + nvinfer1::DimsExprs ret; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = exprBuilder.constant(param.keepTopK); + switch (outputIndex) + { + case 0: + ret.nbDims = 3; + ret.d[2] = exprBuilder.constant(6); + break; + case 1: + ret.nbDims = 2; + break; + default: + break; + } + + return ret; + } + + size_t TRTBatchedRotatedNMS::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + size_t batch_size = inputs[0].dims.d[0]; + size_t boxes_size = inputs[0].dims.d[1] * inputs[0].dims.d[2] * inputs[0].dims.d[3]; + size_t score_size = inputs[1].dims.d[1] * inputs[1].dims.d[2]; + size_t num_priors = inputs[0].dims.d[1]; + bool shareLocation = (inputs[0].dims.d[2] == 1); + int topk = param.topK > 0 && param.topK <= inputs[1].dims.d[1] ? param.topK : inputs[1].dims.d[1]; + return detectionInferenceWorkspaceSize(shareLocation, batch_size, boxes_size, score_size, param.numClasses, num_priors, topk, DataType::kFLOAT, DataType::kFLOAT); + } + + int TRTBatchedRotatedNMS::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + const void* const locData = inputs[0]; + const void* const confData = inputs[1]; + + void* nmsedDets = outputs[0]; + void* nmsedLabels = outputs[1]; + + size_t batch_size = inputDesc[0].dims.d[0]; + size_t boxes_size = inputDesc[0].dims.d[1] * inputDesc[0].dims.d[2] * inputDesc[0].dims.d[3]; + size_t score_size = inputDesc[1].dims.d[1] * inputDesc[1].dims.d[2]; + size_t num_priors = inputDesc[0].dims.d[1]; + bool shareLocation = (inputDesc[0].dims.d[2] == 1); + + int topk = + param.topK > 0 && param.topK <= inputDesc[1].dims.d[1] ? param.topK : inputDesc[1].dims.d[1]; + bool rotated = true; + pluginStatus_t status = nmsInference( + stream, + batch_size, + boxes_size, + score_size, + shareLocation, + param.backgroundLabelId, + num_priors, + param.numClasses, + topk, + param.keepTopK, + param.scoreThreshold, + param.iouThreshold, + DataType::kFLOAT, + locData, + DataType::kFLOAT, + confData, + nmsedDets, + nmsedLabels, + nullptr, + workSpace, + param.isNormalized, + false, + mClipBoxes, + rotated); + ASSERT(status == STATUS_SUCCESS); + + return 0; + } + + size_t TRTBatchedRotatedNMS::getSerializationSize() const TRT_NOEXCEPT + { + // NMSParameters, + return sizeof(NMSParameters) + sizeof(bool); + } + + void TRTBatchedRotatedNMS::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, param); + serialize_value(&buffer, mClipBoxes); + } + + void TRTBatchedRotatedNMS::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT + { + // Validate input arguments + } + + bool TRTBatchedRotatedNMS::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 3) + { + return ioDesc[pos].type == nvinfer1::DataType::kINT32 && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + + const char* TRTBatchedRotatedNMS::getPluginType() const TRT_NOEXCEPT + { + return NMS_PLUGIN_NAME; + } + + const char* TRTBatchedRotatedNMS::getPluginVersion() const TRT_NOEXCEPT + { + return NMS_PLUGIN_VERSION; + } + + IPluginV2DynamicExt* TRTBatchedRotatedNMS::clone() const TRT_NOEXCEPT + { + auto* plugin = new TRTBatchedRotatedNMS(mLayerName, param); + plugin->setPluginNamespace(mNamespace.c_str()); + plugin->setClipParam(mClipBoxes); + return plugin; + } + + nvinfer1::DataType TRTBatchedRotatedNMS::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + ASSERT(index >= 0 && index < this->getNbOutputs()); + if (index == 1) + { + return nvinfer1::DataType::kINT32; + } + return inputTypes[0]; + } + + void TRTBatchedRotatedNMS::setClipParam(bool clip) + { + mClipBoxes = clip; + } + + TRTBatchedRotatedNMSCreator::TRTBatchedRotatedNMSCreator() + { + mPluginAttributes.emplace_back( + PluginField("background_label_id", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("num_classes", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("topk", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("keep_topk", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back( + PluginField("score_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back( + PluginField("iou_threshold", nullptr, PluginFieldType::kFLOAT32, 1)); + mPluginAttributes.emplace_back(PluginField("is_normalized", nullptr, PluginFieldType::kINT32, 1)); + mPluginAttributes.emplace_back(PluginField("clip_boxes", nullptr, PluginFieldType::kINT32, 1)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTBatchedRotatedNMSCreator::getPluginName() const TRT_NOEXCEPT + { + return NMS_PLUGIN_NAME; + } + + const char* TRTBatchedRotatedNMSCreator::getPluginVersion() const TRT_NOEXCEPT + { + return NMS_PLUGIN_VERSION; + } + + IPluginV2Ext* TRTBatchedRotatedNMSCreator::createPlugin( + const char* name, + const PluginFieldCollection* fc) TRT_NOEXCEPT + { + const PluginField* fields = fc->fields; + bool clipBoxes = true; + nvinfer1::plugin::NMSParameters params{}; + + for (int i = 0; i < fc->nbFields; ++i) + { + const char* attrName = fields[i].name; + if (!strcmp(attrName, "background_label_id")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.backgroundLabelId = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "num_classes")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.numClasses = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "topk")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.topK = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "keep_topk")) + { + ASSERT(fields[i].type == PluginFieldType::kINT32); + params.keepTopK = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "score_threshold")) + { + ASSERT(fields[i].type == PluginFieldType::kFLOAT32); + params.scoreThreshold = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "iou_threshold")) + { + ASSERT(fields[i].type == PluginFieldType::kFLOAT32); + params.iouThreshold = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "is_normalized")) + { + params.isNormalized = *(static_cast(fields[i].data)); + } + else if (!strcmp(attrName, "clip_boxes")) + { + clipBoxes = *(static_cast(fields[i].data)); + } + } + + TRTBatchedRotatedNMS* plugin = new TRTBatchedRotatedNMS(name, params); + plugin->setClipParam(clipBoxes); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + + IPluginV2Ext* TRTBatchedRotatedNMSCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + // This object will be deleted when the network is destroyed, which will + // call NMS::destroy() + TRTBatchedRotatedNMS* plugin = new TRTBatchedRotatedNMS(name, serialData, serialLength); + plugin->setPluginNamespace(mNamespace.c_str()); + return plugin; + } + + REGISTER_TENSORRT_PLUGIN(TRTBatchedRotatedNMSCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.hpp b/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.hpp index 66479eb7e7..49b5cb650d 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/batched_rotated_nms/trt_batched_rotated_nms.hpp @@ -7,72 +7,70 @@ #include "NvInferPluginUtils.h" #include "trt_plugin_base.hpp" -namespace mmdeploy { -class TRTBatchedRotatedNMS : public TRTPluginBase { - public: - TRTBatchedRotatedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param); +namespace mmdeploy +{ + class TRTBatchedRotatedNMS : public TRTPluginBase + { + public: + TRTBatchedRotatedNMS(const std::string& name, nvinfer1::plugin::NMSParameters param); - TRTBatchedRotatedNMS(const std::string& name, const void* data, size_t length); + TRTBatchedRotatedNMS(const std::string& name, const void* data, size_t length); - ~TRTBatchedRotatedNMS() TRT_NOEXCEPT override = default; + ~TRTBatchedRotatedNMS() TRT_NOEXCEPT override = default; - int getNbOutputs() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, - int nbInputs, nvinfer1::IExprBuilder& exprBuilder) - TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, - void* const* outputs, void* workSpace, cudaStream_t stream) TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void* buffer) const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* outputs, - int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; - const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginType() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputType, - int nbInputs) const TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputType, int nbInputs) const TRT_NOEXCEPT override; - void setClipParam(bool clip); + void setClipParam(bool clip); - private: - nvinfer1::plugin::NMSParameters param{}; - bool mClipBoxes{}; -}; + private: + nvinfer1::plugin::NMSParameters param{}; + bool mClipBoxes{}; + }; -class TRTBatchedRotatedNMSCreator : public TRTPluginCreatorBase { - public: - TRTBatchedRotatedNMSCreator(); + class TRTBatchedRotatedNMSCreator : public TRTPluginCreatorBase + { + public: + TRTBatchedRotatedNMSCreator(); - ~TRTBatchedRotatedNMSCreator() TRT_NOEXCEPT override = default; + ~TRTBatchedRotatedNMSCreator() TRT_NOEXCEPT override = default; - const char* getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2Ext* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) - TRT_NOEXCEPT override; + nvinfer1::IPluginV2Ext* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; - nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2Ext* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.cpp b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.cpp index 0f236e4956..db2063d235 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.cpp @@ -10,176 +10,222 @@ #include "trt_serialize.hpp" using namespace nvinfer1; -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"TRTBicubicInterpolate"}; -} // namespace - -TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string &name, - std::vector scale_factor, bool align_corners) - : TRTPluginBase(name), mScaleFactor(scale_factor), mAlignCorners(align_corners) {} - -TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string name, const void *data, - size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mScaleFactor); - deserialize_value(&data, &length, &mAlignCorners); -} - -nvinfer1::IPluginV2DynamicExt *TRTBicubicInterpolate::clone() const TRT_NOEXCEPT { - TRTBicubicInterpolate *plugin = - new TRTBicubicInterpolate(mLayerName, mScaleFactor, mAlignCorners); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs TRTBicubicInterpolate::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[0].d[1]; - auto height = exprBuilder.constant(mScaleFactor[0]); - auto width = exprBuilder.constant(mScaleFactor[1]); - auto d2 = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[2], *height); - auto d3 = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[3], *width); - ret.d[2] = d2; - ret.d[3] = d3; - - return ret; -} - -bool TRTBicubicInterpolate::supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 0) { - return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - - } else { - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - } -} - -void TRTBicubicInterpolate::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT {} - -size_t TRTBicubicInterpolate::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int TRTBicubicInterpolate::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - int batch = inputDesc[0].dims.d[0]; - int channels = inputDesc[0].dims.d[1]; - int height = inputDesc[0].dims.d[2]; - int width = inputDesc[0].dims.d[3]; - - int height_out = outputDesc[0].dims.d[2]; - int width_out = outputDesc[0].dims.d[3]; - const void *x = inputs[0]; - void *output = outputs[0]; - - // TODO: add fp16 support - auto data_type = inputDesc[0].type; - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - bicubic_interpolate((float *)x, (float *)output, batch, channels, height, width, - height_out, width_out, mAlignCorners, stream); - break; - default: - return 1; - break; - } - - return 0; -} - -nvinfer1::DataType TRTBicubicInterpolate::getOutputDataType(int index, - const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *TRTBicubicInterpolate::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTBicubicInterpolate::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int TRTBicubicInterpolate::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t TRTBicubicInterpolate::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mScaleFactor) + serialized_size(mAlignCorners); -} - -void TRTBicubicInterpolate::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mScaleFactor); - serialize_value(&buffer, mAlignCorners); -} - -////////////////////// creator ///////////////////////////// - -TRTBicubicInterpolateCreator::TRTBicubicInterpolateCreator() { - mPluginAttributes.clear(); - mPluginAttributes.emplace_back(nvinfer1::PluginField("scale_factor")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("align_corners")); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *TRTBicubicInterpolateCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTBicubicInterpolateCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *TRTBicubicInterpolateCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - nvinfer1::Dims size{2, {1, 1}}; - std::vector scale_factor; - bool align_corners = 1; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; - } - std::string field_name(fc->fields[i].name); - - if (field_name.compare("scale_factor") == 0) { - int data_size = (fc->fields[i].length); - if (data_size != 2) { - data_size = data_size / sizeof(float); - } - ASSERT(data_size == 2) - const float *data_start = static_cast(fc->fields[i].data); - scale_factor = std::vector(data_start, data_start + data_size); - } - - if (field_name.compare("align_corners") == 0) { - align_corners = static_cast(fc->fields[i].data)[0]; - } - } - - TRTBicubicInterpolate *plugin = new TRTBicubicInterpolate(name, scale_factor, align_corners); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *TRTBicubicInterpolateCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new TRTBicubicInterpolate(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(TRTBicubicInterpolateCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"TRTBicubicInterpolate"}; + } // namespace + + TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string& name, + std::vector scale_factor, + bool align_corners) + : TRTPluginBase(name) + , mScaleFactor(scale_factor) + , mAlignCorners(align_corners) + { + } + + TRTBicubicInterpolate::TRTBicubicInterpolate(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mScaleFactor); + deserialize_value(&data, &length, &mAlignCorners); + } + + nvinfer1::IPluginV2DynamicExt* TRTBicubicInterpolate::clone() const TRT_NOEXCEPT + { + TRTBicubicInterpolate* plugin = + new TRTBicubicInterpolate(mLayerName, mScaleFactor, mAlignCorners); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs TRTBicubicInterpolate::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + auto height = exprBuilder.constant(mScaleFactor[0]); + auto width = exprBuilder.constant(mScaleFactor[1]); + auto d2 = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[2], *height); + auto d3 = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[3], *width); + ret.d[2] = d2; + ret.d[3] = d3; + + return ret; + } + + bool TRTBicubicInterpolate::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 0) + { + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } + else + { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } + } + + void TRTBicubicInterpolate::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + + size_t TRTBicubicInterpolate::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int TRTBicubicInterpolate::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int batch = inputDesc[0].dims.d[0]; + int channels = inputDesc[0].dims.d[1]; + int height = inputDesc[0].dims.d[2]; + int width = inputDesc[0].dims.d[3]; + + int height_out = outputDesc[0].dims.d[2]; + int width_out = outputDesc[0].dims.d[3]; + const void* x = inputs[0]; + void* output = outputs[0]; + + // TODO: add fp16 support + auto data_type = inputDesc[0].type; + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + bicubic_interpolate((float*)x, (float*)output, batch, channels, height, width, height_out, width_out, mAlignCorners, stream); + break; + default: + return 1; + break; + } + + return 0; + } + + nvinfer1::DataType TRTBicubicInterpolate::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* TRTBicubicInterpolate::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTBicubicInterpolate::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int TRTBicubicInterpolate::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t TRTBicubicInterpolate::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mScaleFactor) + serialized_size(mAlignCorners); + } + + void TRTBicubicInterpolate::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mScaleFactor); + serialize_value(&buffer, mAlignCorners); + } + + ////////////////////// creator ///////////////////////////// + + TRTBicubicInterpolateCreator::TRTBicubicInterpolateCreator() + { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("scale_factor")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("align_corners")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTBicubicInterpolateCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTBicubicInterpolateCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* TRTBicubicInterpolateCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + nvinfer1::Dims size{2, {1, 1}}; + std::vector scale_factor; + bool align_corners = 1; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("scale_factor") == 0) + { + int data_size = (fc->fields[i].length); + if (data_size != 2) + { + data_size = data_size / sizeof(float); + } + ASSERT(data_size == 2) + const float* data_start = static_cast(fc->fields[i].data); + scale_factor = std::vector(data_start, data_start + data_size); + } + + if (field_name.compare("align_corners") == 0) + { + align_corners = static_cast(fc->fields[i].data)[0]; + } + } + + TRTBicubicInterpolate* plugin = new TRTBicubicInterpolate(name, scale_factor, align_corners); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* TRTBicubicInterpolateCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new TRTBicubicInterpolate(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + REGISTER_TENSORRT_PLUGIN(TRTBicubicInterpolateCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.hpp b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.hpp index 37ad7cf9ff..709976ce32 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate.hpp @@ -7,61 +7,58 @@ #include #include "trt_plugin_base.hpp" -namespace mmdeploy { -class TRTBicubicInterpolate : public TRTPluginBase { - public: - TRTBicubicInterpolate(const std::string &name, std::vector scale_factor, - bool align_corners); +namespace mmdeploy +{ + class TRTBicubicInterpolate : public TRTPluginBase + { + public: + TRTBicubicInterpolate(const std::string& name, std::vector scale_factor, bool align_corners); - TRTBicubicInterpolate(const std::string name, const void *data, size_t length); + TRTBicubicInterpolate(const std::string name, const void* data, size_t length); - TRTBicubicInterpolate() = delete; + TRTBicubicInterpolate() = delete; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - private: - std::vector mScaleFactor; - bool mAlignCorners; -}; + private: + std::vector mScaleFactor; + bool mAlignCorners; + }; -class TRTBicubicInterpolateCreator : public TRTPluginCreatorBase { - public: - TRTBicubicInterpolateCreator(); + class TRTBicubicInterpolateCreator : public TRTPluginCreatorBase + { + public: + TRTBicubicInterpolateCreator(); - const char *getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_BICUBIC_INTERPOLATE_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu index efb078c431..7a03aa3144 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.cu @@ -12,159 +12,176 @@ // Based on // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm -template -__device__ __forceinline__ static scalar_t cubic_convolution1(scalar_t x, scalar_t A) { - return ((A + 2) * x - (A + 3)) * x * x + 1; +template +__device__ __forceinline__ static scalar_t cubic_convolution1(scalar_t x, scalar_t A) +{ + return ((A + 2) * x - (A + 3)) * x * x + 1; } -template -__device__ __forceinline__ static scalar_t cubic_convolution2(scalar_t x, scalar_t A) { - return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; +template +__device__ __forceinline__ static scalar_t cubic_convolution2(scalar_t x, scalar_t A) +{ + return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A; } -template +template __device__ __forceinline__ static void get_cubic_upsample_coefficients(scalar_t coeffs[4], - scalar_t t) { - scalar_t A = -0.75; - - scalar_t x1 = t; - coeffs[0] = cubic_convolution2(x1 + 1.0, A); - coeffs[1] = cubic_convolution1(x1, A); - - // opposite coefficients - scalar_t x2 = 1.0 - t; - coeffs[2] = cubic_convolution1(x2, A); - coeffs[3] = cubic_convolution2(x2 + 1.0, A); + scalar_t t) +{ + scalar_t A = -0.75; + + scalar_t x1 = t; + coeffs[0] = cubic_convolution2(x1 + 1.0, A); + coeffs[1] = cubic_convolution1(x1, A); + + // opposite coefficients + scalar_t x2 = 1.0 - t; + coeffs[2] = cubic_convolution1(x2, A); + coeffs[3] = cubic_convolution2(x2 + 1.0, A); } -template -__device__ __forceinline__ static scalar_t cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, - scalar_t x3, scalar_t t) { - scalar_t coeffs[4]; - get_cubic_upsample_coefficients(coeffs, t); +template +__device__ __forceinline__ static scalar_t cubic_interp1d(scalar_t x0, scalar_t x1, scalar_t x2, scalar_t x3, scalar_t t) +{ + scalar_t coeffs[4]; + get_cubic_upsample_coefficients(coeffs, t); - return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; + return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3]; } /* Used by UpSampleBicubic2d.cu */ -template -__device__ __forceinline__ static scalar_t upsample_get_value_bounded(const scalar_t *data, - int batch, int channel, - int batchsize, int channels, - int height, int width, int y, - int x) { - int access_y = max(min(y, height - 1), 0); - int access_x = max(min(x, width - 1), 0); - return data[batch * channels * height * width + channel * height * width + access_y * width + - access_x]; +template +__device__ __forceinline__ static scalar_t upsample_get_value_bounded(const scalar_t* data, + int batch, + int channel, + int batchsize, + int channels, + int height, + int width, + int y, + int x) +{ + int access_y = max(min(y, height - 1), 0); + int access_x = max(min(x, width - 1), 0); + return data[batch * channels * height * width + channel * height * width + access_y * width + + access_x]; } -template +template __device__ __forceinline__ scalar_t -area_pixel_compute_source_index(scalar_t scale, int64_t dst_index, bool align_corners, bool cubic) { - if (align_corners) { - return scale * dst_index; - } else { - scalar_t src_idx = scale * (dst_index + 0.5) - 0.5; - // [Note] Follow Opencv resize logic: - // We allow negative src_idx here and later will use - // dx = src_idx - floorf(src_idx) - // to compute the "distance"(which affects weights). - // For linear modes, weight distribution doesn't matter - // for negative indices as they use 2 pixels to interpolate. - // For example, [-1, 0], they both use pixel 0 value so it - // doesn't affect if we bound the src_idx to 0 or not. - // TODO: Our current linear mode impls use unbound indices - // where we should and then remove this cubic flag. - // This matters in cubic mode, as we might need [-1, 0, 1, 2] - // to interpolate and the weights can be affected. - return (!cubic && src_idx < 0) ? scalar_t(0) : src_idx; - } + area_pixel_compute_source_index(scalar_t scale, int64_t dst_index, bool align_corners, bool cubic) +{ + if (align_corners) + { + return scale * dst_index; + } + else + { + scalar_t src_idx = scale * (dst_index + 0.5) - 0.5; + // [Note] Follow Opencv resize logic: + // We allow negative src_idx here and later will use + // dx = src_idx - floorf(src_idx) + // to compute the "distance"(which affects weights). + // For linear modes, weight distribution doesn't matter + // for negative indices as they use 2 pixels to interpolate. + // For example, [-1, 0], they both use pixel 0 value so it + // doesn't affect if we bound the src_idx to 0 or not. + // TODO: Our current linear mode impls use unbound indices + // where we should and then remove this cubic flag. + // This matters in cubic mode, as we might need [-1, 0, 1, 2] + // to interpolate and the weights can be affected. + return (!cubic && src_idx < 0) ? scalar_t(0) : src_idx; + } } // cubic interpolation pytorch -template -__global__ void resize_cubic_kernel_torch(const int num_elements, const scalar_t *src, - const int batchsize, const int channels, int srcWidth, - int srcHeight, scalar_t *dst, int dstWidth, int dstHeight, - bool align_corners, float height_scale, - float width_scale) { - CUDA_1D_KERNEL_LOOP(index, num_elements) { - // Special case: input and output are the same size, just copy - const int output_x = index % dstWidth; - const int output_y = index / dstWidth; - - if (srcHeight == dstHeight && srcWidth == dstWidth) { - for (int n = 0; n < batchsize; n++) { - for (int c = 0; c < channels; c++) { - const scalar_t val = src[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + - output_y * dstWidth + output_x]; - dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + output_y * dstWidth + - output_x] = val; +template +__global__ void resize_cubic_kernel_torch(const int num_elements, const scalar_t* src, const int batchsize, const int channels, int srcWidth, int srcHeight, scalar_t* dst, int dstWidth, int dstHeight, bool align_corners, float height_scale, float width_scale) +{ + CUDA_1D_KERNEL_LOOP(index, num_elements) + { + // Special case: input and output are the same size, just copy + const int output_x = index % dstWidth; + const int output_y = index / dstWidth; + + if (srcHeight == dstHeight && srcWidth == dstWidth) + { + for (int n = 0; n < batchsize; n++) + { + for (int c = 0; c < channels; c++) + { + const scalar_t val = src[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + + output_y * dstWidth + output_x]; + dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + output_y * dstWidth + + output_x] = val; + } + } + return; } - } - return; - } - // Interpolation kernel - scalar_t real_x = - area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true); - int in_x = floorf(real_x); - scalar_t t_x = real_x - in_x; - - scalar_t real_y = - area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true); - int in_y = floorf(real_y); - scalar_t t_y = real_y - in_y; - - for (int n = 0; n < batchsize; n++) { - for (int c = 0; c < channels; c++) { - scalar_t coefficients[4]; - - for (int k = 0; k < 4; k++) { - coefficients[k] = cubic_interp1d( - upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, - in_y - 1 + k, in_x - 1), - upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, - in_y - 1 + k, in_x + 0), - upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, - in_y - 1 + k, in_x + 1), - upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, - in_y - 1 + k, in_x + 2), - t_x); + // Interpolation kernel + scalar_t real_x = + area_pixel_compute_source_index(width_scale, output_x, align_corners, /*cubic=*/true); + int in_x = floorf(real_x); + scalar_t t_x = real_x - in_x; + + scalar_t real_y = + area_pixel_compute_source_index(height_scale, output_y, align_corners, /*cubic=*/true); + int in_y = floorf(real_y); + scalar_t t_y = real_y - in_y; + + for (int n = 0; n < batchsize; n++) + { + for (int c = 0; c < channels; c++) + { + scalar_t coefficients[4]; + + for (int k = 0; k < 4; k++) + { + coefficients[k] = cubic_interp1d( + upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, in_y - 1 + k, in_x - 1), + upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, in_y - 1 + k, in_x + 0), + upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, in_y - 1 + k, in_x + 1), + upsample_get_value_bounded(src, n, c, batchsize, channels, srcHeight, srcWidth, in_y - 1 + k, in_x + 2), + t_x); + } + + dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + output_y * dstWidth + + output_x] = scalar_t(cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], coefficients[3], t_y)); + } } - - dst[n * channels * dstHeight * dstWidth + c * dstHeight * dstWidth + output_y * dstWidth + - output_x] = scalar_t(cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], - coefficients[3], t_y)); - } } - } } -template -void resizeGPU(const scalar_t *pIn_d, scalar_t *pOut_d, int batch, int channels, int srcWidth, - int srcHeight, int dstWidth, int dstHeight, bool align_corners, - cudaStream_t stream) { - float height_scale = float(srcHeight) / dstHeight; - float width_scale = float(srcWidth) / dstWidth; - if (align_corners && dstWidth > 1 && dstHeight > 1) { - height_scale = (float)(srcHeight - 1) / (dstHeight - 1); - width_scale = (float)(srcWidth - 1) / (dstWidth - 1); - } - int n = batch * dstWidth * dstHeight * channels; - resize_cubic_kernel_torch<<>>( - dstWidth * dstHeight, pIn_d, batch, channels, srcWidth, srcHeight, pOut_d, dstWidth, - dstHeight, align_corners, height_scale, width_scale); +template +void resizeGPU(const scalar_t* pIn_d, scalar_t* pOut_d, int batch, int channels, int srcWidth, int srcHeight, int dstWidth, int dstHeight, bool align_corners, cudaStream_t stream) +{ + float height_scale = float(srcHeight) / dstHeight; + float width_scale = float(srcWidth) / dstWidth; + if (align_corners && dstWidth > 1 && dstHeight > 1) + { + height_scale = (float)(srcHeight - 1) / (dstHeight - 1); + width_scale = (float)(srcWidth - 1) / (dstWidth - 1); + } + int n = batch * dstWidth * dstHeight * channels; + resize_cubic_kernel_torch<<>>( + dstWidth * dstHeight, + pIn_d, + batch, + channels, + srcWidth, + srcHeight, + pOut_d, + dstWidth, + dstHeight, + align_corners, + height_scale, + width_scale); } -template -void bicubic_interpolate(const scalar_t *input, scalar_t *output, int batch, int channels, - int in_height, int in_width, int out_height, int out_width, - bool align_corners, cudaStream_t stream) { - resizeGPU(input, output, batch, channels, in_width, in_height, out_width, out_height, - align_corners, stream); +template +void bicubic_interpolate(const scalar_t* input, scalar_t* output, int batch, int channels, int in_height, int in_width, int out_height, int out_width, bool align_corners, cudaStream_t stream) +{ + resizeGPU(input, output, batch, channels, in_width, in_height, out_width, out_height, align_corners, stream); } -template void bicubic_interpolate(const float *input, float *output, int batch, int channels, - int in_height, int in_width, int out_height, int out_width, - bool align_corners, cudaStream_t stream); +template void bicubic_interpolate(const float* input, float* output, int batch, int channels, int in_height, int in_width, int out_height, int out_width, bool align_corners, cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp index 66560f59f5..28a89a71db 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/bicubic_interpolate/trt_bicubic_interpolate_kernel.hpp @@ -4,8 +4,6 @@ #include "common_cuda_helper.hpp" -template -void bicubic_interpolate(const scalar_t *input, scalar_t *output, int batch, int channels, - int in_height, int in_width, int out_height, int out_width, - bool align_corners, cudaStream_t stream); +template +void bicubic_interpolate(const scalar_t* input, scalar_t* output, int batch, int channels, int in_height, int in_width, int out_height, int out_width, bool align_corners, cudaStream_t stream); #endif // TRT_BICUBIC_INTERPOLATE_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/common_cuda_helper.hpp b/csrc/mmdeploy/backend_ops/tensorrt/common/common_cuda_helper.hpp index c76cac8a32..97738f8f02 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/common_cuda_helper.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/common_cuda_helper.hpp @@ -9,25 +9,27 @@ #include #define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x) #define THREADS_PER_BLOCK 512 #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) -inline int GET_BLOCKS(const int N) { - int optimal_block_num = DIVUP(N, THREADS_PER_BLOCK); - int max_block_num = 4096; - return std::min(optimal_block_num, max_block_num); +inline int GET_BLOCKS(const int N) +{ + int optimal_block_num = DIVUP(N, THREADS_PER_BLOCK); + int max_block_num = 4096; + return std::min(optimal_block_num, max_block_num); } -#define cudaCheckError() \ - { \ - cudaError_t e = cudaGetLastError(); \ - if (e != cudaSuccess) { \ - printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ - exit(0); \ - } \ - } +#define cudaCheckError() \ + { \ + cudaError_t e = cudaGetLastError(); \ + if (e != cudaSuccess) \ + { \ + printf("Cuda failure %s:%d: '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \ + exit(0); \ + } \ + } /** * Returns a view of the original tensor with its dimensions permuted. @@ -39,44 +41,43 @@ inline int GET_BLOCKS(const int N) { * @param[in] src_dim dim of src tensor * @param[in] stream cuda stream handle */ -template -void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size, int* permute, int src_dim, - cudaStream_t stream = 0); +template +void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size, int* permute, int src_dim, cudaStream_t stream = 0); -template -cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, const scalar_t* alpha, - const scalar_t* A, int lda, const scalar_t* B, int ldb, - const scalar_t* beta, scalar_t* C, int ldc); +template +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const scalar_t* alpha, const scalar_t* A, int lda, const scalar_t* B, int ldb, const scalar_t* beta, scalar_t* C, int ldc); -template +template __device__ __forceinline__ scalar_t bilinear_interpolate(const scalar_t* __restrict__ input, - const int height, const int width, - scalar_t y, scalar_t x) { - // deal with cases that inverse elements are out of feature map boundary - if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; + const int height, + const int width, + scalar_t y, + scalar_t x) +{ + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) return 0; - y = min(scalar_t(height - 1), max(scalar_t(0), y)); - x = min(scalar_t(width - 1), max(scalar_t(0), x)); + y = min(scalar_t(height - 1), max(scalar_t(0), y)); + x = min(scalar_t(width - 1), max(scalar_t(0), x)); - const int y_low = floor(y); - const int x_low = floor(x); - const int y_high = ceil(y); - const int x_high = ceil(x); + const int y_low = floor(y); + const int x_low = floor(x); + const int y_high = ceil(y); + const int x_high = ceil(x); - const scalar_t v1 = input[y_low * width + x_low]; - const scalar_t v2 = input[y_low * width + x_high]; - const scalar_t v3 = input[y_high * width + x_low]; - const scalar_t v4 = input[y_high * width + x_high]; + const scalar_t v1 = input[y_low * width + x_low]; + const scalar_t v2 = input[y_low * width + x_high]; + const scalar_t v3 = input[y_high * width + x_low]; + const scalar_t v4 = input[y_high * width + x_high]; - // lerp can be performed by fma - const scalar_t ly = y - y_low; - const scalar_t lx = x - x_low; - const scalar_t v_low = fma(v2 - v1, lx, v1); - const scalar_t v_high = fma(v4 - v3, lx, v3); - const scalar_t val = fma(v_high - v_low, ly, v_low); + // lerp can be performed by fma + const scalar_t ly = y - y_low; + const scalar_t lx = x - x_low; + const scalar_t v_low = fma(v2 - v1, lx, v1); + const scalar_t v_high = fma(v4 - v3, lx, v3); + const scalar_t val = fma(v_high - v_low, ly, v_low); - return val; + return val; } #endif // COMMON_CUDA_HELPER diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/batched_nms_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/batched_nms_kernel.hpp index 22cffa0605..8b28458fd0 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/batched_nms_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/batched_nms_kernel.hpp @@ -6,14 +6,6 @@ #include "cuda_runtime_api.h" #include "kernel.h" -pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatchBoxesSize, - const int perBatchScoresSize, const bool shareLocation, - const int backgroundLabelId, const int numPredsPerClass, - const int numClasses, const int topK, const int keepTopK, - const float scoreThreshold, const float iouThreshold, - const DataType DT_BBOX, const void* locData, const DataType DT_SCORE, - const void* confData, void* nmsedDets, void* nmsedLabels, - void* nmsedIndex, void* workspace, bool isNormalized, bool confSigmoid, - bool clipBoxes, bool rotated = false); +pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatchBoxesSize, const int perBatchScoresSize, const bool shareLocation, const int backgroundLabelId, const int numPredsPerClass, const int numClasses, const int topK, const int keepTopK, const float scoreThreshold, const float iouThreshold, const DataType DT_BBOX, const void* locData, const DataType DT_SCORE, const void* confData, void* nmsedDets, void* nmsedLabels, void* nmsedIndex, void* workspace, bool isNormalized, bool confSigmoid, bool clipBoxes, bool rotated = false); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/cub_helper.h b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/cub_helper.h index 93fd2a4fb9..81500147e7 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/cub_helper.h +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/cub_helper.h @@ -2,14 +2,14 @@ // modify from // https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin #include "kernel.h" -template -size_t cubSortPairsWorkspaceSize(int num_items, int num_segments) { - size_t temp_storage_bytes = 0; - cub::DeviceSegmentedRadixSort::SortPairsDescending((void*)NULL, temp_storage_bytes, - (const KeyT*)NULL, (KeyT*)NULL, - (const ValueT*)NULL, (ValueT*)NULL, - num_items, // # items - num_segments, // # segments - (const int*)NULL, (const int*)NULL); - return temp_storage_bytes; +template +size_t cubSortPairsWorkspaceSize(int num_items, int num_segments) +{ + size_t temp_storage_bytes = 0; + cub::DeviceSegmentedRadixSort::SortPairsDescending((void*)NULL, temp_storage_bytes, (const KeyT*)NULL, (KeyT*)NULL, (const ValueT*)NULL, (ValueT*)NULL, + num_items, // # items + num_segments, // # segments + (const int*)NULL, + (const int*)NULL); + return temp_storage_bytes; } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h index 1b50fa4e9f..87b089b623 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/nms/kernel.h @@ -15,72 +15,54 @@ using namespace nvinfer1; #define DEBUG_ENABLE 0 -template -struct Bbox { - T xmin, ymin, xmax, ymax; - Bbox(T xmin, T ymin, T xmax, T ymax) : xmin(xmin), ymin(ymin), xmax(xmax), ymax(ymax) {} - Bbox() = default; +template +struct Bbox +{ + T xmin, ymin, xmax, ymax; + Bbox(T xmin, T ymin, T xmax, T ymax) + : xmin(xmin) + , ymin(ymin) + , xmax(xmax) + , ymax(ymax) + { + } + Bbox() = default; }; -size_t get_cuda_arch(int devID); +size_t get_cuda_arch(int devID); -int8_t* alignPtr(int8_t* ptr, uintptr_t to); +int8_t* alignPtr(int8_t* ptr, uintptr_t to); -int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize); +int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize); -void setUniformOffsets(cudaStream_t stream, int num_segments, int offset, int* d_offsets); +void setUniformOffsets(cudaStream_t stream, int num_segments, int offset, int* d_offsets); -pluginStatus_t allClassNMS(cudaStream_t stream, int num, int num_classes, int num_preds_per_class, - int top_k, float nms_threshold, bool share_location, bool isNormalized, - DataType DT_SCORE, DataType DT_BBOX, void* bbox_data, - void* beforeNMS_scores, void* beforeNMS_index_array, - void* afterNMS_scores, void* afterNMS_index_array, bool flipXY = false); +pluginStatus_t allClassNMS(cudaStream_t stream, int num, int num_classes, int num_preds_per_class, int top_k, float nms_threshold, bool share_location, bool isNormalized, DataType DT_SCORE, DataType DT_BBOX, void* bbox_data, void* beforeNMS_scores, void* beforeNMS_index_array, void* afterNMS_scores, void* afterNMS_index_array, bool flipXY = false); -pluginStatus_t allClassRotatedNMS(cudaStream_t stream, int num, int num_classes, - int num_preds_per_class, int top_k, float nms_threshold, - bool share_location, bool isNormalized, DataType DT_SCORE, - DataType DT_BBOX, void* bbox_data, void* beforeNMS_scores, - void* beforeNMS_index_array, void* afterNMS_scores, - void* afterNMS_index_array, bool flipXY = false); +pluginStatus_t allClassRotatedNMS(cudaStream_t stream, int num, int num_classes, int num_preds_per_class, int top_k, float nms_threshold, bool share_location, bool isNormalized, DataType DT_SCORE, DataType DT_BBOX, void* bbox_data, void* beforeNMS_scores, void* beforeNMS_index_array, void* afterNMS_scores, void* afterNMS_index_array, bool flipXY = false); -size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX); +size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX); -size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1, DataType DT_BBOX); +size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1, DataType DT_BBOX); -size_t sortScoresPerClassWorkspaceSize(int num, int num_classes, int num_preds_per_class, - DataType DT_CONF); +size_t sortScoresPerClassWorkspaceSize(int num, int num_classes, int num_preds_per_class, DataType DT_CONF); -size_t sortScoresPerImageWorkspaceSize(int num_images, int num_items_per_image, DataType DT_SCORE); +size_t sortScoresPerImageWorkspaceSize(int num_images, int num_items_per_image, DataType DT_SCORE); -pluginStatus_t sortScoresPerImage(cudaStream_t stream, int num_images, int num_items_per_image, - DataType DT_SCORE, void* unsorted_scores, - void* unsorted_bbox_indices, void* sorted_scores, - void* sorted_bbox_indices, void* workspace); +pluginStatus_t sortScoresPerImage(cudaStream_t stream, int num_images, int num_items_per_image, DataType DT_SCORE, void* unsorted_scores, void* unsorted_bbox_indices, void* sorted_scores, void* sorted_bbox_indices, void* workspace); -pluginStatus_t sortScoresPerClass(cudaStream_t stream, int num, int num_classes, - int num_preds_per_class, int background_label_id, - float confidence_threshold, DataType DT_SCORE, - void* conf_scores_gpu, void* index_array_gpu, void* workspace); +pluginStatus_t sortScoresPerClass(cudaStream_t stream, int num, int num_classes, int num_preds_per_class, int background_label_id, float confidence_threshold, DataType DT_SCORE, void* conf_scores_gpu, void* index_array_gpu, void* workspace); -size_t calculateTotalWorkspaceSize(size_t* workspaces, int count); +size_t calculateTotalWorkspaceSize(size_t* workspaces, int count); -pluginStatus_t permuteData(cudaStream_t stream, int nthreads, int num_classes, int num_data, - int num_dim, DataType DT_DATA, bool confSigmoid, const void* data, - void* new_data); +pluginStatus_t permuteData(cudaStream_t stream, int nthreads, int num_classes, int num_data, int num_dim, DataType DT_DATA, bool confSigmoid, const void* data, void* new_data); -size_t detectionForwardPreNMSSize(int N, int C2); +size_t detectionForwardPreNMSSize(int N, int C2); -size_t detectionForwardPostNMSSize(int N, int numClasses, int topK); +size_t detectionForwardPostNMSSize(int N, int numClasses, int topK); -pluginStatus_t gatherNMSOutputs(cudaStream_t stream, bool shareLocation, int numImages, - int numPredsPerClass, int numClasses, int topK, int keepTopK, - DataType DT_BBOX, DataType DT_SCORE, const void* indices, - const void* scores, const void* bboxData, void* nmsedDets, - void* nmsedLabels, void* nmsedIndex = nullptr, - bool clipBoxes = true, bool rotated = false); +pluginStatus_t gatherNMSOutputs(cudaStream_t stream, bool shareLocation, int numImages, int numPredsPerClass, int numClasses, int topK, int keepTopK, DataType DT_BBOX, DataType DT_SCORE, const void* indices, const void* scores, const void* bboxData, void* nmsedDets, void* nmsedLabels, void* nmsedIndex = nullptr, bool clipBoxes = true, bool rotated = false); -size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1, int C2, int numClasses, - int numPredsPerClass, int topK, DataType DT_BBOX, - DataType DT_SCORE); +size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1, int C2, int numClasses, int numPredsPerClass, int topK, DataType DT_BBOX, DataType DT_SCORE); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp index 8440bb6219..482d11a924 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_base.hpp @@ -5,73 +5,98 @@ #include "NvInferVersion.h" #include "trt_plugin_helper.hpp" -namespace mmdeploy { +namespace mmdeploy +{ #if NV_TENSORRT_MAJOR > 7 -#define TRT_NOEXCEPT noexcept + #define TRT_NOEXCEPT noexcept #else -#define TRT_NOEXCEPT + #define TRT_NOEXCEPT #endif -class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt { - public: - TRTPluginBase(const std::string &name) : mLayerName(name) {} - // IPluginV2 Methods - const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; } - int initialize() TRT_NOEXCEPT override { return STATUS_SUCCESS; } - void terminate() TRT_NOEXCEPT override {} - void destroy() TRT_NOEXCEPT override { delete this; } - void setPluginNamespace(const char *pluginNamespace) TRT_NOEXCEPT override { - mNamespace = pluginNamespace; - } - const char *getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); } + class TRTPluginBase : public nvinfer1::IPluginV2DynamicExt + { + public: + TRTPluginBase(const std::string& name) + : mLayerName(name) + { + } + // IPluginV2 Methods + const char* getPluginVersion() const TRT_NOEXCEPT override + { + return "1"; + } + int initialize() TRT_NOEXCEPT override + { + return STATUS_SUCCESS; + } + void terminate() TRT_NOEXCEPT override {} + void destroy() TRT_NOEXCEPT override + { + delete this; + } + void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override + { + mNamespace = pluginNamespace; + } + const char* getPluginNamespace() const TRT_NOEXCEPT override + { + return mNamespace.c_str(); + } - virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override {} + virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override {} - virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override { - return 0; - } + virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override + { + return 0; + } - virtual void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override {} + virtual void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override {} - virtual void detachFromContext() TRT_NOEXCEPT override {} + virtual void detachFromContext() TRT_NOEXCEPT override {} - protected: - const std::string mLayerName; - std::string mNamespace; + protected: + const std::string mLayerName; + std::string mNamespace; #if NV_TENSORRT_MAJOR < 8 - protected: - // To prevent compiler warnings. - using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; - using nvinfer1::IPluginV2DynamicExt::enqueue; - using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; - using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; - using nvinfer1::IPluginV2DynamicExt::supportsFormat; + protected: + // To prevent compiler warnings. + using nvinfer1::IPluginV2DynamicExt::canBroadcastInputAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::enqueue; + using nvinfer1::IPluginV2DynamicExt::getOutputDimensions; + using nvinfer1::IPluginV2DynamicExt::isOutputBroadcastAcrossBatch; + using nvinfer1::IPluginV2DynamicExt::supportsFormat; #endif -}; + }; -class TRTPluginCreatorBase : public nvinfer1::IPluginCreator { - public: - const char *getPluginVersion() const TRT_NOEXCEPT override { return "1"; }; + class TRTPluginCreatorBase : public nvinfer1::IPluginCreator + { + public: + const char* getPluginVersion() const TRT_NOEXCEPT override + { + return "1"; + }; - const nvinfer1::PluginFieldCollection *getFieldNames() TRT_NOEXCEPT override { return &mFC; } + const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override + { + return &mFC; + } - void setPluginNamespace(const char *pluginNamespace) TRT_NOEXCEPT override { - mNamespace = pluginNamespace; - } + void setPluginNamespace(const char* pluginNamespace) TRT_NOEXCEPT override + { + mNamespace = pluginNamespace; + } - const char *getPluginNamespace() const TRT_NOEXCEPT override { return mNamespace.c_str(); } + const char* getPluginNamespace() const TRT_NOEXCEPT override + { + return mNamespace.c_str(); + } - protected: - nvinfer1::PluginFieldCollection mFC; - std::vector mPluginAttributes; - std::string mNamespace; -}; + protected: + nvinfer1::PluginFieldCollection mFC; + std::vector mPluginAttributes; + std::string mNamespace; + }; } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_helper.hpp b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_helper.hpp index 41b47acdbe..050c0dd308 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_helper.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_plugin_helper.hpp @@ -11,145 +11,159 @@ cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype, cudnnDataType_t* cudnn_dtype); // Enumerator for status -typedef enum { - STATUS_SUCCESS = 0, - STATUS_FAILURE = 1, - STATUS_BAD_PARAM = 2, - STATUS_NOT_SUPPORTED = 3, - STATUS_NOT_INITIALIZED = 4 +typedef enum +{ + STATUS_SUCCESS = 0, + STATUS_FAILURE = 1, + STATUS_BAD_PARAM = 2, + STATUS_NOT_SUPPORTED = 3, + STATUS_NOT_INITIALIZED = 4 } pluginStatus_t; -#define ASSERT(assertion) \ - { \ - if (!(assertion)) { \ - std::cerr << "#assertion" << __FILE__ << "," << __LINE__ << std::endl; \ - abort(); \ - } \ - } - -#define CUASSERT(status_) \ - { \ - auto s_ = status_; \ - if (s_ != cudaSuccess) { \ - std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << ", " << cudaGetErrorString(s_) \ - << std::endl; \ - } \ - } -#define CUBLASASSERT(status_) \ - { \ - auto s_ = status_; \ - if (s_ != CUBLAS_STATUS_SUCCESS) { \ - std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \ - } \ - } -#define CUERRORMSG(status_) \ - { \ - auto s_ = status_; \ - if (s_ != 0) std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \ - } +#define ASSERT(assertion) \ + { \ + if (!(assertion)) \ + { \ + std::cerr << "#assertion" << __FILE__ << "," << __LINE__ << std::endl; \ + abort(); \ + } \ + } + +#define CUASSERT(status_) \ + { \ + auto s_ = status_; \ + if (s_ != cudaSuccess) \ + { \ + std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << ", " << cudaGetErrorString(s_) \ + << std::endl; \ + } \ + } +#define CUBLASASSERT(status_) \ + { \ + auto s_ = status_; \ + if (s_ != CUBLAS_STATUS_SUCCESS) \ + { \ + std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \ + } \ + } +#define CUERRORMSG(status_) \ + { \ + auto s_ = status_; \ + if (s_ != 0) std::cerr << __FILE__ << ", " << __LINE__ << ", " << s_ << std::endl; \ + } #ifndef DEBUG -#define CHECK(status) \ - do { \ - if (status != 0) abort(); \ - } while (0) - -#define ASSERT_PARAM(exp) \ - do { \ - if (!(exp)) return STATUS_BAD_PARAM; \ - } while (0) - -#define ASSERT_FAILURE(exp) \ - do { \ - if (!(exp)) return STATUS_FAILURE; \ - } while (0) - -#define CSC(call, err) \ - do { \ - cudaError_t cudaStatus = call; \ - if (cudaStatus != cudaSuccess) { \ - return err; \ - } \ - } while (0) - -#define DEBUG_PRINTF(...) \ - do { \ - } while (0) + #define CHECK(status) \ + do { \ + if (status != 0) abort(); \ + } while (0) + + #define ASSERT_PARAM(exp) \ + do { \ + if (!(exp)) return STATUS_BAD_PARAM; \ + } while (0) + + #define ASSERT_FAILURE(exp) \ + do { \ + if (!(exp)) return STATUS_FAILURE; \ + } while (0) + + #define CSC(call, err) \ + do { \ + cudaError_t cudaStatus = call; \ + if (cudaStatus != cudaSuccess) \ + { \ + return err; \ + } \ + } while (0) + + #define DEBUG_PRINTF(...) \ + do { \ + } while (0) #else -#define ASSERT_PARAM(exp) \ - do { \ - if (!(exp)) { \ - fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \ - return STATUS_BAD_PARAM; \ - } \ - } while (0) - -#define ASSERT_FAILURE(exp) \ - do { \ - if (!(exp)) { \ - fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \ - return STATUS_FAILURE; \ - } \ - } while (0) - -#define CSC(call, err) \ - do { \ - cudaError_t cudaStatus = call; \ - if (cudaStatus != cudaSuccess) { \ - printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \ - return err; \ - } \ - } while (0) - -#define CHECK(status) \ - { \ - if (status != 0) { \ - DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \ - abort(); \ - } \ - } - -#define DEBUG_PRINTF(...) \ - do { \ - printf(__VA_ARGS__); \ - } while (0) + #define ASSERT_PARAM(exp) \ + do { \ + if (!(exp)) \ + { \ + fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \ + return STATUS_BAD_PARAM; \ + } \ + } while (0) + + #define ASSERT_FAILURE(exp) \ + do { \ + if (!(exp)) \ + { \ + fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \ + return STATUS_FAILURE; \ + } \ + } while (0) + + #define CSC(call, err) \ + do { \ + cudaError_t cudaStatus = call; \ + if (cudaStatus != cudaSuccess) \ + { \ + printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \ + return err; \ + } \ + } while (0) + + #define CHECK(status) \ + { \ + if (status != 0) \ + { \ + DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \ + abort(); \ + } \ + } + + #define DEBUG_PRINTF(...) \ + do { \ + printf(__VA_ARGS__); \ + } while (0) #endif -namespace mmdeploy { - -const int MAXTENSORDIMS = 10; - -struct TensorDesc { - int shape[MAXTENSORDIMS]; - int stride[MAXTENSORDIMS]; - int dim; -}; - -inline unsigned int getElementSize(nvinfer1::DataType t) { - switch (t) { - case nvinfer1::DataType::kINT32: - return 4; - case nvinfer1::DataType::kFLOAT: - return 4; - case nvinfer1::DataType::kHALF: - return 2; - // case nvinfer1::DataType::kBOOL: - case nvinfer1::DataType::kINT8: - return 1; - default: - throw std::runtime_error("Invalid DataType."); - } - throw std::runtime_error("Invalid DataType."); - return 0; -} - -inline size_t getAlignedSize(size_t origin_size, size_t aligned_number = 16) { - return size_t((origin_size + aligned_number - 1) / aligned_number) * aligned_number; -} +namespace mmdeploy +{ + + const int MAXTENSORDIMS = 10; + + struct TensorDesc + { + int shape[MAXTENSORDIMS]; + int stride[MAXTENSORDIMS]; + int dim; + }; + + inline unsigned int getElementSize(nvinfer1::DataType t) + { + switch (t) + { + case nvinfer1::DataType::kINT32: + return 4; + case nvinfer1::DataType::kFLOAT: + return 4; + case nvinfer1::DataType::kHALF: + return 2; + // case nvinfer1::DataType::kBOOL: + case nvinfer1::DataType::kINT8: + return 1; + default: + throw std::runtime_error("Invalid DataType."); + } + throw std::runtime_error("Invalid DataType."); + return 0; + } + + inline size_t getAlignedSize(size_t origin_size, size_t aligned_number = 16) + { + return size_t((origin_size + aligned_number - 1) / aligned_number) * aligned_number; + } } // namespace mmdeploy #endif // TRT_PLUGIN_HELPER_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_serialize.hpp b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_serialize.hpp index db88184432..d1d2fff678 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common/trt_serialize.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common/trt_serialize.hpp @@ -9,89 +9,111 @@ #include #include -template +template inline void serialize_value(void** buffer, T const& value); -template +template inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value); -namespace { +namespace +{ -template -struct Serializer {}; + template + struct Serializer + { + }; -template -struct Serializer::value || std::is_enum::value || - std::is_pod::value>::type> { - static size_t serialized_size(T const& value) { return sizeof(T); } - static void serialize(void** buffer, T const& value) { - ::memcpy(*buffer, &value, sizeof(T)); - reinterpret_cast(*buffer) += sizeof(T); - } - static void deserialize(void const** buffer, size_t* buffer_size, T* value) { - assert(*buffer_size >= sizeof(T)); - ::memcpy(value, *buffer, sizeof(T)); - reinterpret_cast(*buffer) += sizeof(T); - *buffer_size -= sizeof(T); - } -}; + template + struct Serializer::value || std::is_enum::value || + std::is_pod::value>::type> + { + static size_t serialized_size(T const& value) + { + return sizeof(T); + } + static void serialize(void** buffer, T const& value) + { + ::memcpy(*buffer, &value, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + } + static void deserialize(void const** buffer, size_t* buffer_size, T* value) + { + assert(*buffer_size >= sizeof(T)); + ::memcpy(value, *buffer, sizeof(T)); + reinterpret_cast(*buffer) += sizeof(T); + *buffer_size -= sizeof(T); + } + }; -template <> -struct Serializer { - static size_t serialized_size(const char* value) { return strlen(value) + 1; } - static void serialize(void** buffer, const char* value) { - ::strcpy(static_cast(*buffer), value); - reinterpret_cast(*buffer) += strlen(value) + 1; - } - static void deserialize(void const** buffer, size_t* buffer_size, const char** value) { - *value = static_cast(*buffer); - size_t data_size = strnlen(*value, *buffer_size) + 1; - assert(*buffer_size >= data_size); - reinterpret_cast(*buffer) += data_size; - *buffer_size -= data_size; - } -}; + template<> + struct Serializer + { + static size_t serialized_size(const char* value) + { + return strlen(value) + 1; + } + static void serialize(void** buffer, const char* value) + { + ::strcpy(static_cast(*buffer), value); + reinterpret_cast(*buffer) += strlen(value) + 1; + } + static void deserialize(void const** buffer, size_t* buffer_size, const char** value) + { + *value = static_cast(*buffer); + size_t data_size = strnlen(*value, *buffer_size) + 1; + assert(*buffer_size >= data_size); + reinterpret_cast(*buffer) += data_size; + *buffer_size -= data_size; + } + }; -template -struct Serializer, - typename std::enable_if::value || std::is_enum::value || - std::is_pod::value>::type> { - static size_t serialized_size(std::vector const& value) { - return sizeof(value.size()) + value.size() * sizeof(T); - } - static void serialize(void** buffer, std::vector const& value) { - serialize_value(buffer, value.size()); - size_t nbyte = value.size() * sizeof(T); - ::memcpy(*buffer, value.data(), nbyte); - reinterpret_cast(*buffer) += nbyte; - } - static void deserialize(void const** buffer, size_t* buffer_size, std::vector* value) { - size_t size; - deserialize_value(buffer, buffer_size, &size); - value->resize(size); - size_t nbyte = value->size() * sizeof(T); - assert(*buffer_size >= nbyte); - ::memcpy(value->data(), *buffer, nbyte); - reinterpret_cast(*buffer) += nbyte; - *buffer_size -= nbyte; - } -}; + template + struct Serializer, + typename std::enable_if::value || std::is_enum::value || + std::is_pod::value>::type> + { + static size_t serialized_size(std::vector const& value) + { + return sizeof(value.size()) + value.size() * sizeof(T); + } + static void serialize(void** buffer, std::vector const& value) + { + serialize_value(buffer, value.size()); + size_t nbyte = value.size() * sizeof(T); + ::memcpy(*buffer, value.data(), nbyte); + reinterpret_cast(*buffer) += nbyte; + } + static void deserialize(void const** buffer, size_t* buffer_size, std::vector* value) + { + size_t size; + deserialize_value(buffer, buffer_size, &size); + value->resize(size); + size_t nbyte = value->size() * sizeof(T); + assert(*buffer_size >= nbyte); + ::memcpy(value->data(), *buffer, nbyte); + reinterpret_cast(*buffer) += nbyte; + *buffer_size -= nbyte; + } + }; } // namespace -template -inline size_t serialized_size(T const& value) { - return Serializer::serialized_size(value); +template +inline size_t serialized_size(T const& value) +{ + return Serializer::serialized_size(value); } -template -inline void serialize_value(void** buffer, T const& value) { - return Serializer::serialize(buffer, value); +template +inline void serialize_value(void** buffer, T const& value) +{ + return Serializer::serialize(buffer, value); } -template -inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value) { - return Serializer::deserialize(buffer, buffer_size, value); +template +inline void deserialize_value(void const** buffer, size_t* buffer_size, T* value) +{ + return Serializer::deserialize(buffer, buffer_size, value); } #endif // TRT_SERIALIZE_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassNMS.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassNMS.cu index 44c08152db..08a6a617ce 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassNMS.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassNMS.cu @@ -7,62 +7,78 @@ const static int BS = 512; -template -__device__ T_BBOX bboxSize(const Bbox &bbox, const bool normalized, T_BBOX offset) { - if (bbox.xmax < bbox.xmin || bbox.ymax < bbox.ymin) { - // If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. - return 0; - } else { - T_BBOX width = bbox.xmax - bbox.xmin; - T_BBOX height = bbox.ymax - bbox.ymin; - if (normalized) { - return width * height; - } else { - // If bbox is not within range [0, 1]. - return (width + offset) * (height + offset); +template +__device__ T_BBOX bboxSize(const Bbox& bbox, const bool normalized, T_BBOX offset) +{ + if (bbox.xmax < bbox.xmin || bbox.ymax < bbox.ymin) + { + // If bbox is invalid (e.g. xmax < xmin or ymax < ymin), return 0. + return 0; + } + else + { + T_BBOX width = bbox.xmax - bbox.xmin; + T_BBOX height = bbox.ymax - bbox.ymin; + if (normalized) + { + return width * height; + } + else + { + // If bbox is not within range [0, 1]. + return (width + offset) * (height + offset); + } } - } } -template -__device__ void intersectBbox(const Bbox &bbox1, const Bbox &bbox2, - Bbox *intersect_bbox) { - if (bbox2.xmin > bbox1.xmax || bbox2.xmax < bbox1.xmin || bbox2.ymin > bbox1.ymax || - bbox2.ymax < bbox1.ymin) { - // Return [0, 0, 0, 0] if there is no intersection. - intersect_bbox->xmin = T_BBOX(0); - intersect_bbox->ymin = T_BBOX(0); - intersect_bbox->xmax = T_BBOX(0); - intersect_bbox->ymax = T_BBOX(0); - } else { - intersect_bbox->xmin = max(bbox1.xmin, bbox2.xmin); - intersect_bbox->ymin = max(bbox1.ymin, bbox2.ymin); - intersect_bbox->xmax = min(bbox1.xmax, bbox2.xmax); - intersect_bbox->ymax = min(bbox1.ymax, bbox2.ymax); - } +template +__device__ void intersectBbox(const Bbox& bbox1, const Bbox& bbox2, Bbox* intersect_bbox) +{ + if (bbox2.xmin > bbox1.xmax || bbox2.xmax < bbox1.xmin || bbox2.ymin > bbox1.ymax || + bbox2.ymax < bbox1.ymin) + { + // Return [0, 0, 0, 0] if there is no intersection. + intersect_bbox->xmin = T_BBOX(0); + intersect_bbox->ymin = T_BBOX(0); + intersect_bbox->xmax = T_BBOX(0); + intersect_bbox->ymax = T_BBOX(0); + } + else + { + intersect_bbox->xmin = max(bbox1.xmin, bbox2.xmin); + intersect_bbox->ymin = max(bbox1.ymin, bbox2.ymin); + intersect_bbox->xmax = min(bbox1.xmax, bbox2.xmax); + intersect_bbox->ymax = min(bbox1.ymax, bbox2.ymax); + } } -template -__device__ float jaccardOverlap(const Bbox &bbox1, const Bbox &bbox2, - const bool normalized, T_BBOX offset) { - Bbox intersect_bbox; - intersectBbox(bbox1, bbox2, &intersect_bbox); - float intersect_width, intersect_height; - if (normalized) { - intersect_width = intersect_bbox.xmax - intersect_bbox.xmin; - intersect_height = intersect_bbox.ymax - intersect_bbox.ymin; - } else { - intersect_width = intersect_bbox.xmax - intersect_bbox.xmin + offset; - intersect_height = intersect_bbox.ymax - intersect_bbox.ymin + offset; - } - if (intersect_width > 0 && intersect_height > 0) { - float intersect_size = intersect_width * intersect_height; - float bbox1_size = bboxSize(bbox1, normalized, offset); - float bbox2_size = bboxSize(bbox2, normalized, offset); - return intersect_size / (bbox1_size + bbox2_size - intersect_size); - } else { - return 0.; - } +template +__device__ float jaccardOverlap(const Bbox& bbox1, const Bbox& bbox2, const bool normalized, T_BBOX offset) +{ + Bbox intersect_bbox; + intersectBbox(bbox1, bbox2, &intersect_bbox); + float intersect_width, intersect_height; + if (normalized) + { + intersect_width = intersect_bbox.xmax - intersect_bbox.xmin; + intersect_height = intersect_bbox.ymax - intersect_bbox.ymin; + } + else + { + intersect_width = intersect_bbox.xmax - intersect_bbox.xmin + offset; + intersect_height = intersect_bbox.ymax - intersect_bbox.ymin + offset; + } + if (intersect_width > 0 && intersect_height > 0) + { + float intersect_size = intersect_width * intersect_height; + float bbox1_size = bboxSize(bbox1, normalized, offset); + float bbox2_size = bboxSize(bbox2, normalized, offset); + return intersect_size / (bbox1_size + bbox2_size - intersect_size); + } + else + { + return 0.; + } } /********** new NMS for only score and index array **********/ @@ -82,186 +98,211 @@ allClassNMS_kernel(const int num, const int num_classes, const int num_preds_per // location information T_SCORE *beforeNMS_scores, int *beforeNMS_index_array, T_SCORE *afterNMS_scores, int *afterNMS_index_array, bool flipXY = false) { - // clang-format on - //__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE]; - __shared__ bool kept_bboxinfo_flag[TSIZE * BS]; - for (int i = 0; i < num; i++) { - const int offset = i * num_classes * num_preds_per_class + blockIdx.x * num_preds_per_class; - const int max_idx = offset + top_k; // put top_k bboxes into NMS calculation - const int bbox_idx_offset = - share_location ? (i * num_preds_per_class) : (i * num_classes * num_preds_per_class); - - // local thread data - int loc_bboxIndex[TSIZE]; - Bbox loc_bbox[TSIZE]; - - // initialize Bbox, Bboxinfo, kept_bboxinfo_flag - // Eliminate shared memory RAW hazard - __syncthreads(); + // clang-format on + //__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE]; + __shared__ bool kept_bboxinfo_flag[TSIZE * BS]; + for (int i = 0; i < num; i++) + { + const int offset = i * num_classes * num_preds_per_class + blockIdx.x * num_preds_per_class; + const int max_idx = offset + top_k; // put top_k bboxes into NMS calculation + const int bbox_idx_offset = + share_location ? (i * num_preds_per_class) : (i * num_classes * num_preds_per_class); + + // local thread data + int loc_bboxIndex[TSIZE]; + Bbox loc_bbox[TSIZE]; + + // initialize Bbox, Bboxinfo, kept_bboxinfo_flag + // Eliminate shared memory RAW hazard + __syncthreads(); #pragma unroll - for (int t = 0; t < TSIZE; t++) { - const int cur_idx = threadIdx.x + blockDim.x * t; - const int item_idx = offset + cur_idx; + for (int t = 0; t < TSIZE; t++) + { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int item_idx = offset + cur_idx; + + if (item_idx < max_idx) + { + loc_bboxIndex[t] = beforeNMS_index_array[item_idx]; + + if (loc_bboxIndex[t] >= 0) + // if (loc_bboxIndex[t] != -1) + { + const int bbox_data_idx = share_location ? (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset) : loc_bboxIndex[t]; + + loc_bbox[t].xmin = + flipXY ? bbox_data[bbox_data_idx * 4 + 1] : bbox_data[bbox_data_idx * 4 + 0]; + loc_bbox[t].ymin = + flipXY ? bbox_data[bbox_data_idx * 4 + 0] : bbox_data[bbox_data_idx * 4 + 1]; + loc_bbox[t].xmax = + flipXY ? bbox_data[bbox_data_idx * 4 + 3] : bbox_data[bbox_data_idx * 4 + 2]; + loc_bbox[t].ymax = + flipXY ? bbox_data[bbox_data_idx * 4 + 2] : bbox_data[bbox_data_idx * 4 + 3]; + kept_bboxinfo_flag[cur_idx] = true; + } + else + { + kept_bboxinfo_flag[cur_idx] = false; + } + } + else + { + kept_bboxinfo_flag[cur_idx] = false; + } + } - if (item_idx < max_idx) { - loc_bboxIndex[t] = beforeNMS_index_array[item_idx]; + // filter out overlapped boxes with lower scores + int ref_item_idx = offset; + int ref_bbox_idx = + share_location ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) : beforeNMS_index_array[ref_item_idx]; - if (loc_bboxIndex[t] >= 0) - // if (loc_bboxIndex[t] != -1) + while ((ref_bbox_idx != -1) && ref_item_idx < max_idx) { - const int bbox_data_idx = share_location - ? (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset) - : loc_bboxIndex[t]; - - loc_bbox[t].xmin = - flipXY ? bbox_data[bbox_data_idx * 4 + 1] : bbox_data[bbox_data_idx * 4 + 0]; - loc_bbox[t].ymin = - flipXY ? bbox_data[bbox_data_idx * 4 + 0] : bbox_data[bbox_data_idx * 4 + 1]; - loc_bbox[t].xmax = - flipXY ? bbox_data[bbox_data_idx * 4 + 3] : bbox_data[bbox_data_idx * 4 + 2]; - loc_bbox[t].ymax = - flipXY ? bbox_data[bbox_data_idx * 4 + 2] : bbox_data[bbox_data_idx * 4 + 3]; - kept_bboxinfo_flag[cur_idx] = true; - } else { - kept_bboxinfo_flag[cur_idx] = false; + Bbox ref_bbox; + ref_bbox.xmin = flipXY ? bbox_data[ref_bbox_idx * 4 + 1] : bbox_data[ref_bbox_idx * 4 + 0]; + ref_bbox.ymin = flipXY ? bbox_data[ref_bbox_idx * 4 + 0] : bbox_data[ref_bbox_idx * 4 + 1]; + ref_bbox.xmax = flipXY ? bbox_data[ref_bbox_idx * 4 + 3] : bbox_data[ref_bbox_idx * 4 + 2]; + ref_bbox.ymax = flipXY ? bbox_data[ref_bbox_idx * 4 + 2] : bbox_data[ref_bbox_idx * 4 + 3]; + + // Eliminate shared memory RAW hazard + __syncthreads(); + + for (int t = 0; t < TSIZE; t++) + { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int item_idx = offset + cur_idx; + + if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) + { + // TODO: may need to add bool normalized as argument, HERE true means + // normalized + if (jaccardOverlap(ref_bbox, loc_bbox[t], isNormalized, T_BBOX(0)) > nms_threshold) + { + kept_bboxinfo_flag[cur_idx] = false; + } + } + } + __syncthreads(); + + do { + ref_item_idx++; + } while (ref_item_idx < max_idx && !kept_bboxinfo_flag[ref_item_idx - offset]); + + ref_bbox_idx = + share_location ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) : beforeNMS_index_array[ref_item_idx]; } - } else { - kept_bboxinfo_flag[cur_idx] = false; - } - } - // filter out overlapped boxes with lower scores - int ref_item_idx = offset; - int ref_bbox_idx = - share_location - ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) - : beforeNMS_index_array[ref_item_idx]; - - while ((ref_bbox_idx != -1) && ref_item_idx < max_idx) { - Bbox ref_bbox; - ref_bbox.xmin = flipXY ? bbox_data[ref_bbox_idx * 4 + 1] : bbox_data[ref_bbox_idx * 4 + 0]; - ref_bbox.ymin = flipXY ? bbox_data[ref_bbox_idx * 4 + 0] : bbox_data[ref_bbox_idx * 4 + 1]; - ref_bbox.xmax = flipXY ? bbox_data[ref_bbox_idx * 4 + 3] : bbox_data[ref_bbox_idx * 4 + 2]; - ref_bbox.ymax = flipXY ? bbox_data[ref_bbox_idx * 4 + 2] : bbox_data[ref_bbox_idx * 4 + 3]; - - // Eliminate shared memory RAW hazard - __syncthreads(); - - for (int t = 0; t < TSIZE; t++) { - const int cur_idx = threadIdx.x + blockDim.x * t; - const int item_idx = offset + cur_idx; - - if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) { - // TODO: may need to add bool normalized as argument, HERE true means - // normalized - if (jaccardOverlap(ref_bbox, loc_bbox[t], isNormalized, T_BBOX(0)) > nms_threshold) { - kept_bboxinfo_flag[cur_idx] = false; - } + // store data + for (int t = 0; t < TSIZE; t++) + { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int read_item_idx = offset + cur_idx; + const int write_item_idx = (i * num_classes * top_k + blockIdx.x * top_k) + cur_idx; + /* + * If not not keeping the bbox + * Set the score to 0 + * Set the bounding box index to -1 + */ + if (read_item_idx < max_idx) + { + afterNMS_scores[write_item_idx] = + kept_bboxinfo_flag[cur_idx] ? beforeNMS_scores[read_item_idx] : 0.0f; + afterNMS_index_array[write_item_idx] = kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1; + } } - } - __syncthreads(); - - do { - ref_item_idx++; - } while (ref_item_idx < max_idx && !kept_bboxinfo_flag[ref_item_idx - offset]); - - ref_bbox_idx = - share_location - ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) - : beforeNMS_index_array[ref_item_idx]; } - - // store data - for (int t = 0; t < TSIZE; t++) { - const int cur_idx = threadIdx.x + blockDim.x * t; - const int read_item_idx = offset + cur_idx; - const int write_item_idx = (i * num_classes * top_k + blockIdx.x * top_k) + cur_idx; - /* - * If not not keeping the bbox - * Set the score to 0 - * Set the bounding box index to -1 - */ - if (read_item_idx < max_idx) { - afterNMS_scores[write_item_idx] = - kept_bboxinfo_flag[cur_idx] ? beforeNMS_scores[read_item_idx] : 0.0f; - afterNMS_index_array[write_item_idx] = kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1; - } - } - } } -template -pluginStatus_t allClassNMS_gpu(cudaStream_t stream, const int num, const int num_classes, - const int num_preds_per_class, const int top_k, - const float nms_threshold, const bool share_location, - const bool isNormalized, void *bbox_data, void *beforeNMS_scores, - void *beforeNMS_index_array, void *afterNMS_scores, - void *afterNMS_index_array, bool flipXY = false) { +template +pluginStatus_t allClassNMS_gpu(cudaStream_t stream, const int num, const int num_classes, const int num_preds_per_class, const int top_k, const float nms_threshold, const bool share_location, const bool isNormalized, void* bbox_data, void* beforeNMS_scores, void* beforeNMS_index_array, void* afterNMS_scores, void* afterNMS_index_array, bool flipXY = false) +{ #define P(tsize) allClassNMS_kernel - void (*kernel[10])(const int, const int, const int, const int, const float, const bool, - const bool, float *, T_SCORE *, int *, T_SCORE *, int *, bool) = { - P(1), P(2), P(3), P(4), P(5), P(6), P(7), P(8), P(9), P(10), - }; - - const int GS = num_classes; - const int t_size = (top_k + BS - 1) / BS; - - ASSERT(t_size <= 10); - kernel[t_size - 1]<<>>( - num, num_classes, num_preds_per_class, top_k, nms_threshold, share_location, isNormalized, - (T_BBOX *)bbox_data, (T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array, - (T_SCORE *)afterNMS_scores, (int *)afterNMS_index_array, flipXY); - - cudaError_t code = cudaGetLastError(); - CUASSERT(code); - CSC(code, STATUS_FAILURE); - return STATUS_SUCCESS; + void (*kernel[10])(const int, const int, const int, const int, const float, const bool, const bool, float*, T_SCORE*, int*, T_SCORE*, int*, bool) = { + P(1), + P(2), + P(3), + P(4), + P(5), + P(6), + P(7), + P(8), + P(9), + P(10), + }; + + const int GS = num_classes; + const int t_size = (top_k + BS - 1) / BS; + + ASSERT(t_size <= 10); + kernel[t_size - 1]<<>>( + num, + num_classes, + num_preds_per_class, + top_k, + nms_threshold, + share_location, + isNormalized, + (T_BBOX*)bbox_data, + (T_SCORE*)beforeNMS_scores, + (int*)beforeNMS_index_array, + (T_SCORE*)afterNMS_scores, + (int*)afterNMS_index_array, + flipXY); + + cudaError_t code = cudaGetLastError(); + CUASSERT(code); + CSC(code, STATUS_FAILURE); + return STATUS_SUCCESS; } // allClassNMS LAUNCH CONFIG -typedef pluginStatus_t (*nmsFunc)(cudaStream_t, const int, const int, const int, const int, - const float, const bool, const bool, void *, void *, void *, - void *, void *, bool); - -struct nmsLaunchConfigSSD { - DataType t_score; - DataType t_bbox; - nmsFunc function; - - nmsLaunchConfigSSD(DataType t_score, DataType t_bbox) : t_score(t_score), t_bbox(t_bbox) {} - nmsLaunchConfigSSD(DataType t_score, DataType t_bbox, nmsFunc function) - : t_score(t_score), t_bbox(t_bbox), function(function) {} - bool operator==(const nmsLaunchConfigSSD &other) { - return t_score == other.t_score && t_bbox == other.t_bbox; - } +typedef pluginStatus_t (*nmsFunc)(cudaStream_t, const int, const int, const int, const int, const float, const bool, const bool, void*, void*, void*, void*, void*, bool); + +struct nmsLaunchConfigSSD +{ + DataType t_score; + DataType t_bbox; + nmsFunc function; + + nmsLaunchConfigSSD(DataType t_score, DataType t_bbox) + : t_score(t_score) + , t_bbox(t_bbox) + { + } + nmsLaunchConfigSSD(DataType t_score, DataType t_bbox, nmsFunc function) + : t_score(t_score) + , t_bbox(t_bbox) + , function(function) + { + } + bool operator==(const nmsLaunchConfigSSD& other) + { + return t_score == other.t_score && t_bbox == other.t_bbox; + } }; static std::vector nmsFuncVec; -bool nmsInit() { - nmsFuncVec.push_back( - nmsLaunchConfigSSD(DataType::kFLOAT, DataType::kFLOAT, allClassNMS_gpu)); - return true; +bool nmsInit() +{ + nmsFuncVec.push_back( + nmsLaunchConfigSSD(DataType::kFLOAT, DataType::kFLOAT, allClassNMS_gpu)); + return true; } -static bool initialized = nmsInit(); - -pluginStatus_t allClassNMS(cudaStream_t stream, const int num, const int num_classes, - const int num_preds_per_class, const int top_k, - const float nms_threshold, const bool share_location, - const bool isNormalized, const DataType DT_SCORE, const DataType DT_BBOX, - void *bbox_data, void *beforeNMS_scores, void *beforeNMS_index_array, - void *afterNMS_scores, void *afterNMS_index_array, bool flipXY) { - nmsLaunchConfigSSD lc(DT_SCORE, DT_BBOX); - for (unsigned i = 0; i < nmsFuncVec.size(); ++i) { - if (lc == nmsFuncVec[i]) { - DEBUG_PRINTF("all class nms kernel %d\n", i); - return nmsFuncVec[i].function(stream, num, num_classes, num_preds_per_class, top_k, - nms_threshold, share_location, isNormalized, bbox_data, - beforeNMS_scores, beforeNMS_index_array, afterNMS_scores, - afterNMS_index_array, flipXY); +static bool initialized = nmsInit(); + +pluginStatus_t allClassNMS(cudaStream_t stream, const int num, const int num_classes, const int num_preds_per_class, const int top_k, const float nms_threshold, const bool share_location, const bool isNormalized, const DataType DT_SCORE, const DataType DT_BBOX, void* bbox_data, void* beforeNMS_scores, void* beforeNMS_index_array, void* afterNMS_scores, void* afterNMS_index_array, bool flipXY) +{ + nmsLaunchConfigSSD lc(DT_SCORE, DT_BBOX); + for (unsigned i = 0; i < nmsFuncVec.size(); ++i) + { + if (lc == nmsFuncVec[i]) + { + DEBUG_PRINTF("all class nms kernel %d\n", i); + return nmsFuncVec[i].function(stream, num, num_classes, num_preds_per_class, top_k, nms_threshold, share_location, isNormalized, bbox_data, beforeNMS_scores, beforeNMS_index_array, afterNMS_scores, afterNMS_index_array, flipXY); + } } - } - return STATUS_BAD_PARAM; + return STATUS_BAD_PARAM; } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu index 0edea2bfaf..52758ea247 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/allClassRotatedNMS.cu @@ -6,490 +6,559 @@ #include "nms/kernel.h" -template -struct RotatedBox { - T x_ctr, y_ctr, w, h, a; +template +struct RotatedBox +{ + T x_ctr, y_ctr, w, h, a; }; -template -struct Point { - T x, y; - __host__ __device__ __forceinline__ Point(const T &px = 0, const T &py = 0) : x(px), y(py) {} - __host__ __device__ __forceinline__ Point operator+(const Point &p) const { - return Point(x + p.x, y + p.y); - } - __host__ __device__ __forceinline__ Point &operator+=(const Point &p) { - x += p.x; - y += p.y; - return *this; - } - __host__ __device__ __forceinline__ Point operator-(const Point &p) const { - return Point(x - p.x, y - p.y); - } - __host__ __device__ __forceinline__ Point operator*(const T coeff) const { - return Point(x * coeff, y * coeff); - } +template +struct Point +{ + T x, y; + __host__ __device__ __forceinline__ Point(const T& px = 0, const T& py = 0) + : x(px) + , y(py) + { + } + __host__ __device__ __forceinline__ Point operator+(const Point& p) const + { + return Point(x + p.x, y + p.y); + } + __host__ __device__ __forceinline__ Point& operator+=(const Point& p) + { + x += p.x; + y += p.y; + return *this; + } + __host__ __device__ __forceinline__ Point operator-(const Point& p) const + { + return Point(x - p.x, y - p.y); + } + __host__ __device__ __forceinline__ Point operator*(const T coeff) const + { + return Point(x * coeff, y * coeff); + } }; -template -__host__ __device__ __forceinline__ T dot_2d(const Point &A, const Point &B) { - return A.x * B.x + A.y * B.y; +template +__host__ __device__ __forceinline__ T dot_2d(const Point& A, const Point& B) +{ + return A.x * B.x + A.y * B.y; } -template -__host__ __device__ __forceinline__ T cross_2d(const Point &A, const Point &B) { - return A.x * B.y - B.x * A.y; +template +__host__ __device__ __forceinline__ T cross_2d(const Point& A, const Point& B) +{ + return A.x * B.y - B.x * A.y; } -template -__host__ __device__ __forceinline__ void get_rotated_vertices(const RotatedBox &box, - Point (&pts)[4]) { - // M_PI / 180. == 0.01745329251 - // double theta = box.a * 0.01745329251; - // MODIFIED - double theta = box.a; - T cosTheta2 = (T)cos(theta) * 0.5f; - T sinTheta2 = (T)sin(theta) * 0.5f; - - // y: top --> down; x: left --> right - pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w; - pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; - pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w; - pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; - pts[2].x = 2 * box.x_ctr - pts[0].x; - pts[2].y = 2 * box.y_ctr - pts[0].y; - pts[3].x = 2 * box.x_ctr - pts[1].x; - pts[3].y = 2 * box.y_ctr - pts[1].y; +template +__host__ __device__ __forceinline__ void get_rotated_vertices(const RotatedBox& box, + Point (&pts)[4]) +{ + // M_PI / 180. == 0.01745329251 + // double theta = box.a * 0.01745329251; + // MODIFIED + double theta = box.a; + T cosTheta2 = (T)cos(theta) * 0.5f; + T sinTheta2 = (T)sin(theta) * 0.5f; + + // y: top --> down; x: left --> right + pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w; + pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; + pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w; + pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; + pts[2].x = 2 * box.x_ctr - pts[0].x; + pts[2].y = 2 * box.y_ctr - pts[0].y; + pts[3].x = 2 * box.x_ctr - pts[1].x; + pts[3].y = 2 * box.y_ctr - pts[1].y; } -template +template __host__ __device__ __forceinline__ int get_intersection_points(const Point (&pts1)[4], const Point (&pts2)[4], - Point (&intersections)[24]) { - // Line vector - // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] - Point vec1[4], vec2[4]; - for (int i = 0; i < 4; i++) { - vec1[i] = pts1[(i + 1) % 4] - pts1[i]; - vec2[i] = pts2[(i + 1) % 4] - pts2[i]; - } - - // Line test - test all line combos for intersection - int num = 0; // number of intersections - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - // Solve for 2x2 Ax=b - T det = cross_2d(vec2[j], vec1[i]); - - // This takes care of parallel lines - if (fabs(det) <= 1e-14) { - continue; - } - - auto vec12 = pts2[j] - pts1[i]; - - T t1 = cross_2d(vec2[j], vec12) / det; - T t2 = cross_2d(vec1[i], vec12) / det; - - if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) { - intersections[num++] = pts1[i] + vec1[i] * t1; - } + Point (&intersections)[24]) +{ + // Line vector + // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] + Point vec1[4], vec2[4]; + for (int i = 0; i < 4; i++) + { + vec1[i] = pts1[(i + 1) % 4] - pts1[i]; + vec2[i] = pts2[(i + 1) % 4] - pts2[i]; } - } - - // Check for vertices of rect1 inside rect2 - { - const auto &AB = vec2[0]; - const auto &DA = vec2[3]; - auto ABdotAB = dot_2d(AB, AB); - auto ADdotAD = dot_2d(DA, DA); - for (int i = 0; i < 4; i++) { - // assume ABCD is the rectangle, and P is the point to be judged - // P is inside ABCD iff. P's projection on AB lies within AB - // and P's projection on AD lies within AD - - auto AP = pts1[i] - pts2[0]; - - auto APdotAB = dot_2d(AP, AB); - auto APdotAD = -dot_2d(AP, DA); - - if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) { - intersections[num++] = pts1[i]; - } + + // Line test - test all line combos for intersection + int num = 0; // number of intersections + for (int i = 0; i < 4; i++) + { + for (int j = 0; j < 4; j++) + { + // Solve for 2x2 Ax=b + T det = cross_2d(vec2[j], vec1[i]); + + // This takes care of parallel lines + if (fabs(det) <= 1e-14) + { + continue; + } + + auto vec12 = pts2[j] - pts1[i]; + + T t1 = cross_2d(vec2[j], vec12) / det; + T t2 = cross_2d(vec1[i], vec12) / det; + + if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) + { + intersections[num++] = pts1[i] + vec1[i] * t1; + } + } } - } - - // Reverse the check - check for vertices of rect2 inside rect1 - { - const auto &AB = vec1[0]; - const auto &DA = vec1[3]; - auto ABdotAB = dot_2d(AB, AB); - auto ADdotAD = dot_2d(DA, DA); - for (int i = 0; i < 4; i++) { - auto AP = pts2[i] - pts1[0]; - - auto APdotAB = dot_2d(AP, AB); - auto APdotAD = -dot_2d(AP, DA); - - if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) { - intersections[num++] = pts2[i]; - } + + // Check for vertices of rect1 inside rect2 + { + const auto& AB = vec2[0]; + const auto& DA = vec2[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) + { + // assume ABCD is the rectangle, and P is the point to be judged + // P is inside ABCD iff. P's projection on AB lies within AB + // and P's projection on AD lies within AD + + auto AP = pts1[i] - pts2[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) + { + intersections[num++] = pts1[i]; + } + } + } + + // Reverse the check - check for vertices of rect2 inside rect1 + { + const auto& AB = vec1[0]; + const auto& DA = vec1[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) + { + auto AP = pts2[i] - pts1[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= ADdotAD)) + { + intersections[num++] = pts2[i]; + } + } } - } - return num; + return num; } -template +template __host__ __device__ __forceinline__ int convex_hull_graham(const Point (&p)[24], - const int &num_in, Point (&q)[24], - bool shift_to_zero = false) { - assert(num_in >= 2); - - // Step 1: - // Find point with minimum y - // if more than 1 points have the same minimum y, - // pick the one with the minimum x. - int t = 0; - for (int i = 1; i < num_in; i++) { - if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) { - t = i; + const int& num_in, + Point (&q)[24], + bool shift_to_zero = false) +{ + assert(num_in >= 2); + + // Step 1: + // Find point with minimum y + // if more than 1 points have the same minimum y, + // pick the one with the minimum x. + int t = 0; + for (int i = 1; i < num_in; i++) + { + if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) + { + t = i; + } + } + auto& start = p[t]; // starting point + + // Step 2: + // Subtract starting point from every points (for sorting in the next step) + for (int i = 0; i < num_in; i++) + { + q[i] = p[i] - start; + } + + // Swap the starting point to position 0 + auto tmp = q[0]; + q[0] = q[t]; + q[t] = tmp; + + // Step 3: + // Sort point 1 ~ num_in according to their relative cross-product values + // (essentially sorting according to angles) + // If the angles are the same, sort according to their distance to origin + T dist[24]; + for (int i = 0; i < num_in; i++) + { + dist[i] = dot_2d(q[i], q[i]); + } + + for (int i = 1; i < num_in - 1; i++) + { + for (int j = i + 1; j < num_in; j++) + { + T crossProduct = cross_2d(q[i], q[j]); + if ((crossProduct < -1e-6) || (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) + { + auto q_tmp = q[i]; + q[i] = q[j]; + q[j] = q_tmp; + auto dist_tmp = dist[i]; + dist[i] = dist[j]; + dist[j] = dist_tmp; + } + } } - } - auto &start = p[t]; // starting point - - // Step 2: - // Subtract starting point from every points (for sorting in the next step) - for (int i = 0; i < num_in; i++) { - q[i] = p[i] - start; - } - - // Swap the starting point to position 0 - auto tmp = q[0]; - q[0] = q[t]; - q[t] = tmp; - - // Step 3: - // Sort point 1 ~ num_in according to their relative cross-product values - // (essentially sorting according to angles) - // If the angles are the same, sort according to their distance to origin - T dist[24]; - for (int i = 0; i < num_in; i++) { - dist[i] = dot_2d(q[i], q[i]); - } - - for (int i = 1; i < num_in - 1; i++) { - for (int j = i + 1; j < num_in; j++) { - T crossProduct = cross_2d(q[i], q[j]); - if ((crossProduct < -1e-6) || (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) { - auto q_tmp = q[i]; - q[i] = q[j]; - q[j] = q_tmp; - auto dist_tmp = dist[i]; - dist[i] = dist[j]; - dist[j] = dist_tmp; - } + + // Step 4: + // Make sure there are at least 2 points (that don't overlap with each other) + // in the stack + int k; // index of the non-overlapped second point + for (k = 1; k < num_in; k++) + { + if (dist[k] > 1e-8) + { + break; + } } - } - - // Step 4: - // Make sure there are at least 2 points (that don't overlap with each other) - // in the stack - int k; // index of the non-overlapped second point - for (k = 1; k < num_in; k++) { - if (dist[k] > 1e-8) { - break; + if (k == num_in) + { + // We reach the end, which means the convex hull is just one point + q[0] = p[t]; + return 1; } - } - if (k == num_in) { - // We reach the end, which means the convex hull is just one point - q[0] = p[t]; - return 1; - } - q[1] = q[k]; - int m = 2; // 2 points in the stack - // Step 5: - // Finally we can start the scanning process. - // When a non-convex relationship between the 3 points is found - // (either concave shape or duplicated points), - // we pop the previous point from the stack - // until the 3-point relationship is convex again, or - // until the stack only contains two points - for (int i = k + 1; i < num_in; i++) { - while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) { - m--; + q[1] = q[k]; + int m = 2; // 2 points in the stack + // Step 5: + // Finally we can start the scanning process. + // When a non-convex relationship between the 3 points is found + // (either concave shape or duplicated points), + // we pop the previous point from the stack + // until the 3-point relationship is convex again, or + // until the stack only contains two points + for (int i = k + 1; i < num_in; i++) + { + while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) + { + m--; + } + q[m++] = q[i]; } - q[m++] = q[i]; - } - - // Step 6 (Optional): - // In general sense we need the original coordinates, so we - // need to shift the points back (reverting Step 2) - // But if we're only interested in getting the area/perimeter of the shape - // We can simply return. - if (!shift_to_zero) { - for (int i = 0; i < m; i++) { - q[i] += start; + + // Step 6 (Optional): + // In general sense we need the original coordinates, so we + // need to shift the points back (reverting Step 2) + // But if we're only interested in getting the area/perimeter of the shape + // We can simply return. + if (!shift_to_zero) + { + for (int i = 0; i < m; i++) + { + q[i] += start; + } } - } - return m; + return m; } -template -__host__ __device__ __forceinline__ T polygon_area(const Point (&q)[24], const int &m) { - if (m <= 2) { - return 0; - } +template +__host__ __device__ __forceinline__ T polygon_area(const Point (&q)[24], const int& m) +{ + if (m <= 2) + { + return 0; + } - T area = 0; - for (int i = 1; i < m - 1; i++) { - area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); - } + T area = 0; + for (int i = 1; i < m - 1; i++) + { + area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); + } - return area / 2.0; + return area / 2.0; } -template -__host__ __device__ __forceinline__ T rotated_boxes_intersection(const RotatedBox &box1, - const RotatedBox &box2) { - // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned - // from rotated_rect_intersection_pts - Point intersectPts[24], orderedPts[24]; +template +__host__ __device__ __forceinline__ T rotated_boxes_intersection(const RotatedBox& box1, + const RotatedBox& box2) +{ + // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned + // from rotated_rect_intersection_pts + Point intersectPts[24], orderedPts[24]; - Point pts1[4]; - Point pts2[4]; - get_rotated_vertices(box1, pts1); - get_rotated_vertices(box2, pts2); + Point pts1[4]; + Point pts2[4]; + get_rotated_vertices(box1, pts1); + get_rotated_vertices(box2, pts2); - int num = get_intersection_points(pts1, pts2, intersectPts); + int num = get_intersection_points(pts1, pts2, intersectPts); - if (num <= 2) { - return 0.0; - } + if (num <= 2) + { + return 0.0; + } - // Convex Hull to order the intersection points in clockwise order and find - // the contour area. - int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); - return polygon_area(orderedPts, num_convex); + // Convex Hull to order the intersection points in clockwise order and find + // the contour area. + int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); + return polygon_area(orderedPts, num_convex); } -template -__host__ __device__ __forceinline__ T single_box_iou_rotated(T const *const box1_raw, - T const *const box2_raw) { - // shift center to the middle point to achieve higher precision in result - RotatedBox box1, box2; - auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0; - auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0; - box1.x_ctr = box1_raw[0] - center_shift_x; - box1.y_ctr = box1_raw[1] - center_shift_y; - box1.w = box1_raw[2]; - box1.h = box1_raw[3]; - box1.a = box1_raw[4]; - box2.x_ctr = box2_raw[0] - center_shift_x; - box2.y_ctr = box2_raw[1] - center_shift_y; - box2.w = box2_raw[2]; - box2.h = box2_raw[3]; - box2.a = box2_raw[4]; - - const T area1 = box1.w * box1.h; - const T area2 = box2.w * box2.h; - if (area1 < 1e-14 || area2 < 1e-14) { - return 1.0f; - } - - const T intersection = rotated_boxes_intersection(box1, box2); - T baseS = 1.0; - baseS = (area1 + area2 - intersection); - const T iou = intersection / baseS; - return iou; +template +__host__ __device__ __forceinline__ T single_box_iou_rotated(T const* const box1_raw, + T const* const box2_raw) +{ + // shift center to the middle point to achieve higher precision in result + RotatedBox box1, box2; + auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0; + auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0; + box1.x_ctr = box1_raw[0] - center_shift_x; + box1.y_ctr = box1_raw[1] - center_shift_y; + box1.w = box1_raw[2]; + box1.h = box1_raw[3]; + box1.a = box1_raw[4]; + box2.x_ctr = box2_raw[0] - center_shift_x; + box2.y_ctr = box2_raw[1] - center_shift_y; + box2.w = box2_raw[2]; + box2.h = box2_raw[3]; + box2.a = box2_raw[4]; + + const T area1 = box1.w * box1.h; + const T area2 = box2.w * box2.h; + if (area1 < 1e-14 || area2 < 1e-14) + { + return 1.0f; + } + + const T intersection = rotated_boxes_intersection(box1, box2); + T baseS = 1.0; + baseS = (area1 + area2 - intersection); + const T iou = intersection / baseS; + return iou; } /********** new NMS for only score and index array **********/ -template -__global__ void allClassRotatedNMS_kernel(const int num, const int num_classes, - const int num_preds_per_class, const int top_k, - const float nms_threshold, const bool share_location, - const bool isNormalized, - T_BBOX *bbox_data, // bbox_data should be float to - // preserve location information - T_SCORE *beforeNMS_scores, int *beforeNMS_index_array, - T_SCORE *afterNMS_scores, int *afterNMS_index_array) { - //__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE]; - extern __shared__ bool kept_bboxinfo_flag[]; - for (int i = 0; i < num; i++) { - const int offset = i * num_classes * num_preds_per_class + blockIdx.x * num_preds_per_class; - const int max_idx = offset + top_k; // put top_k bboxes into NMS calculation - const int bbox_idx_offset = - share_location ? (i * num_preds_per_class) : (i * num_classes * num_preds_per_class); - - // local thread data - int loc_bboxIndex[TSIZE]; - T_BBOX loc_bbox[TSIZE * 5]; - - // initialize Bbox, Bboxinfo, kept_bboxinfo_flag - // Eliminate shared memory RAW hazard - __syncthreads(); +template +__global__ void allClassRotatedNMS_kernel(const int num, const int num_classes, const int num_preds_per_class, const int top_k, const float nms_threshold, const bool share_location, const bool isNormalized, + T_BBOX* bbox_data, // bbox_data should be float to + // preserve location information + T_SCORE* beforeNMS_scores, + int* beforeNMS_index_array, + T_SCORE* afterNMS_scores, + int* afterNMS_index_array) +{ + //__shared__ bool kept_bboxinfo_flag[CAFFE_CUDA_NUM_THREADS * TSIZE]; + extern __shared__ bool kept_bboxinfo_flag[]; + for (int i = 0; i < num; i++) + { + const int offset = i * num_classes * num_preds_per_class + blockIdx.x * num_preds_per_class; + const int max_idx = offset + top_k; // put top_k bboxes into NMS calculation + const int bbox_idx_offset = + share_location ? (i * num_preds_per_class) : (i * num_classes * num_preds_per_class); + + // local thread data + int loc_bboxIndex[TSIZE]; + T_BBOX loc_bbox[TSIZE * 5]; + + // initialize Bbox, Bboxinfo, kept_bboxinfo_flag + // Eliminate shared memory RAW hazard + __syncthreads(); #pragma unroll - for (int t = 0; t < TSIZE; t++) { - const int cur_idx = threadIdx.x + blockDim.x * t; - const int item_idx = offset + cur_idx; + for (int t = 0; t < TSIZE; t++) + { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int item_idx = offset + cur_idx; + + if (item_idx < max_idx) + { + loc_bboxIndex[t] = beforeNMS_index_array[item_idx]; + + if (loc_bboxIndex[t] >= 0) + // if (loc_bboxIndex[t] != -1) + { + const int bbox_data_idx = share_location ? (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset) : loc_bboxIndex[t]; + memcpy(&loc_bbox[t * 5], &bbox_data[bbox_data_idx * 5], 5 * sizeof(T_BBOX)); + kept_bboxinfo_flag[cur_idx] = true; + } + else + { + kept_bboxinfo_flag[cur_idx] = false; + } + } + else + { + kept_bboxinfo_flag[cur_idx] = false; + } + } - if (item_idx < max_idx) { - loc_bboxIndex[t] = beforeNMS_index_array[item_idx]; + // filter out overlapped boxes with lower scores + int ref_item_idx = offset; + int ref_bbox_idx = + share_location ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) : beforeNMS_index_array[ref_item_idx]; - if (loc_bboxIndex[t] >= 0) - // if (loc_bboxIndex[t] != -1) + while ((ref_bbox_idx != -1) && ref_item_idx < max_idx) { - const int bbox_data_idx = share_location - ? (loc_bboxIndex[t] % num_preds_per_class + bbox_idx_offset) - : loc_bboxIndex[t]; - memcpy(&loc_bbox[t * 5], &bbox_data[bbox_data_idx * 5], 5 * sizeof(T_BBOX)); - kept_bboxinfo_flag[cur_idx] = true; - } else { - kept_bboxinfo_flag[cur_idx] = false; + T_BBOX ref_bbox[5]; + memcpy(&ref_bbox[0], &bbox_data[ref_bbox_idx * 5], 5 * sizeof(T_BBOX)); + + // Eliminate shared memory RAW hazard + __syncthreads(); + + for (int t = 0; t < TSIZE; t++) + { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int item_idx = offset + cur_idx; + + if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) + { + // TODO: may need to add bool normalized as argument, HERE true means + // normalized + if (single_box_iou_rotated(&ref_bbox[0], loc_bbox + t * 5) > nms_threshold) + { + kept_bboxinfo_flag[cur_idx] = false; + } + } + } + __syncthreads(); + + do { + ref_item_idx++; + } while (ref_item_idx < max_idx && !kept_bboxinfo_flag[ref_item_idx - offset]); + + ref_bbox_idx = + share_location ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) : beforeNMS_index_array[ref_item_idx]; } - } else { - kept_bboxinfo_flag[cur_idx] = false; - } - } - // filter out overlapped boxes with lower scores - int ref_item_idx = offset; - int ref_bbox_idx = - share_location - ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) - : beforeNMS_index_array[ref_item_idx]; - - while ((ref_bbox_idx != -1) && ref_item_idx < max_idx) { - T_BBOX ref_bbox[5]; - memcpy(&ref_bbox[0], &bbox_data[ref_bbox_idx * 5], 5 * sizeof(T_BBOX)); - - // Eliminate shared memory RAW hazard - __syncthreads(); - - for (int t = 0; t < TSIZE; t++) { - const int cur_idx = threadIdx.x + blockDim.x * t; - const int item_idx = offset + cur_idx; - - if ((kept_bboxinfo_flag[cur_idx]) && (item_idx > ref_item_idx)) { - // TODO: may need to add bool normalized as argument, HERE true means - // normalized - if (single_box_iou_rotated(&ref_bbox[0], loc_bbox + t * 5) > nms_threshold) { - kept_bboxinfo_flag[cur_idx] = false; - } + // store data + for (int t = 0; t < TSIZE; t++) + { + const int cur_idx = threadIdx.x + blockDim.x * t; + const int read_item_idx = offset + cur_idx; + const int write_item_idx = (i * num_classes * top_k + blockIdx.x * top_k) + cur_idx; + /* + * If not not keeping the bbox + * Set the score to 0 + * Set the bounding box index to -1 + */ + if (read_item_idx < max_idx) + { + afterNMS_scores[write_item_idx] = + kept_bboxinfo_flag[cur_idx] ? beforeNMS_scores[read_item_idx] : 0.0f; + afterNMS_index_array[write_item_idx] = kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1; + } } - } - __syncthreads(); - - do { - ref_item_idx++; - } while (ref_item_idx < max_idx && !kept_bboxinfo_flag[ref_item_idx - offset]); - - ref_bbox_idx = - share_location - ? (beforeNMS_index_array[ref_item_idx] % num_preds_per_class + bbox_idx_offset) - : beforeNMS_index_array[ref_item_idx]; - } - - // store data - for (int t = 0; t < TSIZE; t++) { - const int cur_idx = threadIdx.x + blockDim.x * t; - const int read_item_idx = offset + cur_idx; - const int write_item_idx = (i * num_classes * top_k + blockIdx.x * top_k) + cur_idx; - /* - * If not not keeping the bbox - * Set the score to 0 - * Set the bounding box index to -1 - */ - if (read_item_idx < max_idx) { - afterNMS_scores[write_item_idx] = - kept_bboxinfo_flag[cur_idx] ? beforeNMS_scores[read_item_idx] : 0.0f; - afterNMS_index_array[write_item_idx] = kept_bboxinfo_flag[cur_idx] ? loc_bboxIndex[t] : -1; - } } - } } -template -pluginStatus_t allClassRotatedNMS_gpu(cudaStream_t stream, const int num, const int num_classes, - const int num_preds_per_class, const int top_k, - const float nms_threshold, const bool share_location, - const bool isNormalized, void *bbox_data, - void *beforeNMS_scores, void *beforeNMS_index_array, - void *afterNMS_scores, void *afterNMS_index_array) { +template +pluginStatus_t allClassRotatedNMS_gpu(cudaStream_t stream, const int num, const int num_classes, const int num_preds_per_class, const int top_k, const float nms_threshold, const bool share_location, const bool isNormalized, void* bbox_data, void* beforeNMS_scores, void* beforeNMS_index_array, void* afterNMS_scores, void* afterNMS_index_array) +{ #define P(tsize) allClassRotatedNMS_kernel - void (*kernel[10])(const int, const int, const int, const int, const float, const bool, - const bool, float *, T_SCORE *, int *, T_SCORE *, int *) = { - P(1), P(2), P(3), P(4), P(5), P(6), P(7), P(8), P(9), P(10), - }; - - const int BS = 512; - const int GS = num_classes; - const int t_size = (top_k + BS - 1) / BS; - - ASSERT(t_size <= 10); - kernel[t_size - 1]<<>>( - num, num_classes, num_preds_per_class, top_k, nms_threshold, share_location, isNormalized, - (T_BBOX *)bbox_data, (T_SCORE *)beforeNMS_scores, (int *)beforeNMS_index_array, - (T_SCORE *)afterNMS_scores, (int *)afterNMS_index_array); - - CSC(cudaGetLastError(), STATUS_FAILURE); - return STATUS_SUCCESS; + void (*kernel[10])(const int, const int, const int, const int, const float, const bool, const bool, float*, T_SCORE*, int*, T_SCORE*, int*) = { + P(1), + P(2), + P(3), + P(4), + P(5), + P(6), + P(7), + P(8), + P(9), + P(10), + }; + + const int BS = 512; + const int GS = num_classes; + const int t_size = (top_k + BS - 1) / BS; + + ASSERT(t_size <= 10); + kernel[t_size - 1]<<>>( + num, + num_classes, + num_preds_per_class, + top_k, + nms_threshold, + share_location, + isNormalized, + (T_BBOX*)bbox_data, + (T_SCORE*)beforeNMS_scores, + (int*)beforeNMS_index_array, + (T_SCORE*)afterNMS_scores, + (int*)afterNMS_index_array); + + CSC(cudaGetLastError(), STATUS_FAILURE); + return STATUS_SUCCESS; } // allClassNMS LAUNCH CONFIG -typedef pluginStatus_t (*rotatedNmsFunc)(cudaStream_t, const int, const int, const int, const int, - const float, const bool, const bool, void *, void *, - void *, void *, void *); - -struct rotatedNmsLaunchConfig { - DataType t_score; - DataType t_bbox; - rotatedNmsFunc function; - - rotatedNmsLaunchConfig(DataType t_score, DataType t_bbox) : t_score(t_score), t_bbox(t_bbox) {} - rotatedNmsLaunchConfig(DataType t_score, DataType t_bbox, rotatedNmsFunc function) - : t_score(t_score), t_bbox(t_bbox), function(function) {} - bool operator==(const rotatedNmsLaunchConfig &other) { - return t_score == other.t_score && t_bbox == other.t_bbox; - } +typedef pluginStatus_t (*rotatedNmsFunc)(cudaStream_t, const int, const int, const int, const int, const float, const bool, const bool, void*, void*, void*, void*, void*); + +struct rotatedNmsLaunchConfig +{ + DataType t_score; + DataType t_bbox; + rotatedNmsFunc function; + + rotatedNmsLaunchConfig(DataType t_score, DataType t_bbox) + : t_score(t_score) + , t_bbox(t_bbox) + { + } + rotatedNmsLaunchConfig(DataType t_score, DataType t_bbox, rotatedNmsFunc function) + : t_score(t_score) + , t_bbox(t_bbox) + , function(function) + { + } + bool operator==(const rotatedNmsLaunchConfig& other) + { + return t_score == other.t_score && t_bbox == other.t_bbox; + } }; static std::vector rotatedNmsFuncVec; -bool rotatedNmsInit() { - rotatedNmsFuncVec.push_back(rotatedNmsLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, - allClassRotatedNMS_gpu)); - return true; +bool rotatedNmsInit() +{ + rotatedNmsFuncVec.push_back(rotatedNmsLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, allClassRotatedNMS_gpu)); + return true; } -static bool initialized = rotatedNmsInit(); - -pluginStatus_t allClassRotatedNMS(cudaStream_t stream, const int num, const int num_classes, - const int num_preds_per_class, const int top_k, - const float nms_threshold, const bool share_location, - const bool isNormalized, const DataType DT_SCORE, - const DataType DT_BBOX, void *bbox_data, void *beforeNMS_scores, - void *beforeNMS_index_array, void *afterNMS_scores, - void *afterNMS_index_array, bool) { - auto __cuda_arch__ = get_cuda_arch(0); // assume there is only one arch 7.2 device - if (__cuda_arch__ == 720 && top_k >= 1000) { - printf("Warning: pre_top_k need to be reduced for devices with arch 7.2, got pre_top_k=%d\n", - top_k); - } - rotatedNmsLaunchConfig lc(DT_SCORE, DT_BBOX); - - for (unsigned i = 0; i < rotatedNmsFuncVec.size(); ++i) { - if (lc == rotatedNmsFuncVec[i]) { - DEBUG_PRINTF("all class rotated nms kernel %d\n", i); - return rotatedNmsFuncVec[i].function(stream, num, num_classes, num_preds_per_class, top_k, - nms_threshold, share_location, isNormalized, bbox_data, - beforeNMS_scores, beforeNMS_index_array, afterNMS_scores, - afterNMS_index_array); +static bool initialized = rotatedNmsInit(); + +pluginStatus_t allClassRotatedNMS(cudaStream_t stream, const int num, const int num_classes, const int num_preds_per_class, const int top_k, const float nms_threshold, const bool share_location, const bool isNormalized, const DataType DT_SCORE, const DataType DT_BBOX, void* bbox_data, void* beforeNMS_scores, void* beforeNMS_index_array, void* afterNMS_scores, void* afterNMS_index_array, bool) +{ + auto __cuda_arch__ = get_cuda_arch(0); // assume there is only one arch 7.2 device + if (__cuda_arch__ == 720 && top_k >= 1000) + { + printf("Warning: pre_top_k need to be reduced for devices with arch 7.2, got pre_top_k=%d\n", + top_k); + } + rotatedNmsLaunchConfig lc(DT_SCORE, DT_BBOX); + + for (unsigned i = 0; i < rotatedNmsFuncVec.size(); ++i) + { + if (lc == rotatedNmsFuncVec[i]) + { + DEBUG_PRINTF("all class rotated nms kernel %d\n", i); + return rotatedNmsFuncVec[i].function(stream, num, num_classes, num_preds_per_class, top_k, nms_threshold, share_location, isNormalized, bbox_data, beforeNMS_scores, beforeNMS_index_array, afterNMS_scores, afterNMS_index_array); + } } - } - return STATUS_BAD_PARAM; + return STATUS_BAD_PARAM; } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/batched_nms_kernel.cpp b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/batched_nms_kernel.cpp index 71cb7a8592..903624d86b 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/batched_nms_kernel.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/batched_nms_kernel.cpp @@ -3,123 +3,111 @@ // https://github.com/NVIDIA/TensorRT/tree/master/plugin/batchedNMSPlugin #include "nms/batched_nms_kernel.hpp" -pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatchBoxesSize, - const int perBatchScoresSize, const bool shareLocation, - const int backgroundLabelId, const int numPredsPerClass, - const int numClasses, const int topK, const int keepTopK, - const float scoreThreshold, const float iouThreshold, - const DataType DT_BBOX, const void* locData, const DataType DT_SCORE, - const void* confData, void* nmsedDets, void* nmsedLabels, - void* nmsedIndex, void* workspace, bool isNormalized, bool confSigmoid, - bool clipBoxes, bool rotated) { - const int topKVal = topK < 0 ? numPredsPerClass : topK; - const int keepTopKVal = keepTopK < 0 ? numPredsPerClass : keepTopK; - // locCount = batch_size * number_boxes_per_sample * 4 - const int locCount = N * perBatchBoxesSize; - /* - * shareLocation - * Bounding box are shared among all classes, i.e., a bounding box could be - * classified as any candidate class. Otherwise Bounding box are designed for - * specific classes, i.e., a bounding box could be classified as one certain - * class or not (binary classification). - */ - const int numLocClasses = shareLocation ? 1 : numClasses; - - size_t bboxDataSize = detectionForwardBBoxDataSize(N, perBatchBoxesSize, DataType::kFLOAT); - void* bboxDataRaw = workspace; - cudaMemcpyAsync(bboxDataRaw, locData, bboxDataSize, cudaMemcpyDeviceToDevice, stream); - pluginStatus_t status; - - /* - * bboxDataRaw format: - * [batch size, numPriors (per sample), numLocClasses, 4] - */ - // float for now - void* bboxData; - size_t bboxPermuteSize = - detectionForwardBBoxPermuteSize(shareLocation, N, perBatchBoxesSize, DataType::kFLOAT); - void* bboxPermute = nextWorkspacePtr((int8_t*)bboxDataRaw, bboxDataSize); - - /* - * After permutation, bboxData format: - * [batch_size, numLocClasses, numPriors (per sample) (numPredsPerClass), 4] - * This is equivalent to swapping axis - */ - if (!shareLocation) { - status = permuteData(stream, locCount, numLocClasses, numPredsPerClass, rotated ? 5 : 4, - DataType::kFLOAT, false, bboxDataRaw, bboxPermute); +pluginStatus_t nmsInference(cudaStream_t stream, const int N, const int perBatchBoxesSize, const int perBatchScoresSize, const bool shareLocation, const int backgroundLabelId, const int numPredsPerClass, const int numClasses, const int topK, const int keepTopK, const float scoreThreshold, const float iouThreshold, const DataType DT_BBOX, const void* locData, const DataType DT_SCORE, const void* confData, void* nmsedDets, void* nmsedLabels, void* nmsedIndex, void* workspace, bool isNormalized, bool confSigmoid, bool clipBoxes, bool rotated) +{ + const int topKVal = topK < 0 ? numPredsPerClass : topK; + const int keepTopKVal = keepTopK < 0 ? numPredsPerClass : keepTopK; + // locCount = batch_size * number_boxes_per_sample * 4 + const int locCount = N * perBatchBoxesSize; + /* + * shareLocation + * Bounding box are shared among all classes, i.e., a bounding box could be + * classified as any candidate class. Otherwise Bounding box are designed for + * specific classes, i.e., a bounding box could be classified as one certain + * class or not (binary classification). + */ + const int numLocClasses = shareLocation ? 1 : numClasses; + + size_t bboxDataSize = detectionForwardBBoxDataSize(N, perBatchBoxesSize, DataType::kFLOAT); + void* bboxDataRaw = workspace; + cudaMemcpyAsync(bboxDataRaw, locData, bboxDataSize, cudaMemcpyDeviceToDevice, stream); + pluginStatus_t status; + + /* + * bboxDataRaw format: + * [batch size, numPriors (per sample), numLocClasses, 4] + */ + // float for now + void* bboxData; + size_t bboxPermuteSize = + detectionForwardBBoxPermuteSize(shareLocation, N, perBatchBoxesSize, DataType::kFLOAT); + void* bboxPermute = nextWorkspacePtr((int8_t*)bboxDataRaw, bboxDataSize); + + /* + * After permutation, bboxData format: + * [batch_size, numLocClasses, numPriors (per sample) (numPredsPerClass), 4] + * This is equivalent to swapping axis + */ + if (!shareLocation) + { + status = permuteData(stream, locCount, numLocClasses, numPredsPerClass, rotated ? 5 : 4, DataType::kFLOAT, false, bboxDataRaw, bboxPermute); + ASSERT_FAILURE(status == STATUS_SUCCESS); + bboxData = bboxPermute; + } + /* + * If shareLocation, numLocClasses = 1 + * No need to permute data on linear memory + */ + else + { + bboxData = bboxDataRaw; + } + + /* + * Conf data format + * [batch size, numPriors * param.numClasses, 1, 1] + */ + const int numScores = N * perBatchScoresSize; + size_t totalScoresSize = detectionForwardPreNMSSize(N, perBatchScoresSize); + void* scores = nextWorkspacePtr((int8_t*)bboxPermute, bboxPermuteSize); + + // need a conf_scores + /* + * After permutation, bboxData format: + * [batch_size, numClasses, numPredsPerClass, 1] + */ + status = permuteData(stream, numScores, numClasses, numPredsPerClass, 1, DataType::kFLOAT, confSigmoid, confData, scores); ASSERT_FAILURE(status == STATUS_SUCCESS); - bboxData = bboxPermute; - } - /* - * If shareLocation, numLocClasses = 1 - * No need to permute data on linear memory - */ - else { - bboxData = bboxDataRaw; - } - - /* - * Conf data format - * [batch size, numPriors * param.numClasses, 1, 1] - */ - const int numScores = N * perBatchScoresSize; - size_t totalScoresSize = detectionForwardPreNMSSize(N, perBatchScoresSize); - void* scores = nextWorkspacePtr((int8_t*)bboxPermute, bboxPermuteSize); - - // need a conf_scores - /* - * After permutation, bboxData format: - * [batch_size, numClasses, numPredsPerClass, 1] - */ - status = permuteData(stream, numScores, numClasses, numPredsPerClass, 1, DataType::kFLOAT, - confSigmoid, confData, scores); - ASSERT_FAILURE(status == STATUS_SUCCESS); - - size_t indicesSize = detectionForwardPreNMSSize(N, perBatchScoresSize); - void* indices = nextWorkspacePtr((int8_t*)scores, totalScoresSize); - - size_t postNMSScoresSize = detectionForwardPostNMSSize(N, numClasses, topKVal); - size_t postNMSIndicesSize = detectionForwardPostNMSSize(N, numClasses, topKVal); - void* postNMSScores = nextWorkspacePtr((int8_t*)indices, indicesSize); - void* postNMSIndices = nextWorkspacePtr((int8_t*)postNMSScores, postNMSScoresSize); - - void* sortingWorkspace = nextWorkspacePtr((int8_t*)postNMSIndices, postNMSIndicesSize); - // Sort the scores so that the following NMS could be applied. - - status = sortScoresPerClass(stream, N, numClasses, numPredsPerClass, backgroundLabelId, - scoreThreshold, DataType::kFLOAT, scores, indices, sortingWorkspace); - ASSERT_FAILURE(status == STATUS_SUCCESS); - - // This is set to true as the input bounding boxes are of the format [ymin, - // xmin, ymax, xmax]. The default implementation assumes [xmin, ymin, xmax, - // ymax] - bool flipXY = false; - // NMS - if (rotated) { - status = allClassRotatedNMS(stream, N, numClasses, numPredsPerClass, topKVal, iouThreshold, - shareLocation, isNormalized, DataType::kFLOAT, DataType::kFLOAT, - bboxData, scores, indices, postNMSScores, postNMSIndices, flipXY); - } else { - status = allClassNMS(stream, N, numClasses, numPredsPerClass, topKVal, iouThreshold, - shareLocation, isNormalized, DataType::kFLOAT, DataType::kFLOAT, bboxData, - scores, indices, postNMSScores, postNMSIndices, flipXY); - } - - ASSERT_FAILURE(status == STATUS_SUCCESS); - - // Sort the bounding boxes after NMS using scores - status = sortScoresPerImage(stream, N, numClasses * topKVal, DataType::kFLOAT, postNMSScores, - postNMSIndices, scores, indices, sortingWorkspace); - - ASSERT_FAILURE(status == STATUS_SUCCESS); - - // Gather data from the sorted bounding boxes after NMS - status = gatherNMSOutputs(stream, shareLocation, N, numPredsPerClass, numClasses, topKVal, - keepTopKVal, DataType::kFLOAT, DataType::kFLOAT, indices, scores, - bboxData, nmsedDets, nmsedLabels, nmsedIndex, clipBoxes, rotated); - - ASSERT_FAILURE(status == STATUS_SUCCESS); - - return STATUS_SUCCESS; + + size_t indicesSize = detectionForwardPreNMSSize(N, perBatchScoresSize); + void* indices = nextWorkspacePtr((int8_t*)scores, totalScoresSize); + + size_t postNMSScoresSize = detectionForwardPostNMSSize(N, numClasses, topKVal); + size_t postNMSIndicesSize = detectionForwardPostNMSSize(N, numClasses, topKVal); + void* postNMSScores = nextWorkspacePtr((int8_t*)indices, indicesSize); + void* postNMSIndices = nextWorkspacePtr((int8_t*)postNMSScores, postNMSScoresSize); + + void* sortingWorkspace = nextWorkspacePtr((int8_t*)postNMSIndices, postNMSIndicesSize); + // Sort the scores so that the following NMS could be applied. + + status = sortScoresPerClass(stream, N, numClasses, numPredsPerClass, backgroundLabelId, scoreThreshold, DataType::kFLOAT, scores, indices, sortingWorkspace); + ASSERT_FAILURE(status == STATUS_SUCCESS); + + // This is set to true as the input bounding boxes are of the format [ymin, + // xmin, ymax, xmax]. The default implementation assumes [xmin, ymin, xmax, + // ymax] + bool flipXY = false; + // NMS + if (rotated) + { + status = allClassRotatedNMS(stream, N, numClasses, numPredsPerClass, topKVal, iouThreshold, shareLocation, isNormalized, DataType::kFLOAT, DataType::kFLOAT, bboxData, scores, indices, postNMSScores, postNMSIndices, flipXY); + } + else + { + status = allClassNMS(stream, N, numClasses, numPredsPerClass, topKVal, iouThreshold, shareLocation, isNormalized, DataType::kFLOAT, DataType::kFLOAT, bboxData, scores, indices, postNMSScores, postNMSIndices, flipXY); + } + + ASSERT_FAILURE(status == STATUS_SUCCESS); + + // Sort the bounding boxes after NMS using scores + status = sortScoresPerImage(stream, N, numClasses * topKVal, DataType::kFLOAT, postNMSScores, postNMSIndices, scores, indices, sortingWorkspace); + + ASSERT_FAILURE(status == STATUS_SUCCESS); + + // Gather data from the sorted bounding boxes after NMS + status = gatherNMSOutputs(stream, shareLocation, N, numPredsPerClass, numClasses, topKVal, keepTopKVal, DataType::kFLOAT, DataType::kFLOAT, indices, scores, bboxData, nmsedDets, nmsedLabels, nmsedIndex, clipBoxes, rotated); + + ASSERT_FAILURE(status == STATUS_SUCCESS); + + return STATUS_SUCCESS; } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu index 58419f8c16..22d901565c 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/gatherNMSOutputs.cu @@ -6,159 +6,170 @@ #include "nms/kernel.h" #include "trt_plugin_helper.hpp" -template +template __launch_bounds__(nthds_per_cta) __global__ - void gatherNMSOutputs_kernel(const bool shareLocation, const int numImages, - const int numPredsPerClass, const int numClasses, const int topK, - const int keepTopK, const int *indices, const T_SCORE *scores, - const T_BBOX *bboxData, T_BBOX *nmsedDets, int *nmsedLabels, - int *nmsedIndex, bool clipBoxes) { - if (keepTopK > topK) return; - for (int i = blockIdx.x * nthds_per_cta + threadIdx.x; i < numImages * keepTopK; - i += gridDim.x * nthds_per_cta) { - const int imgId = i / keepTopK; - const int detId = i % keepTopK; - const int offset = imgId * numClasses * topK; - const int index = indices[offset + detId]; - const T_SCORE score = scores[offset + detId]; - if (index == -1) { - nmsedLabels[i] = -1; - if (nmsedIndex != nullptr) { - nmsedIndex[i] = -1; - } - if (rotated) { - nmsedDets[i * 6] = 0; - nmsedDets[i * 6 + 1] = 0; - nmsedDets[i * 6 + 2] = 0; - nmsedDets[i * 6 + 3] = 0; - nmsedDets[i * 6 + 4] = 0; - nmsedDets[i * 6 + 5] = 0; - } else { - nmsedDets[i * 5] = 0; - nmsedDets[i * 5 + 1] = 0; - nmsedDets[i * 5 + 2] = 0; - nmsedDets[i * 5 + 3] = 0; - nmsedDets[i * 5 + 4] = 0; - } - } else { - const int bboxOffset = - imgId * (shareLocation ? numPredsPerClass : (numClasses * numPredsPerClass)); - nmsedLabels[i] = (index % (numClasses * numPredsPerClass)) / numPredsPerClass; // label - if (rotated) { - const int bboxId = ((shareLocation ? (index % numPredsPerClass) - : index % (numClasses * numPredsPerClass)) + - bboxOffset) * - 5; - if (nmsedIndex != nullptr) { - nmsedIndex[i] = bboxId / 5 - bboxOffset; + void gatherNMSOutputs_kernel(const bool shareLocation, const int numImages, const int numPredsPerClass, const int numClasses, const int topK, const int keepTopK, const int* indices, const T_SCORE* scores, const T_BBOX* bboxData, T_BBOX* nmsedDets, int* nmsedLabels, int* nmsedIndex, bool clipBoxes) +{ + if (keepTopK > topK) return; + for (int i = blockIdx.x * nthds_per_cta + threadIdx.x; i < numImages * keepTopK; + i += gridDim.x * nthds_per_cta) + { + const int imgId = i / keepTopK; + const int detId = i % keepTopK; + const int offset = imgId * numClasses * topK; + const int index = indices[offset + detId]; + const T_SCORE score = scores[offset + detId]; + if (index == -1) + { + nmsedLabels[i] = -1; + if (nmsedIndex != nullptr) + { + nmsedIndex[i] = -1; + } + if (rotated) + { + nmsedDets[i * 6] = 0; + nmsedDets[i * 6 + 1] = 0; + nmsedDets[i * 6 + 2] = 0; + nmsedDets[i * 6 + 3] = 0; + nmsedDets[i * 6 + 4] = 0; + nmsedDets[i * 6 + 5] = 0; + } + else + { + nmsedDets[i * 5] = 0; + nmsedDets[i * 5 + 1] = 0; + nmsedDets[i * 5 + 2] = 0; + nmsedDets[i * 5 + 3] = 0; + nmsedDets[i * 5 + 4] = 0; + } } - // clipped bbox xmin - nmsedDets[i * 6] = - clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId]; - // clipped bbox ymin - nmsedDets[i * 6 + 1] = clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 1]; - // clipped bbox xmax - nmsedDets[i * 6 + 2] = clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 2]; - // clipped bbox ymax - nmsedDets[i * 6 + 3] = clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 3]; - // clipped bbox angle - nmsedDets[i * 6 + 4] = clipBoxes ? max(min(bboxData[bboxId + 4], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 4]; - nmsedDets[i * 6 + 5] = score; - } else { - const int bboxId = ((shareLocation ? (index % numPredsPerClass) - : index % (numClasses * numPredsPerClass)) + - bboxOffset) * - 4; - if (nmsedIndex != nullptr) { - nmsedIndex[i] = bboxId / 4 - bboxOffset; + else + { + const int bboxOffset = + imgId * (shareLocation ? numPredsPerClass : (numClasses * numPredsPerClass)); + nmsedLabels[i] = (index % (numClasses * numPredsPerClass)) / numPredsPerClass; // label + if (rotated) + { + const int bboxId = ((shareLocation ? (index % numPredsPerClass) : index % (numClasses * numPredsPerClass)) + + bboxOffset) * + 5; + if (nmsedIndex != nullptr) + { + nmsedIndex[i] = bboxId / 5 - bboxOffset; + } + // clipped bbox xmin + nmsedDets[i * 6] = + clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId]; + // clipped bbox ymin + nmsedDets[i * 6 + 1] = clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 1]; + // clipped bbox xmax + nmsedDets[i * 6 + 2] = clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 2]; + // clipped bbox ymax + nmsedDets[i * 6 + 3] = clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 3]; + // clipped bbox angle + nmsedDets[i * 6 + 4] = clipBoxes ? max(min(bboxData[bboxId + 4], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 4]; + nmsedDets[i * 6 + 5] = score; + } + else + { + const int bboxId = ((shareLocation ? (index % numPredsPerClass) : index % (numClasses * numPredsPerClass)) + + bboxOffset) * + 4; + if (nmsedIndex != nullptr) + { + nmsedIndex[i] = bboxId / 4 - bboxOffset; + } + // clipped bbox xmin + nmsedDets[i * 5] = + clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId]; + // clipped bbox ymin + nmsedDets[i * 5 + 1] = clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 1]; + // clipped bbox xmax + nmsedDets[i * 5 + 2] = clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 2]; + // clipped bbox ymax + nmsedDets[i * 5 + 3] = clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId + 3]; + nmsedDets[i * 5 + 4] = score; + } } - // clipped bbox xmin - nmsedDets[i * 5] = - clipBoxes ? max(min(bboxData[bboxId], T_BBOX(1.)), T_BBOX(0.)) : bboxData[bboxId]; - // clipped bbox ymin - nmsedDets[i * 5 + 1] = clipBoxes ? max(min(bboxData[bboxId + 1], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 1]; - // clipped bbox xmax - nmsedDets[i * 5 + 2] = clipBoxes ? max(min(bboxData[bboxId + 2], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 2]; - // clipped bbox ymax - nmsedDets[i * 5 + 3] = clipBoxes ? max(min(bboxData[bboxId + 3], T_BBOX(1.)), T_BBOX(0.)) - : bboxData[bboxId + 3]; - nmsedDets[i * 5 + 4] = score; - } } - } } -template -pluginStatus_t gatherNMSOutputs_gpu(cudaStream_t stream, const bool shareLocation, - const int numImages, const int numPredsPerClass, - const int numClasses, const int topK, const int keepTopK, - const void *indices, const void *scores, const void *bboxData, - void *nmsedDets, void *nmsedLabels, void *nmsedIndex, - bool clipBoxes) { - const int BS = 32; - const int GS = 32; - gatherNMSOutputs_kernel<<>>( - shareLocation, numImages, numPredsPerClass, numClasses, topK, keepTopK, (int *)indices, - (T_SCORE *)scores, (T_BBOX *)bboxData, (T_BBOX *)nmsedDets, (int *)nmsedLabels, - (int *)nmsedIndex, clipBoxes); +template +pluginStatus_t gatherNMSOutputs_gpu(cudaStream_t stream, const bool shareLocation, const int numImages, const int numPredsPerClass, const int numClasses, const int topK, const int keepTopK, const void* indices, const void* scores, const void* bboxData, void* nmsedDets, void* nmsedLabels, void* nmsedIndex, bool clipBoxes) +{ + const int BS = 32; + const int GS = 32; + gatherNMSOutputs_kernel<<>>( + shareLocation, + numImages, + numPredsPerClass, + numClasses, + topK, + keepTopK, + (int*)indices, + (T_SCORE*)scores, + (T_BBOX*)bboxData, + (T_BBOX*)nmsedDets, + (int*)nmsedLabels, + (int*)nmsedIndex, + clipBoxes); - CSC(cudaGetLastError(), STATUS_FAILURE); - return STATUS_SUCCESS; + CSC(cudaGetLastError(), STATUS_FAILURE); + return STATUS_SUCCESS; } // gatherNMSOutputs LAUNCH CONFIG {{{ -typedef pluginStatus_t (*nmsOutFunc)(cudaStream_t, const bool, const int, const int, const int, - const int, const int, const void *, const void *, const void *, - void *, void *, void *, bool); -struct nmsOutLaunchConfig { - DataType t_bbox; - DataType t_score; - bool rotated; - nmsOutFunc function; +typedef pluginStatus_t (*nmsOutFunc)(cudaStream_t, const bool, const int, const int, const int, const int, const int, const void*, const void*, const void*, void*, void*, void*, bool); +struct nmsOutLaunchConfig +{ + DataType t_bbox; + DataType t_score; + bool rotated; + nmsOutFunc function; - nmsOutLaunchConfig(DataType t_bbox, DataType t_score, bool rotated) - : t_bbox(t_bbox), t_score(t_score), rotated(rotated) {} - nmsOutLaunchConfig(DataType t_bbox, DataType t_score, bool rotated, nmsOutFunc function) - : t_bbox(t_bbox), t_score(t_score), rotated(rotated), function(function) {} - bool operator==(const nmsOutLaunchConfig &other) { - return t_bbox == other.t_bbox && t_score == other.t_score && rotated == other.rotated; - } + nmsOutLaunchConfig(DataType t_bbox, DataType t_score, bool rotated) + : t_bbox(t_bbox) + , t_score(t_score) + , rotated(rotated) + { + } + nmsOutLaunchConfig(DataType t_bbox, DataType t_score, bool rotated, nmsOutFunc function) + : t_bbox(t_bbox) + , t_score(t_score) + , rotated(rotated) + , function(function) + { + } + bool operator==(const nmsOutLaunchConfig& other) + { + return t_bbox == other.t_bbox && t_score == other.t_score && rotated == other.rotated; + } }; using nvinfer1::DataType; static std::vector nmsOutFuncVec; -bool nmsOutputInit() { - nmsOutFuncVec.push_back(nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, false, - gatherNMSOutputs_gpu)); - nmsOutFuncVec.push_back(nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, true, - gatherNMSOutputs_gpu)); - return true; +bool nmsOutputInit() +{ + nmsOutFuncVec.push_back(nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, false, gatherNMSOutputs_gpu)); + nmsOutFuncVec.push_back(nmsOutLaunchConfig(DataType::kFLOAT, DataType::kFLOAT, true, gatherNMSOutputs_gpu)); + return true; } -static bool initialized = nmsOutputInit(); +static bool initialized = nmsOutputInit(); -pluginStatus_t gatherNMSOutputs(cudaStream_t stream, const bool shareLocation, const int numImages, - const int numPredsPerClass, const int numClasses, const int topK, - const int keepTopK, const DataType DT_BBOX, const DataType DT_SCORE, - const void *indices, const void *scores, const void *bboxData, - void *nmsedDets, void *nmsedLabels, void *nmsedIndex, - bool clipBoxes, bool rotated) { - nmsOutLaunchConfig lc = nmsOutLaunchConfig(DT_BBOX, DT_SCORE, rotated); - for (unsigned i = 0; i < nmsOutFuncVec.size(); ++i) { - if (lc == nmsOutFuncVec[i]) { - DEBUG_PRINTF("gatherNMSOutputs kernel %d\n", i); - return nmsOutFuncVec[i].function(stream, shareLocation, numImages, numPredsPerClass, - numClasses, topK, keepTopK, indices, scores, bboxData, - nmsedDets, nmsedLabels, nmsedIndex, clipBoxes); +pluginStatus_t gatherNMSOutputs(cudaStream_t stream, const bool shareLocation, const int numImages, const int numPredsPerClass, const int numClasses, const int topK, const int keepTopK, const DataType DT_BBOX, const DataType DT_SCORE, const void* indices, const void* scores, const void* bboxData, void* nmsedDets, void* nmsedLabels, void* nmsedIndex, bool clipBoxes, bool rotated) +{ + nmsOutLaunchConfig lc = nmsOutLaunchConfig(DT_BBOX, DT_SCORE, rotated); + for (unsigned i = 0; i < nmsOutFuncVec.size(); ++i) + { + if (lc == nmsOutFuncVec[i]) + { + DEBUG_PRINTF("gatherNMSOutputs kernel %d\n", i); + return nmsOutFuncVec[i].function(stream, shareLocation, numImages, numPredsPerClass, numClasses, topK, keepTopK, indices, scores, bboxData, nmsedDets, nmsedLabels, nmsedIndex, clipBoxes); + } } - } - return STATUS_BAD_PARAM; + return STATUS_BAD_PARAM; } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/kernel.cu index f0e1c9d0cc..e13f8969d4 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/kernel.cu @@ -12,96 +12,109 @@ #define CUDA_MEM_ALIGN 256 // return cuda arch -size_t get_cuda_arch(int devID) { - int computeMode = -1, major = 0, minor = 0; - CUASSERT(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, devID)); - CUASSERT(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, devID)); - CUASSERT(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, devID)); - return major * 100 + minor * 10; +size_t get_cuda_arch(int devID) +{ + int computeMode = -1, major = 0, minor = 0; + CUASSERT(cudaDeviceGetAttribute(&computeMode, cudaDevAttrComputeMode, devID)); + CUASSERT(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, devID)); + CUASSERT(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, devID)); + return major * 100 + minor * 10; } // ALIGNPTR -int8_t *alignPtr(int8_t *ptr, uintptr_t to) { - uintptr_t addr = (uintptr_t)ptr; - if (addr % to) { - addr += to - addr % to; - } - return (int8_t *)addr; +int8_t* alignPtr(int8_t* ptr, uintptr_t to) +{ + uintptr_t addr = (uintptr_t)ptr; + if (addr % to) + { + addr += to - addr % to; + } + return (int8_t*)addr; } // NEXTWORKSPACEPTR -int8_t *nextWorkspacePtr(int8_t *ptr, uintptr_t previousWorkspaceSize) { - uintptr_t addr = (uintptr_t)ptr; - addr += previousWorkspaceSize; - return alignPtr((int8_t *)addr, CUDA_MEM_ALIGN); +int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize) +{ + uintptr_t addr = (uintptr_t)ptr; + addr += previousWorkspaceSize; + return alignPtr((int8_t*)addr, CUDA_MEM_ALIGN); } // CALCULATE TOTAL WORKSPACE SIZE -size_t calculateTotalWorkspaceSize(size_t *workspaces, int count) { - size_t total = 0; - for (int i = 0; i < count; i++) { - total += workspaces[i]; - if (workspaces[i] % CUDA_MEM_ALIGN) { - total += CUDA_MEM_ALIGN - (workspaces[i] % CUDA_MEM_ALIGN); +size_t calculateTotalWorkspaceSize(size_t* workspaces, int count) +{ + size_t total = 0; + for (int i = 0; i < count; i++) + { + total += workspaces[i]; + if (workspaces[i] % CUDA_MEM_ALIGN) + { + total += CUDA_MEM_ALIGN - (workspaces[i] % CUDA_MEM_ALIGN); + } } - } - return total; + return total; } using nvinfer1::DataType; -template +template __launch_bounds__(nthds_per_cta) __global__ - void setUniformOffsets_kernel(const int num_segments, const int offset, int *d_offsets) { - const int idx = blockIdx.x * nthds_per_cta + threadIdx.x; - if (idx <= num_segments) d_offsets[idx] = idx * offset; + void setUniformOffsets_kernel(const int num_segments, const int offset, int* d_offsets) +{ + const int idx = blockIdx.x * nthds_per_cta + threadIdx.x; + if (idx <= num_segments) d_offsets[idx] = idx * offset; } -void setUniformOffsets(cudaStream_t stream, const int num_segments, const int offset, - int *d_offsets) { - const int BS = 32; - const int GS = (num_segments + 1 + BS - 1) / BS; - setUniformOffsets_kernel<<>>(num_segments, offset, d_offsets); +void setUniformOffsets(cudaStream_t stream, const int num_segments, const int offset, int* d_offsets) +{ + const int BS = 32; + const int GS = (num_segments + 1 + BS - 1) / BS; + setUniformOffsets_kernel<<>>(num_segments, offset, d_offsets); } -size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX) { - if (DT_BBOX == DataType::kFLOAT) { - return N * C1 * sizeof(float); - } +size_t detectionForwardBBoxDataSize(int N, int C1, DataType DT_BBOX) +{ + if (DT_BBOX == DataType::kFLOAT) + { + return N * C1 * sizeof(float); + } - printf("Only FP32 type bounding boxes are supported.\n"); - return (size_t)-1; + printf("Only FP32 type bounding boxes are supported.\n"); + return (size_t)-1; } -size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1, DataType DT_BBOX) { - if (DT_BBOX == DataType::kFLOAT) { - return shareLocation ? 0 : N * C1 * sizeof(float); - } - printf("Only FP32 type bounding boxes are supported.\n"); - return (size_t)-1; +size_t detectionForwardBBoxPermuteSize(bool shareLocation, int N, int C1, DataType DT_BBOX) +{ + if (DT_BBOX == DataType::kFLOAT) + { + return shareLocation ? 0 : N * C1 * sizeof(float); + } + printf("Only FP32 type bounding boxes are supported.\n"); + return (size_t)-1; } -size_t detectionForwardPreNMSSize(int N, int C2) { - ASSERT(sizeof(float) == sizeof(int)); - return N * C2 * sizeof(float); +size_t detectionForwardPreNMSSize(int N, int C2) +{ + ASSERT(sizeof(float) == sizeof(int)); + return N * C2 * sizeof(float); } -size_t detectionForwardPostNMSSize(int N, int numClasses, int topK) { - ASSERT(sizeof(float) == sizeof(int)); - return N * numClasses * topK * sizeof(float); +size_t detectionForwardPostNMSSize(int N, int numClasses, int topK) +{ + ASSERT(sizeof(float) == sizeof(int)); + return N * numClasses * topK * sizeof(float); } -size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1, int C2, int numClasses, - int numPredsPerClass, int topK, DataType DT_BBOX, - DataType DT_SCORE) { - size_t wss[7]; - wss[0] = detectionForwardBBoxDataSize(N, C1, DT_BBOX); - wss[1] = detectionForwardBBoxPermuteSize(shareLocation, N, C1, DT_BBOX); - wss[2] = detectionForwardPreNMSSize(N, C2); - wss[3] = detectionForwardPreNMSSize(N, C2); - wss[4] = detectionForwardPostNMSSize(N, numClasses, topK); - wss[5] = detectionForwardPostNMSSize(N, numClasses, topK); - wss[6] = std::max(sortScoresPerClassWorkspaceSize(N, numClasses, numPredsPerClass, DT_SCORE), - sortScoresPerImageWorkspaceSize(N, numClasses * topK, DT_SCORE)); - return calculateTotalWorkspaceSize(wss, 7); +size_t detectionInferenceWorkspaceSize(bool shareLocation, int N, int C1, int C2, int numClasses, int numPredsPerClass, int topK, DataType DT_BBOX, DataType DT_SCORE) +{ + size_t wss[7]; + wss[0] = detectionForwardBBoxDataSize(N, C1, DT_BBOX); + wss[1] = detectionForwardBBoxPermuteSize(shareLocation, N, C1, DT_BBOX); + wss[2] = detectionForwardPreNMSSize(N, C2); + wss[3] = detectionForwardPreNMSSize(N, C2); + wss[4] = detectionForwardPostNMSSize(N, numClasses, topK); + wss[5] = detectionForwardPostNMSSize(N, numClasses, topK); + wss[6] = std::max(sortScoresPerClassWorkspaceSize(N, numClasses, numPredsPerClass, DT_SCORE), + sortScoresPerImageWorkspaceSize(N, numClasses * topK, DT_SCORE)); + return calculateTotalWorkspaceSize(wss, 7); } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/permuteData.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/permuteData.cu index 659c964970..23600a3ce8 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/permuteData.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/permuteData.cu @@ -5,72 +5,80 @@ #include "nms/kernel.h" -template +template __launch_bounds__(nthds_per_cta) __global__ - void permuteData_kernel(const int nthreads, const int num_classes, const int num_data, - const int num_dim, bool confSigmoid, const Dtype *data, - Dtype *new_data) { - // data format: [batch_size, num_data, num_classes, num_dim] - for (int index = blockIdx.x * nthds_per_cta + threadIdx.x; index < nthreads; - index += nthds_per_cta * gridDim.x) { - const int i = index % num_dim; - const int c = (index / num_dim) % num_classes; - const int d = (index / num_dim / num_classes) % num_data; - const int n = index / num_dim / num_classes / num_data; - const int new_index = ((n * num_classes + c) * num_data + d) * num_dim + i; - float result = data[index]; - if (confSigmoid) result = exp(result) / (1 + exp(result)); + void permuteData_kernel(const int nthreads, const int num_classes, const int num_data, const int num_dim, bool confSigmoid, const Dtype* data, Dtype* new_data) +{ + // data format: [batch_size, num_data, num_classes, num_dim] + for (int index = blockIdx.x * nthds_per_cta + threadIdx.x; index < nthreads; + index += nthds_per_cta * gridDim.x) + { + const int i = index % num_dim; + const int c = (index / num_dim) % num_classes; + const int d = (index / num_dim / num_classes) % num_data; + const int n = index / num_dim / num_classes / num_data; + const int new_index = ((n * num_classes + c) * num_data + d) * num_dim + i; + float result = data[index]; + if (confSigmoid) result = exp(result) / (1 + exp(result)); - new_data[new_index] = result; - } - // new data format: [batch_size, num_classes, num_data, num_dim] + new_data[new_index] = result; + } + // new data format: [batch_size, num_classes, num_data, num_dim] } -template -pluginStatus_t permuteData_gpu(cudaStream_t stream, const int nthreads, const int num_classes, - const int num_data, const int num_dim, bool confSigmoid, - const void *data, void *new_data) { - const int BS = 512; - const int GS = (nthreads + BS - 1) / BS; - permuteData_kernel<<>>(nthreads, num_classes, num_data, num_dim, - confSigmoid, (const Dtype *)data, - (Dtype *)new_data); - CSC(cudaGetLastError(), STATUS_FAILURE); - return STATUS_SUCCESS; +template +pluginStatus_t permuteData_gpu(cudaStream_t stream, const int nthreads, const int num_classes, const int num_data, const int num_dim, bool confSigmoid, const void* data, void* new_data) +{ + const int BS = 512; + const int GS = (nthreads + BS - 1) / BS; + permuteData_kernel<<>>(nthreads, num_classes, num_data, num_dim, confSigmoid, (const Dtype*)data, (Dtype*)new_data); + CSC(cudaGetLastError(), STATUS_FAILURE); + return STATUS_SUCCESS; } // permuteData LAUNCH CONFIG -typedef pluginStatus_t (*pdFunc)(cudaStream_t, const int, const int, const int, const int, bool, - const void *, void *); +typedef pluginStatus_t (*pdFunc)(cudaStream_t, const int, const int, const int, const int, bool, const void*, void*); -struct pdLaunchConfig { - DataType t_data; - pdFunc function; +struct pdLaunchConfig +{ + DataType t_data; + pdFunc function; - pdLaunchConfig(DataType t_data) : t_data(t_data) {} - pdLaunchConfig(DataType t_data, pdFunc function) : t_data(t_data), function(function) {} - bool operator==(const pdLaunchConfig &other) { return t_data == other.t_data; } + pdLaunchConfig(DataType t_data) + : t_data(t_data) + { + } + pdLaunchConfig(DataType t_data, pdFunc function) + : t_data(t_data) + , function(function) + { + } + bool operator==(const pdLaunchConfig& other) + { + return t_data == other.t_data; + } }; static std::vector pdFuncVec; -bool permuteDataInit() { - pdFuncVec.push_back(pdLaunchConfig(DataType::kFLOAT, permuteData_gpu)); - return true; +bool permuteDataInit() +{ + pdFuncVec.push_back(pdLaunchConfig(DataType::kFLOAT, permuteData_gpu)); + return true; } -static bool initialized = permuteDataInit(); +static bool initialized = permuteDataInit(); -pluginStatus_t permuteData(cudaStream_t stream, const int nthreads, const int num_classes, - const int num_data, const int num_dim, const DataType DT_DATA, - bool confSigmoid, const void *data, void *new_data) { - pdLaunchConfig lc = pdLaunchConfig(DT_DATA); - for (unsigned i = 0; i < pdFuncVec.size(); ++i) { - if (lc == pdFuncVec[i]) { - DEBUG_PRINTF("permuteData kernel %d\n", i); - return pdFuncVec[i].function(stream, nthreads, num_classes, num_data, num_dim, confSigmoid, - data, new_data); +pluginStatus_t permuteData(cudaStream_t stream, const int nthreads, const int num_classes, const int num_data, const int num_dim, const DataType DT_DATA, bool confSigmoid, const void* data, void* new_data) +{ + pdLaunchConfig lc = pdLaunchConfig(DT_DATA); + for (unsigned i = 0; i < pdFuncVec.size(); ++i) + { + if (lc == pdFuncVec[i]) + { + DEBUG_PRINTF("permuteData kernel %d\n", i); + return pdFuncVec[i].function(stream, nthreads, num_classes, num_data, num_dim, confSigmoid, data, new_data); + } } - } - return STATUS_BAD_PARAM; + return STATUS_BAD_PARAM; } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerClass.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerClass.cu index e72f040cc9..284974e801 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerClass.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerClass.cu @@ -8,134 +8,166 @@ #include "nms/kernel.h" #include "trt_plugin_helper.hpp" -template +template __launch_bounds__(nthds_per_cta) __global__ - void prepareSortData(const int num, const int num_classes, const int num_preds_per_class, - const int background_label_id, const float confidence_threshold, - T_SCORE *conf_scores_gpu, T_SCORE *temp_scores, int *temp_idx, - int *d_offsets) { - // Prepare scores data for sort - const int cur_idx = blockIdx.x * nthds_per_cta + threadIdx.x; - const int numPredsPerBatch = num_classes * num_preds_per_class; - if (cur_idx < numPredsPerBatch) { - const int class_idx = cur_idx / num_preds_per_class; - for (int i = 0; i < num; i++) { - const int targetIdx = i * numPredsPerBatch + cur_idx; - const T_SCORE score = conf_scores_gpu[targetIdx]; + void prepareSortData(const int num, const int num_classes, const int num_preds_per_class, const int background_label_id, const float confidence_threshold, T_SCORE* conf_scores_gpu, T_SCORE* temp_scores, int* temp_idx, int* d_offsets) +{ + // Prepare scores data for sort + const int cur_idx = blockIdx.x * nthds_per_cta + threadIdx.x; + const int numPredsPerBatch = num_classes * num_preds_per_class; + if (cur_idx < numPredsPerBatch) + { + const int class_idx = cur_idx / num_preds_per_class; + for (int i = 0; i < num; i++) + { + const int targetIdx = i * numPredsPerBatch + cur_idx; + const T_SCORE score = conf_scores_gpu[targetIdx]; - // "Clear" background labeled score and index - // Because we do not care about background - if (class_idx == background_label_id) { - // Set scores to 0 - // Set label = -1 - temp_scores[targetIdx] = 0.0f; - temp_idx[targetIdx] = -1; - conf_scores_gpu[targetIdx] = 0.0f; - } - // "Clear" scores lower than threshold - else { - if (score > confidence_threshold) { - temp_scores[targetIdx] = score; - temp_idx[targetIdx] = cur_idx + i * numPredsPerBatch; - } else { - // Set scores to 0 - // Set label = -1 - temp_scores[targetIdx] = 0.0f; - temp_idx[targetIdx] = -1; - conf_scores_gpu[targetIdx] = 0.0f; - // TODO: HERE writing memory too many times - } - } + // "Clear" background labeled score and index + // Because we do not care about background + if (class_idx == background_label_id) + { + // Set scores to 0 + // Set label = -1 + temp_scores[targetIdx] = 0.0f; + temp_idx[targetIdx] = -1; + conf_scores_gpu[targetIdx] = 0.0f; + } + // "Clear" scores lower than threshold + else + { + if (score > confidence_threshold) + { + temp_scores[targetIdx] = score; + temp_idx[targetIdx] = cur_idx + i * numPredsPerBatch; + } + else + { + // Set scores to 0 + // Set label = -1 + temp_scores[targetIdx] = 0.0f; + temp_idx[targetIdx] = -1; + conf_scores_gpu[targetIdx] = 0.0f; + // TODO: HERE writing memory too many times + } + } - if ((cur_idx % num_preds_per_class) == 0) { - const int offset_ct = i * num_classes + cur_idx / num_preds_per_class; - d_offsets[offset_ct] = offset_ct * num_preds_per_class; - // set the last element in d_offset - if (blockIdx.x == 0 && threadIdx.x == 0) - d_offsets[num * num_classes] = num * numPredsPerBatch; - } + if ((cur_idx % num_preds_per_class) == 0) + { + const int offset_ct = i * num_classes + cur_idx / num_preds_per_class; + d_offsets[offset_ct] = offset_ct * num_preds_per_class; + // set the last element in d_offset + if (blockIdx.x == 0 && threadIdx.x == 0) + d_offsets[num * num_classes] = num * numPredsPerBatch; + } + } } - } } -template -pluginStatus_t sortScoresPerClass_gpu(cudaStream_t stream, const int num, const int num_classes, - const int num_preds_per_class, const int background_label_id, - const float confidence_threshold, void *conf_scores_gpu, - void *index_array_gpu, void *workspace) { - const int num_segments = num * num_classes; - void *temp_scores = workspace; - const int arrayLen = num * num_classes * num_preds_per_class; - void *temp_idx = nextWorkspacePtr((int8_t *)temp_scores, arrayLen * sizeof(T_SCORE)); - void *d_offsets = nextWorkspacePtr((int8_t *)temp_idx, arrayLen * sizeof(int)); - size_t cubOffsetSize = (num_segments + 1) * sizeof(int); - void *cubWorkspace = nextWorkspacePtr((int8_t *)d_offsets, cubOffsetSize); +template +pluginStatus_t sortScoresPerClass_gpu(cudaStream_t stream, const int num, const int num_classes, const int num_preds_per_class, const int background_label_id, const float confidence_threshold, void* conf_scores_gpu, void* index_array_gpu, void* workspace) +{ + const int num_segments = num * num_classes; + void* temp_scores = workspace; + const int arrayLen = num * num_classes * num_preds_per_class; + void* temp_idx = nextWorkspacePtr((int8_t*)temp_scores, arrayLen * sizeof(T_SCORE)); + void* d_offsets = nextWorkspacePtr((int8_t*)temp_idx, arrayLen * sizeof(int)); + size_t cubOffsetSize = (num_segments + 1) * sizeof(int); + void* cubWorkspace = nextWorkspacePtr((int8_t*)d_offsets, cubOffsetSize); - const int BS = 512; - const int GS = (num_classes * num_preds_per_class + BS - 1) / BS; - prepareSortData<<>>( - num, num_classes, num_preds_per_class, background_label_id, confidence_threshold, - (T_SCORE *)conf_scores_gpu, (T_SCORE *)temp_scores, (int *)temp_idx, (int *)d_offsets); + const int BS = 512; + const int GS = (num_classes * num_preds_per_class + BS - 1) / BS; + prepareSortData<<>>( + num, + num_classes, + num_preds_per_class, + background_label_id, + confidence_threshold, + (T_SCORE*)conf_scores_gpu, + (T_SCORE*)temp_scores, + (int*)temp_idx, + (int*)d_offsets); - size_t temp_storage_bytes = cubSortPairsWorkspaceSize(arrayLen, num_segments); - cub::DeviceSegmentedRadixSort::SortPairsDescending( - cubWorkspace, temp_storage_bytes, (const T_SCORE *)(temp_scores), - (T_SCORE *)(conf_scores_gpu), (const int *)(temp_idx), (int *)(index_array_gpu), arrayLen, - num_segments, (const int *)d_offsets, (const int *)d_offsets + 1, 0, sizeof(T_SCORE) * 8, - stream); - CSC(cudaGetLastError(), STATUS_FAILURE); - return STATUS_SUCCESS; + size_t temp_storage_bytes = cubSortPairsWorkspaceSize(arrayLen, num_segments); + cub::DeviceSegmentedRadixSort::SortPairsDescending( + cubWorkspace, + temp_storage_bytes, + (const T_SCORE*)(temp_scores), + (T_SCORE*)(conf_scores_gpu), + (const int*)(temp_idx), + (int*)(index_array_gpu), + arrayLen, + num_segments, + (const int*)d_offsets, + (const int*)d_offsets + 1, + 0, + sizeof(T_SCORE) * 8, + stream); + CSC(cudaGetLastError(), STATUS_FAILURE); + return STATUS_SUCCESS; } // sortScoresPerClass LAUNCH CONFIG -typedef pluginStatus_t (*sspcFunc)(cudaStream_t, const int, const int, const int, const int, - const float, void *, void *, void *); -struct sspcLaunchConfig { - DataType t_score; - sspcFunc function; +typedef pluginStatus_t (*sspcFunc)(cudaStream_t, const int, const int, const int, const int, const float, void*, void*, void*); +struct sspcLaunchConfig +{ + DataType t_score; + sspcFunc function; - sspcLaunchConfig(DataType t_score) : t_score(t_score) {} - sspcLaunchConfig(DataType t_score, sspcFunc function) : t_score(t_score), function(function) {} - bool operator==(const sspcLaunchConfig &other) { return t_score == other.t_score; } + sspcLaunchConfig(DataType t_score) + : t_score(t_score) + { + } + sspcLaunchConfig(DataType t_score, sspcFunc function) + : t_score(t_score) + , function(function) + { + } + bool operator==(const sspcLaunchConfig& other) + { + return t_score == other.t_score; + } }; static std::vector sspcFuncVec; -bool sspcInit() { - sspcFuncVec.push_back(sspcLaunchConfig(DataType::kFLOAT, sortScoresPerClass_gpu)); - return true; +bool sspcInit() +{ + sspcFuncVec.push_back(sspcLaunchConfig(DataType::kFLOAT, sortScoresPerClass_gpu)); + return true; } -static bool initialized = sspcInit(); +static bool initialized = sspcInit(); -pluginStatus_t sortScoresPerClass(cudaStream_t stream, const int num, const int num_classes, - const int num_preds_per_class, const int background_label_id, - const float confidence_threshold, const DataType DT_SCORE, - void *conf_scores_gpu, void *index_array_gpu, void *workspace) { - sspcLaunchConfig lc = sspcLaunchConfig(DT_SCORE); - for (unsigned i = 0; i < sspcFuncVec.size(); ++i) { - if (lc == sspcFuncVec[i]) { - DEBUG_PRINTF("sortScoresPerClass kernel %d\n", i); - return sspcFuncVec[i].function(stream, num, num_classes, num_preds_per_class, - background_label_id, confidence_threshold, conf_scores_gpu, - index_array_gpu, workspace); +pluginStatus_t sortScoresPerClass(cudaStream_t stream, const int num, const int num_classes, const int num_preds_per_class, const int background_label_id, const float confidence_threshold, const DataType DT_SCORE, void* conf_scores_gpu, void* index_array_gpu, void* workspace) +{ + sspcLaunchConfig lc = sspcLaunchConfig(DT_SCORE); + for (unsigned i = 0; i < sspcFuncVec.size(); ++i) + { + if (lc == sspcFuncVec[i]) + { + DEBUG_PRINTF("sortScoresPerClass kernel %d\n", i); + return sspcFuncVec[i].function(stream, num, num_classes, num_preds_per_class, background_label_id, confidence_threshold, conf_scores_gpu, index_array_gpu, workspace); + } } - } - return STATUS_BAD_PARAM; + return STATUS_BAD_PARAM; } -size_t sortScoresPerClassWorkspaceSize(const int num, const int num_classes, - const int num_preds_per_class, const DataType DT_CONF) { - size_t wss[4]; - const int arrayLen = num * num_classes * num_preds_per_class; - wss[0] = arrayLen * mmdeploy::getElementSize(DT_CONF); // temp scores - wss[1] = arrayLen * sizeof(int); // temp indices - wss[2] = (num * num_classes + 1) * sizeof(int); // offsets - if (DT_CONF == DataType::kFLOAT) { - wss[3] = cubSortPairsWorkspaceSize(arrayLen, num * num_classes); // cub workspace - } else { - printf("SCORE type not supported\n"); - return (size_t)-1; - } +size_t sortScoresPerClassWorkspaceSize(const int num, const int num_classes, const int num_preds_per_class, const DataType DT_CONF) +{ + size_t wss[4]; + const int arrayLen = num * num_classes * num_preds_per_class; + wss[0] = arrayLen * mmdeploy::getElementSize(DT_CONF); // temp scores + wss[1] = arrayLen * sizeof(int); // temp indices + wss[2] = (num * num_classes + 1) * sizeof(int); // offsets + if (DT_CONF == DataType::kFLOAT) + { + wss[3] = cubSortPairsWorkspaceSize(arrayLen, num * num_classes); // cub workspace + } + else + { + printf("SCORE type not supported\n"); + return (size_t)-1; + } - return calculateTotalWorkspaceSize(wss, 4); + return calculateTotalWorkspaceSize(wss, 4); } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerImage.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerImage.cu index a6ad70262d..2a940b691a 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerImage.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/nms/sortScoresPerImage.cu @@ -7,75 +7,94 @@ #include "nms/cub_helper.h" #include "nms/kernel.h" -template -pluginStatus_t sortScoresPerImage_gpu(cudaStream_t stream, const int num_images, - const int num_items_per_image, void *unsorted_scores, - void *unsorted_bbox_indices, void *sorted_scores, - void *sorted_bbox_indices, void *workspace) { - void *d_offsets = workspace; - void *cubWorkspace = nextWorkspacePtr((int8_t *)d_offsets, (num_images + 1) * sizeof(int)); +template +pluginStatus_t sortScoresPerImage_gpu(cudaStream_t stream, const int num_images, const int num_items_per_image, void* unsorted_scores, void* unsorted_bbox_indices, void* sorted_scores, void* sorted_bbox_indices, void* workspace) +{ + void* d_offsets = workspace; + void* cubWorkspace = nextWorkspacePtr((int8_t*)d_offsets, (num_images + 1) * sizeof(int)); - setUniformOffsets(stream, num_images, num_items_per_image, (int *)d_offsets); + setUniformOffsets(stream, num_images, num_items_per_image, (int*)d_offsets); - const int arrayLen = num_images * num_items_per_image; - size_t temp_storage_bytes = cubSortPairsWorkspaceSize(arrayLen, num_images); - cub::DeviceSegmentedRadixSort::SortPairsDescending( - cubWorkspace, temp_storage_bytes, (const T_SCORE *)(unsorted_scores), - (T_SCORE *)(sorted_scores), (const int *)(unsorted_bbox_indices), - (int *)(sorted_bbox_indices), arrayLen, num_images, (const int *)d_offsets, - (const int *)d_offsets + 1, 0, sizeof(T_SCORE) * 8, stream); - CSC(cudaGetLastError(), STATUS_FAILURE); - return STATUS_SUCCESS; + const int arrayLen = num_images * num_items_per_image; + size_t temp_storage_bytes = cubSortPairsWorkspaceSize(arrayLen, num_images); + cub::DeviceSegmentedRadixSort::SortPairsDescending( + cubWorkspace, + temp_storage_bytes, + (const T_SCORE*)(unsorted_scores), + (T_SCORE*)(sorted_scores), + (const int*)(unsorted_bbox_indices), + (int*)(sorted_bbox_indices), + arrayLen, + num_images, + (const int*)d_offsets, + (const int*)d_offsets + 1, + 0, + sizeof(T_SCORE) * 8, + stream); + CSC(cudaGetLastError(), STATUS_FAILURE); + return STATUS_SUCCESS; } // sortScoresPerImage LAUNCH CONFIG -typedef pluginStatus_t (*sspiFunc)(cudaStream_t, const int, const int, void *, void *, void *, - void *, void *); -struct sspiLaunchConfig { - DataType t_score; - sspiFunc function; +typedef pluginStatus_t (*sspiFunc)(cudaStream_t, const int, const int, void*, void*, void*, void*, void*); +struct sspiLaunchConfig +{ + DataType t_score; + sspiFunc function; - sspiLaunchConfig(DataType t_score) : t_score(t_score) {} - sspiLaunchConfig(DataType t_score, sspiFunc function) : t_score(t_score), function(function) {} - bool operator==(const sspiLaunchConfig &other) { return t_score == other.t_score; } + sspiLaunchConfig(DataType t_score) + : t_score(t_score) + { + } + sspiLaunchConfig(DataType t_score, sspiFunc function) + : t_score(t_score) + , function(function) + { + } + bool operator==(const sspiLaunchConfig& other) + { + return t_score == other.t_score; + } }; static std::vector sspiFuncVec; -bool sspiInit() { - sspiFuncVec.push_back(sspiLaunchConfig(DataType::kFLOAT, sortScoresPerImage_gpu)); - return true; +bool sspiInit() +{ + sspiFuncVec.push_back(sspiLaunchConfig(DataType::kFLOAT, sortScoresPerImage_gpu)); + return true; } -static bool initialized = sspiInit(); +static bool initialized = sspiInit(); -pluginStatus_t sortScoresPerImage(cudaStream_t stream, const int num_images, - const int num_items_per_image, const DataType DT_SCORE, - void *unsorted_scores, void *unsorted_bbox_indices, - void *sorted_scores, void *sorted_bbox_indices, void *workspace) { - sspiLaunchConfig lc = sspiLaunchConfig(DT_SCORE); - for (unsigned i = 0; i < sspiFuncVec.size(); ++i) { - if (lc == sspiFuncVec[i]) { - DEBUG_PRINTF("sortScoresPerImage kernel %d\n", i); - return sspiFuncVec[i].function(stream, num_images, num_items_per_image, unsorted_scores, - unsorted_bbox_indices, sorted_scores, sorted_bbox_indices, - workspace); +pluginStatus_t sortScoresPerImage(cudaStream_t stream, const int num_images, const int num_items_per_image, const DataType DT_SCORE, void* unsorted_scores, void* unsorted_bbox_indices, void* sorted_scores, void* sorted_bbox_indices, void* workspace) +{ + sspiLaunchConfig lc = sspiLaunchConfig(DT_SCORE); + for (unsigned i = 0; i < sspiFuncVec.size(); ++i) + { + if (lc == sspiFuncVec[i]) + { + DEBUG_PRINTF("sortScoresPerImage kernel %d\n", i); + return sspiFuncVec[i].function(stream, num_images, num_items_per_image, unsorted_scores, unsorted_bbox_indices, sorted_scores, sorted_bbox_indices, workspace); + } } - } - return STATUS_BAD_PARAM; + return STATUS_BAD_PARAM; } -size_t sortScoresPerImageWorkspaceSize(const int num_images, const int num_items_per_image, - const DataType DT_SCORE) { - const int arrayLen = num_images * num_items_per_image; - size_t wss[2]; - wss[0] = (num_images + 1) * sizeof(int); // offsets - if (DT_SCORE == DataType::kFLOAT) { - wss[1] = cubSortPairsWorkspaceSize(arrayLen, - num_images); // cub workspace - } else { - printf("SCORE type not supported.\n"); - return (size_t)-1; - } +size_t sortScoresPerImageWorkspaceSize(const int num_images, const int num_items_per_image, const DataType DT_SCORE) +{ + const int arrayLen = num_images * num_items_per_image; + size_t wss[2]; + wss[0] = (num_images + 1) * sizeof(int); // offsets + if (DT_SCORE == DataType::kFLOAT) + { + wss[1] = cubSortPairsWorkspaceSize(arrayLen, + num_images); // cub workspace + } + else + { + printf("SCORE type not supported.\n"); + return (size_t)-1; + } - return calculateTotalWorkspaceSize(wss, 2); + return calculateTotalWorkspaceSize(wss, 2); } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu index 47e8ae8615..67fa9d7961 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/common_impl/trt_cuda_helper.cu @@ -4,92 +4,98 @@ using mmdeploy::TensorDesc; -template -__global__ void copy_permute_kernel(scalar_t *__restrict__ dst, const scalar_t *__restrict__ src, - int n, TensorDesc ts_src_stride, TensorDesc ts_dst_stride, - TensorDesc ts_permute) { - const int src_dim = ts_src_stride.dim; - const auto src_stride = ts_src_stride.stride; - const auto dst_stride = ts_dst_stride.stride; - const auto permute = ts_permute.shape; - CUDA_1D_KERNEL_LOOP(index, n) { - size_t dst_index = index; - size_t src_index = 0; - for (int i = 0; i < src_dim; ++i) { - int dim_index = dst_index / dst_stride[i]; - dst_index = dst_index % dst_stride[i]; - src_index += dim_index * src_stride[permute[i]]; +template +__global__ void copy_permute_kernel(scalar_t* __restrict__ dst, const scalar_t* __restrict__ src, int n, TensorDesc ts_src_stride, TensorDesc ts_dst_stride, TensorDesc ts_permute) +{ + const int src_dim = ts_src_stride.dim; + const auto src_stride = ts_src_stride.stride; + const auto dst_stride = ts_dst_stride.stride; + const auto permute = ts_permute.shape; + CUDA_1D_KERNEL_LOOP(index, n) + { + size_t dst_index = index; + size_t src_index = 0; + for (int i = 0; i < src_dim; ++i) + { + int dim_index = dst_index / dst_stride[i]; + dst_index = dst_index % dst_stride[i]; + src_index += dim_index * src_stride[permute[i]]; + } + dst[index] = src[src_index]; } - dst[index] = src[src_index]; - } } -template -void memcpyPermute(scalar_t *dst, const scalar_t *src, int *src_size, int *permute, int src_dim, - cudaStream_t stream) { - size_t copy_size = 1; - TensorDesc ts_permute; - memcpy(&(ts_permute.shape[0]), permute, src_dim * sizeof(int)); +template +void memcpyPermute(scalar_t* dst, const scalar_t* src, int* src_size, int* permute, int src_dim, cudaStream_t stream) +{ + size_t copy_size = 1; + TensorDesc ts_permute; + memcpy(&(ts_permute.shape[0]), permute, src_dim * sizeof(int)); - TensorDesc ts_src_stride; - TensorDesc ts_dst_stride; - ts_src_stride.dim = src_dim; - ts_dst_stride.dim = src_dim; - int *src_stride = &(ts_src_stride.stride[0]); - int *dst_stride = &(ts_dst_stride.stride[0]); - int *dst_size = &(ts_dst_stride.shape[0]); - src_stride[src_dim - 1] = 1; - dst_stride[src_dim - 1] = 1; + TensorDesc ts_src_stride; + TensorDesc ts_dst_stride; + ts_src_stride.dim = src_dim; + ts_dst_stride.dim = src_dim; + int* src_stride = &(ts_src_stride.stride[0]); + int* dst_stride = &(ts_dst_stride.stride[0]); + int* dst_size = &(ts_dst_stride.shape[0]); + src_stride[src_dim - 1] = 1; + dst_stride[src_dim - 1] = 1; - for (int i = src_dim - 1; i >= 0; --i) { - dst_size[i] = src_size[permute[i]]; - if (i < src_dim - 1) { - src_stride[i] = src_stride[i + 1] * src_size[i + 1]; + for (int i = src_dim - 1; i >= 0; --i) + { + dst_size[i] = src_size[permute[i]]; + if (i < src_dim - 1) + { + src_stride[i] = src_stride[i + 1] * src_size[i + 1]; + } } - } - for (int i = src_dim - 1; i >= 0; --i) { - copy_size *= dst_size[i]; - if (i < src_dim - 1) { - dst_stride[i] = dst_stride[i + 1] * dst_size[i + 1]; + for (int i = src_dim - 1; i >= 0; --i) + { + copy_size *= dst_size[i]; + if (i < src_dim - 1) + { + dst_stride[i] = dst_stride[i + 1] * dst_size[i + 1]; + } } - } - copy_permute_kernel<<>>( - dst, src, copy_size, ts_src_stride, ts_dst_stride, ts_permute); + copy_permute_kernel<<>>( + dst, + src, + copy_size, + ts_src_stride, + ts_dst_stride, + ts_permute); } -template void memcpyPermute(float *dst, const float *src, int *src_size, int *permute, - int src_dim, cudaStream_t stream); -template void memcpyPermute(half *dst, const half *src, int *src_size, int *permute, - int src_dim, cudaStream_t stream); +template void memcpyPermute(float* dst, const float* src, int* src_size, int* permute, int src_dim, cudaStream_t stream); +template void memcpyPermute(half* dst, const half* src, int* src_size, int* permute, int src_dim, cudaStream_t stream); -cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype, cudnnDataType_t *cudnn_dtype) { - switch (trt_dtype) { - case nvinfer1::DataType::kFLOAT: - *cudnn_dtype = CUDNN_DATA_FLOAT; - break; - case nvinfer1::DataType::kHALF: - *cudnn_dtype = CUDNN_DATA_HALF; - break; - default: - return CUDNN_STATUS_BAD_PARAM; - } - return CUDNN_STATUS_SUCCESS; +cudnnStatus_t convert_trt2cudnn_dtype(nvinfer1::DataType trt_dtype, cudnnDataType_t* cudnn_dtype) +{ + switch (trt_dtype) + { + case nvinfer1::DataType::kFLOAT: + *cudnn_dtype = CUDNN_DATA_FLOAT; + break; + case nvinfer1::DataType::kHALF: + *cudnn_dtype = CUDNN_DATA_HALF; + break; + default: + return CUDNN_STATUS_BAD_PARAM; + } + return CUDNN_STATUS_SUCCESS; } -template <> -cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float *alpha, const float *A, int lda, const float *B, - int ldb, const float *beta, float *C, int ldc) { - return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +template<> +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, const float* A, int lda, const float* B, int ldb, const float* beta, float* C, int ldc) +{ + return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } -template <> -cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const half *alpha, const half *A, int lda, const half *B, - int ldb, const half *beta, half *C, int ldc) { - return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); +template<> +cublasStatus_t cublasGemmWrap(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const half* alpha, const half* A, int lda, const half* B, int ldb, const half* beta, half* C, int ldc) +{ + return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp index 0d518323d2..b833a7e19a 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.cpp @@ -10,254 +10,302 @@ using namespace nvinfer1; -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"MMCVDeformConv2d"}; -} // namespace - -DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string &name, - const nvinfer1::Dims stride, - const nvinfer1::Dims padding, - const nvinfer1::Dims dilation, - const int deformableGroup, const int group) - : TRTPluginBase(name), - mStride(stride), - mPadding(padding), - mDilation(dilation), - mDeformableGroup(deformableGroup), - mGroup(group) {} - -DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name, const void *data, - size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mStride); - deserialize_value(&data, &length, &mPadding); - deserialize_value(&data, &length, &mDilation); - deserialize_value(&data, &length, &mDeformableGroup); - deserialize_value(&data, &length, &mGroup); -} -DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {} - -nvinfer1::IPluginV2DynamicExt *DeformableConvPluginDynamic::clone() const TRT_NOEXCEPT { - DeformableConvPluginDynamic *plugin = new DeformableConvPluginDynamic( - mLayerName, mStride, mPadding, mDilation, mDeformableGroup, mGroup); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs DeformableConvPluginDynamic::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - // input[0] == input - // input[1] == offset - // input[2] == weight - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[2].d[0]; - - ret.d[2] = inputs[1].d[2]; - ret.d[3] = inputs[1].d[3]; - - return ret; -} - -bool DeformableConvPluginDynamic::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 0) { - return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT || - ioDesc[pos].type == nvinfer1::DataType::kHALF) && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - } else { - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - } -} - -void DeformableConvPluginDynamic::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT {} - -size_t DeformableConvPluginDynamic::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - int sizeof_dtype = mmdeploy::getElementSize(outputs[0].type); - - int batch_size = inputs[0].dims.d[0]; - int nInputPlane = inputs[0].dims.d[1]; - int inputHeight = inputs[0].dims.d[2]; - int inputWidth = inputs[0].dims.d[3]; - - int nOutputPlane = outputs[0].dims.d[1]; - int outputHeight = outputs[0].dims.d[2]; - int outputWidth = outputs[0].dims.d[3]; - - int kW = inputs[2].dims.d[2]; - int kH = inputs[2].dims.d[3]; - int im2col_step = std::min(32, batch_size); - - size_t col_size = mmdeploy::getAlignedSize(nInputPlane * kW * kH * im2col_step * outputHeight * - outputWidth * sizeof_dtype); - - size_t out_size = 0; - if (im2col_step != 1) - out_size = mmdeploy::getAlignedSize(batch_size * nOutputPlane * outputHeight * outputWidth * - sizeof_dtype); - - return col_size + out_size; -} - -int DeformableConvPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, - void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { - int batch = inputDesc[0].dims.d[0]; - int channels = inputDesc[0].dims.d[1]; - int height = inputDesc[0].dims.d[2]; - int width = inputDesc[0].dims.d[3]; - int channels_out = outputDesc[0].dims.d[1]; - int kernel_h = inputDesc[2].dims.d[2]; - int kernel_w = inputDesc[2].dims.d[3]; - - const void *x = inputs[0]; - const void *offset = inputs[1]; - const void *weight = inputs[2]; - void *output = outputs[0]; - int im2col_step = std::min(batch, 32); - - auto data_type = inputDesc[0].type; - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - deform_conv((float *)x, (float *)weight, (float *)offset, (float *)output, workSpace, - batch, channels, height, width, channels_out, kernel_w, kernel_h, - mStride.d[0], mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0], - mDilation.d[1], mGroup, mDeformableGroup, im2col_step, m_cublas_handle, - stream); - break; - case nvinfer1::DataType::kHALF: - deform_conv((half *)x, (half *)weight, (half *)offset, (half *)output, workSpace, batch, - channels, height, width, channels_out, kernel_w, kernel_h, mStride.d[0], - mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1], - mGroup, mDeformableGroup, im2col_step, m_cublas_handle, stream); - break; - default: - return 1; - } - - return 0; -} - -nvinfer1::DataType DeformableConvPluginDynamic::getOutputDataType( - int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *DeformableConvPluginDynamic::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *DeformableConvPluginDynamic::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -int DeformableConvPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t DeformableConvPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mStride) + serialized_size(mPadding) + serialized_size(mDilation) + - serialized_size(mDeformableGroup) + serialized_size(mGroup); -} - -void DeformableConvPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mStride); - serialize_value(&buffer, mPadding); - serialize_value(&buffer, mDilation); - serialize_value(&buffer, mDeformableGroup); - serialize_value(&buffer, mGroup); -} - -void DeformableConvPluginDynamic::attachToContext( - cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT { - m_cublas_handle = cublasContext; -} - -void DeformableConvPluginDynamic::detachFromContext() TRT_NOEXCEPT {} - -////////////////////// creator ///////////////////////////// - -DeformableConvPluginDynamicCreator::DeformableConvPluginDynamicCreator() { - mPluginAttributes.clear(); - mPluginAttributes.emplace_back(nvinfer1::PluginField("stride")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("padding")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("groups")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups")); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *DeformableConvPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *DeformableConvPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - nvinfer1::Dims stride{2, {1, 1}}; - nvinfer1::Dims padding{2, {0, 0}}; - nvinfer1::Dims dilation{2, {1, 1}}; - int deformableGroup = 1; - int group = 1; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"MMCVDeformConv2d"}; + } // namespace + + DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string& name, + const nvinfer1::Dims stride, + const nvinfer1::Dims padding, + const nvinfer1::Dims dilation, + const int deformableGroup, + const int group) + : TRTPluginBase(name) + , mStride(stride) + , mPadding(padding) + , mDilation(dilation) + , mDeformableGroup(deformableGroup) + , mGroup(group) + { } - std::string field_name(fc->fields[i].name); - if (field_name.compare("deform_groups") == 0) { - deformableGroup = static_cast(fc->fields[i].data)[0]; + DeformableConvPluginDynamic::DeformableConvPluginDynamic(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mStride); + deserialize_value(&data, &length, &mPadding); + deserialize_value(&data, &length, &mDilation); + deserialize_value(&data, &length, &mDeformableGroup); + deserialize_value(&data, &length, &mGroup); + } + DeformableConvPluginDynamic::~DeformableConvPluginDynamic() {} + + nvinfer1::IPluginV2DynamicExt* DeformableConvPluginDynamic::clone() const TRT_NOEXCEPT + { + DeformableConvPluginDynamic* plugin = new DeformableConvPluginDynamic( + mLayerName, + mStride, + mPadding, + mDilation, + mDeformableGroup, + mGroup); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs DeformableConvPluginDynamic::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + // input[0] == input + // input[1] == offset + // input[2] == weight + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[2].d[0]; + + ret.d[2] = inputs[1].d[2]; + ret.d[3] = inputs[1].d[3]; + + return ret; + } + + bool DeformableConvPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 0) + { + return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT || + ioDesc[pos].type == nvinfer1::DataType::kHALF) && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } + else + { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } + } + + void DeformableConvPluginDynamic::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + + size_t DeformableConvPluginDynamic::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + int sizeof_dtype = mmdeploy::getElementSize(outputs[0].type); + + int batch_size = inputs[0].dims.d[0]; + int nInputPlane = inputs[0].dims.d[1]; + int inputHeight = inputs[0].dims.d[2]; + int inputWidth = inputs[0].dims.d[3]; + + int nOutputPlane = outputs[0].dims.d[1]; + int outputHeight = outputs[0].dims.d[2]; + int outputWidth = outputs[0].dims.d[3]; + + int kW = inputs[2].dims.d[2]; + int kH = inputs[2].dims.d[3]; + int im2col_step = std::min(32, batch_size); + + size_t col_size = mmdeploy::getAlignedSize(nInputPlane * kW * kH * im2col_step * outputHeight * + outputWidth * sizeof_dtype); + + size_t out_size = 0; + if (im2col_step != 1) + out_size = mmdeploy::getAlignedSize(batch_size * nOutputPlane * outputHeight * outputWidth * + sizeof_dtype); + + return col_size + out_size; + } + + int DeformableConvPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int batch = inputDesc[0].dims.d[0]; + int channels = inputDesc[0].dims.d[1]; + int height = inputDesc[0].dims.d[2]; + int width = inputDesc[0].dims.d[3]; + int channels_out = outputDesc[0].dims.d[1]; + int kernel_h = inputDesc[2].dims.d[2]; + int kernel_w = inputDesc[2].dims.d[3]; + + const void* x = inputs[0]; + const void* offset = inputs[1]; + const void* weight = inputs[2]; + void* output = outputs[0]; + int im2col_step = std::min(batch, 32); + + auto data_type = inputDesc[0].type; + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + deform_conv((float*)x, (float*)weight, (float*)offset, (float*)output, workSpace, batch, channels, height, width, channels_out, kernel_w, kernel_h, mStride.d[0], mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1], mGroup, mDeformableGroup, im2col_step, m_cublas_handle, stream); + break; + case nvinfer1::DataType::kHALF: + deform_conv((half*)x, (half*)weight, (half*)offset, (half*)output, workSpace, batch, channels, height, width, channels_out, kernel_w, kernel_h, mStride.d[0], mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1], mGroup, mDeformableGroup, im2col_step, m_cublas_handle, stream); + break; + default: + return 1; + } + + return 0; + } + + nvinfer1::DataType DeformableConvPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* DeformableConvPluginDynamic::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* DeformableConvPluginDynamic::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int DeformableConvPluginDynamic::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t DeformableConvPluginDynamic::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mStride) + serialized_size(mPadding) + serialized_size(mDilation) + + serialized_size(mDeformableGroup) + serialized_size(mGroup); + } + + void DeformableConvPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mStride); + serialize_value(&buffer, mPadding); + serialize_value(&buffer, mDilation); + serialize_value(&buffer, mDeformableGroup); + serialize_value(&buffer, mGroup); + } + + void DeformableConvPluginDynamic::attachToContext( + cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT + { + m_cublas_handle = cublasContext; + } + + void DeformableConvPluginDynamic::detachFromContext() TRT_NOEXCEPT {} + + ////////////////////// creator ///////////////////////////// + + DeformableConvPluginDynamicCreator::DeformableConvPluginDynamicCreator() + { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("padding")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("groups")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); } - if (field_name.compare("groups") == 0) { - group = static_cast(fc->fields[i].data)[0]; + const char* DeformableConvPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; } - if (field_name.compare("stride") == 0) { - stride.nbDims = 2; - stride.d[0] = static_cast(fc->fields[i].data)[0]; - stride.d[1] = static_cast(fc->fields[i].data)[1]; + const char* DeformableConvPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; } - if (field_name.compare("padding") == 0) { - padding.nbDims = 2; - padding.d[0] = static_cast(fc->fields[i].data)[0]; - padding.d[1] = static_cast(fc->fields[i].data)[1]; + nvinfer1::IPluginV2* DeformableConvPluginDynamicCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + nvinfer1::Dims stride{2, {1, 1}}; + nvinfer1::Dims padding{2, {0, 0}}; + nvinfer1::Dims dilation{2, {1, 1}}; + int deformableGroup = 1; + int group = 1; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("deform_groups") == 0) + { + deformableGroup = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("groups") == 0) + { + group = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("stride") == 0) + { + stride.nbDims = 2; + stride.d[0] = static_cast(fc->fields[i].data)[0]; + stride.d[1] = static_cast(fc->fields[i].data)[1]; + } + + if (field_name.compare("padding") == 0) + { + padding.nbDims = 2; + padding.d[0] = static_cast(fc->fields[i].data)[0]; + padding.d[1] = static_cast(fc->fields[i].data)[1]; + } + + if (field_name.compare("dilation") == 0) + { + dilation.nbDims = 2; + dilation.d[0] = static_cast(fc->fields[i].data)[0]; + dilation.d[1] = static_cast(fc->fields[i].data)[1]; + } + } + + DeformableConvPluginDynamic* plugin = + new DeformableConvPluginDynamic(name, stride, padding, dilation, deformableGroup, group); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; } - if (field_name.compare("dilation") == 0) { - dilation.nbDims = 2; - dilation.d[0] = static_cast(fc->fields[i].data)[0]; - dilation.d[1] = static_cast(fc->fields[i].data)[1]; + nvinfer1::IPluginV2* DeformableConvPluginDynamicCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new DeformableConvPluginDynamic(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; } - } - - DeformableConvPluginDynamic *plugin = - new DeformableConvPluginDynamic(name, stride, padding, dilation, deformableGroup, group); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *DeformableConvPluginDynamicCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new DeformableConvPluginDynamic(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator); + REGISTER_TENSORRT_PLUGIN(DeformableConvPluginDynamicCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp index 3ea0ccbefe..6d3b4f936c 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv.hpp @@ -9,73 +9,68 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class DeformableConvPluginDynamic : public TRTPluginBase { - public: - DeformableConvPluginDynamic(const std::string &name, const nvinfer1::Dims stride, - const nvinfer1::Dims padding, const nvinfer1::Dims dilation, - const int deformableGroup, const int group); - - DeformableConvPluginDynamic(const std::string name, const void *data, size_t length); - - DeformableConvPluginDynamic() = delete; - - ~DeformableConvPluginDynamic() TRT_NOEXCEPT override; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override; - void detachFromContext() TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; - - private: - nvinfer1::Dims mStride; - nvinfer1::Dims mPadding; - nvinfer1::Dims mDilation; - int mDeformableGroup; - int mGroup; - - cublasHandle_t m_cublas_handle; -}; - -class DeformableConvPluginDynamicCreator : public TRTPluginCreatorBase { - public: - DeformableConvPluginDynamicCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class DeformableConvPluginDynamic : public TRTPluginBase + { + public: + DeformableConvPluginDynamic(const std::string& name, const nvinfer1::Dims stride, const nvinfer1::Dims padding, const nvinfer1::Dims dilation, const int deformableGroup, const int group); + + DeformableConvPluginDynamic(const std::string name, const void* data, size_t length); + + DeformableConvPluginDynamic() = delete; + + ~DeformableConvPluginDynamic() TRT_NOEXCEPT override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override; + void detachFromContext() TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + + private: + nvinfer1::Dims mStride; + nvinfer1::Dims mPadding; + nvinfer1::Dims mDilation; + int mDeformableGroup; + int mGroup; + + cublasHandle_t m_cublas_handle; + }; + + class DeformableConvPluginDynamicCreator : public TRTPluginCreatorBase + { + public: + DeformableConvPluginDynamicCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_DEFORM_CONV_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu index 3f401fc9e2..8fe86280af 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cu @@ -68,105 +68,107 @@ #include "trt_deform_conv_kernel.hpp" #include "trt_plugin_helper.hpp" -template -void deform_conv_im2col(const scalar_t* input, const scalar_t* offset, scalar_t* column, - const int channels, const int height, const int width, const int ksize_h, - const int ksize_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, const int dilation_h, const int dilation_w, - const int parallel_imgs, const int deformable_group, cudaStream_t stream) { - int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; - int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; - int num_kernels = channels * height_col * width_col * parallel_imgs; - int channel_per_deformable_group = channels / deformable_group; - - deformable_im2col_gpu_kernel<<>>( - num_kernels, input, offset, height, width, ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, channel_per_deformable_group, parallel_imgs, channels, - deformable_group, height_col, width_col, column); - - cudaCheckError(); +template +void deform_conv_im2col(const scalar_t* input, const scalar_t* offset, scalar_t* column, const int channels, const int height, const int width, const int ksize_h, const int ksize_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int parallel_imgs, const int deformable_group, cudaStream_t stream) +{ + int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + deformable_im2col_gpu_kernel<<>>( + num_kernels, + input, + offset, + height, + width, + ksize_h, + ksize_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + parallel_imgs, + channels, + deformable_group, + height_col, + width_col, + column); + + cudaCheckError(); } -template -void deform_conv(const scalar_t* input, const scalar_t* weight, const scalar_t* offset, - scalar_t* output, void* workspace, int batchSize, int nInputPlane, int inputHeight, - int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, - int padH, int dilationW, int dilationH, int group, int deformable_group, - int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream) { - size_t word_size = sizeof(scalar_t); - - im2col_step = std::min(int(batchSize), im2col_step); - long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; - long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; - - long outputHW = outputHeight * outputWidth; - long kHW = kH * kW; - long columns_size = - mmdeploy::getAlignedSize(nInputPlane * kHW * im2col_step * outputHW * word_size); - - // column buffer for img2col - char* workspace_ptr = reinterpret_cast(workspace); - scalar_t* columns = reinterpret_cast(workspace_ptr); - workspace_ptr = workspace_ptr + columns_size; - - scalar_t* output_buffer; - if (im2col_step == 1) { - output_buffer = output; - } else { - // output need permute when im2col_step!=1 - output_buffer = reinterpret_cast(workspace_ptr); - } - - long input_elt_step = im2col_step * nInputPlane * inputHeight * inputWidth; - long offset_elt_step = im2col_step * deformable_group * 2 * kHW * outputHW; - long out_buffer_step = nOutputPlane * im2col_step * outputHW; - long col_g_step = nInputPlane * kHW * im2col_step * outputHW / group; - long weight_g_step = nOutputPlane * nInputPlane * kHW / (group * group); - long out_buffer_g_step = out_buffer_step / group; - int m = nOutputPlane / group; - int n = im2col_step * outputHW; - int k = nInputPlane * kHW / group; - scalar_t alpha = 1.f; - scalar_t beta = 0.f; - - for (int elt = 0; elt < batchSize / im2col_step; elt++) { - const scalar_t* input_start = input + elt * input_elt_step; - const scalar_t* offset_start = offset + elt * offset_elt_step; - - deform_conv_im2col(input_start, offset_start, columns, nInputPlane, inputHeight, - inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, - im2col_step, deformable_group, stream); - - for (int g = 0; g < group; ++g) { - const scalar_t* weight_start = weight + g * weight_g_step; - scalar_t* col_start = columns + g * col_g_step; - scalar_t* out_buffer_start = output_buffer + elt * out_buffer_step + g * out_buffer_g_step; - - cublasGemmWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, col_start, - n, weight_start, k, &beta, out_buffer_start, n); - cudaCheckError(); +template +void deform_conv(const scalar_t* input, const scalar_t* weight, const scalar_t* offset, scalar_t* output, void* workspace, int batchSize, int nInputPlane, int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, int group, int deformable_group, int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream) +{ + size_t word_size = sizeof(scalar_t); + + im2col_step = std::min(int(batchSize), im2col_step); + long outputWidth = (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + long outputHW = outputHeight * outputWidth; + long kHW = kH * kW; + long columns_size = + mmdeploy::getAlignedSize(nInputPlane * kHW * im2col_step * outputHW * word_size); + + // column buffer for img2col + char* workspace_ptr = reinterpret_cast(workspace); + scalar_t* columns = reinterpret_cast(workspace_ptr); + workspace_ptr = workspace_ptr + columns_size; + + scalar_t* output_buffer; + if (im2col_step == 1) + { + output_buffer = output; + } + else + { + // output need permute when im2col_step!=1 + output_buffer = reinterpret_cast(workspace_ptr); + } + + long input_elt_step = im2col_step * nInputPlane * inputHeight * inputWidth; + long offset_elt_step = im2col_step * deformable_group * 2 * kHW * outputHW; + long out_buffer_step = nOutputPlane * im2col_step * outputHW; + long col_g_step = nInputPlane * kHW * im2col_step * outputHW / group; + long weight_g_step = nOutputPlane * nInputPlane * kHW / (group * group); + long out_buffer_g_step = out_buffer_step / group; + int m = nOutputPlane / group; + int n = im2col_step * outputHW; + int k = nInputPlane * kHW / group; + scalar_t alpha = 1.f; + scalar_t beta = 0.f; + + for (int elt = 0; elt < batchSize / im2col_step; elt++) + { + const scalar_t* input_start = input + elt * input_elt_step; + const scalar_t* offset_start = offset + elt * offset_elt_step; + + deform_conv_im2col(input_start, offset_start, columns, nInputPlane, inputHeight, inputWidth, kH, kW, padH, padW, dH, dW, dilationH, dilationW, im2col_step, deformable_group, stream); + + for (int g = 0; g < group; ++g) + { + const scalar_t* weight_start = weight + g * weight_g_step; + scalar_t* col_start = columns + g * col_g_step; + scalar_t* out_buffer_start = output_buffer + elt * out_buffer_step + g * out_buffer_g_step; + + cublasGemmWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, col_start, n, weight_start, k, &beta, out_buffer_start, n); + cudaCheckError(); + } + } + + if (im2col_step != 1) + { + int output_buffer_shape[5] = {batchSize / im2col_step, nOutputPlane, im2col_step, static_cast(outputHeight), static_cast(outputWidth)}; + int output_buffer_permute[5] = {0, 2, 1, 3, 4}; + memcpyPermute(output, output_buffer, &output_buffer_shape[0], &output_buffer_permute[0], 5, stream); } - } - - if (im2col_step != 1) { - int output_buffer_shape[5] = {batchSize / im2col_step, nOutputPlane, im2col_step, - static_cast(outputHeight), static_cast(outputWidth)}; - int output_buffer_permute[5] = {0, 2, 1, 3, 4}; - memcpyPermute(output, output_buffer, &output_buffer_shape[0], - &output_buffer_permute[0], 5, stream); - } } -template void deform_conv(const float* input, const float* weight, const float* offset, - float* output, void* workspace, int batchSize, int nInputPlane, - int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH, - int dW, int dH, int padW, int padH, int dilationW, int dilationH, - int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream); - -template void deform_conv<__half>(const __half* input, const __half* weight, const __half* offset, - __half* output, void* workspace, int batchSize, int nInputPlane, - int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH, - int dW, int dH, int padW, int padH, int dilationW, int dilationH, - int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream); +template void deform_conv(const float* input, const float* weight, const float* offset, float* output, void* workspace, int batchSize, int nInputPlane, int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, int group, int deformable_group, int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream); + +template void deform_conv<__half>(const __half* input, const __half* weight, const __half* offset, __half* output, void* workspace, int batchSize, int nInputPlane, int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, int group, int deformable_group, int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh index c91f17ca4a..330f4b331a 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh +++ b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.cuh @@ -67,108 +67,134 @@ #include "common_cuda_helper.hpp" -template +template __device__ __forceinline__ scalar_t deformable_im2col_bilinear(const scalar_t* __restrict__ input, - const int height, const int width, - float h, float w) { - if (h <= -1 || height <= h || w <= -1 || width <= w) { - return 0; - } + const int height, + const int width, + float h, + float w) +{ + if (h <= -1 || height <= h || w <= -1 || width <= w) + { + return 0; + } - const int h_low = floorf(h); - const int w_low = floorf(w); + const int h_low = floorf(h); + const int w_low = floorf(w); - input += h_low * width; - const scalar_t v1 = (h_low >= 0 && w_low >= 0) ? input[w_low] : static_cast(0.0f); - const int w_high = w_low + 1; - const scalar_t v2 = - (h_low >= 0 && w_high <= width - 1) ? input[w_high] : static_cast(0.0f); - const scalar_t lw = w - w_low; - const scalar_t v_low = fmaf(v2 - v1, lw, v1); - input += width; - const scalar_t v3 = - (h_low <= height - 2 && w_low >= 0) ? input[w_low] : static_cast(0.0f); - const scalar_t v4 = - (h_low <= height - 2 && w_high <= width - 1) ? input[w_high] : static_cast(0.0f); - const scalar_t v_high = fmaf(v4 - v3, lw, v3); - const scalar_t lh = h - h_low; - const scalar_t val = fmaf(v_high - v_low, lh, v_low); - return val; + input += h_low * width; + const scalar_t v1 = (h_low >= 0 && w_low >= 0) ? input[w_low] : static_cast(0.0f); + const int w_high = w_low + 1; + const scalar_t v2 = + (h_low >= 0 && w_high <= width - 1) ? input[w_high] : static_cast(0.0f); + const scalar_t lw = w - w_low; + const scalar_t v_low = fmaf(v2 - v1, lw, v1); + input += width; + const scalar_t v3 = + (h_low <= height - 2 && w_low >= 0) ? input[w_low] : static_cast(0.0f); + const scalar_t v4 = + (h_low <= height - 2 && w_high <= width - 1) ? input[w_high] : static_cast(0.0f); + const scalar_t v_high = fmaf(v4 - v3, lw, v3); + const scalar_t lh = h - h_low; + const scalar_t val = fmaf(v_high - v_low, lh, v_low); + return val; } -template <> +template<> __device__ __forceinline__ __half deformable_im2col_bilinear(const __half* __restrict__ input, - const int height, const int width, - float h, float w) { - if (h <= -1 || height <= h || w <= -1 || width <= w) { - return 0; - } + const int height, + const int width, + float h, + float w) +{ + if (h <= -1 || height <= h || w <= -1 || width <= w) + { + return 0; + } - const int h_low = floorf(h); - const int w_low = floorf(w); + const int h_low = floorf(h); + const int w_low = floorf(w); - input += h_low * width; - const float v1 = (h_low >= 0 && w_low >= 0) ? __half2float(input[w_low]) : 0.0f; - const int w_high = w_low + 1; - const float v2 = (h_low >= 0 && w_high <= width - 1) ? __half2float(input[w_high]) : 0.0f; - const float lw = w - w_low; - const float v_low = fmaf(v2 - v1, lw, v1); - input += width; - const float v3 = (h_low <= height - 2 && w_low >= 0) ? __half2float(input[w_low]) : 0.0f; - const float v4 = - (h_low <= height - 2 && w_high <= width - 1) ? __half2float(input[w_high]) : 0.0f; - const float v_high = fmaf(v4 - v3, lw, v3); - const float lh = h - h_low; - const float val = fmaf(v_high - v_low, lh, v_low); - return __float2half(val); + input += h_low * width; + const float v1 = (h_low >= 0 && w_low >= 0) ? __half2float(input[w_low]) : 0.0f; + const int w_high = w_low + 1; + const float v2 = (h_low >= 0 && w_high <= width - 1) ? __half2float(input[w_high]) : 0.0f; + const float lw = w - w_low; + const float v_low = fmaf(v2 - v1, lw, v1); + input += width; + const float v3 = (h_low <= height - 2 && w_low >= 0) ? __half2float(input[w_low]) : 0.0f; + const float v4 = + (h_low <= height - 2 && w_high <= width - 1) ? __half2float(input[w_high]) : 0.0f; + const float v_high = fmaf(v4 - v3, lw, v3); + const float lh = h - h_low; + const float val = fmaf(v_high - v_low, lh, v_low); + return __float2half(val); } -template +template __global__ void deformable_im2col_gpu_kernel( - const int n, const scalar_t* __restrict__ data_im, const scalar_t* __restrict__ data_offset, - const int height, const int width, const int kernel_h, const int kernel_w, const int pad_h, - const int pad_w, const int stride_h, const int stride_w, const int dilation_h, - const int dilation_w, const int channel_per_deformable_group, const int batch_size, - const int num_channels, const int deformable_group, const int height_col, const int width_col, - scalar_t* __restrict__ data_col) { - const int hw_col = height_col * width_col; - const int data_col_step = batch_size * hw_col; + const int n, + const scalar_t* __restrict__ data_im, + const scalar_t* __restrict__ data_offset, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int num_channels, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* __restrict__ data_col) +{ + const int hw_col = height_col * width_col; + const int data_col_step = batch_size * hw_col; - CUDA_1D_KERNEL_LOOP(index, n) { - // index index of output matrix - int tmp_index = index; - const int w_col = tmp_index % width_col; - tmp_index /= width_col; - const int h_col = tmp_index % height_col; - tmp_index /= height_col; - const int b_col = tmp_index % batch_size; - const int c_im = tmp_index / batch_size; - const int c_col = c_im * kernel_h * kernel_w; + CUDA_1D_KERNEL_LOOP(index, n) + { + // index index of output matrix + int tmp_index = index; + const int w_col = tmp_index % width_col; + tmp_index /= width_col; + const int h_col = tmp_index % height_col; + tmp_index /= height_col; + const int b_col = tmp_index % batch_size; + const int c_im = tmp_index / batch_size; + const int c_col = c_im * kernel_h * kernel_w; - // compute deformable group index - const int deformable_group_index = c_im / channel_per_deformable_group; + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - scalar_t* __restrict__ data_col_ptr = data_col + c_col * data_col_step + index % data_col_step; - const scalar_t* __restrict__ data_im_ptr = - data_im + (b_col * num_channels + c_im) * height * width; - const scalar_t* __restrict__ data_offset_ptr = - data_offset + - ((b_col * deformable_group + deformable_group_index) << 1) * kernel_h * kernel_w * hw_col + - h_col * width_col + w_col; - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h = (i * kernel_w + j) * hw_col << 1; - const scalar_t offset_h = data_offset_ptr[data_offset_h]; - const int data_offset_w = data_offset_h + hw_col; - const scalar_t offset_w = data_offset_ptr[data_offset_w]; - const scalar_t h_im = h_in + i * dilation_h + (float)offset_h; - const scalar_t w_im = w_in + j * dilation_w + (float)offset_w; - const scalar_t val = deformable_im2col_bilinear(data_im_ptr, height, width, h_im, w_im); - *data_col_ptr = val; - data_col_ptr += data_col_step; - } + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t* __restrict__ data_col_ptr = data_col + c_col * data_col_step + index % data_col_step; + const scalar_t* __restrict__ data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t* __restrict__ data_offset_ptr = + data_offset + + ((b_col * deformable_group + deformable_group_index) << 1) * kernel_h * kernel_w * hw_col + + h_col * width_col + w_col; + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h = (i * kernel_w + j) * hw_col << 1; + const scalar_t offset_h = data_offset_ptr[data_offset_h]; + const int data_offset_w = data_offset_h + hw_col; + const scalar_t offset_w = data_offset_ptr[data_offset_w]; + const scalar_t h_im = h_in + i * dilation_h + (float)offset_h; + const scalar_t w_im = w_in + j * dilation_w + (float)offset_w; + const scalar_t val = deformable_im2col_bilinear(data_im_ptr, height, width, h_im, w_im); + *data_col_ptr = val; + data_col_ptr += data_col_step; + } + } } - } } diff --git a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp index 3d8f6dfc45..35f08be1b4 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/deform_conv/trt_deform_conv_kernel.hpp @@ -4,17 +4,9 @@ #include #include -template -void deform_conv_im2col(const scalar_t* input, const scalar_t* offset, scalar_t* column, - const int channels, const int height, const int width, const int ksize_h, - const int ksize_w, const int pad_h, const int pad_w, const int stride_h, - const int stride_w, const int dilation_h, const int dilation_w, - const int parallel_imgs, const int deformable_group, cudaStream_t stream); +template +void deform_conv_im2col(const scalar_t* input, const scalar_t* offset, scalar_t* column, const int channels, const int height, const int width, const int ksize_h, const int ksize_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int parallel_imgs, const int deformable_group, cudaStream_t stream); -template -void deform_conv(const scalar_t* input, const scalar_t* weight, const scalar_t* offset, - scalar_t* output, void* workspace, int batchSize, int nInputPlane, int inputHeight, - int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, - int padH, int dilationW, int dilationH, int group, int deformable_group, - int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream); +template +void deform_conv(const scalar_t* input, const scalar_t* weight, const scalar_t* offset, scalar_t* output, void* workspace, int batchSize, int nInputPlane, int inputHeight, int inputWidth, int nOutputPlane, int kW, int kH, int dW, int dH, int padW, int padH, int dilationW, int dilationH, int group, int deformable_group, int im2col_step, cublasHandle_t cublas_handle, cudaStream_t stream); #endif // TRT_DEFORM_CONV_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp index b5e6c0b677..7dd688e089 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.cpp @@ -10,141 +10,176 @@ #include "gather_topk_kernel.hpp" #include "trt_serialize.hpp" -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"GatherTopk"}; -} // namespace - -GatherTopk::GatherTopk(const std::string &name) : TRTPluginBase(name) {} - -GatherTopk::GatherTopk(const std::string name, const void *data, size_t length) - : TRTPluginBase(name) {} - -nvinfer1::IPluginV2DynamicExt *GatherTopk::clone() const TRT_NOEXCEPT { - GatherTopk *plugin = new GatherTopk(mLayerName); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs GatherTopk::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - assert(inputs[0].nbDims >= inputs[1].nbDims); - nvinfer1::DimsExprs ret; - ret.nbDims = inputs[0].nbDims; - for (int i = 0; i < inputs[1].nbDims; ++i) { - ret.d[i] = inputs[1].d[i]; - } - for (int i = inputs[1].nbDims; i < inputs[0].nbDims; ++i) { - ret.d[i] = inputs[0].d[i]; - } - return ret; -} - -bool GatherTopk::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - switch (pos) { - case 0: - // data - return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) || - (ioDesc[pos].type == nvinfer1::DataType::kINT32 && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - case 1: - // indices - return ioDesc[pos].type == nvinfer1::DataType::kINT32 && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; - case 2: - // output - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - default: - return true; - } - return true; -} - -void GatherTopk::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT {} - -size_t GatherTopk::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int GatherTopk::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { - const int *dims = &(inputDesc[0].dims.d[0]); - const int *indices_dims = &(inputDesc[1].dims.d[0]); - int nbDims = inputDesc[0].dims.nbDims; - int indice_nbDims = inputDesc[1].dims.nbDims; - - const void *data = inputs[0]; - const void *indices = inputs[1]; - void *output = outputs[0]; - - auto data_type = inputDesc[0].type; - - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - gather_topk_impl((float *)data, (int *)indices, dims, nbDims, indices_dims, - indice_nbDims, (float *)output, stream); - break; - - case nvinfer1::DataType::kINT32: - gather_topk_impl((int *)data, (int *)indices, dims, nbDims, indices_dims, indice_nbDims, - (int *)output, stream); - break; - default: - break; - } - - return 0; -} - -nvinfer1::DataType GatherTopk::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *GatherTopk::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *GatherTopk::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int GatherTopk::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t GatherTopk::getSerializationSize() const TRT_NOEXCEPT { return 0; } - -void GatherTopk::serialize(void *buffer) const TRT_NOEXCEPT {} - -GatherTopkCreator::GatherTopkCreator() { - mPluginAttributes.clear(); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *GatherTopkCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *GatherTopkCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -nvinfer1::IPluginV2 *GatherTopkCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - auto *plugin = new GatherTopk(name); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *GatherTopkCreator::deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT { - auto plugin = new GatherTopk(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -REGISTER_TENSORRT_PLUGIN(GatherTopkCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"GatherTopk"}; + } // namespace + + GatherTopk::GatherTopk(const std::string& name) + : TRTPluginBase(name) + { + } + + GatherTopk::GatherTopk(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + } + + nvinfer1::IPluginV2DynamicExt* GatherTopk::clone() const TRT_NOEXCEPT + { + GatherTopk* plugin = new GatherTopk(mLayerName); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs GatherTopk::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + assert(inputs[0].nbDims >= inputs[1].nbDims); + nvinfer1::DimsExprs ret; + ret.nbDims = inputs[0].nbDims; + for (int i = 0; i < inputs[1].nbDims; ++i) + { + ret.d[i] = inputs[1].d[i]; + } + for (int i = inputs[1].nbDims; i < inputs[0].nbDims; ++i) + { + ret.d[i] = inputs[0].d[i]; + } + return ret; + } + + bool GatherTopk::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT + { + switch (pos) + { + case 0: + // data + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) || + (ioDesc[pos].type == nvinfer1::DataType::kINT32 && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + case 1: + // indices + return ioDesc[pos].type == nvinfer1::DataType::kINT32 && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + case 2: + // output + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + default: + return true; + } + return true; + } + + void GatherTopk::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) TRT_NOEXCEPT {} + + size_t GatherTopk::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int GatherTopk::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + const int* dims = &(inputDesc[0].dims.d[0]); + const int* indices_dims = &(inputDesc[1].dims.d[0]); + int nbDims = inputDesc[0].dims.nbDims; + int indice_nbDims = inputDesc[1].dims.nbDims; + + const void* data = inputs[0]; + const void* indices = inputs[1]; + void* output = outputs[0]; + + auto data_type = inputDesc[0].type; + + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + gather_topk_impl((float*)data, (int*)indices, dims, nbDims, indices_dims, indice_nbDims, (float*)output, stream); + break; + + case nvinfer1::DataType::kINT32: + gather_topk_impl((int*)data, (int*)indices, dims, nbDims, indices_dims, indice_nbDims, (int*)output, stream); + break; + default: + break; + } + + return 0; + } + + nvinfer1::DataType GatherTopk::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* GatherTopk::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* GatherTopk::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int GatherTopk::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t GatherTopk::getSerializationSize() const TRT_NOEXCEPT + { + return 0; + } + + void GatherTopk::serialize(void* buffer) const TRT_NOEXCEPT {} + + GatherTopkCreator::GatherTopkCreator() + { + mPluginAttributes.clear(); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* GatherTopkCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* GatherTopkCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* GatherTopkCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + auto* plugin = new GatherTopk(name); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* GatherTopkCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new GatherTopk(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + REGISTER_TENSORRT_PLUGIN(GatherTopkCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.hpp b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.hpp index 72f76a2678..b3db9b4058 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk.hpp @@ -9,56 +9,54 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class GatherTopk : public TRTPluginBase { - public: - GatherTopk(const std::string &name); - - GatherTopk(const std::string name, const void *data, size_t length); - - GatherTopk() = delete; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; -}; - -class GatherTopkCreator : public TRTPluginCreatorBase { - public: - GatherTopkCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class GatherTopk : public TRTPluginBase + { + public: + GatherTopk(const std::string& name); + + GatherTopk(const std::string name, const void* data, size_t length); + + GatherTopk() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + }; + + class GatherTopkCreator : public TRTPluginCreatorBase + { + public: + GatherTopkCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_SCATTERND_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.cu index 9a5c8ec963..873876ec12 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.cu @@ -8,39 +8,34 @@ #include "gather_topk_kernel.hpp" #include "trt_plugin_helper.hpp" -template -__global__ void gather_topk_kernel(const scalar_t* input, const int* indices, scalar_t* output, - int batch, int num_input, int num_indices, int channel) { - CUDA_1D_KERNEL_LOOP(index, batch * num_indices * channel) { - const int b_id = index / (num_indices * channel); - const int n_id = (index / channel) % num_indices; - const int c_id = index % channel; +template +__global__ void gather_topk_kernel(const scalar_t* input, const int* indices, scalar_t* output, int batch, int num_input, int num_indices, int channel) +{ + CUDA_1D_KERNEL_LOOP(index, batch * num_indices * channel) + { + const int b_id = index / (num_indices * channel); + const int n_id = (index / channel) % num_indices; + const int c_id = index % channel; - const int input_n_id = indices[b_id * num_indices + n_id]; - const scalar_t value = input[b_id * num_input * channel + input_n_id * channel + c_id]; - output[b_id * num_indices * channel + n_id * channel + c_id] = value; - } + const int input_n_id = indices[b_id * num_indices + n_id]; + const scalar_t value = input[b_id * num_input * channel + input_n_id * channel + c_id]; + output[b_id * num_indices * channel + n_id * channel + c_id] = value; + } } -template -void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims, - const int* indices_dims, int indice_nbDims, scalar_t* output, - cudaStream_t stream) { - int batch = 1; - for (int i = 0; i < indice_nbDims - 1; ++i) batch *= dims[i]; - int num_input = dims[indice_nbDims - 1]; - int num_indices = indices_dims[indice_nbDims - 1]; - int channel = 1; - for (int i = indice_nbDims; i < nbDims; ++i) channel *= dims[i]; - const int col_block = DIVUP(batch * num_indices * channel, THREADS_PER_BLOCK); - gather_topk_kernel<<>>(input, indices, output, batch, - num_input, num_indices, channel); +template +void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims, const int* indices_dims, int indice_nbDims, scalar_t* output, cudaStream_t stream) +{ + int batch = 1; + for (int i = 0; i < indice_nbDims - 1; ++i) batch *= dims[i]; + int num_input = dims[indice_nbDims - 1]; + int num_indices = indices_dims[indice_nbDims - 1]; + int channel = 1; + for (int i = indice_nbDims; i < nbDims; ++i) channel *= dims[i]; + const int col_block = DIVUP(batch * num_indices * channel, THREADS_PER_BLOCK); + gather_topk_kernel<<>>(input, indices, output, batch, num_input, num_indices, channel); } -template void gather_topk_impl(const float* input, const int* indices, const int* dims, - int nbDims, const int* indices_dims, int indice_nbDims, - float* output, cudaStream_t stream); +template void gather_topk_impl(const float* input, const int* indices, const int* dims, int nbDims, const int* indices_dims, int indice_nbDims, float* output, cudaStream_t stream); -template void gather_topk_impl(const int32_t* input, const int* indices, const int* dims, - int nbDims, const int* indices_dims, int indice_nbDims, - int32_t* output, cudaStream_t stream); +template void gather_topk_impl(const int32_t* input, const int* indices, const int* dims, int nbDims, const int* indices_dims, int indice_nbDims, int32_t* output, cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.hpp index 1f9b428394..e5ee6b987e 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/gather_topk/gather_topk_kernel.hpp @@ -3,8 +3,6 @@ #define TRT_GRID_SAMPLER_KERNEL_HPP #include -template -void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims, - const int* indices_dims, int indice_nbDims, scalar_t* output, - cudaStream_t stream); +template +void gather_topk_impl(const scalar_t* input, const int* indices, const int* dims, int nbDims, const int* indices_dims, int indice_nbDims, scalar_t* output, cudaStream_t stream); #endif // TRT_GRID_SAMPLER_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp index 1850fbfc1a..ef99b1fba6 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.cpp @@ -10,145 +10,190 @@ using namespace nvinfer1; -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"GridPriorsTRT"}; -} // namespace - -GridPriorsTRT::GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride) - : TRTPluginBase(name), mStride(stride) {} - -GridPriorsTRT::GridPriorsTRT(const std::string name, const void *data, size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mStride); -} -GridPriorsTRT::~GridPriorsTRT() {} - -nvinfer1::IPluginV2DynamicExt *GridPriorsTRT::clone() const TRT_NOEXCEPT { - GridPriorsTRT *plugin = new GridPriorsTRT(mLayerName, mStride); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs GridPriorsTRT::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - // input[0] == base_anchor - // input[1] == empty_h - // input[2] == empty_w - - nvinfer1::DimsExprs ret; - ret.nbDims = 2; - auto area = - exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[2].d[0], *inputs[1].d[0]); - ret.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *area, *(inputs[0].d[0])); - ret.d[1] = exprBuilder.constant(4); - - return ret; -} - -bool GridPriorsTRT::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 0) { - return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - } else if (pos - nbInputs == 0) { - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - } else { - return true; - } -} - -int GridPriorsTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - int num_base_anchors = inputDesc[0].dims.d[0]; - int feat_h = inputDesc[1].dims.d[0]; - int feat_w = inputDesc[2].dims.d[0]; - - const void *base_anchor = inputs[0]; - void *output = outputs[0]; - - auto data_type = inputDesc[0].type; - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - trt_grid_priors_impl((float *)base_anchor, (float *)output, num_base_anchors, feat_w, - feat_h, mStride.d[0], mStride.d[1], stream); - break; - default: - return 1; - } - - return 0; -} - -nvinfer1::DataType GridPriorsTRT::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *GridPriorsTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *GridPriorsTRT::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int GridPriorsTRT::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t GridPriorsTRT::getSerializationSize() const TRT_NOEXCEPT { return serialized_size(mStride); } - -void GridPriorsTRT::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mStride); - ; -} - -////////////////////// creator ///////////////////////////// - -GridPriorsTRTCreator::GridPriorsTRTCreator() { - mPluginAttributes.clear(); - mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_h")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_w")); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *GridPriorsTRTCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *GridPriorsTRTCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -nvinfer1::IPluginV2 *GridPriorsTRTCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - int stride_w = 1; - int stride_h = 1; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; - } - std::string field_name(fc->fields[i].name); - - if (field_name.compare("stride_w") == 0) { - stride_w = static_cast(fc->fields[i].data)[0]; - } - if (field_name.compare("stride_h") == 0) { - stride_h = static_cast(fc->fields[i].data)[0]; - } - } - nvinfer1::Dims stride{2, {stride_w, stride_h}}; - - GridPriorsTRT *plugin = new GridPriorsTRT(name, stride); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *GridPriorsTRTCreator::deserializePlugin(const char *name, - const void *serialData, - size_t serialLength) TRT_NOEXCEPT { - auto plugin = new GridPriorsTRT(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(GridPriorsTRTCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"GridPriorsTRT"}; + } // namespace + + GridPriorsTRT::GridPriorsTRT(const std::string& name, const nvinfer1::Dims stride) + : TRTPluginBase(name) + , mStride(stride) + { + } + + GridPriorsTRT::GridPriorsTRT(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mStride); + } + GridPriorsTRT::~GridPriorsTRT() {} + + nvinfer1::IPluginV2DynamicExt* GridPriorsTRT::clone() const TRT_NOEXCEPT + { + GridPriorsTRT* plugin = new GridPriorsTRT(mLayerName, mStride); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs GridPriorsTRT::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + // input[0] == base_anchor + // input[1] == empty_h + // input[2] == empty_w + + nvinfer1::DimsExprs ret; + ret.nbDims = 2; + auto area = + exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *inputs[2].d[0], *inputs[1].d[0]); + ret.d[0] = exprBuilder.operation(nvinfer1::DimensionOperation::kPROD, *area, *(inputs[0].d[0])); + ret.d[1] = exprBuilder.constant(4); + + return ret; + } + + bool GridPriorsTRT::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT + { + if (pos == 0) + { + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } + else if (pos - nbInputs == 0) + { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } + else + { + return true; + } + } + + int GridPriorsTRT::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int num_base_anchors = inputDesc[0].dims.d[0]; + int feat_h = inputDesc[1].dims.d[0]; + int feat_w = inputDesc[2].dims.d[0]; + + const void* base_anchor = inputs[0]; + void* output = outputs[0]; + + auto data_type = inputDesc[0].type; + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + trt_grid_priors_impl((float*)base_anchor, (float*)output, num_base_anchors, feat_w, feat_h, mStride.d[0], mStride.d[1], stream); + break; + default: + return 1; + } + + return 0; + } + + nvinfer1::DataType GridPriorsTRT::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* GridPriorsTRT::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* GridPriorsTRT::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int GridPriorsTRT::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t GridPriorsTRT::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mStride); + } + + void GridPriorsTRT::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mStride); + ; + } + + ////////////////////// creator ///////////////////////////// + + GridPriorsTRTCreator::GridPriorsTRTCreator() + { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_h")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride_w")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* GridPriorsTRTCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* GridPriorsTRTCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* GridPriorsTRTCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + int stride_w = 1; + int stride_h = 1; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("stride_w") == 0) + { + stride_w = static_cast(fc->fields[i].data)[0]; + } + if (field_name.compare("stride_h") == 0) + { + stride_h = static_cast(fc->fields[i].data)[0]; + } + } + nvinfer1::Dims stride{2, {stride_w, stride_h}}; + + GridPriorsTRT* plugin = new GridPriorsTRT(name, stride); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* GridPriorsTRTCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new GridPriorsTRT(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + REGISTER_TENSORRT_PLUGIN(GridPriorsTRTCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp index 0036f62586..a555b2d54a 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors.hpp @@ -9,58 +9,60 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class GridPriorsTRT : public TRTPluginBase { - public: - GridPriorsTRT(const std::string &name, const nvinfer1::Dims stride); +namespace mmdeploy +{ + class GridPriorsTRT : public TRTPluginBase + { + public: + GridPriorsTRT(const std::string& name, const nvinfer1::Dims stride); - GridPriorsTRT(const std::string name, const void *data, size_t length); + GridPriorsTRT(const std::string name, const void* data, size_t length); - GridPriorsTRT() = delete; + GridPriorsTRT() = delete; - ~GridPriorsTRT() TRT_NOEXCEPT override; + ~GridPriorsTRT() TRT_NOEXCEPT override; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - private: - nvinfer1::Dims mStride; + private: + nvinfer1::Dims mStride; - cublasHandle_t m_cublas_handle; -}; + cublasHandle_t m_cublas_handle; + }; -class GridPriorsTRTCreator : public TRTPluginCreatorBase { - public: - GridPriorsTRTCreator(); + class GridPriorsTRTCreator : public TRTPluginCreatorBase + { + public: + GridPriorsTRTCreator(); - const char *getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_GRID_PRIORS_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.cu index 72c33d179a..9decc3ba6e 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.cu @@ -5,39 +5,42 @@ #include "trt_grid_priors_kernel.hpp" #include "trt_plugin_helper.hpp" -template -__global__ void trt_grid_priors_kernel(const scalar_t* base_anchor, scalar_t* output, - int num_base_anchors, int feat_w, int feat_h, int stride_w, - int stride_h) { - // load base anchor into shared memory. - extern __shared__ scalar_t shared_base_anchor[]; - for (int i = threadIdx.x; i < num_base_anchors * 4; i += blockDim.x) { - shared_base_anchor[i] = base_anchor[i]; - } - __syncthreads(); +template +__global__ void trt_grid_priors_kernel(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors, int feat_w, int feat_h, int stride_w, int stride_h) +{ + // load base anchor into shared memory. + extern __shared__ scalar_t shared_base_anchor[]; + for (int i = threadIdx.x; i < num_base_anchors * 4; i += blockDim.x) + { + shared_base_anchor[i] = base_anchor[i]; + } + __syncthreads(); - CUDA_1D_KERNEL_LOOP(index, num_base_anchors * feat_w * feat_h) { - const int a_offset = (index % num_base_anchors) << 2; - const scalar_t w = scalar_t(((index / num_base_anchors) % feat_w) * stride_w); - const scalar_t h = scalar_t((index / (feat_w * num_base_anchors)) * stride_h); + CUDA_1D_KERNEL_LOOP(index, num_base_anchors * feat_w * feat_h) + { + const int a_offset = (index % num_base_anchors) << 2; + const scalar_t w = scalar_t(((index / num_base_anchors) % feat_w) * stride_w); + const scalar_t h = scalar_t((index / (feat_w * num_base_anchors)) * stride_h); - auto out_start = output + index * 4; - out_start[0] = shared_base_anchor[a_offset] + w; - out_start[1] = shared_base_anchor[a_offset + 1] + h; - out_start[2] = shared_base_anchor[a_offset + 2] + w; - out_start[3] = shared_base_anchor[a_offset + 3] + h; - } + auto out_start = output + index * 4; + out_start[0] = shared_base_anchor[a_offset] + w; + out_start[1] = shared_base_anchor[a_offset + 1] + h; + out_start[2] = shared_base_anchor[a_offset + 2] + w; + out_start[3] = shared_base_anchor[a_offset + 3] + h; + } } -template -void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors, - int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream) { - trt_grid_priors_kernel<<>>( - base_anchor, output, (int)num_base_anchors, (int)feat_w, (int)feat_h, (int)stride_w, - (int)stride_h); +template +void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors, int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream) +{ + trt_grid_priors_kernel<<>>( + base_anchor, + output, + (int)num_base_anchors, + (int)feat_w, + (int)feat_h, + (int)stride_w, + (int)stride_h); } -template void trt_grid_priors_impl(const float* base_anchor, float* output, - int num_base_anchors, int feat_w, int feat_h, - int stride_w, int stride_h, cudaStream_t stream); +template void trt_grid_priors_impl(const float* base_anchor, float* output, int num_base_anchors, int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.hpp index 77cef58c54..e050eb1047 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_priors/trt_grid_priors_kernel.hpp @@ -3,8 +3,7 @@ #define TRT_GRID_PRIORS_KERNEL_HPP #include -template -void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors, - int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream); +template +void trt_grid_priors_impl(const scalar_t* base_anchor, scalar_t* output, int num_base_anchors, int feat_w, int feat_h, int stride_w, int stride_h, cudaStream_t stream); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.cpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.cpp index 7e55686902..0d7ebf32da 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.cpp @@ -9,194 +9,237 @@ #include "trt_plugin_helper.hpp" #include "trt_serialize.hpp" -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"grid_sampler"}; -} // namespace - -TRTGridSampler::TRTGridSampler(const std::string &name, int mode, int paddingMode, - bool alignCorners) - : TRTPluginBase(name), mMode(mode), mPaddingMode(paddingMode), mAlignCorners(alignCorners) {} - -TRTGridSampler::TRTGridSampler(const std::string name, const void *data, size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mMode); - deserialize_value(&data, &length, &mPaddingMode); - deserialize_value(&data, &length, &mAlignCorners); -} - -nvinfer1::IPluginV2DynamicExt *TRTGridSampler::clone() const TRT_NOEXCEPT { - TRTGridSampler *plugin = new TRTGridSampler(mLayerName, mMode, mPaddingMode, mAlignCorners); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs TRTGridSampler::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - nvinfer1::DimsExprs ret; - ret.nbDims = inputs[0].nbDims; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[0].d[1]; - for (int i = 2; i < ret.nbDims; ++i) { - ret.d[i] = inputs[1].d[i - 1]; - } - return ret; -} - -bool TRTGridSampler::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 0) { - return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - } else { - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - } -} - -void TRTGridSampler::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT { - // Validate input arguments -} - -size_t TRTGridSampler::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int TRTGridSampler::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - nvinfer1::Dims input_dims = inputDesc[0].dims; - nvinfer1::Dims grid_dims = inputDesc[1].dims; - nvinfer1::Dims output_dims = outputDesc[0].dims; - - GridSamplerInterpolation interp_mode = GridSamplerInterpolation::Bilinear; - switch (mMode) { - case 0: - interp_mode = GridSamplerInterpolation::Bilinear; - break; - case 1: - interp_mode = GridSamplerInterpolation::Nearest; - break; - default: - break; - } - - GridSamplerPadding padding_mode = GridSamplerPadding::Zeros; - switch (mPaddingMode) { - case 0: - padding_mode = GridSamplerPadding::Zeros; - break; - - case 1: - padding_mode = GridSamplerPadding::Border; - break; - - case 2: - padding_mode = GridSamplerPadding::Reflection; - break; - default: - break; - } - - auto data_type = inputDesc[0].type; - - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - grid_sample((float *)outputs[0], (float *)inputs[0], (float *)inputs[1], - &(output_dims.d[0]), &(input_dims.d[0]), &(grid_dims.d[0]), - input_dims.nbDims, interp_mode, padding_mode, mAlignCorners, stream); - break; - default: - return 1; - break; - } - - return 0; -} - -nvinfer1::DataType TRTGridSampler::getOutputDataType(int index, - const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *TRTGridSampler::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTGridSampler::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int TRTGridSampler::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t TRTGridSampler::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mMode) + serialized_size(mPaddingMode) + serialized_size(mAlignCorners); -} - -void TRTGridSampler::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mMode); - serialize_value(&buffer, mPaddingMode); - serialize_value(&buffer, mAlignCorners); -} - -////////////////////// creator ///////////////////////////// - -TRTGridSamplerCreator::TRTGridSamplerCreator() { - mPluginAttributes = std::vector( - {nvinfer1::PluginField("interpolation_mode"), nvinfer1::PluginField("padding_mode"), - nvinfer1::PluginField("align_corners")}); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *TRTGridSamplerCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTGridSamplerCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -nvinfer1::IPluginV2 *TRTGridSamplerCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - int mode = 0; - int paddingMode = 0; - bool alignCorners = false; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; - } - std::string field_name(fc->fields[i].name); - - if (field_name.compare("interpolation_mode") == 0) { - mode = static_cast(fc->fields[i].data)[0]; - } - - if (field_name.compare("padding_mode") == 0) { - paddingMode = static_cast(fc->fields[i].data)[0]; - } - - if (field_name.compare("align_corners") == 0) { - alignCorners = (bool)(static_cast(fc->fields[i].data)[0]); - } - } - - TRTGridSampler *plugin = new TRTGridSampler(name, mode, paddingMode, alignCorners); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"grid_sampler"}; + } // namespace + + TRTGridSampler::TRTGridSampler(const std::string& name, int mode, int paddingMode, bool alignCorners) + : TRTPluginBase(name) + , mMode(mode) + , mPaddingMode(paddingMode) + , mAlignCorners(alignCorners) + { + } + + TRTGridSampler::TRTGridSampler(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mMode); + deserialize_value(&data, &length, &mPaddingMode); + deserialize_value(&data, &length, &mAlignCorners); + } + + nvinfer1::IPluginV2DynamicExt* TRTGridSampler::clone() const TRT_NOEXCEPT + { + TRTGridSampler* plugin = new TRTGridSampler(mLayerName, mMode, mPaddingMode, mAlignCorners); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs TRTGridSampler::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + nvinfer1::DimsExprs ret; + ret.nbDims = inputs[0].nbDims; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + for (int i = 2; i < ret.nbDims; ++i) + { + ret.d[i] = inputs[1].d[i - 1]; + } + return ret; + } + + bool TRTGridSampler::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT + { + if (pos == 0) + { + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } + else + { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } + } + + void TRTGridSampler::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) TRT_NOEXCEPT + { + // Validate input arguments + } + + size_t TRTGridSampler::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int TRTGridSampler::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + nvinfer1::Dims input_dims = inputDesc[0].dims; + nvinfer1::Dims grid_dims = inputDesc[1].dims; + nvinfer1::Dims output_dims = outputDesc[0].dims; + + GridSamplerInterpolation interp_mode = GridSamplerInterpolation::Bilinear; + switch (mMode) + { + case 0: + interp_mode = GridSamplerInterpolation::Bilinear; + break; + case 1: + interp_mode = GridSamplerInterpolation::Nearest; + break; + default: + break; + } + + GridSamplerPadding padding_mode = GridSamplerPadding::Zeros; + switch (mPaddingMode) + { + case 0: + padding_mode = GridSamplerPadding::Zeros; + break; + + case 1: + padding_mode = GridSamplerPadding::Border; + break; + + case 2: + padding_mode = GridSamplerPadding::Reflection; + break; + default: + break; + } + + auto data_type = inputDesc[0].type; + + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + grid_sample((float*)outputs[0], (float*)inputs[0], (float*)inputs[1], &(output_dims.d[0]), &(input_dims.d[0]), &(grid_dims.d[0]), input_dims.nbDims, interp_mode, padding_mode, mAlignCorners, stream); + break; + default: + return 1; + break; + } + + return 0; + } + + nvinfer1::DataType TRTGridSampler::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* TRTGridSampler::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTGridSampler::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int TRTGridSampler::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t TRTGridSampler::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mMode) + serialized_size(mPaddingMode) + serialized_size(mAlignCorners); + } + + void TRTGridSampler::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mMode); + serialize_value(&buffer, mPaddingMode); + serialize_value(&buffer, mAlignCorners); + } + + ////////////////////// creator ///////////////////////////// -nvinfer1::IPluginV2 *TRTGridSamplerCreator::deserializePlugin(const char *name, - const void *serialData, - size_t serialLength) TRT_NOEXCEPT { - // This object will be deleted when the network is destroyed, which will - // call FCPluginDynamic::destroy() - auto plugin = new TRTGridSampler(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} + TRTGridSamplerCreator::TRTGridSamplerCreator() + { + mPluginAttributes = std::vector( + {nvinfer1::PluginField("interpolation_mode"), nvinfer1::PluginField("padding_mode"), nvinfer1::PluginField("align_corners")}); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTGridSamplerCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTGridSamplerCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* TRTGridSamplerCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + int mode = 0; + int paddingMode = 0; + bool alignCorners = false; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("interpolation_mode") == 0) + { + mode = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("padding_mode") == 0) + { + paddingMode = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("align_corners") == 0) + { + alignCorners = (bool)(static_cast(fc->fields[i].data)[0]); + } + } + + TRTGridSampler* plugin = new TRTGridSampler(name, mode, paddingMode, alignCorners); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* TRTGridSamplerCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + // This object will be deleted when the network is destroyed, which will + // call FCPluginDynamic::destroy() + auto plugin = new TRTGridSampler(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } -REGISTER_TENSORRT_PLUGIN(TRTGridSamplerCreator); + REGISTER_TENSORRT_PLUGIN(TRTGridSamplerCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.hpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.hpp index 0f62bce7c8..1fc41e5bb8 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler.hpp @@ -9,76 +9,74 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { +namespace mmdeploy +{ -class TRTGridSampler : public TRTPluginBase { - public: - TRTGridSampler(const std::string &name, int mode, int paddingMode, bool alignCorners); + class TRTGridSampler : public TRTPluginBase + { + public: + TRTGridSampler(const std::string& name, int mode, int paddingMode, bool alignCorners); - TRTGridSampler(const std::string name, const void *data, size_t length); + TRTGridSampler(const std::string name, const void* data, size_t length); - TRTGridSampler() = delete; + TRTGridSampler() = delete; - ~TRTGridSampler() TRT_NOEXCEPT override = default; + ~TRTGridSampler() TRT_NOEXCEPT override = default; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - private: - int mMode; - int mPaddingMode; - bool mAlignCorners; -}; + private: + int mMode; + int mPaddingMode; + bool mAlignCorners; + }; -class TRTGridSamplerCreator : public TRTPluginCreatorBase { - public: - TRTGridSamplerCreator(); + class TRTGridSamplerCreator : public TRTPluginCreatorBase + { + public: + TRTGridSamplerCreator(); - ~TRTGridSamplerCreator() TRT_NOEXCEPT override = default; + ~TRTGridSamplerCreator() TRT_NOEXCEPT override = default; - const char *getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_GRID_SAMPLER_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.cu index 5d83f98d2c..6dafbbb126 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.cu @@ -27,370 +27,434 @@ using mmdeploy::TensorDesc; // -1 --> -0.5 // +1 --> (size - 1) + 0.5 == size - 0.5 // scale_factor = size / 2 -template -static __forceinline__ __device__ scalar_t grid_sampler_unnormalize(scalar_t coord, int size, - bool align_corners) { - if (align_corners) { - // unnormalize coord from [-1, 1] to [0, size - 1] - return ((coord + 1.f) / 2) * (size - 1); - } else { - // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] - return ((coord + 1.f) * size - 1) / 2; - } +template +static __forceinline__ __device__ scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) +{ + if (align_corners) + { + // unnormalize coord from [-1, 1] to [0, size - 1] + return ((coord + 1.f) / 2) * (size - 1); + } + else + { + // unnormalize coord from [-1, 1] to [-0.5, size - 0.5] + return ((coord + 1.f) * size - 1) / 2; + } } // Clips coordinates to between 0 and clip_limit - 1 -template -static __forceinline__ __device__ scalar_t clip_coordinates(scalar_t in, int clip_limit) { - return ::min(static_cast(clip_limit - 1), ::max(in, static_cast(0))); +template +static __forceinline__ __device__ scalar_t clip_coordinates(scalar_t in, int clip_limit) +{ + return ::min(static_cast(clip_limit - 1), ::max(in, static_cast(0))); } // Reflects coordinates until they fall between low and high (inclusive). // The bounds are passed as twice their value so that half-integer values // can be represented as ints. -template -static __forceinline__ __device__ scalar_t reflect_coordinates(scalar_t in, int twice_low, - int twice_high) { - if (twice_low == twice_high) { - return static_cast(0); - } - scalar_t min = static_cast(twice_low) / 2; - scalar_t span = static_cast(twice_high - twice_low) / 2; - in = ::fabs(in - min); - // `fmod` returns same sign as `in`, which is positive after the `fabs` above. - scalar_t extra = ::fmod(in, span); - int flips = static_cast(::floor(in / span)); - if (flips % 2 == 0) { - return extra + min; - } else { - return span - extra + min; - } +template +static __forceinline__ __device__ scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) +{ + if (twice_low == twice_high) + { + return static_cast(0); + } + scalar_t min = static_cast(twice_low) / 2; + scalar_t span = static_cast(twice_high - twice_low) / 2; + in = ::fabs(in - min); + // `fmod` returns same sign as `in`, which is positive after the `fabs` above. + scalar_t extra = ::fmod(in, span); + int flips = static_cast(::floor(in / span)); + if (flips % 2 == 0) + { + return extra + min; + } + else + { + return span - extra + min; + } } -template -static __forceinline__ __device__ scalar_t safe_downgrade_to_int_range(scalar_t x) { - // -100.0 does not have special meaning. This is just to make sure - // it's not within_bounds_2d or within_bounds_3d, and does not cause - // undefined behavior. See #35506. - if (x > INT_MAX - 1 || x < INT_MIN || !::isfinite(static_cast(x))) - return static_cast(-100.0); - return x; +template +static __forceinline__ __device__ scalar_t safe_downgrade_to_int_range(scalar_t x) +{ + // -100.0 does not have special meaning. This is just to make sure + // it's not within_bounds_2d or within_bounds_3d, and does not cause + // undefined behavior. See #35506. + if (x > INT_MAX - 1 || x < INT_MIN || !::isfinite(static_cast(x))) + return static_cast(-100.0); + return x; } // Computes the pixel source index value for a grid coordinate -template +template static __forceinline__ __device__ scalar_t grid_sampler_compute_source_index( - scalar_t coord, int size, GridSamplerPadding padding_mode, bool align_corners) { - coord = grid_sampler_unnormalize(coord, size, align_corners); - if (padding_mode == GridSamplerPadding::Border) { - // clip coordinates to image borders - coord = clip_coordinates(coord, size); - } else if (padding_mode == GridSamplerPadding::Reflection) { - // reflect coordinates by image borders - if (align_corners) { - coord = reflect_coordinates(coord, 0, 2 * (size - 1)); - } else { - coord = reflect_coordinates(coord, -1, 2 * size - 1); + scalar_t coord, + int size, + GridSamplerPadding padding_mode, + bool align_corners) +{ + coord = grid_sampler_unnormalize(coord, size, align_corners); + if (padding_mode == GridSamplerPadding::Border) + { + // clip coordinates to image borders + coord = clip_coordinates(coord, size); + } + else if (padding_mode == GridSamplerPadding::Reflection) + { + // reflect coordinates by image borders + if (align_corners) + { + coord = reflect_coordinates(coord, 0, 2 * (size - 1)); + } + else + { + coord = reflect_coordinates(coord, -1, 2 * size - 1); + } + // clip coordinates to image borders + coord = clip_coordinates(coord, size); } - // clip coordinates to image borders - coord = clip_coordinates(coord, size); - } - coord = safe_downgrade_to_int_range(coord); - return coord; + coord = safe_downgrade_to_int_range(coord); + return coord; } -static __forceinline__ __device__ bool within_bounds_2d(int h, int w, int H, int W) { - return h >= 0 && h < H && w >= 0 && w < W; +static __forceinline__ __device__ bool within_bounds_2d(int h, int w, int H, int W) +{ + return h >= 0 && h < H && w >= 0 && w < W; } -static __forceinline__ __device__ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) { - return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; +static __forceinline__ __device__ bool within_bounds_3d(int d, int h, int w, int D, int H, int W) +{ + return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W; } -template -__global__ void grid_sampler_2d_kernel(const int nthreads, const scalar_t *input, - const scalar_t *grid, scalar_t *output, - TensorDesc input_desc, TensorDesc grid_desc, - TensorDesc output_desc, - const GridSamplerInterpolation interpolation_mode, - const GridSamplerPadding padding_mode, bool align_corners) { - int C = input_desc.shape[1]; - int inp_H = input_desc.shape[2]; - int inp_W = input_desc.shape[3]; - int out_H = grid_desc.shape[1]; - int out_W = grid_desc.shape[2]; - int inp_sN = input_desc.stride[0]; - int inp_sC = input_desc.stride[1]; - int inp_sH = input_desc.stride[2]; - int inp_sW = input_desc.stride[3]; - int grid_sN = grid_desc.stride[0]; - int grid_sH = grid_desc.stride[1]; - int grid_sW = grid_desc.stride[2]; - int grid_sCoor = grid_desc.stride[3]; - int out_sN = output_desc.stride[0]; - int out_sC = output_desc.stride[1]; - int out_sH = output_desc.stride[2]; - int out_sW = output_desc.stride[3]; - - CUDA_1D_KERNEL_LOOP(index, nthreads) { - const int w = index % out_W; - const int h = (index / out_W) % out_H; - const int n = index / (out_H * out_W); - const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; - - // get the corresponding input x, y coordinates from grid - scalar_t ix = grid[grid_offset]; - scalar_t iy = grid[grid_offset + grid_sCoor]; - - ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); - iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); - - if (interpolation_mode == GridSamplerInterpolation::Bilinear) { - // get NE, NW, SE, SW pixel values from (x, y) - int ix_nw = static_cast(::floor(ix)); - int iy_nw = static_cast(::floor(iy)); - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; - - // get surfaces to each neighbor: - scalar_t nw = (ix_se - ix) * (iy_se - iy); - scalar_t ne = (ix - ix_sw) * (iy_sw - iy); - scalar_t sw = (ix_ne - ix) * (iy - iy_ne); - scalar_t se = (ix - ix_nw) * (iy - iy_nw); - - // calculate bilinear weighted pixel value and set output pixel - auto inp_ptr_NC = input + n * inp_sN; - auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; - for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { - *out_ptr_NCHW = static_cast(0); - if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) { - *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; - } - if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) { - *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; +template +__global__ void grid_sampler_2d_kernel(const int nthreads, const scalar_t* input, const scalar_t* grid, scalar_t* output, TensorDesc input_desc, TensorDesc grid_desc, TensorDesc output_desc, const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, bool align_corners) +{ + int C = input_desc.shape[1]; + int inp_H = input_desc.shape[2]; + int inp_W = input_desc.shape[3]; + int out_H = grid_desc.shape[1]; + int out_W = grid_desc.shape[2]; + int inp_sN = input_desc.stride[0]; + int inp_sC = input_desc.stride[1]; + int inp_sH = input_desc.stride[2]; + int inp_sW = input_desc.stride[3]; + int grid_sN = grid_desc.stride[0]; + int grid_sH = grid_desc.stride[1]; + int grid_sW = grid_desc.stride[2]; + int grid_sCoor = grid_desc.stride[3]; + int out_sN = output_desc.stride[0]; + int out_sC = output_desc.stride[1]; + int out_sH = output_desc.stride[2]; + int out_sW = output_desc.stride[3]; + + CUDA_1D_KERNEL_LOOP(index, nthreads) + { + const int w = index % out_W; + const int h = (index / out_W) % out_H; + const int n = index / (out_H * out_W); + const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y coordinates from grid + scalar_t ix = grid[grid_offset]; + scalar_t iy = grid[grid_offset + grid_sCoor]; + + ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); + iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) + { + // get NE, NW, SE, SW pixel values from (x, y) + int ix_nw = static_cast(::floor(ix)); + int iy_nw = static_cast(::floor(iy)); + int ix_ne = ix_nw + 1; + int iy_ne = iy_nw; + int ix_sw = ix_nw; + int iy_sw = iy_nw + 1; + int ix_se = ix_nw + 1; + int iy_se = iy_nw + 1; + + // get surfaces to each neighbor: + scalar_t nw = (ix_se - ix) * (iy_se - iy); + scalar_t ne = (ix - ix_sw) * (iy_sw - iy); + scalar_t sw = (ix_ne - ix) * (iy - iy_ne); + scalar_t se = (ix - ix_nw) * (iy - iy_nw); + + // calculate bilinear weighted pixel value and set output pixel + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) + { + *out_ptr_NCHW = static_cast(0); + if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) + { + *out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) + { + *out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) + { + *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) + { + *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; + } + } } - if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) { - *out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw; + else if (interpolation_mode == GridSamplerInterpolation::Nearest) + { + int ix_nearest = static_cast(::round(ix)); + int iy_nearest = static_cast(::round(iy)); + + // assign nearest neighbor pixel value to output pixel + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) + { + if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) + { + *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; + } + else + { + *out_ptr_NCHW = static_cast(0); + } + } } - if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) { - *out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se; - } - } - } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { - int ix_nearest = static_cast(::round(ix)); - int iy_nearest = static_cast(::round(iy)); - - // assign nearest neighbor pixel value to output pixel - auto inp_ptr_NC = input + n * inp_sN; - auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; - for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) { - if (within_bounds_2d(iy_nearest, ix_nearest, inp_H, inp_W)) { - *out_ptr_NCHW = inp_ptr_NC[iy_nearest * inp_sH + ix_nearest * inp_sW]; - } else { - *out_ptr_NCHW = static_cast(0); - } - } } - } } -template -__global__ void grid_sampler_3d_kernel(const int nthreads, const scalar_t *input, - const scalar_t *grid, scalar_t *output, - TensorDesc input_desc, TensorDesc grid_desc, - TensorDesc output_desc, - const GridSamplerInterpolation interpolation_mode, - const GridSamplerPadding padding_mode, bool align_corners) { - int C = input_desc.shape[1]; - int inp_D = input_desc.shape[2]; - int inp_H = input_desc.shape[3]; - int inp_W = input_desc.shape[4]; - int out_D = grid_desc.shape[1]; - int out_H = grid_desc.shape[2]; - int out_W = grid_desc.shape[3]; - int inp_sN = input_desc.stride[0]; - int inp_sC = input_desc.stride[1]; - int inp_sD = input_desc.stride[2]; - int inp_sH = input_desc.stride[3]; - int inp_sW = input_desc.stride[4]; - int grid_sN = grid_desc.stride[0]; - int grid_sD = grid_desc.stride[1]; - int grid_sH = grid_desc.stride[2]; - int grid_sW = grid_desc.stride[3]; - int grid_sCoor = grid_desc.stride[4]; - int out_sN = output_desc.stride[0]; - int out_sC = output_desc.stride[1]; - int out_sD = output_desc.stride[2]; - int out_sH = output_desc.stride[3]; - int out_sW = output_desc.stride[4]; - - CUDA_1D_KERNEL_LOOP(index, nthreads) { - const int w = index % out_W; - const int h = (index / out_W) % out_H; - const int d = (index / (out_H * out_W)) % out_D; - const int n = index / (out_D * out_H * out_W); - const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; - - // get the corresponding input x, y, z coordinates from grid - scalar_t ix = grid[grid_offset]; - scalar_t iy = grid[grid_offset + grid_sCoor]; - scalar_t iz = grid[grid_offset + 2 * grid_sCoor]; - - ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); - iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); - iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners); - - if (interpolation_mode == GridSamplerInterpolation::Bilinear) { - // get corner pixel values from (x, y, z) - // for 4d, we used north-east-south-west - // for 5d, we add top-bottom - int ix_tnw = static_cast(::floor(ix)); - int iy_tnw = static_cast(::floor(iy)); - int iz_tnw = static_cast(::floor(iz)); - - int ix_tne = ix_tnw + 1; - int iy_tne = iy_tnw; - int iz_tne = iz_tnw; - - int ix_tsw = ix_tnw; - int iy_tsw = iy_tnw + 1; - int iz_tsw = iz_tnw; - - int ix_tse = ix_tnw + 1; - int iy_tse = iy_tnw + 1; - int iz_tse = iz_tnw; - - int ix_bnw = ix_tnw; - int iy_bnw = iy_tnw; - int iz_bnw = iz_tnw + 1; - - int ix_bne = ix_tnw + 1; - int iy_bne = iy_tnw; - int iz_bne = iz_tnw + 1; - - int ix_bsw = ix_tnw; - int iy_bsw = iy_tnw + 1; - int iz_bsw = iz_tnw + 1; - - int ix_bse = ix_tnw + 1; - int iy_bse = iy_tnw + 1; - int iz_bse = iz_tnw + 1; - - // get surfaces to each neighbor: - scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); - scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); - scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); - scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); - scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); - scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); - scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); - scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); - - auto inp_ptr_NC = input + n * inp_sN; - auto out_ptr_NCDHW = output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; - for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { - // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * - // tne - // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * - // tse - // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * - // bne - // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * - // bse - *out_ptr_NCDHW = static_cast(0); - if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; - } - if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; +template +__global__ void grid_sampler_3d_kernel(const int nthreads, const scalar_t* input, const scalar_t* grid, scalar_t* output, TensorDesc input_desc, TensorDesc grid_desc, TensorDesc output_desc, const GridSamplerInterpolation interpolation_mode, const GridSamplerPadding padding_mode, bool align_corners) +{ + int C = input_desc.shape[1]; + int inp_D = input_desc.shape[2]; + int inp_H = input_desc.shape[3]; + int inp_W = input_desc.shape[4]; + int out_D = grid_desc.shape[1]; + int out_H = grid_desc.shape[2]; + int out_W = grid_desc.shape[3]; + int inp_sN = input_desc.stride[0]; + int inp_sC = input_desc.stride[1]; + int inp_sD = input_desc.stride[2]; + int inp_sH = input_desc.stride[3]; + int inp_sW = input_desc.stride[4]; + int grid_sN = grid_desc.stride[0]; + int grid_sD = grid_desc.stride[1]; + int grid_sH = grid_desc.stride[2]; + int grid_sW = grid_desc.stride[3]; + int grid_sCoor = grid_desc.stride[4]; + int out_sN = output_desc.stride[0]; + int out_sC = output_desc.stride[1]; + int out_sD = output_desc.stride[2]; + int out_sH = output_desc.stride[3]; + int out_sW = output_desc.stride[4]; + + CUDA_1D_KERNEL_LOOP(index, nthreads) + { + const int w = index % out_W; + const int h = (index / out_W) % out_H; + const int d = (index / (out_H * out_W)) % out_D; + const int n = index / (out_D * out_H * out_W); + const int grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z coordinates from grid + scalar_t ix = grid[grid_offset]; + scalar_t iy = grid[grid_offset + grid_sCoor]; + scalar_t iz = grid[grid_offset + 2 * grid_sCoor]; + + ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners); + iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners); + iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners); + + if (interpolation_mode == GridSamplerInterpolation::Bilinear) + { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + int ix_tnw = static_cast(::floor(ix)); + int iy_tnw = static_cast(::floor(iy)); + int iz_tnw = static_cast(::floor(iz)); + + int ix_tne = ix_tnw + 1; + int iy_tne = iy_tnw; + int iz_tne = iz_tnw; + + int ix_tsw = ix_tnw; + int iy_tsw = iy_tnw + 1; + int iz_tsw = iz_tnw; + + int ix_tse = ix_tnw + 1; + int iy_tse = iy_tnw + 1; + int iz_tse = iz_tnw; + + int ix_bnw = ix_tnw; + int iy_bnw = iy_tnw; + int iz_bnw = iz_tnw + 1; + + int ix_bne = ix_tnw + 1; + int iy_bne = iy_tnw; + int iz_bne = iz_tnw + 1; + + int ix_bsw = ix_tnw; + int iy_bsw = iy_tnw + 1; + int iz_bsw = iz_tnw + 1; + + int ix_bse = ix_tnw + 1; + int iy_bse = iy_tnw + 1; + int iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCDHW = output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) + { + // (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * + // tne + // + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * + // tse + // + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * + // bne + // + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * + // bse + *out_ptr_NCDHW = static_cast(0); + if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw; + } + if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne; + } + if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; + } + if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; + } + if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; + } + if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; + } + if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; + } + if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; + } + } } - if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw; + else if (interpolation_mode == GridSamplerInterpolation::Nearest) + { + int ix_nearest = static_cast(::round(ix)); + int iy_nearest = static_cast(::round(iy)); + int iz_nearest = static_cast(::round(iz)); + + // assign nearest neighbor pixel value to output pixel + auto inp_ptr_NC = input + n * inp_sN; + auto out_ptr_NCDHW = output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; + for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) + { + if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) + { + *out_ptr_NCDHW = + inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; + } + else + { + *out_ptr_NCDHW = static_cast(0); + } + } } - if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse; - } - if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw; - } - if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne; - } - if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw; - } - if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse; - } - } - } else if (interpolation_mode == GridSamplerInterpolation::Nearest) { - int ix_nearest = static_cast(::round(ix)); - int iy_nearest = static_cast(::round(iy)); - int iz_nearest = static_cast(::round(iz)); - - // assign nearest neighbor pixel value to output pixel - auto inp_ptr_NC = input + n * inp_sN; - auto out_ptr_NCDHW = output + n * out_sN + d * out_sD + h * out_sH + w * out_sW; - for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { - if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) { - *out_ptr_NCDHW = - inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW]; - } else { - *out_ptr_NCDHW = static_cast(0); - } - } } - } } -void create_desc(const int *dims, int nb_dims, TensorDesc &desc) { - memcpy(&desc.shape[0], dims, sizeof(int) * nb_dims); - desc.stride[nb_dims - 1] = 1; - for (int i = nb_dims - 2; i >= 0; --i) { - desc.stride[i] = desc.stride[i + 1] * desc.shape[i + 1]; - } +void create_desc(const int* dims, int nb_dims, TensorDesc& desc) +{ + memcpy(&desc.shape[0], dims, sizeof(int) * nb_dims); + desc.stride[nb_dims - 1] = 1; + for (int i = nb_dims - 2; i >= 0; --i) + { + desc.stride[i] = desc.stride[i + 1] * desc.shape[i + 1]; + } } -template -void grid_sample(T *output, const T *input, const T *grid, int *output_dims, int *input_dims, - int *grid_dims, int nb_dims, GridSamplerInterpolation interp, - GridSamplerPadding padding, bool align_corners, cudaStream_t stream) { - TensorDesc input_desc; - create_desc(input_dims, nb_dims, input_desc); +template +void grid_sample(T* output, const T* input, const T* grid, int* output_dims, int* input_dims, int* grid_dims, int nb_dims, GridSamplerInterpolation interp, GridSamplerPadding padding, bool align_corners, cudaStream_t stream) +{ + TensorDesc input_desc; + create_desc(input_dims, nb_dims, input_desc); + + TensorDesc output_desc; + create_desc(output_dims, nb_dims, output_desc); - TensorDesc output_desc; - create_desc(output_dims, nb_dims, output_desc); + TensorDesc grid_desc; + create_desc(grid_dims, nb_dims, grid_desc); - TensorDesc grid_desc; - create_desc(grid_dims, nb_dims, grid_desc); + int count = 1; + for (int i = 0; i < nb_dims; ++i) + { + if (i == 1) + { + continue; + } + count *= output_desc.shape[i]; + } - int count = 1; - for (int i = 0; i < nb_dims; ++i) { - if (i == 1) { - continue; + if (nb_dims == 4) + { + grid_sampler_2d_kernel<<>>( + count, + input, + grid, + output, + input_desc, + grid_desc, + output_desc, + interp, + padding, + align_corners); + } + else if (nb_dims == 5) + { + grid_sampler_3d_kernel<<>>( + count, + input, + grid, + output, + input_desc, + grid_desc, + output_desc, + interp, + padding, + align_corners); + } + else + { + printf("input and grid dims should be 4 or 5\n"); } - count *= output_desc.shape[i]; - } - - if (nb_dims == 4) { - grid_sampler_2d_kernel<<>>( - count, input, grid, output, input_desc, grid_desc, output_desc, interp, padding, - align_corners); - } else if (nb_dims == 5) { - grid_sampler_3d_kernel<<>>( - count, input, grid, output, input_desc, grid_desc, output_desc, interp, padding, - align_corners); - } else { - printf("input and grid dims should be 4 or 5\n"); - } } -template void grid_sample(float *output, const float *input, const float *grid, - int *output_dims, int *input_dims, int *grid_dims, int nb_dims, - GridSamplerInterpolation interp, GridSamplerPadding padding, - bool align_corners, cudaStream_t stream); +template void grid_sample(float* output, const float* input, const float* grid, int* output_dims, int* input_dims, int* grid_dims, int nb_dims, GridSamplerInterpolation interp, GridSamplerPadding padding, bool align_corners, cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.hpp index e4e50332f4..b73bd91213 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/grid_sampler/trt_grid_sampler_kernel.hpp @@ -3,11 +3,18 @@ #define TRT_GRID_SAMPLER_KERNEL_HPP #include -enum class GridSamplerInterpolation { Bilinear, Nearest }; -enum class GridSamplerPadding { Zeros, Border, Reflection }; +enum class GridSamplerInterpolation +{ + Bilinear, + Nearest +}; +enum class GridSamplerPadding +{ + Zeros, + Border, + Reflection +}; -template -void grid_sample(T *output, const T *input, const T *grid, int *output_dims, int *input_dims, - int *grid_dims, int nb_dims, GridSamplerInterpolation interp, - GridSamplerPadding padding, bool align_corners, cudaStream_t stream); +template +void grid_sample(T* output, const T* input, const T* grid, int* output_dims, int* input_dims, int* grid_dims, int nb_dims, GridSamplerInterpolation interp, GridSamplerPadding padding, bool align_corners, cudaStream_t stream); #endif // TRT_GRID_SAMPLER_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.cpp b/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.cpp index e6aab92f4c..a3ead6d507 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.cpp @@ -12,203 +12,241 @@ using namespace nvinfer1; -namespace mmdeploy { -namespace { -constexpr const char* PLUGIN_VERSION{"1"}; -constexpr const char* PLUGIN_NAME{"TRTInstanceNormalization"}; -} // namespace - -TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name, float epsilon) - : TRTPluginBase(name), mEpsilon(epsilon) {} - -TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name, void const* serialData, - size_t serialLength) - : TRTPluginBase(name) { - deserialize_value(&serialData, &serialLength, &mEpsilon); -} - -TRTInstanceNormalization::~TRTInstanceNormalization() {} - -// TRTInstanceNormalization returns one output. -int TRTInstanceNormalization::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -DimsExprs TRTInstanceNormalization::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT { - nvinfer1::DimsExprs output(inputs[0]); - return output; -} - -size_t TRTInstanceNormalization::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT { - int n = inputs[0].dims.d[0]; - int c = inputs[0].dims.d[1]; - int elem_size = sizeof(float); - return getAlignedSize(n * c * elem_size) * 2; -} - -int TRTInstanceNormalization::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, - const void* const* inputs, void* const* outputs, - void* workspace, cudaStream_t stream) TRT_NOEXCEPT { - nvinfer1::Dims input_dims = inputDesc[0].dims; - int n = input_dims.d[0]; - int c = input_dims.d[1]; - int h = input_dims.d[2]; - int w = input_dims.nbDims > 3 ? input_dims.d[3] : 1; - int elem_size = sizeof(float); - - void* n_scales = (void*)workspace; - void* n_bias = (void*)((char*)workspace + getAlignedSize(n * c * elem_size)); - - const void* scales = (const void*)inputs[1]; - const void* bias = (const void*)inputs[2]; - - for (int i = 0; i < n; ++i) { - cudaMemcpyAsync((char*)n_scales + i * c * elem_size, scales, c * elem_size, - cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync((char*)n_bias + i * c * elem_size, bias, c * elem_size, - cudaMemcpyDeviceToDevice, stream); - } - - cudnnSetTensor4dDescriptor(_b_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1); - cudnnDataType_t cudnn_dtype{}; - convert_trt2cudnn_dtype(inputDesc[0].type, &cudnn_dtype); - cudnnSetTensor4dDescriptor(_x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w); - cudnnSetTensor4dDescriptor(_y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w); - float alpha = 1; - float beta = 0; - void const* x_ptr = inputs[0]; - void* y_ptr = outputs[0]; - cudnnSetStream(_cudnn_handle, stream); - // Note: Use of CUDNN_BATCHNORM_SPATIAL_PERSISTENT can cause numerical - // overflows (NaNs) for fp32 data in some circumstances. The lower- - // performance CUDNN_BATCHNORM_SPATIAL should be used if this is not - // acceptable. - cudnnBatchNormalizationForwardTraining(_cudnn_handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, - &beta, _x_desc, x_ptr, _y_desc, y_ptr, _b_desc, n_scales, - n_bias, 1., nullptr, nullptr, mEpsilon, nullptr, nullptr); - return 0; -} - -size_t TRTInstanceNormalization::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mEpsilon); -} - -void TRTInstanceNormalization::serialize(void* buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mEpsilon); -} - -bool TRTInstanceNormalization::supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc* ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - switch (pos) { - case 0: - case 3: - return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT || - ioDesc[pos].type == nvinfer1::DataType::kHALF) && - ioDesc[pos].format == nvinfer1::PluginFormat::kLINEAR && - ioDesc[pos].type == ioDesc[0].type); - case 1: - case 2: - return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::PluginFormat::kLINEAR; - default: - return false; - } - return false; -} - -const char* TRTInstanceNormalization::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char* TRTInstanceNormalization::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -IPluginV2DynamicExt* TRTInstanceNormalization::clone() const TRT_NOEXCEPT { - auto* plugin = new TRTInstanceNormalization{mLayerName, mEpsilon}; - plugin->setPluginNamespace(mPluginNamespace.c_str()); - return plugin; -} - -nvinfer1::DataType TRTInstanceNormalization::getOutputDataType(int index, - const nvinfer1::DataType* inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// Attach the plugin object to an execution context and grant the plugin the -// access to some context resource. -void TRTInstanceNormalization::attachToContext(cudnnContext* cudnnContext, - cublasContext* cublasContext, - IGpuAllocator* gpuAllocator) TRT_NOEXCEPT { - _cudnn_handle = cudnnContext; - cudnnCreateTensorDescriptor(&_b_desc); - cudnnCreateTensorDescriptor(&_x_desc); - cudnnCreateTensorDescriptor(&_y_desc); -} - -// Detach the plugin object from its execution context. -void TRTInstanceNormalization::detachFromContext() TRT_NOEXCEPT { - if (_y_desc) { - cudnnDestroyTensorDescriptor(_y_desc); - _y_desc = nullptr; - } - if (_x_desc) { - cudnnDestroyTensorDescriptor(_x_desc); - _x_desc = nullptr; - } - if (_b_desc) { - cudnnDestroyTensorDescriptor(_b_desc); - _b_desc = nullptr; - } -} - -void TRTInstanceNormalization::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, - int nbOutputs) TRT_NOEXCEPT {} - -// TRTInstanceNormalizationCreator methods -TRTInstanceNormalizationCreator::TRTInstanceNormalizationCreator() { - mPluginAttributes.clear(); - mPluginAttributes.emplace_back(PluginField("epsilon", nullptr, PluginFieldType::kFLOAT32, 1)); - - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char* TRTInstanceNormalizationCreator::getPluginName() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char* TRTInstanceNormalizationCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -IPluginV2DynamicExt* TRTInstanceNormalizationCreator::createPlugin( - const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT { - float epsilon = 1e-5; - const PluginField* fields = fc->fields; - for (int i = 0; i < fc->nbFields; ++i) { - const char* attrName = fields[i].name; - if (!strcmp(attrName, "epsilon")) { - epsilon = *(static_cast(fields[i].data)); - } - } - - TRTInstanceNormalization* obj = new TRTInstanceNormalization(name, epsilon); - obj->setPluginNamespace(mNamespace.c_str()); - return obj; -} - -IPluginV2DynamicExt* TRTInstanceNormalizationCreator::deserializePlugin( - const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT { - TRTInstanceNormalization* obj = new TRTInstanceNormalization{name, serialData, serialLength}; - obj->setPluginNamespace(mNamespace.c_str()); - return obj; -} -REGISTER_TENSORRT_PLUGIN(TRTInstanceNormalizationCreator); +namespace mmdeploy +{ + namespace + { + constexpr const char* PLUGIN_VERSION{"1"}; + constexpr const char* PLUGIN_NAME{"TRTInstanceNormalization"}; + } // namespace + + TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name, float epsilon) + : TRTPluginBase(name) + , mEpsilon(epsilon) + { + } + + TRTInstanceNormalization::TRTInstanceNormalization(const std::string& name, void const* serialData, size_t serialLength) + : TRTPluginBase(name) + { + deserialize_value(&serialData, &serialLength, &mEpsilon); + } + + TRTInstanceNormalization::~TRTInstanceNormalization() {} + + // TRTInstanceNormalization returns one output. + int TRTInstanceNormalization::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + DimsExprs TRTInstanceNormalization::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + nvinfer1::DimsExprs output(inputs[0]); + return output; + } + + size_t TRTInstanceNormalization::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + int n = inputs[0].dims.d[0]; + int c = inputs[0].dims.d[1]; + int elem_size = sizeof(float); + return getAlignedSize(n * c * elem_size) * 2; + } + + int TRTInstanceNormalization::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT + { + nvinfer1::Dims input_dims = inputDesc[0].dims; + int n = input_dims.d[0]; + int c = input_dims.d[1]; + int h = input_dims.d[2]; + int w = input_dims.nbDims > 3 ? input_dims.d[3] : 1; + int elem_size = sizeof(float); + + void* n_scales = (void*)workspace; + void* n_bias = (void*)((char*)workspace + getAlignedSize(n * c * elem_size)); + + const void* scales = (const void*)inputs[1]; + const void* bias = (const void*)inputs[2]; + + for (int i = 0; i < n; ++i) + { + cudaMemcpyAsync((char*)n_scales + i * c * elem_size, scales, c * elem_size, cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync((char*)n_bias + i * c * elem_size, bias, c * elem_size, cudaMemcpyDeviceToDevice, stream); + } + + cudnnSetTensor4dDescriptor(_b_desc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, n * c, 1, 1); + cudnnDataType_t cudnn_dtype{}; + convert_trt2cudnn_dtype(inputDesc[0].type, &cudnn_dtype); + cudnnSetTensor4dDescriptor(_x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w); + cudnnSetTensor4dDescriptor(_y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, n * c, h, w); + float alpha = 1; + float beta = 0; + void const* x_ptr = inputs[0]; + void* y_ptr = outputs[0]; + cudnnSetStream(_cudnn_handle, stream); + // Note: Use of CUDNN_BATCHNORM_SPATIAL_PERSISTENT can cause numerical + // overflows (NaNs) for fp32 data in some circumstances. The lower- + // performance CUDNN_BATCHNORM_SPATIAL should be used if this is not + // acceptable. + cudnnBatchNormalizationForwardTraining(_cudnn_handle, CUDNN_BATCHNORM_SPATIAL_PERSISTENT, &alpha, &beta, _x_desc, x_ptr, _y_desc, y_ptr, _b_desc, n_scales, n_bias, 1., nullptr, nullptr, mEpsilon, nullptr, nullptr); + return 0; + } + + size_t TRTInstanceNormalization::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mEpsilon); + } + + void TRTInstanceNormalization::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mEpsilon); + } + + bool TRTInstanceNormalization::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + switch (pos) + { + case 0: + case 3: + return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT || + ioDesc[pos].type == nvinfer1::DataType::kHALF) && + ioDesc[pos].format == nvinfer1::PluginFormat::kLINEAR && + ioDesc[pos].type == ioDesc[0].type); + case 1: + case 2: + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::PluginFormat::kLINEAR; + default: + return false; + } + return false; + } + + const char* TRTInstanceNormalization::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTInstanceNormalization::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + IPluginV2DynamicExt* TRTInstanceNormalization::clone() const TRT_NOEXCEPT + { + auto* plugin = new TRTInstanceNormalization{mLayerName, mEpsilon}; + plugin->setPluginNamespace(mPluginNamespace.c_str()); + return plugin; + } + + nvinfer1::DataType TRTInstanceNormalization::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // Attach the plugin object to an execution context and grant the plugin the + // access to some context resource. + void TRTInstanceNormalization::attachToContext(cudnnContext* cudnnContext, + cublasContext* cublasContext, + IGpuAllocator* gpuAllocator) TRT_NOEXCEPT + { + _cudnn_handle = cudnnContext; + cudnnCreateTensorDescriptor(&_b_desc); + cudnnCreateTensorDescriptor(&_x_desc); + cudnnCreateTensorDescriptor(&_y_desc); + } + + // Detach the plugin object from its execution context. + void TRTInstanceNormalization::detachFromContext() TRT_NOEXCEPT + { + if (_y_desc) + { + cudnnDestroyTensorDescriptor(_y_desc); + _y_desc = nullptr; + } + if (_x_desc) + { + cudnnDestroyTensorDescriptor(_x_desc); + _x_desc = nullptr; + } + if (_b_desc) + { + cudnnDestroyTensorDescriptor(_b_desc); + _b_desc = nullptr; + } + } + + void TRTInstanceNormalization::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT {} + + // TRTInstanceNormalizationCreator methods + TRTInstanceNormalizationCreator::TRTInstanceNormalizationCreator() + { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(PluginField("epsilon", nullptr, PluginFieldType::kFLOAT32, 1)); + + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTInstanceNormalizationCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTInstanceNormalizationCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + IPluginV2DynamicExt* TRTInstanceNormalizationCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + float epsilon = 1e-5; + const PluginField* fields = fc->fields; + for (int i = 0; i < fc->nbFields; ++i) + { + const char* attrName = fields[i].name; + if (!strcmp(attrName, "epsilon")) + { + epsilon = *(static_cast(fields[i].data)); + } + } + + TRTInstanceNormalization* obj = new TRTInstanceNormalization(name, epsilon); + obj->setPluginNamespace(mNamespace.c_str()); + return obj; + } + + IPluginV2DynamicExt* TRTInstanceNormalizationCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + TRTInstanceNormalization* obj = new TRTInstanceNormalization{name, serialData, serialLength}; + obj->setPluginNamespace(mNamespace.c_str()); + return obj; + } + REGISTER_TENSORRT_PLUGIN(TRTInstanceNormalizationCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.hpp b/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.hpp index 2df04a5f6d..d513a59301 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/instance_norm/trt_instance_norm.hpp @@ -14,80 +14,78 @@ typedef unsigned short half_type; -namespace mmdeploy { -class TRTInstanceNormalization final : public TRTPluginBase { - public: - TRTInstanceNormalization(const std::string& name, float epsilon); +namespace mmdeploy +{ + class TRTInstanceNormalization final : public TRTPluginBase + { + public: + TRTInstanceNormalization(const std::string& name, float epsilon); - TRTInstanceNormalization(const std::string& name, void const* serialData, size_t serialLength); + TRTInstanceNormalization(const std::string& name, void const* serialData, size_t serialLength); - TRTInstanceNormalization() = delete; + TRTInstanceNormalization() = delete; - ~TRTInstanceNormalization() TRT_NOEXCEPT override; + ~TRTInstanceNormalization() TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; - // DynamicExt plugins returns DimsExprs class instead of Dims - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, - int nbInputs, nvinfer1::IExprBuilder& exprBuilder) - TRT_NOEXCEPT override; + // DynamicExt plugins returns DimsExprs class instead of Dims + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, - const nvinfer1::PluginTensorDesc* outputs, - int nbOutputs) const TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, - const nvinfer1::PluginTensorDesc* outputDesc, const void* const* inputs, - void* const* outputs, void* workspace, cudaStream_t stream) TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void* buffer) const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - // DynamicExt plugin supportsFormat update. - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; + // DynamicExt plugin supportsFormat update. + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; - const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginType() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; - void attachToContext(cudnnContext* cudnn, cublasContext* cublas, - nvinfer1::IGpuAllocator* allocator) TRT_NOEXCEPT override; + void attachToContext(cudnnContext* cudnn, cublasContext* cublas, nvinfer1::IGpuAllocator* allocator) TRT_NOEXCEPT override; - void detachFromContext() TRT_NOEXCEPT override; + void detachFromContext() TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc* out, - int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; - private: - float mEpsilon{}; - cudnnHandle_t _cudnn_handle{}; - cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _b_desc{}; - std::string mPluginNamespace{}; -}; + private: + float mEpsilon{}; + cudnnHandle_t _cudnn_handle{}; + cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _b_desc{}; + std::string mPluginNamespace{}; + }; -class TRTInstanceNormalizationCreator : public TRTPluginCreatorBase { - public: - TRTInstanceNormalizationCreator(); + class TRTInstanceNormalizationCreator : public TRTPluginCreatorBase + { + public: + TRTInstanceNormalizationCreator(); - ~TRTInstanceNormalizationCreator() override = default; + ~TRTInstanceNormalizationCreator() override = default; - const char* getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char* getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2DynamicExt* createPlugin( - const char* name, const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; + nvinfer1::IPluginV2DynamicExt* createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT override; - nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name, const void* serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_INSTANCE_NORMALIZATION_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.cpp b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.cpp index 692000b740..363242e8e1 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.cpp @@ -10,297 +10,406 @@ using namespace nvinfer1; -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"MMCVModulatedDeformConv2d"}; -} // namespace - -ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic( - const std::string &name, const nvinfer1::Dims stride, const nvinfer1::Dims padding, - const nvinfer1::Dims dilation, const int deformableGroup, const int group) - : TRTPluginBase(name), - mStride(stride), - mPadding(padding), - mDilation(dilation), - mDeformableGroup(deformableGroup), - mGroup(group) { - mWithBias = false; -} - -ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic(const std::string name, - const void *data, - size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mStride); - deserialize_value(&data, &length, &mPadding); - deserialize_value(&data, &length, &mDilation); - deserialize_value(&data, &length, &mDeformableGroup); - deserialize_value(&data, &length, &mGroup); - mWithBias = false; -} -ModulatedDeformableConvPluginDynamic::~ModulatedDeformableConvPluginDynamic() {} - -nvinfer1::IPluginV2DynamicExt *ModulatedDeformableConvPluginDynamic::clone() const TRT_NOEXCEPT { - ModulatedDeformableConvPluginDynamic *plugin = new ModulatedDeformableConvPluginDynamic( - mLayerName, mStride, mPadding, mDilation, mDeformableGroup, mGroup); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -static const nvinfer1::IDimensionExpr *get_hw(const nvinfer1::IDimensionExpr *input, - const nvinfer1::IDimensionExpr *weight, - const nvinfer1::IDimensionExpr *stride, - const nvinfer1::IDimensionExpr *pad, - const nvinfer1::IDimensionExpr *dilation, - nvinfer1::IExprBuilder &exprBuilder) { - using DimOp = nvinfer1::DimensionOperation; - auto expr_1 = exprBuilder.constant(1); - - // d*(w-1)+1 - auto kernel_0 = exprBuilder.operation(DimOp::kSUB, *weight, *expr_1); - auto kernel_1 = exprBuilder.operation(DimOp::kPROD, *dilation, *kernel_0); - auto kernel = exprBuilder.operation(DimOp::kSUM, *kernel_1, *expr_1); - - // (1+2*p-k)//stride -1 - auto out_0 = exprBuilder.operation(DimOp::kSUM, *pad, *pad); - auto out_1 = exprBuilder.operation(DimOp::kSUM, *input, *out_0); - auto out_2 = exprBuilder.operation(DimOp::kSUB, *out_1, *kernel); - auto out_3 = exprBuilder.operation(DimOp::kFLOOR_DIV, *out_2, *stride); - auto out = exprBuilder.operation(DimOp::kSUM, *out_3, *expr_1); - - return out; -} - -nvinfer1::DimsExprs ModulatedDeformableConvPluginDynamic::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - using DimOp = nvinfer1::DimensionOperation; - auto weight_dim = inputs[3].d; - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[3].d[0]; - - auto input_h = inputs[0].d[2]; - auto input_w = inputs[0].d[3]; - auto weight_h = weight_dim[2]; - auto weight_w = weight_dim[3]; - auto dilation_w = exprBuilder.constant(mDilation.d[0]); - auto dilation_h = exprBuilder.constant(mDilation.d[1]); - auto pad_w = exprBuilder.constant(mPadding.d[0]); - auto pad_h = exprBuilder.constant(mPadding.d[1]); - auto stride_w = exprBuilder.constant(mStride.d[0]); - auto stride_h = exprBuilder.constant(mStride.d[1]); - auto expr_1 = exprBuilder.constant(1); - auto expr_2 = exprBuilder.constant(2); - - ret.d[2] = get_hw(input_h, weight_h, stride_h, pad_h, dilation_h, exprBuilder); - ret.d[3] = get_hw(input_w, weight_w, stride_w, pad_w, dilation_w, exprBuilder); - - return ret; -} - -bool ModulatedDeformableConvPluginDynamic::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 0) { - return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT || - ioDesc[pos].type == nvinfer1::DataType::kHALF) && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - } else { - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - } -} - -void ModulatedDeformableConvPluginDynamic::configurePlugin( - const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) TRT_NOEXCEPT { - if (nbInputs == 5) { - mWithBias = true; - } -} - -size_t ModulatedDeformableConvPluginDynamic::getWorkspaceSize( - const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const TRT_NOEXCEPT { - int sizeof_dtype = mmdeploy::getElementSize(outputs[0].type); - - int batch_size = inputs[0].dims.d[0]; - int nInputPlane = inputs[0].dims.d[1]; - int inputHeight = inputs[0].dims.d[2]; - int inputWidth = inputs[0].dims.d[3]; - - int nOutputPlane = outputs[0].dims.d[1]; - int outputHeight = outputs[0].dims.d[2]; - int outputWidth = outputs[0].dims.d[3]; - - int kW = inputs[3].dims.d[2]; - int kH = inputs[3].dims.d[3]; - int im2col_step = std::min(32, batch_size); - - size_t col_size = - mmdeploy::getAlignedSize(nInputPlane * kW * kH * outputHeight * outputWidth * sizeof_dtype); - - return col_size; -} - -int ModulatedDeformableConvPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, - void *workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - int batch = inputDesc[0].dims.d[0]; - int channels = inputDesc[0].dims.d[1]; - int height = inputDesc[0].dims.d[2]; - int width = inputDesc[0].dims.d[3]; - int channels_out = outputDesc[0].dims.d[1]; - int kernel_h = inputDesc[3].dims.d[2]; - int kernel_w = inputDesc[3].dims.d[3]; - - const void *x = inputs[0]; - const void *offset = inputs[1]; - const void *mask = inputs[2]; - const void *weight = inputs[3]; - const void *bias = mWithBias ? inputs[4] : nullptr; - void *output = outputs[0]; - int im2col_step = std::min(batch, 32); - - // TODO: add fp16 support - auto data_type = inputDesc[0].type; - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - ModulatedDeformConvForwardCUDAKernelLauncher( - (float *)x, (float *)weight, (float *)bias, (float *)offset, (float *)mask, - (float *)output, workSpace, batch, channels, height, width, channels_out, kernel_w, - kernel_h, mStride.d[0], mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0], - mDilation.d[1], mGroup, mDeformableGroup, im2col_step, m_cublas_handle, stream); - break; - case nvinfer1::DataType::kHALF: - ModulatedDeformConvForwardCUDAKernelLauncher( - (half *)x, (half *)weight, (half *)bias, (half *)offset, (half *)mask, (half *)output, - workSpace, batch, channels, height, width, channels_out, kernel_w, kernel_h, mStride.d[0], - mStride.d[1], mPadding.d[0], mPadding.d[1], mDilation.d[0], mDilation.d[1], mGroup, - mDeformableGroup, im2col_step, m_cublas_handle, stream); - break; - default: - return 1; - break; - } - - return 0; -} - -nvinfer1::DataType ModulatedDeformableConvPluginDynamic::getOutputDataType( - int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *ModulatedDeformableConvPluginDynamic::getPluginType() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *ModulatedDeformableConvPluginDynamic::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -int ModulatedDeformableConvPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t ModulatedDeformableConvPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mStride) + serialized_size(mPadding) + serialized_size(mDilation) + - serialized_size(mDeformableGroup) + serialized_size(mGroup); -} - -void ModulatedDeformableConvPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mStride); - serialize_value(&buffer, mPadding); - serialize_value(&buffer, mDilation); - serialize_value(&buffer, mDeformableGroup); - serialize_value(&buffer, mGroup); -} - -void ModulatedDeformableConvPluginDynamic::attachToContext( - cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT { - m_cublas_handle = cublasContext; -} - -void ModulatedDeformableConvPluginDynamic::detachFromContext() TRT_NOEXCEPT {} - -////////////////////// creator ///////////////////////////// - -ModulatedDeformableConvPluginDynamicCreator::ModulatedDeformableConvPluginDynamicCreator() { - mPluginAttributes.clear(); - mPluginAttributes.emplace_back(nvinfer1::PluginField("stride")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("padding")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("groups")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups")); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *ModulatedDeformableConvPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *ModulatedDeformableConvPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *ModulatedDeformableConvPluginDynamicCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - nvinfer1::Dims stride{2, {1, 1}}; - nvinfer1::Dims padding{2, {0, 0}}; - nvinfer1::Dims dilation{2, {1, 1}}; - int deformableGroup = 1; - int group = 1; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"MMCVModulatedDeformConv2d"}; + } // namespace + + ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic( + const std::string& name, + const nvinfer1::Dims stride, + const nvinfer1::Dims padding, + const nvinfer1::Dims dilation, + const int deformableGroup, + const int group) + : TRTPluginBase(name) + , mStride(stride) + , mPadding(padding) + , mDilation(dilation) + , mDeformableGroup(deformableGroup) + , mGroup(group) + { + mWithBias = false; } - std::string field_name(fc->fields[i].name); - if (field_name.compare("deform_groups") == 0) { - deformableGroup = static_cast(fc->fields[i].data)[0]; + ModulatedDeformableConvPluginDynamic::ModulatedDeformableConvPluginDynamic(const std::string name, + const void* data, + size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mStride); + deserialize_value(&data, &length, &mPadding); + deserialize_value(&data, &length, &mDilation); + deserialize_value(&data, &length, &mDeformableGroup); + deserialize_value(&data, &length, &mGroup); + mWithBias = false; + } + ModulatedDeformableConvPluginDynamic::~ModulatedDeformableConvPluginDynamic() {} + + nvinfer1::IPluginV2DynamicExt* ModulatedDeformableConvPluginDynamic::clone() const TRT_NOEXCEPT + { + ModulatedDeformableConvPluginDynamic* plugin = new ModulatedDeformableConvPluginDynamic( + mLayerName, + mStride, + mPadding, + mDilation, + mDeformableGroup, + mGroup); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + static const nvinfer1::IDimensionExpr* get_hw(const nvinfer1::IDimensionExpr* input, + const nvinfer1::IDimensionExpr* weight, + const nvinfer1::IDimensionExpr* stride, + const nvinfer1::IDimensionExpr* pad, + const nvinfer1::IDimensionExpr* dilation, + nvinfer1::IExprBuilder& exprBuilder) + { + using DimOp = nvinfer1::DimensionOperation; + auto expr_1 = exprBuilder.constant(1); + + // d*(w-1)+1 + auto kernel_0 = exprBuilder.operation(DimOp::kSUB, *weight, *expr_1); + auto kernel_1 = exprBuilder.operation(DimOp::kPROD, *dilation, *kernel_0); + auto kernel = exprBuilder.operation(DimOp::kSUM, *kernel_1, *expr_1); + + // (1+2*p-k)//stride -1 + auto out_0 = exprBuilder.operation(DimOp::kSUM, *pad, *pad); + auto out_1 = exprBuilder.operation(DimOp::kSUM, *input, *out_0); + auto out_2 = exprBuilder.operation(DimOp::kSUB, *out_1, *kernel); + auto out_3 = exprBuilder.operation(DimOp::kFLOOR_DIV, *out_2, *stride); + auto out = exprBuilder.operation(DimOp::kSUM, *out_3, *expr_1); + + return out; + } + + nvinfer1::DimsExprs ModulatedDeformableConvPluginDynamic::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + using DimOp = nvinfer1::DimensionOperation; + auto weight_dim = inputs[3].d; + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[3].d[0]; + + auto input_h = inputs[0].d[2]; + auto input_w = inputs[0].d[3]; + auto weight_h = weight_dim[2]; + auto weight_w = weight_dim[3]; + auto dilation_w = exprBuilder.constant(mDilation.d[0]); + auto dilation_h = exprBuilder.constant(mDilation.d[1]); + auto pad_w = exprBuilder.constant(mPadding.d[0]); + auto pad_h = exprBuilder.constant(mPadding.d[1]); + auto stride_w = exprBuilder.constant(mStride.d[0]); + auto stride_h = exprBuilder.constant(mStride.d[1]); + auto expr_1 = exprBuilder.constant(1); + auto expr_2 = exprBuilder.constant(2); + + ret.d[2] = get_hw(input_h, weight_h, stride_h, pad_h, dilation_h, exprBuilder); + ret.d[3] = get_hw(input_w, weight_w, stride_w, pad_w, dilation_w, exprBuilder); + + return ret; + } + + bool ModulatedDeformableConvPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 0) + { + return ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT || + ioDesc[pos].type == nvinfer1::DataType::kHALF) && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } + else + { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } + } + + void ModulatedDeformableConvPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT + { + if (nbInputs == 5) + { + mWithBias = true; + } + } + + size_t ModulatedDeformableConvPluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + int sizeof_dtype = mmdeploy::getElementSize(outputs[0].type); + + int batch_size = inputs[0].dims.d[0]; + int nInputPlane = inputs[0].dims.d[1]; + int inputHeight = inputs[0].dims.d[2]; + int inputWidth = inputs[0].dims.d[3]; + + int nOutputPlane = outputs[0].dims.d[1]; + int outputHeight = outputs[0].dims.d[2]; + int outputWidth = outputs[0].dims.d[3]; + + int kW = inputs[3].dims.d[2]; + int kH = inputs[3].dims.d[3]; + int im2col_step = std::min(32, batch_size); + + size_t col_size = + mmdeploy::getAlignedSize(nInputPlane * kW * kH * outputHeight * outputWidth * sizeof_dtype); + + return col_size; + } + + int ModulatedDeformableConvPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int batch = inputDesc[0].dims.d[0]; + int channels = inputDesc[0].dims.d[1]; + int height = inputDesc[0].dims.d[2]; + int width = inputDesc[0].dims.d[3]; + int channels_out = outputDesc[0].dims.d[1]; + int kernel_h = inputDesc[3].dims.d[2]; + int kernel_w = inputDesc[3].dims.d[3]; + + const void* x = inputs[0]; + const void* offset = inputs[1]; + const void* mask = inputs[2]; + const void* weight = inputs[3]; + const void* bias = mWithBias ? inputs[4] : nullptr; + void* output = outputs[0]; + int im2col_step = std::min(batch, 32); + + // TODO: add fp16 support + auto data_type = inputDesc[0].type; + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + ModulatedDeformConvForwardCUDAKernelLauncher( + (float*)x, + (float*)weight, + (float*)bias, + (float*)offset, + (float*)mask, + (float*)output, + workSpace, + batch, + channels, + height, + width, + channels_out, + kernel_w, + kernel_h, + mStride.d[0], + mStride.d[1], + mPadding.d[0], + mPadding.d[1], + mDilation.d[0], + mDilation.d[1], + mGroup, + mDeformableGroup, + im2col_step, + m_cublas_handle, + stream); + break; + case nvinfer1::DataType::kHALF: + ModulatedDeformConvForwardCUDAKernelLauncher( + (half*)x, + (half*)weight, + (half*)bias, + (half*)offset, + (half*)mask, + (half*)output, + workSpace, + batch, + channels, + height, + width, + channels_out, + kernel_w, + kernel_h, + mStride.d[0], + mStride.d[1], + mPadding.d[0], + mPadding.d[1], + mDilation.d[0], + mDilation.d[1], + mGroup, + mDeformableGroup, + im2col_step, + m_cublas_handle, + stream); + break; + default: + return 1; + break; + } + + return 0; + } + + nvinfer1::DataType ModulatedDeformableConvPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* ModulatedDeformableConvPluginDynamic::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* ModulatedDeformableConvPluginDynamic::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int ModulatedDeformableConvPluginDynamic::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t ModulatedDeformableConvPluginDynamic::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mStride) + serialized_size(mPadding) + serialized_size(mDilation) + + serialized_size(mDeformableGroup) + serialized_size(mGroup); + } + + void ModulatedDeformableConvPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mStride); + serialize_value(&buffer, mPadding); + serialize_value(&buffer, mDilation); + serialize_value(&buffer, mDeformableGroup); + serialize_value(&buffer, mGroup); + } + + void ModulatedDeformableConvPluginDynamic::attachToContext( + cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT + { + m_cublas_handle = cublasContext; + } + + void ModulatedDeformableConvPluginDynamic::detachFromContext() TRT_NOEXCEPT {} + + ////////////////////// creator ///////////////////////////// + + ModulatedDeformableConvPluginDynamicCreator::ModulatedDeformableConvPluginDynamicCreator() + { + mPluginAttributes.clear(); + mPluginAttributes.emplace_back(nvinfer1::PluginField("stride")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("padding")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("dilation")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("groups")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("deform_groups")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); } - if (field_name.compare("groups") == 0) { - group = static_cast(fc->fields[i].data)[0]; + const char* ModulatedDeformableConvPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; } - if (field_name.compare("stride") == 0) { - stride.nbDims = 2; - stride.d[0] = static_cast(fc->fields[i].data)[0]; - stride.d[1] = static_cast(fc->fields[i].data)[1]; + const char* ModulatedDeformableConvPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; } - if (field_name.compare("padding") == 0) { - padding.nbDims = 2; - padding.d[0] = static_cast(fc->fields[i].data)[0]; - padding.d[1] = static_cast(fc->fields[i].data)[1]; + nvinfer1::IPluginV2* ModulatedDeformableConvPluginDynamicCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + nvinfer1::Dims stride{2, {1, 1}}; + nvinfer1::Dims padding{2, {0, 0}}; + nvinfer1::Dims dilation{2, {1, 1}}; + int deformableGroup = 1; + int group = 1; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("deform_groups") == 0) + { + deformableGroup = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("groups") == 0) + { + group = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("stride") == 0) + { + stride.nbDims = 2; + stride.d[0] = static_cast(fc->fields[i].data)[0]; + stride.d[1] = static_cast(fc->fields[i].data)[1]; + } + + if (field_name.compare("padding") == 0) + { + padding.nbDims = 2; + padding.d[0] = static_cast(fc->fields[i].data)[0]; + padding.d[1] = static_cast(fc->fields[i].data)[1]; + } + + if (field_name.compare("dilation") == 0) + { + dilation.nbDims = 2; + dilation.d[0] = static_cast(fc->fields[i].data)[0]; + dilation.d[1] = static_cast(fc->fields[i].data)[1]; + } + } + + ModulatedDeformableConvPluginDynamic* plugin = new ModulatedDeformableConvPluginDynamic( + name, + stride, + padding, + dilation, + deformableGroup, + group); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; } - if (field_name.compare("dilation") == 0) { - dilation.nbDims = 2; - dilation.d[0] = static_cast(fc->fields[i].data)[0]; - dilation.d[1] = static_cast(fc->fields[i].data)[1]; + nvinfer1::IPluginV2* ModulatedDeformableConvPluginDynamicCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new ModulatedDeformableConvPluginDynamic(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; } - } - - ModulatedDeformableConvPluginDynamic *plugin = new ModulatedDeformableConvPluginDynamic( - name, stride, padding, dilation, deformableGroup, group); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *ModulatedDeformableConvPluginDynamicCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new ModulatedDeformableConvPluginDynamic(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(ModulatedDeformableConvPluginDynamicCreator); + REGISTER_TENSORRT_PLUGIN(ModulatedDeformableConvPluginDynamicCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.hpp b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.hpp index 2dc6ed2f20..2082d83b9a 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv.hpp @@ -9,74 +9,69 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class ModulatedDeformableConvPluginDynamic : public TRTPluginBase { - public: - ModulatedDeformableConvPluginDynamic(const std::string &name, const nvinfer1::Dims stride, - const nvinfer1::Dims padding, const nvinfer1::Dims dilation, - const int deformableGroup, const int group); - - ModulatedDeformableConvPluginDynamic(const std::string name, const void *data, size_t length); - - ModulatedDeformableConvPluginDynamic() = delete; - - ~ModulatedDeformableConvPluginDynamic() TRT_NOEXCEPT override; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override; - void detachFromContext() TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; - - private: - nvinfer1::Dims mStride; - nvinfer1::Dims mPadding; - nvinfer1::Dims mDilation; - int mDeformableGroup; - int mGroup; - bool mWithBias; - - cublasHandle_t m_cublas_handle; -}; - -class ModulatedDeformableConvPluginDynamicCreator : public TRTPluginCreatorBase { - public: - ModulatedDeformableConvPluginDynamicCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class ModulatedDeformableConvPluginDynamic : public TRTPluginBase + { + public: + ModulatedDeformableConvPluginDynamic(const std::string& name, const nvinfer1::Dims stride, const nvinfer1::Dims padding, const nvinfer1::Dims dilation, const int deformableGroup, const int group); + + ModulatedDeformableConvPluginDynamic(const std::string name, const void* data, size_t length); + + ModulatedDeformableConvPluginDynamic() = delete; + + ~ModulatedDeformableConvPluginDynamic() TRT_NOEXCEPT override; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override; + void detachFromContext() TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + + private: + nvinfer1::Dims mStride; + nvinfer1::Dims mPadding; + nvinfer1::Dims mDilation; + int mDeformableGroup; + int mGroup; + bool mWithBias; + + cublasHandle_t m_cublas_handle; + }; + + class ModulatedDeformableConvPluginDynamicCreator : public TRTPluginCreatorBase + { + public: + ModulatedDeformableConvPluginDynamicCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_MODULATED_DEFORM_CONV_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.cu index 1e1f99d5ff..21fc6cacf5 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.cu @@ -7,132 +7,228 @@ #include "trt_modulated_deform_conv_kernel.hpp" #include "trt_plugin_helper.hpp" -template -void trt_modulated_deformable_im2col(const T* data_im_, const T* data_offset_, const T* data_mask_, - const int batch_size, const int channels, const int height_im, - const int width_im, const int height_col, const int width_col, - const int kernel_h, const int kenerl_w, const int pad_h, - const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, T* data_col_, - cudaStream_t stream) { - // num_axes should be smaller than block size - const int channel_per_deformable_group = channels / deformable_group; - const int num_kernels = channels * batch_size * height_col * width_col; - - modulated_deformable_im2col_gpu_kernel - <<>>( - num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w, - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, - batch_size, channels, deformable_group, height_col, width_col, data_col_); - - cudaCheckError(); +template +void trt_modulated_deformable_im2col(const T* data_im_, const T* data_offset_, const T* data_mask_, const int batch_size, const int channels, const int height_im, const int width_im, const int height_col, const int width_col, const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w, const int stride_h, const int stride_w, const int dilation_h, const int dilation_w, const int deformable_group, T* data_col_, cudaStream_t stream) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + modulated_deformable_im2col_gpu_kernel + <<>>( + num_kernels, + data_im_, + data_offset_, + data_mask_, + height_im, + width_im, + kernel_h, + kenerl_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + batch_size, + channels, + deformable_group, + height_col, + width_col, + data_col_); + + cudaCheckError(); } -template -__global__ void output_add_bias_kernel(scalar_t* output, const scalar_t* bias, size_t step_batch, - size_t step_channel, size_t n) { - CUDA_1D_KERNEL_LOOP(index, n) { output[index] += bias[(index % step_batch) / step_channel]; } +template +__global__ void output_add_bias_kernel(scalar_t* output, const scalar_t* bias, size_t step_batch, size_t step_channel, size_t n) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + output[index] += bias[(index % step_batch) / step_channel]; + } } #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) -template <> -__global__ void output_add_bias_kernel<__half>(__half* output, const __half* bias, - size_t step_batch, size_t step_channel, size_t n) { - CUDA_1D_KERNEL_LOOP(index, n) { - const __half b = bias[(index % step_batch) / step_channel]; - const __half o = output[index]; - output[index] = __hadd(o, b); - } +template<> +__global__ void output_add_bias_kernel<__half>(__half* output, const __half* bias, size_t step_batch, size_t step_channel, size_t n) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + const __half b = bias[(index % step_batch) / step_channel]; + const __half o = output[index]; + output[index] = __hadd(o, b); + } } #else -template <> -__global__ void output_add_bias_kernel<__half>(__half* output, const __half* bias, - size_t step_batch, size_t step_channel, size_t n) { - CUDA_1D_KERNEL_LOOP(index, n) { - const __half b = bias[(index % step_batch) / step_channel]; - const __half o = output[index]; - output[index] = __float2half(__half2float(o) + __half2float(b)); - } +template<> +__global__ void output_add_bias_kernel<__half>(__half* output, const __half* bias, size_t step_batch, size_t step_channel, size_t n) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + const __half b = bias[(index % step_batch) / step_channel]; + const __half o = output[index]; + output[index] = __float2half(__half2float(o) + __half2float(b)); + } } #endif -template -static void output_add_bias(scalar_t* output, const scalar_t* bias, size_t batch, size_t channel, - size_t height, size_t width, cudaStream_t stream) { - size_t step_channel = height * width; - size_t step_batch = step_channel * channel; - size_t n = step_batch * batch; - output_add_bias_kernel<<>>(output, bias, step_batch, - step_channel, n); +template +static void output_add_bias(scalar_t* output, const scalar_t* bias, size_t batch, size_t channel, size_t height, size_t width, cudaStream_t stream) +{ + size_t step_channel = height * width; + size_t step_batch = step_channel * channel; + size_t n = step_batch * batch; + output_add_bias_kernel<<>>(output, bias, step_batch, step_channel, n); } -template +template void ModulatedDeformConvForwardCUDAKernelLauncher( - const scalar_t* input, const scalar_t* weight, const scalar_t* bias, const scalar_t* offset, - const scalar_t* mask, scalar_t* output, void* workspace, int batch, int channels, int height, - int width, int channels_out, int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w, - int pad_h, int dilation_w, int dilation_h, int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream) { - bool with_bias = (bias != nullptr); - - im2col_step = std::min(int(batch), im2col_step); - assert(batch % im2col_step == 0); - - const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - scalar_t* columns = (scalar_t*)workspace; - - const size_t input_step = channels * height * width; - const size_t offset_step = deformable_group * kernel_h * kernel_w * 2 * height_out * width_out; - const size_t mask_step = deformable_group * kernel_h * kernel_w * height_out * width_out; - const size_t out_step = channels_out * height_out * width_out; - const size_t out_group_step = out_step / group; - const size_t col_g_step = channels * kernel_w * kernel_h / group * height_out * width_out; - const size_t weight_g_step = channels_out / group * channels / group * kernel_h * kernel_w; - - const int m = channels_out / group; - const int n = height_out * width_out; - const int k = channels / group * kernel_h * kernel_w; - scalar_t alpha = 1.; - scalar_t beta = 0.; - - for (int b = 0; b < batch; b++) { - const scalar_t* input_start = input + b * input_step; - const scalar_t* offset_start = offset + b * offset_step; - const scalar_t* mask_start = mask + b * mask_step; - trt_modulated_deformable_im2col( - input_start, offset_start, mask_start, 1, channels, height, width, height_out, width_out, - kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, - deformable_group, columns, stream); - - for (int g = 0; g < group; g++) { - const scalar_t* weight_start = weight + g * weight_g_step; - scalar_t* col_start = columns + g * col_g_step; - scalar_t* out_buffer_start = output + b * out_step + g * out_group_step; - - cublasGemmWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, col_start, - n, weight_start, k, &beta, out_buffer_start, n); - cudaCheckError(); + const scalar_t* input, + const scalar_t* weight, + const scalar_t* bias, + const scalar_t* offset, + const scalar_t* mask, + scalar_t* output, + void* workspace, + int batch, + int channels, + int height, + int width, + int channels_out, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int pad_w, + int pad_h, + int dilation_w, + int dilation_h, + int group, + int deformable_group, + int im2col_step, + cublasHandle_t cublas_handle, + cudaStream_t stream) +{ + bool with_bias = (bias != nullptr); + + im2col_step = std::min(int(batch), im2col_step); + assert(batch % im2col_step == 0); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + scalar_t* columns = (scalar_t*)workspace; + + const size_t input_step = channels * height * width; + const size_t offset_step = deformable_group * kernel_h * kernel_w * 2 * height_out * width_out; + const size_t mask_step = deformable_group * kernel_h * kernel_w * height_out * width_out; + const size_t out_step = channels_out * height_out * width_out; + const size_t out_group_step = out_step / group; + const size_t col_g_step = channels * kernel_w * kernel_h / group * height_out * width_out; + const size_t weight_g_step = channels_out / group * channels / group * kernel_h * kernel_w; + + const int m = channels_out / group; + const int n = height_out * width_out; + const int k = channels / group * kernel_h * kernel_w; + scalar_t alpha = 1.; + scalar_t beta = 0.; + + for (int b = 0; b < batch; b++) + { + const scalar_t* input_start = input + b * input_step; + const scalar_t* offset_start = offset + b * offset_step; + const scalar_t* mask_start = mask + b * mask_step; + trt_modulated_deformable_im2col( + input_start, + offset_start, + mask_start, + 1, + channels, + height, + width, + height_out, + width_out, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + deformable_group, + columns, + stream); + + for (int g = 0; g < group; g++) + { + const scalar_t* weight_start = weight + g * weight_g_step; + scalar_t* col_start = columns + g * col_g_step; + scalar_t* out_buffer_start = output + b * out_step + g * out_group_step; + + cublasGemmWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, col_start, n, weight_start, k, &beta, out_buffer_start, n); + cudaCheckError(); + } } - } - if (with_bias) { - output_add_bias(output, bias, batch, channels_out, height_out, width_out, stream); - } + if (with_bias) + { + output_add_bias(output, bias, batch, channels_out, height_out, width_out, stream); + } } template void ModulatedDeformConvForwardCUDAKernelLauncher( - const float* input, const float* weight, const float* bias, const float* offset, - const float* mask, float* output, void* workspace, int batch, int channels, int height, - int width, int channels_out, int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w, - int pad_h, int dilation_w, int dilation_h, int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream); + const float* input, + const float* weight, + const float* bias, + const float* offset, + const float* mask, + float* output, + void* workspace, + int batch, + int channels, + int height, + int width, + int channels_out, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int pad_w, + int pad_h, + int dilation_w, + int dilation_h, + int group, + int deformable_group, + int im2col_step, + cublasHandle_t cublas_handle, + cudaStream_t stream); template void ModulatedDeformConvForwardCUDAKernelLauncher<__half>( - const __half* input, const __half* weight, const __half* bias, const __half* offset, - const __half* mask, __half* output, void* workspace, int batch, int channels, int height, - int width, int channels_out, int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w, - int pad_h, int dilation_w, int dilation_h, int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream); + const __half* input, + const __half* weight, + const __half* bias, + const __half* offset, + const __half* mask, + __half* output, + void* workspace, + int batch, + int channels, + int height, + int width, + int channels_out, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int pad_w, + int pad_h, + int dilation_w, + int dilation_h, + int group, + int deformable_group, + int im2col_step, + cublasHandle_t cublas_handle, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.hpp index 4cdec4fb38..3a1298558c 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/modulated_deform_conv/trt_modulated_deform_conv_kernel.hpp @@ -4,12 +4,32 @@ #include #include -template +template void ModulatedDeformConvForwardCUDAKernelLauncher( - const scalar_t* input, const scalar_t* weight, const scalar_t* bias, const scalar_t* offset, - const scalar_t* mask, scalar_t* output, void* workspace, int batch, int channels, int height, - int width, int channels_out, int kernel_w, int kernel_h, int stride_w, int stride_h, int pad_w, - int pad_h, int dilation_w, int dilation_h, int group, int deformable_group, int im2col_step, - cublasHandle_t cublas_handle, cudaStream_t stream); + const scalar_t* input, + const scalar_t* weight, + const scalar_t* bias, + const scalar_t* offset, + const scalar_t* mask, + scalar_t* output, + void* workspace, + int batch, + int channels, + int height, + int width, + int channels_out, + int kernel_w, + int kernel_h, + int stride_w, + int stride_h, + int pad_w, + int pad_h, + int dilation_w, + int dilation_h, + int group, + int deformable_group, + int im2col_step, + cublasHandle_t cublas_handle, + cudaStream_t stream); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp index ad9a518da7..456acca9b4 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.cpp @@ -9,219 +9,263 @@ #include "trt_multi_level_roi_align_kernel.hpp" #include "trt_plugin_helper.hpp" #include "trt_serialize.hpp" -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"MMCVMultiLevelRoiAlign"}; -} // namespace - -TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string &name, int alignedHeight, - int alignedWidth, int poolMode, int sampleNum, - const std::vector &featmapStrides, - float roiScaleFactor, int finestScale, bool aligned) - : TRTPluginBase(name), - mAlignedHeight(alignedHeight), - mAlignedWidth(alignedWidth), - mPoolMode(poolMode), - mSampleNum(sampleNum), - mFeatmapStrides(featmapStrides), - mRoiScaleFactor(roiScaleFactor), - mFinestScale(finestScale), - mAligned(aligned) {} - -TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string name, const void *data, - size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mAlignedHeight); - deserialize_value(&data, &length, &mAlignedWidth); - deserialize_value(&data, &length, &mPoolMode); - deserialize_value(&data, &length, &mSampleNum); - deserialize_value(&data, &length, &mRoiScaleFactor); - deserialize_value(&data, &length, &mFinestScale); - deserialize_value(&data, &length, &mAligned); - deserialize_value(&data, &length, &mFeatmapStrides); -} - -nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRoiAlign::clone() const TRT_NOEXCEPT { - TRTMultiLevelRoiAlign *plugin = - new TRTMultiLevelRoiAlign(mLayerName, mAlignedHeight, mAlignedWidth, mPoolMode, mSampleNum, - mFeatmapStrides, mRoiScaleFactor, mFinestScale, mAligned); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - // warning, nbInputs should equal to mFeatmapStrides.size() + 1 - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[1].d[1]; - ret.d[2] = exprBuilder.constant(mAlignedHeight); - ret.d[3] = exprBuilder.constant(mAlignedWidth); - - return ret; -} - -bool TRTMultiLevelRoiAlign::supportsFormatCombination(int pos, - const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; -} - -void TRTMultiLevelRoiAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT { - // Validate input arguments - ASSERT(nbOutputs == 1); - ASSERT(nbInputs >= 1); - mFeatmapStrides = - std::vector(mFeatmapStrides.begin(), mFeatmapStrides.begin() + (nbInputs - 1)); -} - -size_t TRTMultiLevelRoiAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int TRTMultiLevelRoiAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, void *workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - int num_rois = inputDesc[0].dims.d[0]; - int batch_size = inputDesc[1].dims.d[0]; - int channels = inputDesc[1].dims.d[1]; - - const int kMaxFeatMap = 10; - int heights[kMaxFeatMap]; - int widths[kMaxFeatMap]; - float strides[kMaxFeatMap]; - - int num_feats = mFeatmapStrides.size(); - for (int i = 0; i < num_feats; ++i) { - heights[i] = inputDesc[i + 1].dims.d[2]; - widths[i] = inputDesc[i + 1].dims.d[3]; - strides[i] = mFeatmapStrides[i]; - } - - const void *rois = inputs[0]; - const void *const *feats = inputs + 1; - - multi_level_roi_align((float *)outputs[0], (const float *)rois, num_rois, feats, num_feats, - batch_size, channels, &heights[0], &widths[0], &strides[0], - mAlignedHeight, mAlignedWidth, mPoolMode, mSampleNum, - mRoiScaleFactor, mFinestScale, mAligned, stream); - - return 0; -} - -nvinfer1::DataType TRTMultiLevelRoiAlign::getOutputDataType(int index, - const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return nvinfer1::DataType::kFLOAT; -} - -// IPluginV2 Methods -const char *TRTMultiLevelRoiAlign::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTMultiLevelRoiAlign::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int TRTMultiLevelRoiAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t TRTMultiLevelRoiAlign::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) + - serialized_size(mAlignedWidth) + serialized_size(mPoolMode) + serialized_size(mSampleNum) + - serialized_size(mRoiScaleFactor) + serialized_size(mFinestScale) + - serialized_size(mAligned); -} - -void TRTMultiLevelRoiAlign::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mAlignedHeight); - serialize_value(&buffer, mAlignedWidth); - serialize_value(&buffer, mPoolMode); - serialize_value(&buffer, mSampleNum); - serialize_value(&buffer, mRoiScaleFactor); - serialize_value(&buffer, mFinestScale); - serialize_value(&buffer, mAligned); - serialize_value(&buffer, mFeatmapStrides); -} - -TRTMultiLevelRoiAlignCreator::TRTMultiLevelRoiAlignCreator() { - mPluginAttributes = std::vector( - {nvinfer1::PluginField("output_height"), nvinfer1::PluginField("output_width"), - nvinfer1::PluginField("pool_mode"), nvinfer1::PluginField("sampling_ratio"), - nvinfer1::PluginField("featmap_strides"), nvinfer1::PluginField("roi_scale_factor"), - nvinfer1::PluginField("finest_scale"), nvinfer1::PluginField("aligned")}); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *TRTMultiLevelRoiAlignCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTMultiLevelRoiAlignCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - int alignedHeight = 7; - int alignedWidth = 7; - int poolMode = 0; - int sampleNum = 2; - std::vector featmapStrides; - float roiScaleFactor = -1; - int finestScale = 56; - bool aligned = false; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; - } - std::string field_name(fc->fields[i].name); - - if (field_name.compare("output_height") == 0) { - alignedHeight = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("output_width") == 0) { - alignedWidth = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("pool_mode") == 0) { - poolMode = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("sampling_ratio") == 0) { - sampleNum = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("roi_scale_factor") == 0) { - roiScaleFactor = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("finest_scale") == 0) { - finestScale = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("featmap_strides") == 0) { - int data_size = (fc->fields[i].length); - const float *data_start = static_cast(fc->fields[i].data); - featmapStrides = std::vector(data_start, data_start + data_size); - } else if (field_name.compare("aligned") == 0) { - int aligned_int = static_cast(fc->fields[i].data)[0]; - aligned = aligned_int != 0; - } - } - - ASSERT(featmapStrides.size() != 0); - - TRTMultiLevelRoiAlign *plugin = - new TRTMultiLevelRoiAlign(name, alignedHeight, alignedWidth, poolMode, sampleNum, - featmapStrides, roiScaleFactor, finestScale, aligned); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *TRTMultiLevelRoiAlignCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new TRTMultiLevelRoiAlign(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -REGISTER_TENSORRT_PLUGIN(TRTMultiLevelRoiAlignCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"MMCVMultiLevelRoiAlign"}; + } // namespace + + TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string& name, int alignedHeight, int alignedWidth, int poolMode, int sampleNum, const std::vector& featmapStrides, float roiScaleFactor, int finestScale, bool aligned) + : TRTPluginBase(name) + , mAlignedHeight(alignedHeight) + , mAlignedWidth(alignedWidth) + , mPoolMode(poolMode) + , mSampleNum(sampleNum) + , mFeatmapStrides(featmapStrides) + , mRoiScaleFactor(roiScaleFactor) + , mFinestScale(finestScale) + , mAligned(aligned) + { + } + + TRTMultiLevelRoiAlign::TRTMultiLevelRoiAlign(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mAlignedHeight); + deserialize_value(&data, &length, &mAlignedWidth); + deserialize_value(&data, &length, &mPoolMode); + deserialize_value(&data, &length, &mSampleNum); + deserialize_value(&data, &length, &mRoiScaleFactor); + deserialize_value(&data, &length, &mFinestScale); + deserialize_value(&data, &length, &mAligned); + deserialize_value(&data, &length, &mFeatmapStrides); + } + + nvinfer1::IPluginV2DynamicExt* TRTMultiLevelRoiAlign::clone() const TRT_NOEXCEPT + { + TRTMultiLevelRoiAlign* plugin = + new TRTMultiLevelRoiAlign(mLayerName, mAlignedHeight, mAlignedWidth, mPoolMode, mSampleNum, mFeatmapStrides, mRoiScaleFactor, mFinestScale, mAligned); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs TRTMultiLevelRoiAlign::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + // warning, nbInputs should equal to mFeatmapStrides.size() + 1 + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[1].d[1]; + ret.d[2] = exprBuilder.constant(mAlignedHeight); + ret.d[3] = exprBuilder.constant(mAlignedWidth); + + return ret; + } + + bool TRTMultiLevelRoiAlign::supportsFormatCombination(int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + + void TRTMultiLevelRoiAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT + { + // Validate input arguments + ASSERT(nbOutputs == 1); + ASSERT(nbInputs >= 1); + mFeatmapStrides = + std::vector(mFeatmapStrides.begin(), mFeatmapStrides.begin() + (nbInputs - 1)); + } + + size_t TRTMultiLevelRoiAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int TRTMultiLevelRoiAlign::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int num_rois = inputDesc[0].dims.d[0]; + int batch_size = inputDesc[1].dims.d[0]; + int channels = inputDesc[1].dims.d[1]; + + const int kMaxFeatMap = 10; + int heights[kMaxFeatMap]; + int widths[kMaxFeatMap]; + float strides[kMaxFeatMap]; + + int num_feats = mFeatmapStrides.size(); + for (int i = 0; i < num_feats; ++i) + { + heights[i] = inputDesc[i + 1].dims.d[2]; + widths[i] = inputDesc[i + 1].dims.d[3]; + strides[i] = mFeatmapStrides[i]; + } + + const void* rois = inputs[0]; + const void* const* feats = inputs + 1; + + multi_level_roi_align((float*)outputs[0], (const float*)rois, num_rois, feats, num_feats, batch_size, channels, &heights[0], &widths[0], &strides[0], mAlignedHeight, mAlignedWidth, mPoolMode, mSampleNum, mRoiScaleFactor, mFinestScale, mAligned, stream); + + return 0; + } + + nvinfer1::DataType TRTMultiLevelRoiAlign::getOutputDataType(int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return nvinfer1::DataType::kFLOAT; + } + + // IPluginV2 Methods + const char* TRTMultiLevelRoiAlign::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTMultiLevelRoiAlign::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int TRTMultiLevelRoiAlign::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t TRTMultiLevelRoiAlign::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) + + serialized_size(mAlignedWidth) + serialized_size(mPoolMode) + serialized_size(mSampleNum) + + serialized_size(mRoiScaleFactor) + serialized_size(mFinestScale) + + serialized_size(mAligned); + } + + void TRTMultiLevelRoiAlign::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mAlignedHeight); + serialize_value(&buffer, mAlignedWidth); + serialize_value(&buffer, mPoolMode); + serialize_value(&buffer, mSampleNum); + serialize_value(&buffer, mRoiScaleFactor); + serialize_value(&buffer, mFinestScale); + serialize_value(&buffer, mAligned); + serialize_value(&buffer, mFeatmapStrides); + } + + TRTMultiLevelRoiAlignCreator::TRTMultiLevelRoiAlignCreator() + { + mPluginAttributes = std::vector( + {nvinfer1::PluginField("output_height"), nvinfer1::PluginField("output_width"), nvinfer1::PluginField("pool_mode"), nvinfer1::PluginField("sampling_ratio"), nvinfer1::PluginField("featmap_strides"), nvinfer1::PluginField("roi_scale_factor"), nvinfer1::PluginField("finest_scale"), nvinfer1::PluginField("aligned")}); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTMultiLevelRoiAlignCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTMultiLevelRoiAlignCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* TRTMultiLevelRoiAlignCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + int alignedHeight = 7; + int alignedWidth = 7; + int poolMode = 0; + int sampleNum = 2; + std::vector featmapStrides; + float roiScaleFactor = -1; + int finestScale = 56; + bool aligned = false; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("output_height") == 0) + { + alignedHeight = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("output_width") == 0) + { + alignedWidth = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("pool_mode") == 0) + { + poolMode = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("sampling_ratio") == 0) + { + sampleNum = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("roi_scale_factor") == 0) + { + roiScaleFactor = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("finest_scale") == 0) + { + finestScale = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("featmap_strides") == 0) + { + int data_size = (fc->fields[i].length); + const float* data_start = static_cast(fc->fields[i].data); + featmapStrides = std::vector(data_start, data_start + data_size); + } + else if (field_name.compare("aligned") == 0) + { + int aligned_int = static_cast(fc->fields[i].data)[0]; + aligned = aligned_int != 0; + } + } + + ASSERT(featmapStrides.size() != 0); + + TRTMultiLevelRoiAlign* plugin = + new TRTMultiLevelRoiAlign(name, alignedHeight, alignedWidth, poolMode, sampleNum, featmapStrides, roiScaleFactor, finestScale, aligned); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* TRTMultiLevelRoiAlignCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new TRTMultiLevelRoiAlign(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + REGISTER_TENSORRT_PLUGIN(TRTMultiLevelRoiAlignCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.hpp index a9a06236e0..814118d29b 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align.hpp @@ -10,69 +10,65 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class TRTMultiLevelRoiAlign : public TRTPluginBase { - public: - TRTMultiLevelRoiAlign(const std::string &name, int alignedHeight, int alignedWidth, int poolMode, - int sampleNum, const std::vector &featmapStrides, - float roiScaleFactor = -1, int finestScale = 56, bool aligned = false); +namespace mmdeploy +{ + class TRTMultiLevelRoiAlign : public TRTPluginBase + { + public: + TRTMultiLevelRoiAlign(const std::string& name, int alignedHeight, int alignedWidth, int poolMode, int sampleNum, const std::vector& featmapStrides, float roiScaleFactor = -1, int finestScale = 56, bool aligned = false); - TRTMultiLevelRoiAlign(const std::string name, const void *data, size_t length); + TRTMultiLevelRoiAlign(const std::string name, const void* data, size_t length); - TRTMultiLevelRoiAlign() = delete; + TRTMultiLevelRoiAlign() = delete; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - private: - int mAlignedHeight; - int mAlignedWidth; - int mPoolMode; - int mSampleNum; - std::vector mFeatmapStrides; - float mRoiScaleFactor; - int mFinestScale; - bool mAligned; -}; + private: + int mAlignedHeight; + int mAlignedWidth; + int mPoolMode; + int mSampleNum; + std::vector mFeatmapStrides; + float mRoiScaleFactor; + int mFinestScale; + bool mAligned; + }; -class TRTMultiLevelRoiAlignCreator : public TRTPluginCreatorBase { - public: - TRTMultiLevelRoiAlignCreator(); + class TRTMultiLevelRoiAlignCreator : public TRTPluginCreatorBase + { + public: + TRTMultiLevelRoiAlignCreator(); - const char *getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_ROI_ALIGN_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.cu index 9eefbe3f32..1663088e30 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.cu @@ -10,167 +10,234 @@ #include "trt_plugin_helper.hpp" const int kMAX_FEATMAP_SIZE = 10; -struct FeatData { - const void *data[kMAX_FEATMAP_SIZE]; - int batch_size; - int channels; - int h[kMAX_FEATMAP_SIZE]; - int w[kMAX_FEATMAP_SIZE]; - float spatial_scale[kMAX_FEATMAP_SIZE]; - int num_featmap; +struct FeatData +{ + const void* data[kMAX_FEATMAP_SIZE]; + int batch_size; + int channels; + int h[kMAX_FEATMAP_SIZE]; + int w[kMAX_FEATMAP_SIZE]; + float spatial_scale[kMAX_FEATMAP_SIZE]; + int num_featmap; }; -template -__device__ scalar_t roi_align_single(const scalar_t *__restrict__ bottom_data, - const int roi_batch_ind, const scalar_t roi_start_w, - const scalar_t roi_start_h, const scalar_t roi_end_w, - const scalar_t roi_end_h, const scalar_t spatial_scale, - const int pw, const int ph, const int c, const int sample_num, - const int channels, const int height, const int width, - const int pooled_height, const int pooled_width) { - // Force malformed ROIs to be 1x1 - scalar_t roi_width = max(roi_end_w - roi_start_w, (scalar_t)(aligned ? 0. : 1.)); - scalar_t roi_height = max(roi_end_h - roi_start_h, (scalar_t)(aligned ? 0. : 1.)); - - const scalar_t bin_size_h = roi_height / pooled_height; - const scalar_t bin_size_w = roi_width / pooled_width; - - const scalar_t *offset_bottom_data = - bottom_data + (roi_batch_ind * channels + c) * height * width; - - const int sample_num_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height); - const int sample_num_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); - - scalar_t output_val = (pool_mode == 0) ? -FLT_MAX : 0; - const scalar_t y_offset = roi_start_h + ph * bin_size_h; - const scalar_t y_scale = bin_size_h / (scalar_t)(sample_num_h); - const scalar_t x_offset = roi_start_w + pw * bin_size_w; - const scalar_t x_scale = bin_size_w / (scalar_t)(sample_num_w); - for (int iy = 0; iy < sample_num_h; iy++) { - const scalar_t y = fma(scalar_t(iy) + scalar_t(.5f), y_scale, y_offset); - for (int ix = 0; ix < sample_num_w; ix++) { - const scalar_t x = fma(scalar_t(ix) + scalar_t(.5f), x_scale, x_offset); - scalar_t val = bilinear_interpolate(offset_bottom_data, height, width, y, x); - if (pool_mode == 0) { - output_val = max(output_val, val); - } else { - output_val += val; - } +template +__device__ scalar_t roi_align_single(const scalar_t* __restrict__ bottom_data, + const int roi_batch_ind, + const scalar_t roi_start_w, + const scalar_t roi_start_h, + const scalar_t roi_end_w, + const scalar_t roi_end_h, + const scalar_t spatial_scale, + const int pw, + const int ph, + const int c, + const int sample_num, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width) +{ + // Force malformed ROIs to be 1x1 + scalar_t roi_width = max(roi_end_w - roi_start_w, (scalar_t)(aligned ? 0. : 1.)); + scalar_t roi_height = max(roi_end_h - roi_start_h, (scalar_t)(aligned ? 0. : 1.)); + + const scalar_t bin_size_h = roi_height / pooled_height; + const scalar_t bin_size_w = roi_width / pooled_width; + + const scalar_t* offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + + const int sample_num_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height); + const int sample_num_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); + + scalar_t output_val = (pool_mode == 0) ? -FLT_MAX : 0; + const scalar_t y_offset = roi_start_h + ph * bin_size_h; + const scalar_t y_scale = bin_size_h / (scalar_t)(sample_num_h); + const scalar_t x_offset = roi_start_w + pw * bin_size_w; + const scalar_t x_scale = bin_size_w / (scalar_t)(sample_num_w); + for (int iy = 0; iy < sample_num_h; iy++) + { + const scalar_t y = fma(scalar_t(iy) + scalar_t(.5f), y_scale, y_offset); + for (int ix = 0; ix < sample_num_w; ix++) + { + const scalar_t x = fma(scalar_t(ix) + scalar_t(.5f), x_scale, x_offset); + scalar_t val = bilinear_interpolate(offset_bottom_data, height, width, y, x); + if (pool_mode == 0) + { + output_val = max(output_val, val); + } + else + { + output_val += val; + } + } + } + if (pool_mode != 0) + { + output_val /= max(sample_num_h * sample_num_w, 1); } - } - if (pool_mode != 0) { - output_val /= max(sample_num_h * sample_num_w, 1); - } - return output_val; + return output_val; } -template -__global__ void roi_extractor_kernel(scalar_t *__restrict__ output, - const scalar_t *__restrict__ bottom_rois, FeatData feat_data, - const int pool_mode, const int sample_num, - const float roi_scale_factor, const int finest_scale, - const int pooled_height, const int pooled_width, - int nThreads) { - CUDA_1D_KERNEL_LOOP(index, nThreads) { - const int channels = feat_data.channels; - int tmp_index = index; - const int pw = tmp_index % pooled_width; - tmp_index /= pooled_width; - const int ph = tmp_index % pooled_height; - tmp_index /= pooled_height; - const int c = tmp_index % channels; - const int n = tmp_index / channels; - - const scalar_t *offset_bottom_rois = bottom_rois + n * 5; - - scalar_t roi_offset_x0 = offset_bottom_rois[1]; - scalar_t roi_offset_y0 = offset_bottom_rois[2]; - scalar_t roi_offset_x1 = offset_bottom_rois[3]; - scalar_t roi_offset_y1 = offset_bottom_rois[4]; - - const scalar_t scale = sqrtf((roi_offset_y1 - roi_offset_y0) * (roi_offset_x1 - roi_offset_x0)); - - const int target_lvls = - min(feat_data.num_featmap - 1, - max(0, int(floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6))))); - - if (roi_scale_factor > 0.) { - const scalar_t roi_off_cx = (roi_offset_x0 + roi_offset_x1) * 0.5; - const scalar_t roi_off_cy = (roi_offset_y0 + roi_offset_y1) * 0.5; - const scalar_t half_scale_factor = roi_scale_factor * 0.5; - const scalar_t half_roi_off_w = - fma(roi_offset_x1 - roi_offset_x0 + 1, half_scale_factor, scalar_t(-0.5)); - const scalar_t half_roi_off_h = - fma(roi_offset_y1 - roi_offset_y0 + 1, half_scale_factor, scalar_t(-0.5)); - - roi_offset_x0 = roi_off_cx - half_roi_off_w; - roi_offset_x1 = roi_off_cx + half_roi_off_w; - roi_offset_y0 = roi_off_cy - half_roi_off_h; - roi_offset_y1 = roi_off_cy + half_roi_off_h; - } - - const scalar_t spatial_scale = (scalar_t)feat_data.spatial_scale[target_lvls]; - const int height = feat_data.h[target_lvls]; - const int width = feat_data.w[target_lvls]; - const scalar_t *bottom_data = (scalar_t *)feat_data.data[target_lvls]; - - const int roi_batch_ind = offset_bottom_rois[0]; - const scalar_t offset = aligned ? (scalar_t)-0.5 : (scalar_t)0.0; - const scalar_t roi_start_w = - fma(roi_offset_x0, spatial_scale, offset); // roi_offset_x0 * spatial_scale + offset; - const scalar_t roi_start_h = - fma(roi_offset_y0, spatial_scale, offset); // roi_offset_y0 * spatial_scale + offset; - const scalar_t roi_end_w = - fma(roi_offset_x1, spatial_scale, offset); // (roi_offset_x1) * spatial_scale - offset; - const scalar_t roi_end_h = - fma(roi_offset_y1, spatial_scale, offset); // (roi_offset_y1)*spatial_scale - offset; - - if (pool_mode == 0) { - const scalar_t output_val = roi_align_single( - bottom_data, roi_batch_ind, roi_start_w, roi_start_h, roi_end_w, roi_end_h, spatial_scale, - pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width); - output[index] = output_val; - } else { - const scalar_t output_val = roi_align_single( - bottom_data, roi_batch_ind, roi_start_w, roi_start_h, roi_end_w, roi_end_h, spatial_scale, - pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width); - output[index] = output_val; +template +__global__ void roi_extractor_kernel(scalar_t* __restrict__ output, + const scalar_t* __restrict__ bottom_rois, + FeatData feat_data, + const int pool_mode, + const int sample_num, + const float roi_scale_factor, + const int finest_scale, + const int pooled_height, + const int pooled_width, + int nThreads) +{ + CUDA_1D_KERNEL_LOOP(index, nThreads) + { + const int channels = feat_data.channels; + int tmp_index = index; + const int pw = tmp_index % pooled_width; + tmp_index /= pooled_width; + const int ph = tmp_index % pooled_height; + tmp_index /= pooled_height; + const int c = tmp_index % channels; + const int n = tmp_index / channels; + + const scalar_t* offset_bottom_rois = bottom_rois + n * 5; + + scalar_t roi_offset_x0 = offset_bottom_rois[1]; + scalar_t roi_offset_y0 = offset_bottom_rois[2]; + scalar_t roi_offset_x1 = offset_bottom_rois[3]; + scalar_t roi_offset_y1 = offset_bottom_rois[4]; + + const scalar_t scale = sqrtf((roi_offset_y1 - roi_offset_y0) * (roi_offset_x1 - roi_offset_x0)); + + const int target_lvls = + min(feat_data.num_featmap - 1, + max(0, int(floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6))))); + + if (roi_scale_factor > 0.) + { + const scalar_t roi_off_cx = (roi_offset_x0 + roi_offset_x1) * 0.5; + const scalar_t roi_off_cy = (roi_offset_y0 + roi_offset_y1) * 0.5; + const scalar_t half_scale_factor = roi_scale_factor * 0.5; + const scalar_t half_roi_off_w = + fma(roi_offset_x1 - roi_offset_x0 + 1, half_scale_factor, scalar_t(-0.5)); + const scalar_t half_roi_off_h = + fma(roi_offset_y1 - roi_offset_y0 + 1, half_scale_factor, scalar_t(-0.5)); + + roi_offset_x0 = roi_off_cx - half_roi_off_w; + roi_offset_x1 = roi_off_cx + half_roi_off_w; + roi_offset_y0 = roi_off_cy - half_roi_off_h; + roi_offset_y1 = roi_off_cy + half_roi_off_h; + } + + const scalar_t spatial_scale = (scalar_t)feat_data.spatial_scale[target_lvls]; + const int height = feat_data.h[target_lvls]; + const int width = feat_data.w[target_lvls]; + const scalar_t* bottom_data = (scalar_t*)feat_data.data[target_lvls]; + + const int roi_batch_ind = offset_bottom_rois[0]; + const scalar_t offset = aligned ? (scalar_t)-0.5 : (scalar_t)0.0; + const scalar_t roi_start_w = + fma(roi_offset_x0, spatial_scale, offset); // roi_offset_x0 * spatial_scale + offset; + const scalar_t roi_start_h = + fma(roi_offset_y0, spatial_scale, offset); // roi_offset_y0 * spatial_scale + offset; + const scalar_t roi_end_w = + fma(roi_offset_x1, spatial_scale, offset); // (roi_offset_x1) * spatial_scale - offset; + const scalar_t roi_end_h = + fma(roi_offset_y1, spatial_scale, offset); // (roi_offset_y1)*spatial_scale - offset; + + if (pool_mode == 0) + { + const scalar_t output_val = roi_align_single( + bottom_data, + roi_batch_ind, + roi_start_w, + roi_start_h, + roi_end_w, + roi_end_h, + spatial_scale, + pw, + ph, + c, + sample_num, + channels, + height, + width, + pooled_height, + pooled_width); + output[index] = output_val; + } + else + { + const scalar_t output_val = roi_align_single( + bottom_data, + roi_batch_ind, + roi_start_w, + roi_start_h, + roi_end_w, + roi_end_h, + spatial_scale, + pw, + ph, + c, + sample_num, + channels, + height, + width, + pooled_height, + pooled_width); + output[index] = output_val; + } } - } } -template -void multi_level_roi_align(T *output, const T *rois, int num_rois, const void *const *feats, - int num_feats, int n, int c, int *h, int *w, float *strides, - int aligned_height, int aligned_width, int pool_mode, int sample_num, - float roi_scale_factor, int finest_scale, bool aligned, - cudaStream_t stream) { - FeatData feat_data; - feat_data.batch_size = n; - feat_data.channels = c; - feat_data.num_featmap = num_feats; - for (int i = 0; i < num_feats; ++i) { - feat_data.data[i] = feats[i]; - feat_data.h[i] = h[i]; - feat_data.w[i] = w[i]; - feat_data.spatial_scale[i] = 1. / float(strides[i]); - } - int nThreads = num_rois * c * aligned_height * aligned_width; - if (aligned) { - roi_extractor_kernel<<>>( - output, rois, feat_data, pool_mode, sample_num, roi_scale_factor, finest_scale, - aligned_height, aligned_width, nThreads); - } else { - roi_extractor_kernel<<>>( - output, rois, feat_data, pool_mode, sample_num, roi_scale_factor, finest_scale, - aligned_height, aligned_width, nThreads); - } +template +void multi_level_roi_align(T* output, const T* rois, int num_rois, const void* const* feats, int num_feats, int n, int c, int* h, int* w, float* strides, int aligned_height, int aligned_width, int pool_mode, int sample_num, float roi_scale_factor, int finest_scale, bool aligned, cudaStream_t stream) +{ + FeatData feat_data; + feat_data.batch_size = n; + feat_data.channels = c; + feat_data.num_featmap = num_feats; + for (int i = 0; i < num_feats; ++i) + { + feat_data.data[i] = feats[i]; + feat_data.h[i] = h[i]; + feat_data.w[i] = w[i]; + feat_data.spatial_scale[i] = 1. / float(strides[i]); + } + int nThreads = num_rois * c * aligned_height * aligned_width; + if (aligned) + { + roi_extractor_kernel<<>>( + output, + rois, + feat_data, + pool_mode, + sample_num, + roi_scale_factor, + finest_scale, + aligned_height, + aligned_width, + nThreads); + } + else + { + roi_extractor_kernel<<>>( + output, + rois, + feat_data, + pool_mode, + sample_num, + roi_scale_factor, + finest_scale, + aligned_height, + aligned_width, + nThreads); + } } -template void multi_level_roi_align(float *output, const float *rois, int num_rois, - const void *const *feats, int num_feats, int n, int c, - int *h, int *w, float *strides, int aligned_height, - int aligned_width, int pool_mode, int sample_num, - float roi_scale_factor, int finest_scale, bool aligned, - cudaStream_t stream); +template void multi_level_roi_align(float* output, const float* rois, int num_rois, const void* const* feats, int num_feats, int n, int c, int* h, int* w, float* strides, int aligned_height, int aligned_width, int pool_mode, int sample_num, float roi_scale_factor, int finest_scale, bool aligned, cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp index 5f7220dbf0..efd5564a27 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_roi_align/trt_multi_level_roi_align_kernel.hpp @@ -3,11 +3,7 @@ #define TRT_MULTI_LEVEL_ROI_ALIGN_KERNEL_HPP #include -template -void multi_level_roi_align(T *output, const T *rois, int num_rois, const void *const *feats, - int num_feats, int n, int c, int *h, int *w, float *strides, - int aligned_height, int aligned_width, int pool_mode, int sample_num, - float roi_scale_factor, int finest_scale, bool aligned, - cudaStream_t stream); +template +void multi_level_roi_align(T* output, const T* rois, int num_rois, const void* const* feats, int num_feats, int n, int c, int* h, int* w, float* strides, int aligned_height, int aligned_width, int pool_mode, int sample_num, float roi_scale_factor, int finest_scale, bool aligned, cudaStream_t stream); #endif // TRT_MULTI_LEVEL_ROI_ALIGN_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp index 6637603128..492a171efd 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.cpp @@ -9,220 +9,282 @@ #include "trt_multi_level_rotated_roi_align_kernel.hpp" #include "trt_plugin_helper.hpp" #include "trt_serialize.hpp" -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"MMCVMultiLevelRotatedRoiAlign"}; -} // namespace - -TRTMultiLevelRotatedRoiAlign::TRTMultiLevelRotatedRoiAlign( - const std::string &name, int alignedHeight, int alignedWidth, int clockwise, int sampleNum, - const std::vector &featmapStrides, float roiScaleFactor, int finestScale, bool aligned) - : TRTPluginBase(name), - mAlignedHeight(alignedHeight), - mAlignedWidth(alignedWidth), - mClockwise(clockwise), - mSampleNum(sampleNum), - mFeatmapStrides(featmapStrides), - mRoiScaleFactor(roiScaleFactor), - mFinestScale(finestScale), - mAligned(aligned) {} - -TRTMultiLevelRotatedRoiAlign::TRTMultiLevelRotatedRoiAlign(const std::string name, const void *data, - size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mAlignedHeight); - deserialize_value(&data, &length, &mAlignedWidth); - deserialize_value(&data, &length, &mClockwise); - deserialize_value(&data, &length, &mSampleNum); - deserialize_value(&data, &length, &mRoiScaleFactor); - deserialize_value(&data, &length, &mFinestScale); - deserialize_value(&data, &length, &mAligned); - deserialize_value(&data, &length, &mFeatmapStrides); -} - -nvinfer1::IPluginV2DynamicExt *TRTMultiLevelRotatedRoiAlign::clone() const TRT_NOEXCEPT { - TRTMultiLevelRotatedRoiAlign *plugin = new TRTMultiLevelRotatedRoiAlign( - mLayerName, mAlignedHeight, mAlignedWidth, mClockwise, mSampleNum, mFeatmapStrides, - mRoiScaleFactor, mFinestScale, mAligned); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs TRTMultiLevelRotatedRoiAlign::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - // warning, nbInputs should equal to mFeatmapStrides.size() + 1 - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[1].d[1]; - ret.d[2] = exprBuilder.constant(mAlignedHeight); - ret.d[3] = exprBuilder.constant(mAlignedWidth); - - return ret; -} - -bool TRTMultiLevelRotatedRoiAlign::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { - return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; -} - -void TRTMultiLevelRotatedRoiAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT { - // Validate input arguments - ASSERT(nbOutputs == 1); - ASSERT(nbInputs >= 1); - mFeatmapStrides = - std::vector(mFeatmapStrides.begin(), mFeatmapStrides.begin() + nbInputs - 1); -} - -size_t TRTMultiLevelRotatedRoiAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, - int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int TRTMultiLevelRotatedRoiAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, - void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { - int num_rois = inputDesc[0].dims.d[0]; - int batch_size = inputDesc[1].dims.d[0]; - int channels = inputDesc[1].dims.d[1]; - - const int kMaxFeatMap = 10; - int heights[kMaxFeatMap]; - int widths[kMaxFeatMap]; - float strides[kMaxFeatMap]; - - int num_feats = mFeatmapStrides.size(); - for (int i = 0; i < num_feats; ++i) { - heights[i] = inputDesc[i + 1].dims.d[2]; - widths[i] = inputDesc[i + 1].dims.d[3]; - strides[i] = mFeatmapStrides[i]; - } - - const void *rois = inputs[0]; - const void *const *feats = inputs + 1; - - multi_level_rotated_roi_align((float *)outputs[0], (const float *)rois, num_rois, feats, - num_feats, batch_size, channels, &heights[0], &widths[0], - &strides[0], mAlignedHeight, mAlignedWidth, mClockwise, - mSampleNum, mRoiScaleFactor, mFinestScale, mAligned, stream); - - return 0; -} - -nvinfer1::DataType TRTMultiLevelRotatedRoiAlign::getOutputDataType( - int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { - return nvinfer1::DataType::kFLOAT; -} - -// IPluginV2 Methods -const char *TRTMultiLevelRotatedRoiAlign::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTMultiLevelRotatedRoiAlign::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -int TRTMultiLevelRotatedRoiAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t TRTMultiLevelRotatedRoiAlign::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) + - serialized_size(mAlignedWidth) + serialized_size(mClockwise) + - serialized_size(mSampleNum) + serialized_size(mRoiScaleFactor) + - serialized_size(mFinestScale) + serialized_size(mAligned); -} - -void TRTMultiLevelRotatedRoiAlign::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mAlignedHeight); - serialize_value(&buffer, mAlignedWidth); - serialize_value(&buffer, mClockwise); - serialize_value(&buffer, mSampleNum); - serialize_value(&buffer, mRoiScaleFactor); - serialize_value(&buffer, mFinestScale); - serialize_value(&buffer, mAligned); - serialize_value(&buffer, mFeatmapStrides); -} - -TRTMultiLevelRotatedRoiAlignCreator::TRTMultiLevelRotatedRoiAlignCreator() { - mPluginAttributes = std::vector( - {nvinfer1::PluginField("output_height"), nvinfer1::PluginField("output_width"), - nvinfer1::PluginField("clockwise"), nvinfer1::PluginField("sampling_ratio"), - nvinfer1::PluginField("featmap_strides"), nvinfer1::PluginField("roi_scale_factor"), - nvinfer1::PluginField("finest_scale"), nvinfer1::PluginField("aligned")}); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *TRTMultiLevelRotatedRoiAlignCreator::getPluginName() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *TRTMultiLevelRotatedRoiAlignCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *TRTMultiLevelRotatedRoiAlignCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - int alignedHeight = 7; - int alignedWidth = 7; - int clockwise = 0; - int sampleNum = 2; - std::vector featmapStrides; - float roiScaleFactor = -1; - int finestScale = 56; - bool aligned = false; - - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; - } - std::string field_name(fc->fields[i].name); - - if (field_name.compare("output_height") == 0) { - alignedHeight = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("output_width") == 0) { - alignedWidth = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("clockwise") == 0) { - clockwise = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("sampling_ratio") == 0) { - sampleNum = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("roi_scale_factor") == 0) { - roiScaleFactor = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("finest_scale") == 0) { - finestScale = static_cast(fc->fields[i].data)[0]; - } else if (field_name.compare("featmap_strides") == 0) { - int data_size = (fc->fields[i].length); - const float *data_start = static_cast(fc->fields[i].data); - featmapStrides = std::vector(data_start, data_start + data_size); - } else if (field_name.compare("aligned") == 0) { - int aligned_int = static_cast(fc->fields[i].data)[0]; - aligned = aligned_int != 0; - } - } - - ASSERT(featmapStrides.size() != 0); - - TRTMultiLevelRotatedRoiAlign *plugin = - new TRTMultiLevelRotatedRoiAlign(name, alignedHeight, alignedWidth, clockwise, sampleNum, - featmapStrides, roiScaleFactor, finestScale, aligned); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *TRTMultiLevelRotatedRoiAlignCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new TRTMultiLevelRotatedRoiAlign(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -REGISTER_TENSORRT_PLUGIN(TRTMultiLevelRotatedRoiAlignCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"MMCVMultiLevelRotatedRoiAlign"}; + } // namespace + + TRTMultiLevelRotatedRoiAlign::TRTMultiLevelRotatedRoiAlign( + const std::string& name, + int alignedHeight, + int alignedWidth, + int clockwise, + int sampleNum, + const std::vector& featmapStrides, + float roiScaleFactor, + int finestScale, + bool aligned) + : TRTPluginBase(name) + , mAlignedHeight(alignedHeight) + , mAlignedWidth(alignedWidth) + , mClockwise(clockwise) + , mSampleNum(sampleNum) + , mFeatmapStrides(featmapStrides) + , mRoiScaleFactor(roiScaleFactor) + , mFinestScale(finestScale) + , mAligned(aligned) + { + } + + TRTMultiLevelRotatedRoiAlign::TRTMultiLevelRotatedRoiAlign(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mAlignedHeight); + deserialize_value(&data, &length, &mAlignedWidth); + deserialize_value(&data, &length, &mClockwise); + deserialize_value(&data, &length, &mSampleNum); + deserialize_value(&data, &length, &mRoiScaleFactor); + deserialize_value(&data, &length, &mFinestScale); + deserialize_value(&data, &length, &mAligned); + deserialize_value(&data, &length, &mFeatmapStrides); + } + + nvinfer1::IPluginV2DynamicExt* TRTMultiLevelRotatedRoiAlign::clone() const TRT_NOEXCEPT + { + TRTMultiLevelRotatedRoiAlign* plugin = new TRTMultiLevelRotatedRoiAlign( + mLayerName, + mAlignedHeight, + mAlignedWidth, + mClockwise, + mSampleNum, + mFeatmapStrides, + mRoiScaleFactor, + mFinestScale, + mAligned); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs TRTMultiLevelRotatedRoiAlign::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + // warning, nbInputs should equal to mFeatmapStrides.size() + 1 + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[1].d[1]; + ret.d[2] = exprBuilder.constant(mAlignedHeight); + ret.d[3] = exprBuilder.constant(mAlignedWidth); + + return ret; + } + + bool TRTMultiLevelRotatedRoiAlign::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + + void TRTMultiLevelRotatedRoiAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT + { + // Validate input arguments + ASSERT(nbOutputs == 1); + ASSERT(nbInputs >= 1); + mFeatmapStrides = + std::vector(mFeatmapStrides.begin(), mFeatmapStrides.begin() + nbInputs - 1); + } + + size_t TRTMultiLevelRotatedRoiAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int TRTMultiLevelRotatedRoiAlign::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int num_rois = inputDesc[0].dims.d[0]; + int batch_size = inputDesc[1].dims.d[0]; + int channels = inputDesc[1].dims.d[1]; + + const int kMaxFeatMap = 10; + int heights[kMaxFeatMap]; + int widths[kMaxFeatMap]; + float strides[kMaxFeatMap]; + + int num_feats = mFeatmapStrides.size(); + for (int i = 0; i < num_feats; ++i) + { + heights[i] = inputDesc[i + 1].dims.d[2]; + widths[i] = inputDesc[i + 1].dims.d[3]; + strides[i] = mFeatmapStrides[i]; + } + + const void* rois = inputs[0]; + const void* const* feats = inputs + 1; + + multi_level_rotated_roi_align((float*)outputs[0], (const float*)rois, num_rois, feats, num_feats, batch_size, channels, &heights[0], &widths[0], &strides[0], mAlignedHeight, mAlignedWidth, mClockwise, mSampleNum, mRoiScaleFactor, mFinestScale, mAligned, stream); + + return 0; + } + + nvinfer1::DataType TRTMultiLevelRotatedRoiAlign::getOutputDataType( + int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return nvinfer1::DataType::kFLOAT; + } + + // IPluginV2 Methods + const char* TRTMultiLevelRotatedRoiAlign::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTMultiLevelRotatedRoiAlign::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int TRTMultiLevelRotatedRoiAlign::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t TRTMultiLevelRotatedRoiAlign::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mFeatmapStrides) + serialized_size(mAlignedHeight) + + serialized_size(mAlignedWidth) + serialized_size(mClockwise) + + serialized_size(mSampleNum) + serialized_size(mRoiScaleFactor) + + serialized_size(mFinestScale) + serialized_size(mAligned); + } + + void TRTMultiLevelRotatedRoiAlign::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mAlignedHeight); + serialize_value(&buffer, mAlignedWidth); + serialize_value(&buffer, mClockwise); + serialize_value(&buffer, mSampleNum); + serialize_value(&buffer, mRoiScaleFactor); + serialize_value(&buffer, mFinestScale); + serialize_value(&buffer, mAligned); + serialize_value(&buffer, mFeatmapStrides); + } + + TRTMultiLevelRotatedRoiAlignCreator::TRTMultiLevelRotatedRoiAlignCreator() + { + mPluginAttributes = std::vector( + {nvinfer1::PluginField("output_height"), nvinfer1::PluginField("output_width"), nvinfer1::PluginField("clockwise"), nvinfer1::PluginField("sampling_ratio"), nvinfer1::PluginField("featmap_strides"), nvinfer1::PluginField("roi_scale_factor"), nvinfer1::PluginField("finest_scale"), nvinfer1::PluginField("aligned")}); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTMultiLevelRotatedRoiAlignCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTMultiLevelRotatedRoiAlignCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* TRTMultiLevelRotatedRoiAlignCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + int alignedHeight = 7; + int alignedWidth = 7; + int clockwise = 0; + int sampleNum = 2; + std::vector featmapStrides; + float roiScaleFactor = -1; + int finestScale = 56; + bool aligned = false; + + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("output_height") == 0) + { + alignedHeight = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("output_width") == 0) + { + alignedWidth = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("clockwise") == 0) + { + clockwise = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("sampling_ratio") == 0) + { + sampleNum = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("roi_scale_factor") == 0) + { + roiScaleFactor = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("finest_scale") == 0) + { + finestScale = static_cast(fc->fields[i].data)[0]; + } + else if (field_name.compare("featmap_strides") == 0) + { + int data_size = (fc->fields[i].length); + const float* data_start = static_cast(fc->fields[i].data); + featmapStrides = std::vector(data_start, data_start + data_size); + } + else if (field_name.compare("aligned") == 0) + { + int aligned_int = static_cast(fc->fields[i].data)[0]; + aligned = aligned_int != 0; + } + } + + ASSERT(featmapStrides.size() != 0); + + TRTMultiLevelRotatedRoiAlign* plugin = + new TRTMultiLevelRotatedRoiAlign(name, alignedHeight, alignedWidth, clockwise, sampleNum, featmapStrides, roiScaleFactor, finestScale, aligned); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* TRTMultiLevelRotatedRoiAlignCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new TRTMultiLevelRotatedRoiAlign(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + REGISTER_TENSORRT_PLUGIN(TRTMultiLevelRotatedRoiAlignCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp index cf0bab7584..570317ebde 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align.hpp @@ -10,70 +10,65 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class TRTMultiLevelRotatedRoiAlign : public TRTPluginBase { - public: - TRTMultiLevelRotatedRoiAlign(const std::string &name, int alignedHeight, int alignedWidth, - int clockwise, int sampleNum, - const std::vector &featmapStrides, float roiScaleFactor = -1, - int finestScale = 56, bool aligned = false); +namespace mmdeploy +{ + class TRTMultiLevelRotatedRoiAlign : public TRTPluginBase + { + public: + TRTMultiLevelRotatedRoiAlign(const std::string& name, int alignedHeight, int alignedWidth, int clockwise, int sampleNum, const std::vector& featmapStrides, float roiScaleFactor = -1, int finestScale = 56, bool aligned = false); - TRTMultiLevelRotatedRoiAlign(const std::string name, const void *data, size_t length); + TRTMultiLevelRotatedRoiAlign(const std::string name, const void* data, size_t length); - TRTMultiLevelRotatedRoiAlign() = delete; + TRTMultiLevelRotatedRoiAlign() = delete; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - private: - int mAlignedHeight; - int mAlignedWidth; - int mClockwise; - int mSampleNum; - std::vector mFeatmapStrides; - float mRoiScaleFactor; - int mFinestScale; - bool mAligned; -}; + private: + int mAlignedHeight; + int mAlignedWidth; + int mClockwise; + int mSampleNum; + std::vector mFeatmapStrides; + float mRoiScaleFactor; + int mFinestScale; + bool mAligned; + }; -class TRTMultiLevelRotatedRoiAlignCreator : public TRTPluginCreatorBase { - public: - TRTMultiLevelRotatedRoiAlignCreator(); + class TRTMultiLevelRotatedRoiAlignCreator : public TRTPluginCreatorBase + { + public: + TRTMultiLevelRotatedRoiAlignCreator(); - const char *getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu index 1c6f292bae..897ae69e8b 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.cu @@ -10,155 +10,223 @@ #include "trt_plugin_helper.hpp" const int kMAX_FEATMAP_SIZE = 10; -struct FeatData { - const void *data[kMAX_FEATMAP_SIZE]; - int batch_size; - int channels; - int h[kMAX_FEATMAP_SIZE]; - int w[kMAX_FEATMAP_SIZE]; - float spatial_scale[kMAX_FEATMAP_SIZE]; - int num_featmap; +struct FeatData +{ + const void* data[kMAX_FEATMAP_SIZE]; + int batch_size; + int channels; + int h[kMAX_FEATMAP_SIZE]; + int w[kMAX_FEATMAP_SIZE]; + float spatial_scale[kMAX_FEATMAP_SIZE]; + int num_featmap; }; -template -__device__ scalar_t roi_align_single(const scalar_t *__restrict__ bottom_data, - const int roi_batch_ind, scalar_t roi_center_w, - scalar_t roi_center_h, scalar_t roi_width, scalar_t roi_height, - scalar_t theta, const scalar_t spatial_scale, const int pw, - const int ph, const int c, const int sample_num, - const int channels, const int height, const int width, - const int pooled_height, const int pooled_width) { - // Force malformed ROIs to be 1x1 - - roi_width = max(roi_width, (scalar_t)1.); - roi_height = max(roi_height, (scalar_t)1.); - - const scalar_t bin_size_h = roi_height / scalar_t(pooled_height); - const scalar_t bin_size_w = roi_width / scalar_t(pooled_width); - - const scalar_t *offset_bottom_data = - bottom_data + (roi_batch_ind * channels + c) * height * width; - - const int roi_bin_grid_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height); - const int roi_bin_grid_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); - - const scalar_t roi_start_h = -roi_height / scalar_t(2.0); - const scalar_t roi_start_w = -roi_width / scalar_t(2.0); - const scalar_t cosscalar_theta = cos(theta); - const scalar_t sinscalar_theta = sin(theta); - - // We do average (integral) pooling inside a bin - const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 - - scalar_t output_val = 0.; - - for (int iy = 0; iy < roi_bin_grid_h; iy++) { // e.g., iy = 0, 1 - const scalar_t yy = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const scalar_t xx = - roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); - - // Rotate by theta (counterclockwise) around the center and translate - scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h; - scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w; - - scalar_t val = bilinear_interpolate(offset_bottom_data, height, width, y, x); - output_val += val; +template +__device__ scalar_t roi_align_single(const scalar_t* __restrict__ bottom_data, + const int roi_batch_ind, + scalar_t roi_center_w, + scalar_t roi_center_h, + scalar_t roi_width, + scalar_t roi_height, + scalar_t theta, + const scalar_t spatial_scale, + const int pw, + const int ph, + const int c, + const int sample_num, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width) +{ + // Force malformed ROIs to be 1x1 + + roi_width = max(roi_width, (scalar_t)1.); + roi_height = max(roi_height, (scalar_t)1.); + + const scalar_t bin_size_h = roi_height / scalar_t(pooled_height); + const scalar_t bin_size_w = roi_width / scalar_t(pooled_width); + + const scalar_t* offset_bottom_data = + bottom_data + (roi_batch_ind * channels + c) * height * width; + + const int roi_bin_grid_h = (sample_num > 0) ? sample_num : ceil(roi_height / pooled_height); + const int roi_bin_grid_w = (sample_num > 0) ? sample_num : ceil(roi_width / pooled_width); + + const scalar_t roi_start_h = -roi_height / scalar_t(2.0); + const scalar_t roi_start_w = -roi_width / scalar_t(2.0); + const scalar_t cosscalar_theta = cos(theta); + const scalar_t sinscalar_theta = sin(theta); + + // We do average (integral) pooling inside a bin + const scalar_t count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + scalar_t output_val = 0.; + + for (int iy = 0; iy < roi_bin_grid_h; iy++) + { // e.g., iy = 0, 1 + const scalar_t yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) + { + const scalar_t xx = + roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + + // Rotate by theta (counterclockwise) around the center and translate + scalar_t y = yy * cosscalar_theta - xx * sinscalar_theta + roi_center_h; + scalar_t x = yy * sinscalar_theta + xx * cosscalar_theta + roi_center_w; + + scalar_t val = bilinear_interpolate(offset_bottom_data, height, width, y, x); + output_val += val; + } } - } - return output_val / count; + return output_val / count; } -template -__global__ void rotated_roi_extractor_kernel(scalar_t *__restrict__ output, - const scalar_t *__restrict__ bottom_rois, - FeatData feat_data, const int clockwise, - const int sample_num, const float roi_scale_factor, - const int finest_scale, const int pooled_height, - const int pooled_width, int nThreads) { - CUDA_1D_KERNEL_LOOP(index, nThreads) { - const int channels = feat_data.channels; - int tmp_index = index; - const int pw = tmp_index % pooled_width; - tmp_index /= pooled_width; - const int ph = tmp_index % pooled_height; - tmp_index /= pooled_height; - const int c = tmp_index % channels; - const int n = tmp_index / channels; - - const scalar_t *offset_bottom_rois = bottom_rois + n * 6; - - scalar_t roi_offset_x0 = offset_bottom_rois[1]; - scalar_t roi_offset_y0 = offset_bottom_rois[2]; - scalar_t roi_offset_width = offset_bottom_rois[3]; - scalar_t roi_offset_height = offset_bottom_rois[4]; - scalar_t theta = offset_bottom_rois[5]; - - const scalar_t scale = sqrtf(roi_offset_width * roi_offset_height); - - const int target_lvls = - min(feat_data.num_featmap - 1, - max(0, int(floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6))))); - - if (roi_scale_factor > 0.) { - roi_offset_width = roi_offset_width * roi_scale_factor; - roi_offset_height = roi_offset_height * roi_scale_factor; +template +__global__ void rotated_roi_extractor_kernel(scalar_t* __restrict__ output, + const scalar_t* __restrict__ bottom_rois, + FeatData feat_data, + const int clockwise, + const int sample_num, + const float roi_scale_factor, + const int finest_scale, + const int pooled_height, + const int pooled_width, + int nThreads) +{ + CUDA_1D_KERNEL_LOOP(index, nThreads) + { + const int channels = feat_data.channels; + int tmp_index = index; + const int pw = tmp_index % pooled_width; + tmp_index /= pooled_width; + const int ph = tmp_index % pooled_height; + tmp_index /= pooled_height; + const int c = tmp_index % channels; + const int n = tmp_index / channels; + + const scalar_t* offset_bottom_rois = bottom_rois + n * 6; + + scalar_t roi_offset_x0 = offset_bottom_rois[1]; + scalar_t roi_offset_y0 = offset_bottom_rois[2]; + scalar_t roi_offset_width = offset_bottom_rois[3]; + scalar_t roi_offset_height = offset_bottom_rois[4]; + scalar_t theta = offset_bottom_rois[5]; + + const scalar_t scale = sqrtf(roi_offset_width * roi_offset_height); + + const int target_lvls = + min(feat_data.num_featmap - 1, + max(0, int(floorf(log2f(scale / (scalar_t)(finest_scale) + 1e-6))))); + + if (roi_scale_factor > 0.) + { + roi_offset_width = roi_offset_width * roi_scale_factor; + roi_offset_height = roi_offset_height * roi_scale_factor; + } + + const scalar_t spatial_scale = (scalar_t)feat_data.spatial_scale[target_lvls]; + const int height = feat_data.h[target_lvls]; + const int width = feat_data.w[target_lvls]; + const scalar_t* bottom_data = (scalar_t*)feat_data.data[target_lvls]; + + const int roi_batch_ind = offset_bottom_rois[0]; + const scalar_t offset = aligned ? (scalar_t)-0.5 : (scalar_t)0.0; + const scalar_t roi_center_w = fma(roi_offset_x0, spatial_scale, offset); + const scalar_t roi_center_h = fma(roi_offset_y0, spatial_scale, offset); + const scalar_t roi_width = roi_offset_width * spatial_scale; + const scalar_t roi_height = roi_offset_height * spatial_scale; + + theta = clockwise > 0 ? -theta : theta; + + const scalar_t output_val = roi_align_single( + bottom_data, + roi_batch_ind, + roi_center_w, + roi_center_h, + roi_width, + roi_height, + theta, + spatial_scale, + pw, + ph, + c, + sample_num, + channels, + height, + width, + pooled_height, + pooled_width); + output[index] = output_val; } - - const scalar_t spatial_scale = (scalar_t)feat_data.spatial_scale[target_lvls]; - const int height = feat_data.h[target_lvls]; - const int width = feat_data.w[target_lvls]; - const scalar_t *bottom_data = (scalar_t *)feat_data.data[target_lvls]; - - const int roi_batch_ind = offset_bottom_rois[0]; - const scalar_t offset = aligned ? (scalar_t)-0.5 : (scalar_t)0.0; - const scalar_t roi_center_w = fma(roi_offset_x0, spatial_scale, offset); - const scalar_t roi_center_h = fma(roi_offset_y0, spatial_scale, offset); - const scalar_t roi_width = roi_offset_width * spatial_scale; - const scalar_t roi_height = roi_offset_height * spatial_scale; - - theta = clockwise > 0 ? -theta : theta; - - const scalar_t output_val = roi_align_single( - bottom_data, roi_batch_ind, roi_center_w, roi_center_h, roi_width, roi_height, theta, - spatial_scale, pw, ph, c, sample_num, channels, height, width, pooled_height, pooled_width); - output[index] = output_val; - } } -template -void multi_level_rotated_roi_align(T *output, const T *rois, int num_rois, const void *const *feats, - int num_feats, int n, int c, int *h, int *w, float *strides, - int aligned_height, int aligned_width, int clockwise, - int sample_num, float roi_scale_factor, int finest_scale, - bool aligned, cudaStream_t stream) { - FeatData feat_data; - feat_data.batch_size = n; - feat_data.channels = c; - feat_data.num_featmap = num_feats; - for (int i = 0; i < num_feats; ++i) { - feat_data.data[i] = feats[i]; - feat_data.h[i] = h[i]; - feat_data.w[i] = w[i]; - feat_data.spatial_scale[i] = 1. / float(strides[i]); - } - int nThreads = num_rois * c * aligned_height * aligned_width; - if (aligned) { - rotated_roi_extractor_kernel<<>>( - output, rois, feat_data, clockwise, sample_num, roi_scale_factor, finest_scale, - aligned_height, aligned_width, nThreads); - } else { - rotated_roi_extractor_kernel<<>>( - output, rois, feat_data, clockwise, sample_num, roi_scale_factor, finest_scale, - aligned_height, aligned_width, nThreads); - } +template +void multi_level_rotated_roi_align(T* output, const T* rois, int num_rois, const void* const* feats, int num_feats, int n, int c, int* h, int* w, float* strides, int aligned_height, int aligned_width, int clockwise, int sample_num, float roi_scale_factor, int finest_scale, bool aligned, cudaStream_t stream) +{ + FeatData feat_data; + feat_data.batch_size = n; + feat_data.channels = c; + feat_data.num_featmap = num_feats; + for (int i = 0; i < num_feats; ++i) + { + feat_data.data[i] = feats[i]; + feat_data.h[i] = h[i]; + feat_data.w[i] = w[i]; + feat_data.spatial_scale[i] = 1. / float(strides[i]); + } + int nThreads = num_rois * c * aligned_height * aligned_width; + if (aligned) + { + rotated_roi_extractor_kernel<<>>( + output, + rois, + feat_data, + clockwise, + sample_num, + roi_scale_factor, + finest_scale, + aligned_height, + aligned_width, + nThreads); + } + else + { + rotated_roi_extractor_kernel<<>>( + output, + rois, + feat_data, + clockwise, + sample_num, + roi_scale_factor, + finest_scale, + aligned_height, + aligned_width, + nThreads); + } } template void multi_level_rotated_roi_align( - float *output, const float *rois, int num_rois, const void *const *feats, int num_feats, int n, - int c, int *h, int *w, float *strides, int aligned_height, int aligned_width, int clockwise, - int sample_num, float roi_scale_factor, int finest_scale, bool aligned, cudaStream_t stream); + float* output, + const float* rois, + int num_rois, + const void* const* feats, + int num_feats, + int n, + int c, + int* h, + int* w, + float* strides, + int aligned_height, + int aligned_width, + int clockwise, + int sample_num, + float roi_scale_factor, + int finest_scale, + bool aligned, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp index fc3700df3b..f3fb25df83 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_level_rotated_roi_align/trt_multi_level_rotated_roi_align_kernel.hpp @@ -3,11 +3,7 @@ #define TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_KERNEL_HPP #include -template -void multi_level_rotated_roi_align(T *output, const T *rois, int num_rois, const void *const *feats, - int num_feats, int n, int c, int *h, int *w, float *strides, - int aligned_height, int aligned_width, int clockwise, - int sample_num, float roi_scale_factor, int finest_scale, - bool aligned, cudaStream_t stream); +template +void multi_level_rotated_roi_align(T* output, const T* rois, int num_rois, const void* const* feats, int num_feats, int n, int c, int* h, int* w, float* strides, int aligned_height, int aligned_width, int clockwise, int sample_num, float roi_scale_factor, int finest_scale, bool aligned, cudaStream_t stream); #endif // TRT_MULTI_LEVEL_ROTATED_ROI_ALIGN_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.cpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.cpp index d14a25e929..ce9e81290d 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.cpp @@ -10,164 +10,208 @@ using namespace nvinfer1; -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"MMCVMultiScaleDeformableAttention"}; -} // namespace - -MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic( - const std::string &name) - : TRTPluginBase(name) {} - -MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string name, - const void *data, - size_t length) - : TRTPluginBase(name) {} -MultiScaleDeformableAttnPluginDynamic::~MultiScaleDeformableAttnPluginDynamic() {} - -nvinfer1::IPluginV2DynamicExt *MultiScaleDeformableAttnPluginDynamic::clone() const TRT_NOEXCEPT { - MultiScaleDeformableAttnPluginDynamic *plugin = - new MultiScaleDeformableAttnPluginDynamic(mLayerName); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs MultiScaleDeformableAttnPluginDynamic::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - nvinfer1::DimsExprs ret; - ret.nbDims = 3; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[3].d[1]; - - ret.d[2] = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[2], *inputs[0].d[3]); - - return ret; -} - -bool MultiScaleDeformableAttnPluginDynamic::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) { - if ((pos == 1) || (pos == 2)) { - return (ioDesc[pos].type == nvinfer1::DataType::kINT32); - } else { - return ((ioDesc[pos].type == ioDesc[0].type) && - ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT) || - (ioDesc[pos].type == nvinfer1::DataType::kHALF))); - } - } else { - return false; - } -} - -void MultiScaleDeformableAttnPluginDynamic::configurePlugin( - const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, int nbOutputs) TRT_NOEXCEPT {} - -size_t MultiScaleDeformableAttnPluginDynamic::getWorkspaceSize( - const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int MultiScaleDeformableAttnPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, - void *workSpace, - cudaStream_t stream) TRT_NOEXCEPT { - int32_t const batch = inputDesc[0].dims.d[0]; - int32_t spatial_size = inputDesc[0].dims.d[1]; - int32_t num_heads = inputDesc[0].dims.d[2]; - int32_t channels = inputDesc[0].dims.d[3]; - int32_t num_levels = inputDesc[1].dims.d[0]; - int32_t num_query = inputDesc[3].dims.d[1]; - int32_t num_point = inputDesc[3].dims.d[4]; - int32_t rc = 0; - if (inputDesc[0].type == nvinfer1::DataType::kFLOAT) { - float const *value = static_cast(inputs[0]); - int32_t const *spatialShapes = static_cast(inputs[1]); - int32_t const *levelStartIndex = static_cast(inputs[2]); - float const *samplingLoc = static_cast(inputs[3]); - float const *attnWeight = static_cast(inputs[4]); - float *output = static_cast(outputs[0]); - - rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, - output, batch, spatial_size, num_heads, channels, num_levels, - num_query, num_point, stream); - } else if (inputDesc[0].type == nvinfer1::DataType::kHALF) { - const __half *value = static_cast(inputs[0]); - int32_t const *spatialShapes = static_cast(inputs[1]); - int32_t const *levelStartIndex = static_cast(inputs[2]); - const __half *samplingLoc = static_cast(inputs[3]); - const __half *attnWeight = static_cast(inputs[4]); - __half *output = static_cast<__half *>(outputs[0]); - - rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, - output, batch, spatial_size, num_heads, channels, num_levels, - num_query, num_point, stream); - } - - return rc; -} - -nvinfer1::DataType MultiScaleDeformableAttnPluginDynamic::getOutputDataType( - int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *MultiScaleDeformableAttnPluginDynamic::getPluginType() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *MultiScaleDeformableAttnPluginDynamic::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -int MultiScaleDeformableAttnPluginDynamic::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t MultiScaleDeformableAttnPluginDynamic::getSerializationSize() const TRT_NOEXCEPT { - return 0; -} - -void MultiScaleDeformableAttnPluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {} - -void MultiScaleDeformableAttnPluginDynamic::attachToContext( - cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT {} - -void MultiScaleDeformableAttnPluginDynamic::detachFromContext() TRT_NOEXCEPT {} - -////////////////////// creator ///////////////////////////// - -MultiScaleDeformableAttnPluginDynamicCreator::MultiScaleDeformableAttnPluginDynamicCreator() { - mPluginAttributes.clear(); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *MultiScaleDeformableAttnPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - MultiScaleDeformableAttnPluginDynamic *plugin = new MultiScaleDeformableAttnPluginDynamic(name); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *MultiScaleDeformableAttnPluginDynamicCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new MultiScaleDeformableAttnPluginDynamic(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(MultiScaleDeformableAttnPluginDynamicCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"MMCVMultiScaleDeformableAttention"}; + } // namespace + + MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic( + const std::string& name) + : TRTPluginBase(name) + { + } + + MultiScaleDeformableAttnPluginDynamic::MultiScaleDeformableAttnPluginDynamic(const std::string name, + const void* data, + size_t length) + : TRTPluginBase(name) + { + } + MultiScaleDeformableAttnPluginDynamic::~MultiScaleDeformableAttnPluginDynamic() {} + + nvinfer1::IPluginV2DynamicExt* MultiScaleDeformableAttnPluginDynamic::clone() const TRT_NOEXCEPT + { + MultiScaleDeformableAttnPluginDynamic* plugin = + new MultiScaleDeformableAttnPluginDynamic(mLayerName); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs MultiScaleDeformableAttnPluginDynamic::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + nvinfer1::DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[3].d[1]; + + ret.d[2] = exprBuilder.operation(DimensionOperation::kPROD, *inputs[0].d[2], *inputs[0].d[3]); + + return ret; + } + + bool MultiScaleDeformableAttnPluginDynamic::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) + { + if ((pos == 1) || (pos == 2)) + { + return (ioDesc[pos].type == nvinfer1::DataType::kINT32); + } + else + { + return ((ioDesc[pos].type == ioDesc[0].type) && + ((ioDesc[pos].type == nvinfer1::DataType::kFLOAT) || + (ioDesc[pos].type == nvinfer1::DataType::kHALF))); + } + } + else + { + return false; + } + } + + void MultiScaleDeformableAttnPluginDynamic::configurePlugin( + const nvinfer1::DynamicPluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* outputs, + int nbOutputs) TRT_NOEXCEPT {} + + size_t MultiScaleDeformableAttnPluginDynamic::getWorkspaceSize( + const nvinfer1::PluginTensorDesc* inputs, + int nbInputs, + const nvinfer1::PluginTensorDesc* outputs, + int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int MultiScaleDeformableAttnPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int32_t const batch = inputDesc[0].dims.d[0]; + int32_t spatial_size = inputDesc[0].dims.d[1]; + int32_t num_heads = inputDesc[0].dims.d[2]; + int32_t channels = inputDesc[0].dims.d[3]; + int32_t num_levels = inputDesc[1].dims.d[0]; + int32_t num_query = inputDesc[3].dims.d[1]; + int32_t num_point = inputDesc[3].dims.d[4]; + int32_t rc = 0; + if (inputDesc[0].type == nvinfer1::DataType::kFLOAT) + { + float const* value = static_cast(inputs[0]); + int32_t const* spatialShapes = static_cast(inputs[1]); + int32_t const* levelStartIndex = static_cast(inputs[2]); + float const* samplingLoc = static_cast(inputs[3]); + float const* attnWeight = static_cast(inputs[4]); + float* output = static_cast(outputs[0]); + + rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output, batch, spatial_size, num_heads, channels, num_levels, num_query, num_point, stream); + } + else if (inputDesc[0].type == nvinfer1::DataType::kHALF) + { + const __half* value = static_cast(inputs[0]); + int32_t const* spatialShapes = static_cast(inputs[1]); + int32_t const* levelStartIndex = static_cast(inputs[2]); + const __half* samplingLoc = static_cast(inputs[3]); + const __half* attnWeight = static_cast(inputs[4]); + __half* output = static_cast<__half*>(outputs[0]); + + rc = ms_deform_attn_cuda_forward(value, spatialShapes, levelStartIndex, samplingLoc, attnWeight, output, batch, spatial_size, num_heads, channels, num_levels, num_query, num_point, stream); + } + + return rc; + } + + nvinfer1::DataType MultiScaleDeformableAttnPluginDynamic::getOutputDataType( + int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* MultiScaleDeformableAttnPluginDynamic::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* MultiScaleDeformableAttnPluginDynamic::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int MultiScaleDeformableAttnPluginDynamic::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t MultiScaleDeformableAttnPluginDynamic::getSerializationSize() const TRT_NOEXCEPT + { + return 0; + } + + void MultiScaleDeformableAttnPluginDynamic::serialize(void* buffer) const TRT_NOEXCEPT {} + + void MultiScaleDeformableAttnPluginDynamic::attachToContext( + cudnnContext* cudnnContext, + cublasContext* cublasContext, + nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT {} + + void MultiScaleDeformableAttnPluginDynamic::detachFromContext() TRT_NOEXCEPT {} + + ////////////////////// creator ///////////////////////////// + + MultiScaleDeformableAttnPluginDynamicCreator::MultiScaleDeformableAttnPluginDynamicCreator() + { + mPluginAttributes.clear(); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* MultiScaleDeformableAttnPluginDynamicCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* MultiScaleDeformableAttnPluginDynamicCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* MultiScaleDeformableAttnPluginDynamicCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + MultiScaleDeformableAttnPluginDynamic* plugin = new MultiScaleDeformableAttnPluginDynamic(name); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* MultiScaleDeformableAttnPluginDynamicCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new MultiScaleDeformableAttnPluginDynamic(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + REGISTER_TENSORRT_PLUGIN(MultiScaleDeformableAttnPluginDynamicCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.hpp index 7e66e9e54d..5a2c78baf9 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn.hpp @@ -9,62 +9,59 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class MultiScaleDeformableAttnPluginDynamic : public TRTPluginBase { - public: - MultiScaleDeformableAttnPluginDynamic(const std::string &name); +namespace mmdeploy +{ + class MultiScaleDeformableAttnPluginDynamic : public TRTPluginBase + { + public: + MultiScaleDeformableAttnPluginDynamic(const std::string& name); - MultiScaleDeformableAttnPluginDynamic(const std::string name, const void *data, size_t length); + MultiScaleDeformableAttnPluginDynamic(const std::string name, const void* data, size_t length); - MultiScaleDeformableAttnPluginDynamic(); + MultiScaleDeformableAttnPluginDynamic(); - ~MultiScaleDeformableAttnPluginDynamic() TRT_NOEXCEPT override; + ~MultiScaleDeformableAttnPluginDynamic() TRT_NOEXCEPT override; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - void attachToContext(cudnnContext *cudnnContext, cublasContext *cublasContext, - nvinfer1::IGpuAllocator *gpuAllocator) TRT_NOEXCEPT override; - void detachFromContext() TRT_NOEXCEPT override; + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + void attachToContext(cudnnContext* cudnnContext, cublasContext* cublasContext, nvinfer1::IGpuAllocator* gpuAllocator) TRT_NOEXCEPT override; + void detachFromContext() TRT_NOEXCEPT override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; -}; + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + }; -class MultiScaleDeformableAttnPluginDynamicCreator : public TRTPluginCreatorBase { - public: - MultiScaleDeformableAttnPluginDynamicCreator(); + class MultiScaleDeformableAttnPluginDynamicCreator : public TRTPluginCreatorBase + { + public: + MultiScaleDeformableAttnPluginDynamicCreator(); - const char *getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_MS_DEFORM_ATTN_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu index 6b7588eae0..81ddcc6585 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cu @@ -7,58 +7,91 @@ #include "trt_ms_deform_attn_kernel.hpp" #include "trt_plugin_helper.hpp" -template -void ms_deformable_im2col_cuda(cudaStream_t stream, scalar_t const* dataValue, - int32_t const* dataSpatialShapes, int32_t const* dataLevelStartIndex, - scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight, - int32_t const batchSize, int32_t const spatialSize, - int32_t const numHeads, int32_t const channels, - int32_t const numLevels, int32_t const numQuery, - int32_t const numPoint, scalar_t* dataCol) { - int32_t const numKernels = batchSize * numQuery * numHeads * channels; - int32_t const numActualKernels = batchSize * numQuery * numHeads * channels; +template +void ms_deformable_im2col_cuda(cudaStream_t stream, scalar_t const* dataValue, int32_t const* dataSpatialShapes, int32_t const* dataLevelStartIndex, scalar_t const* dataSamplingLoc, scalar_t const* dataAttnWeight, int32_t const batchSize, int32_t const spatialSize, int32_t const numHeads, int32_t const channels, int32_t const numLevels, int32_t const numQuery, int32_t const numPoint, scalar_t* dataCol) +{ + int32_t const numKernels = batchSize * numQuery * numHeads * channels; + int32_t const numActualKernels = batchSize * numQuery * numHeads * channels; - ms_deformable_im2col_gpu_kernel - <<>>( - numKernels, dataValue, dataSpatialShapes, dataLevelStartIndex, dataSamplingLoc, - dataAttnWeight, batchSize, spatialSize, numHeads, channels, numLevels, numQuery, numPoint, - dataCol); + ms_deformable_im2col_gpu_kernel + <<>>( + numKernels, + dataValue, + dataSpatialShapes, + dataLevelStartIndex, + dataSamplingLoc, + dataAttnWeight, + batchSize, + spatialSize, + numHeads, + channels, + numLevels, + numQuery, + numPoint, + dataCol); } -template -int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes, - const int32_t* levelStartIndex, const scalar_t* samplingLoc, - const scalar_t* attnWeight, scalar_t* output, int32_t batch, - int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, - int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint, - cudaStream_t stream) { - auto perValueSize = mSpatialSize * mNumHeads * mChannels; - auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2; - auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint; - auto perOutputSize = mNumQuery * mNumHeads * mChannels; +template +int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes, const int32_t* levelStartIndex, const scalar_t* samplingLoc, const scalar_t* attnWeight, scalar_t* output, int32_t batch, int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint, cudaStream_t stream) +{ + auto perValueSize = mSpatialSize * mNumHeads * mChannels; + auto perSampleLocSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint * 2; + auto perAttnWeightSize = mNumQuery * mNumHeads * mNumLevels * mNumPoint; + auto perOutputSize = mNumQuery * mNumHeads * mChannels; - int32_t mIm2colStep = batch; + int32_t mIm2colStep = batch; - for (int32_t n = 0; n < batch / mIm2colStep; ++n) { - auto columns = output + n * mIm2colStep * perOutputSize; - ms_deformable_im2col_cuda( - stream, value + n * mIm2colStep * perValueSize, spatialShapes, levelStartIndex, - samplingLoc + n * mIm2colStep * perSampleLocSize, - attnWeight + n * mIm2colStep * perAttnWeightSize, mIm2colStep, mSpatialSize, mNumHeads, - mChannels, mNumLevels, mNumQuery, mNumPoint, columns); - } + for (int32_t n = 0; n < batch / mIm2colStep; ++n) + { + auto columns = output + n * mIm2colStep * perOutputSize; + ms_deformable_im2col_cuda( + stream, + value + n * mIm2colStep * perValueSize, + spatialShapes, + levelStartIndex, + samplingLoc + n * mIm2colStep * perSampleLocSize, + attnWeight + n * mIm2colStep * perAttnWeightSize, + mIm2colStep, + mSpatialSize, + mNumHeads, + mChannels, + mNumLevels, + mNumQuery, + mNumPoint, + columns); + } - return 0; + return 0; } template int32_t ms_deform_attn_cuda_forward( - const float* value, const int32_t* spatialShapes, const int32_t* levelStartIndex, - const float* samplingLoc, const float* attnWeight, float* output, int32_t batch, - int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, - int32_t mNumQuery, int32_t mNumPoint, cudaStream_t stream); + const float* value, + const int32_t* spatialShapes, + const int32_t* levelStartIndex, + const float* samplingLoc, + const float* attnWeight, + float* output, + int32_t batch, + int32_t mSpatialSize, + int32_t mNumHeads, + int32_t mChannels, + int32_t mNumLevels, + int32_t mNumQuery, + int32_t mNumPoint, + cudaStream_t stream); template int32_t ms_deform_attn_cuda_forward<__half>( - const __half* value, const int32_t* spatialShapes, const int32_t* levelStartIndex, - const __half* samplingLoc, const __half* attnWeight, __half* output, int32_t batch, - int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, - int32_t mNumQuery, int32_t mNumPoint, cudaStream_t stream); + const __half* value, + const int32_t* spatialShapes, + const int32_t* levelStartIndex, + const __half* samplingLoc, + const __half* attnWeight, + __half* output, + int32_t batch, + int32_t mSpatialSize, + int32_t mNumHeads, + int32_t mChannels, + int32_t mNumLevels, + int32_t mNumQuery, + int32_t mNumPoint, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cuh b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cuh index cee34cfe65..2b62e7fc30 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cuh +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.cuh @@ -4,254 +4,294 @@ #include "common_cuda_helper.hpp" -template -__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t*& bottom_data, const int& height, - const int& width, const int& nheads, - const int& channels, const scalar_t& h, - const scalar_t& w, const int& m, const int& c) { - const int h_low = floorf(h); - const int w_low = floorf(w); - const int h_high = h_low + 1; - const int w_high = w_low + 1; +template +__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t*& bottom_data, const int& height, const int& width, const int& nheads, const int& channels, const scalar_t& h, const scalar_t& w, const int& m, const int& c) +{ + const int h_low = floorf(h); + const int w_low = floorf(w); + const int h_high = h_low + 1; + const int w_high = w_low + 1; - const scalar_t lh = h - h_low; - const scalar_t lw = w - w_low; - const scalar_t hh = 1 - lh, hw = 1 - lw; + const scalar_t lh = h - h_low; + const scalar_t lw = w - w_low; + const scalar_t hh = 1 - lh, hw = 1 - lw; - const int w_stride = nheads * channels; - const int h_stride = width * w_stride; - const int h_low_ptr_offset = h_low * h_stride; - const int h_high_ptr_offset = h_low_ptr_offset + h_stride; - const int w_low_ptr_offset = w_low * w_stride; - const int w_high_ptr_offset = w_low_ptr_offset + w_stride; - const int base_ptr = m * channels + c; + const int w_stride = nheads * channels; + const int h_stride = width * w_stride; + const int h_low_ptr_offset = h_low * h_stride; + const int h_high_ptr_offset = h_low_ptr_offset + h_stride; + const int w_low_ptr_offset = w_low * w_stride; + const int w_high_ptr_offset = w_low_ptr_offset + w_stride; + const int base_ptr = m * channels + c; - scalar_t v1 = 0; - if (h_low >= 0 && w_low >= 0) { - const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; - v1 = bottom_data[ptr1]; - } - scalar_t v2 = 0; - if (h_low >= 0 && w_high <= width - 1) { - const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; - v2 = bottom_data[ptr2]; - } - scalar_t v3 = 0; - if (h_high <= height - 1 && w_low >= 0) { - const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; - v3 = bottom_data[ptr3]; - } - scalar_t v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) { - const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; - v4 = bottom_data[ptr4]; - } + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + { + const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; + v1 = bottom_data[ptr1]; + } + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + { + const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; + v2 = bottom_data[ptr2]; + } + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + { + const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; + v3 = bottom_data[ptr3]; + } + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + { + const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; + v4 = bottom_data[ptr4]; + } - const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; + const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; } -template <> +template<> __device__ __half ms_deform_attn_im2col_bilinear<__half>( - const __half*& bottomData, int32_t const& height, int32_t const& width, int32_t const& nHeads, - int32_t const& channels, const __half& h, const __half& w, int32_t const& m, int32_t const& c) { - int32_t const hLow = __half2int_rd(h); - int32_t const wLow = __half2int_rd(w); - int32_t const hHigh = hLow + 1; - int32_t const wHigh = wLow + 1; + const __half*& bottomData, + int32_t const& height, + int32_t const& width, + int32_t const& nHeads, + int32_t const& channels, + const __half& h, + const __half& w, + int32_t const& m, + int32_t const& c) +{ + int32_t const hLow = __half2int_rd(h); + int32_t const wLow = __half2int_rd(w); + int32_t const hHigh = hLow + 1; + int32_t const wHigh = wLow + 1; - const __half kZERO = __int2half_rz(0); - const __half one = __int2half_rz(1); + const __half kZERO = __int2half_rz(0); + const __half one = __int2half_rz(1); #if __CUDA_ARCH__ >= 530 - const __half lh = __hsub(h, __int2half_rd(hLow)); - const __half lw = __hsub(w, __int2half_rd(wLow)); - const __half hh = __hsub(one, lh), hw = __hsub(one, lw); + const __half lh = __hsub(h, __int2half_rd(hLow)); + const __half lw = __hsub(w, __int2half_rd(wLow)); + const __half hh = __hsub(one, lh), hw = __hsub(one, lw); #else - const __half lh = __float2half(__half2float(h) - hLow); - const __half lw = __float2half(__half2float(w) - wLow); - const __half hh = __float2half(__half2float(one) - __half2float(lh)); - const __half hw = __float2half(__half2float(one) - __half2float(lw)); + const __half lh = __float2half(__half2float(h) - hLow); + const __half lw = __float2half(__half2float(w) - wLow); + const __half hh = __float2half(__half2float(one) - __half2float(lh)); + const __half hw = __float2half(__half2float(one) - __half2float(lw)); #endif - int32_t const wStride = nHeads * channels; - int32_t const hStride = width * wStride; - int32_t const hLowPtrOffset = hLow * hStride; - int32_t const hHighPtrOffset = hLowPtrOffset + hStride; - int32_t const wLowPtrOffset = wLow * wStride; - int32_t const wHighPtrOffset = wLowPtrOffset + wStride; - int32_t const basePtr = m * channels + c; + int32_t const wStride = nHeads * channels; + int32_t const hStride = width * wStride; + int32_t const hLowPtrOffset = hLow * hStride; + int32_t const hHighPtrOffset = hLowPtrOffset + hStride; + int32_t const wLowPtrOffset = wLow * wStride; + int32_t const wHighPtrOffset = wLowPtrOffset + wStride; + int32_t const basePtr = m * channels + c; - __half v1 = kZERO; - if (hLow >= 0 && wLow >= 0) { - int32_t const ptr1 = hLowPtrOffset + wLowPtrOffset + basePtr; - v1 = bottomData[ptr1]; - } - __half v2 = kZERO; - if (hLow >= 0 && wHigh <= width - 1) { - int32_t const ptr2 = hLowPtrOffset + wHighPtrOffset + basePtr; - v2 = bottomData[ptr2]; - } - __half v3 = kZERO; - if (hHigh <= height - 1 && wLow >= 0) { - int32_t const ptr3 = hHighPtrOffset + wLowPtrOffset + basePtr; - v3 = bottomData[ptr3]; - } - __half v4 = kZERO; - if (hHigh <= height - 1 && wHigh <= width - 1) { - int32_t const ptr4 = hHighPtrOffset + wHighPtrOffset + basePtr; - v4 = bottomData[ptr4]; - } + __half v1 = kZERO; + if (hLow >= 0 && wLow >= 0) + { + int32_t const ptr1 = hLowPtrOffset + wLowPtrOffset + basePtr; + v1 = bottomData[ptr1]; + } + __half v2 = kZERO; + if (hLow >= 0 && wHigh <= width - 1) + { + int32_t const ptr2 = hLowPtrOffset + wHighPtrOffset + basePtr; + v2 = bottomData[ptr2]; + } + __half v3 = kZERO; + if (hHigh <= height - 1 && wLow >= 0) + { + int32_t const ptr3 = hHighPtrOffset + wLowPtrOffset + basePtr; + v3 = bottomData[ptr3]; + } + __half v4 = kZERO; + if (hHigh <= height - 1 && wHigh <= width - 1) + { + int32_t const ptr4 = hHighPtrOffset + wHighPtrOffset + basePtr; + v4 = bottomData[ptr4]; + } #if __CUDA_ARCH__ >= 530 - __half w1 = __hmul(__hmul(hh, hw), v1); - __half w2 = __hmul(__hmul(hh, lw), v2); - __half w3 = __hmul(__hmul(lh, hw), v3); - __half w4 = __hmul(__hmul(lh, lw), v4); + __half w1 = __hmul(__hmul(hh, hw), v1); + __half w2 = __hmul(__hmul(hh, lw), v2); + __half w3 = __hmul(__hmul(lh, hw), v3); + __half w4 = __hmul(__hmul(lh, lw), v4); - w1 = __hadd(w1, w2); - w3 = __hadd(w3, w4); + w1 = __hadd(w1, w2); + w3 = __hadd(w3, w4); - const __half val = __hadd(w1, w3); + const __half val = __hadd(w1, w3); #else - __half w1 = __float2half((__half2float(hh) * __half2float(hw)) * __half2float(v1)); - __half w2 = __float2half((__half2float(hh) * __half2float(lw)) * __half2float(v2)); - __half w3 = __float2half((__half2float(lh) * __half2float(hw)) * __half2float(v3)); - __half w4 = __float2half((__half2float(lh) * __half2float(lw)) * __half2float(v4)); + __half w1 = __float2half((__half2float(hh) * __half2float(hw)) * __half2float(v1)); + __half w2 = __float2half((__half2float(hh) * __half2float(lw)) * __half2float(v2)); + __half w3 = __float2half((__half2float(lh) * __half2float(hw)) * __half2float(v3)); + __half w4 = __float2half((__half2float(lh) * __half2float(lw)) * __half2float(v4)); - w1 = __float2half(__half2float(w1) + __half2float(w2)); - w3 = __float2half(__half2float(w3) + __half2float(w4)); + w1 = __float2half(__half2float(w1) + __half2float(w2)); + w3 = __float2half(__half2float(w3) + __half2float(w4)); - const __half val = __float2half(__half2float(w1) + __half2float(w3)); + const __half val = __float2half(__half2float(w1) + __half2float(w3)); #endif - return val; + return val; } #if 1 -template +template __global__ void ms_deformable_im2col_gpu_kernel( - int32_t const n, scalar_t const* dataValue, int32_t const* dataSpatialShapes, - int32_t const* dataLevelStartIndex, scalar_t const* dataSamplingLoc, - scalar_t const* dataAttnWeight, int32_t const batchSize, int32_t const spatialSize, - int32_t const numHeads, int32_t const channels, int32_t const numLevels, int32_t const numQuery, - int32_t const numPoint, scalar_t* dataCol) { - CUDA_1D_KERNEL_LOOP(index, n) { - int32_t _temp = index; - int32_t const cCol = _temp % channels; - _temp /= channels; - int32_t const samplingIndex = _temp; - int32_t const mCol = _temp % numHeads; - _temp /= numHeads; - _temp /= numQuery; - int32_t const bCol = _temp; + int32_t const n, + scalar_t const* dataValue, + int32_t const* dataSpatialShapes, + int32_t const* dataLevelStartIndex, + scalar_t const* dataSamplingLoc, + scalar_t const* dataAttnWeight, + int32_t const batchSize, + int32_t const spatialSize, + int32_t const numHeads, + int32_t const channels, + int32_t const numLevels, + int32_t const numQuery, + int32_t const numPoint, + scalar_t* dataCol) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + int32_t _temp = index; + int32_t const cCol = _temp % channels; + _temp /= channels; + int32_t const samplingIndex = _temp; + int32_t const mCol = _temp % numHeads; + _temp /= numHeads; + _temp /= numQuery; + int32_t const bCol = _temp; - scalar_t* dataColPtr = dataCol + index; - int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; - int32_t dataLocWPtr = dataWeightPtr << 1; - int32_t const qidStride = numHeads * channels; - int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; - scalar_t col = 0; + scalar_t* dataColPtr = dataCol + index; + int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; + int32_t dataLocWPtr = dataWeightPtr << 1; + int32_t const qidStride = numHeads * channels; + int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; + scalar_t col = 0; - for (int32_t lCol = 0; lCol < numLevels; ++lCol) { - int32_t const levelStartId = dataLevelStartIndex[lCol]; - int32_t const spatialHPtr = lCol << 1; - int32_t const spatialH = dataSpatialShapes[spatialHPtr]; - int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; - scalar_t const* dataValuePtr = - dataValue + (dataValuePtrInitOffset + levelStartId * qidStride); - for (int32_t pCol = 0; pCol < numPoint; ++pCol) { - scalar_t const locW = dataSamplingLoc[dataLocWPtr]; - scalar_t const locH = dataSamplingLoc[dataLocWPtr + 1]; - scalar_t const weight = dataAttnWeight[dataWeightPtr]; + for (int32_t lCol = 0; lCol < numLevels; ++lCol) + { + int32_t const levelStartId = dataLevelStartIndex[lCol]; + int32_t const spatialHPtr = lCol << 1; + int32_t const spatialH = dataSpatialShapes[spatialHPtr]; + int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; + scalar_t const* dataValuePtr = + dataValue + (dataValuePtrInitOffset + levelStartId * qidStride); + for (int32_t pCol = 0; pCol < numPoint; ++pCol) + { + scalar_t const locW = dataSamplingLoc[dataLocWPtr]; + scalar_t const locH = dataSamplingLoc[dataLocWPtr + 1]; + scalar_t const weight = dataAttnWeight[dataWeightPtr]; - scalar_t const hIm = locH * spatialH - 0.5; - scalar_t const wIm = locW * spatialW - 0.5; + scalar_t const hIm = locH * spatialH - 0.5; + scalar_t const wIm = locW * spatialW - 0.5; - if (hIm > -1 && wIm > -1 && hIm < spatialH && wIm < spatialW) { - col += ms_deform_attn_im2col_bilinear(dataValuePtr, spatialH, spatialW, numHeads, - channels, hIm, wIm, mCol, cCol) * - weight; - } + if (hIm > -1 && wIm > -1 && hIm < spatialH && wIm < spatialW) + { + col += ms_deform_attn_im2col_bilinear(dataValuePtr, spatialH, spatialW, numHeads, channels, hIm, wIm, mCol, cCol) * + weight; + } - dataWeightPtr += 1; - dataLocWPtr += 2; - } + dataWeightPtr += 1; + dataLocWPtr += 2; + } + } + *dataColPtr = col; } - *dataColPtr = col; - } } -template <> +template<> __global__ void ms_deformable_im2col_gpu_kernel<__half>( - int32_t const n, const __half* dataValue, int32_t const* dataSpatialShapes, - int32_t const* dataLevelStartIndex, const __half* dataSamplingLoc, const __half* dataAttnWeight, - int32_t const batchSize, int32_t const spatialSize, int32_t const numHeads, - int32_t const channels, int32_t const numLevels, int32_t const numQuery, int32_t const numPoint, - __half* dataCol) { - CUDA_1D_KERNEL_LOOP(index, n) { - int32_t _temp = index; - int32_t const cCol = _temp % channels; - _temp /= channels; - int32_t const samplingIndex = _temp; - int32_t const mCol = _temp % numHeads; - _temp /= numHeads; - _temp /= numQuery; - int32_t const bCol = _temp; + int32_t const n, + const __half* dataValue, + int32_t const* dataSpatialShapes, + int32_t const* dataLevelStartIndex, + const __half* dataSamplingLoc, + const __half* dataAttnWeight, + int32_t const batchSize, + int32_t const spatialSize, + int32_t const numHeads, + int32_t const channels, + int32_t const numLevels, + int32_t const numQuery, + int32_t const numPoint, + __half* dataCol) +{ + CUDA_1D_KERNEL_LOOP(index, n) + { + int32_t _temp = index; + int32_t const cCol = _temp % channels; + _temp /= channels; + int32_t const samplingIndex = _temp; + int32_t const mCol = _temp % numHeads; + _temp /= numHeads; + _temp /= numQuery; + int32_t const bCol = _temp; - __half* dataColPtr = dataCol + index; - int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; - int32_t dataLocWPtr = dataWeightPtr << 1; - int32_t const qidStride = numHeads * channels; - int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; - const __half kZERO_POINT_FIVE = __float2half(0.5f); - const __half kMINUS_ONE = __float2half(-1.0f); - const __half kZERO = __int2half_rz(0); - __half tpVal = kZERO; - __half col = kZERO; + __half* dataColPtr = dataCol + index; + int32_t dataWeightPtr = samplingIndex * numLevels * numPoint; + int32_t dataLocWPtr = dataWeightPtr << 1; + int32_t const qidStride = numHeads * channels; + int32_t const dataValuePtrInitOffset = bCol * spatialSize * qidStride; + const __half kZERO_POINT_FIVE = __float2half(0.5f); + const __half kMINUS_ONE = __float2half(-1.0f); + const __half kZERO = __int2half_rz(0); + __half tpVal = kZERO; + __half col = kZERO; - for (int32_t lCol = 0; lCol < numLevels; ++lCol) { - int32_t const levelStartId = dataLevelStartIndex[lCol]; - int32_t const spatialHPtr = lCol << 1; - int32_t const spatialH = dataSpatialShapes[spatialHPtr]; - int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; - const __half spatialHHalf = __int2half_rd(spatialH); - const __half spatialWHalf = __int2half_rd(spatialW); - const __half* dataValuePtr = dataValue + (dataValuePtrInitOffset + levelStartId * qidStride); - for (int32_t pCol = 0; pCol < numPoint; ++pCol) { - const __half locW = dataSamplingLoc[dataLocWPtr]; - const __half locH = dataSamplingLoc[dataLocWPtr + 1]; - const __half weight = dataAttnWeight[dataWeightPtr]; -#if __CUDA_ARCH__ >= 530 - const __half hIm = __hsub(__hmul(locH, spatialHHalf), kZERO_POINT_FIVE); - const __half wIm = __hsub(__hmul(locW, spatialWHalf), kZERO_POINT_FIVE); + for (int32_t lCol = 0; lCol < numLevels; ++lCol) + { + int32_t const levelStartId = dataLevelStartIndex[lCol]; + int32_t const spatialHPtr = lCol << 1; + int32_t const spatialH = dataSpatialShapes[spatialHPtr]; + int32_t const spatialW = dataSpatialShapes[spatialHPtr + 1]; + const __half spatialHHalf = __int2half_rd(spatialH); + const __half spatialWHalf = __int2half_rd(spatialW); + const __half* dataValuePtr = dataValue + (dataValuePtrInitOffset + levelStartId * qidStride); + for (int32_t pCol = 0; pCol < numPoint; ++pCol) + { + const __half locW = dataSamplingLoc[dataLocWPtr]; + const __half locH = dataSamplingLoc[dataLocWPtr + 1]; + const __half weight = dataAttnWeight[dataWeightPtr]; + #if __CUDA_ARCH__ >= 530 + const __half hIm = __hsub(__hmul(locH, spatialHHalf), kZERO_POINT_FIVE); + const __half wIm = __hsub(__hmul(locW, spatialWHalf), kZERO_POINT_FIVE); - if (__hgt(hIm, kMINUS_ONE) && __hgt(wIm, kMINUS_ONE) && __hlt(hIm, spatialHHalf) && - __hlt(wIm, spatialWHalf)) { - tpVal = ms_deform_attn_im2col_bilinear(dataValuePtr, spatialH, spatialW, numHeads, - channels, hIm, wIm, mCol, cCol); - col = __hadd(col, __hmul(tpVal, weight)); - } -#else - const __half hIm = __float2half(__half2float(locH) * __half2float(spatialHHalf) - - __half2float(kZERO_POINT_FIVE)); - const __half wIm = __float2half(__half2float(locW) * __half2float(spatialWHalf) - - __half2float(kZERO_POINT_FIVE)); + if (__hgt(hIm, kMINUS_ONE) && __hgt(wIm, kMINUS_ONE) && __hlt(hIm, spatialHHalf) && + __hlt(wIm, spatialWHalf)) + { + tpVal = ms_deform_attn_im2col_bilinear(dataValuePtr, spatialH, spatialW, numHeads, channels, hIm, wIm, mCol, cCol); + col = __hadd(col, __hmul(tpVal, weight)); + } + #else + const __half hIm = __float2half(__half2float(locH) * __half2float(spatialHHalf) - + __half2float(kZERO_POINT_FIVE)); + const __half wIm = __float2half(__half2float(locW) * __half2float(spatialWHalf) - + __half2float(kZERO_POINT_FIVE)); - if ((__half2float(hIm) > __half2float(kMINUS_ONE)) && - (__half2float(wIm) > __half2float(kMINUS_ONE)) && - (__half2float(hIm) < __half2float(spatialHHalf)) && - (__half2float(wIm) < __half2float(spatialWHalf))) { - tpVal = ms_deform_attn_im2col_bilinear(dataValuePtr, spatialH, spatialW, numHeads, - channels, hIm, wIm, mCol, cCol); - col = __float2half(__half2float(col) + (__half2float(tpVal) * __half2float(weight))); + if ((__half2float(hIm) > __half2float(kMINUS_ONE)) && + (__half2float(wIm) > __half2float(kMINUS_ONE)) && + (__half2float(hIm) < __half2float(spatialHHalf)) && + (__half2float(wIm) < __half2float(spatialWHalf))) + { + tpVal = ms_deform_attn_im2col_bilinear(dataValuePtr, spatialH, spatialW, numHeads, channels, hIm, wIm, mCol, cCol); + col = __float2half(__half2float(col) + (__half2float(tpVal) * __half2float(weight))); + } + #endif + dataWeightPtr += 1; + dataLocWPtr += 2; + } } -#endif - dataWeightPtr += 1; - dataLocWPtr += 2; - } + *dataColPtr = col; } - *dataColPtr = col; - } } #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.hpp index adbe2566fd..5dafa5a169 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/multi_scale_deform_attn/trt_ms_deform_attn_kernel.hpp @@ -4,12 +4,7 @@ #include #include -template -int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes, - const int32_t* levelStartIndex, const scalar_t* samplingLoc, - const scalar_t* attnWeight, scalar_t* output, int32_t batch, - int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, - int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint, - cudaStream_t stream); +template +int32_t ms_deform_attn_cuda_forward(const scalar_t* value, const int32_t* spatialShapes, const int32_t* levelStartIndex, const scalar_t* samplingLoc, const scalar_t* attnWeight, scalar_t* output, int32_t batch, int32_t mSpatialSize, int32_t mNumHeads, int32_t mChannels, int32_t mNumLevels, int32_t mNumQuery, int32_t mNumPoint, cudaStream_t stream); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp index 988893125d..4325f7b89c 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.cpp @@ -9,233 +9,290 @@ #include "trt_roi_align_kernel.hpp" #include "trt_serialize.hpp" -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"MMCVRoiAlign"}; -} // namespace - -TRTRoIAlign::TRTRoIAlign(const std::string &name, int outWidth, int outHeight, float spatialScale, - int sampleRatio, int poolMode, bool aligned) - : TRTPluginBase(name), - mOutWidth(outWidth), - mOutHeight(outHeight), - mSpatialScale(spatialScale), - mSampleRatio(sampleRatio), - mPoolMode(poolMode), - mAligned(aligned) {} - -TRTRoIAlign::TRTRoIAlign(const std::string name, const void *data, size_t length) - : TRTPluginBase(name) { - deserialize_value(&data, &length, &mOutWidth); - deserialize_value(&data, &length, &mOutHeight); - deserialize_value(&data, &length, &mSpatialScale); - deserialize_value(&data, &length, &mSampleRatio); - deserialize_value(&data, &length, &mPoolMode); - deserialize_value(&data, &length, &mAligned); -} - -nvinfer1::IPluginV2DynamicExt *TRTRoIAlign::clone() const TRT_NOEXCEPT { - TRTRoIAlign *plugin = new TRTRoIAlign(mLayerName, mOutWidth, mOutHeight, mSpatialScale, - mSampleRatio, mPoolMode, mAligned); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs TRTRoIAlign::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - nvinfer1::DimsExprs ret; - ret.nbDims = 4; - ret.d[0] = inputs[1].d[0]; - ret.d[1] = inputs[0].d[1]; - ret.d[2] = exprBuilder.constant(mOutHeight); - ret.d[3] = exprBuilder.constant(mOutWidth); - - return ret; -} - -bool TRTRoIAlign::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; -} - -void TRTRoIAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT {} - -size_t TRTRoIAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - size_t output_size = 0; - size_t word_size = 0; - switch (mPoolMode) { - case 0: // max - output_size = - outputs[0].dims.d[0] * outputs[0].dims.d[1] * outputs[0].dims.d[2] * outputs[0].dims.d[3]; - word_size = mmdeploy::getElementSize(outputs[0].type); - return output_size * word_size * 2; - break; - case 1: - return 0; - break; - default: - return 0; - } - return 0; -} - -int TRTRoIAlign::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { - int channels = inputDesc[0].dims.d[1]; - int height = inputDesc[0].dims.d[2]; - int width = inputDesc[0].dims.d[3]; - - int output_size = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] * outputDesc[0].dims.d[2] * - outputDesc[0].dims.d[3]; - int word_size = mmdeploy::getElementSize(outputDesc[0].type); - - const void *feat = inputs[0]; - const void *rois = inputs[1]; - void *output = outputs[0]; - void *argmax_y = nullptr; - void *argmax_x = nullptr; - - switch (mPoolMode) { - case 0: // max - argmax_y = workSpace; - argmax_x = (char *)argmax_y + output_size * word_size; - break; - case 1: // avg - break; - } - - switch (outputDesc[0].type) { - case nvinfer1::DataType::kFLOAT: - TRTRoIAlignForwardCUDAKernelLauncher( - (const float *)feat, (const float *)rois, (float *)output, (float *)argmax_y, - (float *)argmax_x, output_size, channels, height, width, mOutHeight, mOutWidth, - mSpatialScale, mSampleRatio, mPoolMode, mAligned, stream); - break; - - default: - break; - } - - return 0; -} - -nvinfer1::DataType TRTRoIAlign::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *TRTRoIAlign::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTRoIAlign::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int TRTRoIAlign::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t TRTRoIAlign::getSerializationSize() const TRT_NOEXCEPT { - return serialized_size(mOutWidth) + serialized_size(mOutHeight) + serialized_size(mSpatialScale) + - serialized_size(mSampleRatio) + serialized_size(mPoolMode) + serialized_size(mAligned); -} - -void TRTRoIAlign::serialize(void *buffer) const TRT_NOEXCEPT { - serialize_value(&buffer, mOutWidth); - serialize_value(&buffer, mOutHeight); - serialize_value(&buffer, mSpatialScale); - serialize_value(&buffer, mSampleRatio); - serialize_value(&buffer, mPoolMode); - serialize_value(&buffer, mAligned); -} - -TRTRoIAlignCreator::TRTRoIAlignCreator() { - mPluginAttributes.emplace_back(nvinfer1::PluginField("output_height")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("output_width")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("spatial_scale")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("sampling_ratio")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("mode")); - mPluginAttributes.emplace_back(nvinfer1::PluginField("aligned")); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *TRTRoIAlignCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTRoIAlignCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -nvinfer1::IPluginV2 *TRTRoIAlignCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - int outWidth = 7; - int outHeight = 7; - float spatialScale = 1.0; - int sampleRatio = 0; - int poolMode = -1; - bool aligned = true; - for (int i = 0; i < fc->nbFields; i++) { - if (fc->fields[i].data == nullptr) { - continue; - } - std::string field_name(fc->fields[i].name); - - if (field_name.compare("output_height") == 0) { - outHeight = static_cast(fc->fields[i].data)[0]; - } - - if (field_name.compare("output_width") == 0) { - outWidth = static_cast(fc->fields[i].data)[0]; - } - - if (field_name.compare("spatial_scale") == 0) { - spatialScale = static_cast(fc->fields[i].data)[0]; - } - - if (field_name.compare("sampling_ratio") == 0) { - sampleRatio = static_cast(fc->fields[i].data)[0]; - } - - if (field_name.compare("mode") == 0) { - int data_size = fc->fields[i].length; - ASSERT(data_size > 0); - const char *data_start = static_cast(fc->fields[i].data); - std::string pool_mode(data_start); - if (pool_mode == "avg") { - poolMode = 1; - } else if (pool_mode == "max") { - poolMode = 0; - } else { - std::cout << "Unknown pool mode \"" << pool_mode << "\"." << std::endl; - } - ASSERT(poolMode >= 0); - } - - if (field_name.compare("aligned") == 0) { - int aligned_int = static_cast(fc->fields[i].data)[0]; - aligned = aligned_int != 0; - } - } - - ASSERT(outHeight > 0); - ASSERT(outWidth > 0); - ASSERT(spatialScale > 0.); - ASSERT(poolMode >= 0); - - TRTRoIAlign *plugin = - new TRTRoIAlign(name, outWidth, outHeight, spatialScale, sampleRatio, poolMode, aligned); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *TRTRoIAlignCreator::deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT { - auto plugin = new TRTRoIAlign(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(TRTRoIAlignCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"MMCVRoiAlign"}; + } // namespace + + TRTRoIAlign::TRTRoIAlign(const std::string& name, int outWidth, int outHeight, float spatialScale, int sampleRatio, int poolMode, bool aligned) + : TRTPluginBase(name) + , mOutWidth(outWidth) + , mOutHeight(outHeight) + , mSpatialScale(spatialScale) + , mSampleRatio(sampleRatio) + , mPoolMode(poolMode) + , mAligned(aligned) + { + } + + TRTRoIAlign::TRTRoIAlign(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { + deserialize_value(&data, &length, &mOutWidth); + deserialize_value(&data, &length, &mOutHeight); + deserialize_value(&data, &length, &mSpatialScale); + deserialize_value(&data, &length, &mSampleRatio); + deserialize_value(&data, &length, &mPoolMode); + deserialize_value(&data, &length, &mAligned); + } + + nvinfer1::IPluginV2DynamicExt* TRTRoIAlign::clone() const TRT_NOEXCEPT + { + TRTRoIAlign* plugin = new TRTRoIAlign(mLayerName, mOutWidth, mOutHeight, mSpatialScale, mSampleRatio, mPoolMode, mAligned); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs TRTRoIAlign::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + nvinfer1::DimsExprs ret; + ret.nbDims = 4; + ret.d[0] = inputs[1].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = exprBuilder.constant(mOutHeight); + ret.d[3] = exprBuilder.constant(mOutWidth); + + return ret; + } + + bool TRTRoIAlign::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT + { + return ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + } + + void TRTRoIAlign::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) TRT_NOEXCEPT {} + + size_t TRTRoIAlign::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT + { + size_t output_size = 0; + size_t word_size = 0; + switch (mPoolMode) + { + case 0: // max + output_size = + outputs[0].dims.d[0] * outputs[0].dims.d[1] * outputs[0].dims.d[2] * outputs[0].dims.d[3]; + word_size = mmdeploy::getElementSize(outputs[0].type); + return output_size * word_size * 2; + break; + case 1: + return 0; + break; + default: + return 0; + } + return 0; + } + + int TRTRoIAlign::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + int channels = inputDesc[0].dims.d[1]; + int height = inputDesc[0].dims.d[2]; + int width = inputDesc[0].dims.d[3]; + + int output_size = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] * outputDesc[0].dims.d[2] * + outputDesc[0].dims.d[3]; + int word_size = mmdeploy::getElementSize(outputDesc[0].type); + + const void* feat = inputs[0]; + const void* rois = inputs[1]; + void* output = outputs[0]; + void* argmax_y = nullptr; + void* argmax_x = nullptr; + + switch (mPoolMode) + { + case 0: // max + argmax_y = workSpace; + argmax_x = (char*)argmax_y + output_size * word_size; + break; + case 1: // avg + break; + } + + switch (outputDesc[0].type) + { + case nvinfer1::DataType::kFLOAT: + TRTRoIAlignForwardCUDAKernelLauncher( + (const float*)feat, + (const float*)rois, + (float*)output, + (float*)argmax_y, + (float*)argmax_x, + output_size, + channels, + height, + width, + mOutHeight, + mOutWidth, + mSpatialScale, + mSampleRatio, + mPoolMode, + mAligned, + stream); + break; + + default: + break; + } + + return 0; + } + + nvinfer1::DataType TRTRoIAlign::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* TRTRoIAlign::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTRoIAlign::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int TRTRoIAlign::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t TRTRoIAlign::getSerializationSize() const TRT_NOEXCEPT + { + return serialized_size(mOutWidth) + serialized_size(mOutHeight) + serialized_size(mSpatialScale) + + serialized_size(mSampleRatio) + serialized_size(mPoolMode) + serialized_size(mAligned); + } + + void TRTRoIAlign::serialize(void* buffer) const TRT_NOEXCEPT + { + serialize_value(&buffer, mOutWidth); + serialize_value(&buffer, mOutHeight); + serialize_value(&buffer, mSpatialScale); + serialize_value(&buffer, mSampleRatio); + serialize_value(&buffer, mPoolMode); + serialize_value(&buffer, mAligned); + } + + TRTRoIAlignCreator::TRTRoIAlignCreator() + { + mPluginAttributes.emplace_back(nvinfer1::PluginField("output_height")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("output_width")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("spatial_scale")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("sampling_ratio")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("mode")); + mPluginAttributes.emplace_back(nvinfer1::PluginField("aligned")); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTRoIAlignCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTRoIAlignCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* TRTRoIAlignCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + int outWidth = 7; + int outHeight = 7; + float spatialScale = 1.0; + int sampleRatio = 0; + int poolMode = -1; + bool aligned = true; + for (int i = 0; i < fc->nbFields; i++) + { + if (fc->fields[i].data == nullptr) + { + continue; + } + std::string field_name(fc->fields[i].name); + + if (field_name.compare("output_height") == 0) + { + outHeight = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("output_width") == 0) + { + outWidth = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("spatial_scale") == 0) + { + spatialScale = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("sampling_ratio") == 0) + { + sampleRatio = static_cast(fc->fields[i].data)[0]; + } + + if (field_name.compare("mode") == 0) + { + int data_size = fc->fields[i].length; + ASSERT(data_size > 0); + const char* data_start = static_cast(fc->fields[i].data); + std::string pool_mode(data_start); + if (pool_mode == "avg") + { + poolMode = 1; + } + else if (pool_mode == "max") + { + poolMode = 0; + } + else + { + std::cout << "Unknown pool mode \"" << pool_mode << "\"." << std::endl; + } + ASSERT(poolMode >= 0); + } + + if (field_name.compare("aligned") == 0) + { + int aligned_int = static_cast(fc->fields[i].data)[0]; + aligned = aligned_int != 0; + } + } + + ASSERT(outHeight > 0); + ASSERT(outWidth > 0); + ASSERT(spatialScale > 0.); + ASSERT(poolMode >= 0); + + TRTRoIAlign* plugin = + new TRTRoIAlign(name, outWidth, outHeight, spatialScale, sampleRatio, poolMode, aligned); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* TRTRoIAlignCreator::deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new TRTRoIAlign(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + REGISTER_TENSORRT_PLUGIN(TRTRoIAlignCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.hpp b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.hpp index cfc14758f7..45301e014e 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align.hpp @@ -8,65 +8,62 @@ #include #include "trt_plugin_base.hpp" -namespace mmdeploy { -class TRTRoIAlign : public TRTPluginBase { - public: - TRTRoIAlign(const std::string &name, int outWidth, int outHeight, float spatialScale, - int sampleRatio, int poolMode, bool aligned); +namespace mmdeploy +{ + class TRTRoIAlign : public TRTPluginBase + { + public: + TRTRoIAlign(const std::string& name, int outWidth, int outHeight, float spatialScale, int sampleRatio, int poolMode, bool aligned); - TRTRoIAlign(const std::string name, const void *data, size_t length); + TRTRoIAlign(const std::string name, const void* data, size_t length); - TRTRoIAlign() = delete; + TRTRoIAlign() = delete; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; - private: - int mOutWidth; - int mOutHeight; - float mSpatialScale; - int mSampleRatio; - int mPoolMode; // 1:avg 0:max - bool mAligned; -}; + private: + int mOutWidth; + int mOutHeight; + float mSpatialScale; + int mSampleRatio; + int mPoolMode; // 1:avg 0:max + bool mAligned; + }; -class TRTRoIAlignCreator : public TRTPluginCreatorBase { - public: - TRTRoIAlignCreator(); + class TRTRoIAlignCreator : public TRTPluginCreatorBase + { + public: + TRTRoIAlignCreator(); - const char *getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_ROI_ALIGN_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.cu index 4e1a825d4f..4cd448aa52 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.cu @@ -4,104 +4,135 @@ #include "trt_roi_align_kernel.hpp" /*** Forward ***/ -template -__global__ void roi_align_forward_cuda_kernel(const int nthreads, const T* input, const T* rois, - T* output, T* argmax_y, T* argmax_x, - const int pooled_height, const int pooled_width, - const T spatial_scale, const int sampling_ratio, - const int pool_mode, // 0 - max pool, 1 - avg pool - const bool aligned, const int channels, - const int height, const int width) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { - // (n, c, ph, pw) is an element in the pooled output - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int c = (index / pooled_width / pooled_height) % channels; - int n = index / pooled_width / pooled_height / channels; +template +__global__ void roi_align_forward_cuda_kernel(const int nthreads, const T* input, const T* rois, T* output, T* argmax_y, T* argmax_x, const int pooled_height, const int pooled_width, const T spatial_scale, const int sampling_ratio, + const int pool_mode, // 0 - max pool, 1 - avg pool + const bool aligned, + const int channels, + const int height, + const int width) +{ + CUDA_1D_KERNEL_LOOP(index, nthreads) + { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; - const T* offset_rois = rois + n * 5; - int roi_batch_ind = offset_rois[0]; + const T* offset_rois = rois + n * 5; + int roi_batch_ind = offset_rois[0]; - // Do not using rounding; this implementation detail is critical - T offset = aligned ? (T)0.5 : (T)0.0; - T roi_start_w = offset_rois[1] * spatial_scale - offset; - T roi_start_h = offset_rois[2] * spatial_scale - offset; - T roi_end_w = offset_rois[3] * spatial_scale - offset; - T roi_end_h = offset_rois[4] * spatial_scale - offset; + // Do not using rounding; this implementation detail is critical + T offset = aligned ? (T)0.5 : (T)0.0; + T roi_start_w = offset_rois[1] * spatial_scale - offset; + T roi_start_h = offset_rois[2] * spatial_scale - offset; + T roi_end_w = offset_rois[3] * spatial_scale - offset; + T roi_end_h = offset_rois[4] * spatial_scale - offset; - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - if (!aligned) { // for backward-compatibility only - roi_width = max(roi_width, (T)1.); - roi_height = max(roi_height, (T)1.); - } + T roi_width = roi_end_w - roi_start_w; + T roi_height = roi_end_h - roi_start_h; + if (!aligned) + { // for backward-compatibility only + roi_width = max(roi_width, (T)1.); + roi_height = max(roi_height, (T)1.); + } - T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); - T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); - const T* offset_input = input + (roi_batch_ind * channels + c) * height * width; + const T* offset_input = input + (roi_batch_ind * channels + c) * height * width; - // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = - (sampling_ratio > 0) ? sampling_ratio : static_cast(ceilf(roi_height / pooled_height)); - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : static_cast(ceilf(roi_width / pooled_width)); + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = + (sampling_ratio > 0) ? sampling_ratio : static_cast(ceilf(roi_height / pooled_height)); + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : static_cast(ceilf(roi_width / pooled_width)); - if (pool_mode == 0) { - // We do max pooling inside a bin - T maxval = -FLT_MAX; - T maxidx_y = -1.f, maxidx_x = -1.f; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); - T val = bilinear_interpolate(offset_input, height, width, y, x); - if (val > maxval) { - maxval = val; - maxidx_y = y; - maxidx_x = x; - } + if (pool_mode == 0) + { + // We do max pooling inside a bin + T maxval = -FLT_MAX; + T maxidx_y = -1.f, maxidx_x = -1.f; + for (int iy = 0; iy < roi_bin_grid_h; iy++) + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) + { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x); + if (val > maxval) + { + maxval = val; + maxidx_y = y; + maxidx_x = x; + } + } + } + output[index] = maxval; + argmax_y[index] = maxidx_y; + argmax_x[index] = maxidx_x; } - } - output[index] = maxval; - argmax_y[index] = maxidx_y; - argmax_x[index] = maxidx_x; - } else if (pool_mode == 1) { - // We do average pooling inside a bin - const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); - T output_val = 0.; - for (int iy = 0; iy < roi_bin_grid_h; iy++) { - const T y = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); - for (int ix = 0; ix < roi_bin_grid_w; ix++) { - const T x = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); - T val = bilinear_interpolate(offset_input, height, width, y, x); - output_val += val; + else if (pool_mode == 1) + { + // We do average pooling inside a bin + const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) + { + const T y = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); + for (int ix = 0; ix < roi_bin_grid_w; ix++) + { + const T x = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); + T val = bilinear_interpolate(offset_input, height, width, y, x); + output_val += val; + } + } + output[index] = output_val / count; } - } - output[index] = output_val / count; } - } } -template -void TRTRoIAlignForwardCUDAKernelLauncher(const scalar_t* input, const scalar_t* rois, - scalar_t* output, scalar_t* argmax_y, scalar_t* argmax_x, - int output_size, int channels, int height, int width, - int aligned_height, int aligned_width, - scalar_t spatial_scale, int sampling_ratio, int pool_mode, - bool aligned, cudaStream_t stream) { - roi_align_forward_cuda_kernel - <<>>( - output_size, input, rois, output, argmax_y, argmax_x, aligned_height, aligned_width, - static_cast(spatial_scale), sampling_ratio, pool_mode, aligned, channels, - height, width); +template +void TRTRoIAlignForwardCUDAKernelLauncher(const scalar_t* input, const scalar_t* rois, scalar_t* output, scalar_t* argmax_y, scalar_t* argmax_x, int output_size, int channels, int height, int width, int aligned_height, int aligned_width, scalar_t spatial_scale, int sampling_ratio, int pool_mode, bool aligned, cudaStream_t stream) +{ + roi_align_forward_cuda_kernel + <<>>( + output_size, + input, + rois, + output, + argmax_y, + argmax_x, + aligned_height, + aligned_width, + static_cast(spatial_scale), + sampling_ratio, + pool_mode, + aligned, + channels, + height, + width); } template void TRTRoIAlignForwardCUDAKernelLauncher( - const float* input, const float* rois, float* output, float* argmax_y, float* argmax_x, - int output_size, int channels, int height, int width, int aligned_height, int aligned_width, - float spatial_scale, int sampling_ratio, int pool_mode, bool aligned, cudaStream_t stream); + const float* input, + const float* rois, + float* output, + float* argmax_y, + float* argmax_x, + int output_size, + int channels, + int height, + int width, + int aligned_height, + int aligned_width, + float spatial_scale, + int sampling_ratio, + int pool_mode, + bool aligned, + cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.hpp index 3db656bff9..39e8dc7893 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/roi_align/trt_roi_align_kernel.hpp @@ -4,12 +4,7 @@ #include "common_cuda_helper.hpp" -template -void TRTRoIAlignForwardCUDAKernelLauncher(const scalar_t* input, const scalar_t* rois, - scalar_t* output, scalar_t* argmax_y, scalar_t* argmax_x, - int output_size, int channels, int height, int width, - int aligned_height, int aligned_width, - scalar_t spatial_scale, int sampling_ratio, int pool_mode, - bool aligned, cudaStream_t stream); +template +void TRTRoIAlignForwardCUDAKernelLauncher(const scalar_t* input, const scalar_t* rois, scalar_t* output, scalar_t* argmax_y, scalar_t* argmax_x, int output_size, int channels, int height, int width, int aligned_height, int aligned_width, scalar_t spatial_scale, int sampling_ratio, int pool_mode, bool aligned, cudaStream_t stream); #endif // ROI_ALIGN_CUDA_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.cpp b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.cpp index a4ecb2356a..b20a4b37ea 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.cpp @@ -10,174 +10,223 @@ using namespace nvinfer1; -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"ScaledDotProductAttentionTRT"}; -} // namespace - -ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string &name) - : TRTPluginBase(name), mask_dim(0) {} - -ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string name, const void *data, - size_t length) - : TRTPluginBase(name), mask_dim(0) {} - -ScaledDotProductAttentionTRT::~ScaledDotProductAttentionTRT() {} - -nvinfer1::IPluginV2DynamicExt *ScaledDotProductAttentionTRT::clone() const TRT_NOEXCEPT { - ScaledDotProductAttentionTRT *plugin = new ScaledDotProductAttentionTRT(mLayerName); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::DimsExprs ScaledDotProductAttentionTRT::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - if (outputIndex == 0) return inputs[0]; - nvinfer1::DimsExprs ret; - ret.nbDims = 3; - ret.d[0] = inputs[0].d[0]; - ret.d[1] = inputs[0].d[1]; - ret.d[2] = inputs[1].d[1]; - - return ret; -} - -bool ScaledDotProductAttentionTRT::supportsFormatCombination( - int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos == 0) { - return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - } else { - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - } -} - -// Attach the plugin object to an execution context and grant the plugin the -// access to some context resource. -void ScaledDotProductAttentionTRT::attachToContext(cudnnContext *cudnnContext, - cublasContext *cublasContext, - IGpuAllocator *gpuAllocator) TRT_NOEXCEPT { - _cublas_handle = cublasContext; - _cudnn_handle = cudnnContext; - cudnnCreateTensorDescriptor(&_x_desc); - cudnnCreateTensorDescriptor(&_y_desc); - cudnnCreateTensorDescriptor(&_mask_desc); -} - -// Detach the plugin object from its execution context. -void ScaledDotProductAttentionTRT::detachFromContext() TRT_NOEXCEPT { - cudnnDestroyTensorDescriptor(_y_desc); - cudnnDestroyTensorDescriptor(_x_desc); - cudnnDestroyTensorDescriptor(_mask_desc); -} - -void ScaledDotProductAttentionTRT::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, - int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT { - if (nbInputs != 4) { - mask_dim = 0; - } else { - mask_dim = in[3].desc.dims.nbDims; - } -} - -int ScaledDotProductAttentionTRT::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, - const void *const *inputs, void *const *outputs, - void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { - if (CUDNN_STATUS_SUCCESS != cudnnSetStream(_cudnn_handle, stream)) return 1; - if (CUBLAS_STATUS_SUCCESS != cublasSetStream(_cublas_handle, stream)) return 1; - int B = inputDesc[0].dims.d[0]; // batch * heads - int Nt = inputDesc[0].dims.d[1]; - int Ns = inputDesc[1].dims.d[1]; - int E = inputDesc[0].dims.d[2]; // embeding size - - const void *query = inputs[0]; - const void *key = inputs[1]; - const void *value = inputs[2]; - const void *mask = nullptr; - - int mask_dims[3]; - mask_dims[0] = 0; - if (mask_dim > 0) { - mask = inputs[3]; - // check if mask need broadcast - if (mask_dim == 2) { - mask_dims[0] = 1; - mask_dims[1] = inputDesc[3].dims.d[0]; - mask_dims[2] = inputDesc[3].dims.d[1]; - } else { - mask_dims[0] = inputDesc[3].dims.d[0]; - mask_dims[1] = inputDesc[3].dims.d[1]; - mask_dims[2] = inputDesc[3].dims.d[2]; - } - } - - void *output = outputs[0]; - void *attn = outputs[1]; - - auto data_type = inputDesc[0].type; - cudnnDataType_t cudnn_dtype{}; - convert_trt2cudnn_dtype(data_type, &cudnn_dtype); - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - dot_product_attention_impl((float *)query, (float *)key, (float *)value, (float *)mask, - (float *)attn, (float *)output, B, Nt, Ns, E, &mask_dims[0], - _x_desc, _y_desc, _mask_desc, cudnn_dtype, stream, - _cublas_handle, _cudnn_handle); - break; - default: - return 1; - } - - return 0; -} - -nvinfer1::DataType ScaledDotProductAttentionTRT::getOutputDataType( - int index, const nvinfer1::DataType *inputTypes, int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *ScaledDotProductAttentionTRT::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *ScaledDotProductAttentionTRT::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -int ScaledDotProductAttentionTRT::getNbOutputs() const TRT_NOEXCEPT { return 2; } - -size_t ScaledDotProductAttentionTRT::getSerializationSize() const TRT_NOEXCEPT { return 0; } - -void ScaledDotProductAttentionTRT::serialize(void *buffer) const TRT_NOEXCEPT {} - -////////////////////// creator ///////////////////////////// - -ScaledDotProductAttentionTRTCreator::ScaledDotProductAttentionTRTCreator() {} - -const char *ScaledDotProductAttentionTRTCreator::getPluginName() const TRT_NOEXCEPT { - return PLUGIN_NAME; -} - -const char *ScaledDotProductAttentionTRTCreator::getPluginVersion() const TRT_NOEXCEPT { - return PLUGIN_VERSION; -} - -nvinfer1::IPluginV2 *ScaledDotProductAttentionTRTCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - ScaledDotProductAttentionTRT *plugin = new ScaledDotProductAttentionTRT(name); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *ScaledDotProductAttentionTRTCreator::deserializePlugin( - const char *name, const void *serialData, size_t serialLength) TRT_NOEXCEPT { - auto plugin = new ScaledDotProductAttentionTRT(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} -REGISTER_TENSORRT_PLUGIN(ScaledDotProductAttentionTRTCreator); +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"ScaledDotProductAttentionTRT"}; + } // namespace + + ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string& name) + : TRTPluginBase(name) + , mask_dim(0) + { + } + + ScaledDotProductAttentionTRT::ScaledDotProductAttentionTRT(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + , mask_dim(0) + { + } + + ScaledDotProductAttentionTRT::~ScaledDotProductAttentionTRT() {} + + nvinfer1::IPluginV2DynamicExt* ScaledDotProductAttentionTRT::clone() const TRT_NOEXCEPT + { + ScaledDotProductAttentionTRT* plugin = new ScaledDotProductAttentionTRT(mLayerName); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::DimsExprs ScaledDotProductAttentionTRT::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + if (outputIndex == 0) return inputs[0]; + nvinfer1::DimsExprs ret; + ret.nbDims = 3; + ret.d[0] = inputs[0].d[0]; + ret.d[1] = inputs[0].d[1]; + ret.d[2] = inputs[1].d[1]; + + return ret; + } + + bool ScaledDotProductAttentionTRT::supportsFormatCombination( + int pos, + const nvinfer1::PluginTensorDesc* ioDesc, + int nbInputs, + int nbOutputs) TRT_NOEXCEPT + { + if (pos == 0) + { + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + } + else + { + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + } + } + + // Attach the plugin object to an execution context and grant the plugin the + // access to some context resource. + void ScaledDotProductAttentionTRT::attachToContext(cudnnContext* cudnnContext, + cublasContext* cublasContext, + IGpuAllocator* gpuAllocator) TRT_NOEXCEPT + { + _cublas_handle = cublasContext; + _cudnn_handle = cudnnContext; + cudnnCreateTensorDescriptor(&_x_desc); + cudnnCreateTensorDescriptor(&_y_desc); + cudnnCreateTensorDescriptor(&_mask_desc); + } + + // Detach the plugin object from its execution context. + void ScaledDotProductAttentionTRT::detachFromContext() TRT_NOEXCEPT + { + cudnnDestroyTensorDescriptor(_y_desc); + cudnnDestroyTensorDescriptor(_x_desc); + cudnnDestroyTensorDescriptor(_mask_desc); + } + + void ScaledDotProductAttentionTRT::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, + int nbInputs, + const nvinfer1::DynamicPluginTensorDesc* out, + int nbOutputs) TRT_NOEXCEPT + { + if (nbInputs != 4) + { + mask_dim = 0; + } + else + { + mask_dim = in[3].desc.dims.nbDims; + } + } + + int ScaledDotProductAttentionTRT::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + if (CUDNN_STATUS_SUCCESS != cudnnSetStream(_cudnn_handle, stream)) return 1; + if (CUBLAS_STATUS_SUCCESS != cublasSetStream(_cublas_handle, stream)) return 1; + int B = inputDesc[0].dims.d[0]; // batch * heads + int Nt = inputDesc[0].dims.d[1]; + int Ns = inputDesc[1].dims.d[1]; + int E = inputDesc[0].dims.d[2]; // embeding size + + const void* query = inputs[0]; + const void* key = inputs[1]; + const void* value = inputs[2]; + const void* mask = nullptr; + + int mask_dims[3]; + mask_dims[0] = 0; + if (mask_dim > 0) + { + mask = inputs[3]; + // check if mask need broadcast + if (mask_dim == 2) + { + mask_dims[0] = 1; + mask_dims[1] = inputDesc[3].dims.d[0]; + mask_dims[2] = inputDesc[3].dims.d[1]; + } + else + { + mask_dims[0] = inputDesc[3].dims.d[0]; + mask_dims[1] = inputDesc[3].dims.d[1]; + mask_dims[2] = inputDesc[3].dims.d[2]; + } + } + + void* output = outputs[0]; + void* attn = outputs[1]; + + auto data_type = inputDesc[0].type; + cudnnDataType_t cudnn_dtype{}; + convert_trt2cudnn_dtype(data_type, &cudnn_dtype); + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + dot_product_attention_impl((float*)query, (float*)key, (float*)value, (float*)mask, (float*)attn, (float*)output, B, Nt, Ns, E, &mask_dims[0], _x_desc, _y_desc, _mask_desc, cudnn_dtype, stream, _cublas_handle, _cudnn_handle); + break; + default: + return 1; + } + + return 0; + } + + nvinfer1::DataType ScaledDotProductAttentionTRT::getOutputDataType( + int index, + const nvinfer1::DataType* inputTypes, + int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* ScaledDotProductAttentionTRT::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* ScaledDotProductAttentionTRT::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int ScaledDotProductAttentionTRT::getNbOutputs() const TRT_NOEXCEPT + { + return 2; + } + + size_t ScaledDotProductAttentionTRT::getSerializationSize() const TRT_NOEXCEPT + { + return 0; + } + + void ScaledDotProductAttentionTRT::serialize(void* buffer) const TRT_NOEXCEPT {} + + ////////////////////// creator ///////////////////////////// + + ScaledDotProductAttentionTRTCreator::ScaledDotProductAttentionTRTCreator() {} + + const char* ScaledDotProductAttentionTRTCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* ScaledDotProductAttentionTRTCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* ScaledDotProductAttentionTRTCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + ScaledDotProductAttentionTRT* plugin = new ScaledDotProductAttentionTRT(name); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* ScaledDotProductAttentionTRTCreator::deserializePlugin( + const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new ScaledDotProductAttentionTRT(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + REGISTER_TENSORRT_PLUGIN(ScaledDotProductAttentionTRTCreator); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.hpp b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.hpp index 86d35616a9..9e184626cb 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention.hpp @@ -9,65 +9,64 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class ScaledDotProductAttentionTRT : public TRTPluginBase { - public: - ScaledDotProductAttentionTRT(const std::string &name); +namespace mmdeploy +{ + class ScaledDotProductAttentionTRT : public TRTPluginBase + { + public: + ScaledDotProductAttentionTRT(const std::string& name); - ScaledDotProductAttentionTRT(const std::string name, const void *data, size_t length); + ScaledDotProductAttentionTRT(const std::string name, const void* data, size_t length); - ScaledDotProductAttentionTRT() = delete; + ScaledDotProductAttentionTRT() = delete; - ~ScaledDotProductAttentionTRT() TRT_NOEXCEPT override; + ~ScaledDotProductAttentionTRT() TRT_NOEXCEPT override; - virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; + virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; - void attachToContext(cudnnContext *cudnn, cublasContext *cublas, - nvinfer1::IGpuAllocator *allocator) TRT_NOEXCEPT override; - void detachFromContext() TRT_NOEXCEPT override; + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + void attachToContext(cudnnContext* cudnn, cublasContext* cublas, nvinfer1::IGpuAllocator* allocator) TRT_NOEXCEPT override; + void detachFromContext() TRT_NOEXCEPT override; - private: - int mask_dim; - cublasHandle_t _cublas_handle{}; - cudnnHandle_t _cudnn_handle{}; - cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _mask_desc{}; -}; + private: + int mask_dim; + cublasHandle_t _cublas_handle{}; + cudnnHandle_t _cudnn_handle{}; + cudnnTensorDescriptor_t _x_desc{}, _y_desc{}, _mask_desc{}; + }; -class ScaledDotProductAttentionTRTCreator : public TRTPluginCreatorBase { - public: - ScaledDotProductAttentionTRTCreator(); + class ScaledDotProductAttentionTRTCreator : public TRTPluginCreatorBase + { + public: + ScaledDotProductAttentionTRTCreator(); - const char *getPluginName() const TRT_NOEXCEPT override; + const char* getPluginName() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_SCALED_DOT_PRODUCT_ATTENTION_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.cu index a0ee16c998..738316b9a8 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.cu @@ -11,93 +11,79 @@ #include "scaled_dot_product_attention_kernel.hpp" #include "trt_plugin_helper.hpp" -template -cublasStatus_t cublasgemmStridedBatchedWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const scalar_t* alpha, const scalar_t* A, int lda, - long long int strideA, const scalar_t* B, int ldb, - long long int strideB, const scalar_t* beta, - scalar_t* C, int ldc, long long int strideC, - int batchCount); +template +cublasStatus_t cublasgemmStridedBatchedWrap(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const scalar_t* alpha, const scalar_t* A, int lda, long long int strideA, const scalar_t* B, int ldb, long long int strideB, const scalar_t* beta, scalar_t* C, int ldc, long long int strideC, int batchCount); -template <> -cublasStatus_t cublasgemmStridedBatchedWrap(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const float* alpha, const float* A, int lda, - long long int strideA, const float* B, int ldb, - long long int strideB, const float* beta, - float* C, int ldc, long long int strideC, - int batchCount) { - return cublasSgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, - strideB, beta, C, ldc, strideC, batchCount); +template<> +cublasStatus_t cublasgemmStridedBatchedWrap(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, const float* A, int lda, long long int strideA, const float* B, int ldb, long long int strideB, const float* beta, float* C, int ldc, long long int strideC, int batchCount) +{ + return cublasSgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); } -template <> -cublasStatus_t cublasgemmStridedBatchedWrap<__half>(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const __half* alpha, const __half* A, int lda, - long long int strideA, const __half* B, int ldb, - long long int strideB, const __half* beta, - __half* C, int ldc, long long int strideC, - int batchCount) { - return cublasHgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, - strideB, beta, C, ldc, strideC, batchCount); +template<> +cublasStatus_t cublasgemmStridedBatchedWrap<__half>(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const __half* alpha, const __half* A, int lda, long long int strideA, const __half* B, int ldb, long long int strideB, const __half* beta, __half* C, int ldc, long long int strideC, int batchCount) +{ + return cublasHgemmStridedBatched(handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount); } -template -void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value, - const scalar_t* mask, scalar_t* attn, scalar_t* output, int B, - int Nt, int Ns, int E, const int* mask_dims, - cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc, - cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, - cudaStream_t stream, cublasHandle_t cublas_handle, - cudnnHandle_t cudnn_handle) { - { - // Q @ K - const int m = Ns; - const int n = Nt; - const int k = E; - const auto alpha = scalar_t(1.0f / sqrt(float(E))); - const auto beta = scalar_t(0); - cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, key, k, - Ns * E, query, k, Nt * E, &beta, attn, m, Nt * Ns, B); - } +template +void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value, const scalar_t* mask, scalar_t* attn, scalar_t* output, int B, int Nt, int Ns, int E, const int* mask_dims, cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc, cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, cudaStream_t stream, cublasHandle_t cublas_handle, cudnnHandle_t cudnn_handle) +{ + { + // Q @ K + const int m = Ns; + const int n = Nt; + const int k = E; + const auto alpha = scalar_t(1.0f / sqrt(float(E))); + const auto beta = scalar_t(0); + cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, m, n, k, &alpha, key, k, Ns * E, query, k, Nt * E, &beta, attn, m, Nt * Ns, B); + } - if (mask_dims != nullptr && mask_dims[0] != 0) { - const auto alpha = scalar_t(1); - const auto beta = scalar_t(1); - cudnnSetTensor4dDescriptor(mask_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, mask_dims[0], - mask_dims[1], mask_dims[2]); - cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, B, Nt, Ns); - cudnnAddTensor(cudnn_handle, &alpha, mask_desc, mask, &beta, x_desc, attn); - } + if (mask_dims != nullptr && mask_dims[0] != 0) + { + const auto alpha = scalar_t(1); + const auto beta = scalar_t(1); + cudnnSetTensor4dDescriptor(mask_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, mask_dims[0], mask_dims[1], mask_dims[2]); + cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, 1, B, Nt, Ns); + cudnnAddTensor(cudnn_handle, &alpha, mask_desc, mask, &beta, x_desc, attn); + } - { - // softmax attention - const auto alpha = scalar_t(1); - const auto beta = scalar_t(0); - cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1); - cudnnSetTensor4dDescriptor(y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1); - cudnnSoftmaxForward(cudnn_handle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_INSTANCE, &alpha, - x_desc, attn, &beta, y_desc, attn); - } + { + // softmax attention + const auto alpha = scalar_t(1); + const auto beta = scalar_t(0); + cudnnSetTensor4dDescriptor(x_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1); + cudnnSetTensor4dDescriptor(y_desc, CUDNN_TENSOR_NCHW, cudnn_dtype, B * Nt, Ns, 1, 1); + cudnnSoftmaxForward(cudnn_handle, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_INSTANCE, &alpha, x_desc, attn, &beta, y_desc, attn); + } - { - // attn @ v - const int m = E; - const int n = Nt; - const int k = Ns; - const auto alpha = scalar_t(1); - const auto beta = scalar_t(0); - cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, value, m, - Ns * E, (const scalar_t*)(attn), k, Ns * Nt, &beta, output, m, - Nt * E, B); - } + { + // attn @ v + const int m = E; + const int n = Nt; + const int k = Ns; + const auto alpha = scalar_t(1); + const auto beta = scalar_t(0); + cublasgemmStridedBatchedWrap(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, value, m, Ns * E, (const scalar_t*)(attn), k, Ns * Nt, &beta, output, m, Nt * E, B); + } } template void dot_product_attention_impl( - const float* query, const float* key, const float* value, const float* mask, float* attn, - float* output, int B, int Nt, int Ns, int E, const int* mask_dims, - cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc, - cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, cudaStream_t stream, - cublasHandle_t cublas_handle, cudnnHandle_t cudnn_handle); + const float* query, + const float* key, + const float* value, + const float* mask, + float* attn, + float* output, + int B, + int Nt, + int Ns, + int E, + const int* mask_dims, + cudnnTensorDescriptor_t& x_desc, + cudnnTensorDescriptor_t& y_desc, + cudnnTensorDescriptor_t& mask_desc, + cudnnDataType_t cudnn_dtype, + cudaStream_t stream, + cublasHandle_t cublas_handle, + cudnnHandle_t cudnn_handle); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.hpp index d1cdc7773a..10db2aade1 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/scaled_dot_product_attention/scaled_dot_product_attention_kernel.hpp @@ -5,13 +5,7 @@ #include #include -template -void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value, - const scalar_t* mask, scalar_t* attn, scalar_t* output, int B, - int Nt, int Ns, int E, const int* mask_dims, - cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc, - cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, - cudaStream_t stream, cublasHandle_t cublas_handle, - cudnnHandle_t cudnn_handle); +template +void dot_product_attention_impl(const scalar_t* query, const scalar_t* key, const scalar_t* value, const scalar_t* mask, scalar_t* attn, scalar_t* output, int B, int Nt, int Ns, int E, const int* mask_dims, cudnnTensorDescriptor_t& x_desc, cudnnTensorDescriptor_t& y_desc, cudnnTensorDescriptor_t& mask_desc, cudnnDataType_t cudnn_dtype, cudaStream_t stream, cublasHandle_t cublas_handle, cudnnHandle_t cudnn_handle); #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.cpp b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.cpp index 13c637f408..ca0ed9afa0 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.cpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.cpp @@ -2,155 +2,192 @@ #include "NvInferVersion.h" // ScatterND is supported since TensorRT8 #if NV_TENSORRT_MAJOR <= 7 -#include -#include - -#include - -#include "trt_scatternd.hpp" -#include "trt_scatternd_kernel.hpp" -#include "trt_serialize.hpp" - -namespace mmdeploy { -namespace { -static const char *PLUGIN_VERSION{"1"}; -static const char *PLUGIN_NAME{"ScatterND"}; -} // namespace - -TRTScatterND::TRTScatterND(const std::string &name) : TRTPluginBase(name) {} - -TRTScatterND::TRTScatterND(const std::string name, const void *data, size_t length) - : TRTPluginBase(name) {} - -nvinfer1::IPluginV2DynamicExt *TRTScatterND::clone() const TRT_NOEXCEPT { - TRTScatterND *plugin = new TRTScatterND(mLayerName); - plugin->setPluginNamespace(getPluginNamespace()); - - return plugin; -} - -nvinfer1::DimsExprs TRTScatterND::getOutputDimensions( - int outputIndex, const nvinfer1::DimsExprs *inputs, int nbInputs, - nvinfer1::IExprBuilder &exprBuilder) TRT_NOEXCEPT { - return inputs[0]; -} - -bool TRTScatterND::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, - int nbInputs, int nbOutputs) TRT_NOEXCEPT { - if (pos < nbInputs) { - switch (pos) { - case 0: - // data - return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) || - (ioDesc[pos].type == nvinfer1::DataType::kINT32 && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); - case 1: - // indices - return ioDesc[pos].type == nvinfer1::DataType::kINT32 && - ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; - case 2: - // updates - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - default: - return true; + #include + #include + + #include + + #include "trt_scatternd.hpp" + #include "trt_scatternd_kernel.hpp" + #include "trt_serialize.hpp" + +namespace mmdeploy +{ + namespace + { + static const char* PLUGIN_VERSION{"1"}; + static const char* PLUGIN_NAME{"ScatterND"}; + } // namespace + + TRTScatterND::TRTScatterND(const std::string& name) + : TRTPluginBase(name) + { + } + + TRTScatterND::TRTScatterND(const std::string name, const void* data, size_t length) + : TRTPluginBase(name) + { } - } else { - switch (pos - nbInputs) { - case 0: - // output - return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; - default: + + nvinfer1::IPluginV2DynamicExt* TRTScatterND::clone() const TRT_NOEXCEPT + { + TRTScatterND* plugin = new TRTScatterND(mLayerName); + plugin->setPluginNamespace(getPluginNamespace()); + + return plugin; + } + + nvinfer1::DimsExprs TRTScatterND::getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) TRT_NOEXCEPT + { + return inputs[0]; + } + + bool TRTScatterND::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT + { + if (pos < nbInputs) + { + switch (pos) + { + case 0: + // data + return (ioDesc[pos].type == nvinfer1::DataType::kFLOAT && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR) || + (ioDesc[pos].type == nvinfer1::DataType::kINT32 && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR); + case 1: + // indices + return ioDesc[pos].type == nvinfer1::DataType::kINT32 && + ioDesc[pos].format == nvinfer1::TensorFormat::kLINEAR; + case 2: + // updates + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + default: + return true; + } + } + else + { + switch (pos - nbInputs) + { + case 0: + // output + return ioDesc[pos].type == ioDesc[0].type && ioDesc[pos].format == ioDesc[0].format; + default: + return true; + } + } return true; } - } - return true; -} - -void TRTScatterND::configurePlugin(const nvinfer1::DynamicPluginTensorDesc *inputs, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *outputs, - int nbOutputs) TRT_NOEXCEPT {} - -size_t TRTScatterND::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT { - return 0; -} - -int TRTScatterND::enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workSpace, cudaStream_t stream) TRT_NOEXCEPT { - const int *dims = &(inputDesc[0].dims.d[0]); - const int *indices_dims = &(inputDesc[1].dims.d[0]); - int nbDims = inputDesc[0].dims.nbDims; - int indice_nbDims = inputDesc[1].dims.nbDims; - - const void *data = inputs[0]; - const void *indices = inputs[1]; - const void *update = inputs[2]; - void *output = outputs[0]; - - auto data_type = inputDesc[0].type; - - switch (data_type) { - case nvinfer1::DataType::kFLOAT: - TRTONNXScatterNDKernelLauncher((float *)data, (int *)indices, (float *)update, dims, - nbDims, indices_dims, indice_nbDims, (float *)output, - stream); - break; - - case nvinfer1::DataType::kINT32: - TRTONNXScatterNDKernelLauncher((int *)data, (int *)indices, (int *)update, dims, nbDims, - indices_dims, indice_nbDims, (int *)output, stream); - break; - default: - break; - } - - return 0; -} - -nvinfer1::DataType TRTScatterND::getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT { - return inputTypes[0]; -} - -// IPluginV2 Methods -const char *TRTScatterND::getPluginType() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTScatterND::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -int TRTScatterND::getNbOutputs() const TRT_NOEXCEPT { return 1; } - -size_t TRTScatterND::getSerializationSize() const TRT_NOEXCEPT { return 0; } - -void TRTScatterND::serialize(void *buffer) const TRT_NOEXCEPT {} - -TRTScatterNDCreator::TRTScatterNDCreator() { - mPluginAttributes.clear(); - mFC.nbFields = mPluginAttributes.size(); - mFC.fields = mPluginAttributes.data(); -} - -const char *TRTScatterNDCreator::getPluginName() const TRT_NOEXCEPT { return PLUGIN_NAME; } - -const char *TRTScatterNDCreator::getPluginVersion() const TRT_NOEXCEPT { return PLUGIN_VERSION; } - -nvinfer1::IPluginV2 *TRTScatterNDCreator::createPlugin( - const char *name, const nvinfer1::PluginFieldCollection *fc) TRT_NOEXCEPT { - TRTScatterND *plugin = new TRTScatterND(name); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -nvinfer1::IPluginV2 *TRTScatterNDCreator::deserializePlugin(const char *name, - const void *serialData, - size_t serialLength) TRT_NOEXCEPT { - auto plugin = new TRTScatterND(name, serialData, serialLength); - plugin->setPluginNamespace(getPluginNamespace()); - return plugin; -} - -REGISTER_TENSORRT_PLUGIN(TRTScatterNDCreator); + + void TRTScatterND::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* inputs, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* outputs, int nbOutputs) TRT_NOEXCEPT {} + + size_t TRTScatterND::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT + { + return 0; + } + + int TRTScatterND::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workSpace, + cudaStream_t stream) TRT_NOEXCEPT + { + const int* dims = &(inputDesc[0].dims.d[0]); + const int* indices_dims = &(inputDesc[1].dims.d[0]); + int nbDims = inputDesc[0].dims.nbDims; + int indice_nbDims = inputDesc[1].dims.nbDims; + + const void* data = inputs[0]; + const void* indices = inputs[1]; + const void* update = inputs[2]; + void* output = outputs[0]; + + auto data_type = inputDesc[0].type; + + switch (data_type) + { + case nvinfer1::DataType::kFLOAT: + TRTONNXScatterNDKernelLauncher((float*)data, (int*)indices, (float*)update, dims, nbDims, indices_dims, indice_nbDims, (float*)output, stream); + break; + + case nvinfer1::DataType::kINT32: + TRTONNXScatterNDKernelLauncher((int*)data, (int*)indices, (int*)update, dims, nbDims, indices_dims, indice_nbDims, (int*)output, stream); + break; + default: + break; + } + + return 0; + } + + nvinfer1::DataType TRTScatterND::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT + { + return inputTypes[0]; + } + + // IPluginV2 Methods + const char* TRTScatterND::getPluginType() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTScatterND::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + int TRTScatterND::getNbOutputs() const TRT_NOEXCEPT + { + return 1; + } + + size_t TRTScatterND::getSerializationSize() const TRT_NOEXCEPT + { + return 0; + } + + void TRTScatterND::serialize(void* buffer) const TRT_NOEXCEPT {} + + TRTScatterNDCreator::TRTScatterNDCreator() + { + mPluginAttributes.clear(); + mFC.nbFields = mPluginAttributes.size(); + mFC.fields = mPluginAttributes.data(); + } + + const char* TRTScatterNDCreator::getPluginName() const TRT_NOEXCEPT + { + return PLUGIN_NAME; + } + + const char* TRTScatterNDCreator::getPluginVersion() const TRT_NOEXCEPT + { + return PLUGIN_VERSION; + } + + nvinfer1::IPluginV2* TRTScatterNDCreator::createPlugin( + const char* name, + const nvinfer1::PluginFieldCollection* fc) TRT_NOEXCEPT + { + TRTScatterND* plugin = new TRTScatterND(name); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + nvinfer1::IPluginV2* TRTScatterNDCreator::deserializePlugin(const char* name, + const void* serialData, + size_t serialLength) TRT_NOEXCEPT + { + auto plugin = new TRTScatterND(name, serialData, serialLength); + plugin->setPluginNamespace(getPluginNamespace()); + return plugin; + } + + REGISTER_TENSORRT_PLUGIN(TRTScatterNDCreator); } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.hpp b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.hpp index d6b859855e..b75adc40c2 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd.hpp @@ -9,56 +9,54 @@ #include "trt_plugin_base.hpp" -namespace mmdeploy { -class TRTScatterND : public TRTPluginBase { - public: - TRTScatterND(const std::string &name); - - TRTScatterND(const std::string name, const void *data, size_t length); - - TRTScatterND() = delete; - - // IPluginV2DynamicExt Methods - nvinfer1::IPluginV2DynamicExt *clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs *inputs, - int nbInputs, nvinfer1::IExprBuilder &exprBuilder) - TRT_NOEXCEPT override; - bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *ioDesc, int nbInputs, - int nbOutputs) TRT_NOEXCEPT override; - void configurePlugin(const nvinfer1::DynamicPluginTensorDesc *in, int nbInputs, - const nvinfer1::DynamicPluginTensorDesc *out, - int nbOutputs) TRT_NOEXCEPT override; - size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, - const nvinfer1::PluginTensorDesc *outputs, - int nbOutputs) const TRT_NOEXCEPT override; - int enqueue(const nvinfer1::PluginTensorDesc *inputDesc, - const nvinfer1::PluginTensorDesc *outputDesc, const void *const *inputs, - void *const *outputs, void *workspace, cudaStream_t stream) TRT_NOEXCEPT override; - - // IPluginV2Ext Methods - nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, - int nbInputs) const TRT_NOEXCEPT override; - - // IPluginV2 Methods - const char *getPluginType() const TRT_NOEXCEPT override; - const char *getPluginVersion() const TRT_NOEXCEPT override; - int getNbOutputs() const TRT_NOEXCEPT override; - size_t getSerializationSize() const TRT_NOEXCEPT override; - void serialize(void *buffer) const TRT_NOEXCEPT override; -}; - -class TRTScatterNDCreator : public TRTPluginCreatorBase { - public: - TRTScatterNDCreator(); - - const char *getPluginName() const TRT_NOEXCEPT override; - - const char *getPluginVersion() const TRT_NOEXCEPT override; - nvinfer1::IPluginV2 *createPlugin(const char *name, const nvinfer1::PluginFieldCollection *fc) - TRT_NOEXCEPT override; - - nvinfer1::IPluginV2 *deserializePlugin(const char *name, const void *serialData, - size_t serialLength) TRT_NOEXCEPT override; -}; +namespace mmdeploy +{ + class TRTScatterND : public TRTPluginBase + { + public: + TRTScatterND(const std::string& name); + + TRTScatterND(const std::string name, const void* data, size_t length); + + TRTScatterND() = delete; + + // IPluginV2DynamicExt Methods + nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; + nvinfer1::DimsExprs getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) + TRT_NOEXCEPT override; + bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* ioDesc, int nbInputs, int nbOutputs) TRT_NOEXCEPT override; + void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs, const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override; + size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, int nbInputs, const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const TRT_NOEXCEPT override; + int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, + const nvinfer1::PluginTensorDesc* outputDesc, + const void* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) TRT_NOEXCEPT override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const TRT_NOEXCEPT override; + + // IPluginV2 Methods + const char* getPluginType() const TRT_NOEXCEPT override; + const char* getPluginVersion() const TRT_NOEXCEPT override; + int getNbOutputs() const TRT_NOEXCEPT override; + size_t getSerializationSize() const TRT_NOEXCEPT override; + void serialize(void* buffer) const TRT_NOEXCEPT override; + }; + + class TRTScatterNDCreator : public TRTPluginCreatorBase + { + public: + TRTScatterNDCreator(); + + const char* getPluginName() const TRT_NOEXCEPT override; + + const char* getPluginVersion() const TRT_NOEXCEPT override; + nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc) + TRT_NOEXCEPT override; + + nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength) TRT_NOEXCEPT override; + }; } // namespace mmdeploy #endif // TRT_SCATTERND_HPP diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu index c763992e9f..a9ec98fa36 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu +++ b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.cu @@ -8,68 +8,70 @@ using mmdeploy::TensorDesc; -template -__global__ void onnx_scatternd_kernel(const int n, const int* indices, const T* update, T* output, - TensorDesc tensor_desc, TensorDesc indice_desc) { - const int indice_cols = indice_desc.shape[indice_desc.dim - 1]; - const int copy_stride = tensor_desc.stride[indice_cols - 1]; - const int* stride = &(tensor_desc.stride[0]); - CUDA_1D_KERNEL_LOOP(index, n) { - int output_offset = 0; - const int* indices_current = indices + index * indice_cols; - for (int i = 0; i < indice_cols; ++i) { - output_offset += stride[i] * indices_current[i]; +template +__global__ void onnx_scatternd_kernel(const int n, const int* indices, const T* update, T* output, TensorDesc tensor_desc, TensorDesc indice_desc) +{ + const int indice_cols = indice_desc.shape[indice_desc.dim - 1]; + const int copy_stride = tensor_desc.stride[indice_cols - 1]; + const int* stride = &(tensor_desc.stride[0]); + CUDA_1D_KERNEL_LOOP(index, n) + { + int output_offset = 0; + const int* indices_current = indices + index * indice_cols; + for (int i = 0; i < indice_cols; ++i) + { + output_offset += stride[i] * indices_current[i]; + } + memcpy(output + output_offset, update + index * copy_stride, copy_stride * sizeof(T)); } - memcpy(output + output_offset, update + index * copy_stride, copy_stride * sizeof(T)); - } } -template -void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices, const T* update, - const int* dims, int nbDims, const int* indices_dims, - int indice_nbDims, T* output, cudaStream_t stream) { - // fill tensordesc and initial - TensorDesc tensor_desc; - memset((void*)&tensor_desc, 0, sizeof(TensorDesc)); - tensor_desc.dim = nbDims; - tensor_desc.shape[nbDims - 1] = dims[nbDims - 1]; - tensor_desc.stride[nbDims - 1] = 1; - for (int i = nbDims - 2; i >= 0; --i) { - tensor_desc.shape[i] = dims[i]; - tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1]; - } - const int data_size = tensor_desc.stride[0] * tensor_desc.shape[0]; +template +void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices, const T* update, const int* dims, int nbDims, const int* indices_dims, int indice_nbDims, T* output, cudaStream_t stream) +{ + // fill tensordesc and initial + TensorDesc tensor_desc; + memset((void*)&tensor_desc, 0, sizeof(TensorDesc)); + tensor_desc.dim = nbDims; + tensor_desc.shape[nbDims - 1] = dims[nbDims - 1]; + tensor_desc.stride[nbDims - 1] = 1; + for (int i = nbDims - 2; i >= 0; --i) + { + tensor_desc.shape[i] = dims[i]; + tensor_desc.stride[i] = dims[i + 1] * tensor_desc.stride[i + 1]; + } + const int data_size = tensor_desc.stride[0] * tensor_desc.shape[0]; - TensorDesc indice_desc; - memset((void*)&indice_desc, 0, sizeof(TensorDesc)); - indice_desc.dim = indice_nbDims; - indice_desc.shape[indice_nbDims - 1] = indices_dims[indice_nbDims - 1]; - indice_desc.stride[indice_nbDims - 1] = 1; - for (int i = indice_nbDims - 2; i >= 0; --i) { - indice_desc.shape[i] = indices_dims[i]; - indice_desc.stride[i] = indices_dims[i + 1] * indice_desc.stride[i + 1]; - } + TensorDesc indice_desc; + memset((void*)&indice_desc, 0, sizeof(TensorDesc)); + indice_desc.dim = indice_nbDims; + indice_desc.shape[indice_nbDims - 1] = indices_dims[indice_nbDims - 1]; + indice_desc.stride[indice_nbDims - 1] = 1; + for (int i = indice_nbDims - 2; i >= 0; --i) + { + indice_desc.shape[i] = indices_dims[i]; + indice_desc.stride[i] = indices_dims[i + 1] * indice_desc.stride[i + 1]; + } - // output = np.copy(data) - cudaMemcpyAsync(output, data, data_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); + // output = np.copy(data) + cudaMemcpyAsync(output, data, data_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); - int num_update_indice = 1; - for (int i = 0; i < indice_nbDims - 1; ++i) { - num_update_indice *= indice_desc.shape[i]; - } - // scatter - const int col_block = DIVUP(num_update_indice, THREADS_PER_BLOCK); - onnx_scatternd_kernel<<>>( - num_update_indice, indices, update, output, tensor_desc, indice_desc); + int num_update_indice = 1; + for (int i = 0; i < indice_nbDims - 1; ++i) + { + num_update_indice *= indice_desc.shape[i]; + } + // scatter + const int col_block = DIVUP(num_update_indice, THREADS_PER_BLOCK); + onnx_scatternd_kernel<<>>( + num_update_indice, + indices, + update, + output, + tensor_desc, + indice_desc); } -template void TRTONNXScatterNDKernelLauncher(const float* data, const int* indices, - const float* update, const int* dims, - int nbDims, const int* indices_dims, - int indice_nbDims, float* output, - cudaStream_t stream); +template void TRTONNXScatterNDKernelLauncher(const float* data, const int* indices, const float* update, const int* dims, int nbDims, const int* indices_dims, int indice_nbDims, float* output, cudaStream_t stream); -template void TRTONNXScatterNDKernelLauncher(const int* data, const int* indices, - const int* update, const int* dims, int nbDims, - const int* indices_dims, int indice_nbDims, - int* output, cudaStream_t stream); +template void TRTONNXScatterNDKernelLauncher(const int* data, const int* indices, const int* update, const int* dims, int nbDims, const int* indices_dims, int indice_nbDims, int* output, cudaStream_t stream); diff --git a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.hpp b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.hpp index b64b66494d..ae8ae2c34b 100644 --- a/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.hpp +++ b/csrc/mmdeploy/backend_ops/tensorrt/scatternd/trt_scatternd_kernel.hpp @@ -3,9 +3,7 @@ #define TRT_SCATTERND_KERNEL_HPP #include -template -void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices, const T* update, - const int* dims, int nbDims, const int* indices_dims, - int indice_nbDims, T* output, cudaStream_t stream); +template +void TRTONNXScatterNDKernelLauncher(const T* data, const int* indices, const T* update, const int* dims, int nbDims, const int* indices_dims, int indice_nbDims, T* output, cudaStream_t stream); #endif // TRT_SCATTERND_KERNEL_HPP diff --git a/csrc/mmdeploy/backend_ops/torchscript/ops/bind.cpp b/csrc/mmdeploy/backend_ops/torchscript/ops/bind.cpp index f236ac9b66..777b2b1eed 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/ops/bind.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/ops/bind.cpp @@ -1,13 +1,14 @@ // Copyright (c) OpenMMLab. All rights reserved. #include "torch/script.h" -TORCH_LIBRARY(mmdeploy, m) { - m.def( - "modulated_deform_conv(Tensor input, Tensor weight, Tensor bias, Tensor offset, Tensor " - "mask, " - "int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int " - "dilation_h,int dilation_w, int groups, int deform_groups, bool with_bias) -> Tensor") - .def( - "coreml_nms(Tensor boxes, Tensor scores, float iou_threshold, " - "float score_threshold, int max_boxes) -> Tensor[]"); +TORCH_LIBRARY(mmdeploy, m) +{ + m.def( + "modulated_deform_conv(Tensor input, Tensor weight, Tensor bias, Tensor offset, Tensor " + "mask, " + "int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, int pad_w, int " + "dilation_h,int dilation_w, int groups, int deform_groups, bool with_bias) -> Tensor") + .def( + "coreml_nms(Tensor boxes, Tensor scores, float iou_threshold, " + "float score_threshold, int max_boxes) -> Tensor[]"); } diff --git a/csrc/mmdeploy/backend_ops/torchscript/ops/coreml_nms/coreml_nms_cpu.cpp b/csrc/mmdeploy/backend_ops/torchscript/ops/coreml_nms/coreml_nms_cpu.cpp index a78b701349..77fc5c6388 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/ops/coreml_nms/coreml_nms_cpu.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/ops/coreml_nms/coreml_nms_cpu.cpp @@ -4,28 +4,32 @@ #include #include "torch/script.h" -namespace mmdeploy { - -using at::Tensor; - -std::vector coreml_nms_cpu(Tensor boxes, Tensor scores, double iou_threshold, - double score_threshold, int64_t max_boxes) { - assert(boxes.dim() == 3); // bboxes with shape (batch_size, num_bboxes, 4) - assert(boxes.size(2) == 4); - assert(boxes.size(0) == scores.size(0)); // check batch size - assert(boxes.size(1) == scores.size(1)); // check num boxes - - auto batch_size = boxes.size(0); - auto num_boxes = boxes.size(1); - auto num_classes = scores.size(2); - - Tensor ret_boxes = at::zeros({batch_size, max_boxes, 4}); - Tensor ret_scores = at::zeros({batch_size, max_boxes, num_classes}); - Tensor indices = at::zeros({batch_size, max_boxes}, at::kInt); - Tensor num_outputs = at::zeros({batch_size}, at::kInt); - - return std::vector({ret_boxes, ret_scores, indices, num_outputs}); -} - -TORCH_LIBRARY_IMPL(mmdeploy, CPU, m) { m.impl("coreml_nms", coreml_nms_cpu); } +namespace mmdeploy +{ + + using at::Tensor; + + std::vector coreml_nms_cpu(Tensor boxes, Tensor scores, double iou_threshold, double score_threshold, int64_t max_boxes) + { + assert(boxes.dim() == 3); // bboxes with shape (batch_size, num_bboxes, 4) + assert(boxes.size(2) == 4); + assert(boxes.size(0) == scores.size(0)); // check batch size + assert(boxes.size(1) == scores.size(1)); // check num boxes + + auto batch_size = boxes.size(0); + auto num_boxes = boxes.size(1); + auto num_classes = scores.size(2); + + Tensor ret_boxes = at::zeros({batch_size, max_boxes, 4}); + Tensor ret_scores = at::zeros({batch_size, max_boxes, num_classes}); + Tensor indices = at::zeros({batch_size, max_boxes}, at::kInt); + Tensor num_outputs = at::zeros({batch_size}, at::kInt); + + return std::vector({ret_boxes, ret_scores, indices, num_outputs}); + } + + TORCH_LIBRARY_IMPL(mmdeploy, CPU, m) + { + m.impl("coreml_nms", coreml_nms_cpu); + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cpu.cpp b/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cpu.cpp index c6d980919f..cf404849b4 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cpu.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cpu.cpp @@ -3,19 +3,37 @@ #include "torch/script.h" -namespace mmdeploy { - -void modulated_deformable_im2col_cpu( - const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, - const int64_t batch_size, const int64_t channels, const int64_t height_im, - const int64_t width_im, const int64_t height_col, const int64_t width_col, - const int64_t kernel_h, const int64_t kernel_w, const int64_t pad_h, const int64_t pad_w, - const int64_t stride_h, const int64_t stride_w, const int64_t dilation_h, - const int64_t dilation_w, int64_t deformable_group, at::Tensor data_col) { - // num_axes should be smaller than block size - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - data_im.scalar_type(), "modulated_deformable_im2col_cpu", ([&] { +namespace mmdeploy +{ + + void modulated_deformable_im2col_cpu( + const at::Tensor data_im, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int64_t batch_size, + const int64_t channels, + const int64_t height_im, + const int64_t width_im, + const int64_t height_col, + const int64_t width_col, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + int64_t deformable_group, + at::Tensor data_col) + { + // num_axes should be smaller than block size + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), + "modulated_deformable_im2col_cpu", + ([&] + { const scalar_t *data_im_ = data_im.data_ptr(); const scalar_t *data_offset_ = data_offset.data_ptr(); const scalar_t *data_mask_ = data_mask.data_ptr(); @@ -24,71 +42,66 @@ void modulated_deformable_im2col_cpu( deformable_im2col_2d(data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channels, deformable_group, - height_col, width_col, data_mask_ != nullptr, data_col_); - })); -} - -at::Tensor modulated_deform_conv_forward_cpu(at::Tensor input, at::Tensor weight, at::Tensor bias, - at::Tensor offset, at::Tensor mask, int64_t kernel_h, - int64_t kernel_w, int64_t stride_h, int64_t stride_w, - int64_t pad_h, int64_t pad_w, int64_t dilation_h, - int64_t dilation_w, int64_t group, - int64_t deformable_group, bool with_bias) { - at::DeviceGuard guard(input.device()); - - const int batch = input.size(0); - const int channels = input.size(1); - const int height = input.size(2); - const int width = input.size(3); - - const int channels_out = weight.size(0); - const int channels_kernel = weight.size(1); - const int kernel_h_ = weight.size(2); - const int kernel_w_ = weight.size(3); - - if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) - AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, - kernel_h_, kernel_w_); - if (channels != channels_kernel * group) - AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", channels, - channels_kernel * group); - - const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - // resize output - at::Tensor output = - at::zeros({batch, group, channels_out / group, height_out, width_out}, input.options()); - // resize temporary columns - at::Tensor columns = at::zeros( - {group, channels * kernel_h * kernel_w / group, 1 * height_out * width_out}, input.options()); - - // divide into group - weight = - weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)}); - for (int b = 0; b < batch; b++) { - modulated_deformable_im2col_cpu(input[b], offset[b], mask[b], 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, - stride_h, stride_w, dilation_h, dilation_w, deformable_group, - columns); - - for (int g = 0; g < group; g++) { - output[b][g] = - output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]); + height_col, width_col, data_mask_ != nullptr, data_col_); })); } - } - output = output.view( - {output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)}); - - if (with_bias) { - output += bias.view({1, bias.size(0), 1, 1}); - } - - return output; -} + at::Tensor modulated_deform_conv_forward_cpu(at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor offset, at::Tensor mask, int64_t kernel_h, int64_t kernel_w, int64_t stride_h, int64_t stride_w, int64_t pad_h, int64_t pad_w, int64_t dilation_h, int64_t dilation_w, int64_t group, int64_t deformable_group, bool with_bias) + { + at::DeviceGuard guard(input.device()); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", channels, channels_kernel * group); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + // resize output + at::Tensor output = + at::zeros({batch, group, channels_out / group, height_out, width_out}, input.options()); + // resize temporary columns + at::Tensor columns = at::zeros( + {group, channels * kernel_h * kernel_w / group, 1 * height_out * width_out}, + input.options()); + + // divide into group + weight = + weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)}); + for (int b = 0; b < batch; b++) + { + modulated_deformable_im2col_cpu(input[b], offset[b], mask[b], 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, columns); + + for (int g = 0; g < group; g++) + { + output[b][g] = + output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]); + } + } + + output = output.view( + {output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)}); + + if (with_bias) + { + output += bias.view({1, bias.size(0), 1, 1}); + } + + return output; + } -TORCH_LIBRARY_IMPL(mmdeploy, CPU, m) { - m.impl("modulated_deform_conv", modulated_deform_conv_forward_cpu); -} + TORCH_LIBRARY_IMPL(mmdeploy, CPU, m) + { + m.impl("modulated_deform_conv", modulated_deform_conv_forward_cpu); + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cuda.cu b/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cuda.cu index 3f9b6aef08..83fddb8a8c 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cuda.cu +++ b/csrc/mmdeploy/backend_ops/torchscript/ops/modulated_deform_conv/modulated_deform_conv_cuda.cu @@ -3,21 +3,39 @@ #include "modulated_deform_conv/modulated_deform_conv_cuda.cuh" #include "torch/script.h" -namespace mmdeploy { +namespace mmdeploy +{ -void modulated_deformable_im2col_cuda( - const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask, - const int64_t batch_size, const int64_t channels, const int64_t height_im, - const int64_t width_im, const int64_t height_col, const int64_t width_col, - const int64_t kernel_h, const int64_t kernel_w, const int64_t pad_h, const int64_t pad_w, - const int64_t stride_h, const int64_t stride_w, const int64_t dilation_h, - const int64_t dilation_w, const int64_t deformable_group, at::Tensor data_col) { - // num_axes should be smaller than block size - const int channel_per_deformable_group = channels / deformable_group; - const int num_kernels = channels * batch_size * height_col * width_col; + void modulated_deformable_im2col_cuda( + const at::Tensor data_im, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int64_t batch_size, + const int64_t channels, + const int64_t height_im, + const int64_t width_im, + const int64_t height_col, + const int64_t width_col, + const int64_t kernel_h, + const int64_t kernel_w, + const int64_t pad_h, + const int64_t pad_w, + const int64_t stride_h, + const int64_t stride_w, + const int64_t dilation_h, + const int64_t dilation_w, + const int64_t deformable_group, + at::Tensor data_col) + { + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - data_im.scalar_type(), "modulated_deformable_im2col_cuda", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), + "modulated_deformable_im2col_cuda", + ([&] + { const scalar_t *data_im_ = data_im.data_ptr(); const scalar_t *data_offset_ = data_offset.data_ptr(); const scalar_t *data_mask_ = data_mask.data_ptr(); @@ -27,71 +45,66 @@ void modulated_deformable_im2col_cuda( num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, batch_size, channels, deformable_group, height_col, - width_col, data_col_); - })); -} + width_col, data_col_); })); + } -at::Tensor modulated_deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, at::Tensor bias, - at::Tensor offset, at::Tensor mask, int64_t kernel_h, - int64_t kernel_w, int64_t stride_h, int64_t stride_w, - int64_t pad_h, int64_t pad_w, int64_t dilation_h, - int64_t dilation_w, int64_t group, - int64_t deformable_group, bool with_bias) { - at::DeviceGuard guard(input.device()); + at::Tensor modulated_deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor offset, at::Tensor mask, int64_t kernel_h, int64_t kernel_w, int64_t stride_h, int64_t stride_w, int64_t pad_h, int64_t pad_w, int64_t dilation_h, int64_t dilation_w, int64_t group, int64_t deformable_group, bool with_bias) + { + at::DeviceGuard guard(input.device()); - const int batch = input.size(0); - const int channels = input.size(1); - const int height = input.size(2); - const int width = input.size(3); + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); - const int channels_out = weight.size(0); - const int channels_kernel = weight.size(1); - const int kernel_h_ = weight.size(2); - const int kernel_w_ = weight.size(3); + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); - if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) - AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, - kernel_h_, kernel_w_); - if (channels != channels_kernel * group) - AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", channels, - channels_kernel * group); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).", channels, channels_kernel * group); - const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - // resize output - at::Tensor output = - at::zeros({batch, group, channels_out / group, height_out, width_out}, input.options()); - // resize temporary columns - at::Tensor columns = at::zeros( - {group, channels * kernel_h * kernel_w / group, 1 * height_out * width_out}, input.options()); + // resize output + at::Tensor output = + at::zeros({batch, group, channels_out / group, height_out, width_out}, input.options()); + // resize temporary columns + at::Tensor columns = at::zeros( + {group, channels * kernel_h * kernel_w / group, 1 * height_out * width_out}, + input.options()); - // divide into group - weight = - weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)}); - for (int b = 0; b < batch; b++) { - modulated_deformable_im2col_cuda(input[b], offset[b], mask[b], 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, - stride_h, stride_w, dilation_h, dilation_w, deformable_group, - columns); + // divide into group + weight = + weight.view({group, weight.size(0) / group, weight.size(1), weight.size(2), weight.size(3)}); + for (int b = 0; b < batch; b++) + { + modulated_deformable_im2col_cuda(input[b], offset[b], mask[b], 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, columns); - for (int g = 0; g < group; g++) { - output[b][g] = - output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]); - } - } + for (int g = 0; g < group; g++) + { + output[b][g] = + output[b][g].flatten(1).addmm_(weight[g].flatten(1), columns[g]).view_as(output[b][g]); + } + } - output = output.view( - {output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)}); + output = output.view( + {output.size(0), output.size(1) * output.size(2), output.size(3), output.size(4)}); - if (with_bias) { - output += bias.view({1, bias.size(0), 1, 1}); - } + if (with_bias) + { + output += bias.view({1, bias.size(0), 1, 1}); + } - return output; -} + return output; + } -TORCH_LIBRARY_IMPL(mmdeploy, CUDA, m) { - m.impl("modulated_deform_conv", modulated_deform_conv_forward_cuda); -} + TORCH_LIBRARY_IMPL(mmdeploy, CUDA, m) + { + m.impl("modulated_deform_conv", modulated_deform_conv_forward_cuda); + } } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp index 3b8bb0f632..58cf0c6018 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp @@ -12,36 +12,39 @@ #include "passes/onnx/merge_shape_concate.h" #include "passes/onnx/onnx_peephole.h" -namespace mmdeploy { -namespace torch_jit { +namespace mmdeploy +{ + namespace torch_jit + { -void optimize_for_backend(torch::jit::Module& model, const std::string& ir = "torchscript", - const std::string& backend = "torchscript") { - if (ir == "torchscript") { - model = optimize_for_torchscript(model); - } else if (ir == "onnx") { - model = optimize_for_onnx(model); - } else { - fprintf(stderr, "No optimize for combination ir: %s backend: %s\n", ir.c_str(), - backend.c_str()); - exit(-1); - } -} + void optimize_for_backend(torch::jit::Module& model, const std::string& ir = "torchscript", const std::string& backend = "torchscript") + { + if (ir == "torchscript") + { + model = optimize_for_torchscript(model); + } + else if (ir == "onnx") + { + model = optimize_for_onnx(model); + } + else + { + fprintf(stderr, "No optimize for combination ir: %s backend: %s\n", ir.c_str(), backend.c_str()); + exit(-1); + } + } -PYBIND11_MODULE(ts_optimizer, m) { - namespace py = pybind11; - m.def("optimize_for_backend", optimize_for_backend, py::arg("module"), - py::arg("ir") = std::string("torchscript"), - py::arg("backend") = std::string("torchscript")); - py::module_ onnx_module = m.def_submodule("onnx"); - onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph")); - onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph")); - onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph")); - onnx_module.def("_jit_pass_fuse_select_assign", FuseSelectAssign, py::arg("graph"), - py::arg("params")); - onnx_module.def("_jit_pass_common_subgraph_elimination", CommonSubgraphElimination, - py::arg("graph"), py::arg("params")); -} + PYBIND11_MODULE(ts_optimizer, m) + { + namespace py = pybind11; + m.def("optimize_for_backend", optimize_for_backend, py::arg("module"), py::arg("ir") = std::string("torchscript"), py::arg("backend") = std::string("torchscript")); + py::module_ onnx_module = m.def_submodule("onnx"); + onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph")); + onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph")); + onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph")); + onnx_module.def("_jit_pass_fuse_select_assign", FuseSelectAssign, py::arg("graph"), py::arg("params")); + onnx_module.def("_jit_pass_common_subgraph_elimination", CommonSubgraphElimination, py::arg("graph"), py::arg("params")); + } -} // namespace torch_jit + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp index 10ce9829d5..e5f06e9c8b 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp @@ -8,306 +8,355 @@ #include #include -namespace mmdeploy { -namespace torch_jit { - -using torch::jit::AttributeKind; -using torch::jit::ClassType; -using torch::jit::Node; -using torch::jit::Symbol; -using torch::jit::Value; - -namespace prim { -using namespace ::c10::prim; -} - -namespace attr { -using namespace ::c10::attr; -} - -/** - * \brief A class implementing an API for comparing subgraphs. - */ -class SubgraphMatcher::SubgraphMatcherImpl { - public: - explicit SubgraphMatcherImpl(const Graph& pattern, MatchAttribute match_attribute) - : pattern_(pattern), match_attribute_(match_attribute) {} - - /** - * \brief Compare matchGraph with the part of the graph denoted by a node \p - * ANCHOR. - * - * The anchor node would be compared against the deepest node in the - * match-graph. A node is considered matching if its number of inputs/outputs - * is the same as in the corresponding matchGraph node, its type is the same, - * and all nodes producing input-values also match. - */ - bool matchesSubgraphFromAnchorNode(Node* anchor); - - /** \brief Return match map for nodes. */ - std::unordered_map nodes_map() const { return nodes_map_; } - - /** \brief Return match map for values. */ - std::unordered_map values_map() const { return values_map_; } - - private: - bool matchValues(const Value* v1, Value* v2); - bool matchNodes(const Node* n1, Node* n2); - bool matchAttributes(const Node* n1, Node* n2); - - static bool isInput(const Value* v); - static bool isOutput(const Value* v); - - std::unordered_map nodes_map_; - std::unordered_map values_map_; - - const MatchAttribute match_attribute_; - const Graph& pattern_; - const Node* anchor_ = nullptr; -}; - -bool SubgraphMatcher::SubgraphMatcherImpl::isInput(const Value* v) { - return v->node()->kind() == prim::Param; -} - -bool SubgraphMatcher::SubgraphMatcherImpl::isOutput(const Value* v) { - for (const Value* output : v->owningGraph()->outputs()) { - if (v == output) { - return true; - } - } - return false; -} - -/** - * Compare two Values. V1 is from pattern, V2 is from the actual graph. - * - * The values are considered matching if: - * 1) the nodes defining them match - * 2) they have the same number of uses, except they are entry or exit nodes. - */ -bool SubgraphMatcher::SubgraphMatcherImpl::matchValues(const Value* v1, Value* v2) { - // Check if we've already visited these values. - if (values_map_.count(v1)) { - if (values_map_.at(v1) != v2) { - GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), - " did not match because %", v1->debugName(), " has already been matched with %", - values_map_.at(v1)->debugName(), ".\n"); - return false; - } - return true; - } - - // When V2 is ANCHOR, we're comparing exiting values, and when V1->node is - // PARAM, we're comparing entering values - in these two cases the number of - // uses don't need to be the same. - if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) { - GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), - " did not match because number of their uses is different.\n"); - return false; - } - - // Add the values to the map before calling matchNodes to avoid infinite - // recursion. - GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), " matched.\n"); - values_map_[v1] = v2; - return matchNodes(v1->node(), v2->node()); -} - -bool SubgraphMatcher::SubgraphMatcherImpl::matchAttributes(const Node* n1, Node* n2) { - if (match_attribute_ == FORCE_MATCH && n1->numAttributes() != n2->numAttributes()) { - GRAPH_DEBUG("Nodes did not match in number attributes:\n", *n1, *n2); - return false; - } - for (const Symbol& attr_name : n1->attributeNames()) { - if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) { - GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(), - "' did not match:\n", *n1, *n2); - return false; - } - std::vector n1is, n2is; - std::vector n1fs, n2fs; - switch (n1->kindOf(attr_name)) { - case AttributeKind::s: - if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) { - GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(), - "' did not match: ", n1->s(attr_name), " != ", n2->s(attr_name), " \n", *n1, - *n2); - return false; +namespace mmdeploy +{ + namespace torch_jit + { + + using torch::jit::AttributeKind; + using torch::jit::ClassType; + using torch::jit::Node; + using torch::jit::Symbol; + using torch::jit::Value; + + namespace prim + { + using namespace ::c10::prim; } - break; - case AttributeKind::f: - if (n1->f(attr_name) != n2->f(attr_name)) { - GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(), - "' did not match:", n1->f(attr_name), " != ", n2->f(attr_name), " \n", *n1, - *n2); - return false; + + namespace attr + { + using namespace ::c10::attr; + } + + /** + * \brief A class implementing an API for comparing subgraphs. + */ + class SubgraphMatcher::SubgraphMatcherImpl + { + public: + explicit SubgraphMatcherImpl(const Graph& pattern, MatchAttribute match_attribute) + : pattern_(pattern) + , match_attribute_(match_attribute) + { + } + + /** + * \brief Compare matchGraph with the part of the graph denoted by a node \p + * ANCHOR. + * + * The anchor node would be compared against the deepest node in the + * match-graph. A node is considered matching if its number of inputs/outputs + * is the same as in the corresponding matchGraph node, its type is the same, + * and all nodes producing input-values also match. + */ + bool matchesSubgraphFromAnchorNode(Node* anchor); + + /** \brief Return match map for nodes. */ + std::unordered_map nodes_map() const + { + return nodes_map_; + } + + /** \brief Return match map for values. */ + std::unordered_map values_map() const + { + return values_map_; + } + + private: + bool matchValues(const Value* v1, Value* v2); + bool matchNodes(const Node* n1, Node* n2); + bool matchAttributes(const Node* n1, Node* n2); + + static bool isInput(const Value* v); + static bool isOutput(const Value* v); + + std::unordered_map nodes_map_; + std::unordered_map values_map_; + + const MatchAttribute match_attribute_; + const Graph& pattern_; + const Node* anchor_ = nullptr; + }; + + bool SubgraphMatcher::SubgraphMatcherImpl::isInput(const Value* v) + { + return v->node()->kind() == prim::Param; + } + + bool SubgraphMatcher::SubgraphMatcherImpl::isOutput(const Value* v) + { + for (const Value* output : v->owningGraph()->outputs()) + { + if (v == output) + { + return true; + } + } + return false; + } + + /** + * Compare two Values. V1 is from pattern, V2 is from the actual graph. + * + * The values are considered matching if: + * 1) the nodes defining them match + * 2) they have the same number of uses, except they are entry or exit nodes. + */ + bool SubgraphMatcher::SubgraphMatcherImpl::matchValues(const Value* v1, Value* v2) + { + // Check if we've already visited these values. + if (values_map_.count(v1)) + { + if (values_map_.at(v1) != v2) + { + GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), " did not match because %", v1->debugName(), " has already been matched with %", values_map_.at(v1)->debugName(), ".\n"); + return false; + } + return true; + } + + // When V2 is ANCHOR, we're comparing exiting values, and when V1->node is + // PARAM, we're comparing entering values - in these two cases the number of + // uses don't need to be the same. + if (v1->uses().size() != v2->uses().size() && !isOutput(v1) && !isInput(v1)) + { + GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), " did not match because number of their uses is different.\n"); + return false; + } + + // Add the values to the map before calling matchNodes to avoid infinite + // recursion. + GRAPH_DEBUG("Values %", v1->debugName(), " and %", v2->debugName(), " matched.\n"); + values_map_[v1] = v2; + return matchNodes(v1->node(), v2->node()); + } + + bool SubgraphMatcher::SubgraphMatcherImpl::matchAttributes(const Node* n1, Node* n2) + { + if (match_attribute_ == FORCE_MATCH && n1->numAttributes() != n2->numAttributes()) + { + GRAPH_DEBUG("Nodes did not match in number attributes:\n", *n1, *n2); + return false; + } + for (const Symbol& attr_name : n1->attributeNames()) + { + if (n1->kindOf(attr_name) != n2->kindOf(attr_name)) + { + GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(), "' did not match:\n", *n1, *n2); + return false; + } + std::vector n1is, n2is; + std::vector n1fs, n2fs; + switch (n1->kindOf(attr_name)) + { + case AttributeKind::s: + if (!std::regex_match(n2->s(attr_name), std::regex(n1->s(attr_name)))) + { + GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(), "' did not match: ", n1->s(attr_name), " != ", n2->s(attr_name), " \n", *n1, *n2); + return false; + } + break; + case AttributeKind::f: + if (n1->f(attr_name) != n2->f(attr_name)) + { + GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(), "' did not match:", n1->f(attr_name), " != ", n2->f(attr_name), " \n", *n1, *n2); + return false; + } + break; + case AttributeKind::i: + if (n1->i(attr_name) != n2->i(attr_name)) + { + GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(), "' did not match:", n1->i(attr_name), " != ", n2->i(attr_name), " \n", *n1, *n2); + return false; + } + break; + case AttributeKind::is: + n1is = n1->is(attr_name); + n2is = n2->is(attr_name); + if (n1is.size() != n2is.size()) return false; + for (size_t i = 0; i < n1is.size(); ++i) + { + if (n1is[i] != n2is[i]) return false; + } + break; + case AttributeKind::fs: + n1fs = n1->fs(attr_name); + n2fs = n2->fs(attr_name); + if (n1fs.size() != n2fs.size()) return false; + for (size_t i = 0; i < n1fs.size(); ++i) + { + if (n1fs[i] != n2fs[i]) return false; + } + break; + default: + { + // Other attributes types not supported yet + GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(), "' is not supported.\n", *n1, *n2); + return false; + } + } + } + return true; + } + + static bool endsWith(const std::string& str, const std::string& suffix) + { + return str.size() >= suffix.size() && + 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } - break; - case AttributeKind::i: - if (n1->i(attr_name) != n2->i(attr_name)) { - GRAPH_DEBUG("Nodes did not match because attribute '", attr_name.toQualString(), - "' did not match:", n1->i(attr_name), " != ", n2->i(attr_name), " \n", *n1, - *n2); - return false; + + /** + * Compare two Nodes. N1 is from pattern, N2 is from the actual graph. + * + * The nodes are considered matching if: + * 1) N1 and N2 are of the same kind. + * 2) Number of inputs and outputs is the same. + * 3) All input and output values match. + * + * A special case is when N1 is PARAM - this is considered outside the pattern, + * so it matches everything. + */ + bool SubgraphMatcher::SubgraphMatcherImpl::matchNodes(const Node* n1, Node* n2) + { + // Check if we've already visited these nodes. + if (nodes_map_.count(n1)) + { + return nodes_map_.at(n1) == n2; + } + + // Param node in pattern graph matches everything. + if (n1->kind() == prim::Param) + { + GRAPH_DEBUG("Nodes matched:\n", *n1, *n2); + return true; + } + + // We don't allow matches to span across blocks, so check if N2 is in the same + // block as the first (anchor) node. + if (n2->owningBlock() != anchor_->owningBlock()) + { + GRAPH_DEBUG("Nodes did not match because it is in the different block:\n", *n1, *n2); + return false; + } + + // Special handling for matching modules + if (n1->kind() == Symbol::fromQualString("match::module")) + { + if (n2->kind() == prim::GetAttr) + { + if (!n1->hasAttributeS("name")) + { + GRAPH_DEBUG( + "Nodes did not match because special node match::module does not have 'name' " + "attribute:\n", + *n1, + *n2); + return false; + } + auto t = n2->output()->type()->expect(); + auto real_typename = t->name()->qualifiedName(); + auto pattern_typename = n1->s(attr::name); + if (!endsWith(real_typename, pattern_typename)) + { + GRAPH_DEBUG("Nodes did not match because expected module type is different:\n"); + GRAPH_DEBUG(" actualtype: ", real_typename, "\n"); + GRAPH_DEBUG(" expected type: ", pattern_typename, "\n"); + GRAPH_DEBUG("Nodes:", *n1, *n2); + return false; + } + } + } + else + { + if (n1->kind() != n2->kind() || n1->outputs().size() != n2->outputs().size() || + n1->inputs().size() != n2->inputs().size()) + { + GRAPH_DEBUG("Nodes did not match in their kind or number of inputs/outputs:\n", *n1, *n2); + return false; + } + + if (match_attribute_ != NO_MATCH) + { + if (!matchAttributes(n1, n2)) + { + return false; + } + } + } + + // Add nodes to the map before calling matchValues to avoid infinite + // recursion. + nodes_map_[n1] = n2; + for (const auto i : c10::irange(n1->outputs().size())) + { + if (!matchValues(n1->outputs()[i], n2->outputs()[i])) + { + return false; + } + } + for (const auto i : c10::irange(n1->inputs().size())) + { + if (!matchValues(n1->inputs()[i], n2->inputs()[i])) + { + return false; + } + } + + GRAPH_DEBUG("Nodes matched:\n", *n1, *n2); + return true; + } + + /** + * Recursively try to match pattern with the actual graph starting from the + * exiting node in the pattern and anchor node in the actual graph. + */ + bool SubgraphMatcher::SubgraphMatcherImpl::matchesSubgraphFromAnchorNode(Node* anchor) + { + GRAPH_UPDATE("Starting match from a new anchor: ", *anchor); + nodes_map_.clear(); + values_map_.clear(); + anchor_ = anchor; + + const Node* bottom_node = *(pattern_.nodes().end()); + bottom_node = bottom_node->input(0)->node(); + + if (!matchNodes(bottom_node, anchor)) + { + return false; + } + + for (const Value* output : pattern_.outputs()) + { + AT_ASSERT(values_map_.count(output)); + } + + GRAPH_UPDATE("Pattern matched!\n"); + return true; } - break; - case AttributeKind::is: - n1is = n1->is(attr_name); - n2is = n2->is(attr_name); - if (n1is.size() != n2is.size()) return false; - for (size_t i = 0; i < n1is.size(); ++i) { - if (n1is[i] != n2is[i]) return false; + + SubgraphMatcher::SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute) + : impl_(new SubgraphMatcher::SubgraphMatcherImpl(pattern, match_attribute)) + { } - break; - case AttributeKind::fs: - n1fs = n1->fs(attr_name); - n2fs = n2->fs(attr_name); - if (n1fs.size() != n2fs.size()) return false; - for (size_t i = 0; i < n1fs.size(); ++i) { - if (n1fs[i] != n2fs[i]) return false; + + SubgraphMatcher::~SubgraphMatcher() = default; + + bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) + { + return impl_->matchesSubgraphFromAnchorNode(anchor); } - break; - default: { - // Other attributes types not supported yet - GRAPH_DEBUG("Nodes did not match because type of attribute '", attr_name.toQualString(), - "' is not supported.\n", *n1, *n2); - return false; - } - } - } - return true; -} - -static bool endsWith(const std::string& str, const std::string& suffix) { - return str.size() >= suffix.size() && - 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); -} - -/** - * Compare two Nodes. N1 is from pattern, N2 is from the actual graph. - * - * The nodes are considered matching if: - * 1) N1 and N2 are of the same kind. - * 2) Number of inputs and outputs is the same. - * 3) All input and output values match. - * - * A special case is when N1 is PARAM - this is considered outside the pattern, - * so it matches everything. - */ -bool SubgraphMatcher::SubgraphMatcherImpl::matchNodes(const Node* n1, Node* n2) { - // Check if we've already visited these nodes. - if (nodes_map_.count(n1)) { - return nodes_map_.at(n1) == n2; - } - - // Param node in pattern graph matches everything. - if (n1->kind() == prim::Param) { - GRAPH_DEBUG("Nodes matched:\n", *n1, *n2); - return true; - } - - // We don't allow matches to span across blocks, so check if N2 is in the same - // block as the first (anchor) node. - if (n2->owningBlock() != anchor_->owningBlock()) { - GRAPH_DEBUG("Nodes did not match because it is in the different block:\n", *n1, *n2); - return false; - } - - // Special handling for matching modules - if (n1->kind() == Symbol::fromQualString("match::module")) { - if (n2->kind() == prim::GetAttr) { - if (!n1->hasAttributeS("name")) { - GRAPH_DEBUG( - "Nodes did not match because special node match::module does not have 'name' " - "attribute:\n", - *n1, *n2); - return false; - } - auto t = n2->output()->type()->expect(); - auto real_typename = t->name()->qualifiedName(); - auto pattern_typename = n1->s(attr::name); - if (!endsWith(real_typename, pattern_typename)) { - GRAPH_DEBUG("Nodes did not match because expected module type is different:\n"); - GRAPH_DEBUG(" actualtype: ", real_typename, "\n"); - GRAPH_DEBUG(" expected type: ", pattern_typename, "\n"); - GRAPH_DEBUG("Nodes:", *n1, *n2); - return false; - } - } - } else { - if (n1->kind() != n2->kind() || n1->outputs().size() != n2->outputs().size() || - n1->inputs().size() != n2->inputs().size()) { - GRAPH_DEBUG("Nodes did not match in their kind or number of inputs/outputs:\n", *n1, *n2); - return false; - } - - if (match_attribute_ != NO_MATCH) { - if (!matchAttributes(n1, n2)) { - return false; - } - } - } - - // Add nodes to the map before calling matchValues to avoid infinite - // recursion. - nodes_map_[n1] = n2; - for (const auto i : c10::irange(n1->outputs().size())) { - if (!matchValues(n1->outputs()[i], n2->outputs()[i])) { - return false; - } - } - for (const auto i : c10::irange(n1->inputs().size())) { - if (!matchValues(n1->inputs()[i], n2->inputs()[i])) { - return false; - } - } - - GRAPH_DEBUG("Nodes matched:\n", *n1, *n2); - return true; -} - -/** - * Recursively try to match pattern with the actual graph starting from the - * exiting node in the pattern and anchor node in the actual graph. - */ -bool SubgraphMatcher::SubgraphMatcherImpl::matchesSubgraphFromAnchorNode(Node* anchor) { - GRAPH_UPDATE("Starting match from a new anchor: ", *anchor); - nodes_map_.clear(); - values_map_.clear(); - anchor_ = anchor; - - const Node* bottom_node = *(pattern_.nodes().end()); - bottom_node = bottom_node->input(0)->node(); - - if (!matchNodes(bottom_node, anchor)) { - return false; - } - - for (const Value* output : pattern_.outputs()) { - AT_ASSERT(values_map_.count(output)); - } - - GRAPH_UPDATE("Pattern matched!\n"); - return true; -} - -SubgraphMatcher::SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute) - : impl_(new SubgraphMatcher::SubgraphMatcherImpl(pattern, match_attribute)) {} - -SubgraphMatcher::~SubgraphMatcher() = default; - -bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) { - return impl_->matchesSubgraphFromAnchorNode(anchor); -} - -std::unordered_map SubgraphMatcher::nodes_map() const { - return impl_->nodes_map(); -} - -std::unordered_map SubgraphMatcher::values_map() const { - return impl_->values_map(); -} - -} // namespace torch_jit + + std::unordered_map SubgraphMatcher::nodes_map() const + { + return impl_->nodes_map(); + } + + std::unordered_map SubgraphMatcher::values_map() const + { + return impl_->values_map(); + } + + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h index e2488e252c..ffe1b51aa8 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h @@ -5,34 +5,42 @@ #include #include -namespace mmdeploy { -namespace torch_jit { -using torch::jit::Graph; -using torch::jit::Node; -using torch::jit::Value; - -enum MatchAttribute { FORCE_MATCH, TRY_MATCH, NO_MATCH }; - -class SubgraphMatcher { - public: - explicit SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute = TRY_MATCH); - - ~SubgraphMatcher(); - - bool matchesSubgraphFromAnchorNode(Node* anchor); - - /** \brief Return match map for nodes. */ - std::unordered_map nodes_map() const; - - /** \brief Return match map for values. */ - std::unordered_map values_map() const; - - private: - class SubgraphMatcherImpl; - std::unique_ptr impl_; -}; - -} // namespace torch_jit +namespace mmdeploy +{ + namespace torch_jit + { + using torch::jit::Graph; + using torch::jit::Node; + using torch::jit::Value; + + enum MatchAttribute + { + FORCE_MATCH, + TRY_MATCH, + NO_MATCH + }; + + class SubgraphMatcher + { + public: + explicit SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute = TRY_MATCH); + + ~SubgraphMatcher(); + + bool matchesSubgraphFromAnchorNode(Node* anchor); + + /** \brief Return match map for nodes. */ + std::unordered_map nodes_map() const; + + /** \brief Return match map for values. */ + std::unordered_map values_map() const; + + private: + class SubgraphMatcherImpl; + std::unique_ptr impl_; + }; + + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.cpp index 05ef9d54cd..2178bb3a4e 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.cpp @@ -12,59 +12,63 @@ #include #if TORCH_VERSION_MINOR >= 9 -#include -#include -#include + #include + #include + #include #endif -namespace mmdeploy { +namespace mmdeploy +{ -using torch::jit::Graph; -const std::shared_ptr& required_passes(const std::shared_ptr& graph) { - RemoveExpands(graph); - CanonicalizeOps(graph); - EliminateDeadCode(graph); - return graph; -} + using torch::jit::Graph; + const std::shared_ptr& required_passes(const std::shared_ptr& graph) + { + RemoveExpands(graph); + CanonicalizeOps(graph); + EliminateDeadCode(graph); + return graph; + } -Module optimize_for_torchscript(const Module& model) { - auto frozen_model = freeze_module(model); - auto graph = frozen_model.get_method("forward").graph(); - OptimizeFrozenGraph(graph, true); + Module optimize_for_torchscript(const Module& model) + { + auto frozen_model = freeze_module(model); + auto graph = frozen_model.get_method("forward").graph(); + OptimizeFrozenGraph(graph, true); #if TORCH_VERSION_MINOR >= 9 - FuseFrozenConvAddRelu(graph); - ConvertFrozenOpsToMKLDNN(graph); - FrozenLinearTranspose(graph); + FuseFrozenConvAddRelu(graph); + ConvertFrozenOpsToMKLDNN(graph); + FrozenLinearTranspose(graph); #endif - graph = required_passes(graph); - EliminateCommonSubexpression(graph); - PeepholeOptimize(graph); - ConstantPropagation(graph); - ConstantPooling(graph); + graph = required_passes(graph); + EliminateCommonSubexpression(graph); + PeepholeOptimize(graph); + ConstantPropagation(graph); + ConstantPooling(graph); - // TODO: add more custom passes + // TODO: add more custom passes - return frozen_model; -} + return frozen_model; + } -Module optimize_for_onnx(const Module& model) { - auto frozen_model = freeze_module(model, {"training"}); - auto graph = frozen_model.get_method("forward").graph(); - OptimizeFrozenGraph(graph, true); + Module optimize_for_onnx(const Module& model) + { + auto frozen_model = freeze_module(model, {"training"}); + auto graph = frozen_model.get_method("forward").graph(); + OptimizeFrozenGraph(graph, true); #if TORCH_VERSION_MINOR >= 9 - FuseFrozenConvAddRelu(graph); - ConvertFrozenOpsToMKLDNN(graph); - FrozenLinearTranspose(graph); + FuseFrozenConvAddRelu(graph); + ConvertFrozenOpsToMKLDNN(graph); + FrozenLinearTranspose(graph); #endif - // TODO: add more custom passes + // TODO: add more custom passes - return frozen_model; -} + return frozen_model; + } -// TODO: add optimizer for other backend/onnx + // TODO: add optimizer for other backend/onnx } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.h index d0d91c627d..fc5a3725d1 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/optimizer.h @@ -1,10 +1,11 @@ // Copyright (c) OpenMMLab. All rights reserved. #include -namespace mmdeploy { -using torch::jit::script::Module; +namespace mmdeploy +{ + using torch::jit::script::Module; -Module optimize_for_torchscript(const Module &model); + Module optimize_for_torchscript(const Module& model); -Module optimize_for_onnx(const Module &model); + Module optimize_for_onnx(const Module& model); } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp index c6541e630a..c26db5a34f 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp @@ -4,135 +4,161 @@ #include #include -namespace mmdeploy { -namespace torch_jit { - -using c10::Symbol; -using torch::jit::Block; -using torch::jit::EqualNode; -using torch::jit::HashNode; -using torch::jit::Node; -using torch::jit::Value; - -struct EqualNodeWithParams { - EqualNodeWithParams(std::unordered_map& params) : params_(params) {} - - bool operator()(const Node* lhs, const Node* rhs) const { - auto lhs_inputs = lhs->inputs(); - auto rhs_inputs = rhs->inputs(); - } - - private: - std::unordered_map& params_; -}; - -struct CommonSubexpressionEliminator { - using ParamMapType = std::unordered_map>; - CommonSubexpressionEliminator(std::shared_ptr graph, - std::unordered_map& params) - : graph_(std::move(graph)), params_(params) {} - - bool run(std::function parent_lookup_fn) { - ParamMapType param_map; - return run(graph_->block(), std::move(parent_lookup_fn), param_map); - } - - // The function implements common subexpression elimination. - // Since the nodes are visited in topological order, one pass is enough. - // returns true if CSE made changes to a graph - bool run(Block* block, std::function parent_lookup_fn, ParamMapType& param_map) { - std::unordered_set subexprs; - bool changed = false; - for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { - auto node = *it; - - // check if inputs come from params(graph input) - auto node_inputs = node->inputs(); - for (auto input : node_inputs) { - if (input->node()->kind() == Symbol::fromQualString("prim::Param")) { - auto debug_name = input->debugName(); - - // check if input in params_ - if (params_.find(debug_name) == params_.end()) continue; - - // check if input is already visited. - if (param_map.find(debug_name) != param_map.end()) continue; - - // check if there is a param has same value with input - auto val = params_[debug_name]; - bool update_map = true; - for (auto kv : param_map) { - auto param_val = kv.second.first; - if (val.device() != param_val.device()) continue; - if (val.dtype() != param_val.dtype()) continue; - if (!val.equal(param_val)) continue; - input->replaceAllUsesWith(kv.second.second); - update_map = false; - break; - } - - // add input to param_map - if (update_map) { - param_map.emplace(debug_name, - std::make_pair(std::move(val), std::move(input))); - } - } - } - - if (!node->blocks().empty()) { - // Traverse sub-blocks. - for (auto block : node->blocks()) { - changed |= run( - block, - [&](Node* n) { - auto existing = subexprs.find(n); - if (existing != subexprs.end()) { - return *existing; +namespace mmdeploy +{ + namespace torch_jit + { + + using c10::Symbol; + using torch::jit::Block; + using torch::jit::EqualNode; + using torch::jit::HashNode; + using torch::jit::Node; + using torch::jit::Value; + + struct EqualNodeWithParams + { + EqualNodeWithParams(std::unordered_map& params) + : params_(params) + { + } + + bool operator()(const Node* lhs, const Node* rhs) const + { + auto lhs_inputs = lhs->inputs(); + auto rhs_inputs = rhs->inputs(); + } + + private: + std::unordered_map& params_; + }; + + struct CommonSubexpressionEliminator + { + using ParamMapType = std::unordered_map>; + CommonSubexpressionEliminator(std::shared_ptr graph, + std::unordered_map& params) + : graph_(std::move(graph)) + , params_(params) + { + } + + bool run(std::function parent_lookup_fn) + { + ParamMapType param_map; + return run(graph_->block(), std::move(parent_lookup_fn), param_map); + } + + // The function implements common subexpression elimination. + // Since the nodes are visited in topological order, one pass is enough. + // returns true if CSE made changes to a graph + bool run(Block* block, std::function parent_lookup_fn, ParamMapType& param_map) + { + std::unordered_set subexprs; + bool changed = false; + for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) + { + auto node = *it; + + // check if inputs come from params(graph input) + auto node_inputs = node->inputs(); + for (auto input : node_inputs) + { + if (input->node()->kind() == Symbol::fromQualString("prim::Param")) + { + auto debug_name = input->debugName(); + + // check if input in params_ + if (params_.find(debug_name) == params_.end()) continue; + + // check if input is already visited. + if (param_map.find(debug_name) != param_map.end()) continue; + + // check if there is a param has same value with input + auto val = params_[debug_name]; + bool update_map = true; + for (auto kv : param_map) + { + auto param_val = kv.second.first; + if (val.device() != param_val.device()) continue; + if (val.dtype() != param_val.dtype()) continue; + if (!val.equal(param_val)) continue; + input->replaceAllUsesWith(kv.second.second); + update_map = false; + break; + } + + // add input to param_map + if (update_map) + { + param_map.emplace(debug_name, + std::make_pair(std::move(val), std::move(input))); + } + } + } + + if (!node->blocks().empty()) + { + // Traverse sub-blocks. + for (auto block : node->blocks()) + { + changed |= run( + block, + [&](Node* n) + { + auto existing = subexprs.find(n); + if (existing != subexprs.end()) + { + return *existing; + } + + return parent_lookup_fn(n); + }, + param_map); + } + + continue; + } + + // Check for CSE opportunities in the parent block. + auto parent_lookup = parent_lookup_fn(node); + auto g_out = node->owningGraph()->outputs(); + if (parent_lookup != nullptr) + { + changed = true; + node->replaceAllUsesWith(parent_lookup); + it.destroyCurrent(); + continue; + } + + // Check whether the same subexpression already exists. + auto subit = subexprs.insert(node); + if (!subit.second) + { + // Subexpression exists, replace the uses of node, and destroy it. + auto existing = *subit.first; + + changed = true; + node->replaceAllUsesWith(existing); + // Destroy the node. + it.destroyCurrent(); + } } - return parent_lookup_fn(n); - }, - param_map); - } + return changed; + } - continue; - } - - // Check for CSE opportunities in the parent block. - auto parent_lookup = parent_lookup_fn(node); - auto g_out = node->owningGraph()->outputs(); - if (parent_lookup != nullptr) { - changed = true; - node->replaceAllUsesWith(parent_lookup); - it.destroyCurrent(); - continue; - } - - // Check whether the same subexpression already exists. - auto subit = subexprs.insert(node); - if (!subit.second) { - // Subexpression exists, replace the uses of node, and destroy it. - auto existing = *subit.first; - - changed = true; - node->replaceAllUsesWith(existing); - // Destroy the node. - it.destroyCurrent(); - } - } - - return changed; - } - - private: - std::shared_ptr graph_; - std::unordered_map& params_; -}; - -void CommonSubgraphElimination(std::shared_ptr& graph, - std::unordered_map& params) { - CommonSubexpressionEliminator cse(graph, params); - cse.run([](Node*) { return nullptr; }); -} -} // namespace torch_jit + private: + std::shared_ptr graph_; + std::unordered_map& params_; + }; + + void CommonSubgraphElimination(std::shared_ptr& graph, + std::unordered_map& params) + { + CommonSubexpressionEliminator cse(graph, params); + cse.run([](Node*) + { return nullptr; }); + } + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h index d90b98073e..da108ff733 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h @@ -3,18 +3,20 @@ #define _COMMON_SUBGRAPH_ELIMINATION_H_ #include -namespace mmdeploy { -namespace torch_jit { -using torch::Tensor; -using torch::jit::Graph; +namespace mmdeploy +{ + namespace torch_jit + { + using torch::Tensor; + using torch::jit::Graph; -// This pass is used eliminate the common subgraph. -// There are two main difference between the one in torch/csrc/jit/pass -// 1. AliasDb is not needed in ONNX model -// 2. params might also participated in the elimination -void CommonSubgraphElimination(std::shared_ptr& graph, - std::unordered_map& params); -} // namespace torch_jit + // This pass is used eliminate the common subgraph. + // There are two main difference between the one in torch/csrc/jit/pass + // 1. AliasDb is not needed in ONNX model + // 2. params might also participated in the elimination + void CommonSubgraphElimination(std::shared_ptr& graph, + std::unordered_map& params); + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp index 73f8965412..db44bdb4c1 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.cpp @@ -9,89 +9,94 @@ #include "utils.h" -namespace mmdeploy { -namespace torch_jit { - -using c10::Symbol; -using torch::jit::IValue; -using torch::jit::Match; -using torch::jit::TensorType; -using torch::jit::TypeKind; -using torch::jit::Value; - -static bool matchClsHead(const Match& match, const std::unordered_map& map) { - // TODO: check if value map in latest pytorch can ease the filter. - - // check cat -1 - { - // check if the shape of second inputs is 1 - auto cat_v1 = match.values_map.at(map.at("cat1")); - if (cat_v1->type()->kind() != TypeKind::TensorType) return false; - auto cat_v1_type = cat_v1->type()->cast(); - auto cat_v1_size = cat_v1_type->sizes().concrete_sizes(); - if (!cat_v1_size.has_value()) return false; - IValue cat_v1_size_value(cat_v1_size.value()); - auto size_list = cat_v1_size_value.toIntList(); - if (size_list.size() != 1 || size_list[0] != 1) return false; - } - - // check unsqueeze - auto cat_v0 = match.values_map.at(map.at("cat0")); - auto unsqueeze_node = cat_v0->node(); - { - if (!is_kind(unsqueeze_node, "onnx::Unsqueeze")) return false; - auto unsqueeze_axes = unsqueeze_node->is(Symbol::attr("axes")); - if (unsqueeze_axes.size() != 1 || unsqueeze_axes[0] != 0) return false; - } - - // check gather - auto gather_node = unsqueeze_node->input()->node(); - auto gather_inputs = gather_node->inputs(); - { - if (!is_kind(gather_node, "onnx::Gather")) return false; - auto gather_axis = gather_node->i(Symbol::attr("axis")); - if (gather_axis != 0) return false; - } - - auto x = match.values_map.at(map.at("x")); - // check shape - auto shape_node = gather_inputs[0]->node(); - { - if (!is_kind(shape_node, "onnx::Shape")) return false; - if (shape_node->input() != x) return false; - } - - // check constant - auto const_node = gather_inputs[1]->node(); - { - if (!is_kind(const_node, "onnx::Constant")) return false; - auto ival = const_node->t(Symbol::attr("value")); - if (ival.dim() != 0) return false; - auto ival_dataptr = ival.data_ptr(); - if (ival_dataptr[0] != 0) return false; - } - - // check if reshape is the output of the graph - auto reshape_pattern = map.at("reshape"); - auto reshape_node = match.values_map.at(reshape_pattern); - auto uses = reshape_node->uses(); - for (auto use : uses) { - auto user = use.user; - if (is_kind(user, "prim::Return")) return false; - } - - return true; -} - -// from: -// x->shape->gather->unsqueeze->concat -// | | -// gap--------------------------reshape -// -// to: -// x->gap->flatten -void FlattenClsHead(std::shared_ptr& graph) { - std::string pattern = R"IR( +namespace mmdeploy +{ + namespace torch_jit + { + + using c10::Symbol; + using torch::jit::IValue; + using torch::jit::Match; + using torch::jit::TensorType; + using torch::jit::TypeKind; + using torch::jit::Value; + + static bool matchClsHead(const Match& match, const std::unordered_map& map) + { + // TODO: check if value map in latest pytorch can ease the filter. + + // check cat -1 + { + // check if the shape of second inputs is 1 + auto cat_v1 = match.values_map.at(map.at("cat1")); + if (cat_v1->type()->kind() != TypeKind::TensorType) return false; + auto cat_v1_type = cat_v1->type()->cast(); + auto cat_v1_size = cat_v1_type->sizes().concrete_sizes(); + if (!cat_v1_size.has_value()) return false; + IValue cat_v1_size_value(cat_v1_size.value()); + auto size_list = cat_v1_size_value.toIntList(); + if (size_list.size() != 1 || size_list[0] != 1) return false; + } + + // check unsqueeze + auto cat_v0 = match.values_map.at(map.at("cat0")); + auto unsqueeze_node = cat_v0->node(); + { + if (!is_kind(unsqueeze_node, "onnx::Unsqueeze")) return false; + auto unsqueeze_axes = unsqueeze_node->is(Symbol::attr("axes")); + if (unsqueeze_axes.size() != 1 || unsqueeze_axes[0] != 0) return false; + } + + // check gather + auto gather_node = unsqueeze_node->input()->node(); + auto gather_inputs = gather_node->inputs(); + { + if (!is_kind(gather_node, "onnx::Gather")) return false; + auto gather_axis = gather_node->i(Symbol::attr("axis")); + if (gather_axis != 0) return false; + } + + auto x = match.values_map.at(map.at("x")); + // check shape + auto shape_node = gather_inputs[0]->node(); + { + if (!is_kind(shape_node, "onnx::Shape")) return false; + if (shape_node->input() != x) return false; + } + + // check constant + auto const_node = gather_inputs[1]->node(); + { + if (!is_kind(const_node, "onnx::Constant")) return false; + auto ival = const_node->t(Symbol::attr("value")); + if (ival.dim() != 0) return false; + auto ival_dataptr = ival.data_ptr(); + if (ival_dataptr[0] != 0) return false; + } + + // check if reshape is the output of the graph + auto reshape_pattern = map.at("reshape"); + auto reshape_node = match.values_map.at(reshape_pattern); + auto uses = reshape_node->uses(); + for (auto use : uses) + { + auto user = use.user; + if (is_kind(user, "prim::Return")) return false; + } + + return true; + } + + // from: + // x->shape->gather->unsqueeze->concat + // | | + // gap--------------------------reshape + // + // to: + // x->gap->flatten + void FlattenClsHead(std::shared_ptr& graph) + { + std::string pattern = R"IR( graph(%x, %cat0, %cat1): %gap = onnx::GlobalAveragePool(%x) %cat = onnx::Concat[axis=0](%cat0, %cat1) @@ -99,21 +104,22 @@ void FlattenClsHead(std::shared_ptr& graph) { return (%reshape) )IR"; - std::string replacement = R"IR( + std::string replacement = R"IR( graph(%x, %cat0, %cat1): %gap = onnx::GlobalAveragePool(%x) %flatten = onnx::Flatten(%gap) return (%flatten) )IR"; - torch::jit::SubgraphRewriter subgraph_rewriter; - subgraph_rewriter.RegisterRewritePattern(pattern, replacement); - subgraph_rewriter.runOnGraph(graph, matchClsHead); + torch::jit::SubgraphRewriter subgraph_rewriter; + subgraph_rewriter.RegisterRewritePattern(pattern, replacement); + subgraph_rewriter.runOnGraph(graph, matchClsHead); - torch::jit::EliminateDeadCode( - graph->block(), true, - torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); -} + torch::jit::EliminateDeadCode( + graph->block(), + true, + torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); + } -} // namespace torch_jit + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h index b66b700d1c..64d8ea3352 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/flatten_cls_head.h @@ -3,12 +3,14 @@ #define _FLATTEN_CLS_HEAD_H_ #include -namespace mmdeploy { -namespace torch_jit { -using torch::jit::Graph; +namespace mmdeploy +{ + namespace torch_jit + { + using torch::jit::Graph; -void FlattenClsHead(std::shared_ptr& graph); -} // namespace torch_jit + void FlattenClsHead(std::shared_ptr& graph); + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp index 8dc5847753..2798abaa8c 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp @@ -6,131 +6,149 @@ #include "common_subgraph_elimination.h" #include "torch/csrc/jit/ir/irparser.h" -namespace mmdeploy { -namespace torch_jit { - -using c10::Symbol; -using torch::jit::Block; -using torch::jit::IValue; -using torch::jit::Node; - -bool RemoveBoolCast(Node* node) { - auto bottom_node = node->input()->node(); - if (bottom_node->kind() != Symbol::onnx("Greater") && - bottom_node->kind() != Symbol::onnx("Less")) { - return false; - } - node->output()->replaceAllUsesWith(bottom_node->output()); - return true; -} - -bool FuseSelectAssign(Node* node, std::unordered_map& params, - std::unordered_map& vmap, SubgraphMatcher& matcher) { - auto values_map = matcher.values_map(); - - auto cmp1 = values_map[vmap["cmp_1"]]->node(); - auto cmp2 = values_map[vmap["cmp_2"]]->node(); - if (cmp1 != cmp2) { - // cmp_1 == cmp_2, cmp in (Great, Less) - if (cmp1->kind() != cmp2->kind()) return false; - if (!(cmp1->kind() == Symbol::onnx("Greater") || cmp1->kind() == Symbol::onnx("Less"))) - return false; - - // check threshold - Node* cmps[] = {cmp1, cmp2}; - float thres = 0.0f; - Node* x = nullptr; - for (int i = 0; i < 2; ++i) { - auto cmp = cmps[i]; - auto threshold = cmp->inputs()[1]->node(); - if (threshold->kind() != Symbol::onnx("Constant")) return false; - auto thres_val = threshold->t(Symbol::attr("value")); - if (i == 0) { - thres = thres_val.data_ptr()[0]; - x = cmp->inputs()[0]->node(); - } else { - float tmp_val = thres_val.data_ptr()[0]; - if (fabs(thres - tmp_val) > 1e-10) { - return false; +namespace mmdeploy +{ + namespace torch_jit + { + + using c10::Symbol; + using torch::jit::Block; + using torch::jit::IValue; + using torch::jit::Node; + + bool RemoveBoolCast(Node* node) + { + auto bottom_node = node->input()->node(); + if (bottom_node->kind() != Symbol::onnx("Greater") && + bottom_node->kind() != Symbol::onnx("Less")) + { + return false; + } + node->output()->replaceAllUsesWith(bottom_node->output()); + return true; } - if (x != cmp->inputs()[0]->node()) { - return false; + + bool FuseSelectAssign(Node* node, std::unordered_map& params, std::unordered_map& vmap, SubgraphMatcher& matcher) + { + auto values_map = matcher.values_map(); + + auto cmp1 = values_map[vmap["cmp_1"]]->node(); + auto cmp2 = values_map[vmap["cmp_2"]]->node(); + if (cmp1 != cmp2) + { + // cmp_1 == cmp_2, cmp in (Great, Less) + if (cmp1->kind() != cmp2->kind()) return false; + if (!(cmp1->kind() == Symbol::onnx("Greater") || cmp1->kind() == Symbol::onnx("Less"))) + return false; + + // check threshold + Node* cmps[] = {cmp1, cmp2}; + float thres = 0.0f; + Node* x = nullptr; + for (int i = 0; i < 2; ++i) + { + auto cmp = cmps[i]; + auto threshold = cmp->inputs()[1]->node(); + if (threshold->kind() != Symbol::onnx("Constant")) return false; + auto thres_val = threshold->t(Symbol::attr("value")); + if (i == 0) + { + thres = thres_val.data_ptr()[0]; + x = cmp->inputs()[0]->node(); + } + else + { + float tmp_val = thres_val.data_ptr()[0]; + if (fabs(thres - tmp_val) > 1e-10) + { + return false; + } + if (x != cmp->inputs()[0]->node()) + { + return false; + } + } + } + } + + { + // check shape of reshape + Node* shape = values_map[vmap["reshape_1_shape"]]->node(); + auto shape_val = shape->t(Symbol::attr("value")); + if (shape_val.dim() != 1) return false; + if (shape_val.data_ptr()[0] != -1) return false; + } + + { + // check transpose + Node* trans[] = {values_map[vmap["trans_1"]]->node(), values_map[vmap["trans_2"]]->node()}; + for (auto tran : trans) + { + auto tran_perm = tran->is(Symbol::attr("perm")); + if (tran_perm.size() != 2) return false; + if (tran_perm[0] != 1 || tran_perm[1] != 0) return false; + } + } + + { + // check gather indice + Node* gather_inds = values_map[vmap["gather_inds_2"]]->node(); + auto inds_val = gather_inds->t(Symbol::attr("value")); + if (inds_val.dim() != 0) return false; + if (inds_val.data_ptr()[0] != 0) return false; + } + + { + // check slice start + Node* slice = values_map[vmap["slice_2"]]->node(); + auto start_name = slice->inputs()[1]->debugName(); + auto start_val = params[start_name]; + if (start_val.dim() != 1) return false; + if (start_val.data_ptr()[0] != 0) return false; + } + + // create new node + auto graph = node->owningGraph(); + auto z = values_map[vmap["z"]]; + auto y = values_map[vmap["y"]]; + auto where_node = graph->create(Symbol::onnx("Where"), {cmp1->output(), z, y}); + where_node->insertBefore(node); + where_node->output()->copyMetadata(node->output()); + node->output()->replaceAllUsesWith(where_node->output()); + return true; + } + + void FuseSelectAssign(Block* block, std::unordered_map& params, std::unordered_map& vmap, SubgraphMatcher& matcher) + { + auto graph = block->owningGraph(); + auto it = block->nodes().begin(); + while (it != block->nodes().end()) + { + auto node = *it; + ++it; + for (auto block : node->blocks()) + { + FuseSelectAssign(block, params, vmap, matcher); + } + + if (node->kind() == Symbol::onnx("Cast") && node->i(Symbol::attr("to")) == 9) + { + RemoveBoolCast(node); + } + else if (matcher.matchesSubgraphFromAnchorNode(node)) + { + FuseSelectAssign(node, params, vmap, matcher); + } + } } - } - } - } - - { - // check shape of reshape - Node* shape = values_map[vmap["reshape_1_shape"]]->node(); - auto shape_val = shape->t(Symbol::attr("value")); - if (shape_val.dim() != 1) return false; - if (shape_val.data_ptr()[0] != -1) return false; - } - - { - // check transpose - Node* trans[] = {values_map[vmap["trans_1"]]->node(), values_map[vmap["trans_2"]]->node()}; - for (auto tran : trans) { - auto tran_perm = tran->is(Symbol::attr("perm")); - if (tran_perm.size() != 2) return false; - if (tran_perm[0] != 1 || tran_perm[1] != 0) return false; - } - } - - { - // check gather indice - Node* gather_inds = values_map[vmap["gather_inds_2"]]->node(); - auto inds_val = gather_inds->t(Symbol::attr("value")); - if (inds_val.dim() != 0) return false; - if (inds_val.data_ptr()[0] != 0) return false; - } - - { - // check slice start - Node* slice = values_map[vmap["slice_2"]]->node(); - auto start_name = slice->inputs()[1]->debugName(); - auto start_val = params[start_name]; - if (start_val.dim() != 1) return false; - if (start_val.data_ptr()[0] != 0) return false; - } - - // create new node - auto graph = node->owningGraph(); - auto z = values_map[vmap["z"]]; - auto y = values_map[vmap["y"]]; - auto where_node = graph->create(Symbol::onnx("Where"), {cmp1->output(), z, y}); - where_node->insertBefore(node); - where_node->output()->copyMetadata(node->output()); - node->output()->replaceAllUsesWith(where_node->output()); - return true; -} - -void FuseSelectAssign(Block* block, std::unordered_map& params, - std::unordered_map& vmap, SubgraphMatcher& matcher) { - auto graph = block->owningGraph(); - auto it = block->nodes().begin(); - while (it != block->nodes().end()) { - auto node = *it; - ++it; - for (auto block : node->blocks()) { - FuseSelectAssign(block, params, vmap, matcher); - } - - if (node->kind() == Symbol::onnx("Cast") && node->i(Symbol::attr("to")) == 9) { - RemoveBoolCast(node); - } else if (matcher.matchesSubgraphFromAnchorNode(node)) { - FuseSelectAssign(node, params, vmap, matcher); - } - } -} - -void FuseSelectAssign(std::shared_ptr& graph, - std::unordered_map& params) { - // cse before search - CommonSubgraphElimination(graph, params); - - std::string pattern_str = R"IR( + + void FuseSelectAssign(std::shared_ptr& graph, + std::unordered_map& params) + { + // cse before search + CommonSubgraphElimination(graph, params); + + std::string pattern_str = R"IR( graph(%y, %z, %cmp_1, %cmp_2, %start, %axes, %shape_2): %nz_1 = onnx::NonZero(%cmp_1) %trans_1 = onnx::Transpose(%nz_1) @@ -149,15 +167,16 @@ void FuseSelectAssign(std::shared_ptr& graph, return (%scatter_2) )IR"; - Graph pattern; - std::unordered_map vmap; - torch::jit::parseIR(pattern_str, &pattern, vmap); - - SubgraphMatcher matcher(pattern, MatchAttribute::NO_MATCH); - FuseSelectAssign(graph->block(), params, vmap, matcher); - torch::jit::EliminateDeadCode( - graph->block(), true, - torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); -} -} // namespace torch_jit + Graph pattern; + std::unordered_map vmap; + torch::jit::parseIR(pattern_str, &pattern, vmap); + + SubgraphMatcher matcher(pattern, MatchAttribute::NO_MATCH); + FuseSelectAssign(graph->block(), params, vmap, matcher); + torch::jit::EliminateDeadCode( + graph->block(), + true, + torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); + } + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h index afa0dc56d6..0e80ec1d67 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h @@ -3,15 +3,17 @@ #define _FUSE_SELECT_ASSIGN_H_ #include -namespace mmdeploy { -namespace torch_jit { -using torch::Tensor; -using torch::jit::Graph; +namespace mmdeploy +{ + namespace torch_jit + { + using torch::Tensor; + using torch::jit::Graph; -// this pass is used to fuse y[x>thres] = z[x>thres] -void FuseSelectAssign(std::shared_ptr& graph, - std::unordered_map& params); -} // namespace torch_jit + // this pass is used to fuse y[x>thres] = z[x>thres] + void FuseSelectAssign(std::shared_ptr& graph, + std::unordered_map& params); + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.cpp index 3da4933b15..dea6909f8b 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.cpp @@ -5,111 +5,131 @@ #include "utils.h" -namespace mmdeploy { -namespace torch_jit { - -using c10::Symbol; -using torch::jit::Block; -using torch::jit::IValue; -using torch::jit::Node; -using torch::jit::TensorType; -using torch::jit::Value; - -void MergeShapeConcate(Node* node) { - auto inputs = node->inputs(); - - std::vector gather_value; - Value* shape_from = nullptr; - - std::vector node_to_remove{node}; - - // check pattern shape->gather->unsqueeze->concate - for (auto input : inputs) { - auto unsqueeze_node = input->node(); - if (!is_kind(unsqueeze_node, "onnx::Unsqueeze") || unsqueeze_node->output()->uses().size() != 1) - return; - - if (unsqueeze_node->hasAttribute(Symbol::attr("axes"))) { - auto axes = unsqueeze_node->is(Symbol::attr("axes")); - if (axes.size() != 1 && axes[0] != 0) return; - } - - auto gather_node = unsqueeze_node->input(0)->node(); - if (!is_kind(gather_node, "onnx::Gather") || gather_node->i(Symbol::attr("axis")) != 0 || - gather_node->output()->uses().size() != 1) - return; - - auto gather_inputs = gather_node->inputs(); - auto gather_data = gather_inputs[0]; - auto gather_indices = gather_inputs[1]; - auto shape_node = gather_data->node(); - if (!is_kind(shape_node, "onnx::Shape") || shape_node->output()->uses().size() != 1) return; - - auto current_shape_from = shape_node->input(); - if (!shape_from) { - shape_from = current_shape_from; - } else { - if (shape_from != current_shape_from) return; - } - - auto constant_node = gather_indices->node(); - if (!is_kind(constant_node, "onnx::Constant")) return; - - auto gather_indices_val = constant_node->t(Symbol::attr("value")); - int64_t* data_ptr = gather_indices_val.data_ptr(); - if (gather_indices_val.dim() == 0) { - gather_value.push_back(data_ptr[0]); - } else { - int element_size = gather_indices_val.element_size(); - for (int j = 0; j < element_size; ++j) { - gather_value.push_back(data_ptr[j]); - } - } - - node_to_remove.insert(node_to_remove.end(), {unsqueeze_node, gather_node, shape_node}); - } - - // create constant value - auto graph = node->owningGraph(); - auto const_node = graph->create(Symbol::onnx("Constant")); - const_node->t_(Symbol::attr("value"), at::tensor(gather_value)); - auto first_node = node->owningGraph()->block()->nodes().front(); - if (const_node != first_node) const_node->insertBefore(first_node); - - // recreate shape node - auto shape_node = graph->create(Symbol::onnx("Shape"), {shape_from}); - shape_node->insertBefore(node); - - // create gather node - auto gather_node = - graph->create(Symbol::onnx("Gather"), {shape_node->output(), const_node->output()}); - - // insert into graph - gather_node->insertAfter(node); - node->output()->replaceAllUsesWith(gather_node->output()); - - for (auto n : node_to_remove) { - n->destroy(); - } -} - -void MergeShapeConcate(Block* block) { - auto graph = block->owningGraph(); - auto it = block->nodes().begin(); - while (it != block->nodes().end()) { - auto node = *it; - ++it; - for (auto block : node->blocks()) { - MergeShapeConcate(block); - } - - if (is_kind(node, "onnx::Concat")) { - MergeShapeConcate(node); - } - } -} - -void MergeShapeConcate(const std::shared_ptr& graph) { MergeShapeConcate(graph->block()); } - -} // namespace torch_jit +namespace mmdeploy +{ + namespace torch_jit + { + + using c10::Symbol; + using torch::jit::Block; + using torch::jit::IValue; + using torch::jit::Node; + using torch::jit::TensorType; + using torch::jit::Value; + + void MergeShapeConcate(Node* node) + { + auto inputs = node->inputs(); + + std::vector gather_value; + Value* shape_from = nullptr; + + std::vector node_to_remove{node}; + + // check pattern shape->gather->unsqueeze->concate + for (auto input : inputs) + { + auto unsqueeze_node = input->node(); + if (!is_kind(unsqueeze_node, "onnx::Unsqueeze") || unsqueeze_node->output()->uses().size() != 1) + return; + + if (unsqueeze_node->hasAttribute(Symbol::attr("axes"))) + { + auto axes = unsqueeze_node->is(Symbol::attr("axes")); + if (axes.size() != 1 && axes[0] != 0) return; + } + + auto gather_node = unsqueeze_node->input(0)->node(); + if (!is_kind(gather_node, "onnx::Gather") || gather_node->i(Symbol::attr("axis")) != 0 || + gather_node->output()->uses().size() != 1) + return; + + auto gather_inputs = gather_node->inputs(); + auto gather_data = gather_inputs[0]; + auto gather_indices = gather_inputs[1]; + auto shape_node = gather_data->node(); + if (!is_kind(shape_node, "onnx::Shape") || shape_node->output()->uses().size() != 1) return; + + auto current_shape_from = shape_node->input(); + if (!shape_from) + { + shape_from = current_shape_from; + } + else + { + if (shape_from != current_shape_from) return; + } + + auto constant_node = gather_indices->node(); + if (!is_kind(constant_node, "onnx::Constant")) return; + + auto gather_indices_val = constant_node->t(Symbol::attr("value")); + int64_t* data_ptr = gather_indices_val.data_ptr(); + if (gather_indices_val.dim() == 0) + { + gather_value.push_back(data_ptr[0]); + } + else + { + int element_size = gather_indices_val.element_size(); + for (int j = 0; j < element_size; ++j) + { + gather_value.push_back(data_ptr[j]); + } + } + + node_to_remove.insert(node_to_remove.end(), {unsqueeze_node, gather_node, shape_node}); + } + + // create constant value + auto graph = node->owningGraph(); + auto const_node = graph->create(Symbol::onnx("Constant")); + const_node->t_(Symbol::attr("value"), at::tensor(gather_value)); + auto first_node = node->owningGraph()->block()->nodes().front(); + if (const_node != first_node) const_node->insertBefore(first_node); + + // recreate shape node + auto shape_node = graph->create(Symbol::onnx("Shape"), {shape_from}); + shape_node->insertBefore(node); + + // create gather node + auto gather_node = + graph->create(Symbol::onnx("Gather"), {shape_node->output(), const_node->output()}); + + // insert into graph + gather_node->insertAfter(node); + node->output()->replaceAllUsesWith(gather_node->output()); + + for (auto n : node_to_remove) + { + n->destroy(); + } + } + + void MergeShapeConcate(Block* block) + { + auto graph = block->owningGraph(); + auto it = block->nodes().begin(); + while (it != block->nodes().end()) + { + auto node = *it; + ++it; + for (auto block : node->blocks()) + { + MergeShapeConcate(block); + } + + if (is_kind(node, "onnx::Concat")) + { + MergeShapeConcate(node); + } + } + } + + void MergeShapeConcate(const std::shared_ptr& graph) + { + MergeShapeConcate(graph->block()); + } + + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.h index 8656da63c2..13a67f0f47 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/merge_shape_concate.h @@ -3,12 +3,14 @@ #define _MERGE_SHAPE_CONCATE_H_ #include -namespace mmdeploy { -namespace torch_jit { -using torch::jit::Graph; +namespace mmdeploy +{ + namespace torch_jit + { + using torch::jit::Graph; -void MergeShapeConcate(const std::shared_ptr& graph); -} // namespace torch_jit + void MergeShapeConcate(const std::shared_ptr& graph); + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.cpp index f0ef5a5230..0b687c5083 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.cpp @@ -7,75 +7,91 @@ #include "utils.h" -namespace mmdeploy { -namespace torch_jit { - -using c10::Symbol; -using torch::jit::Block; -using torch::jit::IValue; -using torch::jit::Node; -using torch::jit::TensorType; -using torch::jit::Value; - -void RemoveReshapeChain(Node* node) { - // reshape->reshape => reshape - auto output = node->output(); - if (!(output->hasUses())) { - return; - } - auto uses = output->uses(); - - for (auto use : uses) { - if (!is_kind(use.user, "onnx::Reshape") || use.offset != 0) { - return; - } - } - - auto input = node->inputs()[0]; - output->replaceAllUsesWith(input); - - node->destroy(); -} - -void RemoveRedundantCast(Node* node) { - // Cast(type n)->Cast(type n) => Cast(type n) - - auto to_type = node->i(Symbol::attr("to")); - auto input = node->input(); - - auto input_node = input->node(); - if (is_kind(input_node, "onnx::Cast") && input_node->i(Symbol::attr("to")) == to_type) { - auto output = node->output(); - - output->replaceAllUsesWith(input); - node->destroy(); - } -} - -void ONNXPeephole(Block* block) { - auto graph = block->owningGraph(); - auto it = block->nodes().begin(); - while (it != block->nodes().end()) { - auto node = *it; - ++it; - for (auto block : node->blocks()) { - ONNXPeephole(block); - } - - if (is_kind(node, "onnx::Reshape")) { - RemoveReshapeChain(node); - } else if (is_kind(node, "onnx::Cast")) { - RemoveRedundantCast(node); - } - } -} - -void ONNXPeephole(const std::shared_ptr& graph) { - ONNXPeephole(graph->block()); - torch::jit::EliminateDeadCode( - graph->block(), true, - torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); -} - -} // namespace torch_jit +namespace mmdeploy +{ + namespace torch_jit + { + + using c10::Symbol; + using torch::jit::Block; + using torch::jit::IValue; + using torch::jit::Node; + using torch::jit::TensorType; + using torch::jit::Value; + + void RemoveReshapeChain(Node* node) + { + // reshape->reshape => reshape + auto output = node->output(); + if (!(output->hasUses())) + { + return; + } + auto uses = output->uses(); + + for (auto use : uses) + { + if (!is_kind(use.user, "onnx::Reshape") || use.offset != 0) + { + return; + } + } + + auto input = node->inputs()[0]; + output->replaceAllUsesWith(input); + + node->destroy(); + } + + void RemoveRedundantCast(Node* node) + { + // Cast(type n)->Cast(type n) => Cast(type n) + + auto to_type = node->i(Symbol::attr("to")); + auto input = node->input(); + + auto input_node = input->node(); + if (is_kind(input_node, "onnx::Cast") && input_node->i(Symbol::attr("to")) == to_type) + { + auto output = node->output(); + + output->replaceAllUsesWith(input); + node->destroy(); + } + } + + void ONNXPeephole(Block* block) + { + auto graph = block->owningGraph(); + auto it = block->nodes().begin(); + while (it != block->nodes().end()) + { + auto node = *it; + ++it; + for (auto block : node->blocks()) + { + ONNXPeephole(block); + } + + if (is_kind(node, "onnx::Reshape")) + { + RemoveReshapeChain(node); + } + else if (is_kind(node, "onnx::Cast")) + { + RemoveRedundantCast(node); + } + } + } + + void ONNXPeephole(const std::shared_ptr& graph) + { + ONNXPeephole(graph->block()); + torch::jit::EliminateDeadCode( + graph->block(), + true, + torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); + } + + } // namespace torch_jit } // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.h index f388da1bfa..21b7be15d1 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/onnx_peephole.h @@ -3,13 +3,15 @@ #define _ONNX_PEEPHOLE_H_ #include -namespace mmdeploy { -namespace torch_jit { -using torch::jit::Graph; +namespace mmdeploy +{ + namespace torch_jit + { + using torch::jit::Graph; -void ONNXPeephole(const std::shared_ptr& graph); + void ONNXPeephole(const std::shared_ptr& graph); -} // namespace torch_jit + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/utils.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/utils.h index 1c92cd15a1..147e5b1349 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/utils.h +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/utils.h @@ -3,18 +3,24 @@ #include -namespace mmdeploy { -namespace torch_jit { -using c10::Symbol; -using torch::jit::Node; +namespace mmdeploy +{ + namespace torch_jit + { + using c10::Symbol; + using torch::jit::Node; -inline bool is_kind(const Node* node, const Symbol& symbol) { return node->kind() == symbol; } + inline bool is_kind(const Node* node, const Symbol& symbol) + { + return node->kind() == symbol; + } -inline bool is_kind(const Node* node, const char* symbol_name) { - return is_kind(node, Symbol::fromQualString(symbol_name)); -} + inline bool is_kind(const Node* node, const char* symbol_name) + { + return is_kind(node, Symbol::fromQualString(symbol_name)); + } -} // namespace torch_jit + } // namespace torch_jit } // namespace mmdeploy #endif diff --git a/csrc/mmdeploy/codebase/common.h b/csrc/mmdeploy/codebase/common.h index 391f177590..f5d01c3bbe 100644 --- a/csrc/mmdeploy/codebase/common.h +++ b/csrc/mmdeploy/codebase/common.h @@ -9,69 +9,87 @@ #include "mmdeploy/core/utils/formatter.h" #include "mmdeploy/experimental/module_adapter.h" -namespace mmdeploy { +namespace mmdeploy +{ -using namespace framework; + using namespace framework; -class Context { - public: - explicit Context(const Value& config) { - MMDEPLOY_DEBUG("config: {}", config); - device_ = config["context"]["device"].get(); - stream_ = config["context"]["stream"].get(); - } + class Context + { + public: + explicit Context(const Value& config) + { + MMDEPLOY_DEBUG("config: {}", config); + device_ = config["context"]["device"].get(); + stream_ = config["context"]["stream"].get(); + } - Device& device() { return device_; } - Stream& stream() { return stream_; } + Device& device() + { + return device_; + } + Stream& stream() + { + return stream_; + } - protected: - Device device_; - Stream stream_; -}; + protected: + Device device_; + Stream stream_; + }; -template -class CodebaseCreator : public Creator { - public: - std::string_view name() const noexcept override { return Tag::name; } - std::unique_ptr Create(const Value& cfg) override { - constexpr auto key{"component"}; - if (!cfg.contains(key)) { - MMDEPLOY_ERROR("no key '{}' in config {}", key, cfg); - throw_exception(eInvalidArgument); - } - if (!cfg[key].is_string()) { - MMDEPLOY_ERROR("key '{}' is not a string", key); - throw_exception(eInvalidArgument); - } - auto postprocess_type = cfg[key].get(); - auto creator = gRegistry().Get(postprocess_type); - if (creator == nullptr) { - MMDEPLOY_ERROR("Could not found entry '{}' in {}. Available components: {}", postprocess_type, - Tag::name, gRegistry().List()); - throw_exception(eEntryNotFound); - } - return creator->Create(cfg); - } -}; + template + class CodebaseCreator : public Creator + { + public: + std::string_view name() const noexcept override + { + return Tag::name; + } + std::unique_ptr Create(const Value& cfg) override + { + constexpr auto key{"component"}; + if (!cfg.contains(key)) + { + MMDEPLOY_ERROR("no key '{}' in config {}", key, cfg); + throw_exception(eInvalidArgument); + } + if (!cfg[key].is_string()) + { + MMDEPLOY_ERROR("key '{}' is not a string", key); + throw_exception(eInvalidArgument); + } + auto postprocess_type = cfg[key].get(); + auto creator = gRegistry().Get(postprocess_type); + if (creator == nullptr) + { + MMDEPLOY_ERROR("Could not found entry '{}' in {}. Available components: {}", postprocess_type, Tag::name, gRegistry().List()); + throw_exception(eEntryNotFound); + } + return creator->Create(cfg); + } + }; -#define MMDEPLOY_DECLARE_CODEBASE(codebase_type, codebase_name) \ - class codebase_type : public Context { \ - public: \ - static constexpr const auto name = #codebase_name; \ - using type = std::unique_ptr; \ - explicit codebase_type(const Value& config) : Context(config) {} \ - }; \ - MMDEPLOY_DECLARE_REGISTRY(codebase_type, std::unique_ptr(const Value& config)); +#define MMDEPLOY_DECLARE_CODEBASE(codebase_type, codebase_name) \ + class codebase_type : public Context \ + { \ + public: \ + static constexpr const auto name = #codebase_name; \ + using type = std::unique_ptr; \ + explicit codebase_type(const Value& config) \ + : Context(config) \ + { \ + } \ + }; \ + MMDEPLOY_DECLARE_REGISTRY(codebase_type, std::unique_ptr(const Value& config)); -#define MMDEPLOY_REGISTER_CODEBASE(codebase) \ - using codebase##_##Creator = CodebaseCreator; \ - MMDEPLOY_REGISTER_CREATOR(Module, codebase##_##Creator) \ - MMDEPLOY_DEFINE_REGISTRY(codebase) +#define MMDEPLOY_REGISTER_CODEBASE(codebase) \ + using codebase##_##Creator = CodebaseCreator; \ + MMDEPLOY_REGISTER_CREATOR(Module, codebase##_##Creator) \ + MMDEPLOY_DEFINE_REGISTRY(codebase) -#define MMDEPLOY_REGISTER_CODEBASE_COMPONENT(codebase, component_type) \ - MMDEPLOY_REGISTER_FACTORY_FUNC(codebase, (component_type, 0), [](const Value& config) { \ - return CreateTask(component_type(config)); \ - }) +#define MMDEPLOY_REGISTER_CODEBASE_COMPONENT(codebase, component_type) \ + MMDEPLOY_REGISTER_FACTORY_FUNC(codebase, (component_type, 0), [](const Value& config) { return CreateTask(component_type(config)); }) } // namespace mmdeploy diff --git a/csrc/mmdeploy/codebase/mmaction/base_head.cpp b/csrc/mmdeploy/codebase/mmaction/base_head.cpp index 931c9663eb..2e541fd660 100644 --- a/csrc/mmdeploy/codebase/mmaction/base_head.cpp +++ b/csrc/mmdeploy/codebase/mmaction/base_head.cpp @@ -7,66 +7,75 @@ #include "mmdeploy/core/tensor.h" #include "mmdeploy/core/utils/device_utils.h" -namespace mmdeploy::mmaction { +namespace mmdeploy::mmaction +{ -class BaseHead : public MMAction { - public: - explicit BaseHead(const Value& cfg) : MMAction(cfg) { - if (cfg.contains("params")) { - topk_ = cfg["params"].value("topk", 1); - if (topk_ <= 0) { - MMDEPLOY_ERROR("'topk' should be greater than 0, but got '{}'", topk_); - throw_exception(eInvalidArgument); - } - } - } + class BaseHead : public MMAction + { + public: + explicit BaseHead(const Value& cfg) + : MMAction(cfg) + { + if (cfg.contains("params")) + { + topk_ = cfg["params"].value("topk", 1); + if (topk_ <= 0) + { + MMDEPLOY_ERROR("'topk' should be greater than 0, but got '{}'", topk_); + throw_exception(eInvalidArgument); + } + } + } - Result operator()(const Value& infer_res) { - MMDEPLOY_DEBUG("infer_res: {}", infer_res); - auto output = infer_res["output"].get(); + Result operator()(const Value& infer_res) + { + MMDEPLOY_DEBUG("infer_res: {}", infer_res); + auto output = infer_res["output"].get(); - if (!(output.shape().size() >= 2 && output.data_type() == DataType::kFLOAT)) { - MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", output.shape(), - (int)output.data_type()); - return Status(eNotSupported); - } + if (!(output.shape().size() >= 2 && output.data_type() == DataType::kFLOAT)) + { + MMDEPLOY_ERROR("unsupported `output` tensor, shape: {}, dtype: {}", output.shape(), (int)output.data_type()); + return Status(eNotSupported); + } - auto class_num = (int)output.shape(1); + auto class_num = (int)output.shape(1); - OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(output, kHost, stream())); - OUTCOME_TRY(stream().Wait()); + OUTCOME_TRY(auto _scores, MakeAvailableOnDevice(output, kHost, stream())); + OUTCOME_TRY(stream().Wait()); - return GetLabels(_scores, class_num); - } + return GetLabels(_scores, class_num); + } - private: - Value GetLabels(const Tensor& scores, int class_num) const { - auto scores_data = scores.data(); - Labels output; - output.reserve(topk_); - std::vector idx(class_num); - iota(begin(idx), end(idx), 0); - partial_sort(begin(idx), begin(idx) + topk_, end(idx), - [&](int i, int j) { return scores_data[i] > scores_data[j]; }); - for (int i = 0; i < topk_; ++i) { - auto label = Label{idx[i], scores_data[idx[i]]}; - MMDEPLOY_DEBUG("label_id: {}, score: {}", label.label_id, label.score); - output.push_back(label); - } - return to_value(std::move(output)); - } + private: + Value GetLabels(const Tensor& scores, int class_num) const + { + auto scores_data = scores.data(); + Labels output; + output.reserve(topk_); + std::vector idx(class_num); + iota(begin(idx), end(idx), 0); + partial_sort(begin(idx), begin(idx) + topk_, end(idx), [&](int i, int j) + { return scores_data[i] > scores_data[j]; }); + for (int i = 0; i < topk_; ++i) + { + auto label = Label{idx[i], scores_data[idx[i]]}; + MMDEPLOY_DEBUG("label_id: {}, score: {}", label.label_id, label.score); + output.push_back(label); + } + return to_value(std::move(output)); + } - private: - static constexpr const auto kHost = Device{0}; - int topk_{1}; -}; + private: + static constexpr const auto kHost = Device{0}; + int topk_{1}; + }; -MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMAction, BaseHead); + MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMAction, BaseHead); -using SlowFastHead = BaseHead; -MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMAction, SlowFastHead); + using SlowFastHead = BaseHead; + MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMAction, SlowFastHead); -using TSNHead = BaseHead; -MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMAction, TSNHead); + using TSNHead = BaseHead; + MMDEPLOY_REGISTER_CODEBASE_COMPONENT(MMAction, TSNHead); } // namespace mmdeploy::mmaction diff --git a/csrc/mmdeploy/codebase/mmaction/format_shape.cpp b/csrc/mmdeploy/codebase/mmaction/format_shape.cpp index 7d8c6ac5c6..ff65fe184d 100644 --- a/csrc/mmdeploy/codebase/mmaction/format_shape.cpp +++ b/csrc/mmdeploy/codebase/mmaction/format_shape.cpp @@ -7,122 +7,141 @@ using namespace std; -namespace mmdeploy::mmaction { - -FormatShape::FormatShape(const Value& args) { - input_format_ = args.value("input_format", std::string("")); - if (input_format_ != "NCHW" && input_format_ != "NCTHW") { - MMDEPLOY_ERROR("'input_format' should be 'NCHW' or 'NCTHW'"); - throw_exception(eInvalidArgument); - } - permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create(); -} - -Result FormatShape::MergeInputs(const std::vector& images, Tensor& inputs) { - auto N = static_cast(images.size()); - auto H = images[0].shape(1); - auto W = images[0].shape(2); - auto C = images[0].shape(3); - auto& device = operation::gContext().device(); - auto& stream = operation::gContext().stream(); - - TensorDesc desc = {device, DataType::kFLOAT, {N, H, W, C}}; - inputs = Tensor(desc); - auto offset = 0UL; - auto n_item = H * W * C; - auto copy_size = n_item * sizeof(float); - for (int i = 0; i < N; i++) { - auto src_buffer = images[i].buffer(); - auto dst_buffer = inputs.buffer(); - OUTCOME_TRY(stream.Copy(src_buffer, dst_buffer, copy_size, 0, offset)); - offset += copy_size; - } - return success(); -} - -Result FormatShape::Format(const std::vector& images, Tensor& output, int clip_len, - int num_clips) { - Tensor inputs; - OUTCOME_TRY(MergeInputs(images, inputs)); - - // Tensor dst; - if (input_format_ == "NCHW") { - OUTCOME_TRY(FormatNCHW(inputs, clip_len, num_clips, output)); - } - if (input_format_ == "NCTHW") { - OUTCOME_TRY(FormatNCTHW(inputs, clip_len, num_clips, output)); - } - - TensorShape expand_dim = output.shape(); - expand_dim.insert(expand_dim.begin(), 1); - output.Reshape(expand_dim); - - return success(); -} - -Result FormatShape::FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) { - const vector axes = {0, 3, 1, 2}; - OUTCOME_TRY(permute_.Apply(src, dst, axes)); - return success(); -} - -Result FormatShape::FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) { - auto N = src.shape(0); - auto H = src.shape(1); - auto W = src.shape(2); - auto C = src.shape(3); - int L = clip_len; - if (N % L != 0) { - return Status(eInvalidArgument); - } - int M = N / L; - src.Reshape({M, L, H, W, C}); - const vector axes = {0, 4, 1, 2, 3}; - OUTCOME_TRY(permute_.Apply(src, dst, axes)); - return success(); -} - -Result FormatShape::Apply(Value& data) { - MMDEPLOY_DEBUG("input: {}", data); - - if (!data.is_array()) { - MMDEPLOY_ERROR("input of format shape should be array"); - return Status(eInvalidArgument); - } - if (!(data[0].contains("imgs") || data[0].contains("img"))) { - MMDEPLOY_ERROR("input should contains imgs or img"); - return Status(eInvalidArgument); - } - - int n_image = data.size(); - int clip_len = data[0]["clip_len"].get(); - int num_clips = data[0]["num_clips"].get(); - std::vector images; - - if (data[0].contains("imgs")) { - int n_crop = data[0]["imgs"].size(); - int total = n_image * n_crop; - images.reserve(total); - for (int i = 0; i < n_crop; i++) { - for (int j = 0; j < n_image; j++) { - images.push_back(data[j]["imgs"][i].get()); - } +namespace mmdeploy::mmaction +{ + + FormatShape::FormatShape(const Value& args) + { + input_format_ = args.value("input_format", std::string("")); + if (input_format_ != "NCHW" && input_format_ != "NCTHW") + { + MMDEPLOY_ERROR("'input_format' should be 'NCHW' or 'NCTHW'"); + throw_exception(eInvalidArgument); + } + permute_ = ::mmdeploy::operation::Managed<::mmdeploy::operation::Permute>::Create(); } - } else if (data[0].contains("img")) { - images.reserve(n_image); - for (int i = 0; i < n_image; i++) { - images.push_back(data[i]["img"].get()); + + Result FormatShape::MergeInputs(const std::vector& images, Tensor& inputs) + { + auto N = static_cast(images.size()); + auto H = images[0].shape(1); + auto W = images[0].shape(2); + auto C = images[0].shape(3); + auto& device = operation::gContext().device(); + auto& stream = operation::gContext().stream(); + + TensorDesc desc = {device, DataType::kFLOAT, {N, H, W, C}}; + inputs = Tensor(desc); + auto offset = 0UL; + auto n_item = H * W * C; + auto copy_size = n_item * sizeof(float); + for (int i = 0; i < N; i++) + { + auto src_buffer = images[i].buffer(); + auto dst_buffer = inputs.buffer(); + OUTCOME_TRY(stream.Copy(src_buffer, dst_buffer, copy_size, 0, offset)); + offset += copy_size; + } + return success(); + } + + Result FormatShape::Format(const std::vector& images, Tensor& output, int clip_len, int num_clips) + { + Tensor inputs; + OUTCOME_TRY(MergeInputs(images, inputs)); + + // Tensor dst; + if (input_format_ == "NCHW") + { + OUTCOME_TRY(FormatNCHW(inputs, clip_len, num_clips, output)); + } + if (input_format_ == "NCTHW") + { + OUTCOME_TRY(FormatNCTHW(inputs, clip_len, num_clips, output)); + } + + TensorShape expand_dim = output.shape(); + expand_dim.insert(expand_dim.begin(), 1); + output.Reshape(expand_dim); + + return success(); } - } - Tensor dst; - data = Value{}; - OUTCOME_TRY(Format(images, dst, clip_len, num_clips)); - data["img"] = std::move(dst); + Result FormatShape::FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) + { + const vector axes = {0, 3, 1, 2}; + OUTCOME_TRY(permute_.Apply(src, dst, axes)); + return success(); + } - return success(); -} + Result FormatShape::FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst) + { + auto N = src.shape(0); + auto H = src.shape(1); + auto W = src.shape(2); + auto C = src.shape(3); + int L = clip_len; + if (N % L != 0) + { + return Status(eInvalidArgument); + } + int M = N / L; + src.Reshape({M, L, H, W, C}); + const vector axes = {0, 4, 1, 2, 3}; + OUTCOME_TRY(permute_.Apply(src, dst, axes)); + return success(); + } + + Result FormatShape::Apply(Value& data) + { + MMDEPLOY_DEBUG("input: {}", data); + + if (!data.is_array()) + { + MMDEPLOY_ERROR("input of format shape should be array"); + return Status(eInvalidArgument); + } + if (!(data[0].contains("imgs") || data[0].contains("img"))) + { + MMDEPLOY_ERROR("input should contains imgs or img"); + return Status(eInvalidArgument); + } + + int n_image = data.size(); + int clip_len = data[0]["clip_len"].get(); + int num_clips = data[0]["num_clips"].get(); + std::vector images; + + if (data[0].contains("imgs")) + { + int n_crop = data[0]["imgs"].size(); + int total = n_image * n_crop; + images.reserve(total); + for (int i = 0; i < n_crop; i++) + { + for (int j = 0; j < n_image; j++) + { + images.push_back(data[j]["imgs"][i].get()); + } + } + } + else if (data[0].contains("img")) + { + images.reserve(n_image); + for (int i = 0; i < n_image; i++) + { + images.push_back(data[i]["img"].get()); + } + } + + Tensor dst; + data = Value{}; + OUTCOME_TRY(Format(images, dst, clip_len, num_clips)); + data["img"] = std::move(dst); + + return success(); + } -MMDEPLOY_REGISTER_TRANSFORM(FormatShape); + MMDEPLOY_REGISTER_TRANSFORM(FormatShape); } // namespace mmdeploy::mmaction diff --git a/csrc/mmdeploy/codebase/mmaction/format_shape.h b/csrc/mmdeploy/codebase/mmaction/format_shape.h index 97e4f99356..7ea0326c84 100644 --- a/csrc/mmdeploy/codebase/mmaction/format_shape.h +++ b/csrc/mmdeploy/codebase/mmaction/format_shape.h @@ -12,27 +12,28 @@ #include "mmdeploy/operation/vision.h" #include "mmdeploy/preprocess/transform/transform.h" -namespace mmdeploy::mmaction { +namespace mmdeploy::mmaction +{ -class FormatShape : public Transform { - public: - explicit FormatShape(const Value& args); + class FormatShape : public Transform + { + public: + explicit FormatShape(const Value& args); - Result Apply(Value& data) override; + Result Apply(Value& data) override; - Result Format(const std::vector& images, Tensor& output, int clip_len, - int num_clips); + Result Format(const std::vector& images, Tensor& output, int clip_len, int num_clips); - Result FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst); + Result FormatNCHW(Tensor& src, int clip_len, int num_clips, Tensor& dst); - Result FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst); + Result FormatNCTHW(Tensor& src, int clip_len, int num_clips, Tensor& dst); - Result MergeInputs(const std::vector& images, Tensor& inputs); + Result MergeInputs(const std::vector& images, Tensor& inputs); - private: - std::string input_format_; - operation::Managed permute_; -}; + private: + std::string input_format_; + operation::Managed permute_; + }; } // namespace mmdeploy::mmaction diff --git a/csrc/mmdeploy/codebase/mmaction/mmaction.cpp b/csrc/mmdeploy/codebase/mmaction/mmaction.cpp index dc590a1800..7de226ecd1 100644 --- a/csrc/mmdeploy/codebase/mmaction/mmaction.cpp +++ b/csrc/mmdeploy/codebase/mmaction/mmaction.cpp @@ -2,8 +2,9 @@ #include "mmdeploy/codebase/mmaction/mmaction.h" -namespace mmdeploy::mmaction { +namespace mmdeploy::mmaction +{ -MMDEPLOY_REGISTER_CODEBASE(MMAction); + MMDEPLOY_REGISTER_CODEBASE(MMAction); } // namespace mmdeploy::mmaction diff --git a/csrc/mmdeploy/codebase/mmaction/mmaction.h b/csrc/mmdeploy/codebase/mmaction/mmaction.h index ef097e6f20..a3add86894 100644 --- a/csrc/mmdeploy/codebase/mmaction/mmaction.h +++ b/csrc/mmdeploy/codebase/mmaction/mmaction.h @@ -8,17 +8,19 @@ #include "mmdeploy/core/module.h" #include "mmdeploy/core/serialization.h" -namespace mmdeploy::mmaction { +namespace mmdeploy::mmaction +{ -struct Label { - int label_id; - float score; - MMDEPLOY_ARCHIVE_MEMBERS(label_id, score); -}; + struct Label + { + int label_id; + float score; + MMDEPLOY_ARCHIVE_MEMBERS(label_id, score); + }; -using Labels = std::vector