yolov3_head.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import torch
  2. import torch.nn as nn
  3. from typing import List
  4. try:
  5. from .modules import ConvModule
  6. except:
  7. from modules import ConvModule
  8. class DecoupledHead(nn.Module):
  9. def __init__(self, cfg, in_dim: int = 256):
  10. super().__init__()
  11. self.in_dim = in_dim
  12. self.cls_head_dim = cfg.head_dim
  13. self.reg_head_dim = cfg.head_dim
  14. self.num_cls_head = cfg.num_cls_head
  15. self.num_reg_head = cfg.num_reg_head
  16. # classification feature head
  17. cls_feats = []
  18. for i in range(self.num_cls_head):
  19. if i == 0:
  20. cls_feats.append(ConvModule(in_dim, self.cls_head_dim, kernel_size=3, stride=1))
  21. else:
  22. cls_feats.append(ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, stride=1))
  23. # box regression feature head
  24. reg_feats = []
  25. for i in range(self.num_reg_head):
  26. if i == 0:
  27. reg_feats.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, stride=1))
  28. else:
  29. reg_feats.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, stride=1))
  30. self.cls_feats = nn.Sequential(*cls_feats)
  31. self.reg_feats = nn.Sequential(*reg_feats)
  32. def forward(self, x):
  33. """
  34. in_feats: (Tensor) [B, C, H, W]
  35. """
  36. cls_feats = self.cls_feats(x)
  37. reg_feats = self.reg_feats(x)
  38. return cls_feats, reg_feats
  39. if __name__=='__main__':
  40. from thop import profile
  41. # YOLOv2 configuration
  42. class Yolov3BaseConfig(object):
  43. def __init__(self) -> None:
  44. # ---------------- Model config ----------------
  45. self.head_dim = 256
  46. self.num_cls_head = 2
  47. self.num_reg_head = 2
  48. cfg = Yolov3BaseConfig()
  49. # Build a head
  50. model = DecoupledHead(cfg, in_dim= 256)
  51. # Randomly generate a input data
  52. x = torch.randn(2, 256, 20, 20)
  53. # Inference
  54. cls_feats, reg_feats = model(x)
  55. print(' - the shape of input : ', x.shape)
  56. print(' - the shape of cls feats : ', cls_feats.shape)
  57. print(' - the shape of reg feats : ', reg_feats.shape)
  58. x = torch.randn(1, 256, 20, 20)
  59. flops, params = profile(model, inputs=(x, ), verbose=False)
  60. print('============== FLOPs & Params ================')
  61. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  62. print(' - Params : {:.2f} M'.format(params / 1e6))