__init__.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. from .yolov3.build import build_yolov3
  7. from .yolov4.build import build_yolov4
  8. from .yolov5.build import build_yolov5
  9. from .yolov7.build import build_yolov7
  10. from .yolox.build import build_yolox
  11. # build object detector
  12. def build_model(args,
  13. model_cfg,
  14. device,
  15. num_classes=80,
  16. trainable=False):
  17. # YOLOv1
  18. if args.model == 'yolov1':
  19. model, criterion = build_yolov1(
  20. args, model_cfg, device, num_classes, trainable)
  21. # YOLOv2
  22. elif args.model == 'yolov2':
  23. model, criterion = build_yolov2(
  24. args, model_cfg, device, num_classes, trainable)
  25. # YOLOv3
  26. elif args.model == 'yolov3':
  27. model, criterion = build_yolov3(
  28. args, model_cfg, device, num_classes, trainable)
  29. # YOLOv4
  30. elif args.model == 'yolov4':
  31. model, criterion = build_yolov4(
  32. args, model_cfg, device, num_classes, trainable)
  33. # YOLOv5
  34. elif args.model == 'yolov5':
  35. model, criterion = build_yolov5(
  36. args, model_cfg, device, num_classes, trainable)
  37. # YOLOv5
  38. elif args.model == 'yolov7':
  39. model, criterion = build_yolov7(
  40. args, model_cfg, device, num_classes, trainable)
  41. # YOLOX
  42. elif args.model == 'yolox':
  43. model, criterion = build_yolox(
  44. args, model_cfg, device, num_classes, trainable)
  45. if trainable:
  46. # Load pretrained weight
  47. if args.pretrained is not None:
  48. print('Loading COCO pretrained weight ...')
  49. checkpoint = torch.load(args.pretrained, map_location='cpu')
  50. # checkpoint state dict
  51. checkpoint_state_dict = checkpoint.pop("model")
  52. # model state dict
  53. model_state_dict = model.state_dict()
  54. # check
  55. for k in list(checkpoint_state_dict.keys()):
  56. if k in model_state_dict:
  57. shape_model = tuple(model_state_dict[k].shape)
  58. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  59. if shape_model != shape_checkpoint:
  60. checkpoint_state_dict.pop(k)
  61. print(k)
  62. else:
  63. checkpoint_state_dict.pop(k)
  64. print(k)
  65. model.load_state_dict(checkpoint_state_dict, strict=False)
  66. # keep training
  67. if args.resume is not None:
  68. print('keep training: ', args.resume)
  69. checkpoint = torch.load(args.resume, map_location='cpu')
  70. # checkpoint state dict
  71. checkpoint_state_dict = checkpoint.pop("model")
  72. model.load_state_dict(checkpoint_state_dict)
  73. return model, criterion
  74. else:
  75. return model