rtcdet_head.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .rtcdet_basic import BasicConv
  5. except:
  6. from rtcdet_basic import BasicConv
  7. # -------------------- Detection Head --------------------
  8. ## Single-level Detection Head
  9. class DetHead(nn.Module):
  10. def __init__(self,
  11. in_dim :int = 256,
  12. cls_head_dim :int = 256,
  13. reg_head_dim :int = 256,
  14. num_cls_head :int = 2,
  15. num_reg_head :int = 2,
  16. act_type :str = "silu",
  17. norm_type :str = "BN",
  18. depthwise :bool = False):
  19. super().__init__()
  20. # --------- Basic Parameters ----------
  21. self.in_dim = in_dim
  22. self.num_cls_head = num_cls_head
  23. self.num_reg_head = num_reg_head
  24. self.act_type = act_type
  25. self.norm_type = norm_type
  26. self.depthwise = depthwise
  27. # --------- Network Parameters ----------
  28. ## cls head
  29. cls_feats = []
  30. self.cls_head_dim = cls_head_dim
  31. for i in range(num_cls_head):
  32. if i == 0:
  33. cls_feats.append(
  34. BasicConv(in_dim, self.cls_head_dim,
  35. kernel_size=3, padding=1, stride=1,
  36. act_type=act_type,
  37. norm_type=norm_type,
  38. depthwise=depthwise)
  39. )
  40. else:
  41. cls_feats.append(
  42. BasicConv(self.cls_head_dim, self.cls_head_dim,
  43. kernel_size=3, padding=1, stride=1,
  44. act_type=act_type,
  45. norm_type=norm_type,
  46. depthwise=depthwise)
  47. )
  48. ## reg head
  49. reg_feats = []
  50. self.reg_head_dim = reg_head_dim
  51. for i in range(num_reg_head):
  52. if i == 0:
  53. reg_feats.append(
  54. BasicConv(in_dim, self.reg_head_dim,
  55. kernel_size=3, padding=1, stride=1, groups=4,
  56. act_type=act_type,
  57. norm_type=norm_type,
  58. depthwise=depthwise)
  59. )
  60. else:
  61. reg_feats.append(
  62. BasicConv(self.reg_head_dim, self.reg_head_dim,
  63. kernel_size=3, padding=1, stride=1, groups=4,
  64. act_type=act_type,
  65. norm_type=norm_type,
  66. depthwise=depthwise)
  67. )
  68. self.cls_feats = nn.Sequential(*cls_feats)
  69. self.reg_feats = nn.Sequential(*reg_feats)
  70. self.init_weights()
  71. def init_weights(self):
  72. """Initialize the parameters."""
  73. for m in self.modules():
  74. if isinstance(m, torch.nn.Conv2d):
  75. # In order to be consistent with the source code,
  76. # reset the Conv2d initialization parameters
  77. m.reset_parameters()
  78. def forward(self, x):
  79. """
  80. in_feats: (Tensor) [B, C, H, W]
  81. """
  82. cls_feats = self.cls_feats(x)
  83. reg_feats = self.reg_feats(x)
  84. return cls_feats, reg_feats
  85. ## Multi-scales Detection Head
  86. class MSDetHead(nn.Module):
  87. def __init__(self, cfg, in_dims):
  88. super().__init__()
  89. ## ----------- Network Parameters -----------
  90. self.multi_level_heads = nn.ModuleList(
  91. [DetHead(in_dim = in_dims[level],
  92. cls_head_dim = max(in_dims[0], min(cfg.num_classes, 128)),
  93. reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
  94. num_cls_head = cfg.num_cls_head,
  95. num_reg_head = cfg.num_reg_head,
  96. act_type = cfg.head_act,
  97. norm_type = cfg.head_norm,
  98. depthwise = cfg.head_depthwise)
  99. for level in range(cfg.num_levels)
  100. ])
  101. # --------- Basic Parameters ----------
  102. self.in_dims = in_dims
  103. self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
  104. self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
  105. def forward(self, feats):
  106. """
  107. feats: List[(Tensor)] [[B, C, H, W], ...]
  108. """
  109. cls_feats = []
  110. reg_feats = []
  111. for feat, head in zip(feats, self.multi_level_heads):
  112. # ---------------- Pred ----------------
  113. cls_feat, reg_feat = head(feat)
  114. cls_feats.append(cls_feat)
  115. reg_feats.append(reg_feat)
  116. return cls_feats, reg_feats
  117. # -------------------- Segmentation Head --------------------
  118. ## Single-level Segmentation Head
  119. class SegHead(nn.Module):
  120. def __init__(self,
  121. in_dim :int = 256,
  122. cls_head_dim :int = 256,
  123. reg_head_dim :int = 256,
  124. seg_head_dim :int = 256,
  125. num_cls_head :int = 2,
  126. num_reg_head :int = 2,
  127. num_seg_head :int = 2,
  128. act_type :str = "silu",
  129. norm_type :str = "BN",
  130. depthwise :bool = False):
  131. super().__init__()
  132. # --------- Basic Parameters ----------
  133. self.in_dim = in_dim
  134. self.num_cls_head = num_cls_head
  135. self.num_reg_head = num_reg_head
  136. self.num_seg_head = num_reg_head
  137. self.act_type = act_type
  138. self.norm_type = norm_type
  139. self.depthwise = depthwise
  140. # --------- Network Parameters ----------
  141. ## cls head
  142. cls_feats = []
  143. self.cls_head_dim = cls_head_dim
  144. for i in range(num_cls_head):
  145. if i == 0:
  146. cls_feats.append(
  147. BasicConv(in_dim, self.cls_head_dim,
  148. kernel_size=3, padding=1, stride=1,
  149. act_type=act_type,
  150. norm_type=norm_type,
  151. depthwise=depthwise)
  152. )
  153. else:
  154. cls_feats.append(
  155. BasicConv(self.cls_head_dim, self.cls_head_dim,
  156. kernel_size=3, padding=1, stride=1,
  157. act_type=act_type,
  158. norm_type=norm_type,
  159. depthwise=depthwise)
  160. )
  161. ## reg head
  162. reg_feats = []
  163. self.reg_head_dim = reg_head_dim
  164. for i in range(num_reg_head):
  165. if i == 0:
  166. reg_feats.append(
  167. BasicConv(in_dim, self.reg_head_dim,
  168. kernel_size=3, padding=1, stride=1,
  169. act_type=act_type,
  170. norm_type=norm_type,
  171. depthwise=depthwise)
  172. )
  173. else:
  174. reg_feats.append(
  175. BasicConv(self.reg_head_dim, self.reg_head_dim,
  176. kernel_size=3, padding=1, stride=1,
  177. act_type=act_type,
  178. norm_type=norm_type,
  179. depthwise=depthwise)
  180. )
  181. ## seg head
  182. seg_feats = []
  183. self.seg_head_dim = seg_head_dim
  184. for i in range(num_seg_head):
  185. if i == 0:
  186. seg_feats.append(
  187. BasicConv(in_dim, self.seg_head_dim,
  188. kernel_size=3, padding=1, stride=1,
  189. act_type=act_type,
  190. norm_type=norm_type,
  191. depthwise=depthwise)
  192. )
  193. else:
  194. seg_feats.append(
  195. BasicConv(self.seg_head_dim, self.seg_head_dim,
  196. kernel_size=3, padding=1, stride=1,
  197. act_type=act_type,
  198. norm_type=norm_type,
  199. depthwise=depthwise)
  200. )
  201. self.cls_feats = nn.Sequential(*cls_feats)
  202. self.reg_feats = nn.Sequential(*reg_feats)
  203. self.seg_feats = nn.Sequential(*seg_feats)
  204. self.init_weights()
  205. def init_weights(self):
  206. """Initialize the parameters."""
  207. for m in self.modules():
  208. if isinstance(m, torch.nn.Conv2d):
  209. # In order to be consistent with the source code,
  210. # reset the Conv2d initialization parameters
  211. m.reset_parameters()
  212. def forward(self, x):
  213. """
  214. in_feats: (Tensor) [B, C, H, W]
  215. """
  216. cls_feats = self.cls_feats(x)
  217. reg_feats = self.reg_feats(x)
  218. seg_feats = self.reg_feats(x)
  219. return cls_feats, reg_feats, seg_feats
  220. ## Multi-scales Segmentation Head
  221. class MSSegHead(nn.Module):
  222. def __init__(self, cfg, in_dims):
  223. super().__init__()
  224. ## ----------- Network Parameters -----------
  225. self.multi_level_heads = nn.ModuleList(
  226. [SegHead(in_dim = in_dims[level],
  227. cls_head_dim = max(in_dims[0], min(cfg.num_classes, 128)),
  228. reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
  229. seg_head_dim = in_dims[0],
  230. num_cls_head = cfg.num_cls_head,
  231. num_reg_head = cfg.num_reg_head,
  232. num_seg_head = cfg.num_seg_head,
  233. act_type = cfg.head_act,
  234. norm_type = cfg.head_norm,
  235. depthwise = cfg.head_depthwise)
  236. for level in range(cfg.num_levels)
  237. ])
  238. # --------- Basic Parameters ----------
  239. self.in_dims = in_dims
  240. self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
  241. self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
  242. self.seg_head_dim = self.multi_level_heads[0].seg_head_dim
  243. def forward(self, feats):
  244. """
  245. feats: List[(Tensor)] [[B, C, H, W], ...]
  246. """
  247. cls_feats = []
  248. reg_feats = []
  249. seg_feats = []
  250. for feat, head in zip(feats, self.multi_level_heads):
  251. # ---------------- Pred ----------------
  252. cls_feat, reg_feat, seg_feat = head(feat)
  253. cls_feats.append(cls_feat)
  254. reg_feats.append(reg_feat)
  255. seg_feats.append(seg_feat)
  256. return cls_feats, reg_feats, seg_feats