rtdetr_decoder.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. from torch.nn.init import constant_, xavier_uniform_, uniform_
  6. from typing import List
  7. try:
  8. from .basic_modules.basic import BasicConv, MLP
  9. from .basic_modules.transformer import DeformableTransformerDecoder
  10. from .basic_modules.dn_compoments import get_contrastive_denoising_training_group
  11. except:
  12. from basic_modules.basic import BasicConv, MLP
  13. from basic_modules.transformer import DeformableTransformerDecoder
  14. from basic_modules.dn_compoments import get_contrastive_denoising_training_group
  15. def build_transformer(cfg, in_dims, num_classes, return_intermediate=False):
  16. if cfg['transformer'] == 'rtdetr_transformer':
  17. return RTDETRTransformer(in_dims = in_dims,
  18. hidden_dim = cfg['hidden_dim'],
  19. strides = cfg['out_stride'],
  20. num_classes = num_classes,
  21. num_queries = cfg['num_queries'],
  22. pos_embed_type = 'sine',
  23. num_heads = cfg['de_num_heads'],
  24. num_layers = cfg['de_num_layers'],
  25. num_levels = len(cfg['out_stride']),
  26. num_points = cfg['de_num_points'],
  27. mlp_ratio = cfg['de_mlp_ratio'],
  28. dropout = cfg['de_dropout'],
  29. act_type = cfg['de_act'],
  30. return_intermediate = return_intermediate,
  31. num_denoising = cfg['dn_num_denoising'],
  32. label_noise_ratio = cfg['dn_label_noise_ratio'],
  33. box_noise_scale = cfg['dn_box_noise_scale'],
  34. learnt_init_query = cfg['learnt_init_query'],
  35. )
  36. # ----------------- Dencoder for Detection task -----------------
  37. ## RTDETR's Transformer for Detection task
  38. class RTDETRTransformer(nn.Module):
  39. def __init__(self,
  40. # basic parameters
  41. in_dims :List = [256, 512, 1024],
  42. hidden_dim :int = 256,
  43. strides :List = [8, 16, 32],
  44. num_classes :int = 80,
  45. num_queries :int = 300,
  46. pos_embed_type :str = 'sine',
  47. # transformer parameters
  48. num_heads :int = 8,
  49. num_layers :int = 1,
  50. num_levels :int = 3,
  51. num_points :int = 4,
  52. mlp_ratio :float = 4.0,
  53. dropout :float = 0.1,
  54. act_type :str = "relu",
  55. return_intermediate :bool = False,
  56. # Denoising parameters
  57. num_denoising :int = 100,
  58. label_noise_ratio :float = 0.5,
  59. box_noise_scale :float = 1.0,
  60. learnt_init_query :bool = True,
  61. ):
  62. super().__init__()
  63. # --------------- Basic setting ---------------
  64. ## Basic parameters
  65. self.in_dims = in_dims
  66. self.strides = strides
  67. self.num_queries = num_queries
  68. self.pos_embed_type = pos_embed_type
  69. self.num_classes = num_classes
  70. self.eps = 1e-2
  71. ## Transformer parameters
  72. self.num_heads = num_heads
  73. self.num_layers = num_layers
  74. self.num_levels = num_levels
  75. self.num_points = num_points
  76. self.mlp_ratio = mlp_ratio
  77. self.dropout = dropout
  78. self.act_type = act_type
  79. self.return_intermediate = return_intermediate
  80. ## Denoising parameters
  81. self.num_denoising = num_denoising
  82. self.label_noise_ratio = label_noise_ratio
  83. self.box_noise_scale = box_noise_scale
  84. self.learnt_init_query = learnt_init_query
  85. # --------------- Network setting ---------------
  86. ## Input proj layers
  87. self.input_proj_layers = nn.ModuleList(
  88. BasicConv(in_dims[i], hidden_dim, kernel_size=1, act_type=None, norm_type="BN")
  89. for i in range(num_levels)
  90. )
  91. ## Deformable transformer decoder
  92. self.transformer_decoder = DeformableTransformerDecoder(
  93. d_model = hidden_dim,
  94. num_heads = num_heads,
  95. num_layers = num_layers,
  96. num_levels = num_levels,
  97. num_points = num_points,
  98. mlp_ratio = mlp_ratio,
  99. dropout = dropout,
  100. act_type = act_type,
  101. return_intermediate = return_intermediate
  102. )
  103. ## Detection head for Encoder
  104. self.enc_output = nn.Sequential(
  105. nn.Linear(hidden_dim, hidden_dim),
  106. nn.LayerNorm(hidden_dim)
  107. )
  108. self.enc_class_head = nn.Linear(hidden_dim, num_classes)
  109. self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
  110. ## Detection head for Decoder
  111. self.dec_class_head = nn.ModuleList([
  112. nn.Linear(hidden_dim, num_classes)
  113. for _ in range(num_layers)
  114. ])
  115. self.dec_bbox_head = nn.ModuleList([
  116. MLP(hidden_dim, hidden_dim, 4, num_layers=3)
  117. for _ in range(num_layers)
  118. ])
  119. ## Denoising part
  120. self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim)
  121. ## Object query
  122. if learnt_init_query:
  123. self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
  124. self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
  125. self._reset_parameters()
  126. def _reset_parameters(self):
  127. def _linear_init(module):
  128. bound = 1 / math.sqrt(module.weight.shape[0])
  129. uniform_(module.weight, -bound, bound)
  130. if hasattr(module, "bias") and module.bias is not None:
  131. uniform_(module.bias, -bound, bound)
  132. # class and bbox head init
  133. prior_prob = 0.01
  134. cls_bias_init = float(-math.log((1 - prior_prob) / prior_prob))
  135. _linear_init(self.enc_class_head)
  136. constant_(self.enc_class_head.bias, cls_bias_init)
  137. constant_(self.enc_bbox_head.layers[-1].weight, 0.)
  138. constant_(self.enc_bbox_head.layers[-1].bias, 0.)
  139. for cls_, reg_ in zip(self.dec_class_head, self.dec_bbox_head):
  140. _linear_init(cls_)
  141. constant_(cls_.bias, cls_bias_init)
  142. constant_(reg_.layers[-1].weight, 0.)
  143. constant_(reg_.layers[-1].bias, 0.)
  144. _linear_init(self.enc_output[0])
  145. xavier_uniform_(self.enc_output[0].weight)
  146. if self.learnt_init_query:
  147. xavier_uniform_(self.tgt_embed.weight)
  148. xavier_uniform_(self.query_pos_head.layers[0].weight)
  149. xavier_uniform_(self.query_pos_head.layers[1].weight)
  150. for l in self.input_proj_layers:
  151. xavier_uniform_(l.conv.weight)
  152. def generate_anchors(self, spatial_shapes, grid_size=0.05):
  153. anchors = []
  154. for lvl, (h, w) in enumerate(spatial_shapes):
  155. grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w))
  156. grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()
  157. valid_WH = torch.as_tensor([w, h]).float()
  158. grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
  159. wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
  160. anchors.append(torch.cat([grid_xy, wh], -1).reshape([-1, h * w, 4]))
  161. anchors = torch.cat(anchors, 1)
  162. valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
  163. anchors = torch.log(anchors / (1 - anchors))
  164. anchors = torch.where(valid_mask, anchors, torch.as_tensor(float("inf")))
  165. return anchors, valid_mask
  166. def get_encoder_input(self, feats):
  167. # get projection features
  168. proj_feats = [self.input_proj_layers[i](feat) for i, feat in enumerate(feats)]
  169. # get encoder inputs
  170. feat_flatten = []
  171. spatial_shapes = []
  172. level_start_index = [0, ]
  173. for i, feat in enumerate(proj_feats):
  174. _, _, h, w = feat.shape
  175. # [b, c, h, w] -> [b, h*w, c]
  176. feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
  177. # [num_levels, 2]
  178. spatial_shapes.append([h, w])
  179. # [l], start index of each level
  180. level_start_index.append(h * w + level_start_index[-1])
  181. # [b, l, c]
  182. feat_flatten = torch.cat(feat_flatten, 1)
  183. level_start_index.pop()
  184. return (feat_flatten, spatial_shapes, level_start_index)
  185. def get_decoder_input(self,
  186. memory,
  187. spatial_shapes,
  188. denoising_class=None,
  189. denoising_bbox_unact=None):
  190. bs, _, _ = memory.shape
  191. # prepare input for decoder
  192. anchors, valid_mask = self.generate_anchors(spatial_shapes)
  193. memory = torch.where(valid_mask, memory, torch.as_tensor(0.))
  194. output_memory = self.enc_output(memory)
  195. # [bs, num_quries, c]
  196. enc_outputs_class = self.enc_class_head(output_memory)
  197. enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
  198. topk = self.num_queries
  199. topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] # [bs, topk]
  200. reference_points_unact = torch.gather(enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4)) # [bs, topk, 4]
  201. enc_topk_bboxes = F.sigmoid(reference_points_unact)
  202. if denoising_bbox_unact is not None:
  203. reference_points_unact = torch.cat(
  204. [denoising_bbox_unact, reference_points_unact], 1)
  205. if self.training:
  206. reference_points_unact = reference_points_unact.detach()
  207. enc_topk_logits = torch.gather(enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes)) # [bs, topk, nc]
  208. # extract region features
  209. if self.learnt_init_query:
  210. target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
  211. else:
  212. target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
  213. if self.training:
  214. target = target.detach()
  215. if denoising_class is not None:
  216. target = torch.cat([denoising_class, target], dim=1)
  217. return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits
  218. def forward(self, feats, gt_meta=None):
  219. # input projection and embedding
  220. memory, spatial_shapes, _ = self.get_encoder_input(feats)
  221. # prepare denoising training
  222. if self.training:
  223. denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
  224. get_contrastive_denoising_training_group(gt_meta,
  225. self.num_classes,
  226. self.num_queries,
  227. self.denoising_class_embed.weight,
  228. self.num_denoising,
  229. self.label_noise_ratio,
  230. self.box_noise_scale)
  231. else:
  232. denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
  233. target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
  234. self.get_decoder_input(
  235. memory, spatial_shapes, denoising_class, denoising_bbox_unact)
  236. # decoder
  237. out_bboxes, out_logits = self.transformer_decoder(target,
  238. init_ref_points_unact,
  239. memory,
  240. spatial_shapes,
  241. self.dec_bbox_head,
  242. self.dec_class_head,
  243. self.query_pos_head,
  244. attn_mask)
  245. return out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta
  246. # ----------------- Dencoder for Segmentation task -----------------
  247. ## RTDETR's Transformer for Segmentation task
  248. class SegTransformerDecoder(nn.Module):
  249. def __init__(self, ):
  250. super().__init__()
  251. # TODO: design seg-decoder
  252. def forward(self, x):
  253. return
  254. # ----------------- Dencoder for Pose estimation task -----------------
  255. ## RTDETR's Transformer for Pose estimation task
  256. class PosTransformerDecoder(nn.Module):
  257. def __init__(self, ):
  258. super().__init__()
  259. # TODO: design seg-decoder
  260. def forward(self, x):
  261. return
  262. if __name__ == '__main__':
  263. import time
  264. from thop import profile
  265. cfg = {
  266. 'out_stride': [8, 16, 32],
  267. # Transformer Decoder
  268. 'transformer': 'rtdetr_transformer',
  269. 'hidden_dim': 256,
  270. 'de_num_heads': 8,
  271. 'de_num_layers': 6,
  272. 'de_mlp_ratio': 4.0,
  273. 'de_dropout': 0.1,
  274. 'de_act': 'gelu',
  275. 'de_num_points': 4,
  276. 'num_queries': 300,
  277. 'learnt_init_query': False,
  278. 'pe_temperature': 10000.,
  279. 'dn_num_denoising': 100,
  280. 'dn_label_noise_ratio': 0.5,
  281. 'dn_box_noise_scale': 1,
  282. }
  283. bs = 1
  284. hidden_dim = cfg['hidden_dim']
  285. in_dims = [hidden_dim] * 3
  286. targets = [{
  287. 'labels': torch.tensor([2, 4, 5, 8]).long(),
  288. 'boxes': torch.tensor([[0, 0, 10, 10], [12, 23, 56, 70], [0, 10, 20, 30], [50, 60, 55, 150]]).float()
  289. }] * bs
  290. pyramid_feats = [torch.randn(bs, hidden_dim, 80, 80),
  291. torch.randn(bs, hidden_dim, 40, 40),
  292. torch.randn(bs, hidden_dim, 20, 20)]
  293. model = build_transformer(cfg, in_dims, 80, True)
  294. model.train()
  295. t0 = time.time()
  296. outputs = model(pyramid_feats, targets)
  297. out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = outputs
  298. t1 = time.time()
  299. print('Time: ', t1 - t0)
  300. print(out_bboxes.shape)
  301. print(out_logits.shape)
  302. print(enc_topk_bboxes.shape)
  303. print(enc_topk_logits.shape)
  304. print('==============================')
  305. model.eval()
  306. flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
  307. print('==============================')
  308. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  309. print('Params : {:.2f} M'.format(params / 1e6))