__init__.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. from .yolov3.build import build_yolov3
  7. from .yolov5.build import build_yolov5
  8. from .yolov5_af.build import build_yolov5af
  9. from .yolov8.build import build_yolov8
  10. from .gelan.build import build_gelan
  11. from .rtdetr.build import build_rtdetr
  12. # build object detector
  13. def build_model(args, cfg, is_val=False):
  14. # ------------ build object detector ------------
  15. ## Modified YOLOv1
  16. if 'yolov1' in args.model:
  17. model, criterion = build_yolov1(cfg, is_val)
  18. ## Modified YOLOv2
  19. elif 'yolov2' in args.model:
  20. model, criterion = build_yolov2(cfg, is_val)
  21. ## Modified YOLOv3
  22. elif 'yolov3' in args.model:
  23. model, criterion = build_yolov3(cfg, is_val)
  24. ## Anchor-free YOLOv5
  25. elif 'yolov5_af' in args.model:
  26. model, criterion = build_yolov5af(cfg, is_val)
  27. ## Modified YOLOv5
  28. elif 'yolov5' in args.model:
  29. model, criterion = build_yolov5(cfg, is_val)
  30. ## YOLOv8
  31. elif 'yolov8' in args.model:
  32. model, criterion = build_yolov8(cfg, is_val)
  33. ## GElan
  34. elif 'gelan' in args.model:
  35. model, criterion = build_gelan(cfg, is_val)
  36. ## RT-DETR
  37. elif 'rtdetr' in args.model:
  38. model, criterion = build_rtdetr(cfg, is_val)
  39. if is_val:
  40. # ------------ Load pretrained weight ------------
  41. if args.pretrained is not None:
  42. print('Loading COCO pretrained weight ...')
  43. checkpoint = torch.load(args.pretrained, map_location='cpu')
  44. # checkpoint state dict
  45. checkpoint_state_dict = checkpoint.pop("model")
  46. # model state dict
  47. model_state_dict = model.state_dict()
  48. # check
  49. for k in list(checkpoint_state_dict.keys()):
  50. if k in model_state_dict:
  51. shape_model = tuple(model_state_dict[k].shape)
  52. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  53. if shape_model != shape_checkpoint:
  54. checkpoint_state_dict.pop(k)
  55. print(k)
  56. else:
  57. checkpoint_state_dict.pop(k)
  58. print(k)
  59. model.load_state_dict(checkpoint_state_dict, strict=False)
  60. # ------------ Keep training from the given checkpoint ------------
  61. if args.resume and args.resume != "None":
  62. checkpoint = torch.load(args.resume, map_location='cpu')
  63. # checkpoint state dict
  64. try:
  65. checkpoint_state_dict = checkpoint.pop("model")
  66. print('Load model from the checkpoint: ', args.resume)
  67. model.load_state_dict(checkpoint_state_dict)
  68. del checkpoint, checkpoint_state_dict
  69. except:
  70. print("No model in the given checkpoint.")
  71. return model, criterion
  72. else:
  73. return model