__init__.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. # YOLO series
  5. from .yolov1.build import build_yolov1
  6. from .yolov2.build import build_yolov2
  7. from .yolov3.build import build_yolov3
  8. from .yolov4.build import build_yolov4
  9. from .yolov5.build import build_yolov5
  10. from .yolov7.build import build_yolov7
  11. from .yolovx.build import build_yolovx
  12. # My custom YOLO
  13. from .yolox.build import build_yolox
  14. # Real-time DETR
  15. from .rtdetr.build import build_rtdetr
  16. # build object detector
  17. def build_model(args,
  18. model_cfg,
  19. device,
  20. num_classes=80,
  21. trainable=False,
  22. deploy=False):
  23. # YOLOv1
  24. if args.model == 'yolov1':
  25. model, criterion = build_yolov1(
  26. args, model_cfg, device, num_classes, trainable, deploy)
  27. # YOLOv2
  28. elif args.model == 'yolov2':
  29. model, criterion = build_yolov2(
  30. args, model_cfg, device, num_classes, trainable, deploy)
  31. # YOLOv3
  32. elif args.model in ['yolov3', 'yolov3_tiny']:
  33. model, criterion = build_yolov3(
  34. args, model_cfg, device, num_classes, trainable, deploy)
  35. # YOLOv4
  36. elif args.model in ['yolov4', 'yolov4_tiny']:
  37. model, criterion = build_yolov4(
  38. args, model_cfg, device, num_classes, trainable, deploy)
  39. # YOLOv5
  40. elif args.model in ['yolov5_n', 'yolov5_s', 'yolov5_m', 'yolov5_l', 'yolov5_x']:
  41. model, criterion = build_yolov5(
  42. args, model_cfg, device, num_classes, trainable, deploy)
  43. # YOLOv7
  44. elif args.model in ['yolov7_tiny', 'yolov7', 'yolov7_x']:
  45. model, criterion = build_yolov7(
  46. args, model_cfg, device, num_classes, trainable, deploy)
  47. # YOLOvx
  48. elif args.model in ['yolovx_n', 'yolovx_t', 'yolovx_s', 'yolovx_m', 'yolovx_l', 'yolovx_x']:
  49. model, criterion = build_yolovx(
  50. args, model_cfg, device, num_classes, trainable, deploy)
  51. # YOLOX
  52. elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
  53. model, criterion = build_yolox(
  54. args, model_cfg, device, num_classes, trainable, deploy)
  55. # RT-DETR
  56. elif args.model in ['rtdetr_n', 'rtdetr_s', 'rtdetr_m', 'rtdetr_l', 'rtdetr_x']:
  57. model, criterion = build_rtdetr(
  58. args, model_cfg, device, num_classes, trainable, deploy)
  59. if trainable:
  60. # Load pretrained weight
  61. if args.pretrained is not None:
  62. print('Loading COCO pretrained weight ...')
  63. checkpoint = torch.load(args.pretrained, map_location='cpu')
  64. # checkpoint state dict
  65. checkpoint_state_dict = checkpoint.pop("model")
  66. # model state dict
  67. model_state_dict = model.state_dict()
  68. # check
  69. for k in list(checkpoint_state_dict.keys()):
  70. if k in model_state_dict:
  71. shape_model = tuple(model_state_dict[k].shape)
  72. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  73. if shape_model != shape_checkpoint:
  74. checkpoint_state_dict.pop(k)
  75. print(k)
  76. else:
  77. checkpoint_state_dict.pop(k)
  78. print(k)
  79. model.load_state_dict(checkpoint_state_dict, strict=False)
  80. # keep training
  81. if args.resume is not None:
  82. print('keep training: ', args.resume)
  83. checkpoint = torch.load(args.resume, map_location='cpu')
  84. # checkpoint state dict
  85. checkpoint_state_dict = checkpoint.pop("model")
  86. model.load_state_dict(checkpoint_state_dict)
  87. del checkpoint, checkpoint_state_dict
  88. return model, criterion
  89. else:
  90. return model