yolov2_head.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolov2_basic import Conv
  5. except:
  6. from yolov2_basic import Conv
  7. class DecoupledHead(nn.Module):
  8. def __init__(self, cfg, in_dim, out_dim, num_classes=80):
  9. super().__init__()
  10. print('==============================')
  11. print('Head: Decoupled Head')
  12. self.in_dim = in_dim
  13. self.num_cls_head=cfg['num_cls_head']
  14. self.num_reg_head=cfg['num_reg_head']
  15. self.act_type=cfg['head_act']
  16. self.norm_type=cfg['head_norm']
  17. # cls head
  18. cls_feats = []
  19. self.cls_out_dim = max(out_dim, num_classes)
  20. for i in range(cfg['num_cls_head']):
  21. if i == 0:
  22. cls_feats.append(
  23. Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1,
  24. act_type=self.act_type,
  25. norm_type=self.norm_type,
  26. depthwise=cfg['head_depthwise'])
  27. )
  28. else:
  29. cls_feats.append(
  30. Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1,
  31. act_type=self.act_type,
  32. norm_type=self.norm_type,
  33. depthwise=cfg['head_depthwise'])
  34. )
  35. # reg head
  36. reg_feats = []
  37. self.reg_out_dim = max(out_dim, 64)
  38. for i in range(cfg['num_reg_head']):
  39. if i == 0:
  40. reg_feats.append(
  41. Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1,
  42. act_type=self.act_type,
  43. norm_type=self.norm_type,
  44. depthwise=cfg['head_depthwise'])
  45. )
  46. else:
  47. reg_feats.append(
  48. Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1,
  49. act_type=self.act_type,
  50. norm_type=self.norm_type,
  51. depthwise=cfg['head_depthwise'])
  52. )
  53. self.cls_feats = nn.Sequential(*cls_feats)
  54. self.reg_feats = nn.Sequential(*reg_feats)
  55. def forward(self, x):
  56. """
  57. in_feats: (Tensor) [B, C, H, W]
  58. """
  59. cls_feats = self.cls_feats(x)
  60. reg_feats = self.reg_feats(x)
  61. return cls_feats, reg_feats
  62. # build detection head
  63. def build_head(cfg, in_dim, out_dim, num_classes=80):
  64. head = DecoupledHead(cfg, in_dim, out_dim, num_classes)
  65. return head
  66. if __name__ == '__main__':
  67. import time
  68. from thop import profile
  69. cfg = {
  70. 'num_cls_head': 2,
  71. 'num_reg_head': 2,
  72. 'head_act': 'silu',
  73. 'head_norm': 'BN',
  74. 'head_depthwise': False,
  75. 'reg_max': 16,
  76. }
  77. fpn_dims = [256, 512, 512]
  78. # Head-1
  79. model = build_head(cfg, 256, fpn_dims, num_classes=80)
  80. x = torch.randn(1, 256, 80, 80)
  81. t0 = time.time()
  82. outputs = model(x)
  83. t1 = time.time()
  84. print('Time: ', t1 - t0)
  85. # for out in outputs:
  86. # print(out.shape)
  87. print('==============================')
  88. flops, params = profile(model, inputs=(x, ), verbose=False)
  89. print('==============================')
  90. print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  91. print('Head-1: Params : {:.2f} M'.format(params / 1e6))
  92. # Head-2
  93. model = build_head(cfg, 512, fpn_dims, num_classes=80)
  94. x = torch.randn(1, 512, 40, 40)
  95. t0 = time.time()
  96. outputs = model(x)
  97. t1 = time.time()
  98. print('Time: ', t1 - t0)
  99. # for out in outputs:
  100. # print(out.shape)
  101. print('==============================')
  102. flops, params = profile(model, inputs=(x, ), verbose=False)
  103. print('==============================')
  104. print('Head-2: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  105. print('Head-2: Params : {:.2f} M'.format(params / 1e6))
  106. # Head-3
  107. model = build_head(cfg, 512, fpn_dims, num_classes=80)
  108. x = torch.randn(1, 512, 20, 20)
  109. t0 = time.time()
  110. outputs = model(x)
  111. t1 = time.time()
  112. print('Time: ', t1 - t0)
  113. # for out in outputs:
  114. # print(out.shape)
  115. print('==============================')
  116. flops, params = profile(model, inputs=(x, ), verbose=False)
  117. print('==============================')
  118. print('Head-3: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  119. print('Head-3: Params : {:.2f} M'.format(params / 1e6))