yolov2_head.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov2_basic import BasicConv
  5. except:
  6. from yolov2_basic import BasicConv
  7. class Yolov2DetHead(nn.Module):
  8. def __init__(self, cfg, in_dim: int = 256):
  9. super().__init__()
  10. # --------- Basic Parameters ----------
  11. self.in_dim = in_dim
  12. self.cls_head_dim = cfg.head_dim
  13. self.reg_head_dim = cfg.head_dim
  14. self.num_cls_head = cfg.num_cls_head
  15. self.num_reg_head = cfg.num_reg_head
  16. self.act_type = cfg.head_act
  17. self.norm_type = cfg.head_norm
  18. self.depthwise = cfg.head_depthwise
  19. # --------- Network Parameters ----------
  20. ## cls head
  21. cls_feats = []
  22. for i in range(self.num_cls_head):
  23. if i == 0:
  24. cls_feats.append(
  25. BasicConv(in_dim, self.cls_head_dim,
  26. kernel_size=3, padding=1, stride=1,
  27. act_type = self.act_type,
  28. norm_type = self.norm_type,
  29. depthwise = self.depthwise)
  30. )
  31. else:
  32. cls_feats.append(
  33. BasicConv(self.cls_head_dim, self.cls_head_dim,
  34. kernel_size=3, padding=1, stride=1,
  35. act_type = self.act_type,
  36. norm_type = self.norm_type,
  37. depthwise = self.depthwise)
  38. )
  39. ## reg head
  40. reg_feats = []
  41. for i in range(self.num_reg_head):
  42. if i == 0:
  43. reg_feats.append(
  44. BasicConv(in_dim, self.reg_head_dim,
  45. kernel_size=3, padding=1, stride=1,
  46. act_type = self.act_type,
  47. norm_type = self.norm_type,
  48. depthwise = self.depthwise)
  49. )
  50. else:
  51. reg_feats.append(
  52. BasicConv(self.reg_head_dim, self.reg_head_dim,
  53. kernel_size=3, padding=1, stride=1,
  54. act_type = self.act_type,
  55. norm_type = self.norm_type,
  56. depthwise = self.depthwise)
  57. )
  58. self.cls_feats = nn.Sequential(*cls_feats)
  59. self.reg_feats = nn.Sequential(*reg_feats)
  60. self.init_weights()
  61. def init_weights(self):
  62. """Initialize the parameters."""
  63. for m in self.modules():
  64. if isinstance(m, torch.nn.Conv2d):
  65. # In order to be consistent with the source code,
  66. # reset the Conv2d initialization parameters
  67. m.reset_parameters()
  68. def forward(self, x):
  69. """
  70. in_feats: (Tensor) [B, C, H, W]
  71. """
  72. cls_feats = self.cls_feats(x)
  73. reg_feats = self.reg_feats(x)
  74. return cls_feats, reg_feats
  75. if __name__=='__main__':
  76. import time
  77. from thop import profile
  78. # Model config
  79. # YOLOv8-Base config
  80. class Yolov2BaseConfig(object):
  81. def __init__(self) -> None:
  82. # ---------------- Model config ----------------
  83. self.out_stride = 32
  84. self.max_stride = 32
  85. ## Head
  86. self.head_act = 'lrelu'
  87. self.head_norm = 'BN'
  88. self.head_depthwise = False
  89. self.head_dim = 256
  90. self.num_cls_head = 2
  91. self.num_reg_head = 2
  92. cfg = Yolov2BaseConfig()
  93. # Build a head
  94. head = Yolov2DetHead(cfg, 512)
  95. # Inference
  96. x = torch.randn(1, 512, 20, 20)
  97. t0 = time.time()
  98. cls_feat, reg_feat = head(x)
  99. t1 = time.time()
  100. print('Time: ', t1 - t0)
  101. print(cls_feat.shape, reg_feat.shape)
  102. print('==============================')
  103. flops, params = profile(head, inputs=(x, ), verbose=False)
  104. print('==============================')
  105. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  106. print('Params : {:.2f} M'.format(params / 1e6))