detr.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. # --------------- Model components ---------------
  5. from ...backbone import build_backbone
  6. from ...transformer import build_transformer
  7. from ...basic.mlp import MLP
  8. # Detection with Transformer
  9. class DETR(nn.Module):
  10. def __init__(self,
  11. cfg,
  12. num_classes :int = 90,
  13. conf_thresh :float = 0.05,
  14. topk :int = 1000,
  15. ):
  16. super().__init__()
  17. # ---------------------- Basic Parameters ----------------------
  18. self.cfg = cfg
  19. self.topk = topk
  20. self.num_classes = num_classes
  21. self.conf_thresh = conf_thresh
  22. # ---------------------- Network Parameters ----------------------
  23. ## Backbone
  24. backbone, feat_dims = build_backbone(cfg)
  25. self.backbone = nn.Sequential(backbone)
  26. ## Input proj
  27. self.input_proj = nn.Conv2d(feat_dims[-1], cfg.hidden_dim, kernel_size=1)
  28. ## Transformer
  29. self.transformer = build_transformer(cfg, return_intermediate_dec=True)
  30. ## Object queries
  31. self.query_embed = nn.Embedding(cfg.num_queries, cfg.hidden_dim)
  32. ## Output
  33. self.class_embed = nn.Linear(cfg.hidden_dim, num_classes + 1)
  34. self.bbox_embed = MLP(cfg.hidden_dim, cfg.hidden_dim, 4, 3)
  35. @torch.jit.unused
  36. def set_aux_loss(self, outputs_class, outputs_coord):
  37. return [{'pred_logits': a, 'pred_boxes': b}
  38. for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
  39. def post_process(self, cls_pred, box_pred):
  40. """
  41. Input:
  42. cls_pred: (Tensor) [Nq, C]
  43. box_pred: (Tensor) [Nq, 4]
  44. """
  45. # [Nq x C,]
  46. scores_i = cls_pred.flatten()
  47. # Keep top k top scoring indices only.
  48. num_topk = min(self.topk, box_pred.size(0))
  49. # torch.sort is actually faster than .topk (at least on GPUs)
  50. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  51. topk_scores = predicted_prob[:num_topk]
  52. topk_idxs = topk_idxs[:num_topk]
  53. # filter out the proposals with low confidence score
  54. keep_idxs = topk_scores > self.conf_thresh
  55. topk_idxs = topk_idxs[keep_idxs]
  56. # final scores
  57. scores = topk_scores[keep_idxs]
  58. # final labels
  59. labels = topk_idxs % self.num_classes
  60. # final bboxes
  61. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  62. bboxes = box_pred[anchor_idxs]
  63. # to cpu & numpy
  64. scores = scores.cpu().numpy()
  65. labels = labels.cpu().numpy()
  66. bboxes = bboxes.cpu().numpy()
  67. return bboxes, scores, labels
  68. def forward(self, src, src_mask=None):
  69. # ---------------- Backbone ----------------
  70. pyramid_feats = self.backbone(src)
  71. feat = self.input_proj(pyramid_feats[-1])
  72. if src_mask is not None:
  73. src_mask = F.interpolate(src_mask[None].float(), size=feat.shape[-2:]).bool()[0]
  74. else:
  75. src_mask = torch.zeros([feat.shape[0], *feat.shape[-2:]], device=feat.device, dtype=torch.bool)
  76. # ---------------- Transformer ----------------
  77. hs = self.transformer(feat, src_mask, self.query_embed.weight)[0]
  78. # ---------------- Head ----------------
  79. outputs_class = self.class_embed(hs)
  80. outputs_coord = self.bbox_embed(hs).sigmoid()
  81. if self.training:
  82. outputs = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
  83. outputs['aux_outputs'] = self.set_aux_loss(outputs_class, outputs_coord)
  84. else:
  85. cls_pred = outputs_class[-1].softmax(-1)[..., :-1]
  86. box_pred = outputs_coord[-1]
  87. # [B, N, C] -> [N, C]
  88. cls_pred = cls_pred[0]
  89. box_pred = box_pred[0]
  90. # xywh -> xyxy
  91. cxcy_pred = box_pred[..., :2]
  92. bwbh_pred = box_pred[..., 2:]
  93. x1y1_pred = cxcy_pred - 0.5 * bwbh_pred
  94. x2y2_pred = cxcy_pred + 0.5 * bwbh_pred
  95. box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
  96. # Post-process (no NMS)
  97. bboxes, scores, labels = self.post_process(cls_pred, box_pred)
  98. outputs = {
  99. 'scores': scores,
  100. 'labels': labels,
  101. 'bboxes': bboxes
  102. }
  103. return outputs