rtdetr.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import torch
  2. import torch.nn as nn
  3. from .rtdetr_encoder import build_encoder
  4. from .rtdetr_decoder import build_decoder
  5. from .rtdetr_dethead import build_dethead
  6. # Real-time DETR
  7. class RTDETR(nn.Module):
  8. def __init__(self,
  9. cfg,
  10. device,
  11. num_classes = 20,
  12. trainable = False,
  13. aux_loss = False,
  14. with_box_refine = False,
  15. deploy = False):
  16. super(RTDETR, self).__init__()
  17. # --------- Basic Parameters ----------
  18. self.cfg = cfg
  19. self.device = device
  20. self.num_classes = num_classes
  21. self.trainable = trainable
  22. self.max_stride = max(cfg['stride'])
  23. self.d_model = round(cfg['d_model'] * self.cfg['width'])
  24. self.aux_loss = aux_loss
  25. self.with_box_refine = with_box_refine
  26. self.deploy = deploy
  27. # --------- Network Parameters ----------
  28. ## Encoder
  29. self.encoder = build_encoder(cfg, trainable, 'img_encoder')
  30. ## Decoder
  31. self.decoder = build_decoder(cfg, self.d_model, return_intermediate=aux_loss)
  32. ## DetHead
  33. self.dethead = build_dethead(cfg, self.d_model, num_classes, with_box_refine)
  34. # set for TR-Decoder
  35. self.decoder.class_embed = self.dethead.class_embed
  36. self.decoder.bbox_embed = self.dethead.bbox_embed
  37. # ---------------------- Basic Functions ----------------------
  38. @torch.jit.unused
  39. def set_aux_loss(self, outputs_class, outputs_coord):
  40. # this is a workaround to make torchscript happy, as torchscript
  41. # doesn't support dictionary with non-homogeneous values, such
  42. # as a dict having both a Tensor and a list.
  43. return [{'pred_logits': a, 'pred_boxes': b}
  44. for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
  45. # ---------------------- Main Process for Inference ----------------------
  46. @torch.no_grad()
  47. def inference_single_image(self, x):
  48. # -------------------- Encoder --------------------
  49. memory, memory_pos = self.encoder(x)
  50. # -------------------- Decoder --------------------
  51. hs, reference = self.decoder(memory, memory_pos)
  52. # -------------------- DetHead --------------------
  53. out_logits, out_bbox = self.dethead(hs, reference, False)
  54. cls_pred, box_pred = out_logits[0], out_bbox[0]
  55. # -------------------- Top-k --------------------
  56. cls_pred = cls_pred.flatten().sigmoid_()
  57. num_topk = 100
  58. predicted_prob, topk_idxs = cls_pred.sort(descending=True)
  59. topk_idxs = topk_idxs[:num_topk]
  60. topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  61. topk_scores = predicted_prob[:num_topk]
  62. topk_labels = topk_idxs % self.num_classes
  63. topk_bboxes = box_pred[topk_box_idxs]
  64. # denormalize bbox
  65. img_h, img_w = x.shape[-2:]
  66. topk_bboxes[..., 0::2] *= img_w
  67. topk_bboxes[..., 1::2] *= img_h
  68. if self.deploy:
  69. return topk_bboxes, topk_scores, topk_labels
  70. else:
  71. return topk_bboxes.cpu().numpy(), topk_scores.cpu().numpy(), topk_labels.cpu().numpy()
  72. # ---------------------- Main Process for Training ----------------------
  73. def forward(self, x):
  74. if not self.trainable:
  75. return self.inference_single_image(x)
  76. else:
  77. # -------------------- Encoder --------------------
  78. memory, memory_pos = self.encoder(x)
  79. # -------------------- Decoder --------------------
  80. hs, reference = self.decoder(memory, memory_pos)
  81. # -------------------- DetHead --------------------
  82. outputs_class, outputs_coords = self.dethead(hs, reference, True)
  83. outputs = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coords[-1]}
  84. if self.aux_loss:
  85. outputs['aux_outputs'] = self.set_aux_loss(outputs_class, outputs_coords)
  86. return outputs