fcos.py 8.0 KB


  1. import torch
  2. import torch.nn as nn
  3. # --------------- Model components ---------------
  4. from ...backbone import build_backbone
  5. from ...neck import build_neck
  6. from ...head import build_head
  7. # --------------- External components ---------------
  8. from utils.misc import multiclass_nms
  9. # ------------------------ Fully Convolutional One-Stage Detector ------------------------
  10. class FCOS(nn.Module):
  11. def __init__(self,
  12. cfg,
  13. num_classes :int = 80,
  14. conf_thresh :float = 0.05,
  15. nms_thresh :float = 0.6,
  16. topk :int = 1000,
  17. ca_nms :bool = False):
  18. super(FCOS, self).__init__()
  19. # ---------------------- Basic Parameters ----------------------
  20. self.cfg = cfg
  21. self.topk = topk
  22. self.num_classes = num_classes
  23. self.conf_thresh = conf_thresh
  24. self.nms_thresh = nms_thresh
  25. self.ca_nms = ca_nms
  26. # ---------------------- Network Parameters ----------------------
  27. ## Backbone
  28. self.backbone, feat_dims = build_backbone(cfg)
  29. ## Neck
  30. self.fpn = build_neck(cfg, feat_dims, cfg.head_dim)
  31. ## Heads
  32. self.head = build_head(cfg, cfg.head_dim, cfg.head_dim)
  33. def post_process(self, cls_preds, ctn_preds, box_preds):
  34. """
  35. Input:
  36. cls_preds: List(Tensor) [[B, H x W, C], ...]
  37. ctn_preds: List(Tensor) [[B, H x W, 1], ...]
  38. box_preds: List(Tensor) [[B, H x W, 4], ...]
  39. """
  40. all_scores = []
  41. all_labels = []
  42. all_bboxes = []
  43. for cls_pred_i, ctn_pred_i, box_pred_i in zip(cls_preds, ctn_preds, box_preds):
  44. cls_pred_i = cls_pred_i[0]
  45. ctn_pred_i = ctn_pred_i[0]
  46. box_pred_i = box_pred_i[0]
  47. # (H x W x C,)
  48. scores_i = torch.sqrt(cls_pred_i.sigmoid() * ctn_pred_i.sigmoid()).flatten()
  49. # Keep top k top scoring indices only.
  50. num_topk = min(self.topk, box_pred_i.size(0))
  51. # torch.sort is actually faster than .topk (at least on GPUs)
  52. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  53. topk_scores = predicted_prob[:num_topk]
  54. topk_idxs = topk_idxs[:num_topk]
  55. # filter out the proposals with low confidence score
  56. keep_idxs = topk_scores > self.conf_thresh
  57. topk_idxs = topk_idxs[keep_idxs]
  58. # final scores
  59. scores = topk_scores[keep_idxs]
  60. # final labels
  61. labels = topk_idxs % self.num_classes
  62. # final bboxes
  63. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  64. bboxes = box_pred_i[anchor_idxs]
  65. all_scores.append(scores)
  66. all_labels.append(labels)
  67. all_bboxes.append(bboxes)
  68. scores = torch.cat(all_scores)
  69. labels = torch.cat(all_labels)
  70. bboxes = torch.cat(all_bboxes)
  71. # to cpu & numpy
  72. scores = scores.cpu().numpy()
  73. labels = labels.cpu().numpy()
  74. bboxes = bboxes.cpu().numpy()
  75. # nms
  76. scores, labels, bboxes = multiclass_nms(
  77. scores, labels, bboxes, self.nms_thresh, self.num_classes, self.ca_nms)
  78. return bboxes, scores, labels
  79. def forward(self, src, src_mask=None):
  80. # ---------------- Backbone ----------------
  81. pyramid_feats = self.backbone(src)
  82. # ---------------- Neck ----------------
  83. pyramid_feats = self.fpn(pyramid_feats)
  84. # ---------------- Heads ----------------
  85. outputs = self.head(pyramid_feats, src_mask)
  86. if not self.training:
  87. # ---------------- PostProcess ----------------
  88. cls_pred = outputs["pred_cls"]
  89. ctn_pred = outputs["pred_ctn"]
  90. box_pred = outputs["pred_box"]
  91. bboxes, scores, labels = self.post_process(cls_pred, ctn_pred, box_pred)
  92. # normalize bbox
  93. bboxes[..., 0::2] /= src.shape[-1]
  94. bboxes[..., 1::2] /= src.shape[-2]
  95. bboxes = bboxes.clip(0., 1.)
  96. outputs = {
  97. 'scores': scores,
  98. 'labels': labels,
  99. 'bboxes': bboxes
  100. }
  101. return outputs
  102. # ------------------------ Real-time FCOS ------------------------
  103. class FcosRT(nn.Module):
  104. def __init__(self,
  105. cfg,
  106. num_classes :int = 80,
  107. conf_thresh :float = 0.05,
  108. nms_thresh :float = 0.6,
  109. topk :int = 1000,
  110. ca_nms :bool = False):
  111. super(FcosRT, self).__init__()
  112. # ---------------------- Basic Parameters ----------------------
  113. self.cfg = cfg
  114. self.topk = topk
  115. self.num_classes = num_classes
  116. self.conf_thresh = conf_thresh
  117. self.nms_thresh = nms_thresh
  118. self.ca_nms = ca_nms
  119. # ---------------------- Network Parameters ----------------------
  120. ## Backbone
  121. self.backbone, feat_dims = build_backbone(cfg)
  122. ## Neck
  123. self.fpn = build_neck(cfg, feat_dims, cfg.head_dim)
  124. ## Heads
  125. self.head = build_head(cfg, cfg.head_dim, cfg.head_dim)
  126. def post_process(self, cls_preds, box_preds):
  127. """
  128. Input:
  129. cls_preds: List(Tensor) [[B, H x W, C], ...]
  130. box_preds: List(Tensor) [[B, H x W, 4], ...]
  131. """
  132. all_scores = []
  133. all_labels = []
  134. all_bboxes = []
  135. for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
  136. cls_pred_i = cls_pred_i[0]
  137. box_pred_i = box_pred_i[0]
  138. # (H x W x C,)
  139. scores_i = cls_pred_i.sigmoid().flatten()
  140. # Keep top k top scoring indices only.
  141. num_topk = min(self.topk, box_pred_i.size(0))
  142. # torch.sort is actually faster than .topk (at least on GPUs)
  143. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  144. topk_scores = predicted_prob[:num_topk]
  145. topk_idxs = topk_idxs[:num_topk]
  146. # filter out the proposals with low confidence score
  147. keep_idxs = topk_scores > self.conf_thresh
  148. topk_idxs = topk_idxs[keep_idxs]
  149. # final scores
  150. scores = topk_scores[keep_idxs]
  151. # final labels
  152. labels = topk_idxs % self.num_classes
  153. # final bboxes
  154. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  155. bboxes = box_pred_i[anchor_idxs]
  156. all_scores.append(scores)
  157. all_labels.append(labels)
  158. all_bboxes.append(bboxes)
  159. scores = torch.cat(all_scores)
  160. labels = torch.cat(all_labels)
  161. bboxes = torch.cat(all_bboxes)
  162. # to cpu & numpy
  163. scores = scores.cpu().numpy()
  164. labels = labels.cpu().numpy()
  165. bboxes = bboxes.cpu().numpy()
  166. # nms
  167. scores, labels, bboxes = multiclass_nms(
  168. scores, labels, bboxes, self.nms_thresh, self.num_classes, self.ca_nms)
  169. return bboxes, scores, labels
  170. def forward(self, src, src_mask=None):
  171. # ---------------- Backbone ----------------
  172. pyramid_feats = self.backbone(src)
  173. # ---------------- Neck ----------------
  174. pyramid_feats = self.fpn(pyramid_feats)
  175. # ---------------- Heads ----------------
  176. outputs = self.head(pyramid_feats, src_mask)
  177. if not self.training:
  178. # ---------------- PostProcess ----------------
  179. cls_pred = outputs["pred_cls"]
  180. box_pred = outputs["pred_box"]
  181. bboxes, scores, labels = self.post_process(cls_pred, box_pred)
  182. # normalize bbox
  183. bboxes[..., 0::2] /= src.shape[-1]
  184. bboxes[..., 1::2] /= src.shape[-2]
  185. bboxes = bboxes.clip(0., 1.)
  186. outputs = {
  187. 'scores': scores,
  188. 'labels': labels,
  189. 'bboxes': bboxes
  190. }
  191. return outputs