yolov8.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .yolov8_backbone import build_backbone
  5. from .yolov8_neck import build_neck
  6. from .yolov8_pafpn import build_fpn
  7. from .yolov8_head import build_head
  8. from utils.nms import multiclass_nms
  9. # Anchor-free YOLO
  10. class YOLOv8(nn.Module):
  11. def __init__(self,
  12. cfg,
  13. device,
  14. num_classes = 20,
  15. conf_thresh = 0.05,
  16. nms_thresh = 0.6,
  17. trainable = False,
  18. topk = 1000):
  19. super(YOLOv8, self).__init__()
  20. # --------- Basic Parameters ----------
  21. self.cfg = cfg
  22. self.device = device
  23. self.stride = cfg['stride']
  24. self.reg_max = cfg['reg_max']
  25. self.use_dfl = cfg['reg_max'] > 1
  26. self.num_classes = num_classes
  27. self.trainable = trainable
  28. self.conf_thresh = conf_thresh
  29. self.nms_thresh = nms_thresh
  30. self.topk = topk
  31. # --------- Network Parameters ----------
  32. self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False)
  33. ## backbone
  34. self.backbone, feats_dim = build_backbone(cfg=cfg)
  35. ## neck
  36. self.neck = build_neck(cfg=cfg, in_dim=feats_dim[-1], out_dim=feats_dim[-1])
  37. feats_dim[-1] = self.neck.out_dim
  38. ## fpn
  39. self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim)
  40. fpn_dims = self.fpn.out_dim
  41. ## non-shared heads
  42. self.non_shared_heads = nn.ModuleList(
  43. [build_head(cfg, feat_dim, fpn_dims, num_classes)
  44. for feat_dim in fpn_dims
  45. ])
  46. ## pred
  47. self.cls_preds = nn.ModuleList(
  48. [nn.Conv2d(head.cls_out_dim, self.num_classes, kernel_size=1)
  49. for head in self.non_shared_heads
  50. ])
  51. self.reg_preds = nn.ModuleList(
  52. [nn.Conv2d(head.reg_out_dim, 4*(cfg['reg_max']), kernel_size=1)
  53. for head in self.non_shared_heads
  54. ])
  55. # --------- Network Initialization ----------
  56. # init bias
  57. self.init_yolo()
  58. def init_yolo(self):
  59. # Init yolo
  60. for m in self.modules():
  61. if isinstance(m, nn.BatchNorm2d):
  62. m.eps = 1e-3
  63. m.momentum = 0.03
  64. # Init bias
  65. init_prob = 0.01
  66. bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
  67. # cls pred
  68. for cls_pred in self.cls_preds:
  69. b = cls_pred.bias.view(1, -1)
  70. b.data.fill_(bias_value.item())
  71. cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  72. for reg_pred in self.reg_preds:
  73. b = reg_pred.bias.view(-1, )
  74. b.data.fill_(1.0)
  75. reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  76. w = reg_pred.weight
  77. w.data.fill_(0.)
  78. reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
  79. self.proj = nn.Parameter(torch.linspace(0, self.reg_max, self.reg_max), requires_grad=False)
  80. self.proj_conv.weight = nn.Parameter(self.proj.view([1, self.reg_max, 1, 1]).clone().detach(),
  81. requires_grad=False)
  82. def generate_anchors(self, level, fmp_size):
  83. """
  84. fmp_size: (List) [H, W]
  85. """
  86. # generate grid cells
  87. fmp_h, fmp_w = fmp_size
  88. anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
  89. # [H, W, 2] -> [HW, 2]
  90. anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5
  91. anchor_xy *= self.stride[level]
  92. anchors = anchor_xy.to(self.device)
  93. return anchors
  94. def decode_boxes(self, anchors, pred_regs, stride):
  95. """
  96. Input:
  97. anchors: (List[Tensor]) [1, M, 2]
  98. pred_reg: (List[Tensor]) [B, M, 4*(reg_max)]
  99. Output:
  100. pred_box: (Tensor) [B, M, 4]
  101. """
  102. if self.use_dfl:
  103. B, M = pred_regs.shape[:2]
  104. # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
  105. pred_regs = pred_regs.reshape([B, M, 4, self.reg_max])
  106. # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
  107. pred_regs = pred_regs.permute(0, 3, 2, 1).contiguous()
  108. # [B, reg_max, 4, M] -> [B, 1, 4, M]
  109. pred_regs = self.proj_conv(F.softmax(pred_regs, dim=1))
  110. # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
  111. pred_regs = pred_regs.view(B, 4, M).permute(0, 2, 1).contiguous()
  112. # tlbr -> xyxy
  113. pred_x1y1 = anchors - pred_regs[..., :2] * stride
  114. pred_x2y2 = anchors + pred_regs[..., 2:] * stride
  115. pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
  116. return pred_box
  117. def post_process(self, cls_preds, reg_preds, anchors):
  118. """
  119. Input:
  120. cls_preds: List(Tensor) [[B, H x W, C], ...]
  121. reg_preds: List(Tensor) [[B, H x W, 4*(reg_max)], ...]
  122. anchors: List(Tensor) [[H x W, 2], ...]
  123. """
  124. all_scores = []
  125. all_labels = []
  126. all_bboxes = []
  127. for level, (cls_pred_i, reg_pred_i, anchors_i) in enumerate(zip(cls_preds, reg_preds, anchors)):
  128. # [B, M, C] -> [M, C]
  129. cur_cls_pred_i = cls_pred_i[0]
  130. cur_reg_pred_i = reg_pred_i[0]
  131. # [MC,]
  132. scores_i = cur_cls_pred_i.sigmoid().flatten()
  133. # Keep top k top scoring indices only.
  134. num_topk = min(self.topk, cur_reg_pred_i.size(0))
  135. # torch.sort is actually faster than .topk (at least on GPUs)
  136. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  137. scores = predicted_prob[:num_topk]
  138. topk_idxs = topk_idxs[:num_topk]
  139. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  140. labels = topk_idxs % self.num_classes
  141. cur_reg_pred_i = cur_reg_pred_i[anchor_idxs]
  142. anchors_i = anchors_i[anchor_idxs]
  143. # decode box: [M, 4]
  144. box_pred_i = self.decode_boxes(
  145. anchors_i[None], cur_reg_pred_i[None], self.stride[level])
  146. bboxes = box_pred_i[0]
  147. all_scores.append(scores)
  148. all_labels.append(labels)
  149. all_bboxes.append(bboxes)
  150. scores = torch.cat(all_scores)
  151. labels = torch.cat(all_labels)
  152. bboxes = torch.cat(all_bboxes)
  153. # threshold
  154. keep_idxs = scores.gt(self.conf_thresh)
  155. scores = scores[keep_idxs]
  156. labels = labels[keep_idxs]
  157. bboxes = bboxes[keep_idxs]
  158. # to cpu & numpy
  159. scores = scores.cpu().numpy()
  160. labels = labels.cpu().numpy()
  161. bboxes = bboxes.cpu().numpy()
  162. # nms
  163. scores, labels, bboxes = multiclass_nms(
  164. scores, labels, bboxes, self.nms_thresh, self.num_classes, False)
  165. return bboxes, scores, labels
  166. @torch.no_grad()
  167. def inference_single_image(self, x):
  168. # backbone
  169. pyramid_feats = self.backbone(x)
  170. # neck
  171. pyramid_feats[-1] = self.neck(pyramid_feats[-1])
  172. # fpn
  173. pyramid_feats = self.fpn(pyramid_feats)
  174. # non-shared heads
  175. all_cls_preds = []
  176. all_reg_preds = []
  177. all_anchors = []
  178. for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
  179. cls_feat, reg_feat = head(feat)
  180. # pred
  181. cls_pred = self.cls_preds[level](cls_feat) # [B, C, H, W]
  182. reg_pred = self.reg_preds[level](reg_feat) # [B, 4*(reg_max), H, W]
  183. B, _, H, W = cls_pred.size()
  184. fmp_size = [H, W]
  185. # [M, 2]
  186. anchors = self.generate_anchors(level, fmp_size)
  187. # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
  188. cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
  189. reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
  190. all_cls_preds.append(cls_pred)
  191. all_reg_preds.append(reg_pred)
  192. all_anchors.append(anchors)
  193. # post process
  194. bboxes, scores, labels = self.post_process(
  195. all_cls_preds, all_reg_preds, all_anchors)
  196. return bboxes, scores, labels
  197. def forward(self, x):
  198. if not self.trainable:
  199. return self.inference_single_image(x)
  200. else:
  201. # backbone
  202. pyramid_feats = self.backbone(x)
  203. # neck
  204. pyramid_feats[-1] = self.neck(pyramid_feats[-1])
  205. # fpn
  206. pyramid_feats = self.fpn(pyramid_feats)
  207. # non-shared heads
  208. all_anchors = []
  209. all_cls_preds = []
  210. all_reg_preds = []
  211. all_box_preds = []
  212. all_strides = []
  213. for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
  214. cls_feat, reg_feat = head(feat)
  215. # pred
  216. cls_pred = self.cls_preds[level](cls_feat) # [B, C, H, W]
  217. reg_pred = self.reg_preds[level](reg_feat) # [B, 4*(reg_max), H, W]
  218. B, _, H, W = cls_pred.size()
  219. fmp_size = [H, W]
  220. # generate anchor boxes: [M, 2]
  221. anchors = self.generate_anchors(level, fmp_size)
  222. # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
  223. cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
  224. reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
  225. # decode box: [B, M, 4]
  226. box_pred = self.decode_boxes(anchors, reg_pred, self.stride[level])
  227. # stride tensor: [M, 1]
  228. stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
  229. all_cls_preds.append(cls_pred)
  230. all_reg_preds.append(reg_pred)
  231. all_box_preds.append(box_pred)
  232. all_anchors.append(anchors)
  233. all_strides.append(stride_tensor)
  234. # output dict
  235. outputs = {"pred_cls": all_cls_preds, # List(Tensor) [B, M, C]
  236. "pred_reg": all_reg_preds, # List(Tensor) [B, M, 4*(reg_max)]
  237. "pred_box": all_box_preds, # List(Tensor) [B, M, 4]
  238. "anchors": all_anchors, # List(Tensor) [M, 2]
  239. "strides": self.stride, # List(Int) = [8, 16, 32]
  240. "stride_tensor": all_strides # List(Tensor) [M, 1]
  241. }
  242. return outputs