__init__.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. from .yolov3.build import build_yolov3
  7. from .yolov5.build import build_yolov5
  8. from .yolov5_af.build import build_yolov5af
  9. from .yolov6.build import build_yolov6
  10. from .yolov8.build import build_yolov8
  11. from .yolov8_e2e.build import build_yolov8_e2e
  12. from .gelan.build import build_gelan
  13. from .rtdetr.build import build_rtdetr
  14. # build object detector
  15. def build_model(args, cfg, is_val=False):
  16. # ------------ build object detector ------------
  17. ## Modified YOLOv1
  18. if 'yolov1' in args.model:
  19. model, criterion = build_yolov1(cfg, is_val)
  20. ## Modified YOLOv2
  21. elif 'yolov2' in args.model:
  22. model, criterion = build_yolov2(cfg, is_val)
  23. ## Modified YOLOv3
  24. elif 'yolov3' in args.model:
  25. model, criterion = build_yolov3(cfg, is_val)
  26. ## Anchor-free YOLOv5
  27. elif 'yolov5_af' in args.model:
  28. model, criterion = build_yolov5af(cfg, is_val)
  29. ## Modified YOLOv5
  30. elif 'yolov5' in args.model:
  31. model, criterion = build_yolov5(cfg, is_val)
  32. ## YOLOv6
  33. elif 'yolov6' in args.model:
  34. model, criterion = build_yolov6(cfg, is_val)
  35. ## YOLOv8
  36. elif 'yolov8_e2e' in args.model:
  37. model, criterion = build_yolov8_e2e(cfg, is_val)
  38. ## YOLOv8
  39. elif 'yolov8' in args.model:
  40. model, criterion = build_yolov8(cfg, is_val)
  41. ## GElan
  42. elif 'gelan' in args.model:
  43. model, criterion = build_gelan(cfg, is_val)
  44. ## RT-DETR
  45. elif 'rtdetr' in args.model:
  46. model, criterion = build_rtdetr(cfg, is_val)
  47. if is_val:
  48. # ------------ Load pretrained weight ------------
  49. if hasattr(args, "pretrained") and args.pretrained is not None:
  50. print('Loading COCO pretrained weight ...')
  51. checkpoint = torch.load(args.pretrained, map_location='cpu')
  52. # checkpoint state dict
  53. checkpoint_state_dict = checkpoint.pop("model")
  54. # model state dict
  55. model_state_dict = model.state_dict()
  56. # check
  57. for k in list(checkpoint_state_dict.keys()):
  58. if k in model_state_dict:
  59. shape_model = tuple(model_state_dict[k].shape)
  60. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  61. if shape_model != shape_checkpoint:
  62. checkpoint_state_dict.pop(k)
  63. print(k)
  64. else:
  65. checkpoint_state_dict.pop(k)
  66. print(k)
  67. model.load_state_dict(checkpoint_state_dict, strict=False)
  68. # ------------ Keep training from the given checkpoint ------------
  69. if hasattr(args, "resume") and args.resume and args.resume != "None":
  70. checkpoint = torch.load(args.resume, map_location='cpu')
  71. # checkpoint state dict
  72. try:
  73. checkpoint_state_dict = checkpoint.pop("model")
  74. print('Load model from the checkpoint: ', args.resume)
  75. model.load_state_dict(checkpoint_state_dict)
  76. del checkpoint, checkpoint_state_dict
  77. except:
  78. print("No model in the given checkpoint.")
  79. return model, criterion
  80. else:
  81. return model