eval.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import argparse
  2. import os
  3. from copy import deepcopy
  4. import torch
  5. from evaluator.voc_evaluator import VOCAPIEvaluator
  6. from evaluator.coco_evaluator import COCOAPIEvaluator
  7. from evaluator.ourdataset_evaluator import OurDatasetEvaluator
  8. # load transform
  9. from dataset.build import build_transform
  10. # load some utils
  11. from utils.misc import load_weight
  12. from utils.misc import compute_flops
  13. from config import build_dataset_config, build_model_config, build_trans_config
  14. from models.detectors import build_model
  15. def parse_args():
  16. parser = argparse.ArgumentParser(description='YOLO-Tutorial')
  17. # basic
  18. parser.add_argument('-size', '--img_size', default=640, type=int,
  19. help='the max size of input image')
  20. parser.add_argument('--cuda', action='store_true', default=False,
  21. help='Use cuda')
  22. # model
  23. parser.add_argument('-m', '--model', default='yolov1', type=str,
  24. help='build yolo')
  25. parser.add_argument('--weight', default=None,
  26. type=str, help='Trained state_dict file path to open')
  27. parser.add_argument('-ct', '--conf_thresh', default=0.005, type=float,
  28. help='confidence threshold')
  29. parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,
  30. help='NMS threshold')
  31. parser.add_argument('--topk', default=1000, type=int,
  32. help='topk candidates for testing')
  33. parser.add_argument("--no_decode", action="store_true", default=False,
  34. help="not decode in inference or yes")
  35. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  36. help='fuse Conv & BN')
  37. # dataset
  38. parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
  39. help='data root')
  40. parser.add_argument('-d', '--dataset', default='coco',
  41. help='coco, voc.')
  42. parser.add_argument('--mosaic', default=None, type=float,
  43. help='mosaic augmentation.')
  44. parser.add_argument('--mixup', default=None, type=float,
  45. help='mixup augmentation.')
  46. parser.add_argument('--load_cache', action='store_true', default=False,
  47. help='load data into memory.')
  48. # TTA
  49. parser.add_argument('-tta', '--test_aug', action='store_true', default=False,
  50. help='use test augmentation.')
  51. return parser.parse_args()
  52. def voc_test(model, data_dir, device, transform):
  53. evaluator = VOCAPIEvaluator(data_dir=data_dir,
  54. device=device,
  55. transform=transform,
  56. display=True)
  57. # VOC evaluation
  58. evaluator.evaluate(model)
  59. def coco_test(model, data_dir, device, transform, test=False):
  60. if test:
  61. # test-dev
  62. print('test on test-dev 2017')
  63. evaluator = COCOAPIEvaluator(
  64. data_dir=data_dir,
  65. device=device,
  66. testset=True,
  67. transform=transform)
  68. else:
  69. # eval
  70. evaluator = COCOAPIEvaluator(
  71. data_dir=data_dir,
  72. device=device,
  73. testset=False,
  74. transform=transform)
  75. # COCO evaluation
  76. evaluator.evaluate(model)
  77. def our_test(model, data_dir, device, transform):
  78. evaluator = OurDatasetEvaluator(
  79. data_dir=data_dir,
  80. device=device,
  81. image_set='val',
  82. transform=transform)
  83. # WiderFace evaluation
  84. evaluator.evaluate(model)
  85. if __name__ == '__main__':
  86. args = parse_args()
  87. # cuda
  88. if args.cuda:
  89. print('use cuda')
  90. device = torch.device("cuda")
  91. else:
  92. device = torch.device("cpu")
  93. # Dataset & Model Config
  94. data_cfg = build_dataset_config(args)
  95. model_cfg = build_model_config(args)
  96. trans_cfg = build_trans_config(model_cfg['trans_type'])
  97. data_dir = os.path.join(args.root, data_cfg['data_name'])
  98. num_classes = data_cfg['num_classes']
  99. # build model
  100. model = build_model(args, model_cfg, device, num_classes, False)
  101. # load trained weight
  102. model = load_weight(model, args.weight, args.fuse_conv_bn)
  103. model.to(device).eval()
  104. # compute FLOPs and Params
  105. model_copy = deepcopy(model)
  106. model_copy.trainable = False
  107. model_copy.eval()
  108. compute_flops(
  109. model=model_copy,
  110. img_size=args.img_size,
  111. device=device)
  112. del model_copy
  113. # transform
  114. val_transform, trans_cfg = build_transform(args, trans_cfg, model_cfg['max_stride'], is_train=False)
  115. # evaluation
  116. with torch.no_grad():
  117. if args.dataset == 'voc':
  118. voc_test(model, data_dir, device, val_transform)
  119. elif args.dataset == 'coco-val':
  120. coco_test(model, data_dir, device, val_transform, test=False)
  121. elif args.dataset == 'coco-test':
  122. coco_test(model, data_dir, device, val_transform, test=True)
  123. elif args.dataset == 'ourdataset':
  124. our_test(model, data_dir, device, val_transform)