yolov5_head.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import torch
  2. import torch.nn as nn
  3. from .yolov5_basic import Conv
  4. class DecoupledHead(nn.Module):
  5. def __init__(self, cfg, in_dim, out_dim, num_classes=80):
  6. super().__init__()
  7. print('==============================')
  8. print('Head: Decoupled Head')
  9. # --------- Basic Parameters ----------
  10. self.in_dim = in_dim
  11. self.num_cls_head=cfg['num_cls_head']
  12. self.num_reg_head=cfg['num_reg_head']
  13. # --------- Network Parameters ----------
  14. ## cls head
  15. cls_feats = []
  16. self.cls_out_dim = max(out_dim, num_classes)
  17. for i in range(cfg['num_cls_head']):
  18. if i == 0:
  19. cls_feats.append(
  20. Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1,
  21. act_type=cfg['head_act'],
  22. norm_type=cfg['head_norm'],
  23. depthwise=cfg['head_depthwise'])
  24. )
  25. else:
  26. cls_feats.append(
  27. Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1,
  28. act_type=cfg['head_act'],
  29. norm_type=cfg['head_norm'],
  30. depthwise=cfg['head_depthwise'])
  31. )
  32. ## reg head
  33. reg_feats = []
  34. self.reg_out_dim = max(out_dim, 64)
  35. for i in range(cfg['num_reg_head']):
  36. if i == 0:
  37. reg_feats.append(
  38. Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1,
  39. act_type=cfg['head_act'],
  40. norm_type=cfg['head_norm'],
  41. depthwise=cfg['head_depthwise'])
  42. )
  43. else:
  44. reg_feats.append(
  45. Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1,
  46. act_type=cfg['head_act'],
  47. norm_type=cfg['head_norm'],
  48. depthwise=cfg['head_depthwise'])
  49. )
  50. self.cls_feats = nn.Sequential(*cls_feats)
  51. self.reg_feats = nn.Sequential(*reg_feats)
  52. def forward(self, x):
  53. """
  54. in_feats: (Tensor) [B, C, H, W]
  55. """
  56. cls_feats = self.cls_feats(x)
  57. reg_feats = self.reg_feats(x)
  58. return cls_feats, reg_feats
  59. # build detection head
  60. def build_head(cfg, in_dim, out_dim, num_classes=80):
  61. head = DecoupledHead(cfg, in_dim, out_dim, num_classes)
  62. return head