byte_tracker.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. import numpy as np
  2. import os
  3. import os.path as osp
  4. from .kalman_filter import KalmanFilter
  5. from .matching import iou_distance, fuse_score, linear_assignment
  6. from .basetrack import BaseTrack, TrackState
  7. class STrack(BaseTrack):
  8. shared_kalman = KalmanFilter()
  9. def __init__(self, xywh, score):
  10. # wait activate
  11. self._xywh = np.asarray(xywh, dtype=np.float)
  12. self.kalman_filter = None
  13. self.mean, self.covariance = None, None
  14. self.is_activated = False
  15. self.score = score
  16. self.tracklet_len = 0
  17. def predict(self):
  18. mean_state = self.mean.copy()
  19. if self.state != TrackState.Tracked:
  20. mean_state[7] = 0
  21. self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
  22. @staticmethod
  23. def multi_predict(stracks):
  24. if len(stracks) > 0:
  25. multi_mean = np.asarray([st.mean.copy() for st in stracks])
  26. multi_covariance = np.asarray([st.covariance for st in stracks])
  27. for i, st in enumerate(stracks):
  28. if st.state != TrackState.Tracked:
  29. multi_mean[i][7] = 0
  30. multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
  31. for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
  32. stracks[i].mean = mean
  33. stracks[i].covariance = cov
  34. def activate(self, kalman_filter, frame_id):
  35. """Start a new tracklet"""
  36. self.kalman_filter = kalman_filter
  37. self.track_id = self.next_id()
  38. self.mean, self.covariance = self.kalman_filter.initiate(self.xywh_to_cxcyah(self._xywh))
  39. self.tracklet_len = 0
  40. self.state = TrackState.Tracked
  41. if frame_id == 1:
  42. self.is_activated = True
  43. # self.is_activated = True
  44. self.frame_id = frame_id
  45. self.start_frame = frame_id
  46. def re_activate(self, new_track, frame_id, new_id=False):
  47. self.mean, self.covariance = self.kalman_filter.update(
  48. self.mean, self.covariance, self.xywh_to_cxcyah(new_track.xywh)
  49. )
  50. self.tracklet_len = 0
  51. self.state = TrackState.Tracked
  52. self.is_activated = True
  53. self.frame_id = frame_id
  54. if new_id:
  55. self.track_id = self.next_id()
  56. self.score = new_track.score
  57. def update(self, new_track, frame_id):
  58. """
  59. Update a matched track
  60. :type new_track: STrack
  61. :type frame_id: int
  62. :type update_feature: bool
  63. :return:
  64. """
  65. self.frame_id = frame_id
  66. self.tracklet_len += 1
  67. new_xywh = new_track.xywh
  68. self.mean, self.covariance = self.kalman_filter.update(
  69. self.mean, self.covariance, self.xywh_to_cxcyah(new_xywh))
  70. self.state = TrackState.Tracked
  71. self.is_activated = True
  72. self.score = new_track.score
  73. @property
  74. # @jit(nopython=True)
  75. def xywh(self):
  76. """Get current position in bounding box format `(top left x, top left y,
  77. width, height)`.
  78. """
  79. if self.mean is None:
  80. return self._xywh.copy()
  81. ret = self.mean[:4].copy()
  82. ret[2] *= ret[3]
  83. ret[:2] -= ret[2:] / 2
  84. return ret
  85. @property
  86. # @jit(nopython=True)
  87. def xyxy(self):
  88. """Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
  89. `(top left, bottom right)`.
  90. """
  91. ret = self.xywh.copy()
  92. ret[2:] += ret[:2]
  93. return ret
  94. @staticmethod
  95. # @jit(nopython=True)
  96. def xywh_to_cxcyah(xywh):
  97. """[x1, y1, w, h] -> [cx, cy, aspect ratio, h],
  98. where the aspect ratio is `width / height`.
  99. """
  100. ret = np.asarray(xywh).copy()
  101. ret[:2] += ret[2:] / 2
  102. ret[2] /= ret[3]
  103. return ret
  104. @staticmethod
  105. # @jit(nopython=True)
  106. def xyxy_to_xywh(xyxy):
  107. """ [x1, y1, x2, y2] -> [x1, y1, w, h]"""
  108. ret = np.asarray(xyxy).copy()
  109. ret[2:] -= ret[:2]
  110. return ret
  111. @staticmethod
  112. # @jit(nopython=True)
  113. def xywh_to_xyxy(xywh):
  114. ret = np.asarray(xywh).copy()
  115. ret[2:] += ret[:2]
  116. return ret
  117. def to_cxcyah(self):
  118. return self.xywh_to_cxcyah(self.xywh)
  119. def __repr__(self):
  120. return 'OT_{}_({}-{})'.format(self.track_id, self.start_frame, self.end_frame)
  121. class ByteTracker(object):
  122. def __init__(self, track_thresh=0.6, track_buffer=30, frame_rate=30, match_thresh=0.9, mot20=False):
  123. self.tracked_stracks = [] # type: list[STrack]
  124. self.lost_stracks = [] # type: list[STrack]
  125. self.removed_stracks = [] # type: list[STrack]
  126. self.frame_id = 0
  127. self.track_thresh = track_thresh
  128. self.track_buffer = track_buffer
  129. self.det_thresh = track_thresh + 0.1
  130. self.match_thresh = match_thresh
  131. self.buffer_size = int(frame_rate / 30.0 * track_buffer)
  132. self.max_time_lost = self.buffer_size
  133. self.kalman_filter = KalmanFilter()
  134. self.mot20 = mot20
  135. def update(self, scores, bboxes, labels):
  136. self.frame_id += 1
  137. activated_starcks = []
  138. refind_stracks = []
  139. lost_stracks = []
  140. removed_stracks = []
  141. # process outputs
  142. remain_inds = scores > self.track_thresh
  143. inds_low = scores > 0.1
  144. inds_high = scores < self.track_thresh
  145. inds_second = np.logical_and(inds_low, inds_high)
  146. # high score detections
  147. dets = bboxes[remain_inds]
  148. scores_keep = scores[remain_inds]
  149. # second detections
  150. dets_second = bboxes[inds_second]
  151. scores_second = scores[inds_second]
  152. if len(dets) > 0:
  153. '''Detections'''
  154. detections = [STrack(STrack.xyxy_to_xywh(xyxy), s) for
  155. (xyxy, s) in zip(dets, scores_keep)]
  156. else:
  157. detections = []
  158. ''' Add newly detected tracklets to tracked_stracks'''
  159. unconfirmed = []
  160. tracked_stracks = [] # type: list[STrack]
  161. for track in self.tracked_stracks:
  162. if not track.is_activated:
  163. unconfirmed.append(track)
  164. else:
  165. tracked_stracks.append(track)
  166. ''' Step 2: First association, with high score detection boxes'''
  167. strack_pool = joint_stracks(tracked_stracks, self.lost_stracks)
  168. # Predict the current location with KF
  169. STrack.multi_predict(strack_pool)
  170. dists = iou_distance(strack_pool, detections)
  171. if not self.mot20:
  172. dists = fuse_score(dists, detections)
  173. matches, u_track, u_detection = linear_assignment(dists, thresh=self.match_thresh)
  174. for itracked, idet in matches:
  175. track = strack_pool[itracked]
  176. det = detections[idet]
  177. if track.state == TrackState.Tracked:
  178. track.update(detections[idet], self.frame_id)
  179. activated_starcks.append(track)
  180. else:
  181. track.re_activate(det, self.frame_id, new_id=False)
  182. refind_stracks.append(track)
  183. ''' Step 3: Second association, with low score detection boxes'''
  184. # association the untrack to the low score detections
  185. if len(dets_second) > 0:
  186. '''Detections'''
  187. detections_second = [STrack(STrack.xyxy_to_xywh(xyxy), s) for
  188. (xyxy, s) in zip(dets_second, scores_second)]
  189. else:
  190. detections_second = []
  191. r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
  192. dists = iou_distance(r_tracked_stracks, detections_second)
  193. matches, u_track, u_detection_second = linear_assignment(dists, thresh=0.5)
  194. for itracked, idet in matches:
  195. track = r_tracked_stracks[itracked]
  196. det = detections_second[idet]
  197. if track.state == TrackState.Tracked:
  198. track.update(det, self.frame_id)
  199. activated_starcks.append(track)
  200. else:
  201. track.re_activate(det, self.frame_id, new_id=False)
  202. refind_stracks.append(track)
  203. for it in u_track:
  204. track = r_tracked_stracks[it]
  205. if not track.state == TrackState.Lost:
  206. track.mark_lost()
  207. lost_stracks.append(track)
  208. '''Deal with unconfirmed tracks, usually tracks with only one beginning frame'''
  209. detections = [detections[i] for i in u_detection]
  210. dists = iou_distance(unconfirmed, detections)
  211. if not self.mot20:
  212. dists = fuse_score(dists, detections)
  213. matches, u_unconfirmed, u_detection = linear_assignment(dists, thresh=0.7)
  214. for itracked, idet in matches:
  215. unconfirmed[itracked].update(detections[idet], self.frame_id)
  216. activated_starcks.append(unconfirmed[itracked])
  217. for it in u_unconfirmed:
  218. track = unconfirmed[it]
  219. track.mark_removed()
  220. removed_stracks.append(track)
  221. """ Step 4: Init new stracks"""
  222. for inew in u_detection:
  223. track = detections[inew]
  224. if track.score < self.det_thresh:
  225. continue
  226. track.activate(self.kalman_filter, self.frame_id)
  227. activated_starcks.append(track)
  228. """ Step 5: Update state"""
  229. for track in self.lost_stracks:
  230. if self.frame_id - track.end_frame > self.max_time_lost:
  231. track.mark_removed()
  232. removed_stracks.append(track)
  233. self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
  234. self.tracked_stracks = joint_stracks(self.tracked_stracks, activated_starcks)
  235. self.tracked_stracks = joint_stracks(self.tracked_stracks, refind_stracks)
  236. self.lost_stracks = sub_stracks(self.lost_stracks, self.tracked_stracks)
  237. self.lost_stracks.extend(lost_stracks)
  238. self.lost_stracks = sub_stracks(self.lost_stracks, self.removed_stracks)
  239. self.removed_stracks.extend(removed_stracks)
  240. self.tracked_stracks, self.lost_stracks = remove_duplicate_stracks(
  241. self.tracked_stracks, self.lost_stracks)
  242. # get scores of lost tracks
  243. output_stracks = [track for track in self.tracked_stracks if track.is_activated]
  244. return output_stracks
  245. def joint_stracks(tlista, tlistb):
  246. exists = {}
  247. res = []
  248. for t in tlista:
  249. exists[t.track_id] = 1
  250. res.append(t)
  251. for t in tlistb:
  252. tid = t.track_id
  253. if not exists.get(tid, 0):
  254. exists[tid] = 1
  255. res.append(t)
  256. return res
  257. def sub_stracks(tlista, tlistb):
  258. stracks = {}
  259. for t in tlista:
  260. stracks[t.track_id] = t
  261. for t in tlistb:
  262. tid = t.track_id
  263. if stracks.get(tid, 0):
  264. del stracks[tid]
  265. return list(stracks.values())
  266. def remove_duplicate_stracks(stracksa, stracksb):
  267. pdist = iou_distance(stracksa, stracksb)
  268. pairs = np.where(pdist < 0.15)
  269. dupa, dupb = list(), list()
  270. for p, q in zip(*pairs):
  271. timep = stracksa[p].frame_id - stracksa[p].start_frame
  272. timeq = stracksb[q].frame_id - stracksb[q].start_frame
  273. if timep > timeq:
  274. dupb.append(q)
  275. else:
  276. dupa.append(p)
  277. resa = [t for i, t in enumerate(stracksa) if not i in dupa]
  278. resb = [t for i, t in enumerate(stracksb) if not i in dupb]
  279. return resa, resb