test.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. import argparse
  2. import cv2
  3. import os
  4. import time
  5. import numpy as np
  6. from copy import deepcopy
  7. import torch
  8. # load transform
  9. from datasets import build_dataset, build_transform
  10. # load some utils
  11. from utils.misc import load_weight, compute_flops
  12. from config import build_config
  13. from models.detectors import build_model
  14. def parse_args():
  15. parser = argparse.ArgumentParser(description='Object Detection Lab')
  16. # Basic
  17. parser.add_argument('--cuda', action='store_true', default=False,
  18. help='use cuda.')
  19. parser.add_argument('--show', action='store_true', default=False,
  20. help='show the visulization results.')
  21. parser.add_argument('--save', action='store_true', default=False,
  22. help='save the visulization results.')
  23. parser.add_argument('--save_folder', default='det_results/', type=str,
  24. help='Dir to save results')
  25. parser.add_argument('-vt', '--visual_threshold', default=0.3, type=float,
  26. help='Final confidence threshold')
  27. parser.add_argument('-ws', '--window_scale', default=1.0, type=float,
  28. help='resize window of cv2 for visualization.')
  29. parser.add_argument('--resave', action='store_true', default=False,
  30. help='resave checkpoints without optimizer state dict.')
  31. # Model
  32. parser.add_argument('-m', '--model', default='yolof_r18_c5_1x', type=str,
  33. help='build detector')
  34. parser.add_argument('--weight', default=None,
  35. type=str, help='Trained state_dict file path to open')
  36. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  37. help='fuse Conv & BN')
  38. # Dataset
  39. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  40. help='data root')
  41. parser.add_argument('-d', '--dataset', default='coco',
  42. help='coco, voc.')
  43. return parser.parse_args()
  44. def plot_bbox_labels(img, bbox, label=None, cls_color=None, text_scale=0.4):
  45. x1, y1, x2, y2 = bbox
  46. x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
  47. t_size = cv2.getTextSize(label, 0, fontScale=1, thickness=2)[0]
  48. # plot bbox
  49. cv2.rectangle(img, (x1, y1), (x2, y2), cls_color, 2)
  50. if label is not None:
  51. # plot title bbox
  52. cv2.rectangle(img, (x1, y1-t_size[1]), (int(x1 + t_size[0] * text_scale), y1), cls_color, -1)
  53. # put the test on the title bbox
  54. cv2.putText(img, label, (int(x1), int(y1 - 5)), 0, text_scale, (0, 0, 0), 1, lineType=cv2.LINE_AA)
  55. return img
  56. def visualize(img,
  57. bboxes,
  58. scores,
  59. labels,
  60. vis_thresh,
  61. class_colors,
  62. class_names):
  63. ts = 0.4
  64. for i, bbox in enumerate(bboxes):
  65. if scores[i] > vis_thresh:
  66. cls_id = int(labels[i])
  67. cls_color = class_colors[cls_id]
  68. mess = '%s: %.2f' % (class_names[cls_id], scores[i])
  69. img = plot_bbox_labels(img, bbox, mess, cls_color, text_scale=ts)
  70. return img
  71. @torch.no_grad()
  72. def run(args, model, device, dataset, transform, class_colors, class_names):
  73. num_images = len(dataset)
  74. save_path = os.path.join('det_results/', args.dataset, args.model)
  75. os.makedirs(save_path, exist_ok=True)
  76. for index, (image, _) in enumerate(dataset):
  77. print('Testing image {:d}/{:d}....'.format(index+1, num_images))
  78. orig_h, orig_w = image.height, image.width
  79. # PreProcess
  80. x, _ = transform(image)
  81. x = x.unsqueeze(0).to(device)
  82. # Inference
  83. t0 = time.time()
  84. bboxes, scores, labels = model(x)
  85. print("Infer. time: {}".format(time.time() - t0, "s"))
  86. # Rescale bboxes
  87. bboxes[..., 0::2] *= orig_w
  88. bboxes[..., 1::2] *= orig_h
  89. # vis detection
  90. image = np.array(image).astype(np.uint8)
  91. image = image[..., (2, 1, 0)].copy()
  92. img_processed = visualize(
  93. image, bboxes, scores, labels, args.visual_threshold, class_colors, class_names)
  94. if args.show:
  95. h, w = img_processed.shape[:2]
  96. sw, sh = int(w*args.window_scale), int(h*args.window_scale)
  97. cv2.namedWindow('detection', 0)
  98. cv2.resizeWindow('detection', sw, sh)
  99. cv2.imshow('detection', img_processed)
  100. cv2.waitKey(0)
  101. if args.save:
  102. # save result
  103. cv2.imwrite(os.path.join(save_path, str(index).zfill(6) +'.jpg'), img_processed)
  104. if __name__ == '__main__':
  105. args = parse_args()
  106. # cuda
  107. if args.cuda:
  108. print('use cuda')
  109. device = torch.device("cuda")
  110. else:
  111. device = torch.device("cpu")
  112. # Dataset & Model Config
  113. cfg = build_config(args)
  114. # Transform
  115. transform = build_transform(cfg, is_train=False)
  116. # Dataset
  117. dataset, dataset_info = build_dataset(args, is_train=False)
  118. np.random.seed(0)
  119. class_colors = [(np.random.randint(255),
  120. np.random.randint(255),
  121. np.random.randint(255))
  122. for _ in range(dataset_info['num_classes'])]
  123. # Model
  124. model = build_model(args, cfg, dataset_info['num_classes'], is_val=False)
  125. model = load_weight(model, args.weight, args.fuse_conv_bn)
  126. model.to(device).eval()
  127. # Compute FLOPs and Params
  128. model_copy = deepcopy(model)
  129. model_copy.trainable = False
  130. model_copy.eval()
  131. compute_flops(
  132. model=model_copy,
  133. min_size=cfg['test_min_size'],
  134. max_size=cfg['test_max_size'],
  135. device=device)
  136. del model_copy
  137. # Resave model weight
  138. if args.resave:
  139. print('Resave: {}'.format(args.model.upper()))
  140. checkpoint = torch.load(args.weight, map_location='cpu')
  141. output_dir = 'weights/{}/{}/'.format(args.dataset, args.model)
  142. os.makedirs(output_dir, exist_ok=True)
  143. checkpoint_path = os.path.join(output_dir, "{}_pure.pth".format(args.model))
  144. torch.save({'model': model.state_dict(),
  145. 'mAP': checkpoint.pop("mAP"),
  146. 'epoch': checkpoint.pop("epoch")},
  147. checkpoint_path)
  148. print("================= DETECT =================")
  149. # run
  150. run(args, model, device, dataset, transform, class_colors, dataset_info['class_labels'])