benchmark.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import argparse
  2. import cv2
  3. import os
  4. import time
  5. import numpy as np
  6. from copy import deepcopy
  7. import torch
  8. # load transform
  9. from dataset.build import build_dataset, build_transform
  10. # load some utils
  11. from utils.misc import load_weight, compute_flops
  12. from utils.box_ops import rescale_bboxes
  13. from utils.vis_tools import visualize
  14. from config import build_config
  15. from models import build_model
  16. def parse_args():
  17. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  18. # Basic setting
  19. parser.add_argument('-size', '--img_size', default=640, type=int,
  20. help='the max size of input image')
  21. parser.add_argument('--cuda', action='store_true', default=False,
  22. help='use cuda.')
  23. # Model setting
  24. parser.add_argument('-m', '--model', default='yolo_n', type=str,
  25. help='build yolo')
  26. parser.add_argument('--weight', default=None,
  27. type=str, help='Trained state_dict file path to open')
  28. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  29. help='fuse Conv & BN')
  30. parser.add_argument('--fuse_rep_conv', action='store_true', default=False,
  31. help='fuse Conv & BN')
  32. # Data setting
  33. parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
  34. help='data root')
  35. return parser.parse_args()
  36. @torch.no_grad()
  37. def test_det(model,
  38. device,
  39. dataset,
  40. transform=None
  41. ):
  42. # Step-1: Compute FLOPs and Params
  43. compute_flops(model, cfg.test_img_size, device)
  44. # Step-2: Compute FPS
  45. num_images = 2002
  46. total_time = 0
  47. count = 0
  48. with torch.no_grad():
  49. for index in range(num_images):
  50. if index % 500 == 0:
  51. print('Testing image {:d}/{:d}....'.format(index+1, num_images))
  52. # Load an image
  53. image, _ = dataset.pull_image(index)
  54. # Preprocess
  55. x, _, ratio = transform(image)
  56. x = x.unsqueeze(0).to(device)
  57. # Start
  58. torch.cuda.synchronize()
  59. start_time = time.perf_counter()
  60. # Inference
  61. outputs = model(x)
  62. # End
  63. torch.cuda.synchronize()
  64. elapsed = time.perf_counter() - start_time
  65. if index > 1:
  66. total_time += elapsed
  67. count += 1
  68. print('- FPS :', 1.0 / (total_time / count))
  69. if __name__ == '__main__':
  70. args = parse_args()
  71. # cuda
  72. if args.cuda:
  73. print('use cuda')
  74. device = torch.device("cuda")
  75. else:
  76. device = torch.device("cpu")
  77. # Model Config
  78. cfg = build_config(args)
  79. # Transform
  80. transform = build_transform(cfg, is_train=False)
  81. # Dataset
  82. args.dataset = 'coco'
  83. dataset = build_dataset(args, cfg, transform, is_train=False)
  84. # Build model
  85. model = build_model(args, cfg, is_val=False)
  86. # Load trained weight
  87. model = load_weight(model, args.weight, args.fuse_conv_bn, args.fuse_rep_conv)
  88. model.to(device).eval()
  89. # Run
  90. test_det(model = model,
  91. device = device,
  92. dataset = dataset,
  93. transform = transform,
  94. )