eval.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import argparse
  2. import torch
  3. from evaluator.voc_evaluator import VOCAPIEvaluator
  4. from evaluator.coco_evaluator import COCOAPIEvaluator
  5. from evaluator.customed_evaluator import CustomedEvaluator
  6. # load transform
  7. from dataset.build import build_dataset, build_transform
  8. # load some utils
  9. from utils.misc import load_weight
  10. from config import build_config
  11. from models import build_model
  12. def parse_args():
  13. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  14. # Basic setting
  15. parser.add_argument('-size', '--img_size', default=640, type=int,
  16. help='the max size of input image')
  17. parser.add_argument('--cuda', action='store_true', default=False,
  18. help='Use cuda')
  19. # Model setting
  20. parser.add_argument('-m', '--model', default='yolov1', type=str,
  21. help='build yolo')
  22. parser.add_argument('--weight', default=None,
  23. type=str, help='Trained state_dict file path to open')
  24. parser.add_argument('-p', '--pretrained', default=None, type=str,
  25. help='load pretrained weight')
  26. parser.add_argument('-r', '--resume', default=None, type=str,
  27. help='keep training')
  28. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  29. help='fuse Conv & BN')
  30. parser.add_argument('--fuse_rep_conv', action='store_true', default=False,
  31. help='fuse Conv & BN')
  32. # Data setting
  33. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',
  34. help='data root')
  35. parser.add_argument('-d', '--dataset', default='coco',
  36. help='coco, voc.')
  37. # TTA
  38. parser.add_argument('-tta', '--test_aug', action='store_true', default=False,
  39. help='use test augmentation.')
  40. return parser.parse_args()
  41. def voc_test(cfg, model, data_dir, device, transform):
  42. evaluator = VOCAPIEvaluator(cfg=cfg,
  43. data_dir=data_dir,
  44. device=device,
  45. transform=transform,
  46. display=True)
  47. # VOC evaluation
  48. evaluator.evaluate(model)
  49. def coco_test(cfg, model, data_dir, device, transform):
  50. # eval
  51. evaluator = COCOAPIEvaluator(
  52. cfg=cfg,
  53. data_dir=data_dir,
  54. device=device,
  55. transform=transform)
  56. # COCO evaluation
  57. evaluator.evaluate(model)
  58. def customed_test(cfg, model, data_dir, device, transform):
  59. evaluator = CustomedEvaluator(
  60. cfg=cfg,
  61. data_dir=data_dir,
  62. device=device,
  63. image_set='val',
  64. transform=transform)
  65. # WiderFace evaluation
  66. evaluator.evaluate(model)
  67. if __name__ == '__main__':
  68. args = parse_args()
  69. # cuda
  70. if args.cuda:
  71. print('use cuda')
  72. device = torch.device("cuda")
  73. else:
  74. device = torch.device("cpu")
  75. # Dataset & Model Config
  76. cfg = build_config(args)
  77. # Transform
  78. transform = build_transform(cfg, is_train=False)
  79. # Dataset
  80. dataset = build_dataset(args, cfg, transform, is_train=False)
  81. # build model
  82. model, _ = build_model(args, cfg, is_val=True)
  83. # load trained weight
  84. model = load_weight(model, args.weight, args.fuse_conv_bn)
  85. model.to(device).eval()
  86. # evaluation
  87. with torch.no_grad():
  88. if args.dataset == 'voc':
  89. voc_test(cfg, model, args.root, device, transform)
  90. elif args.dataset == 'coco':
  91. coco_test(cfg, model, args.root, device, transform)
  92. elif args.dataset == 'customed':
  93. customed_test(cfg, model, args.root, device, transform)