__init__.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. from .yolov3.build import build_yolov3
  7. from .yolov4.build import build_yolov4
  8. from .yolov5.build import build_yolov5
  9. from .yolov7.build import build_yolov7
  10. from .yolov8.build import build_yolov8
  11. from .yolox.build import build_yolox
  12. # build object detector
  13. def build_model(args,
  14. model_cfg,
  15. device,
  16. num_classes=80,
  17. trainable=False):
  18. # YOLOv1
  19. if args.model == 'yolov1':
  20. model, criterion = build_yolov1(
  21. args, model_cfg, device, num_classes, trainable)
  22. # YOLOv2
  23. elif args.model == 'yolov2':
  24. model, criterion = build_yolov2(
  25. args, model_cfg, device, num_classes, trainable)
  26. # YOLOv3
  27. elif args.model == 'yolov3':
  28. model, criterion = build_yolov3(
  29. args, model_cfg, device, num_classes, trainable)
  30. # YOLOv4
  31. elif args.model == 'yolov4':
  32. model, criterion = build_yolov4(
  33. args, model_cfg, device, num_classes, trainable)
  34. # YOLOv5
  35. elif args.model in ['yolov5_n', 'yolov5_s', 'yolov5_m', 'yolov5_l', 'yolov5_x']:
  36. model, criterion = build_yolov5(
  37. args, model_cfg, device, num_classes, trainable)
  38. # YOLOv7
  39. elif args.model in ['yolov7_t', 'yolov7_l', 'yolov7_x']:
  40. model, criterion = build_yolov7(
  41. args, model_cfg, device, num_classes, trainable)
  42. # YOLOv8
  43. elif args.model in ['yolov8_n', 'yolov8_s', 'yolov8_m', 'yolov8_l', 'yolov8_x']:
  44. model, criterion = build_yolov8(
  45. args, model_cfg, device, num_classes, trainable)
  46. # YOLOX
  47. elif args.model == 'yolox':
  48. model, criterion = build_yolox(
  49. args, model_cfg, device, num_classes, trainable)
  50. if trainable:
  51. # Load pretrained weight
  52. if args.pretrained is not None:
  53. print('Loading COCO pretrained weight ...')
  54. checkpoint = torch.load(args.pretrained, map_location='cpu')
  55. # checkpoint state dict
  56. checkpoint_state_dict = checkpoint.pop("model")
  57. # model state dict
  58. model_state_dict = model.state_dict()
  59. # check
  60. for k in list(checkpoint_state_dict.keys()):
  61. if k in model_state_dict:
  62. shape_model = tuple(model_state_dict[k].shape)
  63. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  64. if shape_model != shape_checkpoint:
  65. checkpoint_state_dict.pop(k)
  66. print(k)
  67. else:
  68. checkpoint_state_dict.pop(k)
  69. print(k)
  70. model.load_state_dict(checkpoint_state_dict, strict=False)
  71. # keep training
  72. if args.resume is not None:
  73. print('keep training: ', args.resume)
  74. checkpoint = torch.load(args.resume, map_location='cpu')
  75. # checkpoint state dict
  76. checkpoint_state_dict = checkpoint.pop("model")
  77. # check
  78. new_checkpoint_state_dict = {}
  79. for k in list(checkpoint_state_dict.keys()):
  80. v = checkpoint_state_dict[k]
  81. if 'reduce_layer_3' in k:
  82. k_new = k.split('.')
  83. k_new[1] = 'downsample_layer_1'
  84. k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
  85. elif 'reduce_layer_4' in k:
  86. k_new = k.split('.')
  87. k_new[1] = 'downsample_layer_2'
  88. k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
  89. new_checkpoint_state_dict[k] = v
  90. model.load_state_dict(new_checkpoint_state_dict)
  91. return model, criterion
  92. else:
  93. return model