yolov2_head.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule
  5. except:
  6. from modules import ConvModule
  7. class Yolov2DetHead(nn.Module):
  8. def __init__(self, cfg, in_dim: int = 256):
  9. super().__init__()
  10. # --------- Basic Parameters ----------
  11. self.in_dim = in_dim
  12. self.cls_head_dim = cfg.head_dim
  13. self.reg_head_dim = cfg.head_dim
  14. self.num_cls_head = cfg.num_cls_head
  15. self.num_reg_head = cfg.num_reg_head
  16. # --------- Network Parameters ----------
  17. ## cls head
  18. cls_feats = []
  19. for i in range(self.num_cls_head):
  20. if i == 0:
  21. cls_feats.append(ConvModule(in_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
  22. else:
  23. cls_feats.append(ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
  24. ## reg head
  25. reg_feats = []
  26. for i in range(self.num_reg_head):
  27. if i == 0:
  28. reg_feats.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
  29. else:
  30. reg_feats.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
  31. self.cls_feats = nn.Sequential(*cls_feats)
  32. self.reg_feats = nn.Sequential(*reg_feats)
  33. self.init_weights()
  34. def init_weights(self):
  35. """Initialize the parameters."""
  36. for m in self.modules():
  37. if isinstance(m, torch.nn.Conv2d):
  38. m.reset_parameters()
  39. def forward(self, x):
  40. """
  41. in_feats: (Tensor) [B, C, H, W]
  42. """
  43. cls_feats = self.cls_feats(x)
  44. reg_feats = self.reg_feats(x)
  45. return cls_feats, reg_feats
  46. if __name__=='__main__':
  47. from thop import profile
  48. # YOLOv2 configuration
  49. class Yolov2BaseConfig(object):
  50. def __init__(self) -> None:
  51. # ---------------- Model config ----------------
  52. self.out_stride = 32
  53. self.max_stride = 32
  54. ## Head
  55. self.head_dim = 256
  56. self.num_cls_head = 2
  57. self.num_reg_head = 2
  58. cfg = Yolov2BaseConfig()
  59. # Build a head
  60. model = Yolov2DetHead(cfg, 512)
  61. # Randomly generate a input data
  62. x = torch.randn(2, 512, 20, 20)
  63. # Inference
  64. cls_feats, reg_feats = model(x)
  65. print(' - the shape of input : ', x.shape)
  66. print(' - the shape of cls feats : ', cls_feats.shape)
  67. print(' - the shape of reg feats : ', reg_feats.shape)
  68. x = torch.randn(1, 512, 20, 20)
  69. flops, params = profile(model, inputs=(x, ), verbose=False)
  70. print('============== FLOPs & Params ================')
  71. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  72. print(' - Params : {:.2f} M'.format(params / 1e6))