eval.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. import argparse
  2. import os
  3. from copy import deepcopy
  4. import torch
  5. from evaluator.voc_evaluator import VOCAPIEvaluator
  6. from evaluator.coco_evaluator import COCOAPIEvaluator
  7. from evaluator.crowdhuman_evaluator import CrowdHumanEvaluator
  8. from evaluator.widerface_evaluator import WiderFaceEvaluator
  9. from evaluator.customed_evaluator import CustomedEvaluator
  10. # load transform
  11. from dataset.build import build_transform
  12. # load some utils
  13. from utils.misc import load_weight
  14. from utils.misc import compute_flops
  15. from config import build_dataset_config, build_model_config, build_trans_config
  16. from models.detectors import build_model
  17. def parse_args():
  18. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  19. # Basic setting
  20. parser.add_argument('-size', '--img_size', default=640, type=int,
  21. help='the max size of input image')
  22. parser.add_argument('--cuda', action='store_true', default=False,
  23. help='Use cuda')
  24. # Model setting
  25. parser.add_argument('-m', '--model', default='yolov1', type=str,
  26. help='build yolo')
  27. parser.add_argument('--weight', default=None,
  28. type=str, help='Trained state_dict file path to open')
  29. parser.add_argument('-ct', '--conf_thresh', default=0.001, type=float,
  30. help='confidence threshold')
  31. parser.add_argument('-nt', '--nms_thresh', default=0.7, type=float,
  32. help='NMS threshold')
  33. parser.add_argument('--topk', default=1000, type=int,
  34. help='topk candidates dets of each level before NMS')
  35. parser.add_argument("--no_decode", action="store_true", default=False,
  36. help="not decode in inference or yes")
  37. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  38. help='fuse Conv & BN')
  39. parser.add_argument('--no_multi_labels', action='store_true', default=False,
  40. help='Perform post-process with multi-labels trick.')
  41. parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
  42. help='Perform NMS operations regardless of category.')
  43. # Data setting
  44. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',
  45. help='data root')
  46. parser.add_argument('-d', '--dataset', default='coco',
  47. help='coco, voc.')
  48. parser.add_argument('--mosaic', default=None, type=float,
  49. help='mosaic augmentation.')
  50. parser.add_argument('--mixup', default=None, type=float,
  51. help='mixup augmentation.')
  52. parser.add_argument('--load_cache', action='store_true', default=False,
  53. help='load data into memory.')
  54. # TTA
  55. parser.add_argument('-tta', '--test_aug', action='store_true', default=False,
  56. help='use test augmentation.')
  57. return parser.parse_args()
  58. def voc_test(model, data_dir, device, transform):
  59. evaluator = VOCAPIEvaluator(data_dir=data_dir,
  60. device=device,
  61. transform=transform,
  62. display=True)
  63. # VOC evaluation
  64. evaluator.evaluate(model)
  65. def coco_test(model, data_dir, device, transform, test=False):
  66. if test:
  67. # test-dev
  68. print('test on test-dev 2017')
  69. evaluator = COCOAPIEvaluator(
  70. data_dir=data_dir,
  71. device=device,
  72. testset=True,
  73. transform=transform)
  74. else:
  75. # eval
  76. evaluator = COCOAPIEvaluator(
  77. data_dir=data_dir,
  78. device=device,
  79. testset=False,
  80. transform=transform)
  81. # COCO evaluation
  82. evaluator.evaluate(model)
  83. def crowdhuman_test(model, data_dir, device, transform):
  84. evaluator = CrowdHumanEvaluator(
  85. data_dir=data_dir,
  86. device=device,
  87. image_set='val',
  88. transform=transform)
  89. # WiderFace evaluation
  90. evaluator.evaluate(model)
  91. def widerface_test(model, data_dir, device, transform):
  92. evaluator = WiderFaceEvaluator(
  93. data_dir=data_dir,
  94. device=device,
  95. image_set='val',
  96. transform=transform)
  97. # WiderFace evaluation
  98. evaluator.evaluate(model)
  99. def customed_test(model, data_dir, device, transform):
  100. evaluator = CustomedEvaluator(
  101. data_dir=data_dir,
  102. device=device,
  103. image_set='val',
  104. transform=transform)
  105. # WiderFace evaluation
  106. evaluator.evaluate(model)
  107. if __name__ == '__main__':
  108. args = parse_args()
  109. # cuda
  110. if args.cuda:
  111. print('use cuda')
  112. device = torch.device("cuda")
  113. else:
  114. device = torch.device("cpu")
  115. # Dataset & Model Config
  116. data_cfg = build_dataset_config(args)
  117. model_cfg = build_model_config(args)
  118. trans_cfg = build_trans_config(model_cfg['trans_type'])
  119. data_dir = os.path.join(args.root, data_cfg['data_name'])
  120. num_classes = data_cfg['num_classes']
  121. # build model
  122. model = build_model(args, model_cfg, device, num_classes, False)
  123. # load trained weight
  124. model = load_weight(model, args.weight, args.fuse_conv_bn)
  125. model.to(device).eval()
  126. # compute FLOPs and Params
  127. model_copy = deepcopy(model)
  128. model_copy.trainable = False
  129. model_copy.eval()
  130. compute_flops(
  131. model=model_copy,
  132. img_size=args.img_size,
  133. device=device)
  134. del model_copy
  135. # transform
  136. val_transform, trans_cfg = build_transform(args, trans_cfg, model_cfg['max_stride'], is_train=False)
  137. # evaluation
  138. with torch.no_grad():
  139. if args.dataset == 'voc':
  140. voc_test(model, data_dir, device, val_transform)
  141. elif args.dataset == 'coco-val' or args.dataset == 'coco':
  142. coco_test(model, data_dir, device, val_transform, test=False)
  143. elif args.dataset == 'coco-test':
  144. coco_test(model, data_dir, device, val_transform, test=True)
  145. elif args.dataset == 'ourdataset':
  146. customed_test(model, data_dir, device, val_transform)