benchmark.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import argparse
  2. import numpy as np
  3. import time
  4. import os
  5. import torch
  6. from datasets import build_dataset, build_transform
  7. from utils.misc import compute_flops, fuse_conv_bn
  8. from utils.misc import load_weight
  9. from config import build_config
  10. from models.detectors import build_model
  11. parser = argparse.ArgumentParser(description='Benchmark')
  12. # Model
  13. parser.add_argument('-m', '--model', default='fcos_r18_1x',
  14. help='build detector')
  15. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  16. help='fuse conv and bn')
  17. parser.add_argument('--topk', default=100, type=int,
  18. help='NMS threshold')
  19. parser.add_argument('--weight', default=None, type=str,
  20. help='Trained state_dict file path to open')
  21. # Data root
  22. parser.add_argument('--root', default='/data/datasets/COCO',
  23. help='data root')
  24. # cuda
  25. parser.add_argument('--cuda', action='store_true', default=False,
  26. help='use cuda.')
  27. args = parser.parse_args()
  28. def test(cfg, model, device, dataset, transform):
  29. # Step-1: Compute FLOPs and Params
  30. compute_flops(
  31. model=model,
  32. min_size=cfg['test_min_size'],
  33. max_size=cfg['test_max_size'],
  34. device=device)
  35. # Step-2: Compute FPS
  36. num_images = 2002
  37. total_time = 0
  38. count = 0
  39. with torch.no_grad():
  40. for index in range(num_images):
  41. if index % 500 == 0:
  42. print('Testing image {:d}/{:d}....'.format(index+1, num_images))
  43. image, _ = dataset[index]
  44. orig_h, orig_w = image.height, image.width
  45. # PreProcess
  46. x, _ = transform(image)
  47. x = x.unsqueeze(0).to(device)
  48. # star time
  49. torch.cuda.synchronize()
  50. start_time = time.perf_counter()
  51. # inference
  52. bboxes, scores, labels = model(x)
  53. # Rescale bboxes
  54. bboxes[..., 0::2] *= orig_w
  55. bboxes[..., 1::2] *= orig_h
  56. # end time
  57. torch.cuda.synchronize()
  58. elapsed = time.perf_counter() - start_time
  59. # print("detection time used ", elapsed, "s")
  60. if index > 1:
  61. total_time += elapsed
  62. count += 1
  63. print('- FPS :', 1.0 / (total_time / count))
  64. if __name__ == '__main__':
  65. # get device
  66. if args.cuda:
  67. print('use cuda')
  68. device = torch.device("cuda")
  69. else:
  70. device = torch.device("cpu")
  71. # Dataset & Model Config
  72. cfg = build_config(args)
  73. # Transform
  74. transform = build_transform(cfg, is_train=False)
  75. # Dataset
  76. args.dataset = 'coco'
  77. dataset, dataset_info = build_dataset(args, is_train=False)
  78. # Model
  79. model = build_model(args, cfg, device, dataset_info['num_classes'], False)
  80. model = load_weight(model, args.weight, args.fuse_conv_bn)
  81. model.to(device).eval()
  82. # fuse conv bn
  83. if args.fuse_conv_bn:
  84. print('fuse conv and bn ...')
  85. model = fuse_conv_bn(model)
  86. # run
  87. test(cfg, model, device, dataset, transform)