diff --git a/docs/detection/double_detection_filter.md b/docs/detection/double_detection_filter.md
new file mode 100644
index 000000000..1631852f4
--- /dev/null
+++ b/docs/detection/double_detection_filter.md
@@ -0,0 +1,30 @@
+---
+comments: true
+status: new
+---
+
+# Double Detection Filter
+
+
+
+:::supervision.detection.overlap_filter.OverlapFilter
+
+
+
+:::supervision.detection.overlap_filter.box_non_max_suppression
+
+
+
+:::supervision.detection.overlap_filter.mask_non_max_suppression
+
+
+
+:::supervision.detection.overlap_filter.box_non_max_merge
diff --git a/docs/detection/utils.md b/docs/detection/utils.md
index f9c9473bc..369746a3e 100644
--- a/docs/detection/utils.md
+++ b/docs/detection/utils.md
@@ -17,18 +17,6 @@ status: new
:::supervision.detection.utils.mask_iou_batch
-
-
-:::supervision.detection.utils.box_non_max_suppression
-
-
-
-:::supervision.detection.utils.mask_non_max_suppression
-
diff --git a/mkdocs.yml b/mkdocs.yml
index f257238df..19d6a4fd4 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -48,6 +48,7 @@ nav:
- Core: detection/core.md
- Annotators: detection/annotators.md
- Metrics: detection/metrics.md
+ - Double Detection Filter: detection/double_detection_filter.md
- Utils: detection/utils.md
- Keypoint Detection:
- Core: keypoint/core.md
diff --git a/supervision/__init__.py b/supervision/__init__.py
index 1a46226fd..4f28d49f4 100644
--- a/supervision/__init__.py
+++ b/supervision/__init__.py
@@ -40,6 +40,12 @@
from supervision.detection.core import Detections
from supervision.detection.line_zone import LineZone, LineZoneAnnotator
from supervision.detection.lmm import LMM
+from supervision.detection.overlap_filter import (
+ OverlapFilter,
+ box_non_max_merge,
+ box_non_max_suppression,
+ mask_non_max_suppression,
+)
from supervision.detection.tools.csv_sink import CSVSink
from supervision.detection.tools.inference_slicer import InferenceSlicer
from supervision.detection.tools.json_sink import JSONSink
@@ -47,15 +53,12 @@
from supervision.detection.tools.smoother import DetectionsSmoother
from supervision.detection.utils import (
box_iou_batch,
- box_non_max_merge,
- box_non_max_suppression,
calculate_masks_centroids,
clip_boxes,
contains_holes,
contains_multiple_segments,
filter_polygons_by_area,
mask_iou_batch,
- mask_non_max_suppression,
mask_to_polygons,
mask_to_xyxy,
move_boxes,
diff --git a/supervision/detection/core.py b/supervision/detection/core.py
index f93aed1c4..a1239d969 100644
--- a/supervision/detection/core.py
+++ b/supervision/detection/core.py
@@ -8,15 +8,17 @@
from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES
from supervision.detection.lmm import LMM, from_paligemma, validate_lmm_and_kwargs
-from supervision.detection.utils import (
- box_iou_batch,
+from supervision.detection.overlap_filter import (
box_non_max_merge,
box_non_max_suppression,
+ mask_non_max_suppression,
+)
+from supervision.detection.utils import (
+ box_iou_batch,
calculate_masks_centroids,
extract_ultralytics_masks,
get_data_item,
is_data_equal,
- mask_non_max_suppression,
mask_to_xyxy,
merge_data,
process_roboflow_result,
diff --git a/supervision/detection/overlap_filter.py b/supervision/detection/overlap_filter.py
new file mode 100644
index 000000000..ab4408d18
--- /dev/null
+++ b/supervision/detection/overlap_filter.py
@@ -0,0 +1,263 @@
+from enum import Enum
+from typing import List, Union
+
+import numpy as np
+import numpy.typing as npt
+
+from supervision.detection.utils import box_iou_batch, mask_iou_batch
+
+
+def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:
+ """
+ Resize all masks in the array to have a maximum dimension of max_dimension,
+ maintaining aspect ratio.
+
+ Args:
+ masks (np.ndarray): 3D array of binary masks with shape (N, H, W).
+ max_dimension (int): The maximum dimension for the resized masks.
+
+ Returns:
+ np.ndarray: Array of resized masks.
+ """
+ max_height = np.max(masks.shape[1])
+ max_width = np.max(masks.shape[2])
+ scale = min(max_dimension / max_height, max_dimension / max_width)
+
+ new_height = int(scale * max_height)
+ new_width = int(scale * max_width)
+
+ x = np.linspace(0, max_width - 1, new_width).astype(int)
+ y = np.linspace(0, max_height - 1, new_height).astype(int)
+ xv, yv = np.meshgrid(x, y)
+
+ resized_masks = masks[:, yv, xv]
+
+ resized_masks = resized_masks.reshape(masks.shape[0], new_height, new_width)
+ return resized_masks
+
+
+def mask_non_max_suppression(
+ predictions: np.ndarray,
+ masks: np.ndarray,
+ iou_threshold: float = 0.5,
+ mask_dimension: int = 640,
+) -> np.ndarray:
+ """
+ Perform Non-Maximum Suppression (NMS) on segmentation predictions.
+
+ Args:
+ predictions (np.ndarray): A 2D array of object detection predictions in
+ the format of `(x_min, y_min, x_max, y_max, score)`
+ or `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or
+ `(N, 6)`, where N is the number of predictions.
+ masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
+ Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the
+ dimensions of each mask.
+ iou_threshold (float, optional): The intersection-over-union threshold
+ to use for non-maximum suppression.
+ mask_dimension (int, optional): The dimension to which the masks should be
+ resized before computing IOU values. Defaults to 640.
+
+ Returns:
+ np.ndarray: A boolean array indicating which predictions to keep after
+ non-maximum suppression.
+
+ Raises:
+ AssertionError: If `iou_threshold` is not within the closed
+ range from `0` to `1`.
+ """
+ assert 0 <= iou_threshold <= 1, (
+ "Value of `iou_threshold` must be in the closed range from 0 to 1, "
+ f"{iou_threshold} given."
+ )
+ rows, columns = predictions.shape
+
+ if columns == 5:
+ predictions = np.c_[predictions, np.zeros(rows)]
+
+ sort_index = predictions[:, 4].argsort()[::-1]
+ predictions = predictions[sort_index]
+ masks = masks[sort_index]
+ masks_resized = resize_masks(masks, mask_dimension)
+ ious = mask_iou_batch(masks_resized, masks_resized)
+ categories = predictions[:, 5]
+
+ keep = np.ones(rows, dtype=bool)
+ for i in range(rows):
+ if keep[i]:
+ condition = (ious[i] > iou_threshold) & (categories[i] == categories)
+ keep[i + 1 :] = np.where(condition[i + 1 :], False, keep[i + 1 :])
+
+ return keep[sort_index.argsort()]
+
+
+def box_non_max_suppression(
+ predictions: np.ndarray, iou_threshold: float = 0.5
+) -> np.ndarray:
+ """
+ Perform Non-Maximum Suppression (NMS) on object detection predictions.
+
+ Args:
+ predictions (np.ndarray): An array of object detection predictions in
+ the format of `(x_min, y_min, x_max, y_max, score)`
+ or `(x_min, y_min, x_max, y_max, score, class)`.
+ iou_threshold (float, optional): The intersection-over-union threshold
+ to use for non-maximum suppression.
+
+ Returns:
+ np.ndarray: A boolean array indicating which predictions to keep after n
+ on-maximum suppression.
+
+ Raises:
+ AssertionError: If `iou_threshold` is not within the
+ closed range from `0` to `1`.
+ """
+ assert 0 <= iou_threshold <= 1, (
+ "Value of `iou_threshold` must be in the closed range from 0 to 1, "
+ f"{iou_threshold} given."
+ )
+ rows, columns = predictions.shape
+
+ # add column #5 - category filled with zeros for agnostic nms
+ if columns == 5:
+ predictions = np.c_[predictions, np.zeros(rows)]
+
+ # sort predictions column #4 - score
+ sort_index = np.flip(predictions[:, 4].argsort())
+ predictions = predictions[sort_index]
+
+ boxes = predictions[:, :4]
+ categories = predictions[:, 5]
+ ious = box_iou_batch(boxes, boxes)
+ ious = ious - np.eye(rows)
+
+ keep = np.ones(rows, dtype=bool)
+
+ for index, (iou, category) in enumerate(zip(ious, categories)):
+ if not keep[index]:
+ continue
+
+ # drop detections with iou > iou_threshold and
+ # same category as current detections
+ condition = (iou > iou_threshold) & (categories == category)
+ keep = keep & ~condition
+
+ return keep[sort_index.argsort()]
+
+
+def group_overlapping_boxes(
+ predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5
+) -> List[List[int]]:
+ """
+ Apply greedy version of non-maximum merging to avoid detecting too many
+ overlapping bounding boxes for a given object.
+
+ Args:
+ predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` containing
+ the bounding boxes coordinates in format `[x1, y1, x2, y2]`
+ and the confidence scores.
+ iou_threshold (float, optional): The intersection-over-union threshold
+ to use for non-maximum suppression. Defaults to 0.5.
+
+ Returns:
+ List[List[int]]: Groups of prediction indices be merged.
+ Each group may have 1 or more elements.
+ """
+ merge_groups: List[List[int]] = []
+
+ scores = predictions[:, 4]
+ order = scores.argsort()
+
+ while len(order) > 0:
+ idx = int(order[-1])
+
+ order = order[:-1]
+ if len(order) == 0:
+ merge_groups.append([idx])
+ break
+
+ merge_candidate = np.expand_dims(predictions[idx], axis=0)
+ ious = box_iou_batch(predictions[order][:, :4], merge_candidate[:, :4])
+ ious = ious.flatten()
+
+ above_threshold = ious >= iou_threshold
+ merge_group = [idx] + np.flip(order[above_threshold]).tolist()
+ merge_groups.append(merge_group)
+ order = order[~above_threshold]
+ return merge_groups
+
+
+def box_non_max_merge(
+ predictions: npt.NDArray[np.float64],
+ iou_threshold: float = 0.5,
+) -> List[List[int]]:
+ """
+ Apply greedy version of non-maximum merging per category to avoid detecting
+ too many overlapping bounding boxes for a given object.
+
+ Args:
+ predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` or `(n, 6)`
+ containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`,
+ the confidence scores and class_ids. Omit class_id column to allow
+ detections of different classes to be merged.
+ iou_threshold (float, optional): The intersection-over-union threshold
+ to use for non-maximum suppression. Defaults to 0.5.
+
+ Returns:
+ List[List[int]]: Groups of prediction indices be merged.
+ Each group may have 1 or more elements.
+ """
+ if predictions.shape[1] == 5:
+ return group_overlapping_boxes(predictions, iou_threshold)
+
+ category_ids = predictions[:, 5]
+ merge_groups = []
+ for category_id in np.unique(category_ids):
+ curr_indices = np.where(category_ids == category_id)[0]
+ merge_class_groups = group_overlapping_boxes(
+ predictions[curr_indices], iou_threshold
+ )
+
+ for merge_class_group in merge_class_groups:
+ merge_groups.append(curr_indices[merge_class_group].tolist())
+
+ for merge_group in merge_groups:
+ if len(merge_group) == 0:
+ raise ValueError(
+ f"Empty group detected when non-max-merging "
+ f"detections: {merge_groups}"
+ )
+ return merge_groups
+
+
+class OverlapFilter(Enum):
+ """
+ Enum specifying the strategy for filtering overlapping detections.
+
+ Attributes:
+ NONE: Do not filter detections based on overlap.
+ NON_MAX_SUPPRESSION: Filter detections using non-max suppression. This means,
+ detections that overlap by more than a set threshold will be discarded,
+ except for the one with the highest confidence.
+ NON_MAX_MERGE: Merge detections with non-max merging. This means,
+ detections that overlap by more than a set threshold will be merged
+ into a single detection.
+ """
+
+ NONE = "none"
+ NON_MAX_SUPPRESSION = "non_max_suppression"
+ NON_MAX_MERGE = "non_max_merge"
+
+
+def validate_overlap_filter(
+ strategy: Union[OverlapFilter, str],
+) -> OverlapFilter:
+ if isinstance(strategy, str):
+ try:
+ strategy = OverlapFilter(strategy.lower())
+ except ValueError:
+ raise ValueError(
+ f"Invalid strategy value: {strategy}. Must be one of "
+ f"{[e.value for e in OverlapFilter]}"
+ )
+ return strategy
diff --git a/supervision/detection/tools/inference_slicer.py b/supervision/detection/tools/inference_slicer.py
index 82551434e..134361bd3 100644
--- a/supervision/detection/tools/inference_slicer.py
+++ b/supervision/detection/tools/inference_slicer.py
@@ -1,11 +1,14 @@
+import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
-from typing import Callable, Optional, Tuple
+from typing import Callable, Optional, Tuple, Union
import numpy as np
from supervision.detection.core import Detections
+from supervision.detection.overlap_filter import OverlapFilter, validate_overlap_filter
from supervision.detection.utils import move_boxes, move_masks
from supervision.utils.image import crop_image
+from supervision.utils.internal import SupervisionWarnings
def move_detections(
@@ -50,8 +53,10 @@ class InferenceSlicer:
`(width, height)`.
overlap_ratio_wh (Tuple[float, float]): Overlap ratio between consecutive
slices in the format `(width_ratio, height_ratio)`.
- iou_threshold (Optional[float]): Intersection over Union (IoU) threshold
- used for non-max suppression.
+ overlap_filter_strategy (Union[OverlapFilter, str]): Strategy for
+ filtering or merging overlapping detections in slices.
+ iou_threshold (float): Intersection over Union (IoU) threshold
+ used when filtering by overlap.
callback (Callable): A function that performs inference on a given image
slice and returns detections.
thread_workers (int): Number of threads for parallel execution.
@@ -68,12 +73,18 @@ def __init__(
callback: Callable[[np.ndarray], Detections],
slice_wh: Tuple[int, int] = (320, 320),
overlap_ratio_wh: Tuple[float, float] = (0.2, 0.2),
- iou_threshold: Optional[float] = 0.5,
+ overlap_filter_strategy: Union[
+ OverlapFilter, str
+ ] = OverlapFilter.NON_MAX_SUPPRESSION,
+ iou_threshold: float = 0.5,
thread_workers: int = 1,
):
+ overlap_filter_strategy = validate_overlap_filter(overlap_filter_strategy)
+
self.slice_wh = slice_wh
self.overlap_ratio_wh = overlap_ratio_wh
self.iou_threshold = iou_threshold
+ self.overlap_filter_strategy = overlap_filter_strategy
self.callback = callback
self.thread_workers = thread_workers
@@ -104,7 +115,10 @@ def callback(image_slice: np.ndarray) -> sv.Detections:
result = model(image_slice)[0]
return sv.Detections.from_ultralytics(result)
- slicer = sv.InferenceSlicer(callback = callback)
+ slicer = sv.InferenceSlicer(
+ callback=callback,
+ overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION,
+ )
detections = slicer(image)
```
@@ -124,9 +138,19 @@ def callback(image_slice: np.ndarray) -> sv.Detections:
for future in as_completed(futures):
detections_list.append(future.result())
- return Detections.merge(detections_list=detections_list).with_nms(
- threshold=self.iou_threshold
- )
+ merged = Detections.merge(detections_list=detections_list)
+ if self.overlap_filter_strategy == OverlapFilter.NONE:
+ return merged
+ elif self.overlap_filter_strategy == OverlapFilter.NON_MAX_SUPPRESSION:
+ return merged.with_nms(threshold=self.iou_threshold)
+ elif self.overlap_filter_strategy == OverlapFilter.NON_MAX_MERGE:
+ return merged.with_nmm(threshold=self.iou_threshold)
+ else:
+ warnings.warn(
+ f"Invalid overlap filter strategy: {self.overlap_filter_strategy}",
+ category=SupervisionWarnings,
+ )
+ return merged
def _run_callback(self, image, offset) -> Detections:
"""
diff --git a/supervision/detection/utils.py b/supervision/detection/utils.py
index 1ca487916..b36b6853f 100644
--- a/supervision/detection/utils.py
+++ b/supervision/detection/utils.py
@@ -139,229 +139,6 @@ def mask_iou_batch(
return np.vstack(ious)
-def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:
- """
- Resize all masks in the array to have a maximum dimension of max_dimension,
- maintaining aspect ratio.
-
- Args:
- masks (np.ndarray): 3D array of binary masks with shape (N, H, W).
- max_dimension (int): The maximum dimension for the resized masks.
-
- Returns:
- np.ndarray: Array of resized masks.
- """
- max_height = np.max(masks.shape[1])
- max_width = np.max(masks.shape[2])
- scale = min(max_dimension / max_height, max_dimension / max_width)
-
- new_height = int(scale * max_height)
- new_width = int(scale * max_width)
-
- x = np.linspace(0, max_width - 1, new_width).astype(int)
- y = np.linspace(0, max_height - 1, new_height).astype(int)
- xv, yv = np.meshgrid(x, y)
-
- resized_masks = masks[:, yv, xv]
-
- resized_masks = resized_masks.reshape(masks.shape[0], new_height, new_width)
- return resized_masks
-
-
-def mask_non_max_suppression(
- predictions: np.ndarray,
- masks: np.ndarray,
- iou_threshold: float = 0.5,
- mask_dimension: int = 640,
-) -> np.ndarray:
- """
- Perform Non-Maximum Suppression (NMS) on segmentation predictions.
-
- Args:
- predictions (np.ndarray): A 2D array of object detection predictions in
- the format of `(x_min, y_min, x_max, y_max, score)`
- or `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or
- `(N, 6)`, where N is the number of predictions.
- masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
- Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the
- dimensions of each mask.
- iou_threshold (float, optional): The intersection-over-union threshold
- to use for non-maximum suppression.
- mask_dimension (int, optional): The dimension to which the masks should be
- resized before computing IOU values. Defaults to 640.
-
- Returns:
- np.ndarray: A boolean array indicating which predictions to keep after
- non-maximum suppression.
-
- Raises:
- AssertionError: If `iou_threshold` is not within the closed
- range from `0` to `1`.
- """
- assert 0 <= iou_threshold <= 1, (
- "Value of `iou_threshold` must be in the closed range from 0 to 1, "
- f"{iou_threshold} given."
- )
- rows, columns = predictions.shape
-
- if columns == 5:
- predictions = np.c_[predictions, np.zeros(rows)]
-
- sort_index = predictions[:, 4].argsort()[::-1]
- predictions = predictions[sort_index]
- masks = masks[sort_index]
- masks_resized = resize_masks(masks, mask_dimension)
- ious = mask_iou_batch(masks_resized, masks_resized)
- categories = predictions[:, 5]
-
- keep = np.ones(rows, dtype=bool)
- for i in range(rows):
- if keep[i]:
- condition = (ious[i] > iou_threshold) & (categories[i] == categories)
- keep[i + 1 :] = np.where(condition[i + 1 :], False, keep[i + 1 :])
-
- return keep[sort_index.argsort()]
-
-
-def box_non_max_suppression(
- predictions: np.ndarray, iou_threshold: float = 0.5
-) -> np.ndarray:
- """
- Perform Non-Maximum Suppression (NMS) on object detection predictions.
-
- Args:
- predictions (np.ndarray): An array of object detection predictions in
- the format of `(x_min, y_min, x_max, y_max, score)`
- or `(x_min, y_min, x_max, y_max, score, class)`.
- iou_threshold (float, optional): The intersection-over-union threshold
- to use for non-maximum suppression.
-
- Returns:
- np.ndarray: A boolean array indicating which predictions to keep after n
- on-maximum suppression.
-
- Raises:
- AssertionError: If `iou_threshold` is not within the
- closed range from `0` to `1`.
- """
- assert 0 <= iou_threshold <= 1, (
- "Value of `iou_threshold` must be in the closed range from 0 to 1, "
- f"{iou_threshold} given."
- )
- rows, columns = predictions.shape
-
- # add column #5 - category filled with zeros for agnostic nms
- if columns == 5:
- predictions = np.c_[predictions, np.zeros(rows)]
-
- # sort predictions column #4 - score
- sort_index = np.flip(predictions[:, 4].argsort())
- predictions = predictions[sort_index]
-
- boxes = predictions[:, :4]
- categories = predictions[:, 5]
- ious = box_iou_batch(boxes, boxes)
- ious = ious - np.eye(rows)
-
- keep = np.ones(rows, dtype=bool)
-
- for index, (iou, category) in enumerate(zip(ious, categories)):
- if not keep[index]:
- continue
-
- # drop detections with iou > iou_threshold and
- # same category as current detections
- condition = (iou > iou_threshold) & (categories == category)
- keep = keep & ~condition
-
- return keep[sort_index.argsort()]
-
-
-def group_overlapping_boxes(
- predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5
-) -> List[List[int]]:
- """
- Apply greedy version of non-maximum merging to avoid detecting too many
- overlapping bounding boxes for a given object.
-
- Args:
- predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` containing
- the bounding boxes coordinates in format `[x1, y1, x2, y2]`
- and the confidence scores.
- iou_threshold (float, optional): The intersection-over-union threshold
- to use for non-maximum suppression. Defaults to 0.5.
-
- Returns:
- List[List[int]]: Groups of prediction indices be merged.
- Each group may have 1 or more elements.
- """
- merge_groups: List[List[int]] = []
-
- scores = predictions[:, 4]
- order = scores.argsort()
-
- while len(order) > 0:
- idx = int(order[-1])
-
- order = order[:-1]
- if len(order) == 0:
- merge_groups.append([idx])
- break
-
- merge_candidate = np.expand_dims(predictions[idx], axis=0)
- ious = box_iou_batch(predictions[order][:, :4], merge_candidate[:, :4])
- ious = ious.flatten()
-
- above_threshold = ious >= iou_threshold
- merge_group = [idx] + np.flip(order[above_threshold]).tolist()
- merge_groups.append(merge_group)
- order = order[~above_threshold]
- return merge_groups
-
-
-def box_non_max_merge(
- predictions: npt.NDArray[np.float64],
- iou_threshold: float = 0.5,
-) -> List[List[int]]:
- """
- Apply greedy version of non-maximum merging per category to avoid detecting
- too many overlapping bounding boxes for a given object.
-
- Args:
- predictions (npt.NDArray[np.float64]): An array of shape `(n, 5)` or `(n, 6)`
- containing the bounding boxes coordinates in format `[x1, y1, x2, y2]`,
- the confidence scores and class_ids. Omit class_id column to allow
- detections of different classes to be merged.
- iou_threshold (float, optional): The intersection-over-union threshold
- to use for non-maximum suppression. Defaults to 0.5.
-
- Returns:
- List[List[int]]: Groups of prediction indices be merged.
- Each group may have 1 or more elements.
- """
- if predictions.shape[1] == 5:
- return group_overlapping_boxes(predictions, iou_threshold)
-
- category_ids = predictions[:, 5]
- merge_groups = []
- for category_id in np.unique(category_ids):
- curr_indices = np.where(category_ids == category_id)[0]
- merge_class_groups = group_overlapping_boxes(
- predictions[curr_indices], iou_threshold
- )
-
- for merge_class_group in merge_class_groups:
- merge_groups.append(curr_indices[merge_class_group].tolist())
-
- for merge_group in merge_groups:
- if len(merge_group) == 0:
- raise ValueError(
- f"Empty group detected when non-max-merging "
- f"detections: {merge_groups}"
- )
- return merge_groups
-
-
def clip_boxes(xyxy: np.ndarray, resolution_wh: Tuple[int, int]) -> np.ndarray:
"""
Clips bounding boxes coordinates to fit within the frame resolution.
diff --git a/test/detection/test_overlap_filter.py b/test/detection/test_overlap_filter.py
new file mode 100644
index 000000000..f628c30f9
--- /dev/null
+++ b/test/detection/test_overlap_filter.py
@@ -0,0 +1,449 @@
+from contextlib import ExitStack as DoesNotRaise
+from typing import List, Optional
+
+import numpy as np
+import pytest
+
+from supervision.detection.overlap_filter import (
+ box_non_max_suppression,
+ group_overlapping_boxes,
+ mask_non_max_suppression,
+)
+
+
+@pytest.mark.parametrize(
+ "predictions, iou_threshold, expected_result, exception",
+ [
+ (
+ np.empty(shape=(0, 5), dtype=float),
+ 0.5,
+ [],
+ DoesNotRaise(),
+ ),
+ (
+ np.array([[0, 0, 10, 10, 1.0]]),
+ 0.5,
+ [[0]],
+ DoesNotRaise(),
+ ),
+ (
+ np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]),
+ 0.5,
+ [[1, 0]],
+ DoesNotRaise(),
+ ), # High overlap, tie-break to second det
+ (
+ np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 0.99]]),
+ 0.5,
+ [[0, 1]],
+ DoesNotRaise(),
+ ), # High overlap, merge to high confidence
+ (
+ np.array([[0, 0, 10, 10, 0.99], [0, 0, 9, 9, 1.0]]),
+ 0.5,
+ [[1, 0]],
+ DoesNotRaise(),
+ ), # (test symmetry) High overlap, merge to high confidence
+ (
+ np.array([[0, 0, 10, 10, 0.90], [0, 0, 9, 9, 1.0]]),
+ 0.5,
+ [[1, 0]],
+ DoesNotRaise(),
+ ), # (test symmetry) High overlap, merge to high confidence
+ (
+ np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]),
+ 1.0,
+ [[1], [0]],
+ DoesNotRaise(),
+ ), # High IOU required
+ (
+ np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]),
+ 0.0,
+ [[1, 0]],
+ DoesNotRaise(),
+ ), # No IOU required
+ (
+ np.array([[0, 0, 10, 10, 1.0], [0, 0, 5, 5, 0.9]]),
+ 0.25,
+ [[0, 1]],
+ DoesNotRaise(),
+ ), # Below IOU requirement
+ (
+ np.array([[0, 0, 10, 10, 1.0], [0, 0, 5, 5, 0.9]]),
+ 0.26,
+ [[0], [1]],
+ DoesNotRaise(),
+ ), # Above IOU requirement
+ (
+ np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0], [0, 0, 8, 8, 1.0]]),
+ 0.5,
+ [[2, 1, 0]],
+ DoesNotRaise(),
+ ), # 3 boxes
+ (
+ np.array(
+ [
+ [0, 0, 10, 10, 1.0],
+ [0, 0, 9, 9, 1.0],
+ [5, 5, 10, 10, 1.0],
+ [6, 6, 10, 10, 1.0],
+ [9, 9, 10, 10, 1.0],
+ ]
+ ),
+ 0.5,
+ [[4], [3, 2], [1, 0]],
+ DoesNotRaise(),
+ ), # 5 boxes, 2 merges, 1 separate
+ (
+ np.array(
+ [
+ [0, 0, 2, 1, 1.0],
+ [1, 0, 3, 1, 1.0],
+ [2, 0, 4, 1, 1.0],
+ [3, 0, 5, 1, 1.0],
+ [4, 0, 6, 1, 1.0],
+ ]
+ ),
+ 0.33,
+ [[4, 3], [2, 1], [0]],
+ DoesNotRaise(),
+ ), # sequential merge, half overlap
+ (
+ np.array(
+ [
+ [0, 0, 2, 1, 0.9],
+ [1, 0, 3, 1, 0.9],
+ [2, 0, 4, 1, 1.0],
+ [3, 0, 5, 1, 0.9],
+ [4, 0, 6, 1, 0.9],
+ ]
+ ),
+ 0.33,
+ [[2, 3, 1], [4], [0]],
+ DoesNotRaise(),
+ ), # confidence
+ ],
+)
+def test_group_overlapping_boxes(
+ predictions: np.ndarray,
+ iou_threshold: float,
+ expected_result: List[List[int]],
+ exception: Exception,
+) -> None:
+ with exception:
+ result = group_overlapping_boxes(
+ predictions=predictions, iou_threshold=iou_threshold
+ )
+
+ assert result == expected_result
+
+
+@pytest.mark.parametrize(
+ "predictions, iou_threshold, expected_result, exception",
+ [
+ (
+ np.empty(shape=(0, 5)),
+ 0.5,
+ np.array([]),
+ DoesNotRaise(),
+ ), # single box with no category
+ (
+ np.array([[10.0, 10.0, 40.0, 40.0, 0.8]]),
+ 0.5,
+ np.array([True]),
+ DoesNotRaise(),
+ ), # single box with no category
+ (
+ np.array([[10.0, 10.0, 40.0, 40.0, 0.8, 0]]),
+ 0.5,
+ np.array([True]),
+ DoesNotRaise(),
+ ), # single box with category
+ (
+ np.array(
+ [
+ [10.0, 10.0, 40.0, 40.0, 0.8],
+ [15.0, 15.0, 40.0, 40.0, 0.9],
+ ]
+ ),
+ 0.5,
+ np.array([False, True]),
+ DoesNotRaise(),
+ ), # two boxes with no category
+ (
+ np.array(
+ [
+ [10.0, 10.0, 40.0, 40.0, 0.8, 0],
+ [15.0, 15.0, 40.0, 40.0, 0.9, 1],
+ ]
+ ),
+ 0.5,
+ np.array([True, True]),
+ DoesNotRaise(),
+ ), # two boxes with different category
+ (
+ np.array(
+ [
+ [10.0, 10.0, 40.0, 40.0, 0.8, 0],
+ [15.0, 15.0, 40.0, 40.0, 0.9, 0],
+ ]
+ ),
+ 0.5,
+ np.array([False, True]),
+ DoesNotRaise(),
+ ), # two boxes with same category
+ (
+ np.array(
+ [
+ [0.0, 0.0, 30.0, 40.0, 0.8],
+ [5.0, 5.0, 35.0, 45.0, 0.9],
+ [10.0, 10.0, 40.0, 50.0, 0.85],
+ ]
+ ),
+ 0.5,
+ np.array([False, True, False]),
+ DoesNotRaise(),
+ ), # three boxes with no category
+ (
+ np.array(
+ [
+ [0.0, 0.0, 30.0, 40.0, 0.8, 0],
+ [5.0, 5.0, 35.0, 45.0, 0.9, 1],
+ [10.0, 10.0, 40.0, 50.0, 0.85, 2],
+ ]
+ ),
+ 0.5,
+ np.array([True, True, True]),
+ DoesNotRaise(),
+ ), # three boxes with same category
+ (
+ np.array(
+ [
+ [0.0, 0.0, 30.0, 40.0, 0.8, 0],
+ [5.0, 5.0, 35.0, 45.0, 0.9, 0],
+ [10.0, 10.0, 40.0, 50.0, 0.85, 1],
+ ]
+ ),
+ 0.5,
+ np.array([False, True, True]),
+ DoesNotRaise(),
+ ), # three boxes with different category
+ ],
+)
+def test_box_non_max_suppression(
+ predictions: np.ndarray,
+ iou_threshold: float,
+ expected_result: Optional[np.ndarray],
+ exception: Exception,
+) -> None:
+ with exception:
+ result = box_non_max_suppression(
+ predictions=predictions, iou_threshold=iou_threshold
+ )
+ assert np.array_equal(result, expected_result)
+
+
+@pytest.mark.parametrize(
+ "predictions, masks, iou_threshold, expected_result, exception",
+ [
+ (
+ np.empty((0, 6)),
+ np.empty((0, 5, 5)),
+ 0.5,
+ np.array([]),
+ DoesNotRaise(),
+ ), # empty predictions and masks
+ (
+ np.array([[0, 0, 0, 0, 0.8]]),
+ np.array(
+ [
+ [
+ [False, False, False, False, False],
+ [False, True, True, True, False],
+ [False, True, True, True, False],
+ [False, True, True, True, False],
+ [False, False, False, False, False],
+ ]
+ ]
+ ),
+ 0.5,
+ np.array([True]),
+ DoesNotRaise(),
+ ), # single mask with no category
+ (
+ np.array([[0, 0, 0, 0, 0.8, 0]]),
+ np.array(
+ [
+ [
+ [False, False, False, False, False],
+ [False, True, True, True, False],
+ [False, True, True, True, False],
+ [False, True, True, True, False],
+ [False, False, False, False, False],
+ ]
+ ]
+ ),
+ 0.5,
+ np.array([True]),
+ DoesNotRaise(),
+ ), # single mask with category
+ (
+ np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]),
+ np.array(
+ [
+ [
+ [False, False, False, False, False],
+ [False, True, True, False, False],
+ [False, True, True, False, False],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ ],
+ [
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ [False, False, False, True, True],
+ [False, False, False, True, True],
+ [False, False, False, False, False],
+ ],
+ ]
+ ),
+ 0.5,
+ np.array([True, True]),
+ DoesNotRaise(),
+ ), # two masks non-overlapping with no category
+ (
+ np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]),
+ np.array(
+ [
+ [
+ [False, False, False, False, False],
+ [False, True, True, True, False],
+ [False, True, True, True, False],
+ [False, True, True, True, False],
+ [False, False, False, False, False],
+ ],
+ [
+ [False, False, False, False, False],
+ [False, False, True, True, True],
+ [False, False, True, True, True],
+ [False, False, True, True, True],
+ [False, False, False, False, False],
+ ],
+ ]
+ ),
+ 0.4,
+ np.array([False, True]),
+ DoesNotRaise(),
+ ), # two masks partially overlapping with no category
+ (
+ np.array([[0, 0, 0, 0, 0.8, 0], [0, 0, 0, 0, 0.9, 1]]),
+ np.array(
+ [
+ [
+ [False, False, False, False, False],
+ [False, True, True, True, False],
+ [False, True, True, True, False],
+ [False, True, True, True, False],
+ [False, False, False, False, False],
+ ],
+ [
+ [False, False, False, False, False],
+ [False, False, True, True, True],
+ [False, False, True, True, True],
+ [False, False, True, True, True],
+ [False, False, False, False, False],
+ ],
+ ]
+ ),
+ 0.5,
+ np.array([True, True]),
+ DoesNotRaise(),
+ ), # two masks partially overlapping with different category
+ (
+ np.array(
+ [
+ [0, 0, 0, 0, 0.8],
+ [0, 0, 0, 0, 0.85],
+ [0, 0, 0, 0, 0.9],
+ ]
+ ),
+ np.array(
+ [
+ [
+ [False, False, False, False, False],
+ [False, True, True, False, False],
+ [False, True, True, False, False],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ ],
+ [
+ [False, False, False, False, False],
+ [False, True, True, False, False],
+ [False, True, True, False, False],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ ],
+ [
+ [False, False, False, False, False],
+ [False, False, False, True, True],
+ [False, False, False, True, True],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ ],
+ ]
+ ),
+ 0.5,
+ np.array([False, True, True]),
+ DoesNotRaise(),
+ ), # three masks with no category
+ (
+ np.array(
+ [
+ [0, 0, 0, 0, 0.8, 0],
+ [0, 0, 0, 0, 0.85, 1],
+ [0, 0, 0, 0, 0.9, 2],
+ ]
+ ),
+ np.array(
+ [
+ [
+ [False, False, False, False, False],
+ [False, True, True, False, False],
+ [False, True, True, False, False],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ ],
+ [
+ [False, False, False, False, False],
+ [False, True, True, False, False],
+ [False, True, True, False, False],
+ [False, True, True, False, False],
+ [False, False, False, False, False],
+ ],
+ [
+ [False, False, False, False, False],
+ [False, True, True, False, False],
+ [False, True, True, False, False],
+ [False, False, False, False, False],
+ [False, False, False, False, False],
+ ],
+ ]
+ ),
+ 0.5,
+ np.array([True, True, True]),
+ DoesNotRaise(),
+ ), # three masks with different category
+ ],
+)
+def test_mask_non_max_suppression(
+ predictions: np.ndarray,
+ masks: np.ndarray,
+ iou_threshold: float,
+ expected_result: Optional[np.ndarray],
+ exception: Exception,
+) -> None:
+ with exception:
+ result = mask_non_max_suppression(
+ predictions=predictions, masks=masks, iou_threshold=iou_threshold
+ )
+ assert np.array_equal(result, expected_result)
diff --git a/test/detection/test_utils.py b/test/detection/test_utils.py
index 837b3b840..f0f0a6b13 100644
--- a/test/detection/test_utils.py
+++ b/test/detection/test_utils.py
@@ -7,15 +7,12 @@
from supervision.config import CLASS_NAME_DATA_FIELD
from supervision.detection.utils import (
- box_non_max_suppression,
calculate_masks_centroids,
clip_boxes,
contains_holes,
contains_multiple_segments,
filter_polygons_by_area,
get_data_item,
- group_overlapping_boxes,
- mask_non_max_suppression,
merge_data,
move_boxes,
process_roboflow_result,
@@ -26,444 +23,6 @@
TEST_MASK[:, 300:351, 200:251] = True
-@pytest.mark.parametrize(
- "predictions, iou_threshold, expected_result, exception",
- [
- (
- np.empty(shape=(0, 5)),
- 0.5,
- np.array([]),
- DoesNotRaise(),
- ), # single box with no category
- (
- np.array([[10.0, 10.0, 40.0, 40.0, 0.8]]),
- 0.5,
- np.array([True]),
- DoesNotRaise(),
- ), # single box with no category
- (
- np.array([[10.0, 10.0, 40.0, 40.0, 0.8, 0]]),
- 0.5,
- np.array([True]),
- DoesNotRaise(),
- ), # single box with category
- (
- np.array(
- [
- [10.0, 10.0, 40.0, 40.0, 0.8],
- [15.0, 15.0, 40.0, 40.0, 0.9],
- ]
- ),
- 0.5,
- np.array([False, True]),
- DoesNotRaise(),
- ), # two boxes with no category
- (
- np.array(
- [
- [10.0, 10.0, 40.0, 40.0, 0.8, 0],
- [15.0, 15.0, 40.0, 40.0, 0.9, 1],
- ]
- ),
- 0.5,
- np.array([True, True]),
- DoesNotRaise(),
- ), # two boxes with different category
- (
- np.array(
- [
- [10.0, 10.0, 40.0, 40.0, 0.8, 0],
- [15.0, 15.0, 40.0, 40.0, 0.9, 0],
- ]
- ),
- 0.5,
- np.array([False, True]),
- DoesNotRaise(),
- ), # two boxes with same category
- (
- np.array(
- [
- [0.0, 0.0, 30.0, 40.0, 0.8],
- [5.0, 5.0, 35.0, 45.0, 0.9],
- [10.0, 10.0, 40.0, 50.0, 0.85],
- ]
- ),
- 0.5,
- np.array([False, True, False]),
- DoesNotRaise(),
- ), # three boxes with no category
- (
- np.array(
- [
- [0.0, 0.0, 30.0, 40.0, 0.8, 0],
- [5.0, 5.0, 35.0, 45.0, 0.9, 1],
- [10.0, 10.0, 40.0, 50.0, 0.85, 2],
- ]
- ),
- 0.5,
- np.array([True, True, True]),
- DoesNotRaise(),
- ), # three boxes with same category
- (
- np.array(
- [
- [0.0, 0.0, 30.0, 40.0, 0.8, 0],
- [5.0, 5.0, 35.0, 45.0, 0.9, 0],
- [10.0, 10.0, 40.0, 50.0, 0.85, 1],
- ]
- ),
- 0.5,
- np.array([False, True, True]),
- DoesNotRaise(),
- ), # three boxes with different category
- ],
-)
-def test_box_non_max_suppression(
- predictions: np.ndarray,
- iou_threshold: float,
- expected_result: Optional[np.ndarray],
- exception: Exception,
-) -> None:
- with exception:
- result = box_non_max_suppression(
- predictions=predictions, iou_threshold=iou_threshold
- )
- assert np.array_equal(result, expected_result)
-
-
-@pytest.mark.parametrize(
- "predictions, iou_threshold, expected_result, exception",
- [
- (
- np.empty(shape=(0, 5), dtype=float),
- 0.5,
- [],
- DoesNotRaise(),
- ),
- (
- np.array([[0, 0, 10, 10, 1.0]]),
- 0.5,
- [[0]],
- DoesNotRaise(),
- ),
- (
- np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]),
- 0.5,
- [[1, 0]],
- DoesNotRaise(),
- ), # High overlap, tie-break to second det
- (
- np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 0.99]]),
- 0.5,
- [[0, 1]],
- DoesNotRaise(),
- ), # High overlap, merge to high confidence
- (
- np.array([[0, 0, 10, 10, 0.99], [0, 0, 9, 9, 1.0]]),
- 0.5,
- [[1, 0]],
- DoesNotRaise(),
- ), # (test symmetry) High overlap, merge to high confidence
- (
- np.array([[0, 0, 10, 10, 0.90], [0, 0, 9, 9, 1.0]]),
- 0.5,
- [[1, 0]],
- DoesNotRaise(),
- ), # (test symmetry) High overlap, merge to high confidence
- (
- np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]),
- 1.0,
- [[1], [0]],
- DoesNotRaise(),
- ), # High IOU required
- (
- np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0]]),
- 0.0,
- [[1, 0]],
- DoesNotRaise(),
- ), # No IOU required
- (
- np.array([[0, 0, 10, 10, 1.0], [0, 0, 5, 5, 0.9]]),
- 0.25,
- [[0, 1]],
- DoesNotRaise(),
- ), # Below IOU requirement
- (
- np.array([[0, 0, 10, 10, 1.0], [0, 0, 5, 5, 0.9]]),
- 0.26,
- [[0], [1]],
- DoesNotRaise(),
- ), # Above IOU requirement
- (
- np.array([[0, 0, 10, 10, 1.0], [0, 0, 9, 9, 1.0], [0, 0, 8, 8, 1.0]]),
- 0.5,
- [[2, 1, 0]],
- DoesNotRaise(),
- ), # 3 boxes
- (
- np.array(
- [
- [0, 0, 10, 10, 1.0],
- [0, 0, 9, 9, 1.0],
- [5, 5, 10, 10, 1.0],
- [6, 6, 10, 10, 1.0],
- [9, 9, 10, 10, 1.0],
- ]
- ),
- 0.5,
- [[4], [3, 2], [1, 0]],
- DoesNotRaise(),
- ), # 5 boxes, 2 merges, 1 separate
- (
- np.array(
- [
- [0, 0, 2, 1, 1.0],
- [1, 0, 3, 1, 1.0],
- [2, 0, 4, 1, 1.0],
- [3, 0, 5, 1, 1.0],
- [4, 0, 6, 1, 1.0],
- ]
- ),
- 0.33,
- [[4, 3], [2, 1], [0]],
- DoesNotRaise(),
- ), # sequential merge, half overlap
- (
- np.array(
- [
- [0, 0, 2, 1, 0.9],
- [1, 0, 3, 1, 0.9],
- [2, 0, 4, 1, 1.0],
- [3, 0, 5, 1, 0.9],
- [4, 0, 6, 1, 0.9],
- ]
- ),
- 0.33,
- [[2, 3, 1], [4], [0]],
- DoesNotRaise(),
- ), # confidence
- ],
-)
-def test_group_overlapping_boxes(
- predictions: np.ndarray,
- iou_threshold: float,
- expected_result: List[List[int]],
- exception: Exception,
-) -> None:
- with exception:
- result = group_overlapping_boxes(
- predictions=predictions, iou_threshold=iou_threshold
- )
-
- assert result == expected_result
-
-
-@pytest.mark.parametrize(
- "predictions, masks, iou_threshold, expected_result, exception",
- [
- (
- np.empty((0, 6)),
- np.empty((0, 5, 5)),
- 0.5,
- np.array([]),
- DoesNotRaise(),
- ), # empty predictions and masks
- (
- np.array([[0, 0, 0, 0, 0.8]]),
- np.array(
- [
- [
- [False, False, False, False, False],
- [False, True, True, True, False],
- [False, True, True, True, False],
- [False, True, True, True, False],
- [False, False, False, False, False],
- ]
- ]
- ),
- 0.5,
- np.array([True]),
- DoesNotRaise(),
- ), # single mask with no category
- (
- np.array([[0, 0, 0, 0, 0.8, 0]]),
- np.array(
- [
- [
- [False, False, False, False, False],
- [False, True, True, True, False],
- [False, True, True, True, False],
- [False, True, True, True, False],
- [False, False, False, False, False],
- ]
- ]
- ),
- 0.5,
- np.array([True]),
- DoesNotRaise(),
- ), # single mask with category
- (
- np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]),
- np.array(
- [
- [
- [False, False, False, False, False],
- [False, True, True, False, False],
- [False, True, True, False, False],
- [False, False, False, False, False],
- [False, False, False, False, False],
- ],
- [
- [False, False, False, False, False],
- [False, False, False, False, False],
- [False, False, False, True, True],
- [False, False, False, True, True],
- [False, False, False, False, False],
- ],
- ]
- ),
- 0.5,
- np.array([True, True]),
- DoesNotRaise(),
- ), # two masks non-overlapping with no category
- (
- np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]),
- np.array(
- [
- [
- [False, False, False, False, False],
- [False, True, True, True, False],
- [False, True, True, True, False],
- [False, True, True, True, False],
- [False, False, False, False, False],
- ],
- [
- [False, False, False, False, False],
- [False, False, True, True, True],
- [False, False, True, True, True],
- [False, False, True, True, True],
- [False, False, False, False, False],
- ],
- ]
- ),
- 0.4,
- np.array([False, True]),
- DoesNotRaise(),
- ), # two masks partially overlapping with no category
- (
- np.array([[0, 0, 0, 0, 0.8, 0], [0, 0, 0, 0, 0.9, 1]]),
- np.array(
- [
- [
- [False, False, False, False, False],
- [False, True, True, True, False],
- [False, True, True, True, False],
- [False, True, True, True, False],
- [False, False, False, False, False],
- ],
- [
- [False, False, False, False, False],
- [False, False, True, True, True],
- [False, False, True, True, True],
- [False, False, True, True, True],
- [False, False, False, False, False],
- ],
- ]
- ),
- 0.5,
- np.array([True, True]),
- DoesNotRaise(),
- ), # two masks partially overlapping with different category
- (
- np.array(
- [
- [0, 0, 0, 0, 0.8],
- [0, 0, 0, 0, 0.85],
- [0, 0, 0, 0, 0.9],
- ]
- ),
- np.array(
- [
- [
- [False, False, False, False, False],
- [False, True, True, False, False],
- [False, True, True, False, False],
- [False, False, False, False, False],
- [False, False, False, False, False],
- ],
- [
- [False, False, False, False, False],
- [False, True, True, False, False],
- [False, True, True, False, False],
- [False, False, False, False, False],
- [False, False, False, False, False],
- ],
- [
- [False, False, False, False, False],
- [False, False, False, True, True],
- [False, False, False, True, True],
- [False, False, False, False, False],
- [False, False, False, False, False],
- ],
- ]
- ),
- 0.5,
- np.array([False, True, True]),
- DoesNotRaise(),
- ), # three masks with no category
- (
- np.array(
- [
- [0, 0, 0, 0, 0.8, 0],
- [0, 0, 0, 0, 0.85, 1],
- [0, 0, 0, 0, 0.9, 2],
- ]
- ),
- np.array(
- [
- [
- [False, False, False, False, False],
- [False, True, True, False, False],
- [False, True, True, False, False],
- [False, False, False, False, False],
- [False, False, False, False, False],
- ],
- [
- [False, False, False, False, False],
- [False, True, True, False, False],
- [False, True, True, False, False],
- [False, True, True, False, False],
- [False, False, False, False, False],
- ],
- [
- [False, False, False, False, False],
- [False, True, True, False, False],
- [False, True, True, False, False],
- [False, False, False, False, False],
- [False, False, False, False, False],
- ],
- ]
- ),
- 0.5,
- np.array([True, True, True]),
- DoesNotRaise(),
- ), # three masks with different category
- ],
-)
-def test_mask_non_max_suppression(
- predictions: np.ndarray,
- masks: np.ndarray,
- iou_threshold: float,
- expected_result: Optional[np.ndarray],
- exception: Exception,
-) -> None:
- with exception:
- result = mask_non_max_suppression(
- predictions=predictions, masks=masks, iou_threshold=iou_threshold
- )
- assert np.array_equal(result, expected_result)
-
-
@pytest.mark.parametrize(
"xyxy, resolution_wh, expected_result",
[