__init__.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. import torch
  4. from .yolov1.build import build_yolov1
  5. # build object detector
  6. def build_model(args,
  7. model_cfg,
  8. device,
  9. num_classes=80,
  10. trainable=False):
  11. # YOLOv1
  12. if args.model == 'yolov1':
  13. model, criterion = build_yolov1(
  14. args, model_cfg, device, num_classes, trainable)
  15. if trainable:
  16. # Load pretrained weight
  17. if args.pretrained is not None:
  18. print('Loading COCO pretrained weight ...')
  19. checkpoint = torch.load(args.pretrained, map_location='cpu')
  20. # checkpoint state dict
  21. checkpoint_state_dict = checkpoint.pop("model")
  22. # model state dict
  23. model_state_dict = model.state_dict()
  24. # check
  25. for k in list(checkpoint_state_dict.keys()):
  26. if k in model_state_dict:
  27. shape_model = tuple(model_state_dict[k].shape)
  28. shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
  29. if shape_model != shape_checkpoint:
  30. checkpoint_state_dict.pop(k)
  31. print(k)
  32. else:
  33. checkpoint_state_dict.pop(k)
  34. print(k)
  35. model.load_state_dict(checkpoint_state_dict, strict=False)
  36. # keep training
  37. if args.resume is not None:
  38. print('keep training: ', args.resume)
  39. checkpoint = torch.load(args.resume, map_location='cpu')
  40. # checkpoint state dict
  41. checkpoint_state_dict = checkpoint.pop("model")
  42. model.load_state_dict(checkpoint_state_dict)
  43. return model, criterion
  44. else:
  45. return model