__init__.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. from .retinanet_head import RetinaNetHead
  2. from .yolof_head import YOLOFHead
  3. from .fcos_head import FCOSHead
  4. # build head
  5. def build_head(cfg, in_dim, out_dim, num_classes):
  6. print('==============================')
  7. print('Head: {}'.format(cfg['head']))
  8. if cfg['head'] == 'retinanet_head':
  9. model = RetinaNetHead(cfg = cfg,
  10. in_dim = in_dim,
  11. out_dim = out_dim,
  12. num_classes = num_classes,
  13. num_cls_head = cfg['num_cls_head'],
  14. num_reg_head = cfg['num_reg_head'],
  15. act_type = cfg['head_act'],
  16. norm_type = cfg['head_norm']
  17. )
  18. elif cfg['head'] == 'fcos_head':
  19. model = FCOSHead(cfg = cfg,
  20. in_dim = in_dim,
  21. out_dim = out_dim,
  22. num_classes = num_classes,
  23. num_cls_head = cfg['num_cls_head'],
  24. num_reg_head = cfg['num_reg_head'],
  25. act_type = cfg['head_act'],
  26. norm_type = cfg['head_norm']
  27. )
  28. elif cfg['head'] == 'yolof_head':
  29. model = YOLOFHead(cfg = cfg,
  30. in_dim = in_dim,
  31. out_dim = out_dim,
  32. num_classes = num_classes,
  33. num_cls_head = cfg['num_cls_head'],
  34. num_reg_head = cfg['num_reg_head'],
  35. act_type = cfg['head_act'],
  36. norm_type = cfg['head_norm']
  37. )
  38. return model