yolof.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import torch
  2. import torch.nn as nn
  3. # --------------- Model components ---------------
  4. from ...backbone import build_backbone
  5. from ...neck import build_neck
  6. from ...head import build_head
  7. # --------------- External components ---------------
  8. from utils.misc import multiclass_nms
  9. # ------------------------ You Only Look One-level Feature ------------------------
  10. class YOLOF(nn.Module):
  11. def __init__(self,
  12. cfg,
  13. num_classes :int = 80,
  14. conf_thresh :float = 0.05,
  15. nms_thresh :float = 0.6,
  16. topk :int = 1000,
  17. ca_nms :bool = False):
  18. super(YOLOF, self).__init__()
  19. # ---------------------- Basic Parameters ----------------------
  20. self.cfg = cfg
  21. self.topk = topk
  22. self.num_classes = num_classes
  23. self.conf_thresh = conf_thresh
  24. self.nms_thresh = nms_thresh
  25. self.ca_nms = ca_nms
  26. # ---------------------- Network Parameters ----------------------
  27. ## Backbone
  28. self.backbone, feat_dims = build_backbone(cfg)
  29. ## Neck
  30. self.neck = build_neck(cfg, feat_dims[-1], cfg['head_dim'])
  31. ## Heads
  32. self.head = build_head(cfg, cfg['head_dim'], cfg['head_dim'], num_classes)
  33. def post_process(self, cls_pred, box_pred):
  34. """
  35. Input:
  36. cls_pred: (Tensor) [[H x W x KA, C]
  37. box_pred: (Tensor) [H x W x KA, 4]
  38. """
  39. cls_pred = cls_pred[0]
  40. box_pred = box_pred[0]
  41. # (H x W x KA x C,)
  42. scores_i = cls_pred.sigmoid().flatten()
  43. # Keep top k top scoring indices only.
  44. num_topk = min(self.topk, box_pred.size(0))
  45. # torch.sort is actually faster than .topk (at least on GPUs)
  46. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  47. topk_scores = predicted_prob[:num_topk]
  48. topk_idxs = topk_idxs[:num_topk]
  49. # filter out the proposals with low confidence score
  50. keep_idxs = topk_scores > self.conf_thresh
  51. topk_idxs = topk_idxs[keep_idxs]
  52. # final scores
  53. scores = topk_scores[keep_idxs]
  54. # final labels
  55. labels = topk_idxs % self.num_classes
  56. # final bboxes
  57. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  58. bboxes = box_pred[anchor_idxs]
  59. # to cpu & numpy
  60. scores = scores.cpu().numpy()
  61. labels = labels.cpu().numpy()
  62. bboxes = bboxes.cpu().numpy()
  63. # nms
  64. scores, labels, bboxes = multiclass_nms(
  65. scores, labels, bboxes, self.nms_thresh, self.num_classes, self.ca_nms)
  66. return bboxes, scores, labels
  67. def forward(self, src, src_mask=None):
  68. # ---------------- Backbone ----------------
  69. pyramid_feats = self.backbone(src)
  70. # ---------------- Neck ----------------
  71. feat = self.neck(pyramid_feats[-1])
  72. # ---------------- Heads ----------------
  73. outputs = self.head(feat, src_mask)
  74. if not self.training:
  75. # ---------------- PostProcess ----------------
  76. cls_pred = outputs["pred_cls"]
  77. box_pred = outputs["pred_box"]
  78. bboxes, scores, labels = self.post_process(cls_pred, box_pred)
  79. # normalize bbox
  80. bboxes[..., 0::2] /= src.shape[-1]
  81. bboxes[..., 1::2] /= src.shape[-2]
  82. bboxes = bboxes.clip(0., 1.)
  83. return bboxes, scores, labels
  84. return outputs