__init__.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. from .yolov3.build import build_yolov3
  7. from .yolov4.build import build_yolov4
  8. from .yolov5.build import build_yolov5
  9. from .yolov7.build import build_yolov7
  10. from .yolox.build import build_yolox
  11. from .yolox2.build import build_yolox2
  12. from .rtdetr.build import build_rtdetr
  13. from .e2eyolo.build import build_e2eyolo
  14. # build object detector
  15. def build_model(args,
  16. model_cfg,
  17. device,
  18. num_classes=80,
  19. trainable=False,
  20. deploy=False):
  21. # YOLOv1
  22. if args.model == 'yolov1':
  23. model, criterion = build_yolov1(
  24. args, model_cfg, device, num_classes, trainable, deploy)
  25. # YOLOv2
  26. elif args.model == 'yolov2':
  27. model, criterion = build_yolov2(
  28. args, model_cfg, device, num_classes, trainable, deploy)
  29. # YOLOv3
  30. elif args.model in ['yolov3', 'yolov3_t']:
  31. model, criterion = build_yolov3(
  32. args, model_cfg, device, num_classes, trainable, deploy)
  33. # YOLOv4
  34. elif args.model in ['yolov4', 'yolov4_t']:
  35. model, criterion = build_yolov4(
  36. args, model_cfg, device, num_classes, trainable, deploy)
  37. # YOLOv5
  38. elif args.model in ['yolov5_n', 'yolov5_s', 'yolov5_m', 'yolov5_l', 'yolov5_x']:
  39. model, criterion = build_yolov5(
  40. args, model_cfg, device, num_classes, trainable, deploy)
  41. # YOLOv7
  42. elif args.model in ['yolov7_t', 'yolov7_l', 'yolov7_x']:
  43. model, criterion = build_yolov7(
  44. args, model_cfg, device, num_classes, trainable, deploy)
  45. # YOLOX
  46. elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
  47. model, criterion = build_yolox(
  48. args, model_cfg, device, num_classes, trainable, deploy)
  49. # YOLOX2
  50. elif args.model in ['yolox2_n', 'yolox2_s', 'yolox2_m', 'yolox2_l', 'yolox2_x']:
  51. model, criterion = build_yolox2(
  52. args, model_cfg, device, num_classes, trainable, deploy)
  53. # RT-DETR
  54. elif args.model in ['rtdetr_n', 'rtdetr_s', 'rtdetr_m', 'rtdetr_l', 'rtdetr_x']:
  55. model, criterion = build_rtdetr(
  56. args, model_cfg, device, num_classes, trainable, deploy)
  57. # E2E-YOLO
  58. elif args.model in ['e2eyolo_n', 'e2eyolo_s', 'e2eyolo_m', 'e2eyolo_l', 'e2eyolo_x']:
  59. model, criterion = build_e2eyolo(
  60. args, model_cfg, device, num_classes, trainable, deploy)
  61. if trainable:
  62. # Load pretrained weight
  63. if args.pretrained is not None:
  64. print('Loading COCO pretrained weight ...')
  65. checkpoint = torch.load(args.pretrained, map_location='cpu')
  66. # checkpoint state dict
  67. checkpoint_state_dict = checkpoint.pop("model")
  68. # model state dict
  69. model_state_dict = model.state_dict()
  70. # check
  71. for k in list(checkpoint_state_dict.keys()):
  72. if k in model_state_dict:
  73. shape_model = tuple(model_state_dict[k].shape)
  74. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  75. if shape_model != shape_checkpoint:
  76. checkpoint_state_dict.pop(k)
  77. print(k)
  78. else:
  79. checkpoint_state_dict.pop(k)
  80. print(k)
  81. model.load_state_dict(checkpoint_state_dict, strict=False)
  82. # keep training
  83. if args.resume is not None:
  84. print('keep training: ', args.resume)
  85. checkpoint = torch.load(args.resume, map_location='cpu')
  86. # checkpoint state dict
  87. checkpoint_state_dict = checkpoint.pop("model")
  88. model.load_state_dict(checkpoint_state_dict)
  89. return model, criterion
  90. else:
  91. return model