eval.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import argparse
  2. import os
  3. from copy import deepcopy
  4. import torch
  5. from evaluator.voc_evaluator import VOCAPIEvaluator
  6. from evaluator.coco_evaluator import COCOAPIEvaluator
  7. from evaluator.ourdataset_evaluator import OurDatasetEvaluator
  8. # load transform
  9. from dataset.ourdataset import our_class_labels
  10. from dataset.data_augment import build_transform
  11. # load some utils
  12. from utils.misc import load_weight
  13. from utils.com_flops_params import FLOPs_and_Params
  14. from models import build_model
  15. from config import build_model_config, build_trans_config
  16. def parse_args():
  17. parser = argparse.ArgumentParser(description='YOLO-Tutorial')
  18. # basic
  19. parser.add_argument('-size', '--img_size', default=640, type=int,
  20. help='the max size of input image')
  21. parser.add_argument('--cuda', action='store_true', default=False,
  22. help='Use cuda')
  23. # model
  24. parser.add_argument('-m', '--model', default='yolov1', type=str,
  25. help='build yolo')
  26. parser.add_argument('--weight', default=None,
  27. type=str, help='Trained state_dict file path to open')
  28. parser.add_argument('-ct', '--conf_thresh', default=0.001, type=float,
  29. help='confidence threshold')
  30. parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,
  31. help='NMS threshold')
  32. parser.add_argument('--topk', default=1000, type=int,
  33. help='topk candidates for testing')
  34. parser.add_argument("--no_decode", action="store_true", default=False,
  35. help="not decode in inference or yes")
  36. parser.add_argument('--fuse_repconv', action='store_true', default=False,
  37. help='fuse RepConv')
  38. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  39. help='fuse Conv & BN')
  40. # dataset
  41. parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
  42. help='data root')
  43. parser.add_argument('-d', '--dataset', default='coco',
  44. help='coco, voc.')
  45. # TTA
  46. parser.add_argument('-tta', '--test_aug', action='store_true', default=False,
  47. help='use test augmentation.')
  48. return parser.parse_args()
  49. def voc_test(model, data_dir, device, transform):
  50. evaluator = VOCAPIEvaluator(data_dir=data_dir,
  51. device=device,
  52. transform=transform,
  53. display=True)
  54. # VOC evaluation
  55. evaluator.evaluate(model)
  56. def coco_test(model, data_dir, device, transform, test=False):
  57. if test:
  58. # test-dev
  59. print('test on test-dev 2017')
  60. evaluator = COCOAPIEvaluator(
  61. data_dir=data_dir,
  62. device=device,
  63. testset=True,
  64. transform=transform)
  65. else:
  66. # eval
  67. evaluator = COCOAPIEvaluator(
  68. data_dir=data_dir,
  69. device=device,
  70. testset=False,
  71. transform=transform)
  72. # COCO evaluation
  73. evaluator.evaluate(model)
  74. def our_test(model, data_dir, device, transform):
  75. evaluator = OurDatasetEvaluator(
  76. data_dir=data_dir,
  77. device=device,
  78. image_set='val',
  79. transform=transform)
  80. # WiderFace evaluation
  81. evaluator.evaluate(model)
  82. if __name__ == '__main__':
  83. args = parse_args()
  84. # cuda
  85. if args.cuda:
  86. print('use cuda')
  87. device = torch.device("cuda")
  88. else:
  89. device = torch.device("cpu")
  90. # dataset
  91. if args.dataset == 'voc':
  92. print('eval on voc ...')
  93. num_classes = 20
  94. data_dir = os.path.join(args.root, 'VOCdevkit')
  95. elif args.dataset == 'coco-val':
  96. print('eval on coco-val ...')
  97. num_classes = 80
  98. data_dir = os.path.join(args.root, 'COCO')
  99. elif args.dataset == 'coco-test':
  100. print('eval on coco-test-dev ...')
  101. num_classes = 80
  102. data_dir = os.path.join(args.root, 'COCO')
  103. elif args.dataset == 'ourdataset':
  104. print('eval on crowdhuman ...')
  105. num_classes = len(our_class_labels)
  106. data_dir = os.path.join(args.root, 'OurDataset')
  107. else:
  108. print('unknow dataset !! we only support voc, coco-val, coco-test !!!')
  109. exit(0)
  110. # config
  111. model_cfg = build_model_config(args)
  112. trans_cfg = build_trans_config(model_cfg['trans_type'])
  113. # build model
  114. model = build_model(args, model_cfg, device, num_classes, False)
  115. # load trained weight
  116. model = load_weight(model, args.weight, args.fuse_conv_bn, args.fuse_repconv)
  117. model.to(device).eval()
  118. # compute FLOPs and Params
  119. model_copy = deepcopy(model)
  120. model_copy.trainable = False
  121. model_copy.eval()
  122. FLOPs_and_Params(
  123. model=model_copy,
  124. img_size=args.img_size,
  125. device=device)
  126. del model_copy
  127. # transform
  128. transform = build_transform(args.img_size, trans_cfg, is_train=False)
  129. # evaluation
  130. with torch.no_grad():
  131. if args.dataset == 'voc':
  132. voc_test(model, data_dir, device, transform)
  133. elif args.dataset == 'coco-val':
  134. coco_test(model, data_dir, device, transform, test=False)
  135. elif args.dataset == 'coco-test':
  136. coco_test(model, data_dir, device, transform, test=True)
  137. elif args.dataset == 'ourdataset':
  138. our_test(model, data_dir, device, transform)