yolov1_head.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .modules import ConvModule
  5. except:
  6. from modules import ConvModule
  7. class Yolov1DetHead(nn.Module):
  8. def __init__(self, cfg, in_dim: int = 256):
  9. super().__init__()
  10. # --------- Basic Parameters ----------
  11. self.in_dim = in_dim
  12. self.cls_head_dim = cfg.head_dim
  13. self.reg_head_dim = cfg.head_dim
  14. self.num_cls_head = cfg.num_cls_head
  15. self.num_reg_head = cfg.num_reg_head
  16. # --------- Network Parameters ----------
  17. ## cls head
  18. cls_feats = []
  19. for i in range(self.num_cls_head):
  20. if i == 0:
  21. cls_feats.append(ConvModule(in_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
  22. else:
  23. cls_feats.append(ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
  24. ## reg head
  25. reg_feats = []
  26. for i in range(self.num_reg_head):
  27. if i == 0:
  28. reg_feats.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
  29. else:
  30. reg_feats.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
  31. self.cls_feats = nn.Sequential(*cls_feats)
  32. self.reg_feats = nn.Sequential(*reg_feats)
  33. self.init_weights()
  34. def init_weights(self):
  35. for m in self.modules():
  36. if isinstance(m, torch.nn.Conv2d):
  37. m.reset_parameters()
  38. def forward(self, x):
  39. """
  40. in_feats: (Tensor) [B, C, H, W]
  41. """
  42. cls_feats = self.cls_feats(x)
  43. reg_feats = self.reg_feats(x)
  44. return cls_feats, reg_feats
  45. if __name__=='__main__':
  46. from thop import profile
  47. # YOLOv1 configuration
  48. class Yolov1BaseConfig(object):
  49. def __init__(self) -> None:
  50. # ---------------- Model config ----------------
  51. self.out_stride = 32
  52. self.max_stride = 32
  53. ## Head
  54. self.head_dim = 256
  55. self.num_cls_head = 2
  56. self.num_reg_head = 2
  57. cfg = Yolov1BaseConfig()
  58. # Build a head
  59. model = Yolov1DetHead(cfg, 512)
  60. # Randomly generate a input data
  61. x = torch.randn(2, 512, 20, 20)
  62. # Inference
  63. cls_feats, reg_feats = model(x)
  64. print(' - the shape of input : ', x.shape)
  65. print(' - the shape of cls feats : ', cls_feats.shape)
  66. print(' - the shape of reg feats : ', reg_feats.shape)
  67. x = torch.randn(1, 512, 20, 20)
  68. flops, params = profile(model, inputs=(x, ), verbose=False)
  69. print('============== FLOPs & Params ================')
  70. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  71. print(' - Params : {:.2f} M'.format(params / 1e6))