__init__.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import torch
  3. from .fcos.build import build_fcos, build_fcos_rt
  4. from .yolof.build import build_yolof
  5. def build_model(args, cfg, is_val=False):
  6. # ------------ build object detector ------------
  7. ## RT-FCOS
  8. if 'fcos_rt' in args.model:
  9. model, criterion = build_fcos_rt(cfg, is_val)
  10. ## FCOS
  11. elif 'fcos' in args.model:
  12. model, criterion = build_fcos(cfg, is_val)
  13. ## YOLOF
  14. elif 'yolof' in args.model:
  15. model, criterion = build_yolof(cfg, is_val)
  16. else:
  17. raise NotImplementedError("Unknown detector: {}".args.model)
  18. if is_val:
  19. # ------------ Keep training from the given weight ------------
  20. if args.resume is not None and args.resume.lower() != "none":
  21. print('Load model from the checkpoint: ', args.resume)
  22. checkpoint = torch.load(args.resume, map_location='cpu')
  23. # checkpoint state dict
  24. checkpoint_state_dict = checkpoint.pop("model")
  25. model.load_state_dict(checkpoint_state_dict)
  26. return model, criterion
  27. else:
  28. return model