diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000..17d59c48 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,20 @@ +{ + "name": "rt_detr", + "dockerComposeFile": [ + "../docker-compose.yaml" + ], + "service": "rt_detr", + "workspaceFolder": "/home/ros/RT-DETR", + "shutdownAction": "stopCompose", + "customizations": { + "vscode": { + "settings": { + "remote.autoForwardPorts": false + }, + "extensions": [ + "ms-python.python", + "ms-python.debugpy" + ] + } + } +} \ No newline at end of file diff --git a/docker-compose.yaml b/docker-compose.yaml new file mode 100644 index 00000000..da6da0a5 --- /dev/null +++ b/docker-compose.yaml @@ -0,0 +1,27 @@ +services: + rt_detr: + container_name: rt_detr + image: rt_detr:latest + build: + context: ./ + dockerfile: docker/Dockerfile + args: + - TARGET_PATH=. + volumes: + - ./rtdetrv2_pytorch:/home/ros/RT-DETR + - ./benchmark:/home/ros/RT-DETR/benchmark + - /mnt/gr-nas/visionai-data:/home/ros/RT-DETR/data + - /tmp/.X11-unix:/tmp/.X11-unix:rw + env_file: docker/.env + privileged: true + working_dir: /home/ros/RT-DETR + user: ros + ipc: host + network_mode: host + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ['0'] + capabilities: [gpu] \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 00000000..9c6836f6 --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,11 @@ +FROM ghcr.io/greenroom-robotics/ros_builder:jazzy-latest-cuda + +COPY rtdetrv2_pytorch/requirements.txt /tmp/requirements.txt +RUN pip install -r /tmp/requirements.txt --user + +RUN sudo apt-get update && sudo apt-get install python3-opencv + +# Misc dependencies +RUN pip install debugpy + +CMD ["tail", "-f", "/dev/null"] \ No newline at end of file diff --git a/rtdetrv2_pytorch/.vscode/launch.json b/rtdetrv2_pytorch/.vscode/launch.json new file mode 100644 index 00000000..5a0a33b9 --- /dev/null +++ b/rtdetrv2_pytorch/.vscode/launch.json @@ -0,0 +1,25 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + + { + "name": "Python Debugger: Remote Attach", + "type": "debugpy", + "request": "attach", + "connect": { + "host": "localhost", + "port": 5678 + }, + "pathMappings": [ + { + "localRoot": "/home/ros/RT-DETR", + "remoteRoot": "/home/ros/RT-DETR" + } + ], + "justMyCode": false, + } + ] +} \ No newline at end of file diff --git a/rtdetrv2_pytorch/configs/dataset/gr_detection.yml b/rtdetrv2_pytorch/configs/dataset/gr_detection.yml new file mode 100644 index 00000000..5a65aca2 --- /dev/null +++ b/rtdetrv2_pytorch/configs/dataset/gr_detection.yml @@ -0,0 +1,42 @@ +task: detection + +evaluator: + type: CocoEvaluator + iou_types: ['bbox', ] + +num_classes: 5 # Has to be actual number of classes + 1 +remap_mscoco_category: False + + +train_dataloader: + type: DataLoader + dataset: + type: CocoDetection + img_folder: ./data + ann_file: ./data/datasets/image/experiments/e24-007-d24-002-traintestsplit/coco_labels/split_train.mapped.json + return_masks: False + transforms: + type: Compose + ops: ~ + shuffle: True + num_workers: 4 + drop_last: True + collate_fn: + type: BatchImageCollateFuncion + + +val_dataloader: + type: DataLoader + dataset: + type: CocoDetection + img_folder: ./data + ann_file: ./data/datasets/image/experiments/e24-007-d24-002-traintestsplit/coco_labels/split_test.mapped.json + return_masks: False + transforms: + type: Compose + ops: ~ + shuffle: False + num_workers: 4 + drop_last: False + collate_fn: + type: BatchImageCollateFuncion diff --git a/rtdetrv2_pytorch/configs/gr/include/dataloader.yml b/rtdetrv2_pytorch/configs/gr/include/dataloader.yml new file mode 100644 index 00000000..2bb532df --- /dev/null +++ b/rtdetrv2_pytorch/configs/gr/include/dataloader.yml @@ -0,0 +1,38 @@ + +train_dataloader: + dataset: + transforms: + ops: + - {type: RandomPhotometricDistort, p: 0.5} + - {type: RandomZoomOut, fill: 0} + - {type: RandomIoUCrop, p: 0.8} + - {type: SanitizeBoundingBoxes, min_size: 1} + - {type: RandomHorizontalFlip} + - {type: Resize, size: [640, 640], } + - {type: SanitizeBoundingBoxes, min_size: 1} + - {type: ConvertPILImage, dtype: 'float32', scale: True} + - {type: ConvertBoxes, fmt: 'cxcywh', normalize: True} + policy: + name: stop_epoch + epoch: 71 # epoch in [71, ~) stop `ops` + ops: ['RandomPhotometricDistort', 'RandomZoomOut', 'RandomIoUCrop'] + + collate_fn: + type: BatchImageCollateFuncion + scales: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800] + stop_epoch: 71 # epoch in [71, ~) stop `multiscales` + + shuffle: True + total_batch_size: 16 # total batch size equals to 16 (4 * 4) + num_workers: 4 + + +val_dataloader: + dataset: + transforms: + ops: + - {type: Resize, size: [640, 640]} + - {type: ConvertPILImage, dtype: 'float32', scale: True} + shuffle: False + total_batch_size: 32 + num_workers: 4 \ No newline at end of file diff --git a/rtdetrv2_pytorch/configs/gr/include/optimizer.yml b/rtdetrv2_pytorch/configs/gr/include/optimizer.yml new file mode 100644 index 00000000..189a9a1d --- /dev/null +++ b/rtdetrv2_pytorch/configs/gr/include/optimizer.yml @@ -0,0 +1,37 @@ + +use_amp: True +use_ema: True +ema: + type: ModelEMA + decay: 0.9999 + warmups: 2000 + + +epoches: 72 +clip_max_norm: 0.1 + + +optimizer: + type: AdamW + params: + - + params: '^(?=.*backbone)(?!.*norm).*$' + lr: 0.00001 + - + params: '^(?=.*(?:encoder|decoder))(?=.*(?:norm|bn)).*$' + weight_decay: 0. + + lr: 0.0001 + betas: [0.9, 0.999] + weight_decay: 0.0001 + + +lr_scheduler: + type: MultiStepLR + milestones: [1000] + gamma: 0.1 + + +lr_warmup_scheduler: + type: LinearWarmup + warmup_duration: 2000 \ No newline at end of file diff --git a/rtdetrv2_pytorch/configs/gr/include/rtdetrv2_r50vd.yml b/rtdetrv2_pytorch/configs/gr/include/rtdetrv2_r50vd.yml new file mode 100644 index 00000000..a5c14909 --- /dev/null +++ b/rtdetrv2_pytorch/configs/gr/include/rtdetrv2_r50vd.yml @@ -0,0 +1,83 @@ +task: detection + +model: RTDETR +criterion: RTDETRCriterionv2 +postprocessor: RTDETRPostProcessor + + +use_focal_loss: True +eval_spatial_size: [640, 640] # h w + + +RTDETR: + backbone: PResNet + encoder: HybridEncoder + decoder: RTDETRTransformerv2 + + +PResNet: + depth: 50 + variant: d + freeze_at: 0 + return_idx: [1, 2, 3] + num_stages: 4 + freeze_norm: True + pretrained: True + + +HybridEncoder: + in_channels: [512, 1024, 2048] + feat_strides: [8, 16, 32] + + # intra + hidden_dim: 256 + use_encoder_idx: [2] + num_encoder_layers: 1 + nhead: 8 + dim_feedforward: 1024 + dropout: 0. + enc_act: 'gelu' + + # cross + expansion: 1.0 + depth_mult: 1 + act: 'silu' + + +RTDETRTransformerv2: + feat_channels: [256, 256, 256] + feat_strides: [8, 16, 32] + hidden_dim: 256 + num_levels: 3 + + num_layers: 6 + num_queries: 300 + + num_denoising: 100 + label_noise_ratio: 0.5 + box_noise_scale: 1.0 # 1.0 0.4 + + eval_idx: -1 + + # NEW + num_points: [4, 4, 4] # [3,3,3] [2,2,2] + cross_attn_method: default # default, discrete + query_select_method: default # default, agnostic + + +RTDETRPostProcessor: + num_top_queries: 300 + + +RTDETRCriterionv2: + weight_dict: {loss_vfl: 1, loss_bbox: 5, loss_giou: 2,} + losses: ['vfl', 'boxes', ] + alpha: 0.75 + gamma: 2.0 + + matcher: + type: HungarianMatcher + weight_dict: {cost_class: 2, cost_bbox: 5, cost_giou: 2} + alpha: 0.25 + gamma: 2.0 + diff --git a/rtdetrv2_pytorch/configs/gr/rtdetrv2_r18vd_10e_gr.yml b/rtdetrv2_pytorch/configs/gr/rtdetrv2_r18vd_10e_gr.yml new file mode 100644 index 00000000..c459e9a0 --- /dev/null +++ b/rtdetrv2_pytorch/configs/gr/rtdetrv2_r18vd_10e_gr.yml @@ -0,0 +1,67 @@ +__include__: [ + '../dataset/gr_detection.yml', + './include/dataloader.yml', + './include/optimizer.yml', + './include/rtdetrv2_r50vd.yml', + '../runtime.yml', +] + + +output_dir: ./output/rtdetrv2_r18vd_10e + + +PResNet: + depth: 18 + freeze_at: -1 + freeze_norm: False + pretrained: True + + +HybridEncoder: + in_channels: [128, 256, 512] + hidden_dim: 256 + expansion: 0.5 + + +RTDETRTransformerv2: + num_layers: 3 + + +epoches: 10 + +optimizer: + type: AdamW + params: + - + params: '^(?=.*(?:norm|bn)).*$' + weight_decay: 0. + +eval_spatial_size: [1280, 1280] + +train_dataloader: + dataset: + transforms: + ops: + - {type: RandomPhotometricDistort, p: 0.5} + - {type: RandomZoomOut, fill: 0} + - {type: RandomIoUCrop, p: 0.8} + - {type: SanitizeBoundingBoxes, min_size: 1} + - {type: RandomHorizontalFlip} + - {type: Resize, size: [1280, 1280], } + - {type: SanitizeBoundingBoxes, min_size: 1} + - {type: ConvertPILImage, dtype: 'float32', scale: True} + - {type: ConvertBoxes, fmt: 'cxcywh', normalize: True} + policy: + epoch: 45 + collate_fn: + scales: ~ + + total_batch_size: 6 + +val_dataloader: + dataset: + transforms: + ops: + - {type: Resize, size: [1280, 1280]} + - {type: ConvertPILImage, dtype: 'float32', scale: True} + total_batch_size: 12 \ No newline at end of file diff --git a/rtdetrv2_pytorch/configs/gr/rtdetrv2_r18vd_50e_gr_overfit.yml b/rtdetrv2_pytorch/configs/gr/rtdetrv2_r18vd_50e_gr_overfit.yml new file mode 100644 index 00000000..49264663 --- /dev/null +++ b/rtdetrv2_pytorch/configs/gr/rtdetrv2_r18vd_50e_gr_overfit.yml @@ -0,0 +1,67 @@ +__include__: [ + '../dataset/gr_detection_pilot.yml', + './include/dataloader.yml', + './include/optimizer.yml', + './include/rtdetrv2_r50vd.yml', + '../runtime.yml', +] + + +output_dir: ./output/rtdetrv2_r18vd_50e_gr_pilot_tmp + + +PResNet: + depth: 18 + freeze_at: -1 + freeze_norm: False + pretrained: True + + +HybridEncoder: + in_channels: [128, 256, 512] + hidden_dim: 256 + expansion: 0.5 + + +RTDETRTransformerv2: + num_layers: 3 + + +epoches: 50 + +optimizer: + type: AdamW + params: + - + params: '^(?=.*(?:norm|bn)).*$' + weight_decay: 0. + +eval_spatial_size: [1280, 1280] + +train_dataloader: + dataset: + transforms: + ops: + - {type: RandomPhotometricDistort, p: 0.5} + - {type: RandomZoomOut, fill: 0} + - {type: RandomIoUCrop, p: 0.8} + - {type: SanitizeBoundingBoxes, min_size: 1} + - {type: RandomHorizontalFlip} + - {type: Resize, size: [1280, 1280], } + - {type: SanitizeBoundingBoxes, min_size: 1} + - {type: ConvertPILImage, dtype: 'float32', scale: True} + - {type: ConvertBoxes, fmt: 'cxcywh', normalize: True} + policy: + epoch: 45 + collate_fn: + scales: ~ + + total_batch_size: 6 + +val_dataloader: + dataset: + transforms: + ops: + - {type: Resize, size: [1280, 1280]} + - {type: ConvertPILImage, dtype: 'float32', scale: True} + total_batch_size: 12 \ No newline at end of file diff --git a/rtdetrv2_pytorch/references/deploy/rtdetrv2_onnxruntime.py b/rtdetrv2_pytorch/references/deploy/rtdetrv2_onnxruntime.py index 0f94dd29..fe5e16b1 100644 --- a/rtdetrv2_pytorch/references/deploy/rtdetrv2_onnxruntime.py +++ b/rtdetrv2_pytorch/references/deploy/rtdetrv2_onnxruntime.py @@ -9,7 +9,7 @@ from PIL import Image, ImageDraw -def draw(images, labels, boxes, scores, thrh = 0.6): +def draw(images, labels, boxes, scores, thrh = 0.7): for i, im in enumerate(images): draw = ImageDraw.Draw(im) @@ -17,9 +17,14 @@ def draw(images, labels, boxes, scores, thrh = 0.6): lab = labels[i][scr > thrh] box = boxes[i][scr > thrh] - for b in box: + for b,l,s in zip(box, lab, scr): draw.rectangle(list(b), outline='red',) - draw.text((b[0], b[1]), text=str(lab[i].item()), fill='blue', ) + draw.text((b[0], b[1]), text=str(np.round(l,2)), fill='blue', ) + draw.text((b[2], b[3]), text=str(np.round(s,2)), fill='black', ) + + # for b, l in zip(box,scr): + # draw.rectangle(list(b), outline='red',) + # draw.text((b[0], b[1]), text=f"{l.item():0.2f}", fill='blue', ) im.save(f'results_{i}.jpg') @@ -35,7 +40,7 @@ def main(args, ): orig_size = torch.tensor([w, h])[None] transforms = T.Compose([ - T.Resize((640, 640)), + T.Resize((1280, 1280)), T.ToTensor(), ]) im_data = transforms(im_pil)[None] diff --git a/rtdetrv2_pytorch/requirements.txt b/rtdetrv2_pytorch/requirements.txt index f1be6a6e..2f6f53c5 100644 --- a/rtdetrv2_pytorch/requirements.txt +++ b/rtdetrv2_pytorch/requirements.txt @@ -1,5 +1,12 @@ -torch>=2.0.1 -torchvision>=0.15.2 +torch==2.4.* +torchvision==0.19.* pycocotools PyYAML -tensorboard \ No newline at end of file +tensorboard +scipy + +# Onnx deployment and inference +onnx +onnxsim +onnxruntime-gpu +tqdm \ No newline at end of file diff --git a/rtdetrv2_pytorch/src/data/transforms/_transforms.py b/rtdetrv2_pytorch/src/data/transforms/_transforms.py index 53840c30..787d64f7 100644 --- a/rtdetrv2_pytorch/src/data/transforms/_transforms.py +++ b/rtdetrv2_pytorch/src/data/transforms/_transforms.py @@ -13,6 +13,10 @@ import PIL import PIL.Image +# Suppress DecompressionBombWarning +# caused by RandomZoomOut augmentation on 4K images +PIL.Image.MAX_IMAGE_PIXELS = None + from typing import Any, Dict, List, Optional from .._misc import convert_to_tv_tensor, _boxes_keys diff --git a/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py b/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py index dcac0df2..97a8949c 100644 --- a/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py +++ b/rtdetrv2_pytorch/src/zoo/rtdetr/rtdetr_postprocessor.py @@ -52,23 +52,54 @@ def forward(self, outputs, orig_target_sizes: torch.Tensor): bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) - if self.use_focal_loss: - scores = F.sigmoid(logits) - scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) - # TODO for older tensorrt - # labels = index % self.num_classes - labels = mod(index, self.num_classes) - index = index // self.num_classes - boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) + # Classes are 1-indexed (hence the num_classes + 1). The 0th index always has a + # near-zero score for a trained model + if not self.training: + + # Ignore 0-th index (see above comment) + scores = F.sigmoid(logits[:,:,1:]) + # This gives duplicate indices after integer division + scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) + index = index // (self.num_classes - 1) + + # Get the unique indexes and find max score for each index + batch_size = logits.shape[0] + max_scores = torch.zeros(scores.shape, device=scores.device, dtype=scores.dtype) + padded_unique_indices = torch.zeros(scores.shape, device=scores.device, dtype=torch.long) + for b in range(batch_size): + unique_box_indices, inverse_indices = torch.unique(index[b], dim=0, return_inverse=True) + a = torch.scatter_reduce(torch.zeros(self.num_top_queries, device=scores.device, dtype=scores.dtype), + dim=0, index=inverse_indices, src=scores[b], reduce='amax') + max_scores[b] = a + padded_unique_indices[b, :len(unique_box_indices)] = unique_box_indices + + scores = max_scores + index = padded_unique_indices + + # Probability of each class + soft_labels = F.softmax(logits[:,:,1:], dim=-1) + labels = soft_labels.gather(dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) + boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) + else: - scores = F.softmax(logits)[:, :, :-1] - scores, labels = scores.max(dim=-1) - if scores.shape[1] > self.num_top_queries: - scores, index = torch.topk(scores, self.num_top_queries, dim=-1) - labels = torch.gather(labels, dim=1, index=index) - boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) - + if self.use_focal_loss: + scores = F.sigmoid(logits) + scores, index = torch.topk(scores.flatten(1), self.num_top_queries, dim=-1) + # TODO for older tensorrt + # labels = index % self.num_classes + labels = mod(index, self.num_classes) # this will never index the 0-th score since the 0-th score is always near zero for a trained model + index = index // self.num_classes + boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) + + else: + scores = F.softmax(logits)[:, :, :-1] + scores, labels = scores.max(dim=-1) + if scores.shape[1] > self.num_top_queries: + scores, index = torch.topk(scores, self.num_top_queries, dim=-1) + labels = torch.gather(labels, dim=1, index=index) + boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) + # TODO for onnx export if self.deploy_mode: return labels, boxes, scores diff --git a/rtdetrv2_pytorch/test_images/1729581571653708530.jpg b/rtdetrv2_pytorch/test_images/1729581571653708530.jpg new file mode 100755 index 00000000..85ea0816 Binary files /dev/null and b/rtdetrv2_pytorch/test_images/1729581571653708530.jpg differ diff --git a/rtdetrv2_pytorch/test_images/53173655170_ecd2e5796a_o.jpg b/rtdetrv2_pytorch/test_images/53173655170_ecd2e5796a_o.jpg new file mode 100755 index 00000000..60af66e7 Binary files /dev/null and b/rtdetrv2_pytorch/test_images/53173655170_ecd2e5796a_o.jpg differ diff --git a/rtdetrv2_pytorch/test_images/frame_0017-6.png b/rtdetrv2_pytorch/test_images/frame_0017-6.png new file mode 100755 index 00000000..1e5b486a Binary files /dev/null and b/rtdetrv2_pytorch/test_images/frame_0017-6.png differ diff --git a/rtdetrv2_pytorch/test_images/raymarine_1080x1920.jpg b/rtdetrv2_pytorch/test_images/raymarine_1080x1920.jpg new file mode 100644 index 00000000..43867c02 Binary files /dev/null and b/rtdetrv2_pytorch/test_images/raymarine_1080x1920.jpg differ diff --git a/rtdetrv2_pytorch/test_images/whale_oi_frame_000887.jpg b/rtdetrv2_pytorch/test_images/whale_oi_frame_000887.jpg new file mode 100644 index 00000000..1180bceb Binary files /dev/null and b/rtdetrv2_pytorch/test_images/whale_oi_frame_000887.jpg differ diff --git a/rtdetrv2_pytorch/tools/export_onnx.py b/rtdetrv2_pytorch/tools/export_onnx.py index 0df6a606..c3a4dcb6 100644 --- a/rtdetrv2_pytorch/tools/export_onnx.py +++ b/rtdetrv2_pytorch/tools/export_onnx.py @@ -43,8 +43,8 @@ def forward(self, images, orig_target_sizes): model = Model() - data = torch.rand(1, 3, 640, 640) - size = torch.tensor([[640, 640]]) + data = torch.rand(1, 3, 1280, 1280) + size = torch.tensor([[1280, 1280]]) _ = model(data, size) dynamic_axes = { @@ -74,10 +74,10 @@ def forward(self, images, orig_target_sizes): import onnx import onnxsim dynamic = True - # input_shapes = {'images': [1, 3, 640, 640], 'orig_target_sizes': [1, 2]} if dynamic else None + # input_shapes = {'images': [1, 3, 1280, 1280], 'orig_target_sizes': [1, 2]} if dynamic else None input_shapes = {'images': data.shape, 'orig_target_sizes': size.shape} if dynamic else None - onnx_model_simplify, check = onnxsim.simplify(args.file_name, input_shapes=input_shapes, dynamic_input_shape=dynamic) - onnx.save(onnx_model_simplify, args.file_name) + onnx_model_simplify, check = onnxsim.simplify(args.output_file, input_shapes=input_shapes, dynamic_input_shape=dynamic) + onnx.save(onnx_model_simplify, args.output_file) print(f'Simplify onnx model {check}...') diff --git a/rtdetrv2_pytorch/tools/predict_onnx.py b/rtdetrv2_pytorch/tools/predict_onnx.py new file mode 100644 index 00000000..ee09e0c3 --- /dev/null +++ b/rtdetrv2_pytorch/tools/predict_onnx.py @@ -0,0 +1,108 @@ +import os +import sys +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) + +import torch +import onnxruntime as ort +from PIL import Image, ImageDraw, ImageFont +from torchvision.transforms import ToTensor +import argparse +import time +from pathlib import Path + +def read_img(img_path): + im = Image.open(img_path).convert('RGB') + im = im.resize((1280, 1280)) + im_data = ToTensor()(im)[None] + # (width, height) = im.size + # print(im_data.shape) + # print(width, height) + # size = torch.tensor([[width, height]]) + size = torch.tensor([[1280, 1280]]) + return im, im_data, size + +def createDirectory(directory): + try: + if not os.path.exists(directory): + os.makedirs(directory) + except OSError: + print("Error: Failed to create the directory.") + + +def main(args, ): + + print("ort.get_device()", ort.get_device()) + providers = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}), "CPUExecutionProvider"] + sess_options = ort.SessionOptions() + sess_options.enable_profiling = True + sess = ort.InferenceSession(args.model, sess_options=sess_options, providers=providers) + + img_path_list = [] + possible_img_extension = ['.jpg', '.jpeg', '.JPG', '.bmp', '.png'] # 이미지 확장자들 + for (root, dirs, files) in os.walk(args.img): + if len(files) > 0: + for file_name in files: + if os.path.splitext(file_name)[1] in possible_img_extension: + img_path = root + '/' + file_name + img_path_list.append(img_path) + + all_inf_time = [] + for img_path in img_path_list: + im, im_data, size = read_img(img_path) + + tic = time.time() + output = sess.run( + # output_names=['labels', 'boxes', 'scores'], + output_names=None, + input_feed={'images': im_data.data.numpy(), "orig_target_sizes": size.data.numpy()} + ) + inf_time = time.time() - tic + fps = float(1/inf_time) + print('img_path: {}, inf_time: {:.4f}, FPS: {:.2f}'.format(img_path, inf_time, fps)) + all_inf_time.append(inf_time) + + #print(type(output)) + #print([out.shape for out in output]) + + labels, boxes, scores = output + + draw = ImageDraw.Draw(im) # Draw on the original image + thrh = 0.6 + + for i in range(im_data.shape[0]): + + scr = scores[i] + lab = labels[i][scr > thrh] + box = boxes[i][scr > thrh] + + #print(i, sum(scr > thrh)) + + for b in box: + draw.rectangle(list(b), outline='red',) + # font = ImageFont.truetype("Arial.ttf", 15) + draw.text((b[0], b[1]), text=str(lab[i]), fill='yellow', ) + + # Save the original image with bounding boxes + file_dir = Path(img_path).parent.parent / 'output' + createDirectory(file_dir) + new_file_name = os.path.basename(img_path).split('.')[0] + '_onnx'+ os.path.splitext(img_path)[1] + new_file_path = file_dir / new_file_name + print('new_file_path: ', new_file_path) + print("================================================================================") + im.save(new_file_path) + + avr_time = sum(all_inf_time) / len(img_path_list) + avr_fps = float(1/avr_time) + print('All images count: {}'.format(len(img_path_list))) + print("Average Inferece time = {:.4f} s".format(inf_time)) + print("Average FPS = {:.2f} ".format(fps)) + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--img', '-i', type=str, ) # dir + parser.add_argument('--model', '-m', type=str, default='model.onnx') + + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/rtdetrv2_pytorch/tools/predict_torch.py b/rtdetrv2_pytorch/tools/predict_torch.py new file mode 100644 index 00000000..c13028ed --- /dev/null +++ b/rtdetrv2_pytorch/tools/predict_torch.py @@ -0,0 +1,188 @@ +import numpy as np +import argparse +from pathlib import Path +import sys +import time +from tqdm import tqdm +import cv2 +import json +import pickle + +sys.path.append('/home/ros/RT-DETR') +from src.core import YAMLConfig + +import torch +from torch import nn +from PIL import Image +from torchvision import transforms + +class ImageReader: + def __init__(self, resize=1280, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): + self.transform = transforms.Compose([ + transforms.Resize((resize, resize)), + transforms.ToTensor(), + ]) + self.resize = resize + self.pil_img = None + + def __call__(self, image_path, *args, **kwargs): + self.pil_img = Image.open(image_path).convert('RGB') + return self.transform(self.pil_img).unsqueeze(0) + + +class Model(nn.Module): + def __init__(self, confg=None, ckpt="") -> None: + super().__init__() + self.cfg = YAMLConfig(confg, resume=ckpt) + if ckpt: + checkpoint = torch.load(ckpt, map_location='cpu') + if 'ema' in checkpoint: + state = checkpoint['ema']['module'] + else: + state = checkpoint['model'] + else: + raise AttributeError('only support resume to load model.state_dict by now.') + + # NOTE load train mode state -> convert to deploy mode + self.cfg.model.load_state_dict(state) + + self.model = self.cfg.model.deploy() + self.postprocessor = self.cfg.postprocessor.deploy() + + def forward(self, images, orig_target_sizes): + outputs = self.model(images) + return self.postprocessor(outputs, orig_target_sizes) + +def inference(image_paths, id2cat): + device = torch.device(args.device) + reader = ImageReader() + model = Model(confg=args.config, ckpt=args.ckpt) + model.to(device=device) + torch.manual_seed(21) + + inf_times = [] + outputs = [] + for img_path in tqdm(image_paths, total=len(image_paths)): + # print(f'inference for: {img_path}') + + img = reader(img_path).to(device) + w,h = reader.pil_img.size + size = torch.tensor([[w,h]]).to(device) + + with torch.no_grad(): + start_time = time.time() + output = model(img, size) + labels, boxes, scores = output # (Batch, Preds) + + # Batch size = 1 + scr = scores[0].cpu().numpy() + if len(labels.shape) > 2: + lab = labels[0].cpu().numpy() + else: + lab = np.array([id2cat[l] for l in labels[0].cpu().numpy()]) + box = boxes[0].cpu().numpy() + + outputs.append({"scores": scr, "labels": lab, "boxes": box}) + inf_times.append(time.time() - start_time) + + fps = 1/np.mean(inf_times) + print(f"Inference time = {np.mean(inf_times):0.3f} s") + print(f"FPS = {fps:0.2f} ") + + return outputs, inf_times + +def draw_boxes(image_paths, output, score_th = 0.1): + + for img_path, output in zip(image_paths,outputs): + im = cv2.imread(str(img_path)) + + for s, l, b in zip(output["scores"], output["labels"], output["boxes"]): + if s < score_th: + continue + b = list(map(int, b)) # Convert box coordinates to integers + cv2.rectangle(im, (b[0], b[1]), (b[2], b[3]), color=(0, 0, 255), thickness=2) + + # Scale text size and thickness + font_scale = max(0.4, min(im.shape[1], im.shape[0]) / 1000) # Dynamic scaling + font_thickness = max(1, int(font_scale)) + + # Add label and score text + if isinstance(l, np.ndarray): + # label_text = str([np.round(p,2) for p in l.tolist()]) + label_text = str(np.round(s,2)) + else: + label_text = f"{id2cat[l]}: {s:.2f}" + text_size = cv2.getTextSize(label_text, fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=font_scale, thickness=font_thickness)[0] + text_origin = (b[0], b[1] - 10 if b[1] - 10 > 10 else b[1] + 10 + text_size[1]) + + cv2.putText(im, label_text, text_origin, fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), thickness=1) + + save_path = Path("vis_results") / img_path.name + save_path.parent.mkdir(exist_ok=True, parents=True) + cv2.imwrite(str(save_path), im) + +def read_coco(coco_json_path: str, data_dir: str): + with open(coco_json_path, 'r') as f: + coco_data = json.load(f) + + image_paths = [] + for image_entry in coco_data.get('images', []): + file_name = image_entry.get('file_name') + if file_name: + image_path = data_dir / Path(file_name) + + assert image_path.exists(), (f"Image does not exist: {image_path}") + image_paths.append(str(image_path)) + + id2cat = {cat['id']:cat['name'] for cat in coco_data.get('categories')} + return image_paths, id2cat + +def save_preds(image_paths, outputs, data_dir, save_path): + results = {} + for img_path, output in zip(image_paths,outputs): + rel_img_path = Path(img_path).resolve().relative_to(data_dir) + results[str(rel_img_path)] = output + + with open(save_path, 'wb') as f: + pickle.dump(results, f) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", '-c', type=str) + parser.add_argument("--ckpt", '-w', type=str) # pth + parser.add_argument("--images", '-i', type=str, help='folder of images', default=None) + parser.add_argument("--coco", type=str, help='inference on the image files from a coco json', default=None) + parser.add_argument("--device", default="cuda:0") + + # TODO: Implement this + parser.add_argument("--batch-size", type=int, default=1) + + # Save preds + parser.add_argument("--data-dir", default="/home/ros/RT-DETR", help="Path to data folder") + parser.add_argument("--save_path", default="/home/ros/RT-DETR/output2.pkl", help="Path to output file") + + args = parser.parse_args() + + if args.coco is not None: + image_paths, id2cat = read_coco(args.coco, args.data_dir) + elif args.images is not None: + image_paths = [p for p in Path(args.images).glob('*') if p.suffix in [".jpg", ".png"]] + id2cat = { + 1:"marine_mammal", + 2:"marker", + 3:"unknown", + 4:"vessel", + } + else: + raise ValueError("Specify either coco json path or image directory") + + outputs, inf_times = inference(image_paths, id2cat) + + # save_preds(image_paths, outputs, args.data_dir, args.save_path) + + # draw cv2 boxes + draw_boxes(image_paths, outputs) + + + + \ No newline at end of file