fcos.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. import torch
  2. import torch.nn as nn
  3. # --------------- Model components ---------------
  4. from ...backbone import build_backbone
  5. from ...neck import build_neck
  6. from ...head import build_head
  7. # --------------- External components ---------------
  8. from utils.misc import multiclass_nms
  9. # ------------------------ Fully Convolutional One-Stage Detector ------------------------
  10. class FCOS(nn.Module):
  11. def __init__(self,
  12. cfg,
  13. num_classes :int = 80,
  14. conf_thresh :float = 0.05,
  15. nms_thresh :float = 0.6,
  16. topk :int = 1000,
  17. ca_nms :bool = False):
  18. super(FCOS, self).__init__()
  19. # ---------------------- Basic Parameters ----------------------
  20. self.cfg = cfg
  21. self.topk = topk
  22. self.num_classes = num_classes
  23. self.conf_thresh = conf_thresh
  24. self.nms_thresh = nms_thresh
  25. self.ca_nms = ca_nms
  26. # ---------------------- Network Parameters ----------------------
  27. ## Backbone
  28. self.backbone, feat_dims = build_backbone(cfg)
  29. ## Neck
  30. self.fpn = build_neck(cfg, feat_dims, cfg.head_dim)
  31. ## Heads
  32. self.head = build_head(cfg, cfg.head_dim, cfg.head_dim)
  33. def post_process(self, cls_preds, ctn_preds, box_preds):
  34. """
  35. Input:
  36. cls_preds: List(Tensor) [[B, H x W, C], ...]
  37. ctn_preds: List(Tensor) [[B, H x W, 1], ...]
  38. box_preds: List(Tensor) [[B, H x W, 4], ...]
  39. """
  40. all_scores = []
  41. all_labels = []
  42. all_bboxes = []
  43. for cls_pred_i, ctn_pred_i, box_pred_i in zip(cls_preds, ctn_preds, box_preds):
  44. cls_pred_i = cls_pred_i[0]
  45. ctn_pred_i = ctn_pred_i[0]
  46. box_pred_i = box_pred_i[0]
  47. # (H x W x C,)
  48. scores_i = torch.sqrt(cls_pred_i.sigmoid() * ctn_pred_i.sigmoid()).flatten()
  49. # Keep top k top scoring indices only.
  50. num_topk = min(self.topk, box_pred_i.size(0))
  51. # torch.sort is actually faster than .topk (at least on GPUs)
  52. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  53. topk_scores = predicted_prob[:num_topk]
  54. topk_idxs = topk_idxs[:num_topk]
  55. # filter out the proposals with low confidence score
  56. keep_idxs = topk_scores > self.conf_thresh
  57. topk_idxs = topk_idxs[keep_idxs]
  58. # final scores
  59. scores = topk_scores[keep_idxs]
  60. # final labels
  61. labels = topk_idxs % self.num_classes
  62. # final bboxes
  63. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  64. bboxes = box_pred_i[anchor_idxs]
  65. all_scores.append(scores)
  66. all_labels.append(labels)
  67. all_bboxes.append(bboxes)
  68. scores = torch.cat(all_scores)
  69. labels = torch.cat(all_labels)
  70. bboxes = torch.cat(all_bboxes)
  71. # to cpu & numpy
  72. scores = scores.cpu().numpy()
  73. labels = labels.cpu().numpy()
  74. bboxes = bboxes.cpu().numpy()
  75. # nms
  76. scores, labels, bboxes = multiclass_nms(
  77. scores, labels, bboxes, self.nms_thresh, self.num_classes, self.ca_nms)
  78. return bboxes, scores, labels
  79. def forward(self, src, src_mask=None):
  80. # ---------------- Backbone ----------------
  81. pyramid_feats = self.backbone(src)
  82. # ---------------- Neck ----------------
  83. pyramid_feats = self.fpn(pyramid_feats)
  84. # ---------------- Heads ----------------
  85. outputs = self.head(pyramid_feats, src_mask)
  86. if not self.training:
  87. # ---------------- PostProcess ----------------
  88. cls_pred = outputs["pred_cls"]
  89. ctn_pred = outputs["pred_ctn"]
  90. box_pred = outputs["pred_box"]
  91. bboxes, scores, labels = self.post_process(cls_pred, ctn_pred, box_pred)
  92. # normalize bbox
  93. bboxes[..., 0::2] /= src.shape[-1]
  94. bboxes[..., 1::2] /= src.shape[-2]
  95. bboxes = bboxes.clip(0., 1.)
  96. return bboxes, scores, labels
  97. return outputs