rtdetr.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import torch
  2. import torch.nn as nn
  3. from .rtdetr_encoder import build_encoder
  4. from .rtdetr_compressor import build_compressor
  5. from .rtdetr_decoder import build_decoder
  6. from .rtdetr_dethead import build_dethead
  7. # Real-time DETR
  8. class RTDETR(nn.Module):
  9. def __init__(self,
  10. cfg,
  11. device,
  12. num_classes = 20,
  13. trainable = False,
  14. aux_loss = False,
  15. with_box_refine = False,
  16. deploy = False):
  17. super(RTDETR, self).__init__()
  18. # --------- Basic Parameters ----------
  19. self.cfg = cfg
  20. self.device = device
  21. self.num_classes = num_classes
  22. self.trainable = trainable
  23. self.max_stride = max(cfg['stride'])
  24. self.d_model = round(cfg['d_model'] * self.cfg['width'])
  25. self.aux_loss = aux_loss
  26. self.with_box_refine = with_box_refine
  27. self.deploy = deploy
  28. # --------- Network Parameters ----------
  29. ## Encoder
  30. self.encoder = build_encoder(cfg, trainable, 'img_encoder')
  31. ## Compressor
  32. self.compressor = build_compressor(cfg, self.d_model)
  33. ## Decoder
  34. self.decoder = build_decoder(cfg, self.d_model, return_intermediate=aux_loss)
  35. ## DetHead
  36. self.dethead = build_dethead(cfg, self.d_model, num_classes, with_box_refine)
  37. # set for TR-Decoder
  38. self.decoder.class_embed = self.dethead.class_embed
  39. self.decoder.bbox_embed = self.dethead.bbox_embed
  40. # ---------------------- Basic Functions ----------------------
  41. def position_embedding(self, x, temperature=10000):
  42. hs, ws = x.shape[-2:]
  43. device = x.device
  44. num_pos_feats = x.shape[1] // 2
  45. scale = 2 * 3.141592653589793
  46. # generate xy coord mat
  47. y_embed, x_embed = torch.meshgrid(
  48. [torch.arange(1, hs+1, dtype=torch.float32),
  49. torch.arange(1, ws+1, dtype=torch.float32)])
  50. y_embed = y_embed / (hs + 1e-6) * scale
  51. x_embed = x_embed / (ws + 1e-6) * scale
  52. # [H, W] -> [1, H, W]
  53. y_embed = y_embed[None, :, :].to(device)
  54. x_embed = x_embed[None, :, :].to(device)
  55. dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=device)
  56. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
  57. dim_t = temperature ** (2 * dim_t_)
  58. pos_x = torch.div(x_embed[:, :, :, None], dim_t)
  59. pos_y = torch.div(y_embed[:, :, :, None], dim_t)
  60. pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
  61. pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
  62. # [B, C, H, W]
  63. pos_embed = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  64. return pos_embed
  65. @torch.jit.unused
  66. def set_aux_loss(self, outputs_class, outputs_coord):
  67. # this is a workaround to make torchscript happy, as torchscript
  68. # doesn't support dictionary with non-homogeneous values, such
  69. # as a dict having both a Tensor and a list.
  70. return [{'pred_logits': a, 'pred_boxes': b}
  71. for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
  72. # ---------------------- Main Process for Inference ----------------------
  73. @torch.no_grad()
  74. def inference_single_image(self, x):
  75. # -------------------- Encoder --------------------
  76. pyramid_feats = self.encoder(x)
  77. # -------------------- Pos Embed --------------------
  78. memory = torch.cat([feat.flatten(2) for feat in pyramid_feats], dim=-1)
  79. memory_pos = torch.cat([self.position_embedding(feat).flatten(2) for feat in pyramid_feats], dim=-1)
  80. memory = memory.permute(0, 2, 1).contiguous()
  81. memory_pos = memory_pos.permute(0, 2, 1).contiguous()
  82. # -------------------- Compressor --------------------
  83. compressed_memory = self.compressor(memory, memory_pos)
  84. # -------------------- Decoder --------------------
  85. hs, reference = self.decoder(compressed_memory, None)
  86. # -------------------- DetHead --------------------
  87. out_logits, out_bbox = self.dethead(hs, reference, False)
  88. cls_pred, box_pred = out_logits[0], out_bbox[0]
  89. # -------------------- Top-k --------------------
  90. cls_pred = cls_pred.flatten().sigmoid_()
  91. num_topk = 100
  92. predicted_prob, topk_idxs = cls_pred.sort(descending=True)
  93. topk_idxs = topk_idxs[:num_topk]
  94. topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  95. topk_scores = predicted_prob[:num_topk]
  96. topk_labels = topk_idxs % self.num_classes
  97. topk_bboxes = box_pred[topk_box_idxs]
  98. # denormalize bbox
  99. img_h, img_w = x.shape[-2:]
  100. topk_bboxes[..., 0::2] *= img_w
  101. topk_bboxes[..., 1::2] *= img_h
  102. if self.deploy:
  103. return topk_bboxes, topk_scores, topk_labels
  104. else:
  105. return topk_bboxes.cpu().numpy(), topk_scores.cpu().numpy(), topk_labels.cpu().numpy()
  106. # ---------------------- Main Process for Training ----------------------
  107. def forward(self, x):
  108. if not self.trainable:
  109. return self.inference_single_image(x)
  110. else:
  111. # -------------------- Encoder --------------------
  112. pyramid_feats = self.encoder(x)
  113. # -------------------- Pos Embed --------------------
  114. memory = torch.cat([feat.flatten(2) for feat in pyramid_feats], dim=-1)
  115. memory_pos = torch.cat([self.position_embedding(feat).flatten(2) for feat in pyramid_feats], dim=-1)
  116. memory = memory.permute(0, 2, 1).contiguous()
  117. memory_pos = memory_pos.permute(0, 2, 1).contiguous()
  118. # -------------------- Compressor --------------------
  119. compressed_memory = self.compressor(memory, memory_pos)
  120. # -------------------- Decoder --------------------
  121. hs, reference = self.decoder(compressed_memory, None)
  122. # -------------------- DetHead --------------------
  123. outputs_class, outputs_coords = self.dethead(hs, reference, True)
  124. outputs = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coords[-1]}
  125. if self.aux_loss:
  126. outputs['aux_outputs'] = self.set_aux_loss(outputs_class, outputs_coords)
  127. return outputs