rtrdet.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import torch
  2. import torch.nn as nn
  3. from .rtrdet_backbone import build_backbone
  4. from .rtrdet_transformer import build_transformer
  5. # Real-time Detection with Transformer
  6. class RTRDet(nn.Module):
  7. def __init__(self,
  8. cfg,
  9. device,
  10. num_classes :int = 20,
  11. trainable :bool = False,
  12. aux_loss :bool = False,
  13. deploy :bool = False):
  14. super(RTRDet, self).__init__()
  15. assert cfg['out_stride'] == 16 or cfg['out_stride'] == 32
  16. # ------------------ Basic parameters ------------------
  17. self.cfg = cfg
  18. self.device = device
  19. self.out_stride = cfg['out_stride']
  20. self.max_stride = cfg['max_stride']
  21. self.num_levels = 2 if cfg['out_stride'] == 16 else 1
  22. self.num_topk = cfg['num_topk']
  23. self.num_classes = num_classes
  24. self.d_model = round(cfg['d_model'] * cfg['width'])
  25. self.aux_loss = aux_loss
  26. self.trainable = trainable
  27. self.deploy = deploy
  28. # ------------------ Network parameters ------------------
  29. ## Backbone
  30. self.backbone, self.feat_dims = build_backbone(cfg, trainable&cfg['pretrained'])
  31. self.input_projs = nn.ModuleList(nn.Conv2d(self.feat_dims[-i], self.d_model, kernel_size=1) for i in range(1, self.num_levels+1))
  32. ## Transformer
  33. self.transformer = build_transformer(cfg, num_classes, return_intermediate=aux_loss)
  34. @torch.jit.unused
  35. def set_aux_loss(self, outputs_class, outputs_coord):
  36. # this is a workaround to make torchscript happy, as torchscript
  37. # doesn't support dictionary with non-homogeneous values, such
  38. # as a dict having both a Tensor and a list.
  39. return [{'pred_logits': a, 'pred_boxes': b}
  40. for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
  41. # ---------------------- Main Process for Inference ----------------------
  42. @torch.no_grad()
  43. def inference_single_image(self, x):
  44. # -------------------- Inference --------------------
  45. ## Backbone
  46. pyramid_feats = self.backbone(x)
  47. ## Input proj
  48. for idx in range(1, self.num_levels + 1):
  49. pyramid_feats[-idx] = self.input_projs[idx-1](pyramid_feats[-idx])
  50. ## Transformer
  51. if self.num_levels == 2:
  52. src1, src2 = pyramid_feats[-2], pyramid_feats[-1]
  53. else:
  54. src1, src2 = None, pyramid_feats[-1]
  55. output_classes, output_coords = self.transformer(src1, src2)
  56. # -------------------- Post-process --------------------
  57. ## Top-k
  58. cls_pred, box_pred = output_classes[-1].flatten().sigmoid_(), output_coords[-1]
  59. cls_pred = cls_pred[0].flatten().sigmoid_()
  60. box_pred = box_pred[0]
  61. predicted_prob, topk_idxs = cls_pred.sort(descending=True)
  62. topk_idxs = topk_idxs[:self.num_topk]
  63. topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  64. topk_scores = predicted_prob[:self.num_topk]
  65. topk_labels = topk_idxs % self.num_classes
  66. topk_bboxes = box_pred[topk_box_idxs]
  67. ## Denormalize bbox
  68. img_h, img_w = x.shape[-2:]
  69. topk_bboxes[..., 0::2] *= img_w
  70. topk_bboxes[..., 1::2] *= img_h
  71. if self.deploy:
  72. return topk_bboxes, topk_scores, topk_labels
  73. else:
  74. return topk_bboxes.cpu().numpy(), topk_scores.cpu().numpy(), topk_labels.cpu().numpy()
  75. # ---------------------- Main Process for Training ----------------------
  76. def forward(self, x):
  77. if not self.trainable:
  78. return self.inference_single_image(x)
  79. else:
  80. # -------------------- Inference --------------------
  81. ## Backbone
  82. pyramid_feats = self.backbone(x)
  83. ## Input proj
  84. for idx in range(1, self.num_levels + 1):
  85. pyramid_feats[-idx] = self.input_projs[idx-1](pyramid_feats[-idx])
  86. ## Transformer
  87. if self.num_levels == 2:
  88. src1, src2 = pyramid_feats[-2], pyramid_feats[-1]
  89. else:
  90. src1, src2 = None, pyramid_feats[-1]
  91. output_classes, output_coords = self.transformer(src1, src2)
  92. outputs = {'pred_logits': output_classes[-1], 'pred_boxes': output_coords[-1]}
  93. if self.aux_loss:
  94. outputs['aux_outputs'] = self.set_aux_loss(output_classes, output_coords)
  95. return outputs