yolov2_head.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule
  5. except:
  6. from modules import ConvModule
  7. class Yolov2DetHead(nn.Module):
  8. def __init__(self, cfg, in_dim: int = 256):
  9. super().__init__()
  10. # --------- Basic Parameters ----------
  11. self.in_dim = in_dim
  12. self.cls_head_dim = cfg.head_dim
  13. self.reg_head_dim = cfg.head_dim
  14. self.num_cls_head = cfg.num_cls_head
  15. self.num_reg_head = cfg.num_reg_head
  16. # --------- Network Parameters ----------
  17. ## cls head
  18. cls_feats = []
  19. for i in range(self.num_cls_head):
  20. if i == 0:
  21. cls_feats.append(ConvModule(in_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
  22. else:
  23. cls_feats.append(ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
  24. ## reg head
  25. reg_feats = []
  26. for i in range(self.num_reg_head):
  27. if i == 0:
  28. reg_feats.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
  29. else:
  30. reg_feats.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
  31. self.cls_feats = nn.Sequential(*cls_feats)
  32. self.reg_feats = nn.Sequential(*reg_feats)
  33. self.init_weights()
  34. def init_weights(self):
  35. """Initialize the parameters."""
  36. for m in self.modules():
  37. if isinstance(m, torch.nn.Conv2d):
  38. m.reset_parameters()
  39. def forward(self, x):
  40. """
  41. in_feats: (Tensor) [B, C, H, W]
  42. """
  43. cls_feats = self.cls_feats(x)
  44. reg_feats = self.reg_feats(x)
  45. return cls_feats, reg_feats
  46. if __name__=='__main__':
  47. import time
  48. from thop import profile
  49. # Model config
  50. # YOLOv8-Base config
  51. class Yolov2BaseConfig(object):
  52. def __init__(self) -> None:
  53. # ---------------- Model config ----------------
  54. self.out_stride = 32
  55. self.max_stride = 32
  56. ## Head
  57. self.head_act = 'lrelu'
  58. self.head_norm = 'BN'
  59. self.head_depthwise = False
  60. self.head_dim = 256
  61. self.num_cls_head = 2
  62. self.num_reg_head = 2
  63. cfg = Yolov2BaseConfig()
  64. # Build a head
  65. head = Yolov2DetHead(cfg, 512)
  66. # Inference
  67. x = torch.randn(1, 512, 20, 20)
  68. t0 = time.time()
  69. cls_feat, reg_feat = head(x)
  70. t1 = time.time()
  71. print('Time: ', t1 - t0)
  72. print(cls_feat.shape, reg_feat.shape)
  73. print('==============================')
  74. flops, params = profile(head, inputs=(x, ), verbose=False)
  75. print('==============================')
  76. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  77. print('Params : {:.2f} M'.format(params / 1e6))