__init__.py 1.0 KB

1234567891011121314151617181920212223242526272829303132
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import torch
  3. from .fcos.build import build_fcos
  4. from .yolof.build import build_yolof
  5. def build_model(args, cfg, is_val=False):
  6. # ------------ build object detector ------------
  7. ## FCOS
  8. if 'fcos' in args.model:
  9. model, criterion = build_fcos(cfg, is_val)
  10. ## YOLOF
  11. elif 'yolof' in args.model:
  12. model, criterion = build_yolof(cfg, is_val)
  13. else:
  14. raise NotImplementedError("Unknown detector: {}".args.model)
  15. if is_val:
  16. # ------------ Keep training from the given weight ------------
  17. if args.resume is not None:
  18. print('Load model from the checkpoint: ', args.resume)
  19. checkpoint = torch.load(args.resume, map_location='cpu')
  20. # checkpoint state dict
  21. checkpoint_state_dict = checkpoint.pop("model")
  22. model.load_state_dict(checkpoint_state_dict)
  23. return model, criterion
  24. else:
  25. return model