Skip to content

Commit

Permalink
OverlapFilter: Add and validate docs, rename
Browse files Browse the repository at this point in the history
  • Loading branch information
Linas Kondrackis committed May 30, 2024
1 parent 34353aa commit f7bd701
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 55 deletions.
30 changes: 30 additions & 0 deletions docs/detection/double_detection_filter.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
---
comments: true
status: new
---

# Double Detection Filter

<div class="md-typeset">
<h2><a href="#supervision.detection.overlap_filter.OverlapFilter">OverlapFilter</a></h2>
</div>

:::supervision.detection.overlap_filter.OverlapFilter

<div class="md-typeset">
<h2><a href="#supervision.detection.overlap_filter.box_non_max_suppression">box_non_max_suppression</a></h2>
</div>

:::supervision.detection.overlap_filter.box_non_max_suppression

<div class="md-typeset">
<h2><a href="#supervision.detection.overlap_filter.mask_non_max_suppression">mask_non_max_suppression</a></h2>
</div>

:::supervision.detection.overlap_filter.mask_non_max_suppression

<div class="md-typeset">
<h2><a href="#supervision.detection.overlap_filter.box_non_max_merge">box_non_max_merge</a></h2>
</div>

:::supervision.detection.overlap_filter.box_non_max_merge
4 changes: 0 additions & 4 deletions docs/detection/tools/inference_slicer.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,3 @@ comments: true
# InferenceSlicer

:::supervision.detection.tools.inference_slicer.InferenceSlicer

# Overlap Handling Strategy

:::supervision.detection.overlap_handling.OverlapHandlingStrategy
18 changes: 0 additions & 18 deletions docs/detection/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,6 @@ status: new

:::supervision.detection.utils.mask_iou_batch

<div class="md-typeset">
<h2><a href="#supervision.detection.overlap_handling.box_non_max_suppression">box_non_max_suppression</a></h2>
</div>

:::supervision.detection.overlap_handling.box_non_max_suppression

<div class="md-typeset">
<h2><a href="#supervision.detection.overlap_handling.mask_non_max_suppression">mask_non_max_suppression</a></h2>
</div>

:::supervision.detection.overlap_handling.mask_non_max_suppression

<div class="md-typeset">
<h2><a href="#supervision.detection.overlap_handling.box_non_max_merge">box_non_max_merge</a></h2>
</div>

:::supervision.detection.overlap_handling.box_non_max_merge

<div class="md-typeset">
<h2><a href="#supervision.detection.utils.polygon_to_mask">polygon_to_mask</a></h2>
</div>
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ nav:
- Core: detection/core.md
- Annotators: detection/annotators.md
- Metrics: detection/metrics.md
- Double Detection Filter: detection/double_detection_filter.md
- Utils: detection/utils.md
- Keypoint Detection:
- Core: keypoint/core.md
Expand Down
4 changes: 2 additions & 2 deletions supervision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
from supervision.detection.core import Detections
from supervision.detection.line_zone import LineZone, LineZoneAnnotator
from supervision.detection.lmm import LMM
from supervision.detection.overlap_handling import (
OverlapHandlingStrategy,
from supervision.detection.overlap_filter import (
OverlapFilter,
box_non_max_merge,
box_non_max_suppression,
mask_non_max_suppression,
Expand Down
2 changes: 1 addition & 1 deletion supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from supervision.config import CLASS_NAME_DATA_FIELD, ORIENTED_BOX_COORDINATES
from supervision.detection.lmm import LMM, from_paligemma, validate_lmm_and_kwargs
from supervision.detection.overlap_handling import (
from supervision.detection.overlap_filter import (
box_non_max_merge,
box_non_max_suppression,
mask_non_max_suppression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def box_non_max_merge(
return merge_groups


class OverlapHandlingStrategy(Enum):
class OverlapFilter(Enum):
"""
Enum specifying the strategy for filtering overlapping detections.
Expand All @@ -239,27 +239,25 @@ class OverlapHandlingStrategy(Enum):
NON_MAX_SUPPRESSION: Filter detections using non-max suppression. This means,
detections that overlap by more than a set threshold will be discarded,
except for the one with the highest confidence.
NON_MAX_MERGE: Merge detections with non-max-merging. This means,
NON_MAX_MERGE: Merge detections with non-max merging. This means,
detections that overlap by more than a set threshold will be merged
into a single detection.
![overlap-handling-strategies-example](https://media.roboflow.com/supervision-docs/overlap-handling-strategies-example.png)
"""

NONE = "none"
NON_MAX_SUPPRESSION = "non_max_suppression"
NON_MAX_MERGE = "non_max_merge"


def validate_overlapping_handling_strategy(
strategy: Union[OverlapHandlingStrategy, str],
) -> OverlapHandlingStrategy:
def validate_overlap_filter(
strategy: Union[OverlapFilter, str],
) -> OverlapFilter:
if isinstance(strategy, str):
try:
strategy = OverlapHandlingStrategy(strategy.lower())
strategy = OverlapFilter(strategy.lower())
except ValueError:
raise ValueError(
f"Invalid strategy value: {strategy}. Must be one of "
f"{[e.value for e in OverlapHandlingStrategy]}"
f"{[e.value for e in OverlapFilter]}"
)
return strategy
35 changes: 15 additions & 20 deletions supervision/detection/tools/inference_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import numpy as np

from supervision.detection.core import Detections
from supervision.detection.overlap_handling import (
OverlapHandlingStrategy,
validate_overlapping_handling_strategy,
)
from supervision.detection.overlap_filter import OverlapFilter, validate_overlap_filter
from supervision.detection.utils import move_boxes, move_masks
from supervision.utils.image import crop_image
from supervision.utils.internal import SupervisionWarnings
Expand Down Expand Up @@ -56,7 +53,7 @@ class InferenceSlicer:
`(width, height)`.
overlap_ratio_wh (Tuple[float, float]): Overlap ratio between consecutive
slices in the format `(width_ratio, height_ratio)`.
overlap_handling_strategy (Union[OverlapHandlingStrategy, str]): Strategy for
overlap_filter_strategy (Union[OverlapFilter, str]): Strategy for
filtering or merging overlapping detections in slices.
iou_threshold (float): Intersection over Union (IoU) threshold
used when filtering by overlap.
Expand All @@ -76,20 +73,18 @@ def __init__(
callback: Callable[[np.ndarray], Detections],
slice_wh: Tuple[int, int] = (320, 320),
overlap_ratio_wh: Tuple[float, float] = (0.2, 0.2),
overlap_handling_strategy: Union[
OverlapHandlingStrategy, str
] = OverlapHandlingStrategy.NON_MAX_SUPPRESSION,
overlap_filter_strategy: Union[
OverlapFilter, str
] = OverlapFilter.NON_MAX_SUPPRESSION,
iou_threshold: float = 0.5,
thread_workers: int = 1,
):
overlap_handling_strategy = validate_overlapping_handling_strategy(
overlap_handling_strategy
)
overlap_filter_strategy = validate_overlap_filter(overlap_filter_strategy)

self.slice_wh = slice_wh
self.overlap_ratio_wh = overlap_ratio_wh
self.iou_threshold = iou_threshold
self.overlap_handling_strategy = overlap_handling_strategy
self.overlap_filter_strategy = overlap_filter_strategy
self.callback = callback
self.thread_workers = thread_workers

Expand Down Expand Up @@ -120,7 +115,10 @@ def callback(image_slice: np.ndarray) -> sv.Detections:
result = model(image_slice)[0]
return sv.Detections.from_ultralytics(result)
slicer = sv.InferenceSlicer(callback = callback)
slicer = sv.InferenceSlicer(
callback=callback,
overlap_filter_strategy=sv.OverlapFilter.NON_MAX_SUPPRESSION,
)
detections = slicer(image)
```
Expand All @@ -141,18 +139,15 @@ def callback(image_slice: np.ndarray) -> sv.Detections:
detections_list.append(future.result())

merged = Detections.merge(detections_list=detections_list)
if self.overlap_handling_strategy == OverlapHandlingStrategy.NONE:
if self.overlap_filter_strategy == OverlapFilter.NONE:
return merged
elif (
self.overlap_handling_strategy
== OverlapHandlingStrategy.NON_MAX_SUPPRESSION
):
elif self.overlap_filter_strategy == OverlapFilter.NON_MAX_SUPPRESSION:
return merged.with_nms(threshold=self.iou_threshold)
elif self.overlap_handling_strategy == OverlapHandlingStrategy.NON_MAX_MERGE:
elif self.overlap_filter_strategy == OverlapFilter.NON_MAX_MERGE:
return merged.with_nmm(threshold=self.iou_threshold)
else:
warnings.warn(
f"Invalid overlap filter strategy: {self.overlap_handling_strategy}",
f"Invalid overlap filter strategy: {self.overlap_filter_strategy}",
category=SupervisionWarnings,
)
return merged
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pytest

from supervision.detection.overlap_handling import (
from supervision.detection.overlap_filter import (
box_non_max_suppression,
group_overlapping_boxes,
mask_non_max_suppression,
Expand Down

0 comments on commit f7bd701

Please sign in to comment.