fcos_head.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import torch
  2. import torch.nn as nn
  3. from ..basic.conv import ConvModule
  4. class Scale(nn.Module):
  5. """
  6. Multiply the output regression range by a learnable constant value
  7. """
  8. def __init__(self, init_value=1.0):
  9. """
  10. init_value : initial value for the scalar
  11. """
  12. super().__init__()
  13. self.scale = nn.Parameter(
  14. torch.tensor(init_value, dtype=torch.float32),
  15. requires_grad=True
  16. )
  17. def forward(self, x):
  18. """
  19. input -> scale * input
  20. """
  21. return x * self.scale
  22. class FcosHead(nn.Module):
  23. def __init__(self, cfg, in_dim, out_dim, num_classes, num_cls_head=1, num_reg_head=1, act_type='relu', norm_type='BN'):
  24. super().__init__()
  25. self.fmp_size = None
  26. # ------------------ Basic parameters -------------------
  27. self.cfg = cfg
  28. self.in_dim = in_dim
  29. self.num_classes = num_classes
  30. self.num_cls_head = num_cls_head
  31. self.num_reg_head = num_reg_head
  32. self.act_type = act_type
  33. self.norm_type = norm_type
  34. self.stride = cfg['out_stride']
  35. # ------------------ Network parameters -------------------
  36. ## cls head
  37. cls_heads = []
  38. self.cls_head_dim = out_dim
  39. for i in range(self.num_cls_head):
  40. if i == 0:
  41. cls_heads.append(
  42. ConvModule(in_dim, self.cls_head_dim, k=3, p=1, s=1,
  43. act_type=self.act_type,
  44. norm_type=self.norm_type)
  45. )
  46. else:
  47. cls_heads.append(
  48. ConvModule(self.cls_head_dim, self.cls_head_dim, k=3, p=1, s=1,
  49. act_type=self.act_type,
  50. norm_type=self.norm_type)
  51. )
  52. ## reg head
  53. reg_heads = []
  54. self.reg_head_dim = out_dim
  55. for i in range(self.num_reg_head):
  56. if i == 0:
  57. reg_heads.append(
  58. ConvModule(in_dim, self.reg_head_dim, k=3, p=1, s=1,
  59. act_type=self.act_type,
  60. norm_type=self.norm_type)
  61. )
  62. else:
  63. reg_heads.append(
  64. ConvModule(self.reg_head_dim, self.reg_head_dim, k=3, p=1, s=1,
  65. act_type=self.act_type,
  66. norm_type=self.norm_type)
  67. )
  68. self.cls_heads = nn.Sequential(*cls_heads)
  69. self.reg_heads = nn.Sequential(*reg_heads)
  70. ## pred layers
  71. self.cls_pred = nn.Conv2d(self.cls_head_dim, num_classes, kernel_size=3, padding=1)
  72. self.reg_pred = nn.Conv2d(self.reg_head_dim, 4, kernel_size=3, padding=1)
  73. self.ctn_pred = nn.Conv2d(self.reg_head_dim, 1, kernel_size=3, padding=1)
  74. ## scale layers
  75. self.scales = nn.ModuleList(
  76. Scale() for _ in range(len(self.stride))
  77. )
  78. # init bias
  79. self._init_layers()
  80. def _init_layers(self):
  81. for module in [self.cls_heads, self.reg_heads, self.cls_pred, self.reg_pred, self.ctn_pred]:
  82. for layer in module.modules():
  83. if isinstance(layer, nn.Conv2d):
  84. torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
  85. torch.nn.init.constant_(layer.bias, 0)
  86. if isinstance(layer, nn.GroupNorm):
  87. torch.nn.init.constant_(layer.weight, 1)
  88. torch.nn.init.constant_(layer.bias, 0)
  89. # init the bias of cls pred
  90. init_prob = 0.01
  91. bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
  92. torch.nn.init.constant_(self.cls_pred.bias, bias_value)
  93. def get_anchors(self, level, fmp_size):
  94. """
  95. fmp_size: (List) [H, W]
  96. """
  97. # generate grid cells
  98. fmp_h, fmp_w = fmp_size
  99. anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
  100. # [H, W, 2] -> [HW, 2]
  101. anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5
  102. anchors *= self.stride[level]
  103. return anchors
  104. def decode_boxes(self, pred_deltas, anchors):
  105. """
  106. pred_deltas: (List[Tensor]) [B, M, 4] or [M, 4] (l, t, r, b)
  107. anchors: (List[Tensor]) [1, M, 2] or [M, 2]
  108. """
  109. # x1 = x_anchor - l, x2 = x_anchor + r
  110. # y1 = y_anchor - t, y2 = y_anchor + b
  111. pred_x1y1 = anchors - pred_deltas[..., :2]
  112. pred_x2y2 = anchors + pred_deltas[..., 2:]
  113. pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
  114. return pred_box
  115. def forward(self, pyramid_feats, mask=None):
  116. all_masks = []
  117. all_anchors = []
  118. all_cls_preds = []
  119. all_reg_preds = []
  120. all_box_preds = []
  121. all_ctn_preds = []
  122. for level, feat in enumerate(pyramid_feats):
  123. # ------------------- Decoupled head -------------------
  124. cls_feat = self.cls_heads(feat)
  125. reg_feat = self.reg_heads(feat)
  126. # ------------------- Generate anchor box -------------------
  127. B, _, H, W = cls_feat.size()
  128. fmp_size = [H, W]
  129. anchors = self.get_anchors(level, fmp_size) # [M, 4]
  130. anchors = anchors.to(cls_feat.device)
  131. # ------------------- Predict -------------------
  132. cls_pred = self.cls_pred(cls_feat)
  133. reg_pred = self.reg_pred(reg_feat)
  134. ctn_pred = self.ctn_pred(reg_feat)
  135. # ------------------- Process preds -------------------
  136. ## [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
  137. cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
  138. ctn_pred = ctn_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
  139. reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
  140. reg_pred = nn.functional.relu(self.scales[level](reg_pred)) * self.stride[level]
  141. ## Decode bbox
  142. box_pred = self.decode_boxes(reg_pred, anchors)
  143. ## Adjust mask
  144. if mask is not None:
  145. # [B, H, W]
  146. mask_i = torch.nn.functional.interpolate(mask[None].float(), size=[H, W]).bool()[0]
  147. # [B, H, W] -> [B, M]
  148. mask_i = mask_i.flatten(1)
  149. all_masks.append(mask_i)
  150. all_anchors.append(anchors)
  151. all_cls_preds.append(cls_pred)
  152. all_reg_preds.append(reg_pred)
  153. all_box_preds.append(box_pred)
  154. all_ctn_preds.append(ctn_pred)
  155. outputs = {"pred_cls": all_cls_preds, # List [B, M, C]
  156. "pred_reg": all_reg_preds, # List [B, M, 4]
  157. "pred_box": all_box_preds, # List [B, M, 4]
  158. "pred_ctn": all_ctn_preds, # List [B, M, 1]
  159. "anchors": all_anchors, # List [B, M, 2]
  160. "strides": self.stride,
  161. "mask": all_masks} # List [B, M,]
  162. return outputs