eval.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import argparse
  2. import os
  3. from copy import deepcopy
  4. import torch
  5. from evaluator.voc_evaluator import VOCAPIEvaluator
  6. from evaluator.coco_evaluator import COCOAPIEvaluator
  7. # load transform
  8. from dataset.data_augment import build_transform
  9. # load some utils
  10. from utils.misc import load_weight
  11. from utils.com_flops_params import FLOPs_and_Params
  12. from models import build_model
  13. from config import build_model_config, build_trans_config
  14. def parse_args():
  15. parser = argparse.ArgumentParser(description='YOLO-Tutorial')
  16. # basic
  17. parser.add_argument('-size', '--img_size', default=640, type=int,
  18. help='the max size of input image')
  19. parser.add_argument('--cuda', action='store_true', default=False,
  20. help='Use cuda')
  21. # model
  22. parser.add_argument('-m', '--model', default='yolov1', type=str,
  23. choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolox'], help='build yolo')
  24. parser.add_argument('--weight', default=None,
  25. type=str, help='Trained state_dict file path to open')
  26. parser.add_argument('--conf_thresh', default=0.001, type=float,
  27. help='NMS threshold')
  28. parser.add_argument('--nms_thresh', default=0.6, type=float,
  29. help='NMS threshold')
  30. parser.add_argument('--topk', default=1000, type=int,
  31. help='topk candidates for testing')
  32. parser.add_argument("--no_decode", action="store_true", default=False,
  33. help="not decode in inference or yes")
  34. # dataset
  35. parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
  36. help='data root')
  37. parser.add_argument('-d', '--dataset', default='coco',
  38. help='coco, voc.')
  39. # TTA
  40. parser.add_argument('-tta', '--test_aug', action='store_true', default=False,
  41. help='use test augmentation.')
  42. return parser.parse_args()
  43. def voc_test(model, data_dir, device, transform):
  44. evaluator = VOCAPIEvaluator(data_dir=data_dir,
  45. device=device,
  46. transform=transform,
  47. display=True)
  48. # VOC evaluation
  49. evaluator.evaluate(model)
  50. def coco_test(model, data_dir, device, transform, test=False):
  51. if test:
  52. # test-dev
  53. print('test on test-dev 2017')
  54. evaluator = COCOAPIEvaluator(
  55. data_dir=data_dir,
  56. device=device,
  57. testset=True,
  58. transform=transform)
  59. else:
  60. # eval
  61. evaluator = COCOAPIEvaluator(
  62. data_dir=data_dir,
  63. device=device,
  64. testset=False,
  65. transform=transform)
  66. # COCO evaluation
  67. evaluator.evaluate(model)
  68. if __name__ == '__main__':
  69. args = parse_args()
  70. # cuda
  71. if args.cuda:
  72. print('use cuda')
  73. device = torch.device("cuda")
  74. else:
  75. device = torch.device("cpu")
  76. # dataset
  77. if args.dataset == 'voc':
  78. print('eval on voc ...')
  79. num_classes = 20
  80. data_dir = os.path.join(args.root, 'VOCdevkit')
  81. elif args.dataset == 'coco-val':
  82. print('eval on coco-val ...')
  83. num_classes = 80
  84. data_dir = os.path.join(args.root, 'COCO')
  85. elif args.dataset == 'coco-test':
  86. print('eval on coco-test-dev ...')
  87. num_classes = 80
  88. data_dir = os.path.join(args.root, 'COCO')
  89. else:
  90. print('unknow dataset !! we only support voc, coco-val, coco-test !!!')
  91. exit(0)
  92. # config
  93. model_cfg = build_model_config(args)
  94. trans_cfg = build_trans_config(model_cfg['trans_type'])
  95. # build model
  96. model = build_model(args, model_cfg, device, num_classes, False)
  97. # load trained weight
  98. model = load_weight(model=model, path_to_ckpt=args.weight)
  99. model.to(device).eval()
  100. # compute FLOPs and Params
  101. model_copy = deepcopy(model)
  102. model_copy.trainable = False
  103. model_copy.eval()
  104. FLOPs_and_Params(
  105. model=model_copy,
  106. img_size=args.img_size,
  107. device=device)
  108. del model_copy
  109. # transform
  110. transform = build_transform(args.img_size, trans_cfg, is_train=False)
  111. # evaluation
  112. with torch.no_grad():
  113. if args.dataset == 'voc':
  114. voc_test(model, data_dir, device, transform)
  115. elif args.dataset == 'coco-val':
  116. coco_test(model, data_dir, device, transform, test=False)
  117. elif args.dataset == 'coco-test':
  118. coco_test(model, data_dir, device, transform, test=True)