__init__.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov8.build import build_yolov8
  6. from .rtdetr.build import build_rtdetr
  7. # build object detector
  8. def build_model(args, cfg, is_val=False):
  9. # ------------ build object detector ------------
  10. ## YOLOv8
  11. if 'yolov1' in args.model:
  12. model, criterion = build_yolov1(cfg, is_val)
  13. elif 'yolov8' in args.model:
  14. model, criterion = build_yolov8(cfg, is_val)
  15. ## RT-DETR
  16. elif 'rtdetr' in args.model:
  17. model, criterion = build_rtdetr(cfg, is_val)
  18. if is_val:
  19. # ------------ Load pretrained weight ------------
  20. if args.pretrained is not None:
  21. print('Loading COCO pretrained weight ...')
  22. checkpoint = torch.load(args.pretrained, map_location='cpu')
  23. # checkpoint state dict
  24. checkpoint_state_dict = checkpoint.pop("model")
  25. # model state dict
  26. model_state_dict = model.state_dict()
  27. # check
  28. for k in list(checkpoint_state_dict.keys()):
  29. if k in model_state_dict:
  30. shape_model = tuple(model_state_dict[k].shape)
  31. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  32. if shape_model != shape_checkpoint:
  33. checkpoint_state_dict.pop(k)
  34. print(k)
  35. else:
  36. checkpoint_state_dict.pop(k)
  37. print(k)
  38. model.load_state_dict(checkpoint_state_dict, strict=False)
  39. # ------------ Keep training from the given checkpoint ------------
  40. if args.resume and args.resume != "None":
  41. checkpoint = torch.load(args.resume, map_location='cpu')
  42. # checkpoint state dict
  43. try:
  44. checkpoint_state_dict = checkpoint.pop("model")
  45. print('Load model from the checkpoint: ', args.resume)
  46. model.load_state_dict(checkpoint_state_dict)
  47. del checkpoint, checkpoint_state_dict
  48. except:
  49. print("No model in the given checkpoint.")
  50. return model, criterion
  51. else:
  52. return model