__init__.py 688 B

123456789101112131415161718192021
  1. from .dilated_encoder import DilatedEncoder
  2. from .fpn import BasicFPN
  3. from typing import List
  4. # build neck
  5. def build_neck(cfg, in_dim, out_dim):
  6. print('==============================')
  7. print('Neck: {}'.format(cfg.neck))
  8. # ----------------------- Neck module -----------------------
  9. if cfg.neck == 'dilated_encoder':
  10. model = DilatedEncoder(cfg, in_dim, out_dim)
  11. # ----------------------- FPN Neck -----------------------
  12. elif cfg.neck == 'basic_fpn':
  13. assert isinstance(in_dim, List)
  14. model = BasicFPN(cfg, in_dim, out_dim)
  15. else:
  16. raise NotImplementedError("Unknown Neck: <{}>".format(cfg.fpn))
  17. return model