__init__.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. from .yolov3.build import build_yolov3
  7. from .yolov4.build import build_yolov4
  8. from .yolov8.build import build_yolov8
  9. from .rtdetr.build import build_rtdetr
  10. # build object detector
  11. def build_model(args, cfg, is_val=False):
  12. # ------------ build object detector ------------
  13. ## Modified YOLOv1
  14. if 'yolov1' in args.model:
  15. model, criterion = build_yolov1(cfg, is_val)
  16. ## Modified YOLOv2
  17. elif 'yolov2' in args.model:
  18. model, criterion = build_yolov2(cfg, is_val)
  19. ## Modified YOLOv3
  20. elif 'yolov3' in args.model:
  21. model, criterion = build_yolov3(cfg, is_val)
  22. ## Modified YOLOv4
  23. elif 'yolov4' in args.model:
  24. model, criterion = build_yolov4(cfg, is_val)
  25. ## YOLOv8
  26. elif 'yolov8' in args.model:
  27. model, criterion = build_yolov8(cfg, is_val)
  28. ## RT-DETR
  29. elif 'rtdetr' in args.model:
  30. model, criterion = build_rtdetr(cfg, is_val)
  31. if is_val:
  32. # ------------ Load pretrained weight ------------
  33. if args.pretrained is not None:
  34. print('Loading COCO pretrained weight ...')
  35. checkpoint = torch.load(args.pretrained, map_location='cpu')
  36. # checkpoint state dict
  37. checkpoint_state_dict = checkpoint.pop("model")
  38. # model state dict
  39. model_state_dict = model.state_dict()
  40. # check
  41. for k in list(checkpoint_state_dict.keys()):
  42. if k in model_state_dict:
  43. shape_model = tuple(model_state_dict[k].shape)
  44. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  45. if shape_model != shape_checkpoint:
  46. checkpoint_state_dict.pop(k)
  47. print(k)
  48. else:
  49. checkpoint_state_dict.pop(k)
  50. print(k)
  51. model.load_state_dict(checkpoint_state_dict, strict=False)
  52. # ------------ Keep training from the given checkpoint ------------
  53. if args.resume and args.resume != "None":
  54. checkpoint = torch.load(args.resume, map_location='cpu')
  55. # checkpoint state dict
  56. try:
  57. checkpoint_state_dict = checkpoint.pop("model")
  58. print('Load model from the checkpoint: ', args.resume)
  59. model.load_state_dict(checkpoint_state_dict)
  60. del checkpoint, checkpoint_state_dict
  61. except:
  62. print("No model in the given checkpoint.")
  63. return model, criterion
  64. else:
  65. return model