benchmark.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import argparse
  2. import time
  3. import torch
  4. # load transform
  5. from datasets import build_dataset, build_transform
  6. # load some utils
  7. from utils.misc import compute_flops, load_weight
  8. from config import build_config
  9. from models.detectors import build_model
  10. parser = argparse.ArgumentParser(description='Benchmark')
  11. # Model
  12. parser.add_argument('-m', '--model', default='fcos_r18_1x',
  13. help='build detector')
  14. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  15. help='fuse conv and bn')
  16. parser.add_argument('--weight', default=None, type=str,
  17. help='Trained state_dict file path to open')
  18. # Data root
  19. parser.add_argument('--root', default='/data/datasets/COCO',
  20. help='data root')
  21. # cuda
  22. parser.add_argument('--cuda', action='store_true', default=False,
  23. help='use cuda.')
  24. args = parser.parse_args()
  25. def test(cfg, model, device, dataset, transform):
  26. # Step-1: Compute FLOPs and Params
  27. compute_flops(
  28. model=model,
  29. min_size=cfg.test_min_size,
  30. max_size=cfg.test_max_size,
  31. device=device)
  32. # Step-2: Compute FPS
  33. num_images = 2002
  34. total_time = 0
  35. count = 0
  36. with torch.no_grad():
  37. for index in range(num_images):
  38. if index % 500 == 0:
  39. print('Testing image {:d}/{:d}....'.format(index+1, num_images))
  40. # Load an image
  41. image, _ = dataset[index]
  42. # Preprocess
  43. x, _ = transform(image)
  44. x = x.unsqueeze(0).to(device)
  45. # Star
  46. torch.cuda.synchronize()
  47. start_time = time.perf_counter()
  48. # Inference
  49. outputs = model(x)
  50. # End
  51. torch.cuda.synchronize()
  52. elapsed = time.perf_counter() - start_time
  53. if index > 1:
  54. total_time += elapsed
  55. count += 1
  56. print('- FPS :', 1.0 / (total_time / count))
  57. if __name__ == '__main__':
  58. # get device
  59. if args.cuda:
  60. print('use cuda')
  61. device = torch.device("cuda")
  62. else:
  63. device = torch.device("cpu")
  64. # Dataset & Model Config
  65. cfg = build_config(args)
  66. # Transform
  67. transform = build_transform(cfg, is_train=False)
  68. # Dataset
  69. args.dataset = 'coco'
  70. dataset = build_dataset(args, cfg, is_train=False)
  71. # Model
  72. model = build_model(args, cfg, is_val=False)
  73. model = load_weight(model, args.weight, args.fuse_conv_bn)
  74. model.to(device).eval()
  75. print("================= DETECT =================")
  76. # Run
  77. test(cfg, model, device, dataset, transform)