__init__.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. from .yolov3.build import build_yolov3
  7. from .yolov5.build import build_yolov5
  8. from .yolov5_af.build import build_yolov5af
  9. from .yolov7_af.build import build_yolov7af
  10. from .yolov8.build import build_yolov8
  11. from .gelan.build import build_gelan
  12. from .rtdetr.build import build_rtdetr
  13. # build object detector
  14. def build_model(args, cfg, is_val=False):
  15. # ------------ build object detector ------------
  16. ## Modified YOLOv1
  17. if 'yolov1' in args.model:
  18. model, criterion = build_yolov1(cfg, is_val)
  19. ## Modified YOLOv2
  20. elif 'yolov2' in args.model:
  21. model, criterion = build_yolov2(cfg, is_val)
  22. ## Modified YOLOv3
  23. elif 'yolov3' in args.model:
  24. model, criterion = build_yolov3(cfg, is_val)
  25. ## Anchor-free YOLOv5
  26. elif 'yolov5_af' in args.model:
  27. model, criterion = build_yolov5af(cfg, is_val)
  28. ## Modified YOLOv5
  29. elif 'yolov5' in args.model:
  30. model, criterion = build_yolov5(cfg, is_val)
  31. ## Anchor-free YOLOv7
  32. elif 'yolov7_af' in args.model:
  33. model, criterion = build_yolov7af(cfg, is_val)
  34. ## YOLOv8
  35. elif 'yolov8' in args.model:
  36. model, criterion = build_yolov8(cfg, is_val)
  37. ## GElan
  38. elif 'gelan' in args.model:
  39. model, criterion = build_gelan(cfg, is_val)
  40. ## RT-DETR
  41. elif 'rtdetr' in args.model:
  42. model, criterion = build_rtdetr(cfg, is_val)
  43. if is_val:
  44. # ------------ Load pretrained weight ------------
  45. if args.pretrained is not None:
  46. print('Loading COCO pretrained weight ...')
  47. checkpoint = torch.load(args.pretrained, map_location='cpu')
  48. # checkpoint state dict
  49. checkpoint_state_dict = checkpoint.pop("model")
  50. # model state dict
  51. model_state_dict = model.state_dict()
  52. # check
  53. for k in list(checkpoint_state_dict.keys()):
  54. if k in model_state_dict:
  55. shape_model = tuple(model_state_dict[k].shape)
  56. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  57. if shape_model != shape_checkpoint:
  58. checkpoint_state_dict.pop(k)
  59. print(k)
  60. else:
  61. checkpoint_state_dict.pop(k)
  62. print(k)
  63. model.load_state_dict(checkpoint_state_dict, strict=False)
  64. # ------------ Keep training from the given checkpoint ------------
  65. if args.resume and args.resume != "None":
  66. checkpoint = torch.load(args.resume, map_location='cpu')
  67. # checkpoint state dict
  68. try:
  69. checkpoint_state_dict = checkpoint.pop("model")
  70. print('Load model from the checkpoint: ', args.resume)
  71. model.load_state_dict(checkpoint_state_dict)
  72. del checkpoint, checkpoint_state_dict
  73. except:
  74. print("No model in the given checkpoint.")
  75. return model, criterion
  76. else:
  77. return model