__init__.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. from .yolov3.build import build_yolov3
  7. from .yolov4.build import build_yolov4
  8. from .yolov5.build import build_yolov5
  9. from .yolox.build import build_yolox
  10. from .yolov6.build import build_yolov6
  11. from .yolov7.build import build_yolov7
  12. from .yolov8.build import build_yolov8
  13. from .yolov9.build import build_gelan
  14. from .yolov10.build import build_yolov10
  15. from .yolo11.build import build_yolo11
  16. from .yolof.build import build_yolof
  17. from .fcos.build import build_fcos
  18. from .rtdetr.build import build_rtdetr
  19. # build object detector
  20. def build_model(args, cfg, is_val=False):
  21. # ------------ build object detector ------------
  22. ## Modified YOLOv1
  23. if 'yolov1' in args.model:
  24. model, criterion = build_yolov1(cfg, is_val)
  25. ## Modified YOLOv2
  26. elif 'yolov2' in args.model:
  27. model, criterion = build_yolov2(cfg, is_val)
  28. ## Modified YOLOv3
  29. elif 'yolov3' in args.model:
  30. model, criterion = build_yolov3(cfg, is_val)
  31. ## Modified YOLOv4
  32. elif 'yolov4' in args.model:
  33. model, criterion = build_yolov4(cfg, is_val)
  34. ## Anchor-free YOLOv5
  35. elif 'yolox' in args.model:
  36. model, criterion = build_yolox(cfg, is_val)
  37. ## Modified YOLOv5
  38. elif 'yolov5' in args.model:
  39. model, criterion = build_yolov5(cfg, is_val)
  40. ## YOLOv6
  41. elif 'yolov6' in args.model:
  42. model, criterion = build_yolov6(cfg, is_val)
  43. ## YOLOv7
  44. elif 'yolov7' in args.model:
  45. model, criterion = build_yolov7(cfg, is_val)
  46. ## YOLOv8
  47. elif 'yolov8' in args.model:
  48. model, criterion = build_yolov8(cfg, is_val)
  49. ## GElan
  50. elif 'yolov9' in args.model:
  51. model, criterion = build_gelan(cfg, is_val)
  52. ## YOLOv10
  53. elif 'yolov10' in args.model:
  54. model, criterion = build_yolov10(cfg, is_val)
  55. ## YOLO11
  56. elif 'yolo11' in args.model:
  57. model, criterion = build_yolo11(cfg, is_val)
  58. ## Yolof
  59. elif 'yolof' in args.model:
  60. model, criterion = build_yolof(cfg, is_val)
  61. ## Fcos
  62. elif 'fcos' in args.model:
  63. model, criterion = build_fcos(cfg, is_val)
  64. ## RT-DETR
  65. elif 'rtdetr' in args.model:
  66. model, criterion = build_rtdetr(cfg, is_val)
  67. if is_val:
  68. # ------------ Load pretrained weight ------------
  69. if hasattr(args, "pretrained") and args.pretrained is not None:
  70. print('Loading COCO pretrained weight ...')
  71. checkpoint = torch.load(args.pretrained, map_location='cpu')
  72. # checkpoint state dict
  73. checkpoint_state_dict = checkpoint.pop("model")
  74. # model state dict
  75. model_state_dict = model.state_dict()
  76. # check
  77. for k in list(checkpoint_state_dict.keys()):
  78. if k in model_state_dict:
  79. shape_model = tuple(model_state_dict[k].shape)
  80. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  81. if shape_model != shape_checkpoint:
  82. checkpoint_state_dict.pop(k)
  83. print(k)
  84. else:
  85. checkpoint_state_dict.pop(k)
  86. print(k)
  87. model.load_state_dict(checkpoint_state_dict, strict=False)
  88. # ------------ Keep training from the given checkpoint ------------
  89. if hasattr(args, "resume") and args.resume and args.resume != "None":
  90. checkpoint = torch.load(args.resume, map_location='cpu')
  91. # checkpoint state dict
  92. try:
  93. checkpoint_state_dict = checkpoint.pop("model")
  94. print('Load model from the checkpoint: ', args.resume)
  95. model.load_state_dict(checkpoint_state_dict)
  96. del checkpoint, checkpoint_state_dict
  97. except:
  98. print("No model in the given checkpoint.")
  99. return model, criterion
  100. else:
  101. return model