__init__.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. from .yolov3.build import build_yolov3
  7. from .yolov4.build import build_yolov4
  8. from .yolox.build import build_yolox
  9. # build object detector
  10. def build_model(args,
  11. model_cfg,
  12. device,
  13. num_classes=80,
  14. trainable=False):
  15. # YOLOv1
  16. if args.model == 'yolov1':
  17. model, criterion = build_yolov1(
  18. args, model_cfg, device, num_classes, trainable)
  19. # YOLOv2
  20. elif args.model == 'yolov2':
  21. model, criterion = build_yolov2(
  22. args, model_cfg, device, num_classes, trainable)
  23. # YOLOv3
  24. elif args.model == 'yolov3':
  25. model, criterion = build_yolov3(
  26. args, model_cfg, device, num_classes, trainable)
  27. # YOLOv4
  28. elif args.model == 'yolov4':
  29. model, criterion = build_yolov4(
  30. args, model_cfg, device, num_classes, trainable)
  31. # YOLOX
  32. elif args.model == 'yolox':
  33. model, criterion = build_yolox(
  34. args, model_cfg, device, num_classes, trainable)
  35. if trainable:
  36. # Load pretrained weight
  37. if args.pretrained is not None:
  38. print('Loading COCO pretrained weight ...')
  39. checkpoint = torch.load(args.pretrained, map_location='cpu')
  40. # checkpoint state dict
  41. checkpoint_state_dict = checkpoint.pop("model")
  42. # model state dict
  43. model_state_dict = model.state_dict()
  44. # check
  45. for k in list(checkpoint_state_dict.keys()):
  46. if k in model_state_dict:
  47. shape_model = tuple(model_state_dict[k].shape)
  48. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  49. if shape_model != shape_checkpoint:
  50. checkpoint_state_dict.pop(k)
  51. print(k)
  52. else:
  53. checkpoint_state_dict.pop(k)
  54. print(k)
  55. model.load_state_dict(checkpoint_state_dict, strict=False)
  56. # keep training
  57. if args.resume is not None:
  58. print('keep training: ', args.resume)
  59. checkpoint = torch.load(args.resume, map_location='cpu')
  60. # checkpoint state dict
  61. checkpoint_state_dict = checkpoint.pop("model")
  62. model.load_state_dict(checkpoint_state_dict)
  63. return model, criterion
  64. else:
  65. return model