yolox.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import torch
  2. import torch.nn as nn
  3. from .yolox_backbone import build_backbone
  4. from .yolox_pafpn import build_fpn
  5. from .yolox_head import build_head
  6. from utils.misc import multiclass_nms
  7. class YOLOX(nn.Module):
  8. def __init__(self,
  9. cfg,
  10. device,
  11. num_classes = 20,
  12. conf_thresh = 0.05,
  13. nms_thresh = 0.6,
  14. trainable = False,
  15. topk = 1000,
  16. deploy = False,
  17. no_multi_labels = False,
  18. nms_class_agnostic = False):
  19. super(YOLOX, self).__init__()
  20. # ---------------------- Basic Parameters ----------------------
  21. self.cfg = cfg
  22. self.device = device
  23. self.stride = cfg['stride']
  24. self.num_classes = num_classes
  25. self.trainable = trainable
  26. self.conf_thresh = conf_thresh
  27. self.nms_thresh = nms_thresh
  28. self.num_levels = 3
  29. self.topk_candidates = topk
  30. self.deploy = deploy
  31. self.no_multi_labels = no_multi_labels
  32. self.nms_class_agnostic = nms_class_agnostic
  33. # ------------------- Network Structure -------------------
  34. ## 主干网络
  35. self.backbone, feats_dim = build_backbone(cfg)
  36. ## 特征金字塔
  37. self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=round(256*cfg['width']))
  38. self.head_dim = self.fpn.out_dim
  39. ## 检测头
  40. self.non_shared_heads = nn.ModuleList(
  41. [build_head(cfg, head_dim, head_dim, num_classes)
  42. for head_dim in self.head_dim
  43. ])
  44. ## 预测层
  45. self.obj_preds = nn.ModuleList(
  46. [nn.Conv2d(head.reg_out_dim, 1, kernel_size=1)
  47. for head in self.non_shared_heads
  48. ])
  49. self.cls_preds = nn.ModuleList(
  50. [nn.Conv2d(head.cls_out_dim, self.num_classes, kernel_size=1)
  51. for head in self.non_shared_heads
  52. ])
  53. self.reg_preds = nn.ModuleList(
  54. [nn.Conv2d(head.reg_out_dim, 4, kernel_size=1)
  55. for head in self.non_shared_heads
  56. ])
  57. # ---------------------- Basic Functions ----------------------
  58. ## generate anchor points
  59. def generate_anchors(self, level, fmp_size):
  60. """
  61. fmp_size: (List) [H, W]
  62. """
  63. # generate grid cells
  64. fmp_h, fmp_w = fmp_size
  65. anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
  66. # [H, W, 2] -> [HW, 2]
  67. anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
  68. anchor_xy += 0.5 # add center offset
  69. anchor_xy *= self.stride[level]
  70. anchors = anchor_xy.to(self.device)
  71. return anchors
  72. ## post-process
  73. def post_process(self, obj_preds, cls_preds, box_preds):
  74. """
  75. Input:
  76. cls_preds: List[np.array] -> [[M, C], ...]
  77. box_preds: List[np.array] -> [[M, 4], ...]
  78. obj_preds: List[np.array] -> [[M, 1], ...] or None
  79. Output:
  80. bboxes: np.array -> [N, 4]
  81. scores: np.array -> [N,]
  82. labels: np.array -> [N,]
  83. """
  84. assert len(cls_preds) == self.num_levels
  85. all_scores = []
  86. all_labels = []
  87. all_bboxes = []
  88. for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
  89. if self.no_multi_labels:
  90. # [M,]
  91. scores, labels = torch.max(torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()), dim=1)
  92. # Keep top k top scoring indices only.
  93. num_topk = min(self.topk_candidates, box_pred_i.size(0))
  94. # topk candidates
  95. predicted_prob, topk_idxs = scores.sort(descending=True)
  96. topk_scores = predicted_prob[:num_topk]
  97. topk_idxs = topk_idxs[:num_topk]
  98. # filter out the proposals with low confidence score
  99. keep_idxs = topk_scores > self.conf_thresh
  100. scores = topk_scores[keep_idxs]
  101. topk_idxs = topk_idxs[keep_idxs]
  102. labels = labels[topk_idxs]
  103. bboxes = box_pred_i[topk_idxs]
  104. else:
  105. # [M, C] -> [MC,]
  106. scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
  107. # Keep top k top scoring indices only.
  108. num_topk = min(self.topk_candidates, box_pred_i.size(0))
  109. # torch.sort is actually faster than .topk (at least on GPUs)
  110. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  111. topk_scores = predicted_prob[:num_topk]
  112. topk_idxs = topk_idxs[:num_topk]
  113. # filter out the proposals with low confidence score
  114. keep_idxs = topk_scores > self.conf_thresh
  115. scores = topk_scores[keep_idxs]
  116. topk_idxs = topk_idxs[keep_idxs]
  117. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  118. labels = topk_idxs % self.num_classes
  119. bboxes = box_pred_i[anchor_idxs]
  120. all_scores.append(scores)
  121. all_labels.append(labels)
  122. all_bboxes.append(bboxes)
  123. scores = torch.cat(all_scores)
  124. labels = torch.cat(all_labels)
  125. bboxes = torch.cat(all_bboxes)
  126. # to cpu & numpy
  127. scores = scores.cpu().numpy()
  128. labels = labels.cpu().numpy()
  129. bboxes = bboxes.cpu().numpy()
  130. # nms
  131. scores, labels, bboxes = multiclass_nms(
  132. scores, labels, bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
  133. return bboxes, scores, labels
  134. # ---------------------- Main Process for Inference ----------------------
  135. @torch.no_grad()
  136. def inference_single_image(self, x):
  137. # 主干网络
  138. pyramid_feats = self.backbone(x)
  139. # 特征金字塔
  140. pyramid_feats = self.fpn(pyramid_feats)
  141. # 检测头
  142. all_obj_preds = []
  143. all_cls_preds = []
  144. all_box_preds = []
  145. all_anchors = []
  146. for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
  147. cls_feat, reg_feat = head(feat)
  148. # [1, C, H, W]
  149. obj_pred = self.obj_preds[level](reg_feat)
  150. cls_pred = self.cls_preds[level](cls_feat)
  151. reg_pred = self.reg_preds[level](reg_feat)
  152. # anchors: [M, 2]
  153. fmp_size = cls_pred.shape[-2:]
  154. anchors = self.generate_anchors(level, fmp_size)
  155. # [1, C, H, W] -> [H, W, C] -> [M, C]
  156. obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
  157. cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
  158. reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
  159. # decode bbox
  160. ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
  161. wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
  162. pred_x1y1 = ctr_pred - wh_pred * 0.5
  163. pred_x2y2 = ctr_pred + wh_pred * 0.5
  164. box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
  165. all_obj_preds.append(obj_pred)
  166. all_cls_preds.append(cls_pred)
  167. all_box_preds.append(box_pred)
  168. all_anchors.append(anchors)
  169. if self.deploy:
  170. obj_preds = torch.cat(all_obj_preds, dim=0)
  171. cls_preds = torch.cat(all_cls_preds, dim=0)
  172. box_preds = torch.cat(all_box_preds, dim=0)
  173. scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
  174. bboxes = box_preds
  175. # [n_anchors_all, 4 + C]
  176. outputs = torch.cat([bboxes, scores], dim=-1)
  177. else:
  178. # post process
  179. bboxes, scores, labels = self.post_process(
  180. all_obj_preds, all_cls_preds, all_box_preds)
  181. outputs = {
  182. "scores": scores,
  183. "labels": labels,
  184. "bboxes": bboxes
  185. }
  186. return outputs
  187. # ---------------------- Main Process for Training ----------------------
  188. def forward(self, x):
  189. if not self.trainable:
  190. return self.inference_single_image(x)
  191. else:
  192. # 主干网络
  193. pyramid_feats = self.backbone(x)
  194. # 特征金字塔
  195. pyramid_feats = self.fpn(pyramid_feats)
  196. # 检测头
  197. all_anchors = []
  198. all_strides = []
  199. all_obj_preds = []
  200. all_cls_preds = []
  201. all_box_preds = []
  202. all_reg_preds = []
  203. for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
  204. cls_feat, reg_feat = head(feat)
  205. # [B, C, H, W]
  206. obj_pred = self.obj_preds[level](reg_feat)
  207. cls_pred = self.cls_preds[level](cls_feat)
  208. reg_pred = self.reg_preds[level](reg_feat)
  209. B, _, H, W = cls_pred.size()
  210. fmp_size = [H, W]
  211. # generate anchor boxes: [M, 4]
  212. anchors = self.generate_anchors(level, fmp_size)
  213. # stride tensor: [M, 1]
  214. stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
  215. # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
  216. obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
  217. cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
  218. reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
  219. # decode bbox
  220. ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
  221. wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
  222. pred_x1y1 = ctr_pred - wh_pred * 0.5
  223. pred_x2y2 = ctr_pred + wh_pred * 0.5
  224. box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
  225. all_obj_preds.append(obj_pred)
  226. all_cls_preds.append(cls_pred)
  227. all_box_preds.append(box_pred)
  228. all_reg_preds.append(reg_pred)
  229. all_anchors.append(anchors)
  230. all_strides.append(stride_tensor)
  231. # output dict
  232. outputs = {"pred_obj": all_obj_preds, # List(Tensor) [B, M, 1]
  233. "pred_cls": all_cls_preds, # List(Tensor) [B, M, C]
  234. "pred_box": all_box_preds, # List(Tensor) [B, M, 4]
  235. "pred_reg": all_reg_preds, # List(Tensor) [B, M, 4]
  236. "anchors": all_anchors, # List(Tensor) [M, 2]
  237. "strides": self.stride, # List(Int) [8, 16, 32]
  238. "stride_tensors": all_strides # List(Tensor) [M, 1]
  239. }
  240. return outputs