diff --git a/README.md b/README.md index 7425b05..8f8787c 100644 --- a/README.md +++ b/README.md @@ -10,19 +10,27 @@ for coco's AP metrics, especially when dealing with a high number of instances i ### Comparison -For our use case with a test dataset of 1500 images that contains up to 2000 instances per image we saw up to a 100x faster -evaluation using faster-coco-eval (FCE) compared to the original pycocotools code. -```` -Seg eval pycocotools 4 hours -Seg eval FCE: 2.5 min +For our use case with a test dataset of 5000 images from the coco val dataset. +Testing was carried out using the mmdetection framework and the eval_metric.py script. The indicators are presented below. + +Visualization of testing **comparison.ipynb** available in directory [examples/comparison](./examples/comparison/comparison.ipynb) +Tested with yolo3 model (bbox eval) and yoloact model (segm eval) + +Type | COCOeval | COCOeval_faster | Profit +-----|----------|---------------- | ------ +bbox | 22.854 sec. | 8.714 sec. | more than 2x +segm | 35.356 sec. | 18.403 sec. | 2x -BBox eval pycocotools: 4 hours -BBox eval FCE: 2 min -```` # Getting started -### Install +## Local build +Build from source +```bash +python3 setup.py sdist +``` + +## Install Install form source ```bash pip3 install git+https://github.com/MiXaiLL76/faster_coco_eval @@ -55,7 +63,7 @@ For usage, look at the original `COCOEval` [class documentation.](https://github - [x] Append unittest - [x] Append ROC / AUC curves - [x] Check if it works on windows -- [ ] Append fp fn error analysis +- [ ] Append fp fn error analysis # License The original module was licensed with apache 2, I will continue with the same license. diff --git a/examples/comparison/README.md b/examples/comparison/README.md new file mode 100644 index 0000000..ab4920e --- /dev/null +++ b/examples/comparison/README.md @@ -0,0 +1,6 @@ +Visualization of testing [comparison.ipynb](./comparison.ipynb) + +Type | COCOeval | COCOeval_faster | Profit +-----|----------|---------------- | ------ +bbox | 22.854 sec. | 8.714 sec. | more than 2x +segm | 35.356 sec. | 18.403 sec. | 2x \ No newline at end of file diff --git a/examples/comparison/coco_fast.py b/examples/comparison/coco_fast.py new file mode 100644 index 0000000..04dbcbc --- /dev/null +++ b/examples/comparison/coco_fast.py @@ -0,0 +1,225 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import contextlib +import io +import itertools +import logging +import os.path as osp +import tempfile +import warnings +from collections import OrderedDict + +import mmcv +import numpy as np +from mmcv.utils import print_log +from terminaltables import AsciiTable + +from mmdet.core import eval_recalls +from mmdet.datasets.builder import DATASETS +from mmdet.datasets.coco import CocoDataset +from faster_coco_eval import COCOeval_faster + +@DATASETS.register_module() +class FasterCocoDataset(CocoDataset): + def evaluate_det_segm(self, + results, + result_files, + coco_gt, + metrics, + logger=None, + classwise=False, + proposal_nums=(100, 300, 1000), + iou_thrs=None, + metric_items=None): + """Instance segmentation and object detection evaluation in COCO + protocol. + Args: + results (list[list | tuple | dict]): Testing results of the + dataset. + result_files (dict[str, str]): a dict contains json file path. + coco_gt (COCO): COCO API object with ground truth annotation. + metric (str | list[str]): Metrics to be evaluated. Options are + 'bbox', 'segm', 'proposal', 'proposal_fast'. + logger (logging.Logger | str | None): Logger used for printing + related information during evaluation. Default: None. + classwise (bool): Whether to evaluating the AP for each class. + proposal_nums (Sequence[int]): Proposal number used for evaluating + recalls, such as recall@100, recall@1000. + Default: (100, 300, 1000). + iou_thrs (Sequence[float], optional): IoU threshold used for + evaluating recalls/mAPs. If set to a list, the average of all + IoUs will also be computed. If not specified, [0.50, 0.55, + 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used. + Default: None. + metric_items (list[str] | str, optional): Metric items that will + be returned. If not specified, ``['AR@100', 'AR@300', + 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be + used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75', + 'mAP_s', 'mAP_m', 'mAP_l']`` will be used when + ``metric=='bbox' or metric=='segm'``. + Returns: + dict[str, float]: COCO style evaluation metric. + """ + if iou_thrs is None: + iou_thrs = np.linspace( + .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True) + if metric_items is not None: + if not isinstance(metric_items, list): + metric_items = [metric_items] + + eval_results = OrderedDict() + for metric in metrics: + msg = f'Evaluating {metric}...' + if logger is None: + msg = '\n' + msg + print_log(msg, logger=logger) + + if metric == 'proposal_fast': + if isinstance(results[0], tuple): + raise KeyError('proposal_fast is not supported for ' + 'instance segmentation result.') + ar = self.fast_eval_recall( + results, proposal_nums, iou_thrs, logger='silent') + log_msg = [] + for i, num in enumerate(proposal_nums): + eval_results[f'AR@{num}'] = ar[i] + log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}') + log_msg = ''.join(log_msg) + print_log(log_msg, logger=logger) + continue + + iou_type = 'bbox' if metric == 'proposal' else metric + if metric not in result_files: + raise KeyError(f'{metric} is not in results') + try: + predictions = mmcv.load(result_files[metric]) + if iou_type == 'segm': + # Refer to https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/coco.py#L331 # noqa + # When evaluating mask AP, if the results contain bbox, + # cocoapi will use the box area instead of the mask area + # for calculating the instance area. Though the overall AP + # is not affected, this leads to different + # small/medium/large mask AP results. + for x in predictions: + x.pop('bbox') + warnings.simplefilter('once') + warnings.warn( + 'The key "bbox" is deleted for more accurate mask AP ' + 'of small/medium/large instances since v2.12.0. This ' + 'does not change the overall mAP calculation.', + UserWarning) + coco_det = coco_gt.loadRes(predictions) + except IndexError: + print_log( + 'The testing results of the whole dataset is empty.', + logger=logger, + level=logging.ERROR) + break + + cocoEval = COCOeval_faster(coco_gt, coco_det, iou_type) + cocoEval.params.catIds = self.cat_ids + cocoEval.params.imgIds = self.img_ids + cocoEval.params.maxDets = list(proposal_nums) + cocoEval.params.iouThrs = iou_thrs + # mapping of cocoEval.stats + coco_metric_names = { + 'mAP': 0, + 'mAP_50': 1, + 'mAP_75': 2, + 'mAP_s': 3, + 'mAP_m': 4, + 'mAP_l': 5, + 'AR@100': 6, + 'AR@300': 7, + 'AR@1000': 8, + 'AR_s@1000': 9, + 'AR_m@1000': 10, + 'AR_l@1000': 11 + } + if metric_items is not None: + for metric_item in metric_items: + if metric_item not in coco_metric_names: + raise KeyError( + f'metric item {metric_item} is not supported') + + if metric == 'proposal': + cocoEval.params.useCats = 0 + cocoEval.evaluate() + cocoEval.accumulate() + + # Save coco summarize print information to logger + redirect_string = io.StringIO() + with contextlib.redirect_stdout(redirect_string): + cocoEval.summarize() + print_log('\n' + redirect_string.getvalue(), logger=logger) + + if metric_items is None: + metric_items = [ + 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000', + 'AR_m@1000', 'AR_l@1000' + ] + + for item in metric_items: + val = float( + f'{cocoEval.stats[coco_metric_names[item]]:.3f}') + eval_results[item] = val + else: + cocoEval.evaluate() + cocoEval.accumulate() + + # Save coco summarize print information to logger + redirect_string = io.StringIO() + with contextlib.redirect_stdout(redirect_string): + cocoEval.summarize() + print_log('\n' + redirect_string.getvalue(), logger=logger) + + if classwise: # Compute per-category AP + # Compute per-category AP + # from https://github.com/facebookresearch/detectron2/ + precisions = cocoEval.eval['precision'] + # precision: (iou, recall, cls, area range, max dets) + assert len(self.cat_ids) == precisions.shape[2] + + results_per_category = [] + for idx, catId in enumerate(self.cat_ids): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + nm = self.coco.loadCats(catId)[0] + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + results_per_category.append( + (f'{nm["name"]}', f'{float(ap):0.3f}')) + + num_columns = min(6, len(results_per_category) * 2) + results_flatten = list( + itertools.chain(*results_per_category)) + headers = ['category', 'AP'] * (num_columns // 2) + results_2d = itertools.zip_longest(*[ + results_flatten[i::num_columns] + for i in range(num_columns) + ]) + table_data = [headers] + table_data += [result for result in results_2d] + table = AsciiTable(table_data) + print_log('\n' + table.table, logger=logger) + + if metric_items is None: + metric_items = [ + 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l' + ] + + for metric_item in metric_items: + key = f'{metric}_{metric_item}' + val = float( + f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}' + ) + eval_results[key] = val + ap = cocoEval.stats[:6] + eval_results[f'{metric}_mAP_copypaste'] = ( + f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} ' + f'{ap[4]:.3f} {ap[5]:.3f}') + + return eval_results \ No newline at end of file diff --git a/examples/comparison/comparison.ipynb b/examples/comparison/comparison.ipynb new file mode 100644 index 0000000..2491ab6 --- /dev/null +++ b/examples/comparison/comparison.ipynb @@ -0,0 +1,397 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e6c1a344-fdad-47b3-bd20-f1f90ffff2c9", + "metadata": {}, + "source": [ + "https://mmdetection.readthedocs.io/en/latest/tutorials/test_results_submission.html" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "880b2f8f-8d5a-4d34-8cdd-6cd95495d635", + "metadata": {}, + "outputs": [], + "source": [ + "!pip3 install git+https://github.com/MiXaiLL76/faster_coco_eval" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9662b48e-ebd2-4cb3-a13c-1298d989937c", + "metadata": {}, + "outputs": [], + "source": [ + "!mkdir -pv data/coco/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac7b6b1f-4059-43f6-a6d5-64363fb5d289", + "metadata": {}, + "outputs": [], + "source": [ + "!wget -P data/coco/ http://images.cocodataset.org/annotations/annotations_trainval2017.zip\n", + "!wget -P data/coco/ http://images.cocodataset.org/zips/val2017.zip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbcc1e4f-069c-4ad2-9c26-864bf111016a", + "metadata": {}, + "outputs": [], + "source": [ + "!unzip data/coco/annotations_trainval2017.zip -d data/coco/\n", + "!unzip data/coco/val2017.zip -d data/coco/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a1403ee-eff8-4e80-9e9a-69ba74d736f3", + "metadata": {}, + "outputs": [], + "source": [ + "!rm -rf data/coco/*.zip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54ea5c5a-d294-45b2-a858-560050d9ecf8", + "metadata": {}, + "outputs": [], + "source": [ + "yolo3_model_path = \"https://download.openmmlab.com/mmdetection/v2.0/yolo/yolov3_d53_320_273e_coco/yolov3_d53_320_273e_coco-421362b6.pth\"\n", + "!wget -P model {yolo3_model_path}" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e67a6ecb-d868-408e-8e52-661bd25df496", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cfg_path='configs/yolo/yolov3_d53_320_273e_coco.py'\n" + ] + } + ], + "source": [ + "import os.path as osp\n", + "\n", + "_BASE_CONFIG_DIR = \"configs/\"\n", + "CONFIG_FILE = \"yolo/yolov3_d53_320_273e_coco.py\"\n", + "CHECKPOINT_FILE = \"model/yolov3_d53_320_273e_coco-421362b6.pth\"\n", + "WORK_DIR = \".\"\n", + "use_cpu = False\n", + "\n", + "cfg_path = osp.join(_BASE_CONFIG_DIR, CONFIG_FILE)\n", + "print(f\"{cfg_path=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e37d2b3c-72bb-4586-b379-2f1281978f05", + "metadata": {}, + "outputs": [], + "source": [ + "_dop = \"\"\n", + "if use_cpu:\n", + " _dop += f\" --gpu-id -1 \"\n", + "\n", + "!python3 test.py \\\n", + " {cfg_path} \\\n", + " {CHECKPOINT_FILE} \\\n", + " --format-only {_dop}\\\n", + " --cfg-options data.test.ann_file=data/coco/annotations/instances_val2017.json \\\n", + " data.test.img_prefix=data/coco/val2017 \\\n", + " --out yolo_result.pkl" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4eca90de-fd3d-4dc7-b3f4-ada1300d0b90", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading annotations into memory...\n", + "Done (t=0.28s)\n", + "creating index...\n", + "index created!\n", + "Data uploaded for 1.057 sec.\n", + "\n", + "Evaluating bbox...\n", + "Loading and preparing results...\n", + "DONE (t=0.47s)\n", + "creating index...\n", + "index created!\n", + "Running per image evaluation...\n", + "Evaluate annotation type *bbox*\n", + "DONE (t=16.69s).\n", + "Accumulating evaluation results...\n", + "DONE (t=3.02s).\n", + "\n", + " Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.279\n", + " Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 0.491\n", + " Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.283\n", + " Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.105\n", + " Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.301\n", + " Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.438\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.395\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.395\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.395\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.185\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.423\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.574\n", + "\n", + "OrderedDict([('bbox_mAP', 0.279), ('bbox_mAP_50', 0.491), ('bbox_mAP_75', 0.283), ('bbox_mAP_s', 0.105), ('bbox_mAP_m', 0.301), ('bbox_mAP_l', 0.438), ('bbox_mAP_copypaste', '0.279 0.491 0.283 0.105 0.301 0.438')])\n", + "Data validate for 22.854 sec.\n", + "CPU times: user 160 ms, sys: 22.4 ms, total: 183 ms\n", + "Wall time: 26 s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "!python3 eval_metric.py {cfg_path} yolo_result.pkl \\\n", + " --eval bbox \\\n", + " --cfg-options data.test.ann_file=data/coco/annotations/instances_val2017.json" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8bda9c97-ecce-4ae2-ab1d-c36680fb3a70", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading annotations into memory...\n", + "Done (t=0.28s)\n", + "creating index...\n", + "index created!\n", + "Data uploaded for 1.065 sec.\n", + "\n", + "Evaluating bbox...\n", + "Loading and preparing results...\n", + "DONE (t=0.47s)\n", + "creating index...\n", + "index created!\n", + "\n", + "\n", + "OrderedDict([('bbox_mAP', 0.279), ('bbox_mAP_50', 0.491), ('bbox_mAP_75', 0.283), ('bbox_mAP_s', 0.105), ('bbox_mAP_m', 0.301), ('bbox_mAP_l', 0.438), ('bbox_mAP_copypaste', '0.279 0.491 0.283 0.105 0.301 0.438')])\n", + "Data validate for 8.714 sec.\n", + "CPU times: user 91.8 ms, sys: 5.82 ms, total: 97.6 ms\n", + "Wall time: 11.8 s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "!python3 eval_metric.py {cfg_path} yolo_result.pkl \\\n", + " --eval bbox \\\n", + " --cfg-options data.test.ann_file=data/coco/annotations/instances_val2017.json data.test.type='FasterCocoDataset'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6ce86d4-f61e-4c18-821c-676d66a10d08", + "metadata": {}, + "outputs": [], + "source": [ + "yoloact_model_path = \"https://download.openmmlab.com/mmdetection/v2.0/yolact/yolact_r50_1x8_coco/yolact_r50_1x8_coco_20200908-f38d58df.pth\"\n", + "!wget -P model {yoloact_model_path}" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "5685b448-6079-4f0a-a781-a3260b59e579", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "cfg_path='configs/yolact/yolact_r50_1x8_coco.py'\n" + ] + } + ], + "source": [ + "CONFIG_FILE = \"yolact/yolact_r50_1x8_coco.py\"\n", + "CHECKPOINT_FILE = \"model/yolact_r50_1x8_coco_20200908-f38d58df.pth\"\n", + "\n", + "cfg_path = osp.join(_BASE_CONFIG_DIR, CONFIG_FILE)\n", + "print(f\"{cfg_path=}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0129cb16-65ac-4192-b933-9d38fc4d5aee", + "metadata": {}, + "outputs": [], + "source": [ + "_dop = \"\"\n", + "if use_cpu:\n", + " _dop += f\" --gpu-id -1 \"\n", + "\n", + "!python3 test.py \\\n", + " {cfg_path} \\\n", + " {CHECKPOINT_FILE} \\\n", + " --format-only {_dop}\\\n", + " --cfg-options data.test.ann_file=data/coco/annotations/instances_val2017.json \\\n", + " data.test.img_prefix=data/coco/val2017 \\\n", + " --out yoloact_result.pkl" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "91c15c82-b82a-46c2-8171-f517ab677529", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading annotations into memory...\n", + "Done (t=0.28s)\n", + "creating index...\n", + "index created!\n", + "Data uploaded for 2.235 sec.\n", + "\n", + "Evaluating segm...\n", + "/opt/conda/lib/python3.9/site-packages/mmdet/datasets/coco.py:470: UserWarning: The key \"bbox\" is deleted for more accurate mask AP of small/medium/large instances since v2.12.0. This does not change the overall mAP calculation.\n", + " warnings.warn(\n", + "Loading and preparing results...\n", + "DONE (t=1.21s)\n", + "creating index...\n", + "index created!\n", + "Running per image evaluation...\n", + "Evaluate annotation type *segm*\n", + "DONE (t=22.02s).\n", + "Accumulating evaluation results...\n", + "/opt/conda/lib/python3.9/site-packages/pycocotools/cocoeval.py:378: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.\n", + "Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float)\n", + "DONE (t=3.49s).\n", + "\n", + " Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.290\n", + " Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 0.486\n", + " Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.296\n", + " Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.100\n", + " Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.315\n", + " Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.465\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.392\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.392\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.392\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.176\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.439\n", + " Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.568\n", + "\n", + "OrderedDict([('segm_mAP', 0.29), ('segm_mAP_50', 0.486), ('segm_mAP_75', 0.296), ('segm_mAP_s', 0.1), ('segm_mAP_m', 0.315), ('segm_mAP_l', 0.465), ('segm_mAP_copypaste', '0.290 0.486 0.296 0.100 0.315 0.465')])\n", + "Data validate for 35.356 sec.\n", + "CPU times: user 235 ms, sys: 22 ms, total: 257 ms\n", + "Wall time: 39.6 s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "!python3 eval_metric.py {cfg_path} yoloact_result.pkl \\\n", + " --eval segm \\\n", + " --cfg-options data.test.ann_file=data/coco/annotations/instances_val2017.json" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "312e736f-7108-4788-a47f-8b9bba26e10c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loading annotations into memory...\n", + "Done (t=0.28s)\n", + "creating index...\n", + "index created!\n", + "Data uploaded for 2.305 sec.\n", + "\n", + "Evaluating segm...\n", + "/home/rdl/storage/eval/coco_fast.py:105: UserWarning: The key \"bbox\" is deleted for more accurate mask AP of small/medium/large instances since v2.12.0. This does not change the overall mAP calculation.\n", + " warnings.warn(\n", + "Loading and preparing results...\n", + "DONE (t=1.25s)\n", + "creating index...\n", + "index created!\n", + "\n", + "\n", + "OrderedDict([('segm_mAP', 0.29), ('segm_mAP_50', 0.486), ('segm_mAP_75', 0.296), ('segm_mAP_s', 0.1), ('segm_mAP_m', 0.315), ('segm_mAP_l', 0.465), ('segm_mAP_copypaste', '0.290 0.486 0.296 0.100 0.315 0.465')])\n", + "Data validate for 18.403 sec.\n", + "CPU times: user 103 ms, sys: 52.8 ms, total: 156 ms\n", + "Wall time: 22.7 s\n" + ] + } + ], + "source": [ + "%%time\n", + "\n", + "!python3 eval_metric.py {cfg_path} yoloact_result.pkl \\\n", + " --eval segm \\\n", + " --cfg-options data.test.ann_file=data/coco/annotations/instances_val2017.json data.test.type='FasterCocoDataset'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e85b179-87cc-4c7d-999a-a501ff9bb1d2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/comparison/eval_metric.py b/examples/comparison/eval_metric.py new file mode 100644 index 0000000..4d5c3d6 --- /dev/null +++ b/examples/comparison/eval_metric.py @@ -0,0 +1,96 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import mmcv +from mmcv import Config, DictAction + +from mmdet.datasets import build_dataset +from mmdet.utils import replace_cfg_vals, update_data_root +import coco_fast +import time + +def parse_args(): + parser = argparse.ArgumentParser(description='Evaluate metric of the ' + 'results saved in pkl format') + parser.add_argument('config', help='Config of the model') + parser.add_argument('pkl_results', help='Results in pickle format') + parser.add_argument( + '--format-only', + action='store_true', + help='Format the output results without perform evaluation. It is' + 'useful when you want to format the result to a specific format and ' + 'submit it to the test server') + parser.add_argument( + '--eval', + type=str, + nargs='+', + help='Evaluation metrics, which depends on the dataset, e.g., "bbox",' + ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--eval-options', + nargs='+', + action=DictAction, + help='custom options for evaluation, the key-value pair in xxx=yyy ' + 'format will be kwargs for dataset.evaluate() function') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + cfg = Config.fromfile(args.config) + + # replace the ${key} with the value of cfg.key + cfg = replace_cfg_vals(cfg) + + # update data root according to MMDET_DATASETS + update_data_root(cfg) + + assert args.eval or args.format_only, ( + 'Please specify at least one operation (eval/format the results) with ' + 'the argument "--eval", "--format-only"') + if args.eval and args.format_only: + raise ValueError('--eval and --format_only cannot be both specified') + + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + cfg.data.test.test_mode = True + + ts = time.time() + dataset = build_dataset(cfg.data.test) + outputs = mmcv.load(args.pkl_results) + te = time.time() + print(f"Data uploaded for {te - ts:.3f} sec.") + + kwargs = {} if args.eval_options is None else args.eval_options + if args.format_only: + dataset.format_results(outputs, **kwargs) + + if args.eval: + ts = time.time() + eval_kwargs = cfg.get('evaluation', {}).copy() + # hard-code way to remove EvalHook args + for key in [ + 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', + 'rule' + ]: + eval_kwargs.pop(key, None) + eval_kwargs.update(dict(metric=args.eval, **kwargs)) + print(dataset.evaluate(outputs, **eval_kwargs)) + + te = time.time() + print(f"Data validate for {te - ts:.3f} sec.") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/comparison/test.py b/examples/comparison/test.py new file mode 100644 index 0000000..e076735 --- /dev/null +++ b/examples/comparison/test.py @@ -0,0 +1,283 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import os +import os.path as osp +import time +import warnings + +import mmcv +import torch +from mmcv import Config, DictAction +from mmcv.cnn import fuse_conv_bn +from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, + wrap_fp16_model) + +from mmdet.apis import multi_gpu_test, single_gpu_test +from mmdet.datasets import (build_dataloader, build_dataset, + replace_ImageToTensor) +from mmdet.models import build_detector +from mmdet.utils import (build_ddp, build_dp, compat_cfg, get_device, + replace_cfg_vals, setup_multi_processes, + update_data_root) + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMDet test (and eval) a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument( + '--work-dir', + help='the directory to save the file containing evaluation metrics') + parser.add_argument('--out', help='output result file in pickle format') + parser.add_argument( + '--fuse-conv-bn', + action='store_true', + help='Whether to fuse conv and bn, this will slightly increase' + 'the inference speed') + parser.add_argument( + '--gpu-ids', + type=int, + nargs='+', + help='(Deprecated, please use --gpu-id) ids of gpus to use ' + '(only applicable to non-distributed training)') + parser.add_argument( + '--gpu-id', + type=int, + default=0, + help='id of gpu to use ' + '(only applicable to non-distributed testing)') + parser.add_argument( + '--format-only', + action='store_true', + help='Format the output results without perform evaluation. It is' + 'useful when you want to format the result to a specific format and ' + 'submit it to the test server') + parser.add_argument( + '--eval', + type=str, + nargs='+', + help='evaluation metrics, which depends on the dataset, e.g., "bbox",' + ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC') + parser.add_argument('--show', action='store_true', help='show results') + parser.add_argument( + '--show-dir', help='directory where painted images will be saved') + parser.add_argument( + '--show-score-thr', + type=float, + default=0.3, + help='score threshold (default: 0.3)') + parser.add_argument( + '--gpu-collect', + action='store_true', + help='whether to use gpu to collect results.') + parser.add_argument( + '--tmpdir', + help='tmp directory used for collecting results from multiple ' + 'workers, available when gpu-collect is not specified') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--options', + nargs='+', + action=DictAction, + help='custom options for evaluation, the key-value pair in xxx=yyy ' + 'format will be kwargs for dataset.evaluate() function (deprecate), ' + 'change to --eval-options instead.') + parser.add_argument( + '--eval-options', + nargs='+', + action=DictAction, + help='custom options for evaluation, the key-value pair in xxx=yyy ' + 'format will be kwargs for dataset.evaluate() function') + parser.add_argument( + '--launcher', + choices=['none', 'pytorch', 'slurm', 'mpi'], + default='none', + help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) + args = parser.parse_args() + if 'LOCAL_RANK' not in os.environ: + os.environ['LOCAL_RANK'] = str(args.local_rank) + + if args.options and args.eval_options: + raise ValueError( + '--options and --eval-options cannot be both ' + 'specified, --options is deprecated in favor of --eval-options') + if args.options: + warnings.warn('--options is deprecated in favor of --eval-options') + args.eval_options = args.options + return args + + +def main(): + args = parse_args() + + assert args.out or args.eval or args.format_only or args.show \ + or args.show_dir, \ + ('Please specify at least one operation (save/eval/format/show the ' + 'results / save the results) with the argument "--out", "--eval"' + ', "--format-only", "--show" or "--show-dir"') + + if args.eval and args.format_only: + raise ValueError('--eval and --format_only cannot be both specified') + + if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): + raise ValueError('The output file must be a pkl file.') + + cfg = Config.fromfile(args.config) + + # replace the ${key} with the value of cfg.key + cfg = replace_cfg_vals(cfg) + + # update data root according to MMDET_DATASETS + update_data_root(cfg) + + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + cfg = compat_cfg(cfg) + + # set multi-process settings + setup_multi_processes(cfg) + + # set cudnn_benchmark + if cfg.get('cudnn_benchmark', False): + torch.backends.cudnn.benchmark = True + + if 'pretrained' in cfg.model: + cfg.model.pretrained = None + elif 'init_cfg' in cfg.model.backbone: + cfg.model.backbone.init_cfg = None + + if cfg.model.get('neck'): + if isinstance(cfg.model.neck, list): + for neck_cfg in cfg.model.neck: + if neck_cfg.get('rfp_backbone'): + if neck_cfg.rfp_backbone.get('pretrained'): + neck_cfg.rfp_backbone.pretrained = None + elif cfg.model.neck.get('rfp_backbone'): + if cfg.model.neck.rfp_backbone.get('pretrained'): + cfg.model.neck.rfp_backbone.pretrained = None + + if args.gpu_ids is not None: + cfg.gpu_ids = args.gpu_ids[0:1] + warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' + 'Because we only support single GPU mode in ' + 'non-distributed testing. Use the first GPU ' + 'in `gpu_ids` now.') + else: + cfg.gpu_ids = [args.gpu_id] + + if cfg.gpu_ids == [-1]: + cfg.device = "cpu" + else: + cfg.device = get_device() + + # init distributed env first, since logger depends on the dist info. + if args.launcher == 'none': + distributed = False + else: + distributed = True + init_dist(args.launcher, **cfg.dist_params) + + + test_dataloader_default_args = dict( + samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False) + + # in case the test dataset is concatenated + if isinstance(cfg.data.test, dict): + cfg.data.test.test_mode = True + if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1: + # Replace 'ImageToTensor' to 'DefaultFormatBundle' + cfg.data.test.pipeline = replace_ImageToTensor( + cfg.data.test.pipeline) + elif isinstance(cfg.data.test, list): + for ds_cfg in cfg.data.test: + ds_cfg.test_mode = True + if cfg.data.test_dataloader.get('samples_per_gpu', 1) > 1: + for ds_cfg in cfg.data.test: + ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) + + test_loader_cfg = { + **test_dataloader_default_args, + **cfg.data.get('test_dataloader', {}) + } + + rank, _ = get_dist_info() + # allows not to create + if args.work_dir is not None and rank == 0: + mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + json_file = osp.join(args.work_dir, f'eval_{timestamp}.json') + + # build the dataloader + dataset = build_dataset(cfg.data.test) + data_loader = build_dataloader(dataset, **test_loader_cfg) + + # build the model and load checkpoint + cfg.model.train_cfg = None + model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) + fp16_cfg = cfg.get('fp16', None) + if fp16_cfg is not None: + print(f"{fp16_cfg=}") + wrap_fp16_model(model) + + checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') + if args.fuse_conv_bn: + model = fuse_conv_bn(model) + # old versions did not save class info in checkpoints, this walkaround is + # for backward compatibility + if 'CLASSES' in checkpoint.get('meta', {}): + model.CLASSES = checkpoint['meta']['CLASSES'] + else: + model.CLASSES = dataset.CLASSES + + if not distributed: + model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids) + outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, + args.show_score_thr) + else: + model = build_ddp( + model, + cfg.device, + device_ids=[int(os.environ['LOCAL_RANK'])], + broadcast_buffers=False) + outputs = multi_gpu_test( + model, data_loader, args.tmpdir, args.gpu_collect + or cfg.evaluation.get('gpu_collect', False)) + + rank, _ = get_dist_info() + if rank == 0: + if args.out: + print(f'\nwriting results to {args.out}') + mmcv.dump(outputs, args.out) + kwargs = {} if args.eval_options is None else args.eval_options + if args.format_only: + dataset.format_results(outputs, **kwargs) + if args.eval: + eval_kwargs = cfg.get('evaluation', {}).copy() + # hard-code way to remove EvalHook args + for key in [ + 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', + 'rule', 'dynamic_intervals' + ]: + eval_kwargs.pop(key, None) + eval_kwargs.update(dict(metric=args.eval, **kwargs)) + metric = dataset.evaluate(outputs, **eval_kwargs) + print(metric) + metric_dict = dict(config=args.config, metric=metric) + if args.work_dir is not None and rank == 0: + mmcv.dump(metric_dict, json_file) + + +if __name__ == '__main__': + main() \ No newline at end of file