__init__.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. from .yolov2.build import build_yolov2
  6. # build object detector
  7. def build_model(args,
  8. model_cfg,
  9. device,
  10. num_classes=80,
  11. trainable=False):
  12. # YOLOv1
  13. if args.model == 'yolov1':
  14. model, criterion = build_yolov1(
  15. args, model_cfg, device, num_classes, trainable)
  16. # YOLOv2
  17. elif args.model == 'yolov2':
  18. model, criterion = build_yolov2(
  19. args, model_cfg, device, num_classes, trainable)
  20. if trainable:
  21. # Load pretrained weight
  22. if args.pretrained is not None:
  23. print('Loading COCO pretrained weight ...')
  24. checkpoint = torch.load(args.pretrained, map_location='cpu')
  25. # checkpoint state dict
  26. checkpoint_state_dict = checkpoint.pop("model")
  27. # model state dict
  28. model_state_dict = model.state_dict()
  29. # check
  30. for k in list(checkpoint_state_dict.keys()):
  31. if k in model_state_dict:
  32. shape_model = tuple(model_state_dict[k].shape)
  33. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  34. if shape_model != shape_checkpoint:
  35. checkpoint_state_dict.pop(k)
  36. print(k)
  37. else:
  38. checkpoint_state_dict.pop(k)
  39. print(k)
  40. model.load_state_dict(checkpoint_state_dict, strict=False)
  41. # keep training
  42. if args.resume is not None:
  43. print('keep training: ', args.resume)
  44. checkpoint = torch.load(args.resume, map_location='cpu')
  45. # checkpoint state dict
  46. checkpoint_state_dict = checkpoint.pop("model")
  47. model.load_state_dict(checkpoint_state_dict)
  48. return model, criterion
  49. else:
  50. return model