| 123456789101112131415161718192021 |
- from .dilated_encoder import DilatedEncoder
- from .fpn import BasicFPN
- from typing import List
- # build neck
- def build_neck(cfg, in_dim, out_dim):
- print('==============================')
- print('Neck: {}'.format(cfg.neck))
- # ----------------------- Neck module -----------------------
- if cfg.neck == 'dilated_encoder':
- model = DilatedEncoder(cfg, in_dim, out_dim)
-
- # ----------------------- FPN Neck -----------------------
- elif cfg.neck == 'basic_fpn':
- assert isinstance(in_dim, List)
- model = BasicFPN(cfg, in_dim, out_dim)
- else:
- raise NotImplementedError("Unknown Neck: <{}>".format(cfg.fpn))
-
- return model
|