__init__.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. import torch
  3. from .fcos.build import build_fcos
  4. from .yolof.build import build_yolof
  5. from .detr.build import build_detr
  6. def build_model(args, cfg, num_classes=80, is_val=False):
  7. # ------------ build object detector ------------
  8. ## FCOS
  9. if 'fcos' in args.model:
  10. model, criterion = build_fcos(cfg, num_classes, is_val)
  11. ## YOLOF
  12. elif 'yolof' in args.model:
  13. model, criterion = build_yolof(cfg, num_classes, is_val)
  14. ## DETR
  15. elif 'detr' in args.model:
  16. model, criterion = build_detr(cfg, num_classes, is_val)
  17. else:
  18. raise NotImplementedError("Unknown detector: {}".args.model)
  19. if is_val:
  20. # ------------ Keep training from the given weight ------------
  21. if args.resume is not None:
  22. print('keep training: ', args.resume)
  23. checkpoint = torch.load(args.resume, map_location='cpu')
  24. # checkpoint state dict
  25. checkpoint_state_dict = checkpoint.pop("model")
  26. model.load_state_dict(checkpoint_state_dict)
  27. return model, criterion
  28. else:
  29. return model