test.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import argparse
  2. import cv2
  3. import os
  4. import time
  5. import numpy as np
  6. from copy import deepcopy
  7. import torch
  8. # load transform
  9. from dataset.build import build_dataset, build_transform
  10. # load some utils
  11. from utils.misc import load_weight, compute_flops
  12. from utils.box_ops import rescale_bboxes
  13. from config import build_dataset_config, build_model_config, build_trans_config
  14. from models.detectors import build_model
  15. def parse_args():
  16. parser = argparse.ArgumentParser(description='YOLO-Tutorial')
  17. # basic
  18. parser.add_argument('-size', '--img_size', default=640, type=int,
  19. help='the max size of input image')
  20. parser.add_argument('--show', action='store_true', default=False,
  21. help='show the visulization results.')
  22. parser.add_argument('--save', action='store_true', default=False,
  23. help='save the visulization results.')
  24. parser.add_argument('--cuda', action='store_true', default=False,
  25. help='use cuda.')
  26. parser.add_argument('--save_folder', default='det_results/', type=str,
  27. help='Dir to save results')
  28. parser.add_argument('-vt', '--visual_threshold', default=0.3, type=float,
  29. help='Final confidence threshold')
  30. parser.add_argument('-ws', '--window_scale', default=1.0, type=float,
  31. help='resize window of cv2 for visualization.')
  32. parser.add_argument('--resave', action='store_true', default=False,
  33. help='resave checkpoints without optimizer state dict.')
  34. # model
  35. parser.add_argument('-m', '--model', default='yolov1', type=str,
  36. help='build yolo')
  37. parser.add_argument('--weight', default=None,
  38. type=str, help='Trained state_dict file path to open')
  39. parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float,
  40. help='confidence threshold')
  41. parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
  42. help='NMS threshold')
  43. parser.add_argument('--topk', default=100, type=int,
  44. help='topk candidates dets of each level before NMS')
  45. parser.add_argument("--no_decode", action="store_true", default=False,
  46. help="not decode in inference or yes")
  47. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  48. help='fuse Conv & BN')
  49. parser.add_argument('--no_multi_labels', action='store_true', default=False,
  50. help='Perform post-process with multi-labels trick.')
  51. parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
  52. help='Perform NMS operations regardless of category.')
  53. # dataset
  54. parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
  55. help='data root')
  56. parser.add_argument('-d', '--dataset', default='coco',
  57. help='coco, voc.')
  58. parser.add_argument('--min_box_size', default=8.0, type=float,
  59. help='min size of target bounding box.')
  60. parser.add_argument('--mosaic', default=None, type=float,
  61. help='mosaic augmentation.')
  62. parser.add_argument('--mixup', default=None, type=float,
  63. help='mixup augmentation.')
  64. parser.add_argument('--load_cache', action='store_true', default=False,
  65. help='load data into memory.')
  66. return parser.parse_args()
  67. def plot_bbox_labels(img, bbox, label=None, cls_color=None, text_scale=0.4):
  68. x1, y1, x2, y2 = bbox
  69. x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
  70. t_size = cv2.getTextSize(label, 0, fontScale=1, thickness=2)[0]
  71. # plot bbox
  72. cv2.rectangle(img, (x1, y1), (x2, y2), cls_color, 2)
  73. if label is not None:
  74. # plot title bbox
  75. cv2.rectangle(img, (x1, y1-t_size[1]), (int(x1 + t_size[0] * text_scale), y1), cls_color, -1)
  76. # put the test on the title bbox
  77. cv2.putText(img, label, (int(x1), int(y1 - 5)), 0, text_scale, (0, 0, 0), 1, lineType=cv2.LINE_AA)
  78. return img
  79. def visualize(img,
  80. bboxes,
  81. scores,
  82. labels,
  83. vis_thresh,
  84. class_colors,
  85. class_names,
  86. class_indexs=None,
  87. dataset_name='voc'):
  88. ts = 0.4
  89. for i, bbox in enumerate(bboxes):
  90. if scores[i] > vis_thresh:
  91. cls_id = int(labels[i])
  92. if dataset_name == 'coco':
  93. cls_color = class_colors[cls_id]
  94. cls_id = class_indexs[cls_id]
  95. else:
  96. cls_color = class_colors[cls_id]
  97. mess = '%s: %.2f' % (class_names[cls_id], scores[i])
  98. img = plot_bbox_labels(img, bbox, mess, cls_color, text_scale=ts)
  99. return img
  100. @torch.no_grad()
  101. def test(args,
  102. model,
  103. device,
  104. dataset,
  105. transform=None,
  106. class_colors=None,
  107. class_names=None,
  108. class_indexs=None):
  109. num_images = len(dataset)
  110. save_path = os.path.join('det_results/', args.dataset, args.model)
  111. os.makedirs(save_path, exist_ok=True)
  112. for index in range(num_images):
  113. print('Testing image {:d}/{:d}....'.format(index+1, num_images))
  114. image, _ = dataset.pull_image(index)
  115. orig_h, orig_w, _ = image.shape
  116. # prepare
  117. x, _, deltas = transform(image)
  118. x = x.unsqueeze(0).to(device) / 255.
  119. t0 = time.time()
  120. # inference
  121. bboxes, scores, labels = model(x)
  122. print("detection time used ", time.time() - t0, "s")
  123. # rescale bboxes
  124. origin_img_size = [orig_h, orig_w]
  125. cur_img_size = [*x.shape[-2:]]
  126. bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
  127. # vis detection
  128. img_processed = visualize(
  129. img=image,
  130. bboxes=bboxes,
  131. scores=scores,
  132. labels=labels,
  133. vis_thresh=args.visual_threshold,
  134. class_colors=class_colors,
  135. class_names=class_names,
  136. class_indexs=class_indexs,
  137. dataset_name=args.dataset)
  138. if args.show:
  139. h, w = img_processed.shape[:2]
  140. sw, sh = int(w*args.window_scale), int(h*args.window_scale)
  141. cv2.namedWindow('detection', 0)
  142. cv2.resizeWindow('detection', sw, sh)
  143. cv2.imshow('detection', img_processed)
  144. cv2.waitKey(0)
  145. if args.save:
  146. # save result
  147. cv2.imwrite(os.path.join(save_path, str(index).zfill(6) +'.jpg'), img_processed)
  148. if __name__ == '__main__':
  149. args = parse_args()
  150. # cuda
  151. if args.cuda:
  152. print('use cuda')
  153. device = torch.device("cuda")
  154. else:
  155. device = torch.device("cpu")
  156. # Dataset & Model Config
  157. data_cfg = build_dataset_config(args)
  158. model_cfg = build_model_config(args)
  159. trans_cfg = build_trans_config(model_cfg['trans_type'])
  160. # Transform
  161. val_transform, trans_cfg = build_transform(args, trans_cfg, model_cfg['max_stride'], is_train=False)
  162. # Dataset
  163. dataset, dataset_info = build_dataset(args, data_cfg, trans_cfg, val_transform, is_train=False)
  164. num_classes = dataset_info['num_classes']
  165. np.random.seed(0)
  166. class_colors = [(np.random.randint(255),
  167. np.random.randint(255),
  168. np.random.randint(255)) for _ in range(num_classes)]
  169. # build model
  170. model = build_model(args, model_cfg, device, num_classes, False)
  171. # load trained weight
  172. model = load_weight(model, args.weight, args.fuse_conv_bn)
  173. model.to(device).eval()
  174. # compute FLOPs and Params
  175. model_copy = deepcopy(model)
  176. model_copy.trainable = False
  177. model_copy.eval()
  178. compute_flops(
  179. model=model_copy,
  180. img_size=args.img_size,
  181. device=device)
  182. del model_copy
  183. # resave model weight
  184. if args.resave:
  185. print('Resave: {}'.format(args.model.upper()))
  186. checkpoint = torch.load(args.weight, map_location='cpu')
  187. checkpoint_path = 'weights/{}/{}/{}_pure.pth'.format(args.dataset, args.model, args.model)
  188. torch.save({'model': model.state_dict(),
  189. 'mAP': checkpoint.pop("mAP"),
  190. 'epoch': checkpoint.pop("epoch")},
  191. checkpoint_path)
  192. print("================= DETECT =================")
  193. # run
  194. test(args=args,
  195. model=model,
  196. device=device,
  197. dataset=dataset,
  198. transform=val_transform,
  199. class_colors=class_colors,
  200. class_names=dataset_info['class_names'],
  201. class_indexs=dataset_info['class_indexs'],
  202. )