__init__.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. from .yolov3.build import build_yolov3
  7. from .yolov4.build import build_yolov4
  8. from .yolov5.build import build_yolov5
  9. from .yolov7.build import build_yolov7
  10. from .yolox.build import build_yolox
  11. from .artdet.build import build_artdet
  12. # build object detector
  13. def build_model(args,
  14. model_cfg,
  15. device,
  16. num_classes=80,
  17. trainable=False,
  18. deploy=False):
  19. # YOLOv1
  20. if args.model == 'yolov1':
  21. model, criterion = build_yolov1(
  22. args, model_cfg, device, num_classes, trainable, deploy)
  23. # YOLOv2
  24. elif args.model == 'yolov2':
  25. model, criterion = build_yolov2(
  26. args, model_cfg, device, num_classes, trainable, deploy)
  27. # YOLOv3
  28. elif args.model in ['yolov3', 'yolov3_t']:
  29. model, criterion = build_yolov3(
  30. args, model_cfg, device, num_classes, trainable, deploy)
  31. # YOLOv4
  32. elif args.model in ['yolov4', 'yolov4_t']:
  33. model, criterion = build_yolov4(
  34. args, model_cfg, device, num_classes, trainable, deploy)
  35. # YOLOv5
  36. elif args.model in ['yolov5_n', 'yolov5_s', 'yolov5_m', 'yolov5_l', 'yolov5_x']:
  37. model, criterion = build_yolov5(
  38. args, model_cfg, device, num_classes, trainable, deploy)
  39. # YOLOv7
  40. elif args.model in ['yolov7_t', 'yolov7_l', 'yolov7_x']:
  41. model, criterion = build_yolov7(
  42. args, model_cfg, device, num_classes, trainable, deploy)
  43. # YOLOX
  44. elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
  45. model, criterion = build_yolox(
  46. args, model_cfg, device, num_classes, trainable, deploy)
  47. # ARTDet
  48. elif args.model in ['artdet_n', 'artdet_s', 'artdet_m', 'artdet_l', 'artdet_x']:
  49. model, criterion = build_artdet(
  50. args, model_cfg, device, num_classes, trainable, deploy)
  51. if trainable:
  52. # Load pretrained weight
  53. if args.pretrained is not None:
  54. print('Loading COCO pretrained weight ...')
  55. checkpoint = torch.load(args.pretrained, map_location='cpu')
  56. # checkpoint state dict
  57. checkpoint_state_dict = checkpoint.pop("model")
  58. # model state dict
  59. model_state_dict = model.state_dict()
  60. # check
  61. for k in list(checkpoint_state_dict.keys()):
  62. if k in model_state_dict:
  63. shape_model = tuple(model_state_dict[k].shape)
  64. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  65. if shape_model != shape_checkpoint:
  66. checkpoint_state_dict.pop(k)
  67. print(k)
  68. else:
  69. checkpoint_state_dict.pop(k)
  70. print(k)
  71. model.load_state_dict(checkpoint_state_dict, strict=False)
  72. # keep training
  73. if args.resume is not None:
  74. print('keep training: ', args.resume)
  75. checkpoint = torch.load(args.resume, map_location='cpu')
  76. # checkpoint state dict
  77. checkpoint_state_dict = checkpoint.pop("model")
  78. model.load_state_dict(checkpoint_state_dict)
  79. return model, criterion
  80. else:
  81. return model