rtpdetr_decoder.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. try:
  6. from .basic_modules.basic import LayerNorm2D
  7. from .basic_modules.transformer import GlobalDecoder
  8. except:
  9. from basic_modules.basic import LayerNorm2D
  10. from basic_modules.transformer import GlobalDecoder
  11. def build_transformer(cfg, return_intermediate=False):
  12. if cfg['transformer'] == 'plain_detr_transformer':
  13. return PlainDETRTransformer(d_model = cfg['hidden_dim'],
  14. num_heads = cfg['de_num_heads'],
  15. mlp_ratio = cfg['de_mlp_ratio'],
  16. dropout = cfg['de_dropout'],
  17. act_type = cfg['de_act'],
  18. pre_norm = cfg['de_pre_norm'],
  19. rpe_hidden_dim = cfg['rpe_hidden_dim'],
  20. feature_stride = cfg['out_stride'],
  21. num_layers = cfg['de_num_layers'],
  22. return_intermediate = return_intermediate,
  23. use_checkpoint = cfg['use_checkpoint'],
  24. num_queries_one2one = cfg['num_queries_one2one'],
  25. num_queries_one2many = cfg['num_queries_one2many'],
  26. proposal_feature_levels = cfg['proposal_feature_levels'],
  27. proposal_in_stride = cfg['out_stride'],
  28. proposal_tgt_strides = cfg['proposal_tgt_strides'],
  29. )
  30. # ----------------- Dencoder for Detection task -----------------
  31. ## PlainDETR's Transformer for Detection task
  32. class PlainDETRTransformer(nn.Module):
  33. def __init__(self,
  34. # Decoder layer params
  35. d_model :int = 256,
  36. num_heads :int = 8,
  37. mlp_ratio :float = 4.0,
  38. dropout :float = 0.1,
  39. act_type :str = "relu",
  40. pre_norm :bool = False,
  41. rpe_hidden_dim :int = 512,
  42. feature_stride :int = 16,
  43. num_layers :int = 6,
  44. # Decoder params
  45. return_intermediate :bool = False,
  46. use_checkpoint :bool = False,
  47. num_queries_one2one :int = 300,
  48. num_queries_one2many :int = 1500,
  49. proposal_feature_levels :int = 3,
  50. proposal_in_stride :int = 16,
  51. proposal_tgt_strides :int = [8, 16, 32],
  52. ):
  53. super().__init__()
  54. # ------------ Basic setting ------------
  55. ## Model
  56. self.d_model = d_model
  57. self.num_heads = num_heads
  58. self.rpe_hidden_dim = rpe_hidden_dim
  59. self.mlp_ratio = mlp_ratio
  60. self.act_type = act_type
  61. self.num_layers = num_layers
  62. self.return_intermediate = return_intermediate
  63. ## Trick
  64. self.use_checkpoint = use_checkpoint
  65. self.num_queries_one2one = num_queries_one2one
  66. self.num_queries_one2many = num_queries_one2many
  67. self.proposal_feature_levels = proposal_feature_levels
  68. self.proposal_tgt_strides = proposal_tgt_strides
  69. self.proposal_in_stride = proposal_in_stride
  70. self.proposal_min_size = 50
  71. # --------------- Network setting ---------------
  72. ## Global Decoder
  73. self.decoder = GlobalDecoder(d_model, num_heads, mlp_ratio, dropout, act_type, pre_norm,
  74. rpe_hidden_dim, feature_stride, num_layers, return_intermediate,
  75. use_checkpoint,)
  76. ## Two stage
  77. self.enc_output = nn.Linear(d_model, d_model)
  78. self.enc_output_norm = nn.LayerNorm(d_model)
  79. self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
  80. self.pos_trans_norm = nn.LayerNorm(d_model * 2)
  81. ## Expand layers
  82. if proposal_feature_levels > 1:
  83. assert len(proposal_tgt_strides) == proposal_feature_levels
  84. self.enc_output_proj = nn.ModuleList([])
  85. for stride in proposal_tgt_strides:
  86. if stride == proposal_in_stride:
  87. self.enc_output_proj.append(nn.Identity())
  88. elif stride > proposal_in_stride:
  89. scale = int(math.log2(stride / proposal_in_stride))
  90. layers = []
  91. for _ in range(scale - 1):
  92. layers += [
  93. nn.Conv2d(d_model, d_model, kernel_size=2, stride=2),
  94. LayerNorm2D(d_model),
  95. nn.GELU()
  96. ]
  97. layers.append(nn.Conv2d(d_model, d_model, kernel_size=2, stride=2))
  98. self.enc_output_proj.append(nn.Sequential(*layers))
  99. else:
  100. scale = int(math.log2(proposal_in_stride / stride))
  101. layers = []
  102. for _ in range(scale - 1):
  103. layers += [
  104. nn.ConvTranspose2d(d_model, d_model, kernel_size=2, stride=2),
  105. LayerNorm2D(d_model),
  106. nn.GELU()
  107. ]
  108. layers.append(nn.ConvTranspose2d(d_model, d_model, kernel_size=2, stride=2))
  109. self.enc_output_proj.append(nn.Sequential(*layers))
  110. self._reset_parameters()
  111. def _reset_parameters(self):
  112. for p in self.parameters():
  113. if p.dim() > 1:
  114. nn.init.xavier_uniform_(p)
  115. if hasattr(self.decoder, '_reset_parameters'):
  116. print('decoder re-init')
  117. self.decoder._reset_parameters()
  118. def get_proposal_pos_embed(self, proposals):
  119. num_pos_feats = self.d_model // 2
  120. temperature = 10000
  121. scale = 2 * torch.pi
  122. dim_t = torch.arange(
  123. num_pos_feats, dtype=torch.float32, device=proposals.device
  124. )
  125. dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
  126. # N, L, 4
  127. proposals = proposals * scale
  128. # N, L, 4, 128
  129. pos = proposals[:, :, :, None] / dim_t
  130. # N, L, 4, 64, 2
  131. pos = torch.stack(
  132. (pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4
  133. ).flatten(2)
  134. return pos
  135. def get_valid_ratio(self, mask):
  136. _, H, W = mask.shape
  137. valid_H = torch.sum(~mask[:, :, 0], 1)
  138. valid_W = torch.sum(~mask[:, 0, :], 1)
  139. valid_ratio_h = valid_H.float() / H
  140. valid_ratio_w = valid_W.float() / W
  141. valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
  142. return valid_ratio
  143. def expand_encoder_output(self, memory, memory_padding_mask, spatial_shapes):
  144. assert spatial_shapes.size(0) == 1, f'Get encoder output of shape {spatial_shapes}, not sure how to expand'
  145. bs, _, c = memory.shape
  146. h, w = spatial_shapes[0]
  147. _out_memory = memory.view(bs, h, w, c).permute(0, 3, 1, 2)
  148. _out_memory_padding_mask = memory_padding_mask.view(bs, h, w)
  149. out_memory, out_memory_padding_mask, out_spatial_shapes = [], [], []
  150. for i in range(self.proposal_feature_levels):
  151. mem = self.enc_output_proj[i](_out_memory)
  152. mask = F.interpolate(
  153. _out_memory_padding_mask[None].float(), size=mem.shape[-2:]
  154. ).to(torch.bool)
  155. out_memory.append(mem)
  156. out_memory_padding_mask.append(mask.squeeze(0))
  157. out_spatial_shapes.append(mem.shape[-2:])
  158. out_memory = torch.cat([mem.flatten(2).transpose(1, 2) for mem in out_memory], dim=1)
  159. out_memory_padding_mask = torch.cat([mask.flatten(1) for mask in out_memory_padding_mask], dim=1)
  160. out_spatial_shapes = torch.as_tensor(out_spatial_shapes, dtype=torch.long, device=out_memory.device)
  161. return out_memory, out_memory_padding_mask, out_spatial_shapes
  162. def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
  163. if self.proposal_feature_levels > 1:
  164. memory, memory_padding_mask, spatial_shapes = self.expand_encoder_output(
  165. memory, memory_padding_mask, spatial_shapes
  166. )
  167. N_, S_, C_ = memory.shape
  168. # base_scale = 4.0
  169. proposals = []
  170. _cur = 0
  171. for lvl, (H_, W_) in enumerate(spatial_shapes):
  172. stride = self.proposal_tgt_strides[lvl]
  173. grid_y, grid_x = torch.meshgrid(
  174. torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
  175. torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
  176. )
  177. grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
  178. grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) * stride
  179. wh = torch.ones_like(grid) * self.proposal_min_size * (2.0 ** lvl)
  180. proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
  181. proposals.append(proposal)
  182. _cur += H_ * W_
  183. output_proposals = torch.cat(proposals, 1)
  184. H_, W_ = spatial_shapes[0]
  185. stride = self.proposal_tgt_strides[0]
  186. mask_flatten_ = memory_padding_mask[:, :H_*W_].view(N_, H_, W_, 1)
  187. valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1, keepdim=True) * stride
  188. valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1, keepdim=True) * stride
  189. img_size = torch.cat([valid_W, valid_H, valid_W, valid_H], dim=-1)
  190. img_size = img_size.unsqueeze(1) # [BS, 1, 4]
  191. output_proposals_valid = (
  192. (output_proposals > 0.01 * img_size) & (output_proposals < 0.99 * img_size)
  193. ).all(-1, keepdim=True)
  194. output_proposals = output_proposals.masked_fill(
  195. memory_padding_mask.unsqueeze(-1).repeat(1, 1, 1),
  196. max(H_, W_) * stride,
  197. )
  198. output_proposals = output_proposals.masked_fill(
  199. ~output_proposals_valid,
  200. max(H_, W_) * stride,
  201. )
  202. output_memory = memory
  203. output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
  204. output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
  205. output_memory = self.enc_output_norm(self.enc_output(output_memory))
  206. max_shape = (valid_H[:, None, :], valid_W[:, None, :])
  207. return output_memory, output_proposals, max_shape
  208. def get_reference_points(self, memory, mask_flatten, spatial_shapes):
  209. output_memory, output_proposals, max_shape = self.gen_encoder_output_proposals(
  210. memory, mask_flatten, spatial_shapes
  211. )
  212. # hack implementation for two-stage Deformable DETR
  213. enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
  214. enc_outputs_delta = self.decoder.bbox_embed[self.decoder.num_layers](output_memory)
  215. enc_outputs_coord_unact = self.decoder.box_xyxy_to_cxcywh(self.decoder.delta2bbox(
  216. output_proposals,
  217. enc_outputs_delta,
  218. max_shape
  219. ))
  220. topk = self.two_stage_num_proposals
  221. topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
  222. topk_coords_unact = torch.gather(
  223. enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
  224. )
  225. topk_coords_unact = topk_coords_unact.detach()
  226. reference_points = topk_coords_unact
  227. return (reference_points, max_shape, enc_outputs_class,
  228. enc_outputs_coord_unact, enc_outputs_delta, output_proposals)
  229. def forward(self, src, mask, pos_embed, query_embed=None, self_attn_mask=None):
  230. # Prepare input for encoder
  231. bs, c, h, w = src.shape
  232. src_flatten = src.flatten(2).transpose(1, 2)
  233. mask_flatten = mask.flatten(1)
  234. pos_embed_flatten = pos_embed.flatten(2).transpose(1, 2)
  235. spatial_shapes = torch.as_tensor([(h, w)], dtype=torch.long, device=src_flatten.device)
  236. # Prepare input for decoder
  237. memory = src_flatten
  238. bs, seq_l, c = memory.shape
  239. # Two stage trick
  240. if self.training:
  241. self.two_stage_num_proposals = self.num_queries_one2one + self.num_queries_one2many
  242. else:
  243. self.two_stage_num_proposals = self.num_queries_one2one
  244. (reference_points, max_shape, enc_outputs_class,
  245. enc_outputs_coord_unact, enc_outputs_delta, output_proposals) \
  246. = self.get_reference_points(memory, mask_flatten, spatial_shapes)
  247. init_reference_out = reference_points
  248. pos_trans_out = torch.zeros((bs, self.two_stage_num_proposals, 2*c), device=init_reference_out.device)
  249. pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(reference_points)))
  250. # Mixed selection trick
  251. tgt = query_embed.unsqueeze(0).expand(bs, -1, -1)
  252. query_embed, _ = torch.split(pos_trans_out, c, dim=2)
  253. # Decoder
  254. hs, inter_references = self.decoder(tgt,
  255. reference_points,
  256. memory,
  257. pos_embed_flatten,
  258. spatial_shapes,
  259. query_embed,
  260. mask_flatten,
  261. self_attn_mask,
  262. max_shape
  263. )
  264. inter_references_out = inter_references
  265. return (hs,
  266. init_reference_out,
  267. inter_references_out,
  268. enc_outputs_class,
  269. enc_outputs_coord_unact,
  270. enc_outputs_delta,
  271. output_proposals,
  272. max_shape
  273. )
  274. # ----------------- Dencoder for Segmentation task -----------------
  275. ## PlainDETR's Transformer for Segmentation task
  276. class SegTransformerDecoder(nn.Module):
  277. def __init__(self, ):
  278. super().__init__()
  279. # TODO: design seg-decoder
  280. def forward(self, x):
  281. return
  282. # ----------------- Dencoder for Pose estimation task -----------------
  283. ## PlainDETR's Transformer for Pose estimation task
  284. class PosTransformerDecoder(nn.Module):
  285. def __init__(self, ):
  286. super().__init__()
  287. # TODO: design seg-decoder
  288. def forward(self, x):
  289. return
  290. if __name__ == '__main__':
  291. import time
  292. from thop import profile
  293. from basic_modules.basic import MLP
  294. from basic_modules.transformer import get_clones
  295. cfg = {
  296. 'out_stride': 16,
  297. # Transformer Decoder
  298. 'transformer': 'plain_detr_transformer',
  299. 'hidden_dim': 256,
  300. 'num_queries': 300,
  301. 'de_num_heads': 8,
  302. 'de_num_layers': 6,
  303. 'de_mlp_ratio': 4.0,
  304. 'de_dropout': 0.1,
  305. 'de_act': 'gelu',
  306. 'de_pre_norm': True,
  307. 'rpe_hidden_dim': 512,
  308. 'use_checkpoint': False,
  309. 'proposal_feature_levels': 3,
  310. 'proposal_tgt_strides': [8, 16, 32],
  311. }
  312. feat = torch.randn(1, cfg['hidden_dim'], 40, 40)
  313. mask = torch.zeros(1, 40, 40)
  314. pos_embed = torch.randn(1, cfg['hidden_dim'], 40, 40)
  315. query_embed = torch.randn(cfg['num_queries'], cfg['hidden_dim'])
  316. model = build_transformer(cfg, True)
  317. class_embed = nn.Linear(cfg['hidden_dim'], 80)
  318. bbox_embed = MLP(cfg['hidden_dim'], cfg['hidden_dim'], 4, 3)
  319. class_embed = get_clones(class_embed, cfg['de_num_layers'] + 1)
  320. bbox_embed = get_clones(bbox_embed, cfg['de_num_layers'] + 1)
  321. model.decoder.bbox_embed = bbox_embed
  322. model.decoder.class_embed = class_embed
  323. model.train()
  324. t0 = time.time()
  325. outputs = model(feat, mask, pos_embed, query_embed)
  326. (hs,
  327. init_reference_out,
  328. inter_references_out,
  329. enc_outputs_class,
  330. enc_outputs_coord_unact,
  331. enc_outputs_delta,
  332. output_proposals,
  333. max_shape
  334. ) = outputs
  335. t1 = time.time()
  336. print('Time: ', t1 - t0)
  337. print(hs.shape)
  338. print(init_reference_out.shape)
  339. print(inter_references_out.shape)
  340. print(enc_outputs_class.shape)
  341. print(enc_outputs_coord_unact.shape)
  342. print(enc_outputs_delta.shape)
  343. print(output_proposals.shape)
  344. print('==============================')
  345. model.eval()
  346. flops, params = profile(model, inputs=(feat, mask, pos_embed, query_embed, ), verbose=False)
  347. print('==============================')
  348. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  349. print('Params : {:.2f} M'.format(params / 1e6))