__init__.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import torch
  3. from .fcos.build import build_fcos, build_fcos_rt
  4. from .fcos_e2e.build import build_fcos_e2e
  5. from .fcos_pss.build import build_fcos_pss
  6. from .yolof.build import build_yolof
  7. from .detr.build import build_detr
  8. def build_model(args, cfg, is_val=False):
  9. # ------------ build object detector ------------
  10. ## RT-FCOS
  11. if 'fcos_rt' in args.model:
  12. model, criterion = build_fcos_rt(cfg, is_val)
  13. ## E2E-FCOS
  14. elif 'fcos_e2e' in args.model:
  15. model, criterion = build_fcos_e2e(cfg, is_val)
  16. ## PSS-FCOS
  17. elif 'fcos_pss' in args.model:
  18. model, criterion = build_fcos_pss(cfg, is_val)
  19. ## FCOS
  20. elif 'fcos' in args.model:
  21. model, criterion = build_fcos(cfg, is_val)
  22. ## YOLOF
  23. elif 'yolof' in args.model:
  24. model, criterion = build_yolof(cfg, is_val)
  25. ## DETR
  26. elif 'detr' in args.model:
  27. model, criterion = build_detr(cfg, is_val)
  28. else:
  29. raise NotImplementedError("Unknown detector: {}".args.model)
  30. if is_val:
  31. # ------------ Keep training from the given weight ------------
  32. if args.resume is not None and args.resume.lower() != "none":
  33. print('Load model from the checkpoint: ', args.resume)
  34. checkpoint = torch.load(args.resume, map_location='cpu')
  35. # checkpoint state dict
  36. checkpoint_state_dict = checkpoint.pop("model")
  37. model.load_state_dict(checkpoint_state_dict)
  38. return model, criterion
  39. else:
  40. return model