rtdetr.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. import torch
  2. import torch.nn as nn
  3. from .rtdetr_encoder import ImageEncoder
  4. from .rtdetr_decoder import RTDetrTransformer
  5. from .basic_modules.nms_ops import multiclass_nms
  6. # Real-time DETR
  7. class RTDETR(nn.Module):
  8. def __init__(self,
  9. cfg,
  10. is_val = False,
  11. use_nms = False,
  12. onnx_deploy = False,
  13. ) -> None:
  14. super(RTDETR, self).__init__()
  15. # ---------------------- Basic setting ----------------------
  16. self.cfg = cfg
  17. self.use_nms = use_nms
  18. self.onnx_deploy = onnx_deploy
  19. self.num_classes = cfg.num_classes
  20. ## Post-process parameters
  21. self.topk_candidates = cfg.val_topk if is_val else cfg.test_topk
  22. self.conf_thresh = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
  23. self.nms_thresh = cfg.val_nms_thresh if is_val else cfg.test_nms_thresh
  24. self.no_multi_labels = False if is_val else True
  25. # ----------- Network setting -----------
  26. ## Image encoder
  27. self.image_encoder = ImageEncoder(cfg)
  28. ## Detect decoder
  29. self.detect_decoder = RTDetrTransformer(in_dims = self.image_encoder.fpn_dims,
  30. hidden_dim = cfg.hidden_dim,
  31. strides = cfg.out_stride,
  32. num_classes = cfg.num_classes,
  33. num_queries = cfg.num_queries,
  34. num_heads = cfg.de_num_heads,
  35. num_layers = cfg.de_num_layers,
  36. num_levels = len(cfg.out_stride),
  37. num_points = cfg.de_num_points,
  38. ffn_dim = cfg.de_ffn_dim,
  39. dropout = cfg.de_dropout,
  40. act_type = cfg.de_act,
  41. return_intermediate = True,
  42. num_denoising = cfg.dn_num_denoising,
  43. label_noise_ratio = cfg.dn_label_noise_ratio,
  44. box_noise_scale = cfg.dn_box_noise_scale,
  45. learnt_init_query = cfg.learnt_init_query,
  46. )
  47. def post_process(self, box_pred, cls_pred):
  48. # xywh -> xyxy
  49. box_preds_x1y1 = box_pred[..., :2] - 0.5 * box_pred[..., 2:]
  50. box_preds_x2y2 = box_pred[..., :2] + 0.5 * box_pred[..., 2:]
  51. box_pred = torch.cat([box_preds_x1y1, box_preds_x2y2], dim=-1)
  52. cls_pred = cls_pred[0]
  53. box_pred = box_pred[0]
  54. if self.no_multi_labels:
  55. # [M,]
  56. scores, labels = torch.max(cls_pred.sigmoid(), dim=1)
  57. # Keep top k top scoring indices only.
  58. num_topk = min(self.topk_candidates, box_pred.size(0))
  59. # Topk candidates
  60. predicted_prob, topk_idxs = scores.sort(descending=True)
  61. topk_scores = predicted_prob[:num_topk]
  62. topk_idxs = topk_idxs[:num_topk]
  63. # Filter out the proposals with low confidence score
  64. keep_idxs = topk_scores > self.conf_thresh
  65. topk_idxs = topk_idxs[keep_idxs]
  66. # Top-k results
  67. topk_scores = topk_scores[keep_idxs]
  68. topk_labels = labels[topk_idxs]
  69. topk_bboxes = box_pred[topk_idxs]
  70. else:
  71. # Top-k select
  72. cls_pred = cls_pred.flatten().sigmoid_()
  73. box_pred = box_pred
  74. # Keep top k top scoring indices only.
  75. num_topk = min(self.topk_candidates, box_pred.size(0))
  76. # Topk candidates
  77. predicted_prob, topk_idxs = cls_pred.sort(descending=True)
  78. topk_scores = predicted_prob[:num_topk]
  79. topk_idxs = topk_idxs[:num_topk]
  80. # Filter out the proposals with low confidence score
  81. keep_idxs = topk_scores > self.conf_thresh
  82. topk_scores = topk_scores[keep_idxs]
  83. topk_idxs = topk_idxs[keep_idxs]
  84. topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  85. ## Top-k results
  86. topk_labels = topk_idxs % self.num_classes
  87. topk_bboxes = box_pred[topk_box_idxs]
  88. if not self.onnx_deploy:
  89. topk_scores = topk_scores.cpu().numpy()
  90. topk_labels = topk_labels.cpu().numpy()
  91. topk_bboxes = topk_bboxes.cpu().numpy()
  92. # nms
  93. if self.use_nms:
  94. topk_scores, topk_labels, topk_bboxes = multiclass_nms(
  95. topk_scores, topk_labels, topk_bboxes, self.nms_thresh, self.num_classes)
  96. return topk_bboxes, topk_scores, topk_labels
  97. def forward(self, x, targets=None):
  98. # ----------- Image Encoder -----------
  99. pyramid_feats = self.image_encoder(x)
  100. # ----------- Transformer -----------
  101. outputs = self.detect_decoder(pyramid_feats, targets)
  102. if not self.training:
  103. img_h, img_w = x.shape[2:]
  104. box_pred = outputs["pred_boxes"]
  105. cls_pred = outputs["pred_logits"]
  106. # rescale bbox
  107. box_pred[..., [0, 2]] *= img_h
  108. box_pred[..., [1, 3]] *= img_w
  109. # post-process
  110. bboxes, scores, labels = self.post_process(box_pred, cls_pred)
  111. outputs = {
  112. "scores": scores,
  113. "labels": labels,
  114. "bboxes": bboxes,
  115. }
  116. return outputs