__init__.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. from .yolov8.build import build_yolov8
  7. from .rtdetr.build import build_rtdetr
  8. # build object detector
  9. def build_model(args, cfg, is_val=False):
  10. # ------------ build object detector ------------
  11. ## Modified YOLOv1
  12. if 'yolov1' in args.model:
  13. model, criterion = build_yolov1(cfg, is_val)
  14. ## Modified YOLOv2
  15. elif 'yolov2' in args.model:
  16. model, criterion = build_yolov2(cfg, is_val)
  17. ## YOLOv8
  18. elif 'yolov8' in args.model:
  19. model, criterion = build_yolov8(cfg, is_val)
  20. ## RT-DETR
  21. elif 'rtdetr' in args.model:
  22. model, criterion = build_rtdetr(cfg, is_val)
  23. if is_val:
  24. # ------------ Load pretrained weight ------------
  25. if args.pretrained is not None:
  26. print('Loading COCO pretrained weight ...')
  27. checkpoint = torch.load(args.pretrained, map_location='cpu')
  28. # checkpoint state dict
  29. checkpoint_state_dict = checkpoint.pop("model")
  30. # model state dict
  31. model_state_dict = model.state_dict()
  32. # check
  33. for k in list(checkpoint_state_dict.keys()):
  34. if k in model_state_dict:
  35. shape_model = tuple(model_state_dict[k].shape)
  36. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  37. if shape_model != shape_checkpoint:
  38. checkpoint_state_dict.pop(k)
  39. print(k)
  40. else:
  41. checkpoint_state_dict.pop(k)
  42. print(k)
  43. model.load_state_dict(checkpoint_state_dict, strict=False)
  44. # ------------ Keep training from the given checkpoint ------------
  45. if args.resume and args.resume != "None":
  46. checkpoint = torch.load(args.resume, map_location='cpu')
  47. # checkpoint state dict
  48. try:
  49. checkpoint_state_dict = checkpoint.pop("model")
  50. print('Load model from the checkpoint: ', args.resume)
  51. model.load_state_dict(checkpoint_state_dict)
  52. del checkpoint, checkpoint_state_dict
  53. except:
  54. print("No model in the given checkpoint.")
  55. return model, criterion
  56. else:
  57. return model