fcos.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. # --------------- Model components ---------------
  5. from ...backbone import build_backbone
  6. from ...neck import build_neck
  7. from ...head import build_head
  8. # --------------------- End-to-End RT-FCOS ---------------------
  9. class FcosPSS(nn.Module):
  10. def __init__(self,
  11. cfg,
  12. conf_thresh :float = 0.05,
  13. topk_results :int = 1000,
  14. ):
  15. super(FcosPSS, self).__init__()
  16. # ---------------------- Basic Parameters ----------------------
  17. self.conf_thresh = conf_thresh
  18. self.num_classes = cfg.num_classes
  19. self.topk_results = topk_results
  20. # ---------------------- Network Parameters ----------------------
  21. ## Backbone
  22. self.backbone, pyramid_feats = build_backbone(cfg)
  23. ## Neck
  24. self.backbone_fpn = build_neck(cfg, pyramid_feats, cfg.head_dim)
  25. ## Heads
  26. self.detection_head = build_head(cfg, cfg.head_dim, cfg.head_dim)
  27. def post_process(self, cls_preds, box_preds, pss_preds):
  28. """
  29. Input:
  30. cls_preds: List(Tensor) [[B, H x W, C], ...]
  31. box_preds: List(Tensor) [[B, H x W, 4], ...]
  32. pss_preds: List(Tensor) [[B, H x W, 1], ...]
  33. """
  34. all_scores = []
  35. all_labels = []
  36. all_bboxes = []
  37. for cls_pred_i, box_pred_i, pss_pred_i in zip(cls_preds, box_preds, pss_preds):
  38. cls_pred_i = cls_pred_i[0]
  39. box_pred_i = box_pred_i[0]
  40. pss_pred_i = pss_pred_i[0]
  41. # [H, W, C] -> [HWC,]
  42. scores_i = (cls_pred_i.sigmoid() * pss_pred_i.sigmoid()).flatten()
  43. # Keep top k top scoring indices only.
  44. num_topk = min(self.topk_results, box_pred_i.size(0))
  45. # torch.sort is actually faster than .topk (at least on GPUs)
  46. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  47. topk_scores = predicted_prob[:num_topk]
  48. topk_idxs = topk_idxs[:num_topk]
  49. # filter out the proposals with low confidence score
  50. keep_idxs = topk_scores > self.conf_thresh
  51. topk_idxs = topk_idxs[keep_idxs]
  52. # final scores
  53. scores = topk_scores[keep_idxs]
  54. # final labels
  55. labels = topk_idxs % self.num_classes
  56. # final bboxes
  57. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  58. bboxes = box_pred_i[anchor_idxs]
  59. all_scores.append(scores)
  60. all_labels.append(labels)
  61. all_bboxes.append(bboxes)
  62. scores = torch.cat(all_scores)
  63. labels = torch.cat(all_labels)
  64. bboxes = torch.cat(all_bboxes)
  65. # to cpu & numpy
  66. scores = scores.cpu().numpy()
  67. labels = labels.cpu().numpy()
  68. bboxes = bboxes.cpu().numpy()
  69. return bboxes, scores, labels
  70. def inference(self, src):
  71. # ---------------- Backbone ----------------
  72. pyramid_feats = self.backbone(src)
  73. # ---------------- Neck ----------------
  74. pyramid_feats = self.backbone_fpn(pyramid_feats)
  75. # ---------------- Heads ----------------
  76. outputs = self.detection_head(pyramid_feats)
  77. cls_pred = outputs["pred_cls"]
  78. box_pred = outputs["pred_box"]
  79. pss_pred = outputs["pred_pss"]
  80. # Post-process (no NMS)
  81. bboxes, scores, labels = self.post_process(cls_pred, box_pred, pss_pred)
  82. # Normalize bbox
  83. bboxes[..., 0::2] /= src.shape[-1]
  84. bboxes[..., 1::2] /= src.shape[-2]
  85. bboxes = bboxes.clip(0., 1.)
  86. outputs = {
  87. 'scores': scores,
  88. 'labels': labels,
  89. 'bboxes': bboxes
  90. }
  91. return outputs
  92. def forward(self, src, src_mask=None):
  93. if not self.training:
  94. return self.inference(src)
  95. else:
  96. # ---------------- Backbone ----------------
  97. pyramid_feats = self.backbone(src)
  98. # ---------------- Neck ----------------
  99. pyramid_feats = self.backbone_fpn(pyramid_feats)
  100. # ---------------- Heads ----------------
  101. outputs = self.detection_head(pyramid_feats, src_mask)
  102. return outputs