ctrnet_pred.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. def build_det_pred(cls_dim, reg_dim, stride, num_classes, num_coords=4):
  5. pred_layers = SDetPDLayer(cls_dim = cls_dim,
  6. reg_dim = reg_dim,
  7. stride = stride,
  8. num_classes = num_classes,
  9. num_coords = num_coords)
  10. return pred_layers
  11. # ---------------------------- Detection predictor ----------------------------
  12. ## Single-level Detection Prediction Layer
  13. class SDetPDLayer(nn.Module):
  14. def __init__(self,
  15. cls_dim :int = 256,
  16. reg_dim :int = 256,
  17. stride :int = 32,
  18. num_classes :int = 80,
  19. num_coords :int = 4):
  20. super().__init__()
  21. # --------- Basic Parameters ----------
  22. self.stride = stride
  23. self.cls_dim = cls_dim
  24. self.reg_dim = reg_dim
  25. self.num_classes = num_classes
  26. self.num_coords = num_coords
  27. # --------- Network Parameters ----------
  28. self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
  29. self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1)
  30. self.init_bias()
  31. def init_bias(self):
  32. # cls pred bias
  33. b = self.cls_pred.bias.view(1, -1)
  34. b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
  35. self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  36. # reg pred bias
  37. b = self.reg_pred.bias.view(-1, )
  38. b.data.fill_(1.0)
  39. self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  40. def generate_anchors(self, fmp_size):
  41. """
  42. fmp_size: (List) [H, W]
  43. """
  44. # generate grid cells
  45. fmp_h, fmp_w = fmp_size
  46. anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
  47. # [H, W, 2] -> [HW, 2]
  48. anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
  49. anchors += 0.5 # add center offset
  50. anchors *= self.stride
  51. return anchors
  52. def forward(self, cls_feat, reg_feat):
  53. # pred
  54. cls_pred = self.cls_pred(cls_feat)
  55. reg_pred = self.reg_pred(reg_feat)
  56. # generate anchor boxes: [M, 4]
  57. B, _, H, W = cls_pred.size()
  58. fmp_size = [H, W]
  59. anchors = self.generate_anchors(fmp_size)
  60. anchors = anchors.to(cls_pred.device)
  61. # stride tensor: [M, 1]
  62. stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
  63. # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
  64. cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
  65. reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
  66. # ---------------- Decode bbox ----------------
  67. ctr_pred = reg_pred[..., :2] * self.stride + anchors[..., :2]
  68. wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride
  69. pred_x1y1 = ctr_pred - wh_pred * 0.5
  70. pred_x2y2 = ctr_pred + wh_pred * 0.5
  71. box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
  72. # output dict
  73. outputs = {"pred_cls": cls_pred, # (Tensor) [B, M, C]
  74. "pred_reg": reg_pred, # (Tensor) [B, M, 4]
  75. "pred_box": box_pred, # (Tensor) [B, M, 4]
  76. "anchors": anchors, # (Tensor) [M, 2]
  77. "stride": self.stride, # (Int)
  78. "stride_tensors": stride_tensor # List(Tensor) [M, 1]
  79. }
  80. return outputs