__init__.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. # YOLO series
  5. from .yolov1.build import build_yolov1
  6. from .yolov2.build import build_yolov2
  7. from .yolov3.build import build_yolov3
  8. from .yolov4.build import build_yolov4
  9. from .yolov5.build import build_yolov5
  10. from .yolov7.build import build_yolov7
  11. from .rtmdet_v1.build import build_rtmdet_v1
  12. from .rtmdet_v2.build import build_rtmdet_v2
  13. # My custom YOLO
  14. from .yolox.build import build_yolox
  15. # build object detector
  16. def build_model(args,
  17. model_cfg,
  18. device,
  19. num_classes=80,
  20. trainable=False,
  21. deploy=False):
  22. # YOLOv1
  23. if args.model == 'yolov1':
  24. model, criterion = build_yolov1(
  25. args, model_cfg, device, num_classes, trainable, deploy)
  26. # YOLOv2
  27. elif args.model == 'yolov2':
  28. model, criterion = build_yolov2(
  29. args, model_cfg, device, num_classes, trainable, deploy)
  30. # YOLOv3
  31. elif args.model in ['yolov3', 'yolov3_tiny']:
  32. model, criterion = build_yolov3(
  33. args, model_cfg, device, num_classes, trainable, deploy)
  34. # YOLOv4
  35. elif args.model in ['yolov4', 'yolov4_tiny']:
  36. model, criterion = build_yolov4(
  37. args, model_cfg, device, num_classes, trainable, deploy)
  38. # YOLOv5
  39. elif args.model in ['yolov5_n', 'yolov5_s', 'yolov5_m', 'yolov5_l', 'yolov5_x']:
  40. model, criterion = build_yolov5(
  41. args, model_cfg, device, num_classes, trainable, deploy)
  42. # YOLOv7
  43. elif args.model in ['yolov7_tiny', 'yolov7', 'yolov7_x']:
  44. model, criterion = build_yolov7(
  45. args, model_cfg, device, num_classes, trainable, deploy)
  46. # YOLOX
  47. elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
  48. model, criterion = build_yolox(
  49. args, model_cfg, device, num_classes, trainable, deploy)
  50. # My RTMDet-v1
  51. elif args.model in ['rtmdet_v1_n', 'rtmdet_v1_t', 'rtmdet_v1_s', 'rtmdet_v1_m', 'rtmdet_v1_l', 'rtmdet_v1_x']:
  52. model, criterion = build_rtmdet_v1(
  53. args, model_cfg, device, num_classes, trainable, deploy)
  54. # My RTMDet-v2
  55. elif args.model in ['rtmdet_v2_n', 'rtmdet_v2_t', 'rtmdet_v2_s', 'rtmdet_v2_m', 'rtmdet_v2_l', 'rtmdet_v2_x']:
  56. model, criterion = build_rtmdet_v2(
  57. args, model_cfg, device, num_classes, trainable, deploy)
  58. if trainable:
  59. # Load pretrained weight
  60. if args.pretrained is not None:
  61. print('Loading COCO pretrained weight ...')
  62. checkpoint = torch.load(args.pretrained, map_location='cpu')
  63. # checkpoint state dict
  64. checkpoint_state_dict = checkpoint.pop("model")
  65. # model state dict
  66. model_state_dict = model.state_dict()
  67. # check
  68. for k in list(checkpoint_state_dict.keys()):
  69. if k in model_state_dict:
  70. shape_model = tuple(model_state_dict[k].shape)
  71. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  72. if shape_model != shape_checkpoint:
  73. checkpoint_state_dict.pop(k)
  74. print(k)
  75. else:
  76. checkpoint_state_dict.pop(k)
  77. print(k)
  78. model.load_state_dict(checkpoint_state_dict, strict=False)
  79. # keep training
  80. if args.resume is not None:
  81. print('keep training: ', args.resume)
  82. checkpoint = torch.load(args.resume, map_location='cpu')
  83. # checkpoint state dict
  84. checkpoint_state_dict = checkpoint.pop("model")
  85. model.load_state_dict(checkpoint_state_dict)
  86. del checkpoint, checkpoint_state_dict
  87. return model, criterion
  88. else:
  89. return model