__init__.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import torch
  3. from .retinanet.build import build_retinanet
  4. from .fcos.build import build_fcos
  5. from .yolof.build import build_yolof
  6. from .detr.build import build_detr
  7. def build_model(args, cfg, num_classes=80, is_val=False):
  8. # ------------ build object detector ------------
  9. ## RetinaNet
  10. if 'retinanet' in args.model:
  11. model, criterion = build_retinanet(cfg, num_classes, is_val)
  12. ## FCOS
  13. elif 'fcos' in args.model:
  14. model, criterion = build_fcos(cfg, num_classes, is_val)
  15. ## YOLOF
  16. elif 'yolof' in args.model:
  17. model, criterion = build_yolof(cfg, num_classes, is_val)
  18. ## DETR
  19. elif 'detr' in args.model:
  20. model, criterion = build_detr(cfg, num_classes, is_val)
  21. else:
  22. raise NotImplementedError("Unknown detector: {}".args.model)
  23. if is_val:
  24. # ------------ Keep training from the given weight ------------
  25. if args.resume is not None:
  26. print('keep training: ', args.resume)
  27. checkpoint = torch.load(args.resume, map_location='cpu')
  28. # checkpoint state dict
  29. checkpoint_state_dict = checkpoint.pop("model")
  30. model.load_state_dict(checkpoint_state_dict)
  31. return model, criterion
  32. else:
  33. return model