eval.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import argparse
  2. import torch
  3. # evaluators
  4. from evaluator.voc_evaluator import VOCAPIEvaluator
  5. from evaluator.coco_evaluator import COCOAPIEvaluator
  6. from evaluator.custom_evaluator import CustomEvaluator
  7. # load transform
  8. from dataset.build import build_dataset, build_transform
  9. # load some utils
  10. from utils.misc import load_weight
  11. from config import build_config
  12. from models import build_model
  13. def parse_args():
  14. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  15. # Basic setting
  16. parser.add_argument('-size', '--img_size', default=640, type=int,
  17. help='the max size of input image')
  18. parser.add_argument('--cuda', action='store_true', default=False,
  19. help='Use cuda')
  20. # Model setting
  21. parser.add_argument('-m', '--model', default='yolov1', type=str,
  22. help='build yolo')
  23. parser.add_argument('--weight', default=None,
  24. type=str, help='Trained state_dict file path to open')
  25. parser.add_argument('-r', '--resume', default=None, type=str,
  26. help='keep training')
  27. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  28. help='fuse Conv & BN')
  29. # Data setting
  30. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',
  31. help='data root')
  32. parser.add_argument('-d', '--dataset', default='coco',
  33. help='coco, voc.')
  34. # TTA
  35. parser.add_argument('-tta', '--test_aug', action='store_true', default=False,
  36. help='use test augmentation.')
  37. return parser.parse_args()
  38. def voc_test(cfg, model, data_dir, device, transform):
  39. evaluator = VOCAPIEvaluator(cfg=cfg,
  40. data_dir=data_dir,
  41. device=device,
  42. transform=transform,
  43. display=True)
  44. # VOC evaluation
  45. evaluator.evaluate(model)
  46. def coco_test(cfg, model, data_dir, device, transform):
  47. # eval
  48. evaluator = COCOAPIEvaluator(
  49. cfg=cfg,
  50. data_dir=data_dir,
  51. device=device,
  52. transform=transform)
  53. # COCO evaluation
  54. evaluator.evaluate(model)
  55. def custom_test(cfg, model, data_dir, device, transform):
  56. evaluator = CustomEvaluator(
  57. cfg=cfg,
  58. data_dir=data_dir,
  59. device=device,
  60. image_set='val',
  61. transform=transform)
  62. # WiderFace evaluation
  63. evaluator.evaluate(model)
  64. if __name__ == '__main__':
  65. args = parse_args()
  66. # cuda
  67. if args.cuda:
  68. print('use cuda')
  69. device = torch.device("cuda")
  70. else:
  71. device = torch.device("cpu")
  72. # Dataset & Model Config
  73. cfg = build_config(args)
  74. # Transform
  75. transform = build_transform(cfg, is_train=False)
  76. # Dataset
  77. dataset = build_dataset(args, cfg, transform, is_train=False)
  78. # build model
  79. model, _ = build_model(args, cfg, is_val=True)
  80. # load trained weight
  81. model = load_weight(model, args.weight, args.fuse_conv_bn)
  82. model.to(device).eval()
  83. # evaluation
  84. with torch.no_grad():
  85. if args.dataset == 'voc':
  86. voc_test(cfg, model, args.root, device, transform)
  87. elif args.dataset == 'coco':
  88. coco_test(cfg, model, args.root, device, transform)
  89. elif args.dataset == 'custom':
  90. custom_test(cfg, model, args.root, device, transform)