__init__.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. from .yolov3.build import build_yolov3
  7. from .yolov4.build import build_yolov4
  8. from .yolov5.build import build_yolov5
  9. from .yolov7.build import build_yolov7
  10. from .yolox.build import build_yolox
  11. from .yolox2.build import build_yolox2
  12. from .rtdetr.build import build_rtdetr
  13. # build object detector
  14. def build_model(args,
  15. model_cfg,
  16. device,
  17. num_classes=80,
  18. trainable=False,
  19. deploy=False):
  20. # YOLOv1
  21. if args.model == 'yolov1':
  22. model, criterion = build_yolov1(
  23. args, model_cfg, device, num_classes, trainable, deploy)
  24. # YOLOv2
  25. elif args.model == 'yolov2':
  26. model, criterion = build_yolov2(
  27. args, model_cfg, device, num_classes, trainable, deploy)
  28. # YOLOv3
  29. elif args.model in ['yolov3', 'yolov3_t']:
  30. model, criterion = build_yolov3(
  31. args, model_cfg, device, num_classes, trainable, deploy)
  32. # YOLOv4
  33. elif args.model in ['yolov4', 'yolov4_t']:
  34. model, criterion = build_yolov4(
  35. args, model_cfg, device, num_classes, trainable, deploy)
  36. # YOLOv5
  37. elif args.model in ['yolov5_n', 'yolov5_s', 'yolov5_m', 'yolov5_l', 'yolov5_x']:
  38. model, criterion = build_yolov5(
  39. args, model_cfg, device, num_classes, trainable, deploy)
  40. # YOLOv7
  41. elif args.model in ['yolov7_t', 'yolov7_l', 'yolov7_x']:
  42. model, criterion = build_yolov7(
  43. args, model_cfg, device, num_classes, trainable, deploy)
  44. # YOLOX
  45. elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
  46. model, criterion = build_yolox(
  47. args, model_cfg, device, num_classes, trainable, deploy)
  48. # YOLOX2
  49. elif args.model in ['yolox2_n', 'yolox2_s', 'yolox2_m', 'yolox2_l', 'yolox2_x']:
  50. model, criterion = build_yolox2(
  51. args, model_cfg, device, num_classes, trainable, deploy)
  52. # RT-DETR
  53. elif args.model in ['rtdetr_n', 'rtdetr_s', 'rtdetr_m', 'rtdetr_l', 'rtdetr_x']:
  54. model, criterion = build_rtdetr(
  55. args, model_cfg, device, num_classes, trainable, deploy)
  56. if trainable:
  57. # Load pretrained weight
  58. if args.pretrained is not None:
  59. print('Loading COCO pretrained weight ...')
  60. checkpoint = torch.load(args.pretrained, map_location='cpu')
  61. # checkpoint state dict
  62. checkpoint_state_dict = checkpoint.pop("model")
  63. # model state dict
  64. model_state_dict = model.state_dict()
  65. # check
  66. for k in list(checkpoint_state_dict.keys()):
  67. if k in model_state_dict:
  68. shape_model = tuple(model_state_dict[k].shape)
  69. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  70. if shape_model != shape_checkpoint:
  71. checkpoint_state_dict.pop(k)
  72. print(k)
  73. else:
  74. checkpoint_state_dict.pop(k)
  75. print(k)
  76. model.load_state_dict(checkpoint_state_dict, strict=False)
  77. # keep training
  78. if args.resume is not None:
  79. print('keep training: ', args.resume)
  80. checkpoint = torch.load(args.resume, map_location='cpu')
  81. # checkpoint state dict
  82. checkpoint_state_dict = checkpoint.pop("model")
  83. model.load_state_dict(checkpoint_state_dict)
  84. return model, criterion
  85. else:
  86. return model