yolox.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. import torch
  2. import torch.nn as nn
  3. from .yolox_backbone import build_backbone
  4. from .yolox_pafpn import build_fpn
  5. from .yolox_head import build_head
  6. from utils.misc import multiclass_nms
  7. class YOLOX(nn.Module):
  8. def __init__(self,
  9. cfg,
  10. device,
  11. num_classes = 20,
  12. conf_thresh = 0.05,
  13. nms_thresh = 0.6,
  14. trainable = False,
  15. topk = 1000,
  16. deploy = False,
  17. nms_class_agnostic = False):
  18. super(YOLOX, self).__init__()
  19. # ---------------------- Basic Parameters ----------------------
  20. self.cfg = cfg
  21. self.device = device
  22. self.stride = cfg['stride']
  23. self.num_classes = num_classes
  24. self.trainable = trainable
  25. self.conf_thresh = conf_thresh
  26. self.nms_thresh = nms_thresh
  27. self.topk = topk
  28. self.deploy = deploy
  29. self.nms_class_agnostic = nms_class_agnostic
  30. # ------------------- Network Structure -------------------
  31. ## 主干网络
  32. self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
  33. ## 特征金字塔
  34. self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=round(256*cfg['width']))
  35. self.head_dim = self.fpn.out_dim
  36. ## 检测头
  37. self.non_shared_heads = nn.ModuleList(
  38. [build_head(cfg, head_dim, head_dim, num_classes)
  39. for head_dim in self.head_dim
  40. ])
  41. ## 预测层
  42. self.obj_preds = nn.ModuleList(
  43. [nn.Conv2d(head.reg_out_dim, 1, kernel_size=1)
  44. for head in self.non_shared_heads
  45. ])
  46. self.cls_preds = nn.ModuleList(
  47. [nn.Conv2d(head.cls_out_dim, self.num_classes, kernel_size=1)
  48. for head in self.non_shared_heads
  49. ])
  50. self.reg_preds = nn.ModuleList(
  51. [nn.Conv2d(head.reg_out_dim, 4, kernel_size=1)
  52. for head in self.non_shared_heads
  53. ])
  54. # ---------------------- Basic Functions ----------------------
  55. ## generate anchor points
  56. def generate_anchors(self, level, fmp_size):
  57. """
  58. fmp_size: (List) [H, W]
  59. """
  60. # generate grid cells
  61. fmp_h, fmp_w = fmp_size
  62. anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
  63. # [H, W, 2] -> [HW, 2]
  64. anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
  65. anchor_xy += 0.5 # add center offset
  66. anchor_xy *= self.stride[level]
  67. anchors = anchor_xy.to(self.device)
  68. return anchors
  69. ## post-process
  70. def post_process(self, obj_preds, cls_preds, box_preds):
  71. """
  72. Input:
  73. obj_preds: List(Tensor) [[H x W, 1], ...]
  74. cls_preds: List(Tensor) [[H x W, C], ...]
  75. box_preds: List(Tensor) [[H x W, 4], ...]
  76. anchors: List(Tensor) [[H x W, 2], ...]
  77. """
  78. all_scores = []
  79. all_labels = []
  80. all_bboxes = []
  81. for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
  82. # (H x W x KA x C,)
  83. scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
  84. # Keep top k top scoring indices only.
  85. num_topk = min(self.topk, box_pred_i.size(0))
  86. # torch.sort is actually faster than .topk (at least on GPUs)
  87. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  88. topk_scores = predicted_prob[:num_topk]
  89. topk_idxs = topk_idxs[:num_topk]
  90. # filter out the proposals with low confidence score
  91. keep_idxs = topk_scores > self.conf_thresh
  92. scores = topk_scores[keep_idxs]
  93. topk_idxs = topk_idxs[keep_idxs]
  94. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  95. labels = topk_idxs % self.num_classes
  96. bboxes = box_pred_i[anchor_idxs]
  97. all_scores.append(scores)
  98. all_labels.append(labels)
  99. all_bboxes.append(bboxes)
  100. scores = torch.cat(all_scores)
  101. labels = torch.cat(all_labels)
  102. bboxes = torch.cat(all_bboxes)
  103. # to cpu & numpy
  104. scores = scores.cpu().numpy()
  105. labels = labels.cpu().numpy()
  106. bboxes = bboxes.cpu().numpy()
  107. # nms
  108. scores, labels, bboxes = multiclass_nms(
  109. scores, labels, bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
  110. return bboxes, scores, labels
  111. # ---------------------- Main Process for Inference ----------------------
  112. @torch.no_grad()
  113. def inference_single_image(self, x):
  114. # 主干网络
  115. pyramid_feats = self.backbone(x)
  116. # 特征金字塔
  117. pyramid_feats = self.fpn(pyramid_feats)
  118. # 检测头
  119. all_obj_preds = []
  120. all_cls_preds = []
  121. all_box_preds = []
  122. all_anchors = []
  123. for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
  124. cls_feat, reg_feat = head(feat)
  125. # [1, C, H, W]
  126. obj_pred = self.obj_preds[level](reg_feat)
  127. cls_pred = self.cls_preds[level](cls_feat)
  128. reg_pred = self.reg_preds[level](reg_feat)
  129. # anchors: [M, 2]
  130. fmp_size = cls_pred.shape[-2:]
  131. anchors = self.generate_anchors(level, fmp_size)
  132. # [1, C, H, W] -> [H, W, C] -> [M, C]
  133. obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
  134. cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
  135. reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
  136. # decode bbox
  137. ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
  138. wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
  139. pred_x1y1 = ctr_pred - wh_pred * 0.5
  140. pred_x2y2 = ctr_pred + wh_pred * 0.5
  141. box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
  142. all_obj_preds.append(obj_pred)
  143. all_cls_preds.append(cls_pred)
  144. all_box_preds.append(box_pred)
  145. all_anchors.append(anchors)
  146. if self.deploy:
  147. obj_preds = torch.cat(all_obj_preds, dim=0)
  148. cls_preds = torch.cat(all_cls_preds, dim=0)
  149. box_preds = torch.cat(all_box_preds, dim=0)
  150. scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
  151. bboxes = box_preds
  152. # [n_anchors_all, 4 + C]
  153. outputs = torch.cat([bboxes, scores], dim=-1)
  154. return outputs
  155. else:
  156. # post process
  157. bboxes, scores, labels = self.post_process(
  158. all_obj_preds, all_cls_preds, all_box_preds)
  159. return bboxes, scores, labels
  160. # ---------------------- Main Process for Training ----------------------
  161. def forward(self, x):
  162. if not self.trainable:
  163. return self.inference_single_image(x)
  164. else:
  165. # 主干网络
  166. pyramid_feats = self.backbone(x)
  167. # 特征金字塔
  168. pyramid_feats = self.fpn(pyramid_feats)
  169. # 检测头
  170. all_anchors = []
  171. all_strides = []
  172. all_obj_preds = []
  173. all_cls_preds = []
  174. all_box_preds = []
  175. all_reg_preds = []
  176. for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
  177. cls_feat, reg_feat = head(feat)
  178. # [B, C, H, W]
  179. obj_pred = self.obj_preds[level](reg_feat)
  180. cls_pred = self.cls_preds[level](cls_feat)
  181. reg_pred = self.reg_preds[level](reg_feat)
  182. B, _, H, W = cls_pred.size()
  183. fmp_size = [H, W]
  184. # generate anchor boxes: [M, 4]
  185. anchors = self.generate_anchors(level, fmp_size)
  186. # stride tensor: [M, 1]
  187. stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
  188. # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
  189. obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
  190. cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
  191. reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
  192. # decode bbox
  193. ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
  194. wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
  195. pred_x1y1 = ctr_pred - wh_pred * 0.5
  196. pred_x2y2 = ctr_pred + wh_pred * 0.5
  197. box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
  198. all_obj_preds.append(obj_pred)
  199. all_cls_preds.append(cls_pred)
  200. all_box_preds.append(box_pred)
  201. all_reg_preds.append(reg_pred)
  202. all_anchors.append(anchors)
  203. all_strides.append(stride_tensor)
  204. # output dict
  205. outputs = {"pred_obj": all_obj_preds, # List(Tensor) [B, M, 1]
  206. "pred_cls": all_cls_preds, # List(Tensor) [B, M, C]
  207. "pred_box": all_box_preds, # List(Tensor) [B, M, 4]
  208. "pred_reg": all_reg_preds, # List(Tensor) [B, M, 4]
  209. "anchors": all_anchors, # List(Tensor) [M, 2]
  210. "strides": self.stride, # List(Int) [8, 16, 32]
  211. "stride_tensors": all_strides # List(Tensor) [M, 1]
  212. }
  213. return outputs