ctrnet.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # Objects as Points
  2. # --------------- Torch components ---------------
  3. import torch
  4. import torch.nn as nn
  5. # --------------- Model components ---------------
  6. from .ctrnet_encoder import build_encoder
  7. from .ctrnet_decoder import build_decoder
  8. from .ctrnet_neck import build_neck
  9. from .ctrnet_head import build_det_head
  10. from .ctrnet_pred import build_det_pred
  11. # CenterNet
  12. class CenterNet(nn.Module):
  13. def __init__(self,
  14. cfg,
  15. device,
  16. num_classes = 20,
  17. conf_thresh = 0.01,
  18. topk = 1000,
  19. trainable = False,
  20. deploy = False,
  21. no_multi_labels = False,
  22. nms_class_agnostic = False,
  23. ):
  24. super(CenterNet, self).__init__()
  25. # ---------------- Basic Parameters ----------------
  26. self.cfg = cfg
  27. self.device = device
  28. self.stride = cfg['out_stride']
  29. self.num_classes = num_classes
  30. self.trainable = trainable
  31. self.conf_thresh = conf_thresh
  32. self.num_classes = num_classes
  33. self.topk_candidates = topk
  34. self.deploy = deploy
  35. self.no_multi_labels = no_multi_labels
  36. self.nms_class_agnostic = nms_class_agnostic
  37. self.head_dim = round(256 * cfg['width'])
  38. # ---------------- Network Parameters ----------------
  39. ## Encoder
  40. self.encoder, feat_dims = build_encoder(cfg)
  41. ## Neck
  42. self.neck = build_neck(cfg, feat_dims[-1], feat_dims[-1])
  43. self.feat_dim = self.neck.out_dim
  44. ## Decoder
  45. self.decoder = build_decoder(cfg, self.feat_dim, self.head_dim)
  46. ## Head
  47. self.det_head = nn.Sequential(
  48. build_det_head(cfg, self.head_dim, self.head_dim),
  49. build_det_pred(self.head_dim, self.head_dim, self.stride, num_classes, 4)
  50. )
  51. ## Aux Head
  52. self.aux_det_head = nn.Sequential(
  53. build_det_head(cfg, self.head_dim, self.head_dim),
  54. build_det_pred(self.head_dim, self.head_dim, self.stride, num_classes, 4)
  55. )
  56. # Post process
  57. def post_process(self, cls_pred, box_pred):
  58. """
  59. Input:
  60. cls_pred: torch.Tensor -> [M, C]
  61. box_pred: torch.Tensor -> [M, 4]
  62. Output:
  63. bboxes: np.array -> [N, 4]
  64. scores: np.array -> [N,]
  65. labels: np.array -> [N,]
  66. """
  67. cls_pred = cls_pred[0]
  68. box_pred = box_pred[0]
  69. if self.no_multi_labels:
  70. # [M,]
  71. scores, labels = torch.max(cls_pred.sigmoid(), dim=1)
  72. # Keep top k top scoring indices only.
  73. num_topk = min(self.topk_candidates, box_pred.size(0))
  74. # topk candidates
  75. predicted_prob, topk_idxs = scores.sort(descending=True)
  76. topk_scores = predicted_prob[:num_topk]
  77. topk_idxs = topk_idxs[:num_topk]
  78. # filter out the proposals with low confidence score
  79. keep_idxs = topk_scores > self.conf_thresh
  80. scores = topk_scores[keep_idxs]
  81. topk_idxs = topk_idxs[keep_idxs]
  82. labels = labels[topk_idxs]
  83. bboxes = box_pred[topk_idxs]
  84. else:
  85. # [M, C] -> [MC,]
  86. scores = cls_pred.sigmoid().flatten()
  87. # Keep top k top scoring indices only.
  88. num_topk = min(self.topk_candidates, box_pred.size(0))
  89. # torch.sort is actually faster than .topk (at least on GPUs)
  90. predicted_prob, topk_idxs = scores.sort(descending=True)
  91. topk_scores = predicted_prob[:num_topk]
  92. topk_idxs = topk_idxs[:num_topk]
  93. # filter out the proposals with low confidence score
  94. keep_idxs = topk_scores > self.conf_thresh
  95. scores = topk_scores[keep_idxs]
  96. topk_idxs = topk_idxs[keep_idxs]
  97. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  98. labels = topk_idxs % self.num_classes
  99. bboxes = box_pred[anchor_idxs]
  100. # to cpu & numpy
  101. scores = scores.cpu().numpy()
  102. labels = labels.cpu().numpy()
  103. bboxes = bboxes.cpu().numpy()
  104. return bboxes, scores, labels
  105. # Main process
  106. def forward(self, x):
  107. # ---------------- Backbone ----------------
  108. pyramid_feats = self.encoder(x)
  109. # ---------------- Neck ----------------
  110. feat = self.neck(pyramid_feats[-1])
  111. # ---------------- Encoder ----------------
  112. feat = self.decoder(feat)
  113. # ---------------- Head ----------------
  114. outputs = self.det_head(feat)
  115. if self.trainable:
  116. outputs['aux_outputs'] = self.aux_det_head(feat)
  117. # ---------------- Post-process ----------------
  118. if not self.trainable:
  119. cls_preds = outputs['pred_cls']
  120. box_preds = outputs['pred_box']
  121. if self.deploy:
  122. scores = cls_preds[0].sigmoid()
  123. bboxes = box_preds[0]
  124. # [n_anchors_all, 4 + C]
  125. outputs = torch.cat([bboxes, scores], dim=-1)
  126. else:
  127. # post process
  128. bboxes, scores, labels = self.post_process(cls_preds, box_preds)
  129. outputs = {
  130. "scores": scores,
  131. "labels": labels,
  132. "bboxes": bboxes
  133. }
  134. return outputs