From f7bd7011fedd9f2ec802997f20b6d061e595d710 Mon Sep 17 00:00:00 2001 From: Linas Kondrackis Date: Thu, 30 May 2024 13:45:12 +0300 Subject: [PATCH] OverlapFilter: Add and validate docs, rename --- docs/detection/double_detection_filter.md | 30 ++++++++++++++++ docs/detection/tools/inference_slicer.md | 4 --- docs/detection/utils.md | 18 ---------- mkdocs.yml | 1 + supervision/__init__.py | 4 +-- supervision/detection/core.py | 2 +- ...{overlap_handling.py => overlap_filter.py} | 16 ++++----- .../detection/tools/inference_slicer.py | 35 ++++++++----------- ...lap_handling.py => test_overlap_filter.py} | 2 +- 9 files changed, 57 insertions(+), 55 deletions(-) create mode 100644 docs/detection/double_detection_filter.md rename supervision/detection/{overlap_handling.py => overlap_filter.py} (94%) rename test/detection/{test_overlap_handling.py => test_overlap_filter.py} (99%) diff --git a/docs/detection/double_detection_filter.md b/docs/detection/double_detection_filter.md new file mode 100644 index 000000000..1631852f4 --- /dev/null +++ b/docs/detection/double_detection_filter.md @@ -0,0 +1,30 @@ +--- +comments: true +status: new +--- + +# Double Detection Filter + +
+

OverlapFilter

+
+ +:::supervision.detection.overlap_filter.OverlapFilter + +
+

box_non_max_suppression

+
+ +:::supervision.detection.overlap_filter.box_non_max_suppression + +
+

mask_non_max_suppression

+
+ +:::supervision.detection.overlap_filter.mask_non_max_suppression + +
+

box_non_max_merge

+
+ +:::supervision.detection.overlap_filter.box_non_max_merge diff --git a/docs/detection/tools/inference_slicer.md b/docs/detection/tools/inference_slicer.md index 51301e86f..5d5d08bc5 100644 --- a/docs/detection/tools/inference_slicer.md +++ b/docs/detection/tools/inference_slicer.md @@ -5,7 +5,3 @@ comments: true # InferenceSlicer :::supervision.detection.tools.inference_slicer.InferenceSlicer - -# Overlap Handling Strategy - -:::supervision.detection.overlap_handling.OverlapHandlingStrategy diff --git a/docs/detection/utils.md b/docs/detection/utils.md index dd14a23e2..369746a3e 100644 --- a/docs/detection/utils.md +++ b/docs/detection/utils.md @@ -17,24 +17,6 @@ status: new :::supervision.detection.utils.mask_iou_batch -
-

box_non_max_suppression

-
- -:::supervision.detection.overlap_handling.box_non_max_suppression - -
-

mask_non_max_suppression

-
- -:::supervision.detection.overlap_handling.mask_non_max_suppression - -
-

box_non_max_merge

-
- -:::supervision.detection.overlap_handling.box_non_max_merge -

polygon_to_mask

diff --git a/mkdocs.yml b/mkdocs.yml index f257238df..19d6a4fd4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -48,6 +48,7 @@ nav: - Core: detection/core.md - Annotators: detection/annotators.md - Metrics: detection/metrics.md + - Double Detection Filter: detection/double_detection_filter.md - Utils: detection/utils.md - Keypoint Detection: - Core: keypoint/core.md diff --git a/supervision/__init__.py b/supervision/__init__.py index 83e677952..4f28d49f4 100644 --- a/supervision/__init__.py +++ b/supervision/__init__.py @@ -40,8 +40,8 @@ from supervision.detection.core import Detections from supervision.detection.line_zone import LineZone, LineZoneAnnotator from supervision.detection.lmm import LMM -from supervision.detection.overlap_handling import ( - OverlapHandlingStrategy, +from supervision.detection.overlap_filter import ( + OverlapFilter, box_non_max_merge, box_non_max_suppression, mask_non_max_suppression, diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 482e11093..c47bc8197 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -8,7 +8,7 @@ from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES from supervision.detection.lmm import LMM, from_paligemma, validate_lmm_and_kwargs -from supervision.detection.overlap_handling import ( +from supervision.detection.overlap_filter import ( box_non_max_merge, box_non_max_suppression, mask_non_max_suppression, diff --git a/supervision/detection/overlap_handling.py b/supervision/detection/overlap_filter.py similarity index 94% rename from supervision/detection/overlap_handling.py rename to supervision/detection/overlap_filter.py index f9acc4a6c..ab4408d18 100644 --- a/supervision/detection/overlap_handling.py +++ b/supervision/detection/overlap_filter.py @@ -230,7 +230,7 @@ def box_non_max_merge( return merge_groups -class OverlapHandlingStrategy(Enum): +class OverlapFilter(Enum): """ Enum specifying the strategy for filtering overlapping detections. @@ -239,11 +239,9 @@ class OverlapHandlingStrategy(Enum): NON_MAX_SUPPRESSION: Filter detections using non-max suppression. This means, detections that overlap by more than a set threshold will be discarded, except for the one with the highest confidence. - NON_MAX_MERGE: Merge detections with non-max-merging. This means, + NON_MAX_MERGE: Merge detections with non-max merging. This means, detections that overlap by more than a set threshold will be merged into a single detection. - - ![overlap-handling-strategies-example](https://media.roboflow.com/supervision-docs/overlap-handling-strategies-example.png) """ NONE = "none" @@ -251,15 +249,15 @@ class OverlapHandlingStrategy(Enum): NON_MAX_MERGE = "non_max_merge" -def validate_overlapping_handling_strategy( - strategy: Union[OverlapHandlingStrategy, str], -) -> OverlapHandlingStrategy: +def validate_overlap_filter( + strategy: Union[OverlapFilter, str], +) -> OverlapFilter: if isinstance(strategy, str): try: - strategy = OverlapHandlingStrategy(strategy.lower()) + strategy = OverlapFilter(strategy.lower()) except ValueError: raise ValueError( f"Invalid strategy value: {strategy}. Must be one of " - f"{[e.value for e in OverlapHandlingStrategy]}" + f"{[e.value for e in OverlapFilter]}" ) return strategy diff --git a/supervision/detection/tools/inference_slicer.py b/supervision/detection/tools/inference_slicer.py index 372352e6c..134361bd3 100644 --- a/supervision/detection/tools/inference_slicer.py +++ b/supervision/detection/tools/inference_slicer.py @@ -5,10 +5,7 @@ import numpy as np from supervision.detection.core import Detections -from supervision.detection.overlap_handling import ( - OverlapHandlingStrategy, - validate_overlapping_handling_strategy, -) +from supervision.detection.overlap_filter import OverlapFilter, validate_overlap_filter from supervision.detection.utils import move_boxes, move_masks from supervision.utils.image import crop_image from supervision.utils.internal import SupervisionWarnings @@ -56,7 +53,7 @@ class InferenceSlicer: `(width, height)`. overlap_ratio_wh (Tuple[float, float]): Overlap ratio between consecutive slices in the format `(width_ratio, height_ratio)`. - overlap_handling_strategy (Union[OverlapHandlingStrategy, str]): Strategy for + overlap_filter_strategy (Union[OverlapFilter, str]): Strategy for filtering or merging overlapping detections in slices. iou_threshold (float): Intersection over Union (IoU) threshold used when filtering by overlap. @@ -76,20 +73,18 @@ def __init__( callback: Callable[[np.ndarray], Detections], slice_wh: Tuple[int, int] = (320, 320), overlap_ratio_wh: Tuple[float, float] = (0.2, 0.2), - overlap_handling_strategy: Union[ - OverlapHandlingStrategy, str - ] = OverlapHandlingStrategy.NON_MAX_SUPPRESSION, + overlap_filter_strategy: Union[ + OverlapFilter, str + ] = OverlapFilter.NON_MAX_SUPPRESSION, iou_threshold: float = 0.5, thread_workers: int = 1, ): - overlap_handling_strategy = validate_overlapping_handling_strategy( - overlap_handling_strategy - ) + overlap_filter_strategy = validate_overlap_filter(overlap_filter_strategy) self.slice_wh = slice_wh self.overlap_ratio_wh = overlap_ratio_wh self.iou_threshold = iou_threshold - self.overlap_handling_strategy = overlap_handling_strategy + self.overlap_filter_strategy = overlap_filter_strategy self.callback = callback self.thread_workers = thread_workers @@ -120,7 +115,10 @@ def callback(image_slice: np.ndarray) -> sv.Detections: result = model(image_slice)[0] return sv.Detections.from_ultralytics(result) - slicer = sv.InferenceSlicer(callback = callback) + slicer = sv.InferenceSlicer( + callback=callback, + overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION, + ) detections = slicer(image) ``` @@ -141,18 +139,15 @@ def callback(image_slice: np.ndarray) -> sv.Detections: detections_list.append(future.result()) merged = Detections.merge(detections_list=detections_list) - if self.overlap_handling_strategy == OverlapHandlingStrategy.NONE: + if self.overlap_filter_strategy == OverlapFilter.NONE: return merged - elif ( - self.overlap_handling_strategy - == OverlapHandlingStrategy.NON_MAX_SUPPRESSION - ): + elif self.overlap_filter_strategy == OverlapFilter.NON_MAX_SUPPRESSION: return merged.with_nms(threshold=self.iou_threshold) - elif self.overlap_handling_strategy == OverlapHandlingStrategy.NON_MAX_MERGE: + elif self.overlap_filter_strategy == OverlapFilter.NON_MAX_MERGE: return merged.with_nmm(threshold=self.iou_threshold) else: warnings.warn( - f"Invalid overlap filter strategy: {self.overlap_handling_strategy}", + f"Invalid overlap filter strategy: {self.overlap_filter_strategy}", category=SupervisionWarnings, ) return merged diff --git a/test/detection/test_overlap_handling.py b/test/detection/test_overlap_filter.py similarity index 99% rename from test/detection/test_overlap_handling.py rename to test/detection/test_overlap_filter.py index 0186a23e2..f628c30f9 100644 --- a/test/detection/test_overlap_handling.py +++ b/test/detection/test_overlap_filter.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from supervision.detection.overlap_handling import ( +from supervision.detection.overlap_filter import ( box_non_max_suppression, group_overlapping_boxes, mask_non_max_suppression,