__init__.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. # YOLO series
  5. from .yolov1.build import build_yolov1
  6. from .yolov2.build import build_yolov2
  7. from .yolov3.build import build_yolov3
  8. from .yolov4.build import build_yolov4
  9. from .yolov5.build import build_yolov5
  10. from .yolov7.build import build_yolov7
  11. from .yolov8.build import build_yolov8
  12. from .yolox.build import build_yolox
  13. # My RTCDet series
  14. from .rtcdet.build import build_rtcdet
  15. from .rtdetr.build import build_rtdetr
  16. # build object detector
  17. def build_model(args,
  18. model_cfg,
  19. device,
  20. num_classes=80,
  21. trainable=False,
  22. deploy=False):
  23. # YOLOv1
  24. if args.model == 'yolov1':
  25. model, criterion = build_yolov1(
  26. args, model_cfg, device, num_classes, trainable, deploy)
  27. # YOLOv2
  28. elif args.model == 'yolov2':
  29. model, criterion = build_yolov2(
  30. args, model_cfg, device, num_classes, trainable, deploy)
  31. # YOLOv3
  32. elif args.model in ['yolov3', 'yolov3_tiny']:
  33. model, criterion = build_yolov3(
  34. args, model_cfg, device, num_classes, trainable, deploy)
  35. # YOLOv4
  36. elif args.model in ['yolov4', 'yolov4_tiny']:
  37. model, criterion = build_yolov4(
  38. args, model_cfg, device, num_classes, trainable, deploy)
  39. # YOLOv5
  40. elif args.model in ['yolov5_n', 'yolov5_s', 'yolov5_m', 'yolov5_l', 'yolov5_x']:
  41. model, criterion = build_yolov5(
  42. args, model_cfg, device, num_classes, trainable, deploy)
  43. # YOLOv5-AdamW
  44. elif args.model in ['yolov5_n_adamw', 'yolov5_s_adamw', 'yolov5_m_adamw', 'yolov5_l_adamw', 'yolov5_x_adamw']:
  45. model, criterion = build_yolov5(
  46. args, model_cfg, device, num_classes, trainable, deploy)
  47. # YOLOv7
  48. elif args.model in ['yolov7_tiny', 'yolov7', 'yolov7_x']:
  49. model, criterion = build_yolov7(
  50. args, model_cfg, device, num_classes, trainable, deploy)
  51. # YOLOv8
  52. elif args.model in ['yolov8_n', 'yolov8_s', 'yolov8_m', 'yolov8_l', 'yolov8_x']:
  53. model, criterion = build_yolov8(
  54. args, model_cfg, device, num_classes, trainable, deploy)
  55. # YOLOX
  56. elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
  57. model, criterion = build_yolox(
  58. args, model_cfg, device, num_classes, trainable, deploy)
  59. # YOLOX-AdamW
  60. elif args.model in ['yolox_n_adamw', 'yolox_s_adamw', 'yolox_m_adamw', 'yolox_l_adamw', 'yolox_x_adamw']:
  61. model, criterion = build_yolox(
  62. args, model_cfg, device, num_classes, trainable, deploy)
  63. # RTCDet
  64. elif args.model in ['rtcdet_n', 'rtcdet_t', 'rtcdet_s', 'rtcdet_m', 'rtcdet_l', 'rtcdet_x']:
  65. model, criterion = build_rtcdet(
  66. args, model_cfg, device, num_classes, trainable, deploy)
  67. # RT-DETR
  68. elif args.model in ['rtdetr_r18', 'rtdetr_r34', 'rtdetr_r50', 'rtdetr_r101']:
  69. model, criterion = build_rtdetr(
  70. args, model_cfg, num_classes, trainable, deploy)
  71. if trainable:
  72. # Load pretrained weight
  73. if args.pretrained is not None:
  74. print('Loading COCO pretrained weight ...')
  75. checkpoint = torch.load(args.pretrained, map_location='cpu')
  76. # checkpoint state dict
  77. checkpoint_state_dict = checkpoint.pop("model")
  78. # model state dict
  79. model_state_dict = model.state_dict()
  80. # check
  81. for k in list(checkpoint_state_dict.keys()):
  82. if k in model_state_dict:
  83. shape_model = tuple(model_state_dict[k].shape)
  84. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  85. if shape_model != shape_checkpoint:
  86. checkpoint_state_dict.pop(k)
  87. print(k)
  88. else:
  89. checkpoint_state_dict.pop(k)
  90. print(k)
  91. model.load_state_dict(checkpoint_state_dict, strict=False)
  92. # keep training
  93. if args.resume and args.resume != "None":
  94. print('keep training: ', args.resume)
  95. checkpoint = torch.load(args.resume, map_location='cpu')
  96. # checkpoint state dict
  97. checkpoint_state_dict = checkpoint.pop("model")
  98. model.load_state_dict(checkpoint_state_dict)
  99. del checkpoint, checkpoint_state_dict
  100. return model, criterion
  101. else:
  102. return model