retinanet.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import numpy as np
  2. import math
  3. import torch
  4. import torch.nn as nn
  5. # --------------- Model components ---------------
  6. from ...backbone import build_backbone
  7. from ...neck import build_neck
  8. from ...head import build_head
  9. # --------------- External components ---------------
  10. from utils.misc import multiclass_nms
  11. # ------------------------ RetinaNet ------------------------
  12. class RetinaNet(nn.Module):
  13. def __init__(self,
  14. cfg,
  15. num_classes :int = 80,
  16. conf_thresh :float = 0.05,
  17. nms_thresh :float = 0.6,
  18. topk :int = 1000,
  19. trainable :bool = False,
  20. ca_nms :bool = False):
  21. super(RetinaNet, self).__init__()
  22. # ---------------------- Basic Parameters ----------------------
  23. self.cfg = cfg
  24. self.trainable = trainable
  25. self.topk = topk
  26. self.num_classes = num_classes
  27. self.conf_thresh = conf_thresh
  28. self.nms_thresh = nms_thresh
  29. self.ca_nms = ca_nms
  30. # ---------------------- Network Parameters ----------------------
  31. ## Backbone
  32. self.backbone, feat_dims = build_backbone(cfg, trainable&cfg['pretrained'])
  33. ## Neck
  34. self.fpn = build_neck(cfg, feat_dims, cfg['head_dim'])
  35. ## Heads
  36. self.head = build_head(cfg, cfg['head_dim'], cfg['head_dim'], num_classes)
  37. def post_process(self, cls_preds, box_preds):
  38. """
  39. Input:
  40. cls_preds: List(Tensor) [[B, H x W, KA x C], ...]
  41. box_preds: List(Tensor) [[B, H x W, KA x 4], ...]
  42. """
  43. all_scores = []
  44. all_labels = []
  45. all_bboxes = []
  46. for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
  47. cls_pred_i = cls_pred_i[0]
  48. box_pred_i = box_pred_i[0]
  49. # (H x W x KA x C,)
  50. scores_i = cls_pred_i.sigmoid().flatten()
  51. # Keep top k top scoring indices only.
  52. num_topk = min(self.topk, box_pred_i.size(0))
  53. # torch.sort is actually faster than .topk (at least on GPUs)
  54. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  55. topk_scores = predicted_prob[:num_topk]
  56. topk_idxs = topk_idxs[:num_topk]
  57. # filter out the proposals with low confidence score
  58. keep_idxs = topk_scores > self.conf_thresh
  59. topk_idxs = topk_idxs[keep_idxs]
  60. # final scores
  61. scores = topk_scores[keep_idxs]
  62. # final labels
  63. labels = topk_idxs % self.num_classes
  64. # final bboxes
  65. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  66. bboxes = box_pred_i[anchor_idxs]
  67. all_scores.append(scores)
  68. all_labels.append(labels)
  69. all_bboxes.append(bboxes)
  70. scores = torch.cat(all_scores)
  71. labels = torch.cat(all_labels)
  72. bboxes = torch.cat(all_bboxes)
  73. # to cpu & numpy
  74. scores = scores.cpu().numpy()
  75. labels = labels.cpu().numpy()
  76. bboxes = bboxes.cpu().numpy()
  77. # nms
  78. scores, labels, bboxes = multiclass_nms(
  79. scores, labels, bboxes, self.nms_thresh, self.num_classes, self.ca_nms)
  80. return bboxes, scores, labels
  81. def forward(self, src, src_mask=None, targets=None):
  82. # ---------------- Backbone ----------------
  83. pyramid_feats = self.backbone(src)
  84. # ---------------- Neck ----------------
  85. pyramid_feats = self.fpn(pyramid_feats)
  86. # ---------------- Heads ----------------
  87. outputs = self.head(pyramid_feats, src_mask)
  88. if not self.training:
  89. # ---------------- PostProcess ----------------
  90. cls_pred = outputs["pred_cls"]
  91. box_pred = outputs["pred_box"]
  92. bboxes, scores, labels = self.post_process(cls_pred, box_pred)
  93. # normalize bbox
  94. bboxes[..., 0::2] /= src.shape[-1]
  95. bboxes[..., 1::2] /= src.shape[-2]
  96. bboxes = bboxes.clip(0., 1.)
  97. return bboxes, scores, labels
  98. return outputs