detr.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. from ...backbone import build_backbone
  5. from ...basic.mlp import MLP
  6. from ...basic.conv import BasicConv, UpSampleWrapper
  7. from ...basic.transformer import TransformerEncoder, PlainDETRTransformer, get_clones
  8. from utils.misc import multiclass_nms
  9. # DETR
  10. class DETR(nn.Module):
  11. def __init__(self,
  12. cfg,
  13. num_classes = 80,
  14. conf_thresh = 0.1,
  15. nms_thresh = 0.5,
  16. topk = 300,
  17. use_nms = False,
  18. ca_nms = False,
  19. ):
  20. super().__init__()
  21. # ---------------- Basic setting ----------------
  22. self.stride = cfg['out_stride']
  23. self.upsample_factor = cfg['max_stride'] // cfg['out_stride']
  24. self.num_classes = num_classes
  25. ## Transformer parameters
  26. self.num_queries_one2one = cfg['num_queries_one2one']
  27. self.num_queries_one2many = cfg['num_queries_one2many']
  28. self.num_queries = self.num_queries_one2one + self.num_queries_one2many
  29. ## Post-process parameters
  30. self.ca_nms = ca_nms
  31. self.use_nms = use_nms
  32. self.num_topk = topk
  33. self.nms_thresh = nms_thresh
  34. self.conf_thresh = conf_thresh
  35. # ---------------- Network setting ----------------
  36. ## Backbone Network
  37. self.backbone, feat_dims = build_backbone(cfg)
  38. ## Input projection
  39. self.input_proj = BasicConv(feat_dims[-1], cfg['hidden_dim'], kernel_size=1, act_type=None, norm_type='GN')
  40. ## Transformer Encoder
  41. self.transformer_encoder = TransformerEncoder(d_model = cfg['hidden_dim'],
  42. num_heads = cfg['en_num_heads'],
  43. num_layers = cfg['en_num_layers'],
  44. ffn_dim = cfg['en_ffn_dim'],
  45. dropout = cfg['en_dropout'],
  46. act_type = cfg['en_act'],
  47. pre_norm = cfg['en_pre_norm'],
  48. )
  49. ## Upsample layer
  50. self.upsample = UpSampleWrapper(cfg['hidden_dim'], self.upsample_factor)
  51. ## Output projection
  52. self.output_proj = BasicConv(cfg['hidden_dim'], cfg['hidden_dim'], kernel_size=3, padding=1, act_type='silu', norm_type='BN')
  53. ## Transformer
  54. self.query_embed = nn.Embedding(self.num_queries, cfg['hidden_dim'])
  55. self.transformer = PlainDETRTransformer(d_model = cfg['hidden_dim'],
  56. num_heads = cfg['de_num_heads'],
  57. ffn_dim = cfg['de_ffn_dim'],
  58. dropout = cfg['de_dropout'],
  59. act_type = cfg['de_act'],
  60. pre_norm = cfg['de_pre_norm'],
  61. rpe_hidden_dim = cfg['rpe_hidden_dim'],
  62. feature_stride = cfg['out_stride'],
  63. num_layers = cfg['de_num_layers'],
  64. use_checkpoint = cfg['use_checkpoint'],
  65. num_queries_one2one = cfg['num_queries_one2one'],
  66. num_queries_one2many = cfg['num_queries_one2many'],
  67. proposal_feature_levels = cfg['proposal_feature_levels'],
  68. proposal_in_stride = cfg['out_stride'],
  69. proposal_tgt_strides = cfg['proposal_tgt_strides'],
  70. return_intermediate = True,
  71. )
  72. ## Detect Head
  73. class_embed = nn.Linear(cfg['hidden_dim'], num_classes)
  74. bbox_embed = MLP(cfg['hidden_dim'], cfg['hidden_dim'], 4, 3)
  75. prior_prob = 0.01
  76. bias_value = -math.log((1 - prior_prob) / prior_prob)
  77. class_embed.bias.data = torch.ones(num_classes) * bias_value
  78. nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
  79. nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
  80. self.class_embed = get_clones(class_embed, cfg['de_num_layers'] + 1)
  81. self.bbox_embed = get_clones(bbox_embed, cfg['de_num_layers'] + 1)
  82. nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
  83. self.transformer.decoder.bbox_embed = self.bbox_embed
  84. self.transformer.decoder.class_embed = self.class_embed
  85. def get_posembed(self, d_model, mask, temperature=10000, normalize=False):
  86. not_mask = ~mask
  87. scale = 2 * torch.pi
  88. num_pos_feats = d_model // 2
  89. # -------------- Generate XY coords --------------
  90. ## [B, H, W]
  91. y_embed = not_mask.cumsum(1, dtype=torch.float32)
  92. x_embed = not_mask.cumsum(2, dtype=torch.float32)
  93. ## Normalize coords
  94. if normalize:
  95. y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + 1e-6)
  96. x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + 1e-6)
  97. else:
  98. y_embed = y_embed - 0.5
  99. x_embed = x_embed - 0.5
  100. # [H, W] -> [B, H, W, 2]
  101. pos = torch.stack([x_embed, y_embed], dim=-1)
  102. # -------------- Sine-PosEmbedding --------------
  103. dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
  104. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
  105. dim_t = temperature ** (2 * dim_t_)
  106. x_embed = pos[..., 0] * scale
  107. y_embed = pos[..., 1] * scale
  108. pos_x = x_embed[..., None] / dim_t
  109. pos_y = y_embed[..., None] / dim_t
  110. pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
  111. pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
  112. pos_embed = torch.cat((pos_y, pos_x), dim=-1)
  113. # [B, H, W, C] -> [B, C, H, W]
  114. pos_embed = pos_embed.permute(0, 3, 1, 2)
  115. return pos_embed
  116. def post_process(self, box_pred, cls_pred):
  117. # Top-k select
  118. cls_pred = cls_pred[0].flatten().sigmoid_()
  119. box_pred = box_pred[0]
  120. # Keep top k top scoring indices only.
  121. num_topk = min(self.num_topk, box_pred.size(0))
  122. # Topk candidates
  123. predicted_prob, topk_idxs = cls_pred.sort(descending=True)
  124. topk_scores = predicted_prob[:num_topk]
  125. topk_idxs = topk_idxs[:self.num_topk]
  126. # Filter out the proposals with low confidence score
  127. keep_idxs = topk_scores > self.conf_thresh
  128. topk_scores = topk_scores[keep_idxs]
  129. topk_idxs = topk_idxs[keep_idxs]
  130. topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  131. ## Top-k results
  132. topk_labels = topk_idxs % self.num_classes
  133. topk_bboxes = box_pred[topk_box_idxs]
  134. topk_scores = topk_scores.cpu().numpy()
  135. topk_labels = topk_labels.cpu().numpy()
  136. topk_bboxes = topk_bboxes.cpu().numpy()
  137. # nms
  138. if self.use_nms:
  139. topk_scores, topk_labels, topk_bboxes = multiclass_nms(
  140. topk_scores, topk_labels, topk_bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
  141. return topk_bboxes, topk_scores, topk_labels
  142. def resize_mask(self, src, mask=None):
  143. bs, c, h, w = src.shape
  144. if mask is not None:
  145. # [B, H, W]
  146. mask = nn.functional.interpolate(mask[None].float(), size=[h, w]).bool()[0]
  147. else:
  148. mask = torch.zeros([bs, h, w], device=src.device, dtype=torch.bool)
  149. return mask
  150. @torch.jit.unused
  151. def _set_aux_loss(self, outputs_class, outputs_coord, outputs_coord_old, outputs_deltas):
  152. # this is a workaround to make torchscript happy, as torchscript
  153. # doesn't support dictionary with non-homogeneous values, such
  154. # as a dict having both a Tensor and a list.
  155. return [
  156. {"pred_logits": a, "pred_boxes": b, "pred_boxes_old": c, "pred_deltas": d, }
  157. for a, b, c, d in zip(outputs_class[:-1], outputs_coord[:-1], outputs_coord_old[:-1], outputs_deltas[:-1])
  158. ]
  159. def inference_single_image(self, x):
  160. # ----------- Image Encoder -----------
  161. pyramid_feats = self.backbone(x)
  162. src = self.input_proj(pyramid_feats[-1])
  163. src = self.transformer_encoder(src)
  164. src = self.upsample(src)
  165. src = self.output_proj(src)
  166. # ----------- Prepare inputs for Transformer -----------
  167. mask = self.resize_mask(src)
  168. pos_embed = self.get_posembed(src.shape[1], mask, normalize=False)
  169. query_embeds = self.query_embed.weight[:self.num_queries_one2one]
  170. self_attn_mask = None
  171. # -----------Transformer -----------
  172. (
  173. hs,
  174. init_reference,
  175. inter_references,
  176. _,
  177. _,
  178. _,
  179. _,
  180. max_shape
  181. ) = self.transformer(src, mask, pos_embed, query_embeds, self_attn_mask)
  182. # ----------- Process outputs -----------
  183. outputs_classes_one2one = []
  184. outputs_coords_one2one = []
  185. outputs_deltas_one2one = []
  186. for lid in range(hs.shape[0]):
  187. if lid == 0:
  188. reference = init_reference
  189. else:
  190. reference = inter_references[lid - 1]
  191. outputs_class = self.class_embed[lid](hs[lid])
  192. tmp = self.bbox_embed[lid](hs[lid])
  193. outputs_coord = self.transformer.decoder.delta2bbox(reference, tmp, max_shape) # xyxy
  194. outputs_classes_one2one.append(outputs_class[:, :self.num_queries_one2one])
  195. outputs_coords_one2one.append(outputs_coord[:, :self.num_queries_one2one])
  196. outputs_deltas_one2one.append(tmp[:, :self.num_queries_one2one])
  197. outputs_classes_one2one = torch.stack(outputs_classes_one2one)
  198. outputs_coords_one2one = torch.stack(outputs_coords_one2one)
  199. # ------------ Post process ------------
  200. cls_pred = outputs_classes_one2one[-1]
  201. box_pred = outputs_coords_one2one[-1]
  202. # post-process
  203. bboxes, scores, labels = self.post_process(box_pred, cls_pred)
  204. # normalize bbox
  205. bboxes[..., 0::2] /= x.shape[-1]
  206. bboxes[..., 1::2] /= x.shape[-2]
  207. bboxes = bboxes.clip(0., 1.)
  208. return bboxes, scores, labels
  209. def forward(self, x, src_mask=None, targets=None):
  210. if not self.training:
  211. return self.inference_single_image(x)
  212. # ----------- Image Encoder -----------
  213. pyramid_feats = self.backbone(x)
  214. src = self.input_proj(pyramid_feats[-1])
  215. src = self.transformer_encoder(src)
  216. src = self.upsample(src)
  217. src = self.output_proj(src)
  218. # ----------- Prepare inputs for Transformer -----------
  219. mask = self.resize_mask(src, src_mask)
  220. pos_embed = self.get_posembed(src.shape[1], mask, normalize=False)
  221. query_embeds = self.query_embed.weight
  222. self_attn_mask = torch.zeros(
  223. [self.num_queries, self.num_queries, ]).bool().to(src.device)
  224. self_attn_mask[self.num_queries_one2one:, 0: self.num_queries_one2one, ] = True
  225. self_attn_mask[0: self.num_queries_one2one, self.num_queries_one2one:, ] = True
  226. # -----------Transformer -----------
  227. (
  228. hs,
  229. init_reference,
  230. inter_references,
  231. enc_outputs_class,
  232. enc_outputs_coord_unact,
  233. enc_outputs_delta,
  234. output_proposals,
  235. max_shape
  236. ) = self.transformer(src, mask, pos_embed, query_embeds, self_attn_mask)
  237. # ----------- Process outputs -----------
  238. outputs_classes_one2one = []
  239. outputs_coords_one2one = []
  240. outputs_coords_old_one2one = []
  241. outputs_deltas_one2one = []
  242. outputs_classes_one2many = []
  243. outputs_coords_one2many = []
  244. outputs_coords_old_one2many = []
  245. outputs_deltas_one2many = []
  246. for lid in range(hs.shape[0]):
  247. if lid == 0:
  248. reference = init_reference
  249. else:
  250. reference = inter_references[lid - 1]
  251. outputs_class = self.class_embed[lid](hs[lid])
  252. tmp = self.bbox_embed[lid](hs[lid])
  253. outputs_coord = self.transformer.decoder.box_xyxy_to_cxcywh(
  254. self.transformer.decoder.delta2bbox(reference, tmp, max_shape))
  255. outputs_classes_one2one.append(outputs_class[:, 0: self.num_queries_one2one])
  256. outputs_classes_one2many.append(outputs_class[:, self.num_queries_one2one:])
  257. outputs_coords_one2one.append(outputs_coord[:, 0: self.num_queries_one2one])
  258. outputs_coords_one2many.append(outputs_coord[:, self.num_queries_one2one:])
  259. outputs_coords_old_one2one.append(reference[:, :self.num_queries_one2one])
  260. outputs_coords_old_one2many.append(reference[:, self.num_queries_one2one:])
  261. outputs_deltas_one2one.append(tmp[:, :self.num_queries_one2one])
  262. outputs_deltas_one2many.append(tmp[:, self.num_queries_one2one:])
  263. outputs_classes_one2one = torch.stack(outputs_classes_one2one)
  264. outputs_coords_one2one = torch.stack(outputs_coords_one2one)
  265. outputs_classes_one2many = torch.stack(outputs_classes_one2many)
  266. outputs_coords_one2many = torch.stack(outputs_coords_one2many)
  267. out = {
  268. "pred_logits": outputs_classes_one2one[-1],
  269. "pred_boxes": outputs_coords_one2one[-1],
  270. "pred_logits_one2many": outputs_classes_one2many[-1],
  271. "pred_boxes_one2many": outputs_coords_one2many[-1],
  272. "pred_boxes_old": outputs_coords_old_one2one[-1],
  273. "pred_deltas": outputs_deltas_one2one[-1],
  274. "pred_boxes_old_one2many": outputs_coords_old_one2many[-1],
  275. "pred_deltas_one2many": outputs_deltas_one2many[-1],
  276. }
  277. out["aux_outputs"] = self._set_aux_loss(
  278. outputs_classes_one2one, outputs_coords_one2one, outputs_coords_old_one2one, outputs_deltas_one2one
  279. )
  280. out["aux_outputs_one2many"] = self._set_aux_loss(
  281. outputs_classes_one2many, outputs_coords_one2many, outputs_coords_old_one2many, outputs_deltas_one2many
  282. )
  283. out["enc_outputs"] = {
  284. "pred_logits": enc_outputs_class,
  285. "pred_boxes": enc_outputs_coord_unact,
  286. "pred_boxes_old": output_proposals,
  287. "pred_deltas": enc_outputs_delta,
  288. }
  289. return out