__init__.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from .dilated_encoder import DilatedEncoder
  2. from .fpn import BasicFPN
  3. from .spp import SPPF
  4. # build neck
  5. def build_neck(cfg, in_dim, out_dim):
  6. print('==============================')
  7. print('Neck: {}'.format(cfg.neck))
  8. # ----------------------- Neck module -----------------------
  9. if cfg.neck == 'dilated_encoder':
  10. model = DilatedEncoder(in_dim = in_dim,
  11. out_dim = out_dim,
  12. expand_ratio = cfg.neck_expand_ratio,
  13. dilations = cfg.neck_dilations,
  14. act_type = cfg.neck_act,
  15. norm_type = cfg.neck_norm,
  16. )
  17. elif cfg.neck == 'spp_block':
  18. model = SPPF(in_dim = in_dim,
  19. out_dim = out_dim,
  20. expand_ratio = cfg.neck_expand_ratio,
  21. pooling_size = cfg.spp_pooling_size,
  22. act_type = cfg.neck_act,
  23. norm_type = cfg.neck_norm,
  24. )
  25. # ----------------------- FPN Neck -----------------------
  26. elif cfg.neck == 'basic_fpn':
  27. model = BasicFPN(in_dims = in_dim,
  28. out_dim = out_dim,
  29. p6_feat = cfg.fpn_p6_feat,
  30. p7_feat = cfg.fpn_p7_feat,
  31. from_c5 = cfg.fpn_p6_from_c5,
  32. )
  33. else:
  34. raise NotImplementedError("Unknown Neck: <{}>".format(cfg.fpn))
  35. return model