eval.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import argparse
  2. import torch
  3. # evaluators
  4. from evaluator.map_evaluator import MapEvaluator
  5. # load transform
  6. from dataset.build import build_dataset, build_transform
  7. # load some utils
  8. from utils.misc import load_weight
  9. from config import build_config
  10. from models import build_model
  11. def parse_args():
  12. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  13. # Basic setting
  14. parser.add_argument('-size', '--img_size', default=640, type=int,
  15. help='the max size of input image')
  16. parser.add_argument('--cuda', action='store_true', default=False,
  17. help='Use cuda')
  18. # Model setting
  19. parser.add_argument('-m', '--model', default='yolov1', type=str,
  20. help='build yolo')
  21. parser.add_argument('--weight', default=None,
  22. type=str, help='Trained state_dict file path to open')
  23. parser.add_argument('-r', '--resume', default=None, type=str,
  24. help='keep training')
  25. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  26. help='fuse Conv & BN')
  27. # Data setting
  28. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',
  29. help='data root')
  30. parser.add_argument('-d', '--dataset', default='coco',
  31. help='coco, voc.')
  32. # TTA
  33. parser.add_argument('-tta', '--test_aug', action='store_true', default=False,
  34. help='use test augmentation.')
  35. return parser.parse_args()
  36. if __name__ == '__main__':
  37. args = parse_args()
  38. # cuda
  39. if args.cuda:
  40. print('use cuda')
  41. device = torch.device("cuda")
  42. else:
  43. device = torch.device("cpu")
  44. # Dataset & Model Config
  45. cfg = build_config(args)
  46. # Transform
  47. transform = build_transform(cfg, is_train=False)
  48. # Dataset
  49. dataset = build_dataset(args, cfg, transform, is_train=False)
  50. # build model
  51. model, _ = build_model(args, cfg, is_val=True)
  52. # load trained weight
  53. model = load_weight(model, args.weight, args.fuse_conv_bn)
  54. model.to(device).eval()
  55. # evaluation
  56. evaluator = MapEvaluator(cfg = cfg,
  57. dataset_name = args.dataset,
  58. data_dir = args.root,
  59. device = device,
  60. transform = transform
  61. )
  62. evaluator.evaluate(model)