eval.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import argparse
  2. import torch
  3. from evaluator.map_evaluator import MapEvaluator
  4. from dataset.build import build_dataset, build_transform
  5. from utils.misc import load_weight
  6. from config import build_config
  7. from models import build_model
  8. def parse_args():
  9. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  10. # Basic setting
  11. parser.add_argument('--cuda', action='store_true', default=False,
  12. help='Use cuda')
  13. # Model setting
  14. parser.add_argument('--model', default='yolov1', type=str,
  15. help='build yolo')
  16. parser.add_argument('--weight', default=None,
  17. type=str, help='Trained state_dict file path to open')
  18. parser.add_argument('--resume', default=None, type=str,
  19. help='keep training')
  20. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  21. help='fuse Conv & BN')
  22. # Data setting
  23. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',
  24. help='data root')
  25. parser.add_argument('-d', '--dataset', default='coco',
  26. help='coco, voc.')
  27. return parser.parse_args()
  28. if __name__ == '__main__':
  29. args = parse_args()
  30. # cuda
  31. if args.cuda:
  32. print('use cuda')
  33. device = torch.device("cuda")
  34. else:
  35. device = torch.device("cpu")
  36. # Dataset & Model Config
  37. cfg = build_config(args)
  38. # Transform
  39. transform = build_transform(cfg, is_train=False)
  40. # Dataset
  41. dataset = build_dataset(args, cfg, transform, is_train=False)
  42. # build model
  43. model, _ = build_model(args, cfg, is_val=True)
  44. # load trained weight
  45. model = load_weight(model, args.weight, args.fuse_conv_bn)
  46. model.to(device).eval()
  47. # evaluation
  48. evaluator = MapEvaluator(cfg = cfg,
  49. dataset_name = args.dataset,
  50. data_dir = args.root,
  51. device = device,
  52. transform = transform
  53. )
  54. evaluator.evaluate(model)