diff --git a/supervision/tracker/byte_tracker/core.py b/supervision/tracker/byte_tracker/core.py index c77878cf9..bbe92a9cd 100644 --- a/supervision/tracker/byte_tracker/core.py +++ b/supervision/tracker/byte_tracker/core.py @@ -13,7 +13,7 @@ class STrack(BaseTrack): shared_kalman = KalmanFilter() - def __init__(self, tlwh, score, class_ids): + def __init__(self, tlwh, score, class_ids, minimum_consecutive_frames): # wait activate self._tlwh = np.asarray(tlwh, dtype=np.float32) self.kalman_filter = None @@ -23,6 +23,9 @@ def __init__(self, tlwh, score, class_ids): self.score = score self.class_ids = class_ids self.tracklet_len = 0 + self.track_id = -1 + + self.minimum_consecutive_frames = minimum_consecutive_frames def predict(self): mean_state = self.mean.copy() @@ -53,7 +56,6 @@ def multi_predict(stracks): def activate(self, kalman_filter, frame_id): """Start a new tracklet""" self.kalman_filter = kalman_filter - self.track_id = self.next_id() self.mean, self.covariance = self.kalman_filter.initiate( self.tlwh_to_xyah(self._tlwh) ) @@ -73,7 +75,7 @@ def re_activate(self, new_track, frame_id, new_id=False): self.state = TrackState.Tracked self.is_activated = True self.frame_id = frame_id - if new_id: + if new_id and (self.minimum_consecutive_frames == 1): self.track_id = self.next_id() self.score = new_track.score @@ -97,6 +99,9 @@ def update(self, new_track, frame_id): self.score = new_track.score + if self.tracklet_len == self.minimum_consecutive_frames: + self.track_id = self.next_id() + @property def tlwh(self): """Get current position in bounding box format `(top left x, top left y, @@ -186,6 +191,10 @@ class ByteTrack: Increasing minimum_matching_threshold improves accuracy but risks fragmentation. Decreasing it improves completeness but risks false positives and drift. frame_rate (int, optional): The frame rate of the video. + minimum_consecutive_frames (int, optional): Number of consecutive frames that an object must + be tracked before it is considered a 'valid' track. + Increasing minimum_consecutive_frames prevents the creation of accidental tracks from + false detection or double detection, but risks missing shorter tracks. """ # noqa: E501 // docs @deprecated_parameter( @@ -218,6 +227,7 @@ def __init__( lost_track_buffer: int = 30, minimum_matching_threshold: float = 0.8, frame_rate: int = 30, + minimum_consecutive_frames: int = 1, ): self.track_activation_threshold = track_activation_threshold self.minimum_matching_threshold = minimum_matching_threshold @@ -225,12 +235,16 @@ def __init__( self.frame_id = 0 self.det_thresh = self.track_activation_threshold + 0.1 self.max_time_lost = int(frame_rate / 30.0 * lost_track_buffer) + self.minimum_consecutive_frames = minimum_consecutive_frames + self.kalman_filter = KalmanFilter() self.tracked_tracks: List[STrack] = [] self.lost_tracks: List[STrack] = [] self.removed_tracks: List[STrack] = [] + BaseTrack.reset_counter() + def update_with_detections(self, detections: Detections) -> Detections: """ Updates the tracker with the provided detections and returns the updated @@ -290,6 +304,7 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray: return detections[detections.tracker_id != -1] else: + detections = Detections.empty() detections.tracker_id = np.array([], dtype=int) return detections @@ -345,7 +360,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: if len(dets) > 0: """Detections""" detections = [ - STrack(STrack.tlbr_to_tlwh(tlbr), s, c) + STrack(STrack.tlbr_to_tlwh(tlbr), s, c, self.minimum_consecutive_frames) for (tlbr, s, c) in zip(dets, scores_keep, class_ids_keep) ] else: @@ -387,7 +402,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: if len(dets_second) > 0: """Detections""" detections_second = [ - STrack(STrack.tlbr_to_tlwh(tlbr), s, c) + STrack(STrack.tlbr_to_tlwh(tlbr), s, c, self.minimum_consecutive_frames) for (tlbr, s, c) in zip(dets_second, scores_second, class_ids_second) ] else: @@ -458,7 +473,11 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]: self.tracked_tracks, self.lost_tracks = remove_duplicate_tracks( self.tracked_tracks, self.lost_tracks ) - output_stracks = [track for track in self.tracked_tracks if track.is_activated] + output_stracks = [ + track + for track in self.tracked_tracks + if (track.is_activated and track.track_id >= 0) + ] return output_stracks