rtdetr_decoder.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.nn.init import constant_, xavier_uniform_, uniform_, normal_
  6. from typing import List
  7. try:
  8. from .basic_modules.basic import BasicConv, MLP
  9. from .basic_modules.transformer import DeformableTransformerDecoder
  10. from .basic_modules.dn_compoments import get_contrastive_denoising_training_group
  11. except:
  12. from basic_modules.basic import BasicConv, MLP
  13. from basic_modules.transformer import DeformableTransformerDecoder
  14. from basic_modules.dn_compoments import get_contrastive_denoising_training_group
  15. def build_transformer(cfg, in_dims, num_classes, return_intermediate=False):
  16. if cfg['transformer'] == 'rtdetr_transformer':
  17. return RTDETRTransformer(in_dims = in_dims,
  18. hidden_dim = cfg['hidden_dim'],
  19. strides = cfg['out_stride'],
  20. num_classes = num_classes,
  21. num_queries = cfg['num_queries'],
  22. pos_embed_type = 'sine',
  23. num_heads = cfg['de_num_heads'],
  24. num_layers = cfg['de_num_layers'],
  25. num_levels = len(cfg['out_stride']),
  26. num_points = cfg['de_num_points'],
  27. mlp_ratio = cfg['de_mlp_ratio'],
  28. dropout = cfg['de_dropout'],
  29. act_type = cfg['de_act'],
  30. return_intermediate = return_intermediate,
  31. num_denoising = cfg['dn_num_denoising'],
  32. label_noise_ratio = cfg['dn_label_noise_ratio'],
  33. box_noise_scale = cfg['dn_box_noise_scale'],
  34. learnt_init_query = cfg['learnt_init_query'],
  35. )
  36. # ----------------- Dencoder for Detection task -----------------
  37. ## RTDETR's Transformer for Detection task
  38. class RTDETRTransformer(nn.Module):
  39. def __init__(self,
  40. # basic parameters
  41. in_dims :List = [256, 512, 1024],
  42. hidden_dim :int = 256,
  43. strides :List = [8, 16, 32],
  44. num_classes :int = 80,
  45. num_queries :int = 300,
  46. pos_embed_type :str = 'sine',
  47. # transformer parameters
  48. num_heads :int = 8,
  49. num_layers :int = 1,
  50. num_levels :int = 3,
  51. num_points :int = 4,
  52. mlp_ratio :float = 4.0,
  53. dropout :float = 0.1,
  54. act_type :str = "relu",
  55. return_intermediate :bool = False,
  56. # Denoising parameters
  57. num_denoising :int = 100,
  58. label_noise_ratio :float = 0.5,
  59. box_noise_scale :float = 1.0,
  60. learnt_init_query :bool = True,
  61. ):
  62. super().__init__()
  63. # --------------- Basic setting ---------------
  64. ## Basic parameters
  65. self.in_dims = in_dims
  66. self.strides = strides
  67. self.num_queries = num_queries
  68. self.pos_embed_type = pos_embed_type
  69. self.num_classes = num_classes
  70. self.eps = 1e-2
  71. ## Transformer parameters
  72. self.num_heads = num_heads
  73. self.num_layers = num_layers
  74. self.num_levels = num_levels
  75. self.num_points = num_points
  76. self.mlp_ratio = mlp_ratio
  77. self.dropout = dropout
  78. self.act_type = act_type
  79. self.return_intermediate = return_intermediate
  80. ## Denoising parameters
  81. self.num_denoising = num_denoising
  82. self.label_noise_ratio = label_noise_ratio
  83. self.box_noise_scale = box_noise_scale
  84. self.learnt_init_query = learnt_init_query
  85. # --------------- Network setting ---------------
  86. ## Input proj layers
  87. self.input_proj_layers = nn.ModuleList(
  88. BasicConv(in_dims[i], hidden_dim, kernel_size=1, act_type=None, norm_type="BN")
  89. for i in range(num_levels)
  90. )
  91. ## Deformable transformer decoder
  92. self.decoder = DeformableTransformerDecoder(
  93. d_model = hidden_dim,
  94. num_heads = num_heads,
  95. num_layers = num_layers,
  96. num_levels = num_levels,
  97. num_points = num_points,
  98. mlp_ratio = mlp_ratio,
  99. dropout = dropout,
  100. act_type = act_type,
  101. return_intermediate = return_intermediate
  102. )
  103. ## Detection head for Encoder
  104. self.enc_output = nn.Sequential(
  105. nn.Linear(hidden_dim, hidden_dim),
  106. nn.LayerNorm(hidden_dim)
  107. )
  108. self.enc_class_head = nn.Linear(hidden_dim, num_classes)
  109. self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
  110. ## Detection head for Decoder
  111. self.dec_class_head = nn.ModuleList([
  112. nn.Linear(hidden_dim, num_classes)
  113. for _ in range(num_layers)
  114. ])
  115. self.dec_bbox_head = nn.ModuleList([
  116. MLP(hidden_dim, hidden_dim, 4, num_layers=3)
  117. for _ in range(num_layers)
  118. ])
  119. ## Object query
  120. if learnt_init_query:
  121. self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
  122. self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
  123. ## Denoising part
  124. self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim)
  125. self._reset_parameters()
  126. def _reset_parameters(self):
  127. def linear_init_(module):
  128. bound = 1 / math.sqrt(module.weight.shape[0])
  129. uniform_(module.weight, -bound, bound)
  130. if hasattr(module, "bias") and module.bias is not None:
  131. uniform_(module.bias, -bound, bound)
  132. # class and bbox head init
  133. prior_prob = 0.01
  134. cls_bias_init = float(-math.log((1 - prior_prob) / prior_prob))
  135. linear_init_(self.enc_class_head)
  136. constant_(self.enc_class_head.bias, cls_bias_init)
  137. constant_(self.enc_bbox_head.layers[-1].weight, 0.)
  138. constant_(self.enc_bbox_head.layers[-1].bias, 0.)
  139. for cls_, reg_ in zip(self.dec_class_head, self.dec_bbox_head):
  140. linear_init_(cls_)
  141. constant_(cls_.bias, cls_bias_init)
  142. constant_(reg_.layers[-1].weight, 0.)
  143. constant_(reg_.layers[-1].bias, 0.)
  144. linear_init_(self.enc_output[0])
  145. xavier_uniform_(self.enc_output[0].weight)
  146. if self.learnt_init_query:
  147. xavier_uniform_(self.tgt_embed.weight)
  148. xavier_uniform_(self.query_pos_head.layers[0].weight)
  149. xavier_uniform_(self.query_pos_head.layers[1].weight)
  150. for l in self.input_proj_layers:
  151. xavier_uniform_(l.conv.weight)
  152. normal_(self.denoising_class_embed.weight)
  153. def generate_anchors(self, spatial_shapes, grid_size=0.05):
  154. anchors = []
  155. for lvl, (h, w) in enumerate(spatial_shapes):
  156. grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w))
  157. # [H, W, 2]
  158. grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()
  159. valid_WH = torch.as_tensor([w, h]).float()
  160. grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
  161. wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
  162. # [H, W, 4] -> [1, N, 4], N=HxW
  163. anchors.append(torch.cat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4))
  164. # List[L, 1, N_i, 4] -> [1, N, 4], N=N_0 + N_1 + N_2 + ...
  165. anchors = torch.cat(anchors, dim=1)
  166. valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
  167. anchors = torch.log(anchors / (1 - anchors))
  168. # Equal to operation: anchors = torch.masked_fill(anchors, ~valid_mask, torch.as_tensor(float("inf")))
  169. anchors = torch.where(valid_mask, anchors, torch.as_tensor(float("inf")))
  170. return anchors, valid_mask
  171. def get_encoder_input(self, feats):
  172. # get projection features
  173. proj_feats = [self.input_proj_layers[i](feat) for i, feat in enumerate(feats)]
  174. # get encoder inputs
  175. feat_flatten = []
  176. spatial_shapes = []
  177. level_start_index = [0, ]
  178. for i, feat in enumerate(proj_feats):
  179. _, _, h, w = feat.shape
  180. spatial_shapes.append([h, w])
  181. # [l], start index of each level
  182. level_start_index.append(h * w + level_start_index[-1])
  183. # [B, C, H, W] -> [B, N, C], N=HxW
  184. feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
  185. # [B, N, C], N = N_0 + N_1 + ...
  186. feat_flatten = torch.cat(feat_flatten, dim=1)
  187. level_start_index.pop()
  188. return (feat_flatten, spatial_shapes, level_start_index)
  189. def get_decoder_input(self,
  190. memory,
  191. spatial_shapes,
  192. denoising_class=None,
  193. denoising_bbox_unact=None):
  194. bs, _, _ = memory.shape
  195. # Prepare input for decoder
  196. anchors, valid_mask = self.generate_anchors(spatial_shapes)
  197. anchors = anchors.to(memory.device)
  198. valid_mask = valid_mask.to(memory.device)
  199. # Process encoder's output
  200. memory = torch.where(valid_mask, memory, torch.as_tensor(0., device=memory.device))
  201. output_memory = self.enc_output(memory)
  202. # Head for encoder's output : [bs, num_quries, c]
  203. enc_outputs_class = self.enc_class_head(output_memory)
  204. enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
  205. # Topk proposals from encoder's output
  206. topk = self.num_queries
  207. topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] # [bs, num_queries]
  208. enc_topk_logits = torch.gather(
  209. enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes)) # [bs, num_queries, nc]
  210. reference_points_unact = torch.gather(
  211. enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4)) # [bs, num_queries, 4]
  212. enc_topk_bboxes = F.sigmoid(reference_points_unact)
  213. if denoising_bbox_unact is not None:
  214. reference_points_unact = torch.cat(
  215. [denoising_bbox_unact, reference_points_unact], 1)
  216. if self.training:
  217. reference_points_unact = reference_points_unact.detach()
  218. # Extract region features
  219. if self.learnt_init_query:
  220. # [num_queries, c] -> [b, num_queries, c]
  221. target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
  222. else:
  223. # [num_queries, c] -> [b, num_queries, c]
  224. target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
  225. if self.training:
  226. target = target.detach()
  227. if denoising_class is not None:
  228. target = torch.cat([denoising_class, target], dim=1)
  229. return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits
  230. def forward(self, feats, targets=None):
  231. # input projection and embedding
  232. memory, spatial_shapes, _ = self.get_encoder_input(feats)
  233. # prepare denoising training
  234. if self.training:
  235. denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
  236. get_contrastive_denoising_training_group(targets,
  237. self.num_classes,
  238. self.num_queries,
  239. self.denoising_class_embed.weight,
  240. self.num_denoising,
  241. self.label_noise_ratio,
  242. self.box_noise_scale)
  243. else:
  244. denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
  245. target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
  246. self.get_decoder_input(
  247. memory, spatial_shapes, denoising_class, denoising_bbox_unact)
  248. # decoder
  249. out_bboxes, out_logits = self.decoder(target,
  250. init_ref_points_unact,
  251. memory,
  252. spatial_shapes,
  253. self.dec_bbox_head,
  254. self.dec_class_head,
  255. self.query_pos_head,
  256. attn_mask)
  257. return out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta
  258. # ----------------- Dencoder for Segmentation task -----------------
  259. ## RTDETR's Transformer for Segmentation task
  260. class SegTransformerDecoder(nn.Module):
  261. def __init__(self, ):
  262. super().__init__()
  263. # TODO: design seg-decoder
  264. def forward(self, x):
  265. return
  266. # ----------------- Dencoder for Pose estimation task -----------------
  267. ## RTDETR's Transformer for Pose estimation task
  268. class PosTransformerDecoder(nn.Module):
  269. def __init__(self, ):
  270. super().__init__()
  271. # TODO: design seg-decoder
  272. def forward(self, x):
  273. return
  274. if __name__ == '__main__':
  275. import time
  276. from thop import profile
  277. cfg = {
  278. 'out_stride': [8, 16, 32],
  279. # Transformer Decoder
  280. 'transformer': 'rtdetr_transformer',
  281. 'hidden_dim': 256,
  282. 'de_num_heads': 8,
  283. 'de_num_layers': 6,
  284. 'de_mlp_ratio': 4.0,
  285. 'de_dropout': 0.1,
  286. 'de_act': 'gelu',
  287. 'de_num_points': 4,
  288. 'num_queries': 300,
  289. 'learnt_init_query': False,
  290. 'pe_temperature': 10000.,
  291. 'dn_num_denoising': 100,
  292. 'dn_label_noise_ratio': 0.5,
  293. 'dn_box_noise_scale': 1,
  294. }
  295. bs = 1
  296. hidden_dim = cfg['hidden_dim']
  297. in_dims = [hidden_dim] * 3
  298. targets = [{
  299. 'labels': torch.tensor([2, 4, 5, 8]).long(),
  300. 'boxes': torch.tensor([[0, 0, 10, 10], [12, 23, 56, 70], [0, 10, 20, 30], [50, 60, 55, 150]]).float()
  301. }] * bs
  302. pyramid_feats = [torch.randn(bs, hidden_dim, 80, 80),
  303. torch.randn(bs, hidden_dim, 40, 40),
  304. torch.randn(bs, hidden_dim, 20, 20)]
  305. model = build_transformer(cfg, in_dims, 80, True)
  306. model.train()
  307. t0 = time.time()
  308. outputs = model(pyramid_feats, targets)
  309. out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = outputs
  310. t1 = time.time()
  311. print('Time: ', t1 - t0)
  312. print(out_bboxes.shape)
  313. print(out_logits.shape)
  314. print(enc_topk_bboxes.shape)
  315. print(enc_topk_logits.shape)
  316. print('==============================')
  317. model.eval()
  318. flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
  319. print('==============================')
  320. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  321. print('Params : {:.2f} M'.format(params / 1e6))