yolof.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import torch
  2. import torch.nn as nn
  3. # --------------- Model components ---------------
  4. from .yolof_backbone import YolofBackbone
  5. from .yolof_encoder import DilatedEncoder
  6. from .yolof_decoder import YolofHead
  7. # --------------- External components ---------------
  8. from utils.misc import multiclass_nms
  9. # ------------------------ You Only Look One-level Feature ------------------------
  10. class Yolof(nn.Module):
  11. def __init__(self, cfg, is_val: bool = False):
  12. super(Yolof, self).__init__()
  13. # ---------------------- Basic setting ----------------------
  14. self.cfg = cfg
  15. self.num_classes = cfg.num_classes
  16. ## Post-process parameters
  17. self.topk_candidates = cfg.val_topk if is_val else cfg.test_topk
  18. self.conf_thresh = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
  19. self.nms_thresh = cfg.val_nms_thresh if is_val else cfg.test_nms_thresh
  20. self.no_multi_labels = False if is_val else True
  21. # ---------------------- Network Parameters ----------------------
  22. self.backbone = YolofBackbone(cfg)
  23. self.encoder = DilatedEncoder(cfg, self.backbone.feat_dim, cfg.head_dim)
  24. self.decoder = YolofHead(cfg, self.encoder.out_dim, cfg.head_dim)
  25. def post_process(self, cls_pred, box_pred):
  26. """
  27. Input:
  28. cls_pred: (Tensor) [[H x W x KA, C]
  29. box_pred: (Tensor) [H x W x KA, 4]
  30. """
  31. cls_pred = cls_pred[0]
  32. box_pred = box_pred[0]
  33. # (H x W x KA x C,)
  34. scores_i = cls_pred.sigmoid().flatten()
  35. # Keep top k top scoring indices only.
  36. num_topk = min(self.topk_candidates, box_pred.size(0))
  37. # torch.sort is actually faster than .topk (at least on GPUs)
  38. predicted_prob, topk_idxs = scores_i.sort(descending=True)
  39. topk_scores = predicted_prob[:num_topk]
  40. topk_idxs = topk_idxs[:num_topk]
  41. # filter out the proposals with low confidence score
  42. keep_idxs = topk_scores > self.conf_thresh
  43. topk_idxs = topk_idxs[keep_idxs]
  44. # final scores
  45. scores = topk_scores[keep_idxs]
  46. # final labels
  47. labels = topk_idxs % self.num_classes
  48. # final bboxes
  49. anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
  50. bboxes = box_pred[anchor_idxs]
  51. # to cpu & numpy
  52. scores = scores.cpu().numpy()
  53. labels = labels.cpu().numpy()
  54. bboxes = bboxes.cpu().numpy()
  55. # nms
  56. scores, labels, bboxes = multiclass_nms(
  57. scores, labels, bboxes, self.nms_thresh, self.num_classes)
  58. return bboxes, scores, labels
  59. def forward(self, x):
  60. x = self.backbone(x)
  61. x = self.encoder(x)
  62. outputs = self.decoder(x)
  63. if not self.training:
  64. # ---------------- PostProcess ----------------
  65. cls_pred = outputs["pred_cls"]
  66. box_pred = outputs["pred_box"]
  67. bboxes, scores, labels = self.post_process(cls_pred, box_pred)
  68. outputs = {
  69. 'scores': scores,
  70. 'labels': labels,
  71. 'bboxes': bboxes
  72. }
  73. return outputs