rtdetr.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .basic_modules.basic import multiclass_nms
  5. from .rtdetr_encoder import build_image_encoder
  6. from .rtdetr_decoder import build_transformer
  7. except:
  8. from basic_modules.basic import multiclass_nms
  9. from rtdetr_encoder import build_image_encoder
  10. from rtdetr_decoder import build_transformer
  11. # Real-time DETR
  12. class RT_DETR(nn.Module):
  13. def __init__(self,
  14. cfg,
  15. num_classes = 80,
  16. conf_thresh = 0.1,
  17. nms_thresh = 0.5,
  18. topk = 300,
  19. deploy = False,
  20. no_multi_labels = False,
  21. use_nms = False,
  22. nms_class_agnostic = False,
  23. ):
  24. super().__init__()
  25. # ----------- Basic setting -----------
  26. self.num_classes = num_classes
  27. self.num_topk = topk
  28. self.deploy = deploy
  29. ## Post-process parameters
  30. self.use_nms = use_nms
  31. self.nms_thresh = nms_thresh
  32. self.conf_thresh = conf_thresh
  33. self.no_multi_labels = no_multi_labels
  34. self.nms_class_agnostic = nms_class_agnostic
  35. # ----------- Network setting -----------
  36. ## Image encoder
  37. self.image_encoder = build_image_encoder(cfg)
  38. self.fpn_dims = self.image_encoder.fpn_dims
  39. ## Detect decoder
  40. self.detect_decoder = build_transformer(cfg, self.fpn_dims, num_classes, return_intermediate=self.training)
  41. def post_process(self, box_pred, cls_pred):
  42. # xywh -> xyxy
  43. box_preds_x1y1 = box_pred[..., :2] - 0.5 * box_pred[..., 2:]
  44. box_preds_x2y2 = box_pred[..., :2] + 0.5 * box_pred[..., 2:]
  45. box_pred = torch.cat([box_preds_x1y1, box_preds_x2y2], dim=-1)
  46. cls_pred = cls_pred[0]
  47. box_pred = box_pred[0]
  48. if self.no_multi_labels:
  49. # [M,]
  50. scores, labels = torch.max(cls_pred.sigmoid(), dim=1)
  51. # Keep top k top scoring indices only.
  52. num_topk = min(self.num_topk, box_pred.size(0))
  53. # Topk candidates
  54. predicted_prob, topk_idxs = scores.sort(descending=True)
  55. topk_scores = predicted_prob[:num_topk]
  56. topk_idxs = topk_idxs[:num_topk]
  57. # Filter out the proposals with low confidence score
  58. keep_idxs = topk_scores > self.conf_thresh
  59. topk_idxs = topk_idxs[keep_idxs]
  60. # Top-k results
  61. topk_scores = topk_scores[keep_idxs]
  62. topk_labels = labels[topk_idxs]
  63. topk_bboxes = box_pred[topk_idxs]
  64. else:
  65. # Top-k select
  66. cls_pred = cls_pred.flatten().sigmoid_()
  67. box_pred = box_pred
  68. # Keep top k top scoring indices only.
  69. num_topk = min(self.num_topk, box_pred.size(0))
  70. # Topk candidates
  71. predicted_prob, topk_idxs = cls_pred.sort(descending=True)
  72. topk_scores = predicted_prob[:num_topk]
  73. topk_idxs = topk_idxs[:self.num_topk]
  74. # Filter out the proposals with low confidence score
  75. keep_idxs = topk_scores > self.conf_thresh
  76. topk_scores = topk_scores[keep_idxs]
  77. topk_idxs = topk_idxs[keep_idxs]
  78. topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  79. ## Top-k results
  80. topk_labels = topk_idxs % self.num_classes
  81. topk_bboxes = box_pred[topk_box_idxs]
  82. if not self.deploy:
  83. topk_scores = topk_scores.cpu().numpy()
  84. topk_labels = topk_labels.cpu().numpy()
  85. topk_bboxes = topk_bboxes.cpu().numpy()
  86. # nms
  87. if self.use_nms:
  88. topk_scores, topk_labels, topk_bboxes = multiclass_nms(
  89. topk_scores, topk_labels, topk_bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
  90. return topk_bboxes, topk_scores, topk_labels
  91. def forward(self, x, targets=None):
  92. # ----------- Image Encoder -----------
  93. pyramid_feats = self.image_encoder(x)
  94. # ----------- Transformer -----------
  95. outputs = self.detect_decoder(pyramid_feats, targets)
  96. if not self.training:
  97. img_h, img_w = x.shape[2:]
  98. box_pred = outputs["pred_boxes"]
  99. cls_pred = outputs["pred_logits"]
  100. # rescale bbox
  101. box_pred[..., [0, 2]] *= img_h
  102. box_pred[..., [1, 3]] *= img_w
  103. # post-process
  104. bboxes, scores, labels = self.post_process(box_pred, cls_pred)
  105. outputs = {
  106. "scores": scores,
  107. "labels": labels,
  108. "bboxes": bboxes,
  109. }
  110. return outputs
  111. if __name__ == '__main__':
  112. import time
  113. from thop import profile
  114. from loss import build_criterion
  115. # Model config
  116. cfg = {
  117. # Image Encoder - Backbone
  118. 'backbone': 'resnet101',
  119. 'backbone_norm': 'BN',
  120. 'res5_dilation': False,
  121. 'pretrained': False,
  122. 'pretrained_weight': 'imagenet1k_v1',
  123. 'freeze_at': 0,
  124. 'freeze_stem_only': False,
  125. 'out_stride': [8, 16, 32],
  126. 'max_stride': 32,
  127. # Image Encoder - FPN
  128. 'fpn': 'hybrid_encoder',
  129. 'fpn_num_blocks': 4,
  130. 'fpn_act': 'silu',
  131. 'fpn_norm': 'BN',
  132. 'fpn_depthwise': False,
  133. 'hidden_dim': 384,
  134. 'en_num_heads': 8,
  135. 'en_num_layers': 1,
  136. 'en_ffn_dim': 2048,
  137. 'en_dropout': 0.0,
  138. 'pe_temperature': 10000.,
  139. 'en_act': 'gelu',
  140. # Transformer Decoder
  141. 'transformer': 'rtdetr_transformer',
  142. 'de_num_heads': 8,
  143. 'de_num_layers': 6,
  144. 'de_ffn_dim': 2048,
  145. 'de_dropout': 0.0,
  146. 'de_act': 'gelu',
  147. 'de_num_points': 4,
  148. 'num_queries': 300,
  149. 'learnt_init_query': False,
  150. 'pe_temperature': 10000.,
  151. 'dn_num_denoising': 100,
  152. 'dn_label_noise_ratio': 0.5,
  153. 'dn_box_noise_scale': 1,
  154. # Matcher
  155. 'matcher_hpy': {'cost_class': 2.0,
  156. 'cost_bbox': 5.0,
  157. 'cost_giou': 2.0,},
  158. # Loss
  159. 'use_vfl': True,
  160. 'loss_coeff': {'class': 1,
  161. 'bbox': 5,
  162. 'giou': 2,
  163. 'no_object': 0.1,},
  164. }
  165. bs = 1
  166. # Create a batch of images & targets
  167. image = torch.randn(bs, 3, 640, 640).cuda()
  168. targets = [{
  169. 'labels': torch.tensor([2, 4, 5, 8]).long().cuda(),
  170. 'boxes': torch.tensor([[0, 0, 10, 10], [12, 23, 56, 70], [0, 10, 20, 30], [50, 60, 55, 150]]).float().cuda() / 640.
  171. }] * bs
  172. # Create model
  173. model = RT_DETR(cfg, num_classes=20)
  174. model.train().cuda()
  175. # Create criterion
  176. criterion = build_criterion(cfg, num_classes=20)
  177. # Model inference
  178. outputs = model(image, targets)
  179. # Compute loss
  180. loss = criterion(outputs, targets)
  181. for k in loss.keys():
  182. print("{} : {}".format(k, loss[k].item()))
  183. # Inference
  184. with torch.no_grad():
  185. model.eval()
  186. t0 = time.time()
  187. outputs = model(image)
  188. t1 = time.time()
  189. print('Infer time: ', t1 - t0)
  190. print('==============================')
  191. model.eval()
  192. flops, params = profile(model, inputs=(image, ), verbose=False)
  193. print('==============================')
  194. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  195. print('Params : {:.2f} M'.format(params / 1e6))