yolov10_head.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. from typing import List
  5. try:
  6. from .modules import ConvModule, DflLayer
  7. except:
  8. from modules import ConvModule, DflLayer
  9. # YOLOv10 detection head
  10. class Yolov10DetHead(nn.Module):
  11. def __init__(self, cfg, fpn_dims: List = [64, 128, 245]):
  12. super().__init__()
  13. self.out_stride = cfg.out_stride
  14. self.reg_max = cfg.reg_max
  15. self.num_classes = cfg.num_classes
  16. self.cls_dim = max(fpn_dims[0], min(cfg.num_classes, 128))
  17. self.reg_dim = max(fpn_dims[0]//4, 16, 4*cfg.reg_max)
  18. # classification head
  19. self.cls_heads = nn.ModuleList(
  20. nn.Sequential(
  21. nn.Sequential(ConvModule(dim, dim, kernel_size=3, stride=1, groups=dim),
  22. ConvModule(dim, self.cls_dim, kernel_size=1)),
  23. nn.Sequential(ConvModule(self.cls_dim, self.cls_dim, kernel_size=3, stride=1, groups=self.cls_dim),
  24. ConvModule(self.cls_dim, self.cls_dim, kernel_size=1)),
  25. nn.Conv2d(self.cls_dim, cfg.num_classes, kernel_size=1),
  26. )
  27. for dim in fpn_dims
  28. )
  29. # bbox regression head
  30. self.reg_heads = nn.ModuleList(
  31. nn.Sequential(
  32. ConvModule(dim, self.reg_dim, kernel_size=3, stride=1),
  33. ConvModule(self.reg_dim, self.reg_dim, kernel_size=3, stride=1),
  34. nn.Conv2d(self.reg_dim, 4*cfg.reg_max, kernel_size=1),
  35. )
  36. for dim in fpn_dims
  37. )
  38. # DFL layer for decoding bbox
  39. self.dfl_layer = DflLayer(cfg.reg_max)
  40. for p in self.dfl_layer.parameters():
  41. p.requires_grad = False
  42. self.init_bias()
  43. def init_bias(self):
  44. # cls pred
  45. for i, m in enumerate(self.cls_heads):
  46. b = m[-1].bias.view(1, -1)
  47. b.data.fill_(math.log(5 / self.num_classes / (640. / self.out_stride[i]) ** 2))
  48. m[-1].bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  49. # reg pred
  50. for m in self.reg_heads:
  51. b = m[-1].bias.view(-1, )
  52. b.data.fill_(1.0)
  53. m[-1].bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
  54. w = m[-1].weight
  55. w.data.fill_(0.)
  56. m[-1].weight = torch.nn.Parameter(w, requires_grad=True)
  57. def generate_anchors(self, fmp_size, level):
  58. """
  59. fmp_size: (List) [H, W]
  60. """
  61. # generate grid cells
  62. fmp_h, fmp_w = fmp_size
  63. anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
  64. # [H, W, 2] -> [HW, 2]
  65. anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
  66. anchors += 0.5 # add center offset
  67. anchors *= self.out_stride[level]
  68. return anchors
  69. def forward(self, fpn_feats):
  70. anchors = []
  71. strides = []
  72. cls_preds = []
  73. reg_preds = []
  74. box_preds = []
  75. for lvl, (feat, cls_head, reg_head) in enumerate(zip(fpn_feats, self.cls_heads, self.reg_heads)):
  76. bs, c, h, w = feat.size()
  77. device = feat.device
  78. # Prediction
  79. cls_pred = cls_head(feat)
  80. reg_pred = reg_head(feat)
  81. # [bs, c, h, w] -> [bs, c, hw] -> [bs, hw, c]
  82. cls_pred = cls_pred.flatten(2).permute(0, 2, 1).contiguous()
  83. reg_pred = reg_pred.flatten(2).permute(0, 2, 1).contiguous()
  84. # anchor points: [M, 2]
  85. anchor = self.generate_anchors(fmp_size=[h, w], level=lvl).to(device)
  86. stride = torch.ones_like(anchor[..., :1]) * self.out_stride[lvl]
  87. # Decode bbox coords
  88. box_pred = self.dfl_layer(reg_pred, anchor[None], self.out_stride[lvl])
  89. # collect results
  90. anchors.append(anchor)
  91. strides.append(stride)
  92. cls_preds.append(cls_pred)
  93. reg_preds.append(reg_pred)
  94. box_preds.append(box_pred)
  95. # output dict
  96. outputs = {"pred_cls": cls_preds, # List(Tensor) [B, M, C]
  97. "pred_reg": reg_preds, # List(Tensor) [B, M, 4*(reg_max)]
  98. "pred_box": box_preds, # List(Tensor) [B, M, 4]
  99. "anchors": anchors, # List(Tensor) [M, 2]
  100. "stride_tensor": strides, # List(Tensor) [M, 1]
  101. "strides": self.out_stride, # List(Int) = [8, 16, 32]
  102. }
  103. return outputs
  104. if __name__=='__main__':
  105. from thop import profile
  106. # YOLOv10-Base config
  107. class Yolov10BaseConfig(object):
  108. def __init__(self) -> None:
  109. # ---------------- Model config ----------------
  110. self.width = 0.50
  111. self.depth = 0.34
  112. self.ratio = 2.0
  113. self.reg_max = 16
  114. self.out_stride = [8, 16, 32]
  115. self.max_stride = 32
  116. self.num_levels = 3
  117. self.num_classes = 80
  118. cfg = Yolov10BaseConfig()
  119. # Random data
  120. fpn_dims = [256, 512, 512]
  121. x = [torch.randn(1, fpn_dims[0], 80, 80),
  122. torch.randn(1, fpn_dims[1], 40, 40),
  123. torch.randn(1, fpn_dims[2], 20, 20)]
  124. # Neck model
  125. model = Yolov10DetHead(cfg, fpn_dims)
  126. # Inference
  127. outputs = model(x)
  128. print('============ FLOPs & Params ===========')
  129. flops, params = profile(model, inputs=(x, ), verbose=False)
  130. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  131. print('Params : {:.2f} M'.format(params / 1e6))