yolov1_head.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov1_basic import BasicConv
  5. except:
  6. from yolov1_basic import BasicConv
  7. class Yolov1DetHead(nn.Module):
  8. def __init__(self, cfg, in_dim: int = 256):
  9. super().__init__()
  10. # --------- Basic Parameters ----------
  11. self.in_dim = in_dim
  12. self.cls_head_dim = cfg.head_dim
  13. self.reg_head_dim = cfg.head_dim
  14. self.num_cls_head = cfg.num_cls_head
  15. self.num_reg_head = cfg.num_reg_head
  16. self.act_type = cfg.head_act
  17. self.norm_type = cfg.head_norm
  18. self.depthwise = cfg.head_depthwise
  19. # --------- Network Parameters ----------
  20. ## cls head
  21. cls_feats = []
  22. for i in range(self.num_cls_head):
  23. if i == 0:
  24. cls_feats.append(
  25. BasicConv(in_dim, self.cls_head_dim,
  26. kernel_size=3, padding=1, stride=1,
  27. act_type = self.act_type,
  28. norm_type = self.norm_type,
  29. depthwise = self.depthwise)
  30. )
  31. else:
  32. cls_feats.append(
  33. BasicConv(self.cls_head_dim, self.cls_head_dim,
  34. kernel_size=3, padding=1, stride=1,
  35. act_type = self.act_type,
  36. norm_type = self.norm_type,
  37. depthwise = self.depthwise)
  38. )
  39. ## reg head
  40. reg_feats = []
  41. for i in range(self.num_reg_head):
  42. if i == 0:
  43. reg_feats.append(
  44. BasicConv(in_dim, self.reg_head_dim,
  45. kernel_size=3, padding=1, stride=1,
  46. act_type = self.act_type,
  47. norm_type = self.norm_type,
  48. depthwise = self.depthwise)
  49. )
  50. else:
  51. reg_feats.append(
  52. BasicConv(self.reg_head_dim, self.reg_head_dim,
  53. kernel_size=3, padding=1, stride=1,
  54. act_type = self.act_type,
  55. norm_type = self.norm_type,
  56. depthwise = self.depthwise)
  57. )
  58. self.cls_feats = nn.Sequential(*cls_feats)
  59. self.reg_feats = nn.Sequential(*reg_feats)
  60. self.init_weights()
  61. def init_weights(self):
  62. """Initialize the parameters."""
  63. for m in self.modules():
  64. if isinstance(m, torch.nn.Conv2d):
  65. # In order to be consistent with the source code,
  66. # reset the Conv2d initialization parameters
  67. m.reset_parameters()
  68. def forward(self, x):
  69. """
  70. in_feats: (Tensor) [B, C, H, W]
  71. """
  72. cls_feats = self.cls_feats(x)
  73. reg_feats = self.reg_feats(x)
  74. return cls_feats, reg_feats
  75. if __name__=='__main__':
  76. from thop import profile
  77. # YOLOv1 configuration
  78. class Yolov1BaseConfig(object):
  79. def __init__(self) -> None:
  80. # ---------------- Model config ----------------
  81. self.out_stride = 32
  82. self.max_stride = 32
  83. ## Head
  84. self.head_act = 'lrelu'
  85. self.head_norm = 'BN'
  86. self.head_depthwise = False
  87. self.head_dim = 256
  88. self.num_cls_head = 2
  89. self.num_reg_head = 2
  90. cfg = Yolov1BaseConfig()
  91. # Build a head
  92. model = Yolov1DetHead(cfg, 512)
  93. # Randomly generate a input data
  94. x = torch.randn(2, 512, 20, 20)
  95. # Inference
  96. cls_feats, reg_feats = model(x)
  97. print(' - the shape of input : ', x.shape)
  98. print(' - the shape of cls feats : ', cls_feats.shape)
  99. print(' - the shape of reg feats : ', reg_feats.shape)
  100. x = torch.randn(1, 512, 20, 20)
  101. flops, params = profile(model, inputs=(x, ), verbose=False)
  102. print('============== FLOPs & Params ================')
  103. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  104. print(' - Params : {:.2f} M'.format(params / 1e6))