rtcdet_pred.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. # -------------------- Detection Pred Layer --------------------
  6. ## Single-level pred layer
  7. class DetPredLayer(nn.Module):
  8. def __init__(self,
  9. cls_dim :int = 256,
  10. reg_dim :int = 256,
  11. stride :int = 32,
  12. num_classes :int = 80,
  13. num_coords :int = 4):
  14. super().__init__()
  15. # --------- Basic Parameters ----------
  16. self.stride = stride
  17. self.cls_dim = cls_dim
  18. self.reg_dim = reg_dim
  19. self.num_classes = num_classes
  20. self.num_coords = num_coords
  21. # --------- Network Parameters ----------
  22. self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
  23. self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1, groups=4)
  24. self.init_bias()
  25. def init_bias(self):
  26. # cls pred bias
  27. b = self.cls_pred.bias.view(1, -1)
  28. b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
  29. self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  30. # reg pred bias
  31. b = self.reg_pred.bias.view(-1, )
  32. b.data.fill_(1.0)
  33. self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  34. w = self.reg_pred.weight
  35. w.data.fill_(0.)
  36. self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
  37. def generate_anchors(self, fmp_size):
  38. """
  39. fmp_size: (List) [H, W]
  40. """
  41. # generate grid cells
  42. fmp_h, fmp_w = fmp_size
  43. anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
  44. # [H, W, 2] -> [HW, 2]
  45. anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
  46. anchors += 0.5 # add center offset
  47. anchors *= self.stride
  48. return anchors
  49. def forward(self, cls_feat, reg_feat):
  50. # pred
  51. cls_pred = self.cls_pred(cls_feat)
  52. reg_pred = self.reg_pred(reg_feat)
  53. # generate anchor boxes: [M, 4]
  54. B, _, H, W = cls_pred.size()
  55. fmp_size = [H, W]
  56. anchors = self.generate_anchors(fmp_size)
  57. anchors = anchors.to(cls_pred.device)
  58. # stride tensor: [M, 1]
  59. stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
  60. # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
  61. cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
  62. reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_coords)
  63. # output dict
  64. outputs = {"pred_cls": cls_pred, # List(Tensor) [B, M, C]
  65. "pred_reg": reg_pred, # List(Tensor) [B, M, 4*(reg_max)]
  66. "anchors": anchors, # List(Tensor) [M, 2]
  67. "strides": self.stride, # List(Int) = [8, 16, 32]
  68. "stride_tensor": stride_tensor # List(Tensor) [M, 1]
  69. }
  70. return outputs
  71. ## Multi-scales pred layer
  72. class MSDetPredLayer(nn.Module):
  73. def __init__(self,
  74. cfg,
  75. cls_dim,
  76. reg_dim,
  77. ):
  78. super().__init__()
  79. # --------- Basic Parameters ----------
  80. self.cfg = cfg
  81. self.cls_dim = cls_dim
  82. self.reg_dim = reg_dim
  83. self.reg_max = cfg.reg_max
  84. self.num_levels = cfg.num_levels
  85. self.out_stride = cfg.out_stride
  86. # ----------- Network Parameters -----------
  87. ## pred layers
  88. self.multi_level_preds = nn.ModuleList(
  89. [DetPredLayer(cls_dim = cls_dim,
  90. reg_dim = reg_dim,
  91. stride = cfg.out_stride[level],
  92. num_classes = cfg.num_classes,
  93. num_coords = cfg.reg_max * 4)
  94. for level in range(cfg.num_levels)
  95. ])
  96. ## proj conv
  97. proj_init = torch.arange(cfg.reg_max, dtype=torch.float)
  98. self.proj_conv = nn.Conv2d(cfg.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
  99. self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, cfg.reg_max, 1, 1]), requires_grad=False)
  100. def forward(self, cls_feats, reg_feats):
  101. all_anchors = []
  102. all_strides = []
  103. all_cls_preds = []
  104. all_reg_preds = []
  105. all_box_preds = []
  106. all_delta_preds = []
  107. for level in range(self.num_levels):
  108. # -------------- Single-level prediction --------------
  109. outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
  110. # -------------- Decode bbox --------------
  111. B, M = outputs["pred_reg"].shape[:2]
  112. # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
  113. delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.reg_max])
  114. # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
  115. delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
  116. # [B, reg_max, 4, M] -> [B, 1, 4, M]
  117. delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
  118. # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
  119. delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
  120. ## tlbr -> xyxy
  121. x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.out_stride[level]
  122. x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.out_stride[level]
  123. box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
  124. # collect results
  125. all_cls_preds.append(outputs["pred_cls"])
  126. all_reg_preds.append(outputs["pred_reg"])
  127. all_box_preds.append(box_pred)
  128. all_delta_preds.append(delta_pred)
  129. all_anchors.append(outputs["anchors"])
  130. all_strides.append(outputs["stride_tensor"])
  131. # output dict
  132. outputs = {"pred_cls": all_cls_preds, # List(Tensor) [B, M, C]
  133. "pred_reg": all_reg_preds, # List(Tensor) [B, M, 4*(reg_max)]
  134. "pred_box": all_box_preds, # List(Tensor) [B, M, 4]
  135. "pred_delta": all_delta_preds, # List(Tensor) [B, M, 4]
  136. "anchors": all_anchors, # List(Tensor) [M, 2]
  137. "stride_tensor": all_strides, # List(Tensor) [M, 1]
  138. "strides": self.out_stride, # List(Int) = [8, 16, 32]
  139. }
  140. return outputs
  141. # -------------------- Segmentation Pred Layer --------------------
  142. ## Single-level pred layer
  143. class SegPredLayer(nn.Module):
  144. def __init__(self,
  145. cls_dim :int = 256,
  146. reg_dim :int = 256,
  147. seg_dim :int = 256,
  148. stride :int = 32,
  149. num_classes :int = 80,
  150. num_coords :int = 4,
  151. num_masks :int = 1):
  152. super().__init__()
  153. # --------- Basic Parameters ----------
  154. self.stride = stride
  155. self.cls_dim = cls_dim
  156. self.reg_dim = reg_dim
  157. self.seg_dim = seg_dim
  158. self.num_classes = num_classes
  159. self.num_coords = num_coords
  160. self.num_masks = num_masks
  161. # --------- Network Parameters ----------
  162. self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
  163. self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1)
  164. self.seg_pred = nn.Conv2d(seg_dim, num_masks, kernel_size=1)
  165. self.init_bias()
  166. def init_bias(self):
  167. # cls pred bias
  168. b = self.cls_pred.bias.view(1, -1)
  169. b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
  170. self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  171. # reg pred bias
  172. b = self.reg_pred.bias.view(-1, )
  173. b.data.fill_(1.0)
  174. self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  175. w = self.reg_pred.weight
  176. w.data.fill_(0.)
  177. self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
  178. # seg pred bias
  179. b = self.seg_pred.bias.view(-1, )
  180. b.data.fill_(1.0)
  181. self.seg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  182. w = self.seg_pred.weight
  183. w.data.fill_(0.)
  184. self.seg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
  185. def generate_anchors(self, fmp_size):
  186. """
  187. fmp_size: (List) [H, W]
  188. """
  189. # generate grid cells
  190. fmp_h, fmp_w = fmp_size
  191. anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
  192. # [H, W, 2] -> [HW, 2]
  193. anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
  194. anchors += 0.5 # add center offset
  195. anchors *= self.stride
  196. return anchors
  197. def forward(self, cls_feat, reg_feat, seg_feat):
  198. # pred
  199. cls_pred = self.cls_pred(cls_feat)
  200. reg_pred = self.reg_pred(reg_feat)
  201. seg_pred = self.seg_pred(seg_feat)
  202. # generate anchor boxes: [M, 4]
  203. B, _, H, W = cls_pred.size()
  204. fmp_size = [H, W]
  205. anchors = self.generate_anchors(fmp_size)
  206. anchors = anchors.to(cls_pred.device)
  207. # stride tensor: [M, 1]
  208. stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
  209. # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
  210. cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
  211. reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_coords)
  212. seg_pred = seg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_masks)
  213. # output dict
  214. outputs = {"pred_cls": cls_pred, # List(Tensor) [B, M, Nc]
  215. "pred_reg": reg_pred, # List(Tensor) [B, M, Na]
  216. "pred_seg": seg_pred, # List(Tensor) [B, M, Nm]
  217. "anchors": anchors, # List(Tensor) [M, 2]
  218. "strides": self.stride, # List(Int) = [8, 16, 32]
  219. "stride_tensor": stride_tensor # List(Tensor) [M, 1]
  220. }
  221. return outputs
  222. ## Multi-level pred layer
  223. class RTCSegPredLayer(nn.Module):
  224. def __init__(self,
  225. cfg,
  226. cls_dim,
  227. reg_dim,
  228. seg_dim,
  229. ):
  230. super().__init__()
  231. # --------- Basic Parameters ----------
  232. self.cfg = cfg
  233. self.cls_dim = cls_dim
  234. self.reg_dim = reg_dim
  235. self.seg_dim = seg_dim
  236. # ----------- Network Parameters -----------
  237. ## pred layers
  238. self.multi_level_preds = nn.ModuleList(
  239. [SegPredLayer(cls_dim = cls_dim,
  240. reg_dim = reg_dim,
  241. seg_dim = seg_dim,
  242. stride = cfg.out_stride[level],
  243. num_classes = cfg.num_classes,
  244. num_coords = cfg.reg_max * 4,
  245. num_masks = cfg.mask_dim)
  246. for level in range(cfg.num_levels)
  247. ])
  248. ## proj conv
  249. proj_init = torch.arange(cfg.reg_max, dtype=torch.float)
  250. self.proj_conv = nn.Conv2d(cfg.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
  251. self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, cfg.reg_max, 1, 1]), requires_grad=False)
  252. def forward(self, cls_feats, reg_feats, seg_feats):
  253. all_anchors = []
  254. all_strides = []
  255. all_cls_preds = []
  256. all_reg_preds = []
  257. all_box_preds = []
  258. all_seg_preds = []
  259. for level in range(self.cfg.num_levels):
  260. # -------------- Single-level prediction --------------
  261. outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level], seg_feats[level])
  262. # -------------- Decode bbox --------------
  263. B, M = outputs["pred_reg"].shape[:2]
  264. # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
  265. delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.cfg.reg_max])
  266. # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
  267. delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
  268. # [B, reg_max, 4, M] -> [B, 1, 4, M]
  269. delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
  270. # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
  271. delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
  272. ## tlbr -> xyxy
  273. x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.cfg.out_stride[level]
  274. x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.cfg.out_stride[level]
  275. box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
  276. # collect results
  277. all_cls_preds.append(outputs["pred_cls"])
  278. all_reg_preds.append(outputs["pred_reg"])
  279. all_seg_preds.append(outputs["pred_seg"])
  280. all_box_preds.append(box_pred)
  281. all_anchors.append(outputs["anchors"])
  282. all_strides.append(outputs["stride_tensor"])
  283. # output dict
  284. outputs = {"pred_cls": all_cls_preds, # List(Tensor) [B, M, C]
  285. "pred_reg": all_reg_preds, # List(Tensor) [B, M, 4*(reg_max)]
  286. "pred_box": all_box_preds, # List(Tensor) [B, M, 4]
  287. "pred_seg": all_seg_preds, # List(Tensor) [B, M, 4]
  288. "anchors": all_anchors, # List(Tensor) [M, 2]
  289. "stride_tensor": all_strides, # List(Tensor) [M, 1]
  290. "strides": self.cfg.out_stride, # List(Int) = [8, 16, 32]
  291. }
  292. return outputs