diff --git a/docs/detection/double_detection_filter.md b/docs/detection/double_detection_filter.md new file mode 100644 index 000000000..1631852f4 --- /dev/null +++ b/docs/detection/double_detection_filter.md @@ -0,0 +1,30 @@ +--- +comments: true +status: new +--- + +# Double Detection Filter + +
+

OverlapFilter

+
+ +:::supervision.detection.overlap_filter.OverlapFilter + +
+

box_non_max_suppression

+
+ +:::supervision.detection.overlap_filter.box_non_max_suppression + +
+

mask_non_max_suppression

+
+ +:::supervision.detection.overlap_filter.mask_non_max_suppression + +
+

box_non_max_merge

+
+ +:::supervision.detection.overlap_filter.box_non_max_merge diff --git a/docs/detection/utils.md b/docs/detection/utils.md index f9c9473bc..369746a3e 100644 --- a/docs/detection/utils.md +++ b/docs/detection/utils.md @@ -17,18 +17,6 @@ status: new :::supervision.detection.utils.mask_iou_batch -
-

box_non_max_suppression

-
- -:::supervision.detection.utils.box_non_max_suppression - -
-

mask_non_max_suppression

-
- -:::supervision.detection.utils.mask_non_max_suppression -

polygon_to_mask

diff --git a/mkdocs.yml b/mkdocs.yml index f257238df..19d6a4fd4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -48,6 +48,7 @@ nav: - Core: detection/core.md - Annotators: detection/annotators.md - Metrics: detection/metrics.md + - Double Detection Filter: detection/double_detection_filter.md - Utils: detection/utils.md - Keypoint Detection: - Core: keypoint/core.md diff --git a/supervision/__init__.py b/supervision/__init__.py index 1a46226fd..4f28d49f4 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -40,6 +40,12 @@ from supervision.detection.core import Detections from supervision.detection.line_zone import LineZone, LineZoneAnnotator from supervision.detection.lmm import LMM +from supervision.detection.overlap_filter import ( + OverlapFilter, + box_non_max_merge, + box_non_max_suppression, + mask_non_max_suppression, +) from supervision.detection.tools.csv_sink import CSVSink from supervision.detection.tools.inference_slicer import InferenceSlicer from supervision.detection.tools.json_sink import JSONSink @@ -47,15 +53,12 @@ from supervision.detection.tools.smoother import DetectionsSmoother from supervision.detection.utils import ( box_iou_batch, - box_non_max_merge, - box_non_max_suppression, calculate_masks_centroids, clip_boxes, contains_holes, contains_multiple_segments, filter_polygons_by_area, mask_iou_batch, - mask_non_max_suppression, mask_to_polygons, mask_to_xyxy, move_boxes, diff --git a/supervision/detection/core.py b/supervision/detection/core.py index f93aed1c4..a1239d969 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -8,15 +8,17 @@ from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES from supervision.detection.lmm import LMM, from_paligemma, validate_lmm_and_kwargs -from supervision.detection.utils import ( - box_iou_batch, +from supervision.detection.overlap_filter import ( box_non_max_merge, box_non_max_suppression, + mask_non_max_suppression, +) +from supervision.detection.utils import ( + box_iou_batch, calculate_masks_centroids, extract_ultralytics_masks, get_data_item, is_data_equal, - mask_non_max_suppression, mask_to_xyxy, merge_data, process_roboflow_result, diff --git a/supervision/detection/overlap_filter.py b/supervision/detection/overlap_filter.py new file mode 100644 index 000000000..ab4408d18 --- /dev/null +++ b/supervision/detection/overlap_filter.py @@ -0,0 +1,263 @@ +from enum import Enum +from typing import List, Union + +import numpy as np +import numpy.typing as npt + +from supervision.detection.utils import box_iou_batch, mask_iou_batch + + +def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray: + """ + Resize all masks in the array to have a maximum dimension of max_dimension, + maintaining aspect ratio. + + Args: + masks (np.ndarray): 3D array of binary masks with shape (N, H, W). + max_dimension (int): The maximum dimension for the resized masks. + + Returns: + np.ndarray: Array of resized masks. + """ + max_height = np.max(masks.shape[1]) + max_width = np.max(masks.shape[2]) + scale = min(max_dimension / max_height, max_dimension / max_width) + + new_height = int(scale * max_height) + new_width = int(scale * max_width) + + x = np.linspace(0, max_width - 1, new_width).astype(int) + y = np.linspace(0, max_height - 1, new_height).astype(int) + xv, yv = np.meshgrid(x, y) + + resized_masks = masks[:, yv, xv] + + resized_masks = resized_masks.reshape(masks.shape[0], new_height, new_width) + return resized_masks + + +def mask_non_max_suppression( + predictions: np.ndarray, + masks: np.ndarray, + iou_threshold: float = 0.5, + mask_dimension: int = 640, +) -> np.ndarray: + """ + Perform Non-Maximum Suppression (NMS) on segmentation predictions. + + Args: + predictions (np.ndarray): A 2D array of object detection predictions in + the format of `(x_min, y_min, x_max, y_max, score)` + or `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or + `(N, 6)`, where N is the number of predictions. + masks (np.ndarray): A 3D array of binary masks corresponding to the predictions. + Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the + dimensions of each mask. + iou_threshold (float, optional): The intersection-over-union threshold + to use for non-maximum suppression. + mask_dimension (int, optional): The dimension to which the masks should be + resized before computing IOU values. Defaults to 640. + + Returns: + np.ndarray: A boolean array indicating which predictions to keep after + non-maximum suppression. + + Raises: + AssertionError: If `iou_threshold` is not within the closed + range from `0` to `1`. + """ + assert 0 <= iou_threshold <= 1, ( + "Value of `iou_threshold` must be in the closed range from 0 to 1, " + f"{iou_threshold} given." + ) + rows, columns = predictions.shape + + if columns == 5: + predictions = np.c_[predictions, np.zeros(rows)] + + sort_index = predictions[:, 4].argsort()[::-1] + predictions = predictions[sort_index] + masks = masks[sort_index] + masks_resized = resize_masks(masks, mask_dimension) + ious = mask_iou_batch(masks_resized, masks_resized) + categories = predictions[:, 5] + + keep = np.ones(rows, dtype=bool) + for i in range(rows): + if keep[i]: + condition = (ious[i] > iou_threshold) & (categories[i] == categories) + keep[i + 1 :] = np.where(condition[i + 1 :], False, keep[i + 1 :]) + + return keep[sort_index.argsort()] + + +def box_non_max_suppression( + predictions: np.ndarray, iou_threshold: float = 0.5 +) -> np.ndarray: + """ + Perform Non-Maximum Suppression (NMS) on object detection predictions. + + Args: + predictions (np.ndarray): An array of object detection predictions in + the format of `(x_min, y_min, x_max, y_max, score)` + or `(x_min, y_min, x_max, y_max, score, class)`. + iou_threshold (float, optional): The intersection-over-union threshold + to use for non-maximum suppression. + + Returns: + np.ndarray: A boolean array indicating which predictions to keep after n + on-maximum suppression. + + Raises: + AssertionError: If `iou_threshold` is not within the + closed range from `0` to `1`. + """ + assert 0 <= iou_threshold <= 1, ( + "Value of `iou_threshold` must be in the closed range from 0 to 1, " + f"{iou_threshold} given." + ) + rows, columns = predictions.shape + + # add column #5 - category filled with zeros for agnostic nms + if columns == 5: + predictions = np.c_[predictions, np.zeros(rows)] + + # sort predictions column #4 - score + sort_index = np.flip(predictions[:, 4].argsort()) + predictions = predictions[sort_index] + + boxes = predictions[:, :4] + categories = predictions[:, 5] + ious = box_iou_batch(boxes, boxes) + ious = ious - np.eye(rows) + + keep = np.ones(rows, dtype=bool) + + for index, (iou, category) in enumerate(zip(ious, categories)): + if not keep[index]: + continue + + # drop detections with iou > iou_threshold and + # same category as current detections + condition = (iou > iou_threshold) & (categories == category) + keep = keep & ~condition + + return keep[sort_index.argsort()] + + +def group_overlapping_boxes( + predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5 +) -> List[List[int]]: + """ + Apply greedy version of non-maximum merging to avoid detecting too many + overlapping bounding boxes for a given object. + + Args: + predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` containing + the bounding boxes coordinates in format `[x1, y1, x2, y2]` + and the confidence scores. + iou_threshold (float, optional): The intersection-over-union threshold + to use for non-maximum suppression. Defaults to 0.5. + + Returns: + List[List[int]]: Groups of prediction indices be merged. + Each group may have 1 or more elements. + """ + merge_groups: List[List[int]] = [] + + scores = predictions[:, 4] + order = scores.argsort() + + while len(order) > 0: + idx = int(order[-1]) + + order = order[:-1] + if len(order) == 0: + merge_groups.append([idx]) + break + + merge_candidate = np.expand_dims(predictions[idx], axis=0) + ious = box_iou_batch(predictions[order][:, :4], merge_candidate[:, :4]) + ious = ious.flatten() + + above_threshold = ious >= iou_threshold + merge_group = [idx] + np.flip(order[above_threshold]).tolist() + merge_groups.append(merge_group) + order = order[~above_threshold] + return merge_groups + + +def box_non_max_merge( + predictions: npt.NDArray[np.float64], + iou_threshold: float = 0.5, +) -> List[List[int]]: + """ + Apply greedy version of non-maximum merging per category to avoid detecting + too many overlapping bounding boxes for a given object. + + Args: + predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` or `(n, 6)` + containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`, + the confidence scores and class_ids. Omit class_id column to allow + detections of different classes to be merged. + iou_threshold (float, optional): The intersection-over-union threshold + to use for non-maximum suppression. Defaults to 0.5. + + Returns: + List[List[int]]: Groups of prediction indices be merged. + Each group may have 1 or more elements. + """ + if predictions.shape[1] == 5: + return group_overlapping_boxes(predictions, iou_threshold) + + category_ids = predictions[:, 5] + merge_groups = [] + for category_id in np.unique(category_ids): + curr_indices = np.where(category_ids == category_id)[0] + merge_class_groups = group_overlapping_boxes( + predictions[curr_indices], iou_threshold + ) + + for merge_class_group in merge_class_groups: + merge_groups.append(curr_indices[merge_class_group].tolist()) + + for merge_group in merge_groups: + if len(merge_group) == 0: + raise ValueError( + f"Empty group detected when non-max-merging " + f"detections: {merge_groups}" + ) + return merge_groups + + +class OverlapFilter(Enum): + """ + Enum specifying the strategy for filtering overlapping detections. + + Attributes: + NONE: Do not filter detections based on overlap. + NON_MAX_SUPPRESSION: Filter detections using non-max suppression. This means, + detections that overlap by more than a set threshold will be discarded, + except for the one with the highest confidence. + NON_MAX_MERGE: Merge detections with non-max merging. This means, + detections that overlap by more than a set threshold will be merged + into a single detection. + """ + + NONE = "none" + NON_MAX_SUPPRESSION = "non_max_suppression" + NON_MAX_MERGE = "non_max_merge" + + +def validate_overlap_filter( + strategy: Union[OverlapFilter, str], +) -> OverlapFilter: + if isinstance(strategy, str): + try: + strategy = OverlapFilter(strategy.lower()) + except ValueError: + raise ValueError( + f"Invalid strategy value: {strategy}. Must be one of " + f"{[e.value for e in OverlapFilter]}" + ) + return strategy diff --git a/supervision/detection/tools/inference_slicer.py b/supervision/detection/tools/inference_slicer.py index 82551434e..134361bd3 100644 --- a/supervision/detection/tools/inference_slicer.py +++ b/supervision/detection/tools/inference_slicer.py @@ -1,11 +1,14 @@ +import warnings from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Callable, Optional, Tuple +from typing import Callable, Optional, Tuple, Union import numpy as np from supervision.detection.core import Detections +from supervision.detection.overlap_filter import OverlapFilter, validate_overlap_filter from supervision.detection.utils import move_boxes, move_masks from supervision.utils.image import crop_image +from supervision.utils.internal import SupervisionWarnings def move_detections( @@ -50,8 +53,10 @@ class InferenceSlicer: `(width, height)`. overlap_ratio_wh (Tuple[float, float]): Overlap ratio between consecutive slices in the format `(width_ratio, height_ratio)`. - iou_threshold (Optional[float]): Intersection over Union (IoU) threshold - used for non-max suppression. + overlap_filter_strategy (Union[OverlapFilter, str]): Strategy for + filtering or merging overlapping detections in slices. + iou_threshold (float): Intersection over Union (IoU) threshold + used when filtering by overlap. callback (Callable): A function that performs inference on a given image slice and returns detections. thread_workers (int): Number of threads for parallel execution. @@ -68,12 +73,18 @@ def __init__( callback: Callable[[np.ndarray], Detections], slice_wh: Tuple[int, int] = (320, 320), overlap_ratio_wh: Tuple[float, float] = (0.2, 0.2), - iou_threshold: Optional[float] = 0.5, + overlap_filter_strategy: Union[ + OverlapFilter, str + ] = OverlapFilter.NON_MAX_SUPPRESSION, + iou_threshold: float = 0.5, thread_workers: int = 1, ): + overlap_filter_strategy = validate_overlap_filter(overlap_filter_strategy) + self.slice_wh = slice_wh self.overlap_ratio_wh = overlap_ratio_wh self.iou_threshold = iou_threshold + self.overlap_filter_strategy = overlap_filter_strategy self.callback = callback self.thread_workers = thread_workers @@ -104,7 +115,10 @@ def callback(image_slice: np.ndarray) -> sv.Detections: result = model(image_slice)[0] return sv.Detections.from_ultralytics(result) - slicer = sv.InferenceSlicer(callback = callback) + slicer = sv.InferenceSlicer( + callback=callback, + overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION, + ) detections = slicer(image) ``` @@ -124,9 +138,19 @@ def callback(image_slice: np.ndarray) -> sv.Detections: for future in as_completed(futures): detections_list.append(future.result()) - return Detections.merge(detections_list=detections_list).with_nms( - threshold=self.iou_threshold - ) + merged = Detections.merge(detections_list=detections_list) + if self.overlap_filter_strategy == OverlapFilter.NONE: + return merged + elif self.overlap_filter_strategy == OverlapFilter.NON_MAX_SUPPRESSION: + return merged.with_nms(threshold=self.iou_threshold) + elif self.overlap_filter_strategy == OverlapFilter.NON_MAX_MERGE: + return merged.with_nmm(threshold=self.iou_threshold) + else: + warnings.warn( + f"Invalid overlap filter strategy: {self.overlap_filter_strategy}", + category=SupervisionWarnings, + ) + return merged def _run_callback(self, image, offset) -> Detections: """ diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py index 1ca487916..b36b6853f 100644 --- a/supervision/detection/utils.py +++ b/supervision/detection/utils.py @@ -139,229 +139,6 @@ def mask_iou_batch( return np.vstack(ious) -def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray: - """ - Resize all masks in the array to have a maximum dimension of max_dimension, - maintaining aspect ratio. - - Args: - masks (np.ndarray): 3D array of binary masks with shape (N, H, W). - max_dimension (int): The maximum dimension for the resized masks. - - Returns: - np.ndarray: Array of resized masks. - """ - max_height = np.max(masks.shape[1]) - max_width = np.max(masks.shape[2]) - scale = min(max_dimension / max_height, max_dimension / max_width) - - new_height = int(scale * max_height) - new_width = int(scale * max_width) - - x = np.linspace(0, max_width - 1, new_width).astype(int) - y = np.linspace(0, max_height - 1, new_height).astype(int) - xv, yv = np.meshgrid(x, y) - - resized_masks = masks[:, yv, xv] - - resized_masks = resized_masks.reshape(masks.shape[0], new_height, new_width) - return resized_masks - - -def mask_non_max_suppression( - predictions: np.ndarray, - masks: np.ndarray, - iou_threshold: float = 0.5, - mask_dimension: int = 640, -) -> np.ndarray: - """ - Perform Non-Maximum Suppression (NMS) on segmentation predictions. - - Args: - predictions (np.ndarray): A 2D array of object detection predictions in - the format of `(x_min, y_min, x_max, y_max, score)` - or `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or - `(N, 6)`, where N is the number of predictions. - masks (np.ndarray): A 3D array of binary masks corresponding to the predictions. - Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the - dimensions of each mask. - iou_threshold (float, optional): The intersection-over-union threshold - to use for non-maximum suppression. - mask_dimension (int, optional): The dimension to which the masks should be - resized before computing IOU values. Defaults to 640. - - Returns: - np.ndarray: A boolean array indicating which predictions to keep after - non-maximum suppression. - - Raises: - AssertionError: If `iou_threshold` is not within the closed - range from `0` to `1`. - """ - assert 0 <= iou_threshold <= 1, ( - "Value of `iou_threshold` must be in the closed range from 0 to 1, " - f"{iou_threshold} given." - ) - rows, columns = predictions.shape - - if columns == 5: - predictions = np.c_[predictions, np.zeros(rows)] - - sort_index = predictions[:, 4].argsort()[::-1] - predictions = predictions[sort_index] - masks = masks[sort_index] - masks_resized = resize_masks(masks, mask_dimension) - ious = mask_iou_batch(masks_resized, masks_resized) - categories = predictions[:, 5] - - keep = np.ones(rows, dtype=bool) - for i in range(rows): - if keep[i]: - condition = (ious[i] > iou_threshold) & (categories[i] == categories) - keep[i + 1 :] = np.where(condition[i + 1 :], False, keep[i + 1 :]) - - return keep[sort_index.argsort()] - - -def box_non_max_suppression( - predictions: np.ndarray, iou_threshold: float = 0.5 -) -> np.ndarray: - """ - Perform Non-Maximum Suppression (NMS) on object detection predictions. - - Args: - predictions (np.ndarray): An array of object detection predictions in - the format of `(x_min, y_min, x_max, y_max, score)` - or `(x_min, y_min, x_max, y_max, score, class)`. - iou_threshold (float, optional): The intersection-over-union threshold - to use for non-maximum suppression. - - Returns: - np.ndarray: A boolean array indicating which predictions to keep after n - on-maximum suppression. - - Raises: - AssertionError: If `iou_threshold` is not within the - closed range from `0` to `1`. - """ - assert 0 <= iou_threshold <= 1, ( - "Value of `iou_threshold` must be in the closed range from 0 to 1, " - f"{iou_threshold} given." - ) - rows, columns = predictions.shape - - # add column #5 - category filled with zeros for agnostic nms - if columns == 5: - predictions = np.c_[predictions, np.zeros(rows)] - - # sort predictions column #4 - score - sort_index = np.flip(predictions[:, 4].argsort()) - predictions = predictions[sort_index] - - boxes = predictions[:, :4] - categories = predictions[:, 5] - ious = box_iou_batch(boxes, boxes) - ious = ious - np.eye(rows) - - keep = np.ones(rows, dtype=bool) - - for index, (iou, category) in enumerate(zip(ious, categories)): - if not keep[index]: - continue - - # drop detections with iou > iou_threshold and - # same category as current detections - condition = (iou > iou_threshold) & (categories == category) - keep = keep & ~condition - - return keep[sort_index.argsort()] - - -def group_overlapping_boxes( - predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5 -) -> List[List[int]]: - """ - Apply greedy version of non-maximum merging to avoid detecting too many - overlapping bounding boxes for a given object. - - Args: - predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` containing - the bounding boxes coordinates in format `[x1, y1, x2, y2]` - and the confidence scores. - iou_threshold (float, optional): The intersection-over-union threshold - to use for non-maximum suppression. Defaults to 0.5. - - Returns: - List[List[int]]: Groups of prediction indices be merged. - Each group may have 1 or more elements. - """ - merge_groups: List[List[int]] = [] - - scores = predictions[:, 4] - order = scores.argsort() - - while len(order) > 0: - idx = int(order[-1]) - - order = order[:-1] - if len(order) == 0: - merge_groups.append([idx]) - break - - merge_candidate = np.expand_dims(predictions[idx], axis=0) - ious = box_iou_batch(predictions[order][:, :4], merge_candidate[:, :4]) - ious = ious.flatten() - - above_threshold = ious >= iou_threshold - merge_group = [idx] + np.flip(order[above_threshold]).tolist() - merge_groups.append(merge_group) - order = order[~above_threshold] - return merge_groups - - -def box_non_max_merge( - predictions: npt.NDArray[np.float64], - iou_threshold: float = 0.5, -) -> List[List[int]]: - """ - Apply greedy version of non-maximum merging per category to avoid detecting - too many overlapping bounding boxes for a given object. - - Args: - predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` or `(n, 6)` - containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`, - the confidence scores and class_ids. Omit class_id column to allow - detections of different classes to be merged. - iou_threshold (float, optional): The intersection-over-union threshold - to use for non-maximum suppression. Defaults to 0.5. - - Returns: - List[List[int]]: Groups of prediction indices be merged. - Each group may have 1 or more elements. - """ - if predictions.shape[1] == 5: - return group_overlapping_boxes(predictions, iou_threshold) - - category_ids = predictions[:, 5] - merge_groups = [] - for category_id in np.unique(category_ids): - curr_indices = np.where(category_ids == category_id)[0] - merge_class_groups = group_overlapping_boxes( - predictions[curr_indices], iou_threshold - ) - - for merge_class_group in merge_class_groups: - merge_groups.append(curr_indices[merge_class_group].tolist()) - - for merge_group in merge_groups: - if len(merge_group) == 0: - raise ValueError( - f"Empty group detected when non-max-merging " - f"detections: {merge_groups}" - ) - return merge_groups - - def clip_boxes(xyxy: np.ndarray, resolution_wh: Tuple[int, int]) -> np.ndarray: """ Clips bounding boxes coordinates to fit within the frame resolution. diff --git a/test/detection/test_overlap_filter.py b/test/detection/test_overlap_filter.py new file mode 100644 index 000000000..f628c30f9 --- /dev/null +++ b/test/detection/test_overlap_filter.py @@ -0,0 +1,449 @@ +from contextlib import ExitStack as DoesNotRaise +from typing import List, Optional + +import numpy as np +import pytest + +from supervision.detection.overlap_filter import ( + box_non_max_suppression, + group_overlapping_boxes, + mask_non_max_suppression, +) + + +@pytest.mark.parametrize( + "predictions, iou_threshold, expected_result, exception", + [ + ( + np.empty(shape=(0, 5), dtype=float), + 0.5, + [], + DoesNotRaise(), + ), + ( + np.array([[0, 0, 10, 10, 1.0]]), + 0.5, + [[0]], + DoesNotRaise(), + ), + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]), + 0.5, + [[1, 0]], + DoesNotRaise(), + ), # High overlap, tie-break to second det + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 0.99]]), + 0.5, + [[0, 1]], + DoesNotRaise(), + ), # High overlap, merge to high confidence + ( + np.array([[0, 0, 10, 10, 0.99], [0, 0, 9, 9, 1.0]]), + 0.5, + [[1, 0]], + DoesNotRaise(), + ), # (test symmetry) High overlap, merge to high confidence + ( + np.array([[0, 0, 10, 10, 0.90], [0, 0, 9, 9, 1.0]]), + 0.5, + [[1, 0]], + DoesNotRaise(), + ), # (test symmetry) High overlap, merge to high confidence + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]), + 1.0, + [[1], [0]], + DoesNotRaise(), + ), # High IOU required + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]), + 0.0, + [[1, 0]], + DoesNotRaise(), + ), # No IOU required + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 5, 5, 0.9]]), + 0.25, + [[0, 1]], + DoesNotRaise(), + ), # Below IOU requirement + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 5, 5, 0.9]]), + 0.26, + [[0], [1]], + DoesNotRaise(), + ), # Above IOU requirement + ( + np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0], [0, 0, 8, 8, 1.0]]), + 0.5, + [[2, 1, 0]], + DoesNotRaise(), + ), # 3 boxes + ( + np.array( + [ + [0, 0, 10, 10, 1.0], + [0, 0, 9, 9, 1.0], + [5, 5, 10, 10, 1.0], + [6, 6, 10, 10, 1.0], + [9, 9, 10, 10, 1.0], + ] + ), + 0.5, + [[4], [3, 2], [1, 0]], + DoesNotRaise(), + ), # 5 boxes, 2 merges, 1 separate + ( + np.array( + [ + [0, 0, 2, 1, 1.0], + [1, 0, 3, 1, 1.0], + [2, 0, 4, 1, 1.0], + [3, 0, 5, 1, 1.0], + [4, 0, 6, 1, 1.0], + ] + ), + 0.33, + [[4, 3], [2, 1], [0]], + DoesNotRaise(), + ), # sequential merge, half overlap + ( + np.array( + [ + [0, 0, 2, 1, 0.9], + [1, 0, 3, 1, 0.9], + [2, 0, 4, 1, 1.0], + [3, 0, 5, 1, 0.9], + [4, 0, 6, 1, 0.9], + ] + ), + 0.33, + [[2, 3, 1], [4], [0]], + DoesNotRaise(), + ), # confidence + ], +) +def test_group_overlapping_boxes( + predictions: np.ndarray, + iou_threshold: float, + expected_result: List[List[int]], + exception: Exception, +) -> None: + with exception: + result = group_overlapping_boxes( + predictions=predictions, iou_threshold=iou_threshold + ) + + assert result == expected_result + + +@pytest.mark.parametrize( + "predictions, iou_threshold, expected_result, exception", + [ + ( + np.empty(shape=(0, 5)), + 0.5, + np.array([]), + DoesNotRaise(), + ), # single box with no category + ( + np.array([[10.0, 10.0, 40.0, 40.0, 0.8]]), + 0.5, + np.array([True]), + DoesNotRaise(), + ), # single box with no category + ( + np.array([[10.0, 10.0, 40.0, 40.0, 0.8, 0]]), + 0.5, + np.array([True]), + DoesNotRaise(), + ), # single box with category + ( + np.array( + [ + [10.0, 10.0, 40.0, 40.0, 0.8], + [15.0, 15.0, 40.0, 40.0, 0.9], + ] + ), + 0.5, + np.array([False, True]), + DoesNotRaise(), + ), # two boxes with no category + ( + np.array( + [ + [10.0, 10.0, 40.0, 40.0, 0.8, 0], + [15.0, 15.0, 40.0, 40.0, 0.9, 1], + ] + ), + 0.5, + np.array([True, True]), + DoesNotRaise(), + ), # two boxes with different category + ( + np.array( + [ + [10.0, 10.0, 40.0, 40.0, 0.8, 0], + [15.0, 15.0, 40.0, 40.0, 0.9, 0], + ] + ), + 0.5, + np.array([False, True]), + DoesNotRaise(), + ), # two boxes with same category + ( + np.array( + [ + [0.0, 0.0, 30.0, 40.0, 0.8], + [5.0, 5.0, 35.0, 45.0, 0.9], + [10.0, 10.0, 40.0, 50.0, 0.85], + ] + ), + 0.5, + np.array([False, True, False]), + DoesNotRaise(), + ), # three boxes with no category + ( + np.array( + [ + [0.0, 0.0, 30.0, 40.0, 0.8, 0], + [5.0, 5.0, 35.0, 45.0, 0.9, 1], + [10.0, 10.0, 40.0, 50.0, 0.85, 2], + ] + ), + 0.5, + np.array([True, True, True]), + DoesNotRaise(), + ), # three boxes with same category + ( + np.array( + [ + [0.0, 0.0, 30.0, 40.0, 0.8, 0], + [5.0, 5.0, 35.0, 45.0, 0.9, 0], + [10.0, 10.0, 40.0, 50.0, 0.85, 1], + ] + ), + 0.5, + np.array([False, True, True]), + DoesNotRaise(), + ), # three boxes with different category + ], +) +def test_box_non_max_suppression( + predictions: np.ndarray, + iou_threshold: float, + expected_result: Optional[np.ndarray], + exception: Exception, +) -> None: + with exception: + result = box_non_max_suppression( + predictions=predictions, iou_threshold=iou_threshold + ) + assert np.array_equal(result, expected_result) + + +@pytest.mark.parametrize( + "predictions, masks, iou_threshold, expected_result, exception", + [ + ( + np.empty((0, 6)), + np.empty((0, 5, 5)), + 0.5, + np.array([]), + DoesNotRaise(), + ), # empty predictions and masks + ( + np.array([[0, 0, 0, 0, 0.8]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ] + ] + ), + 0.5, + np.array([True]), + DoesNotRaise(), + ), # single mask with no category + ( + np.array([[0, 0, 0, 0, 0.8, 0]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ] + ] + ), + 0.5, + np.array([True]), + DoesNotRaise(), + ), # single mask with category + ( + np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, True, True], + [False, False, False, True, True], + [False, False, False, False, False], + ], + ] + ), + 0.5, + np.array([True, True]), + DoesNotRaise(), + ), # two masks non-overlapping with no category + ( + np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, False, False, False], + ], + ] + ), + 0.4, + np.array([False, True]), + DoesNotRaise(), + ), # two masks partially overlapping with no category + ( + np.array([[0, 0, 0, 0, 0.8, 0], [0, 0, 0, 0, 0.9, 1]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, False, False, False], + ], + ] + ), + 0.5, + np.array([True, True]), + DoesNotRaise(), + ), # two masks partially overlapping with different category + ( + np.array( + [ + [0, 0, 0, 0, 0.8], + [0, 0, 0, 0, 0.85], + [0, 0, 0, 0, 0.9], + ] + ), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, True, True], + [False, False, False, True, True], + [False, False, False, False, False], + [False, False, False, False, False], + ], + ] + ), + 0.5, + np.array([False, True, True]), + DoesNotRaise(), + ), # three masks with no category + ( + np.array( + [ + [0, 0, 0, 0, 0.8, 0], + [0, 0, 0, 0, 0.85, 1], + [0, 0, 0, 0, 0.9, 2], + ] + ), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + ] + ), + 0.5, + np.array([True, True, True]), + DoesNotRaise(), + ), # three masks with different category + ], +) +def test_mask_non_max_suppression( + predictions: np.ndarray, + masks: np.ndarray, + iou_threshold: float, + expected_result: Optional[np.ndarray], + exception: Exception, +) -> None: + with exception: + result = mask_non_max_suppression( + predictions=predictions, masks=masks, iou_threshold=iou_threshold + ) + assert np.array_equal(result, expected_result) diff --git a/test/detection/test_utils.py b/test/detection/test_utils.py index 837b3b840..f0f0a6b13 100644 --- a/test/detection/test_utils.py +++ b/test/detection/test_utils.py @@ -7,15 +7,12 @@ from supervision.config import CLASS_NAME_DATA_FIELD from supervision.detection.utils import ( - box_non_max_suppression, calculate_masks_centroids, clip_boxes, contains_holes, contains_multiple_segments, filter_polygons_by_area, get_data_item, - group_overlapping_boxes, - mask_non_max_suppression, merge_data, move_boxes, process_roboflow_result, @@ -26,444 +23,6 @@ TEST_MASK[:, 300:351, 200:251] = True -@pytest.mark.parametrize( - "predictions, iou_threshold, expected_result, exception", - [ - ( - np.empty(shape=(0, 5)), - 0.5, - np.array([]), - DoesNotRaise(), - ), # single box with no category - ( - np.array([[10.0, 10.0, 40.0, 40.0, 0.8]]), - 0.5, - np.array([True]), - DoesNotRaise(), - ), # single box with no category - ( - np.array([[10.0, 10.0, 40.0, 40.0, 0.8, 0]]), - 0.5, - np.array([True]), - DoesNotRaise(), - ), # single box with category - ( - np.array( - [ - [10.0, 10.0, 40.0, 40.0, 0.8], - [15.0, 15.0, 40.0, 40.0, 0.9], - ] - ), - 0.5, - np.array([False, True]), - DoesNotRaise(), - ), # two boxes with no category - ( - np.array( - [ - [10.0, 10.0, 40.0, 40.0, 0.8, 0], - [15.0, 15.0, 40.0, 40.0, 0.9, 1], - ] - ), - 0.5, - np.array([True, True]), - DoesNotRaise(), - ), # two boxes with different category - ( - np.array( - [ - [10.0, 10.0, 40.0, 40.0, 0.8, 0], - [15.0, 15.0, 40.0, 40.0, 0.9, 0], - ] - ), - 0.5, - np.array([False, True]), - DoesNotRaise(), - ), # two boxes with same category - ( - np.array( - [ - [0.0, 0.0, 30.0, 40.0, 0.8], - [5.0, 5.0, 35.0, 45.0, 0.9], - [10.0, 10.0, 40.0, 50.0, 0.85], - ] - ), - 0.5, - np.array([False, True, False]), - DoesNotRaise(), - ), # three boxes with no category - ( - np.array( - [ - [0.0, 0.0, 30.0, 40.0, 0.8, 0], - [5.0, 5.0, 35.0, 45.0, 0.9, 1], - [10.0, 10.0, 40.0, 50.0, 0.85, 2], - ] - ), - 0.5, - np.array([True, True, True]), - DoesNotRaise(), - ), # three boxes with same category - ( - np.array( - [ - [0.0, 0.0, 30.0, 40.0, 0.8, 0], - [5.0, 5.0, 35.0, 45.0, 0.9, 0], - [10.0, 10.0, 40.0, 50.0, 0.85, 1], - ] - ), - 0.5, - np.array([False, True, True]), - DoesNotRaise(), - ), # three boxes with different category - ], -) -def test_box_non_max_suppression( - predictions: np.ndarray, - iou_threshold: float, - expected_result: Optional[np.ndarray], - exception: Exception, -) -> None: - with exception: - result = box_non_max_suppression( - predictions=predictions, iou_threshold=iou_threshold - ) - assert np.array_equal(result, expected_result) - - -@pytest.mark.parametrize( - "predictions, iou_threshold, expected_result, exception", - [ - ( - np.empty(shape=(0, 5), dtype=float), - 0.5, - [], - DoesNotRaise(), - ), - ( - np.array([[0, 0, 10, 10, 1.0]]), - 0.5, - [[0]], - DoesNotRaise(), - ), - ( - np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]), - 0.5, - [[1, 0]], - DoesNotRaise(), - ), # High overlap, tie-break to second det - ( - np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 0.99]]), - 0.5, - [[0, 1]], - DoesNotRaise(), - ), # High overlap, merge to high confidence - ( - np.array([[0, 0, 10, 10, 0.99], [0, 0, 9, 9, 1.0]]), - 0.5, - [[1, 0]], - DoesNotRaise(), - ), # (test symmetry) High overlap, merge to high confidence - ( - np.array([[0, 0, 10, 10, 0.90], [0, 0, 9, 9, 1.0]]), - 0.5, - [[1, 0]], - DoesNotRaise(), - ), # (test symmetry) High overlap, merge to high confidence - ( - np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]), - 1.0, - [[1], [0]], - DoesNotRaise(), - ), # High IOU required - ( - np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]), - 0.0, - [[1, 0]], - DoesNotRaise(), - ), # No IOU required - ( - np.array([[0, 0, 10, 10, 1.0], [0, 0, 5, 5, 0.9]]), - 0.25, - [[0, 1]], - DoesNotRaise(), - ), # Below IOU requirement - ( - np.array([[0, 0, 10, 10, 1.0], [0, 0, 5, 5, 0.9]]), - 0.26, - [[0], [1]], - DoesNotRaise(), - ), # Above IOU requirement - ( - np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0], [0, 0, 8, 8, 1.0]]), - 0.5, - [[2, 1, 0]], - DoesNotRaise(), - ), # 3 boxes - ( - np.array( - [ - [0, 0, 10, 10, 1.0], - [0, 0, 9, 9, 1.0], - [5, 5, 10, 10, 1.0], - [6, 6, 10, 10, 1.0], - [9, 9, 10, 10, 1.0], - ] - ), - 0.5, - [[4], [3, 2], [1, 0]], - DoesNotRaise(), - ), # 5 boxes, 2 merges, 1 separate - ( - np.array( - [ - [0, 0, 2, 1, 1.0], - [1, 0, 3, 1, 1.0], - [2, 0, 4, 1, 1.0], - [3, 0, 5, 1, 1.0], - [4, 0, 6, 1, 1.0], - ] - ), - 0.33, - [[4, 3], [2, 1], [0]], - DoesNotRaise(), - ), # sequential merge, half overlap - ( - np.array( - [ - [0, 0, 2, 1, 0.9], - [1, 0, 3, 1, 0.9], - [2, 0, 4, 1, 1.0], - [3, 0, 5, 1, 0.9], - [4, 0, 6, 1, 0.9], - ] - ), - 0.33, - [[2, 3, 1], [4], [0]], - DoesNotRaise(), - ), # confidence - ], -) -def test_group_overlapping_boxes( - predictions: np.ndarray, - iou_threshold: float, - expected_result: List[List[int]], - exception: Exception, -) -> None: - with exception: - result = group_overlapping_boxes( - predictions=predictions, iou_threshold=iou_threshold - ) - - assert result == expected_result - - -@pytest.mark.parametrize( - "predictions, masks, iou_threshold, expected_result, exception", - [ - ( - np.empty((0, 6)), - np.empty((0, 5, 5)), - 0.5, - np.array([]), - DoesNotRaise(), - ), # empty predictions and masks - ( - np.array([[0, 0, 0, 0, 0.8]]), - np.array( - [ - [ - [False, False, False, False, False], - [False, True, True, True, False], - [False, True, True, True, False], - [False, True, True, True, False], - [False, False, False, False, False], - ] - ] - ), - 0.5, - np.array([True]), - DoesNotRaise(), - ), # single mask with no category - ( - np.array([[0, 0, 0, 0, 0.8, 0]]), - np.array( - [ - [ - [False, False, False, False, False], - [False, True, True, True, False], - [False, True, True, True, False], - [False, True, True, True, False], - [False, False, False, False, False], - ] - ] - ), - 0.5, - np.array([True]), - DoesNotRaise(), - ), # single mask with category - ( - np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]), - np.array( - [ - [ - [False, False, False, False, False], - [False, True, True, False, False], - [False, True, True, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - ], - [ - [False, False, False, False, False], - [False, False, False, False, False], - [False, False, False, True, True], - [False, False, False, True, True], - [False, False, False, False, False], - ], - ] - ), - 0.5, - np.array([True, True]), - DoesNotRaise(), - ), # two masks non-overlapping with no category - ( - np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]), - np.array( - [ - [ - [False, False, False, False, False], - [False, True, True, True, False], - [False, True, True, True, False], - [False, True, True, True, False], - [False, False, False, False, False], - ], - [ - [False, False, False, False, False], - [False, False, True, True, True], - [False, False, True, True, True], - [False, False, True, True, True], - [False, False, False, False, False], - ], - ] - ), - 0.4, - np.array([False, True]), - DoesNotRaise(), - ), # two masks partially overlapping with no category - ( - np.array([[0, 0, 0, 0, 0.8, 0], [0, 0, 0, 0, 0.9, 1]]), - np.array( - [ - [ - [False, False, False, False, False], - [False, True, True, True, False], - [False, True, True, True, False], - [False, True, True, True, False], - [False, False, False, False, False], - ], - [ - [False, False, False, False, False], - [False, False, True, True, True], - [False, False, True, True, True], - [False, False, True, True, True], - [False, False, False, False, False], - ], - ] - ), - 0.5, - np.array([True, True]), - DoesNotRaise(), - ), # two masks partially overlapping with different category - ( - np.array( - [ - [0, 0, 0, 0, 0.8], - [0, 0, 0, 0, 0.85], - [0, 0, 0, 0, 0.9], - ] - ), - np.array( - [ - [ - [False, False, False, False, False], - [False, True, True, False, False], - [False, True, True, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - ], - [ - [False, False, False, False, False], - [False, True, True, False, False], - [False, True, True, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - ], - [ - [False, False, False, False, False], - [False, False, False, True, True], - [False, False, False, True, True], - [False, False, False, False, False], - [False, False, False, False, False], - ], - ] - ), - 0.5, - np.array([False, True, True]), - DoesNotRaise(), - ), # three masks with no category - ( - np.array( - [ - [0, 0, 0, 0, 0.8, 0], - [0, 0, 0, 0, 0.85, 1], - [0, 0, 0, 0, 0.9, 2], - ] - ), - np.array( - [ - [ - [False, False, False, False, False], - [False, True, True, False, False], - [False, True, True, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - ], - [ - [False, False, False, False, False], - [False, True, True, False, False], - [False, True, True, False, False], - [False, True, True, False, False], - [False, False, False, False, False], - ], - [ - [False, False, False, False, False], - [False, True, True, False, False], - [False, True, True, False, False], - [False, False, False, False, False], - [False, False, False, False, False], - ], - ] - ), - 0.5, - np.array([True, True, True]), - DoesNotRaise(), - ), # three masks with different category - ], -) -def test_mask_non_max_suppression( - predictions: np.ndarray, - masks: np.ndarray, - iou_threshold: float, - expected_result: Optional[np.ndarray], - exception: Exception, -) -> None: - with exception: - result = mask_non_max_suppression( - predictions=predictions, masks=masks, iou_threshold=iou_threshold - ) - assert np.array_equal(result, expected_result) - - @pytest.mark.parametrize( "xyxy, resolution_wh, expected_result", [