eval.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import argparse
  2. import os
  3. from copy import deepcopy
  4. import torch
  5. from evaluator.voc_evaluator import VOCAPIEvaluator
  6. from evaluator.coco_evaluator import COCOAPIEvaluator
  7. # load transform
  8. from dataset.data_augment import build_transform
  9. # load some utils
  10. from utils.misc import load_weight
  11. from utils.com_flops_params import FLOPs_and_Params
  12. from models import build_model
  13. from config import build_model_config, build_trans_config
  14. def parse_args():
  15. parser = argparse.ArgumentParser(description='YOLO-Tutorial')
  16. # basic
  17. parser.add_argument('-size', '--img_size', default=640, type=int,
  18. help='the max size of input image')
  19. parser.add_argument('--cuda', action='store_true', default=False,
  20. help='Use cuda')
  21. # model
  22. parser.add_argument('-m', '--model', default='yolov1', type=str,
  23. help='build yolo')
  24. parser.add_argument('--weight', default=None,
  25. type=str, help='Trained state_dict file path to open')
  26. parser.add_argument('--conf_thresh', default=0.001, type=float,
  27. help='NMS threshold')
  28. parser.add_argument('--nms_thresh', default=0.6, type=float,
  29. help='NMS threshold')
  30. parser.add_argument('--topk', default=1000, type=int,
  31. help='topk candidates for testing')
  32. parser.add_argument("--no_decode", action="store_true", default=False,
  33. help="not decode in inference or yes")
  34. parser.add_argument('--fuse_repconv', action='store_true', default=False,
  35. help='fuse RepConv')
  36. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  37. help='fuse Conv & BN')
  38. # dataset
  39. parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
  40. help='data root')
  41. parser.add_argument('-d', '--dataset', default='coco',
  42. help='coco, voc.')
  43. # TTA
  44. parser.add_argument('-tta', '--test_aug', action='store_true', default=False,
  45. help='use test augmentation.')
  46. return parser.parse_args()
  47. def voc_test(model, data_dir, device, transform):
  48. evaluator = VOCAPIEvaluator(data_dir=data_dir,
  49. device=device,
  50. transform=transform,
  51. display=True)
  52. # VOC evaluation
  53. evaluator.evaluate(model)
  54. def coco_test(model, data_dir, device, transform, test=False):
  55. if test:
  56. # test-dev
  57. print('test on test-dev 2017')
  58. evaluator = COCOAPIEvaluator(
  59. data_dir=data_dir,
  60. device=device,
  61. testset=True,
  62. transform=transform)
  63. else:
  64. # eval
  65. evaluator = COCOAPIEvaluator(
  66. data_dir=data_dir,
  67. device=device,
  68. testset=False,
  69. transform=transform)
  70. # COCO evaluation
  71. evaluator.evaluate(model)
  72. if __name__ == '__main__':
  73. args = parse_args()
  74. # cuda
  75. if args.cuda:
  76. print('use cuda')
  77. device = torch.device("cuda")
  78. else:
  79. device = torch.device("cpu")
  80. # dataset
  81. if args.dataset == 'voc':
  82. print('eval on voc ...')
  83. num_classes = 20
  84. data_dir = os.path.join(args.root, 'VOCdevkit')
  85. elif args.dataset == 'coco-val':
  86. print('eval on coco-val ...')
  87. num_classes = 80
  88. data_dir = os.path.join(args.root, 'COCO')
  89. elif args.dataset == 'coco-test':
  90. print('eval on coco-test-dev ...')
  91. num_classes = 80
  92. data_dir = os.path.join(args.root, 'COCO')
  93. else:
  94. print('unknow dataset !! we only support voc, coco-val, coco-test !!!')
  95. exit(0)
  96. # config
  97. model_cfg = build_model_config(args)
  98. trans_cfg = build_trans_config(model_cfg['trans_type'])
  99. # build model
  100. model = build_model(args, model_cfg, device, num_classes, False)
  101. # load trained weight
  102. model = load_weight(model, args.weight, args.fuse_conv_bn, args.fuse_repconv)
  103. model.to(device).eval()
  104. # compute FLOPs and Params
  105. model_copy = deepcopy(model)
  106. model_copy.trainable = False
  107. model_copy.eval()
  108. FLOPs_and_Params(
  109. model=model_copy,
  110. img_size=args.img_size,
  111. device=device)
  112. del model_copy
  113. # transform
  114. transform = build_transform(args.img_size, trans_cfg, is_train=False)
  115. # evaluation
  116. with torch.no_grad():
  117. if args.dataset == 'voc':
  118. voc_test(model, data_dir, device, transform)
  119. elif args.dataset == 'coco-val':
  120. coco_test(model, data_dir, device, transform, test=False)
  121. elif args.dataset == 'coco-test':
  122. coco_test(model, data_dir, device, transform, test=True)