__init__.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. # YOLO series
  5. from .yolov1.build import build_yolov1
  6. from .yolov2.build import build_yolov2
  7. from .yolov3.build import build_yolov3
  8. from .yolov4.build import build_yolov4
  9. from .yolov5.build import build_yolov5
  10. from .yolov7.build import build_yolov7
  11. from .yolov8.build import build_yolov8
  12. from .yolox.build import build_yolox
  13. # My RTCDet
  14. from .rtcdet.build import build_rtcdet
  15. # My RTRDet
  16. from .rtrdet.build import build_rtrdet
  17. # build object detector
  18. def build_model(args,
  19. model_cfg,
  20. device,
  21. num_classes=80,
  22. trainable=False,
  23. deploy=False):
  24. # YOLOv1
  25. if args.model == 'yolov1':
  26. model, criterion = build_yolov1(
  27. args, model_cfg, device, num_classes, trainable, deploy)
  28. # YOLOv2
  29. elif args.model == 'yolov2':
  30. model, criterion = build_yolov2(
  31. args, model_cfg, device, num_classes, trainable, deploy)
  32. # YOLOv3
  33. elif args.model in ['yolov3', 'yolov3_tiny']:
  34. model, criterion = build_yolov3(
  35. args, model_cfg, device, num_classes, trainable, deploy)
  36. # YOLOv4
  37. elif args.model in ['yolov4', 'yolov4_tiny']:
  38. model, criterion = build_yolov4(
  39. args, model_cfg, device, num_classes, trainable, deploy)
  40. # YOLOv5
  41. elif args.model in ['yolov5_n', 'yolov5_s', 'yolov5_m', 'yolov5_l', 'yolov5_x']:
  42. model, criterion = build_yolov5(
  43. args, model_cfg, device, num_classes, trainable, deploy)
  44. # YOLOv7
  45. elif args.model in ['yolov7_tiny', 'yolov7', 'yolov7_x']:
  46. model, criterion = build_yolov7(
  47. args, model_cfg, device, num_classes, trainable, deploy)
  48. # YOLOv8
  49. elif args.model in ['yolov8_n', 'yolov8_s', 'yolov8_m', 'yolov8_l', 'yolov8_x']:
  50. model, criterion = build_yolov8(
  51. args, model_cfg, device, num_classes, trainable, deploy)
  52. # YOLOX
  53. elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
  54. model, criterion = build_yolox(
  55. args, model_cfg, device, num_classes, trainable, deploy)
  56. # RTCDet
  57. elif args.model in ['rtcdet_p', 'rtcdet_n', 'rtcdet_t', 'rtcdet_s', 'rtcdet_m', 'rtcdet_l', 'rtcdet_x']:
  58. model, criterion = build_rtcdet(
  59. args, model_cfg, device, num_classes, trainable, deploy)
  60. # RTRDet
  61. elif args.model in ['rtrdet_p', 'rtrdet_n', 'rtrdet_t', 'rtrdet_s', 'rtrdet_m', 'rtrdet_l', 'rtrdet_x']:
  62. model, criterion = build_rtrdet(
  63. args, model_cfg, device, num_classes, trainable, deploy)
  64. if trainable:
  65. # Load pretrained weight
  66. if args.pretrained is not None:
  67. print('Loading COCO pretrained weight ...')
  68. checkpoint = torch.load(args.pretrained, map_location='cpu')
  69. # checkpoint state dict
  70. checkpoint_state_dict = checkpoint.pop("model")
  71. # model state dict
  72. model_state_dict = model.state_dict()
  73. # check
  74. for k in list(checkpoint_state_dict.keys()):
  75. if k in model_state_dict:
  76. shape_model = tuple(model_state_dict[k].shape)
  77. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  78. if shape_model != shape_checkpoint:
  79. checkpoint_state_dict.pop(k)
  80. print(k)
  81. else:
  82. checkpoint_state_dict.pop(k)
  83. print(k)
  84. model.load_state_dict(checkpoint_state_dict, strict=False)
  85. # keep training
  86. if args.resume is not None:
  87. print('keep training: ', args.resume)
  88. checkpoint = torch.load(args.resume, map_location='cpu')
  89. # checkpoint state dict
  90. checkpoint_state_dict = checkpoint.pop("model")
  91. model.load_state_dict(checkpoint_state_dict)
  92. del checkpoint, checkpoint_state_dict
  93. return model, criterion
  94. else:
  95. return model