test.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import argparse
  2. import cv2
  3. import os
  4. import time
  5. import numpy as np
  6. from copy import deepcopy
  7. import torch
  8. # load transform
  9. from dataset.build import build_dataset, build_transform
  10. # load some utils
  11. from utils.misc import load_weight, compute_flops
  12. from utils.box_ops import rescale_bboxes
  13. from utils.vis_tools import visualize
  14. from config import build_config
  15. from models import build_model
  16. def parse_args():
  17. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  18. # Basic setting
  19. parser.add_argument('--show', action='store_true', default=False,
  20. help='show the visulization results.')
  21. parser.add_argument('--save', action='store_true', default=False,
  22. help='save the visulization results.')
  23. parser.add_argument('--cuda', action='store_true', default=False,
  24. help='use cuda.')
  25. parser.add_argument('--save_folder', default='det_results/', type=str,
  26. help='Dir to save results')
  27. parser.add_argument('--window_scale', default=1.0, type=float,
  28. help='resize window of cv2 for visualization.')
  29. # Model setting
  30. parser.add_argument('--model', default='yolo_n', type=str,
  31. help='build yolo')
  32. parser.add_argument('--weight', default=None,
  33. type=str, help='Trained state_dict file path to open')
  34. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  35. help='fuse Conv & BN')
  36. # Data setting
  37. parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
  38. help='data root')
  39. parser.add_argument('--dataset', default='coco',
  40. help='coco, voc.')
  41. return parser.parse_args()
  42. @torch.no_grad()
  43. def test_det(args,
  44. model,
  45. device,
  46. dataset,
  47. transform=None,
  48. class_colors=None,
  49. class_names=None):
  50. num_images = len(dataset)
  51. save_path = os.path.join('det_results/', args.dataset, args.model)
  52. os.makedirs(save_path, exist_ok=True)
  53. for index in range(num_images):
  54. print('Testing image {:d}/{:d}....'.format(index+1, num_images))
  55. image, _ = dataset.pull_image(index)
  56. orig_h, orig_w, _ = image.shape
  57. orig_size = [orig_w, orig_h]
  58. # prepare
  59. x, _, ratio = transform(image)
  60. x = x.unsqueeze(0).to(device)
  61. t0 = time.time()
  62. # inference
  63. outputs = model(x)
  64. scores = outputs['scores']
  65. labels = outputs['labels']
  66. bboxes = outputs['bboxes']
  67. print("detection time used ", time.time() - t0, "s")
  68. # rescale bboxes
  69. bboxes = rescale_bboxes(bboxes, orig_size, ratio)
  70. # vis detection
  71. img_processed = visualize(image=image,
  72. bboxes=bboxes,
  73. scores=scores,
  74. labels=labels,
  75. class_colors=class_colors,
  76. class_names=class_names)
  77. if args.show:
  78. h, w = img_processed.shape[:2]
  79. sw, sh = int(w*args.window_scale), int(h*args.window_scale)
  80. cv2.namedWindow('detection', 0)
  81. cv2.resizeWindow('detection', sw, sh)
  82. cv2.imshow('detection', img_processed)
  83. cv2.waitKey(0)
  84. if args.save:
  85. # save result
  86. cv2.imwrite(os.path.join(save_path, str(index).zfill(6) +'.jpg'), img_processed)
  87. if __name__ == '__main__':
  88. args = parse_args()
  89. # Set cuda
  90. if args.cuda and torch.cuda.is_available():
  91. print('use cuda')
  92. device = torch.device("cuda")
  93. else:
  94. device = torch.device("cpu")
  95. # Build config
  96. cfg = build_config(args)
  97. # Build data processor
  98. transform = build_transform(cfg, is_train=False)
  99. # Build dataset
  100. dataset = build_dataset(args, cfg, transform, is_train=False)
  101. # Build model
  102. model = build_model(args, cfg, is_val=False)
  103. # Load trained weight
  104. model = load_weight(model, args.weight, args.fuse_conv_bn)
  105. model.to(device).eval()
  106. # Compute FLOPs and Params
  107. model_copy = deepcopy(model)
  108. model_copy.trainable = False
  109. model_copy.eval()
  110. compute_flops(model_copy, cfg.test_img_size, device)
  111. del model_copy
  112. print("================= DETECT =================")
  113. # Color for beautiful visualization
  114. np.random.seed(0)
  115. class_colors = [(np.random.randint(255),
  116. np.random.randint(255),
  117. np.random.randint(255))
  118. for _ in range(cfg.num_classes)]
  119. # Run
  120. test_det(args = args,
  121. model = model,
  122. device = device,
  123. dataset = dataset,
  124. transform = transform,
  125. class_colors = class_colors,
  126. class_names = cfg.class_labels,
  127. )