rtpdetr.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. try:
  5. from .basic_modules.basic import MLP
  6. from .basic_modules.transformer import get_clones
  7. from .rtpdetr_encoder import build_image_encoder
  8. from .rtpdetr_decoder import build_transformer
  9. except:
  10. from basic_modules.basic import MLP
  11. from basic_modules.transformer import get_clones
  12. from rtpdetr_encoder import build_image_encoder
  13. from rtpdetr_decoder import build_transformer
  14. # Real-time Plain Transformer-based Object Detector
  15. class RT_PDETR(nn.Module):
  16. def __init__(self,
  17. cfg,
  18. num_classes = 80,
  19. conf_thresh = 0.1,
  20. topk = 300,
  21. deploy = False,
  22. no_multi_labels = False,
  23. ):
  24. super().__init__()
  25. # ----------- Basic setting -----------
  26. self.num_queries_one2one = cfg['num_queries_one2one']
  27. self.num_queries_one2many = cfg['num_queries_one2many']
  28. self.num_queries = self.num_queries_one2one + self.num_queries_one2many
  29. self.num_classes = num_classes
  30. self.num_topk = topk
  31. self.conf_thresh = conf_thresh
  32. self.no_multi_labels = no_multi_labels
  33. self.deploy = deploy
  34. # ----------- Network setting -----------
  35. ## Image encoder
  36. self.image_encoder = build_image_encoder(cfg)
  37. ## Transformer Decoder
  38. self.transformer = build_transformer(cfg, return_intermediate=self.training)
  39. self.query_embed = nn.Embedding(self.num_queries, cfg['hidden_dim'])
  40. ## Detect Head
  41. class_embed = nn.Linear(cfg['hidden_dim'], num_classes)
  42. bbox_embed = MLP(cfg['hidden_dim'], cfg['hidden_dim'], 4, 3)
  43. prior_prob = 0.01
  44. bias_value = -math.log((1 - prior_prob) / prior_prob)
  45. class_embed.bias.data = torch.ones(num_classes) * bias_value
  46. nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
  47. nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
  48. self.class_embed = get_clones(class_embed, cfg['de_num_layers'] + 1)
  49. self.bbox_embed = get_clones(bbox_embed, cfg['de_num_layers'] + 1)
  50. nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0)
  51. self.transformer.decoder.bbox_embed = self.bbox_embed
  52. self.transformer.decoder.class_embed = self.class_embed
  53. def pos2posembed(self, d_model, pos, temperature=10000):
  54. scale = 2 * torch.pi
  55. num_pos_feats = d_model // 2
  56. dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
  57. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
  58. dim_t = temperature ** (2 * dim_t_)
  59. # Position embedding for XY
  60. x_embed = pos[..., 0] * scale
  61. y_embed = pos[..., 1] * scale
  62. pos_x = x_embed[..., None] / dim_t
  63. pos_y = y_embed[..., None] / dim_t
  64. pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
  65. pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
  66. posemb = torch.cat((pos_y, pos_x), dim=-1)
  67. # Position embedding for WH
  68. if pos.size(-1) == 4:
  69. w_embed = pos[..., 2] * scale
  70. h_embed = pos[..., 3] * scale
  71. pos_w = w_embed[..., None] / dim_t
  72. pos_h = h_embed[..., None] / dim_t
  73. pos_w = torch.stack((pos_w[..., 0::2].sin(), pos_w[..., 1::2].cos()), dim=-1).flatten(-2)
  74. pos_h = torch.stack((pos_h[..., 0::2].sin(), pos_h[..., 1::2].cos()), dim=-1).flatten(-2)
  75. posemb = torch.cat((posemb, pos_w, pos_h), dim=-1)
  76. return posemb
  77. def get_posembed(self, d_model, mask, temperature=10000):
  78. not_mask = ~mask
  79. # [B, H, W]
  80. y_embed = not_mask.cumsum(1, dtype=torch.float32)
  81. x_embed = not_mask.cumsum(2, dtype=torch.float32)
  82. y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + 1e-6)
  83. x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + 1e-6)
  84. # [H, W] -> [B, H, W, 2]
  85. pos = torch.stack([x_embed, y_embed], dim=-1)
  86. # [B, H, W, C]
  87. pos_embed = self.pos2posembed(d_model, pos, temperature)
  88. pos_embed = pos_embed.permute(0, 3, 1, 2)
  89. return pos_embed
  90. def post_process(self, box_pred, cls_pred):
  91. cls_pred = cls_pred[0]
  92. box_pred = box_pred[0]
  93. if self.no_multi_labels:
  94. # [M,]
  95. scores, labels = torch.max(cls_pred.sigmoid(), dim=1)
  96. # Keep top k top scoring indices only.
  97. num_topk = min(self.num_topk, box_pred.size(0))
  98. # Topk candidates
  99. predicted_prob, topk_idxs = scores.sort(descending=True)
  100. topk_scores = predicted_prob[:num_topk]
  101. topk_idxs = topk_idxs[:num_topk]
  102. # Filter out the proposals with low confidence score
  103. keep_idxs = topk_scores > self.conf_thresh
  104. topk_idxs = topk_idxs[keep_idxs]
  105. # Top-k results
  106. topk_scores = topk_scores[keep_idxs]
  107. topk_labels = labels[topk_idxs]
  108. topk_bboxes = box_pred[topk_idxs]
  109. else:
  110. # Top-k select
  111. cls_pred = cls_pred.flatten().sigmoid_()
  112. box_pred = box_pred
  113. # Keep top k top scoring indices only.
  114. num_topk = min(self.num_topk, box_pred.size(0))
  115. # Topk candidates
  116. predicted_prob, topk_idxs = cls_pred.sort(descending=True)
  117. topk_scores = predicted_prob[:num_topk]
  118. topk_idxs = topk_idxs[:self.num_topk]
  119. # Filter out the proposals with low confidence score
  120. keep_idxs = topk_scores > self.conf_thresh
  121. scores = topk_scores[keep_idxs]
  122. topk_idxs = topk_idxs[keep_idxs]
  123. topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  124. ## Top-k results
  125. topk_scores = predicted_prob[:self.num_topk]
  126. topk_labels = topk_idxs % self.num_classes
  127. topk_bboxes = box_pred[topk_box_idxs]
  128. return topk_bboxes, topk_scores, topk_labels
  129. @torch.jit.unused
  130. def _set_aux_loss(self, outputs_class, outputs_coord, outputs_coord_old, outputs_deltas):
  131. # this is a workaround to make torchscript happy, as torchscript
  132. # doesn't support dictionary with non-homogeneous values, such
  133. # as a dict having both a Tensor and a list.
  134. return [
  135. {"pred_logits": a, "pred_boxes": b, "pred_boxes_old": c, "pred_deltas": d, }
  136. for a, b, c, d in zip(outputs_class[:-1], outputs_coord[:-1], outputs_coord_old[:-1], outputs_deltas[:-1])
  137. ]
  138. def inference_single_image(self, x):
  139. # ----------- Image Encoder -----------
  140. src = self.image_encoder(x)
  141. # ----------- Prepare inputs for Transformer -----------
  142. mask = torch.zeros([src.shape[0], src.shape[2], src.shape[3]]).bool().to(src.device)
  143. pos_embed = self.get_posembed(src.shape[1], mask)
  144. self_attn_mask = None
  145. query_embeds = self.query_embed.weight[:self.num_queries_one2one]
  146. # -----------Transformer -----------
  147. (
  148. hs,
  149. init_reference,
  150. inter_references,
  151. _,
  152. _,
  153. _,
  154. _,
  155. max_shape
  156. ) = self.transformer(src, mask, pos_embed, query_embeds, self_attn_mask)
  157. # ----------- Process outputs -----------
  158. outputs_classes_one2one = []
  159. outputs_coords_one2one = []
  160. outputs_deltas_one2one = []
  161. for lid in range(hs.shape[0]):
  162. if lid == 0:
  163. reference = init_reference
  164. else:
  165. reference = inter_references[lid - 1]
  166. outputs_class = self.class_embed[lid](hs[lid])
  167. tmp = self.bbox_embed[lid](hs[lid])
  168. outputs_coord = self.transformer.decoder.delta2bbox(reference, tmp, max_shape) # xyxy
  169. outputs_classes_one2one.append(outputs_class[:, :self.num_queries_one2one])
  170. outputs_coords_one2one.append(outputs_coord[:, :self.num_queries_one2one])
  171. outputs_deltas_one2one.append(tmp[:, :self.num_queries_one2one])
  172. outputs_classes_one2one = torch.stack(outputs_classes_one2one)
  173. outputs_coords_one2one = torch.stack(outputs_coords_one2one)
  174. # ------------ Post process ------------
  175. cls_pred = outputs_classes_one2one[-1]
  176. box_pred = outputs_coords_one2one[-1]
  177. # post-process
  178. bboxes, scores, labels = self.post_process(box_pred, cls_pred)
  179. outputs = {
  180. "scores": scores.cpu().numpy(),
  181. "labels": labels.cpu().numpy(),
  182. "bboxes": bboxes.cpu().numpy(),
  183. }
  184. return outputs
  185. def forward(self, x):
  186. if not self.training:
  187. return self.inference_single_image(x)
  188. # ----------- Image Encoder -----------
  189. src = self.image_encoder(x)
  190. # ----------- Prepare inputs for Transformer -----------
  191. mask = torch.zeros([src.shape[0], src.shape[2], src.shape[3]]).bool().to(src.device)
  192. pos_embed = self.get_posembed(src.shape[1], mask)
  193. if self.training:
  194. self_attn_mask = torch.zeros(
  195. [self.num_queries, self.num_queries, ]).bool().to(src.device)
  196. self_attn_mask[self.num_queries_one2one:, 0: self.num_queries_one2one, ] = True
  197. self_attn_mask[0: self.num_queries_one2one, self.num_queries_one2one:, ] = True
  198. query_embeds = self.query_embed.weight
  199. else:
  200. self_attn_mask = None
  201. query_embeds = self.query_embed.weight[:self.num_queries_one2one]
  202. # -----------Transformer -----------
  203. (
  204. hs,
  205. init_reference,
  206. inter_references,
  207. enc_outputs_class,
  208. enc_outputs_coord_unact,
  209. enc_outputs_delta,
  210. output_proposals,
  211. max_shape
  212. ) = self.transformer(src, mask, pos_embed, query_embeds, self_attn_mask)
  213. # ----------- Process outputs -----------
  214. outputs_classes_one2one = []
  215. outputs_coords_one2one = []
  216. outputs_classes_one2many = []
  217. outputs_coords_one2many = []
  218. outputs_coords_old_one2one = []
  219. outputs_deltas_one2one = []
  220. outputs_coords_old_one2many = []
  221. outputs_deltas_one2many = []
  222. for lid in range(hs.shape[0]):
  223. if lid == 0:
  224. reference = init_reference
  225. else:
  226. reference = inter_references[lid - 1]
  227. outputs_class = self.class_embed[lid](hs[lid])
  228. tmp = self.bbox_embed[lid](hs[lid])
  229. outputs_coord = self.transformer.decoder.box_xyxy_to_cxcywh(
  230. self.transformer.decoder.delta2bbox(reference, tmp, max_shape))
  231. outputs_classes_one2one.append(outputs_class[:, 0: self.num_queries_one2one])
  232. outputs_classes_one2many.append(outputs_class[:, self.num_queries_one2one:])
  233. outputs_coords_one2one.append(outputs_coord[:, 0: self.num_queries_one2one])
  234. outputs_coords_one2many.append(outputs_coord[:, self.num_queries_one2one:])
  235. outputs_coords_old_one2one.append(reference[:, :self.num_queries_one2one])
  236. outputs_coords_old_one2many.append(reference[:, self.num_queries_one2one:])
  237. outputs_deltas_one2one.append(tmp[:, :self.num_queries_one2one])
  238. outputs_deltas_one2many.append(tmp[:, self.num_queries_one2one:])
  239. outputs_classes_one2one = torch.stack(outputs_classes_one2one)
  240. outputs_coords_one2one = torch.stack(outputs_coords_one2one)
  241. outputs_classes_one2many = torch.stack(outputs_classes_one2many)
  242. outputs_coords_one2many = torch.stack(outputs_coords_one2many)
  243. out = {
  244. "pred_logits": outputs_classes_one2one[-1],
  245. "pred_boxes": outputs_coords_one2one[-1],
  246. "pred_logits_one2many": outputs_classes_one2many[-1],
  247. "pred_boxes_one2many": outputs_coords_one2many[-1],
  248. "pred_boxes_old": outputs_coords_old_one2one[-1],
  249. "pred_deltas": outputs_deltas_one2one[-1],
  250. "pred_boxes_old_one2many": outputs_coords_old_one2many[-1],
  251. "pred_deltas_one2many": outputs_deltas_one2many[-1],
  252. }
  253. out["aux_outputs"] = self._set_aux_loss(
  254. outputs_classes_one2one, outputs_coords_one2one, outputs_coords_old_one2one, outputs_deltas_one2one
  255. )
  256. out["aux_outputs_one2many"] = self._set_aux_loss(
  257. outputs_classes_one2many, outputs_coords_one2many, outputs_coords_old_one2many, outputs_deltas_one2many
  258. )
  259. out["enc_outputs"] = {
  260. "pred_logits": enc_outputs_class,
  261. "pred_boxes": enc_outputs_coord_unact,
  262. "pred_boxes_old": output_proposals,
  263. "pred_deltas": enc_outputs_delta,
  264. }
  265. return out
  266. if __name__ == '__main__':
  267. import time
  268. from thop import profile
  269. # from loss import build_criterion
  270. # Model config
  271. cfg = {
  272. 'width': 1.0,
  273. 'depth': 1.0,
  274. 'max_stride': 32,
  275. 'out_stride': 16,
  276. # Image Encoder - Backbone
  277. 'backbone': 'resnet50',
  278. 'backbone_norm': 'FrozeBN',
  279. 'pretrained': True,
  280. 'freeze_at': 0,
  281. 'freeze_stem_only': False,
  282. 'hidden_dim': 256,
  283. # Transformer Decoder
  284. 'transformer': 'plain_detr_transformer',
  285. 'hidden_dim': 256,
  286. 'de_num_heads': 8,
  287. 'de_num_layers': 6,
  288. 'de_mlp_ratio': 4.0,
  289. 'de_dropout': 0.1,
  290. 'de_act': 'gelu',
  291. 'de_pre_norm': True,
  292. 'rpe_hidden_dim': 512,
  293. 'use_checkpoint': False,
  294. 'proposal_feature_levels': 3,
  295. 'proposal_tgt_strides': [8, 16, 32],
  296. 'num_queries_one2one': 300,
  297. 'num_queries_one2many': 300,
  298. # Matcher
  299. 'matcher_hpy': {'cost_class': 2.0,
  300. 'cost_bbox': 5.0,
  301. 'cost_giou': 2.0,},
  302. # Loss
  303. 'use_vfl': True,
  304. 'loss_coeff': {'class': 1,
  305. 'bbox': 5,
  306. 'giou': 2,
  307. 'no_object': 0.1,},
  308. }
  309. bs = 1
  310. # Create a batch of images & targets
  311. image = torch.randn(bs, 3, 640, 640)
  312. targets = [{
  313. 'labels': torch.tensor([2, 4, 5, 8]).long(),
  314. 'boxes': torch.tensor([[0, 0, 10, 10], [12, 23, 56, 70], [0, 10, 20, 30], [50, 60, 55, 150]]).float() / 640.
  315. }] * bs
  316. # Create model
  317. model = RT_PDETR(cfg, num_classes=80)
  318. model.train()
  319. # Model inference
  320. t0 = time.time()
  321. outputs = model(image)
  322. t1 = time.time()
  323. print('Infer time: ', t1 - t0)
  324. # # Create criterion
  325. # criterion = build_criterion(cfg, num_classes=80)
  326. # # Compute loss
  327. # loss = criterion(*outputs, targets)
  328. # for k in loss.keys():
  329. # print("{} : {}".format(k, loss[k].item()))
  330. print('==============================')
  331. model.eval()
  332. flops, params = profile(model, inputs=(image, ), verbose=False)
  333. print('==============================')
  334. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  335. print('Params : {:.2f} M'.format(params / 1e6))