__init__.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import torch
  3. from .fcos.build import build_fcos, build_fcos_rt
  4. from .yolof.build import build_yolof
  5. from .detr.build import build_detr
  6. def build_model(args, cfg, is_val=False):
  7. # ------------ build object detector ------------
  8. ## RT-FCOS
  9. if 'fcos_rt' in args.model:
  10. model, criterion = build_fcos_rt(cfg, is_val)
  11. ## FCOS
  12. elif 'fcos' in args.model:
  13. model, criterion = build_fcos(cfg, is_val)
  14. ## YOLOF
  15. elif 'yolof' in args.model:
  16. model, criterion = build_yolof(cfg, is_val)
  17. ## DETR
  18. elif 'detr' in args.model:
  19. model, criterion = build_detr(cfg, is_val)
  20. else:
  21. raise NotImplementedError("Unknown detector: {}".args.model)
  22. if is_val:
  23. # ------------ Keep training from the given weight ------------
  24. if args.resume is not None and args.resume.lower() != "none":
  25. print('Load model from the checkpoint: ', args.resume)
  26. checkpoint = torch.load(args.resume, map_location='cpu')
  27. # checkpoint state dict
  28. checkpoint_state_dict = checkpoint.pop("model")
  29. model.load_state_dict(checkpoint_state_dict)
  30. return model, criterion
  31. else:
  32. return model