Skip to content

Commit

Permalink
Updated NonMaxSupression Op with score_threshold attribute
Browse files Browse the repository at this point in the history
Signed-off-by: Sai Chaitanya Gajula <[email protected]>
  • Loading branch information
quic-gsaichai authored and quic-akhobare committed Oct 30, 2023
1 parent 91da3da commit 4ca3836
Showing 1 changed file with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,10 @@ class NonMaxSuppression(torch.nn.Module):
"""
Implementation of NMS Op in the form of nn.Module
"""
def __init__(self, iou_threshold: float, max_output_boxes_per_class: int):
def __init__(self, iou_threshold: float, score_threshold: float, max_output_boxes_per_class: int):
super().__init__()
self.iou_threshold = iou_threshold
self.score_threshold = score_threshold
self.max_output_boxes_per_class = max_output_boxes_per_class

def forward(self, *args) -> torch.Tensor:
Expand All @@ -360,7 +361,11 @@ def forward(self, *args) -> torch.Tensor:
res = []
for index, (boxes, scores) in enumerate(zip(batches_boxes, batch_scores)):
for class_index, classes_score in enumerate(scores):
res_ = torchvision.ops.nms(boxes, classes_score, self.iou_threshold)
filtered_score_ind = (classes_score > self.score_threshold).nonzero()[:, 0]
boxes = boxes[filtered_score_ind, :]
classes_score = classes_score[filtered_score_ind]
temp_res = torchvision.ops.nms(boxes, classes_score, self.iou_threshold)
res_ = filtered_score_ind[temp_res]
for val in res_:
res.append([index, class_index, val.detach()])
res = res[:(self.max_output_boxes_per_class *(index+1))]
Expand Down

0 comments on commit 4ca3836

Please sign in to comment.