fcos.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import copy
  2. import torch
  3. import torch.nn as nn
  4. # --------------- Model components ---------------
  5. from ...backbone import build_backbone
  6. from ...neck import build_neck
  7. from ...head import build_head
  8. # --------------------- End-to-End RT-FCOS ---------------------
  9. class FcosE2E(nn.Module):
  10. def __init__(self,
  11. cfg,
  12. conf_thresh :float = 0.05,
  13. topk_results :int = 1000,
  14. ):
  15. super(FcosE2E, self).__init__()
  16. # ---------------------- Basic Parameters ----------------------
  17. self.conf_thresh = conf_thresh
  18. self.num_classes = cfg.num_classes
  19. self.topk_results = topk_results
  20. # ---------------------- Network Parameters ----------------------
  21. ## Backbone
  22. self.backbone, pyramid_feats = build_backbone(cfg)
  23. ## Neck
  24. self.backbone_fpn = build_neck(cfg, pyramid_feats, cfg.head_dim)
  25. ## Heads (one-to-many)
  26. self.detection_head_o2m = build_head(cfg, cfg.head_dim, cfg.head_dim)
  27. ## Heads (one-to-one)
  28. self.detection_head_o2o = copy.deepcopy(self.detection_head_o2m)
  29. def post_process(self, cls_preds, box_preds):
  30. """
  31. Input:
  32. cls_preds: List(Tensor) [[B, H x W, C], ...]
  33. box_preds: List(Tensor) [[B, H x W, 4], ...]
  34. """
  35. all_scores = []
  36. all_labels = []
  37. all_bboxes = []
  38. for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
  39. cls_pred_i = cls_pred_i[0]
  40. box_pred_i = box_pred_i[0]
  41. # (H x W x C,)
  42. scores_i = cls_pred_i.sigmoid().flatten()
  43. # Keep top k top scoring indices only.
  44. num_topk = min(self.topk_results, box_pred_i.size(0))
  45. # torch.sort is actually faster than .topk (at least on GPUs)
  46. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  47. topk_scores = predicted_prob[:num_topk]
  48. topk_idxs = topk_idxs[:num_topk]
  49. # filter out the proposals with low confidence score
  50. keep_idxs = topk_scores > self.conf_thresh
  51. topk_idxs = topk_idxs[keep_idxs]
  52. # final scores
  53. scores = topk_scores[keep_idxs]
  54. # final labels
  55. labels = topk_idxs % self.num_classes
  56. # final bboxes
  57. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  58. bboxes = box_pred_i[anchor_idxs]
  59. all_scores.append(scores)
  60. all_labels.append(labels)
  61. all_bboxes.append(bboxes)
  62. scores = torch.cat(all_scores)
  63. labels = torch.cat(all_labels)
  64. bboxes = torch.cat(all_bboxes)
  65. # to cpu & numpy
  66. scores = scores.cpu().numpy()
  67. labels = labels.cpu().numpy()
  68. bboxes = bboxes.cpu().numpy()
  69. return bboxes, scores, labels
  70. def inference_o2o(self, src):
  71. # ---------------- Backbone ----------------
  72. pyramid_feats = self.backbone(src)
  73. # ---------------- Neck ----------------
  74. pyramid_feats = self.backbone_fpn(pyramid_feats)
  75. # ---------------- Heads ----------------
  76. outputs = self.detection_head_o2o(pyramid_feats)
  77. cls_pred = outputs["pred_cls"]
  78. box_pred = outputs["pred_box"]
  79. # PostProcess (no NMS)
  80. bboxes, scores, labels = self.post_process(cls_pred, box_pred)
  81. # Normalize bbox
  82. bboxes[..., 0::2] /= src.shape[-1]
  83. bboxes[..., 1::2] /= src.shape[-2]
  84. bboxes = bboxes.clip(0., 1.)
  85. outputs = {
  86. 'scores': scores,
  87. 'labels': labels,
  88. 'bboxes': bboxes
  89. }
  90. return outputs
  91. def forward(self, src, src_mask=None):
  92. if not self.training:
  93. return self.inference_o2o(src)
  94. else:
  95. # ---------------- Backbone ----------------
  96. pyramid_feats = self.backbone(src)
  97. # ---------------- Neck ----------------
  98. pyramid_feats = self.backbone_fpn(pyramid_feats)
  99. # ---------------- Heads ----------------
  100. outputs = {}
  101. ## One-to-many detection
  102. outputs_o2m = self.detection_head_o2m(pyramid_feats, src_mask)
  103. outputs["outputs_o2m"] = outputs_o2m
  104. ## One-to-one detection
  105. pyramid_feats_detach = [feat.detach() for feat in pyramid_feats]
  106. outputs_o2o = self.detection_head_o2o(pyramid_feats_detach, src_mask)
  107. outputs["outputs_o2o"] = outputs_o2o
  108. return outputs