rtdetr_decoder.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from typing import List
  6. from .basic_modules.conv import BasicConv
  7. from .basic_modules.mlp import MLP
  8. from .basic_modules.transformer import DeformableTransformerDecoder
  9. from .basic_modules.dn_compoments import get_contrastive_denoising_training_group
  10. # ----------------- Dencoder for Detection task -----------------
  11. ## RTDETR's Transformer for Detection task
  12. class RTDetrTransformer(nn.Module):
  13. def __init__(self,
  14. # basic parameters
  15. in_dims :List = [256, 512, 1024],
  16. hidden_dim :int = 256,
  17. strides :List = [8, 16, 32],
  18. num_classes :int = 80,
  19. num_queries :int = 300,
  20. # transformer parameters
  21. num_heads :int = 8,
  22. num_layers :int = 1,
  23. num_levels :int = 3,
  24. num_points :int = 4,
  25. ffn_dim :int = 1024,
  26. dropout :float = 0.1,
  27. act_type :str = "relu",
  28. return_intermediate :bool = False,
  29. # Denoising parameters
  30. num_denoising :int = 100,
  31. label_noise_ratio :float = 0.5,
  32. box_noise_scale :float = 1.0,
  33. learnt_init_query :bool = False,
  34. aux_loss :bool = True
  35. ):
  36. super().__init__()
  37. # --------------- Basic setting ---------------
  38. ## Basic parameters
  39. self.in_dims = in_dims
  40. self.strides = strides
  41. self.num_queries = num_queries
  42. self.num_classes = num_classes
  43. self.eps = 1e-2
  44. self.aux_loss = aux_loss
  45. ## Transformer parameters
  46. self.num_heads = num_heads
  47. self.num_layers = num_layers
  48. self.num_levels = num_levels
  49. self.num_points = num_points
  50. self.ffn_dim = ffn_dim
  51. self.dropout = dropout
  52. self.act_type = act_type
  53. self.return_intermediate = return_intermediate
  54. ## Denoising parameters
  55. self.num_denoising = num_denoising
  56. self.label_noise_ratio = label_noise_ratio
  57. self.box_noise_scale = box_noise_scale
  58. self.learnt_init_query = learnt_init_query
  59. # --------------- Network setting ---------------
  60. ## Input proj layers
  61. self.input_proj_layers = nn.ModuleList(
  62. BasicConv(in_dims[i], hidden_dim, kernel_size=1, act_type=None, norm_type="BN")
  63. for i in range(num_levels)
  64. )
  65. ## Deformable transformer decoder
  66. self.decoder = DeformableTransformerDecoder(
  67. d_model = hidden_dim,
  68. num_heads = num_heads,
  69. num_layers = num_layers,
  70. num_levels = num_levels,
  71. num_points = num_points,
  72. ffn_dim = ffn_dim,
  73. dropout = dropout,
  74. act_type = act_type,
  75. return_intermediate = return_intermediate
  76. )
  77. ## Detection head for Encoder
  78. self.enc_output = nn.Sequential(
  79. nn.Linear(hidden_dim, hidden_dim),
  80. nn.LayerNorm(hidden_dim)
  81. )
  82. self.enc_class_head = nn.Linear(hidden_dim, num_classes)
  83. self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
  84. ## Detection head for Decoder
  85. self.dec_class_head = nn.ModuleList([
  86. nn.Linear(hidden_dim, num_classes)
  87. for _ in range(num_layers)
  88. ])
  89. self.dec_bbox_head = nn.ModuleList([
  90. MLP(hidden_dim, hidden_dim, 4, num_layers=3)
  91. for _ in range(num_layers)
  92. ])
  93. ## Object query
  94. if learnt_init_query:
  95. self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
  96. self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
  97. ## Denoising part
  98. if num_denoising > 0:
  99. self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes)
  100. self._reset_parameters()
  101. def _reset_parameters(self):
  102. # class and bbox head init
  103. prior_prob = 0.01
  104. cls_bias_init = float(-math.log((1 - prior_prob) / prior_prob))
  105. nn.init.constant_(self.enc_class_head.bias, cls_bias_init)
  106. nn.init.constant_(self.enc_bbox_head.layers[-1].weight, 0.)
  107. nn.init.constant_(self.enc_bbox_head.layers[-1].bias, 0.)
  108. for cls_, reg_ in zip(self.dec_class_head, self.dec_bbox_head):
  109. nn.init.constant_(cls_.bias, cls_bias_init)
  110. nn.init.constant_(reg_.layers[-1].weight, 0.)
  111. nn.init.constant_(reg_.layers[-1].bias, 0.)
  112. nn.init.xavier_uniform_(self.enc_output[0].weight)
  113. if self.learnt_init_query:
  114. nn.init.xavier_uniform_(self.tgt_embed.weight)
  115. nn.init.xavier_uniform_(self.query_pos_head.layers[0].weight)
  116. nn.init.xavier_uniform_(self.query_pos_head.layers[1].weight)
  117. @torch.jit.unused
  118. def _set_aux_loss(self, outputs_class, outputs_coord):
  119. # this is a workaround to make torchscript happy, as torchscript
  120. # doesn't support dictionary with non-homogeneous values, such
  121. # as a dict having both a Tensor and a list.
  122. return [{'pred_logits': a, 'pred_boxes': b}
  123. for a, b in zip(outputs_class, outputs_coord)]
  124. def generate_anchors(self, spatial_shapes, grid_size=0.05):
  125. anchors = []
  126. for lvl, (h, w) in enumerate(spatial_shapes):
  127. grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w))
  128. # [H, W, 2]
  129. grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()
  130. valid_WH = torch.as_tensor([w, h]).float()
  131. grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
  132. wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
  133. # [H, W, 4] -> [1, N, 4], N=HxW
  134. anchors.append(torch.cat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4))
  135. # List[L, 1, N_i, 4] -> [1, N, 4], N=N_0 + N_1 + N_2 + ...
  136. anchors = torch.cat(anchors, dim=1)
  137. valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
  138. anchors = torch.log(anchors / (1 - anchors))
  139. # Equal to operation: anchors = torch.masked_fill(anchors, ~valid_mask, torch.as_tensor(float("inf")))
  140. anchors = torch.where(valid_mask, anchors, torch.inf)
  141. return anchors, valid_mask
  142. def get_encoder_input(self, feats):
  143. # get projection features
  144. proj_feats = [self.input_proj_layers[i](feat) for i, feat in enumerate(feats)]
  145. # get encoder inputs
  146. feat_flatten = []
  147. spatial_shapes = []
  148. level_start_index = [0, ]
  149. for i, feat in enumerate(proj_feats):
  150. _, _, h, w = feat.shape
  151. spatial_shapes.append([h, w])
  152. # [l], start index of each level
  153. level_start_index.append(h * w + level_start_index[-1])
  154. # [B, C, H, W] -> [B, N, C], N=HxW
  155. feat_flatten.append(feat.flatten(2).permute(0, 2, 1).contiguous())
  156. # [B, N, C], N = N_0 + N_1 + ...
  157. feat_flatten = torch.cat(feat_flatten, dim=1)
  158. level_start_index.pop()
  159. return (feat_flatten, spatial_shapes, level_start_index)
  160. def get_decoder_input(self,
  161. memory,
  162. spatial_shapes,
  163. denoising_class=None,
  164. denoising_bbox_unact=None):
  165. bs, _, _ = memory.shape
  166. # Prepare input for decoder
  167. anchors, valid_mask = self.generate_anchors(spatial_shapes)
  168. anchors = anchors.to(memory.device)
  169. valid_mask = valid_mask.to(memory.device)
  170. # Process encoder's output
  171. memory = torch.where(valid_mask, memory, torch.as_tensor(0., device=memory.device))
  172. output_memory = self.enc_output(memory)
  173. # Head for encoder's output : [bs, num_quries, c]
  174. enc_outputs_class = self.enc_class_head(output_memory)
  175. enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
  176. # Topk proposals from encoder's output
  177. topk = self.num_queries
  178. topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] # [bs, num_queries]
  179. enc_topk_logits = torch.gather(
  180. enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes)) # [bs, num_queries, nc]
  181. reference_points_unact = torch.gather(
  182. enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4)) # [bs, num_queries, 4]
  183. enc_topk_bboxes = F.sigmoid(reference_points_unact)
  184. if denoising_bbox_unact is not None:
  185. reference_points_unact = torch.cat(
  186. [denoising_bbox_unact, reference_points_unact], dim=1)
  187. # Extract region features
  188. if self.learnt_init_query:
  189. # [num_queries, c] -> [b, num_queries, c]
  190. target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
  191. else:
  192. # [num_queries, c] -> [b, num_queries, c]
  193. target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
  194. target = target.detach()
  195. if denoising_class is not None:
  196. target = torch.cat([denoising_class, target], dim=1)
  197. return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
  198. def forward(self, feats, targets=None):
  199. # input projection and embedding
  200. memory, spatial_shapes, _ = self.get_encoder_input(feats)
  201. # prepare denoising training
  202. if self.training and self.num_denoising > 0:
  203. denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
  204. get_contrastive_denoising_training_group(targets, \
  205. self.num_classes,
  206. self.num_queries,
  207. self.denoising_class_embed,
  208. num_denoising=self.num_denoising,
  209. label_noise_ratio=self.label_noise_ratio,
  210. box_noise_scale=self.box_noise_scale, )
  211. else:
  212. denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
  213. target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
  214. self.get_decoder_input(
  215. memory, spatial_shapes, denoising_class, denoising_bbox_unact)
  216. # decoder
  217. out_bboxes, out_logits = self.decoder(target,
  218. init_ref_points_unact,
  219. memory,
  220. spatial_shapes,
  221. self.dec_bbox_head,
  222. self.dec_class_head,
  223. self.query_pos_head,
  224. attn_mask)
  225. if self.training and dn_meta is not None:
  226. dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2)
  227. dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2)
  228. out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
  229. if self.training and self.aux_loss:
  230. out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1])
  231. out['aux_outputs'].extend(self._set_aux_loss([enc_topk_logits], [enc_topk_bboxes]))
  232. if self.training and dn_meta is not None:
  233. out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
  234. out['dn_meta'] = dn_meta
  235. return out
  236. ## RTDETR's Transformer for Instance Segmentation task (not complete yet)
  237. class MaskRTDetrTransformer(RTDetrTransformer):
  238. def __init__(self,
  239. # basic parameters
  240. in_dims :List = [256, 512, 1024],
  241. hidden_dim :int = 256,
  242. strides :List = [8, 16, 32],
  243. num_classes :int = 80,
  244. num_queries :int = 300,
  245. # transformer parameters
  246. num_heads :int = 8,
  247. num_layers :int = 1,
  248. num_levels :int = 3,
  249. num_points :int = 4,
  250. ffn_dim :int = 1024,
  251. dropout :float = 0.1,
  252. act_type :str = "relu",
  253. return_intermediate :bool = False,
  254. # Denoising parameters
  255. num_denoising :int = 100,
  256. label_noise_ratio :float = 0.5,
  257. box_noise_scale :float = 1.0,
  258. learnt_init_query :bool = False,
  259. aux_loss :bool = True
  260. ):
  261. super().__init__()
  262. def forward(self, feats, targets=None):
  263. return