benchmark.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import argparse
  2. import time
  3. import torch
  4. # load transform
  5. from dataset.build import build_dataset, build_transform
  6. # load some utils
  7. from utils.misc import load_weight, compute_flops
  8. from config import build_config
  9. from models import build_model
  10. def parse_args():
  11. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  12. # Basic setting
  13. parser.add_argument('-size', '--img_size', default=640, type=int,
  14. help='the max size of input image')
  15. parser.add_argument('--cuda', action='store_true', default=False,
  16. help='use cuda.')
  17. # Model setting
  18. parser.add_argument('-m', '--model', default='yolov1_r18', type=str,
  19. help='build yolo')
  20. parser.add_argument('--weight', default=None,
  21. type=str, help='Trained state_dict file path to open')
  22. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  23. help='fuse Conv & BN')
  24. # Data setting
  25. parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
  26. help='data root')
  27. return parser.parse_args()
  28. @torch.no_grad()
  29. def test_det(model,
  30. device,
  31. dataset,
  32. transform=None
  33. ):
  34. # Step-1: Compute FLOPs and Params
  35. compute_flops(model, cfg.test_img_size, device)
  36. # Step-2: Compute FPS
  37. num_images = 2002
  38. total_time = 0
  39. count = 0
  40. with torch.no_grad():
  41. for index in range(num_images):
  42. if index % 500 == 0:
  43. print('Testing image {:d}/{:d}....'.format(index+1, num_images))
  44. # Load an image
  45. image, _ = dataset.pull_image(index)
  46. # Preprocess
  47. x, _, ratio = transform(image)
  48. x = x.unsqueeze(0).to(device)
  49. # Start
  50. torch.cuda.synchronize()
  51. start_time = time.perf_counter()
  52. # Inference
  53. outputs = model(x)
  54. # End
  55. torch.cuda.synchronize()
  56. elapsed = time.perf_counter() - start_time
  57. if index > 1:
  58. total_time += elapsed
  59. count += 1
  60. print('- FPS :', 1.0 / (total_time / count))
  61. if __name__ == '__main__':
  62. args = parse_args()
  63. # cuda
  64. if args.cuda:
  65. print('use cuda')
  66. device = torch.device("cuda")
  67. else:
  68. device = torch.device("cpu")
  69. # Model Config
  70. cfg = build_config(args)
  71. # Transform
  72. transform = build_transform(cfg, is_train=False)
  73. # Dataset
  74. args.dataset = 'coco'
  75. dataset = build_dataset(args, cfg, transform, is_train=False)
  76. # Build model
  77. model = build_model(args, cfg, is_val=False)
  78. # Load trained weight
  79. model = load_weight(model, args.weight, args.fuse_conv_bn)
  80. model.to(device).eval()
  81. # Run
  82. test_det(model = model,
  83. device = device,
  84. dataset = dataset,
  85. transform = transform,
  86. )