rtcdet.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. # Real-time Convolutional Object Detector
  2. # --------------- Torch components ---------------
  3. import torch
  4. import torch.nn as nn
  5. # --------------- Model components ---------------
  6. from .rtcdet_backbone import build_backbone
  7. from .rtcdet_neck import build_neck
  8. from .rtcdet_pafpn import build_fpn
  9. from .rtcdet_head import build_det_head, build_seg_head, build_pose_head
  10. from .rtcdet_pred import build_det_pred, build_seg_pred, build_pose_pred
  11. # --------------- External components ---------------
  12. from utils.misc import multiclass_nms
  13. # Real-time Convolutional General Object Detector
  14. class RTCDet(nn.Module):
  15. def __init__(self,
  16. cfg,
  17. device,
  18. num_classes = 20,
  19. conf_thresh = 0.01,
  20. nms_thresh = 0.5,
  21. topk = 1000,
  22. trainable = False,
  23. deploy = False,
  24. no_multi_labels = False,
  25. nms_class_agnostic = False,
  26. ):
  27. super(RTCDet, self).__init__()
  28. # ---------------- Basic settings ----------------
  29. ## Basic parameters
  30. self.cfg = cfg
  31. self.device = device
  32. self.deploy = deploy
  33. self.trainable = trainable
  34. self.num_classes = num_classes
  35. ## Network parameters
  36. self.strides = cfg['stride']
  37. self.reg_max = cfg['det_head']['reg_max']
  38. self.num_levels = len(self.strides)
  39. ## Post-process parameters
  40. self.nms_thresh = nms_thresh
  41. self.conf_thresh = conf_thresh
  42. self.topk_candidates = topk
  43. self.no_multi_labels = no_multi_labels
  44. self.nms_class_agnostic = nms_class_agnostic
  45. # ---------------- Network settings ----------------
  46. ## Backbone
  47. self.backbone, self.fpn_feat_dims = build_backbone(cfg, pretrained=cfg['bk_pretrained']&trainable)
  48. ## Neck: SPP
  49. self.neck = build_neck(cfg, self.fpn_feat_dims[-1], self.fpn_feat_dims[-1])
  50. self.fpn_feat_dims[-1] = self.neck.out_dim
  51. ## Neck: FPN
  52. self.fpn = build_fpn(cfg, self.fpn_feat_dims)
  53. self.fpn_dims = self.fpn.out_dim
  54. self.cls_head_dim = max(self.fpn_dims[0], min(num_classes, 100))
  55. self.reg_head_dim = max(self.fpn_dims[0]//4, 16, 4*self.reg_max)
  56. ## Head: Detection
  57. self.det_head = nn.Sequential(
  58. build_det_head(cfg = cfg['det_head'],
  59. in_dims = self.fpn_dims,
  60. cls_head_dim = self.cls_head_dim,
  61. reg_head_dim = self.reg_head_dim,
  62. num_levels = self.num_levels
  63. ),
  64. build_det_pred(cls_dim = self.cls_head_dim,
  65. reg_dim = self.reg_head_dim,
  66. strides = self.strides,
  67. num_classes = num_classes,
  68. num_coords = 4,
  69. reg_max = self.reg_max,
  70. num_levels = self.num_levels
  71. )
  72. )
  73. ## Head: Segmentation
  74. self.seg_head = nn.Sequential(
  75. build_seg_head(cfg['seg_head']),
  76. build_seg_pred()
  77. ) if cfg['seg_head']['name'] is not None else None
  78. ## Head: Human-pose
  79. self.pos_head = nn.Sequential(
  80. build_pose_head(cfg['pos_head']),
  81. build_pose_pred()
  82. ) if cfg['pos_head']['name'] is not None else None
  83. # Post process
  84. def post_process(self, cls_preds, box_preds):
  85. """
  86. Input:
  87. cls_preds: List[np.array] -> [[M, C], ...]
  88. box_preds: List[np.array] -> [[M, 4], ...]
  89. Output:
  90. bboxes: np.array -> [N, 4]
  91. scores: np.array -> [N,]
  92. labels: np.array -> [N,]
  93. """
  94. assert len(cls_preds) == self.num_levels
  95. all_scores = []
  96. all_labels = []
  97. all_bboxes = []
  98. for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
  99. cls_pred_i = cls_pred_i[0]
  100. box_pred_i = box_pred_i[0]
  101. if self.no_multi_labels:
  102. # [M,]
  103. scores, labels = torch.max(cls_pred_i.sigmoid(), dim=1)
  104. # Keep top k top scoring indices only.
  105. num_topk = min(self.topk_candidates, box_pred_i.size(0))
  106. # topk candidates
  107. predicted_prob, topk_idxs = scores.sort(descending=True)
  108. topk_scores = predicted_prob[:num_topk]
  109. topk_idxs = topk_idxs[:num_topk]
  110. # filter out the proposals with low confidence score
  111. keep_idxs = topk_scores > self.conf_thresh
  112. scores = topk_scores[keep_idxs]
  113. topk_idxs = topk_idxs[keep_idxs]
  114. labels = labels[topk_idxs]
  115. bboxes = box_pred_i[topk_idxs]
  116. else:
  117. # [M, C] -> [MC,]
  118. scores_i = cls_pred_i.sigmoid().flatten()
  119. # Keep top k top scoring indices only.
  120. num_topk = min(self.topk_candidates, box_pred_i.size(0))
  121. # torch.sort is actually faster than .topk (at least on GPUs)
  122. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  123. topk_scores = predicted_prob[:num_topk]
  124. topk_idxs = topk_idxs[:num_topk]
  125. # filter out the proposals with low confidence score
  126. keep_idxs = topk_scores > self.conf_thresh
  127. scores = topk_scores[keep_idxs]
  128. topk_idxs = topk_idxs[keep_idxs]
  129. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  130. labels = topk_idxs % self.num_classes
  131. bboxes = box_pred_i[anchor_idxs]
  132. all_scores.append(scores)
  133. all_labels.append(labels)
  134. all_bboxes.append(bboxes)
  135. scores = torch.cat(all_scores, dim=0)
  136. labels = torch.cat(all_labels, dim=0)
  137. bboxes = torch.cat(all_bboxes, dim=0)
  138. # to cpu & numpy
  139. scores = scores.cpu().numpy()
  140. labels = labels.cpu().numpy()
  141. bboxes = bboxes.cpu().numpy()
  142. # nms
  143. scores, labels, bboxes = multiclass_nms(
  144. scores, labels, bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
  145. return bboxes, scores, labels
  146. # Main process
  147. def forward(self, x):
  148. # ---------------- Backbone ----------------
  149. pyramid_feats = self.backbone(x)
  150. # ---------------- Neck: SPP ----------------
  151. pyramid_feats[-1] = self.neck(pyramid_feats[-1])
  152. # ---------------- Neck: PaFPN ----------------
  153. pyramid_feats = self.fpn(pyramid_feats)
  154. # ---------------- Head ----------------
  155. det_outpus = self.forward_det_head(pyramid_feats)
  156. seg_outpus = self.forward_seg_head(pyramid_feats)
  157. pos_outpus = self.forward_pos_head(pyramid_feats)
  158. outputs = {
  159. 'det_outputs': det_outpus,
  160. 'seg_outputs': seg_outpus,
  161. 'pos_outputs': pos_outpus
  162. }
  163. if not self.trainable:
  164. if seg_outpus is not None:
  165. det_outpus.update(seg_outpus)
  166. if pos_outpus is not None:
  167. det_outpus.update(pos_outpus)
  168. outputs = det_outpus
  169. else:
  170. outputs = {
  171. 'det_outputs': det_outpus,
  172. 'seg_outputs': seg_outpus,
  173. 'pos_outputs': pos_outpus
  174. }
  175. return outputs
  176. def forward_det_head(self, x):
  177. # ---------------- Heads ----------------
  178. outputs = self.det_head(x)
  179. # ---------------- Post-process ----------------
  180. if not self.trainable:
  181. all_cls_preds = outputs['pred_cls']
  182. all_box_preds = outputs['pred_box']
  183. if self.deploy:
  184. cls_preds = torch.cat(all_cls_preds, dim=1)[0]
  185. box_preds = torch.cat(all_box_preds, dim=1)[0]
  186. scores = cls_preds.sigmoid()
  187. bboxes = box_preds
  188. # [n_anchors_all, 4 + C]
  189. outputs = torch.cat([bboxes, scores], dim=-1)
  190. else:
  191. # post process
  192. bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
  193. outputs = {
  194. "scores": scores,
  195. "labels": labels,
  196. "bboxes": bboxes
  197. }
  198. return outputs
  199. def forward_seg_head(self, x):
  200. if self.seg_head is None:
  201. return None
  202. def forward_pos_head(self, x):
  203. if self.pos_head is None:
  204. return None