retinanet_head.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. from ..basic.conv import ConvModule
  5. class RetinaNetHead(nn.Module):
  6. def __init__(self, cfg, in_dim, out_dim, num_classes, num_cls_head=1, num_reg_head=1, act_type='relu', norm_type='BN'):
  7. super().__init__()
  8. self.fmp_size = None
  9. self.DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16)
  10. # ------------------ Basic parameters -------------------
  11. self.cfg = cfg
  12. self.in_dim = in_dim
  13. self.num_classes = num_classes
  14. self.num_cls_head=num_cls_head
  15. self.num_reg_head=num_reg_head
  16. self.act_type=act_type
  17. self.norm_type=norm_type
  18. self.stride = cfg['out_stride']
  19. # ------------------ Anchor parameters -------------------
  20. self.anchor_size = self.get_anchor_sizes(cfg) # [S, KA, 2]
  21. self.num_anchors = self.anchor_size.shape[1]
  22. # ------------------ Network parameters -------------------
  23. ## cls head
  24. cls_heads = []
  25. self.cls_head_dim = out_dim
  26. for i in range(self.num_cls_head):
  27. if i == 0:
  28. cls_heads.append(
  29. ConvModule(in_dim, self.cls_head_dim, k=3, p=1, s=1,
  30. act_type=self.act_type,
  31. norm_type=self.norm_type)
  32. )
  33. else:
  34. cls_heads.append(
  35. ConvModule(self.cls_head_dim, self.cls_head_dim, k=3, p=1, s=1,
  36. act_type=self.act_type,
  37. norm_type=self.norm_type)
  38. )
  39. ## reg head
  40. reg_heads = []
  41. self.reg_head_dim = out_dim
  42. for i in range(self.num_reg_head):
  43. if i == 0:
  44. reg_heads.append(
  45. ConvModule(in_dim, self.reg_head_dim, k=3, p=1, s=1,
  46. act_type=self.act_type,
  47. norm_type=self.norm_type)
  48. )
  49. else:
  50. reg_heads.append(
  51. ConvModule(self.reg_head_dim, self.reg_head_dim, k=3, p=1, s=1,
  52. act_type=self.act_type,
  53. norm_type=self.norm_type)
  54. )
  55. self.cls_heads = nn.Sequential(*cls_heads)
  56. self.reg_heads = nn.Sequential(*reg_heads)
  57. ## pred layers
  58. self.cls_pred = nn.Conv2d(self.cls_head_dim, num_classes * self.num_anchors, kernel_size=3, padding=1)
  59. self.reg_pred = nn.Conv2d(self.reg_head_dim, 4 * self.num_anchors, kernel_size=3, padding=1)
  60. # init bias
  61. self._init_layers()
  62. def _init_layers(self):
  63. for module in [self.cls_heads, self.reg_heads, self.cls_pred, self.reg_pred]:
  64. for layer in module.modules():
  65. if isinstance(layer, nn.Conv2d):
  66. torch.nn.init.normal_(layer.weight, mean=0, std=0.01)
  67. torch.nn.init.constant_(layer.bias, 0)
  68. if isinstance(layer, nn.GroupNorm):
  69. torch.nn.init.constant_(layer.weight, 1)
  70. torch.nn.init.constant_(layer.bias, 0)
  71. # init the bias of cls pred
  72. init_prob = 0.01
  73. bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
  74. torch.nn.init.constant_(self.cls_pred.bias, bias_value)
  75. def get_anchor_sizes(self, cfg):
  76. basic_anchor_size = cfg['anchor_config']['basic_size']
  77. anchor_aspect_ratio = cfg['anchor_config']['aspect_ratio']
  78. anchor_area_scale = cfg['anchor_config']['area_scale']
  79. num_scales = len(basic_anchor_size)
  80. num_anchors = len(anchor_aspect_ratio) * len(anchor_area_scale)
  81. anchor_sizes = []
  82. for size in basic_anchor_size:
  83. for ar in anchor_aspect_ratio:
  84. for s in anchor_area_scale:
  85. ah, aw = size
  86. area = ah * aw * s
  87. anchor_sizes.append(
  88. [torch.sqrt(torch.tensor(ar * area)),
  89. torch.sqrt(torch.tensor(area / ar))]
  90. )
  91. # [S * KA, 2] -> [S, KA, 2]
  92. anchor_sizes = torch.as_tensor(anchor_sizes).view(num_scales, num_anchors, 2)
  93. return anchor_sizes
  94. def get_anchors(self, level, fmp_size):
  95. """
  96. fmp_size: (List) [H, W]
  97. """
  98. # generate grid cells
  99. fmp_h, fmp_w = fmp_size
  100. # [KA, 2]
  101. anchor_size = self.anchor_size[level]
  102. anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
  103. # [H, W, 2] -> [HW, 2]
  104. anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5
  105. # [HW, 2] -> [HW, 1, 2] -> [HW, KA, 2]
  106. anchor_xy = anchor_xy[:, None, :].repeat(1, self.num_anchors, 1)
  107. anchor_xy *= self.stride[level]
  108. # [KA, 2] -> [1, KA, 2] -> [HW, KA, 2]
  109. anchor_wh = anchor_size[None, :, :].repeat(fmp_h*fmp_w, 1, 1)
  110. # [HW, KA, 4] -> [M, 4], M = HW x KA
  111. anchor_boxes = torch.cat([anchor_xy, anchor_wh], dim=-1)
  112. anchor_boxes = anchor_boxes.view(-1, 4)
  113. return anchor_boxes
  114. def decode_boxes(self, anchor_boxes, pred_reg):
  115. """
  116. anchor_boxes: (List[Tensor]) [1, M, 4] or [M, 4]
  117. pred_reg: (List[Tensor]) [B, M, 4] or [M, 4]
  118. """
  119. # x = x_anchor + dx * w_anchor
  120. # y = y_anchor + dy * h_anchor
  121. pred_ctr_offset = pred_reg[..., :2] * anchor_boxes[..., 2:]
  122. pred_ctr_xy = anchor_boxes[..., :2] + pred_ctr_offset
  123. # w = w_anchor * exp(tw)
  124. # h = h_anchor * exp(th)
  125. pred_dwdh = pred_reg[..., 2:]
  126. pred_dwdh = torch.clamp(pred_dwdh, max=self.DEFAULT_SCALE_CLAMP)
  127. pred_wh = anchor_boxes[..., 2:] * pred_dwdh.exp()
  128. # convert [x, y, w, h] -> [x1, y1, x2, y2]
  129. pred_x1y1 = pred_ctr_xy - 0.5 * pred_wh
  130. pred_x2y2 = pred_ctr_xy + 0.5 * pred_wh
  131. pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
  132. return pred_box
  133. def forward(self, pyramid_feats, mask=None):
  134. all_masks = []
  135. all_anchors = []
  136. all_cls_preds = []
  137. all_reg_preds = []
  138. all_box_preds = []
  139. for level, feat in enumerate(pyramid_feats):
  140. # ------------------- Decoupled head -------------------
  141. cls_feat = self.cls_heads(feat)
  142. reg_feat = self.reg_heads(feat)
  143. # ------------------- Generate anchor box -------------------
  144. B, _, H, W = cls_feat.size()
  145. fmp_size = [H, W]
  146. anchor_boxes = self.get_anchors(level, fmp_size) # [M, 4]
  147. anchor_boxes = anchor_boxes.to(cls_feat.device)
  148. # ------------------- Predict -------------------
  149. cls_pred = self.cls_pred(cls_feat)
  150. reg_pred = self.reg_pred(reg_feat)
  151. # ------------------- Process preds -------------------
  152. ## [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
  153. cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
  154. reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
  155. ## Decode bbox
  156. box_pred = self.decode_boxes(anchor_boxes, reg_pred)
  157. ## Adjust mask
  158. if mask is not None:
  159. # [B, H, W]
  160. mask_i = torch.nn.functional.interpolate(mask[None].float(), size=[H, W]).bool()[0]
  161. # [B, H, W] -> [B, M]
  162. mask_i = mask_i.flatten(1)
  163. # [B, HW] -> [B, HW, KA] -> [B, M], M= HW x KA
  164. mask_i = mask_i[..., None].repeat(1, 1, self.num_anchors).flatten(1)
  165. all_masks.append(mask_i)
  166. all_anchors.append(anchor_boxes)
  167. all_cls_preds.append(cls_pred)
  168. all_reg_preds.append(reg_pred)
  169. all_box_preds.append(box_pred)
  170. outputs = {"pred_cls": all_cls_preds, # List [B, M, C]
  171. "pred_reg": all_reg_preds, # List [B, M, 4]
  172. "pred_box": all_box_preds, # List [B, M, 4]
  173. "anchors": all_anchors, # List [B, M, 2]
  174. "strides": self.stride,
  175. "mask": all_masks} # List [B, M,]
  176. return outputs