Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add minimum_consecutive_frames argument to ByteTrack #1050

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 25 additions & 6 deletions supervision/tracker/byte_tracker/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class STrack(BaseTrack):
shared_kalman = KalmanFilter()

def __init__(self, tlwh, score, class_ids):
def __init__(self, tlwh, score, class_ids, minimum_consecutive_frames):
# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float32)
self.kalman_filter = None
Expand All @@ -23,6 +23,9 @@ def __init__(self, tlwh, score, class_ids):
self.score = score
self.class_ids = class_ids
self.tracklet_len = 0
self.track_id = -1

self.minimum_consecutive_frames = minimum_consecutive_frames

def predict(self):
mean_state = self.mean.copy()
Expand Down Expand Up @@ -53,7 +56,6 @@ def multi_predict(stracks):
def activate(self, kalman_filter, frame_id):
"""Start a new tracklet"""
self.kalman_filter = kalman_filter
self.track_id = self.next_id()
self.mean, self.covariance = self.kalman_filter.initiate(
self.tlwh_to_xyah(self._tlwh)
)
Expand All @@ -73,7 +75,7 @@ def re_activate(self, new_track, frame_id, new_id=False):
self.state = TrackState.Tracked
self.is_activated = True
self.frame_id = frame_id
if new_id:
if new_id and (self.minimum_consecutive_frames == 1):
self.track_id = self.next_id()
self.score = new_track.score

Expand All @@ -97,6 +99,9 @@ def update(self, new_track, frame_id):

self.score = new_track.score

if self.tracklet_len == self.minimum_consecutive_frames:
self.track_id = self.next_id()

@property
def tlwh(self):
"""Get current position in bounding box format `(top left x, top left y,
Expand Down Expand Up @@ -186,6 +191,10 @@ class ByteTrack:
Increasing minimum_matching_threshold improves accuracy but risks fragmentation.
Decreasing it improves completeness but risks false positives and drift.
frame_rate (int, optional): The frame rate of the video.
minimum_consecutive_frames (int, optional): Number of consecutive frames that an object must
be tracked before it is considered a 'valid' track.
Increasing minimum_consecutive_frames prevents the creation of accidental tracks from
false detection or double detection, but risks missing shorter tracks.
""" # noqa: E501 // docs

@deprecated_parameter(
Expand Down Expand Up @@ -218,19 +227,24 @@ def __init__(
lost_track_buffer: int = 30,
minimum_matching_threshold: float = 0.8,
frame_rate: int = 30,
minimum_consecutive_frames: int = 1,
):
self.track_activation_threshold = track_activation_threshold
self.minimum_matching_threshold = minimum_matching_threshold

self.frame_id = 0
self.det_thresh = self.track_activation_threshold + 0.1
self.max_time_lost = int(frame_rate / 30.0 * lost_track_buffer)
self.minimum_consecutive_frames = minimum_consecutive_frames

self.kalman_filter = KalmanFilter()

self.tracked_tracks: List[STrack] = []
self.lost_tracks: List[STrack] = []
self.removed_tracks: List[STrack] = []

BaseTrack.reset_counter()

def update_with_detections(self, detections: Detections) -> Detections:
"""
Updates the tracker with the provided detections and returns the updated
Expand Down Expand Up @@ -290,6 +304,7 @@ def callback(frame: np.ndarray, index: int) -> np.ndarray:
return detections[detections.tracker_id != -1]

else:
detections = Detections.empty()
detections.tracker_id = np.array([], dtype=int)

return detections
Expand Down Expand Up @@ -345,7 +360,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
if len(dets) > 0:
"""Detections"""
detections = [
STrack(STrack.tlbr_to_tlwh(tlbr), s, c)
STrack(STrack.tlbr_to_tlwh(tlbr), s, c, self.minimum_consecutive_frames)
for (tlbr, s, c) in zip(dets, scores_keep, class_ids_keep)
]
else:
Expand Down Expand Up @@ -387,7 +402,7 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
if len(dets_second) > 0:
"""Detections"""
detections_second = [
STrack(STrack.tlbr_to_tlwh(tlbr), s, c)
STrack(STrack.tlbr_to_tlwh(tlbr), s, c, self.minimum_consecutive_frames)
for (tlbr, s, c) in zip(dets_second, scores_second, class_ids_second)
]
else:
Expand Down Expand Up @@ -458,7 +473,11 @@ def update_with_tensors(self, tensors: np.ndarray) -> List[STrack]:
self.tracked_tracks, self.lost_tracks = remove_duplicate_tracks(
self.tracked_tracks, self.lost_tracks
)
output_stracks = [track for track in self.tracked_tracks if track.is_activated]
output_stracks = [
track
for track in self.tracked_tracks
if (track.is_activated and track.track_id >= 0)
]

return output_stracks

Expand Down