rtrdet.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .rtrdet_backbone import build_backbone
  5. from .rtrdet_encoder import build_encoder
  6. from .rtrdet_decoder import build_decoder
  7. # Real-time Detection with Transformer
  8. class RTRDet(nn.Module):
  9. def __init__(self,
  10. cfg,
  11. device,
  12. num_classes :int = 20,
  13. trainable :bool = False,
  14. aux_loss :bool = False,
  15. deploy :bool = False):
  16. super(RTRDet, self).__init__()
  17. # ------------------ Basic parameters ------------------
  18. self.cfg = cfg
  19. self.device = device
  20. self.max_stride = cfg['max_stride']
  21. self.num_topk = cfg['num_topk']
  22. self.d_model = round(cfg['d_model'] * cfg['width'])
  23. self.num_classes = num_classes
  24. self.aux_loss = aux_loss
  25. self.trainable = trainable
  26. self.deploy = deploy
  27. # ------------------ Network parameters ------------------
  28. ## Backbone
  29. self.backbone, self.feat_dims = build_backbone(cfg, trainable&cfg['pretrained'])
  30. self.input_proj1 = nn.Conv2d(self.feat_dims[-1], self.d_model, kernel_size=1)
  31. self.input_proj2 = nn.Conv2d(self.feat_dims[-2], self.d_model, kernel_size=1)
  32. ## Transformer Encoder
  33. self.encoder = build_encoder(cfg)
  34. ## Transformer Decoder
  35. self.decoder = build_decoder(cfg, num_classes, return_intermediate=aux_loss)
  36. # ---------------------- Basic Functions ----------------------
  37. def position_embedding(self, x, temperature=10000):
  38. hs, ws = x.shape[-2:]
  39. device = x.device
  40. num_pos_feats = x.shape[1] // 2
  41. scale = 2 * 3.141592653589793
  42. # generate xy coord mat
  43. y_embed, x_embed = torch.meshgrid(
  44. [torch.arange(1, hs+1, dtype=torch.float32),
  45. torch.arange(1, ws+1, dtype=torch.float32)])
  46. y_embed = y_embed / (hs + 1e-6) * scale
  47. x_embed = x_embed / (ws + 1e-6) * scale
  48. # [H, W] -> [1, H, W]
  49. y_embed = y_embed[None, :, :].to(device)
  50. x_embed = x_embed[None, :, :].to(device)
  51. dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=device)
  52. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
  53. dim_t = temperature ** (2 * dim_t_)
  54. pos_x = torch.div(x_embed[:, :, :, None], dim_t)
  55. pos_y = torch.div(y_embed[:, :, :, None], dim_t)
  56. pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
  57. pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
  58. # [B, C, H, W]
  59. pos_embed = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  60. return pos_embed
  61. @torch.jit.unused
  62. def set_aux_loss(self, outputs_class, outputs_coord):
  63. # this is a workaround to make torchscript happy, as torchscript
  64. # doesn't support dictionary with non-homogeneous values, such
  65. # as a dict having both a Tensor and a list.
  66. return [{'pred_logits': a, 'pred_boxes': b}
  67. for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
  68. # ---------------------- Main Process for Inference ----------------------
  69. @torch.no_grad()
  70. def inference_single_image(self, x):
  71. # -------------------- Inference --------------------
  72. ## Backbone
  73. pyramid_feats = self.backbone(x)
  74. high_level_feat = self.input_proj1(pyramid_feats[-1])
  75. bs, c, h, w = high_level_feat.size()
  76. ## Transformer Encoder
  77. pos_embed1 = self.position_embedding(high_level_feat)
  78. high_level_feat = self.encoder(high_level_feat, pos_embed1, self.decoder.adapt_pos2d)
  79. high_level_feat = high_level_feat.permute(0, 2, 1).reshape(bs, c, h, w)
  80. p4_level_feat = self.input_proj2(pyramid_feats[-2]) + F.interpolate(high_level_feat, scale_factor=2.0)
  81. ## Transformer Decoder
  82. pos_embed2 = self.position_embedding(p4_level_feat)
  83. output_classes, output_coords = self.decoder(p4_level_feat, pos_embed2)
  84. # -------------------- Post-process --------------------
  85. ## Top-k
  86. cls_pred, box_pred = output_classes[-1].flatten().sigmoid_(), output_coords[-1]
  87. cls_pred = cls_pred[0].flatten().sigmoid_()
  88. box_pred = box_pred[0]
  89. predicted_prob, topk_idxs = cls_pred.sort(descending=True)
  90. topk_idxs = topk_idxs[:self.num_topk]
  91. topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  92. topk_scores = predicted_prob[:self.num_topk]
  93. topk_labels = topk_idxs % self.num_classes
  94. topk_bboxes = box_pred[topk_box_idxs]
  95. ## Denormalize bbox
  96. img_h, img_w = x.shape[-2:]
  97. topk_bboxes[..., 0::2] *= img_w
  98. topk_bboxes[..., 1::2] *= img_h
  99. if self.deploy:
  100. return topk_bboxes, topk_scores, topk_labels
  101. else:
  102. return topk_bboxes.cpu().numpy(), topk_scores.cpu().numpy(), topk_labels.cpu().numpy()
  103. # ---------------------- Main Process for Training ----------------------
  104. def forward(self, x):
  105. if not self.trainable:
  106. return self.inference_single_image(x)
  107. else:
  108. # -------------------- Inference --------------------
  109. ## Backbone
  110. pyramid_feats = self.backbone(x)
  111. high_level_feat = self.input_proj1(pyramid_feats[-1])
  112. bs, c, h, w = high_level_feat.size()
  113. ## Transformer Encoder
  114. pos_embed1 = self.position_embedding(high_level_feat)
  115. high_level_feat = self.encoder(high_level_feat, pos_embed1, self.decoder.adapt_pos2d)
  116. high_level_feat = high_level_feat.permute(0, 2, 1).reshape(bs, c, h, w)
  117. p4_level_feat = self.input_proj2(pyramid_feats[-2]) + F.interpolate(high_level_feat, scale_factor=2.0)
  118. ## Transformer Decoder
  119. pos_embed2 = self.position_embedding(p4_level_feat)
  120. output_classes, output_coords = self.decoder(p4_level_feat, pos_embed2)
  121. outputs = {'pred_logits': output_classes[-1], 'pred_boxes': output_coords[-1]}
  122. if self.aux_loss:
  123. outputs['aux_outputs'] = self.set_aux_loss(output_classes, output_coords)
  124. return outputs