yolov4_head.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import torch
  2. import torch.nn as nn
  3. from .yolov4_basic import Conv
  4. class DecoupledHead(nn.Module):
  5. def __init__(self, cfg, in_dim, out_dim, num_classes=80):
  6. super().__init__()
  7. print('==============================')
  8. print('Head: Decoupled Head')
  9. self.in_dim = in_dim
  10. self.num_cls_head=cfg['num_cls_head']
  11. self.num_reg_head=cfg['num_reg_head']
  12. self.act_type=cfg['head_act']
  13. self.norm_type=cfg['head_norm']
  14. # cls head
  15. cls_feats = []
  16. self.cls_out_dim = max(out_dim, num_classes)
  17. for i in range(cfg['num_cls_head']):
  18. if i == 0:
  19. cls_feats.append(
  20. Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1,
  21. act_type=self.act_type,
  22. norm_type=self.norm_type,
  23. depthwise=cfg['head_depthwise'])
  24. )
  25. else:
  26. cls_feats.append(
  27. Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1,
  28. act_type=self.act_type,
  29. norm_type=self.norm_type,
  30. depthwise=cfg['head_depthwise'])
  31. )
  32. # reg head
  33. reg_feats = []
  34. self.reg_out_dim = max(out_dim, 64)
  35. for i in range(cfg['num_reg_head']):
  36. if i == 0:
  37. reg_feats.append(
  38. Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1,
  39. act_type=self.act_type,
  40. norm_type=self.norm_type,
  41. depthwise=cfg['head_depthwise'])
  42. )
  43. else:
  44. reg_feats.append(
  45. Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1,
  46. act_type=self.act_type,
  47. norm_type=self.norm_type,
  48. depthwise=cfg['head_depthwise'])
  49. )
  50. self.cls_feats = nn.Sequential(*cls_feats)
  51. self.reg_feats = nn.Sequential(*reg_feats)
  52. def forward(self, x):
  53. """
  54. in_feats: (Tensor) [B, C, H, W]
  55. """
  56. cls_feats = self.cls_feats(x)
  57. reg_feats = self.reg_feats(x)
  58. return cls_feats, reg_feats
  59. # build detection head
  60. def build_head(cfg, in_dim, out_dim, num_classes=80):
  61. head = DecoupledHead(cfg, in_dim, out_dim, num_classes)
  62. return head