__init__.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import torch
  3. from .fcos.build import build_fcos, build_fcos_rt
  4. from .fcos_e2e.build import build_fcos_e2e
  5. from .yolof.build import build_yolof
  6. from .detr.build import build_detr
  7. def build_model(args, cfg, is_val=False):
  8. # ------------ build object detector ------------
  9. ## RT-FCOS
  10. if 'fcos_rt' in args.model:
  11. model, criterion = build_fcos_rt(cfg, is_val)
  12. ## E2E-FCOS
  13. elif 'fcos_e2e' in args.model:
  14. model, criterion = build_fcos_e2e(cfg, is_val)
  15. ## FCOS
  16. elif 'fcos' in args.model:
  17. model, criterion = build_fcos(cfg, is_val)
  18. ## YOLOF
  19. elif 'yolof' in args.model:
  20. model, criterion = build_yolof(cfg, is_val)
  21. ## DETR
  22. elif 'detr' in args.model:
  23. model, criterion = build_detr(cfg, is_val)
  24. else:
  25. raise NotImplementedError("Unknown detector: {}".args.model)
  26. if is_val:
  27. # ------------ Keep training from the given weight ------------
  28. if args.resume is not None and args.resume.lower() != "none":
  29. print('Load model from the checkpoint: ', args.resume)
  30. checkpoint = torch.load(args.resume, map_location='cpu')
  31. # checkpoint state dict
  32. checkpoint_state_dict = checkpoint.pop("model")
  33. model.load_state_dict(checkpoint_state_dict)
  34. return model, criterion
  35. else:
  36. return model