test.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import argparse
  2. import cv2
  3. import os
  4. import time
  5. import numpy as np
  6. from copy import deepcopy
  7. import torch
  8. # load transform
  9. from dataset.build import build_dataset, build_transform
  10. # load some utils
  11. from utils.misc import load_weight, compute_flops
  12. from utils.box_ops import rescale_bboxes
  13. from utils.vis_tools import visualize
  14. from config import build_dataset_config, build_model_config, build_trans_config
  15. from models.detectors import build_model
  16. def parse_args():
  17. parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
  18. # Basic setting
  19. parser.add_argument('-size', '--img_size', default=640, type=int,
  20. help='the max size of input image')
  21. parser.add_argument('--show', action='store_true', default=False,
  22. help='show the visulization results.')
  23. parser.add_argument('--save', action='store_true', default=False,
  24. help='save the visulization results.')
  25. parser.add_argument('--cuda', action='store_true', default=False,
  26. help='use cuda.')
  27. parser.add_argument('--save_folder', default='det_results/', type=str,
  28. help='Dir to save results')
  29. parser.add_argument('-ws', '--window_scale', default=1.0, type=float,
  30. help='resize window of cv2 for visualization.')
  31. parser.add_argument('--resave', action='store_true', default=False,
  32. help='resave checkpoints without optimizer state dict.')
  33. # Model setting
  34. parser.add_argument('-m', '--model', default='yolov1', type=str,
  35. help='build yolo')
  36. parser.add_argument('--weight', default=None,
  37. type=str, help='Trained state_dict file path to open')
  38. parser.add_argument('-ct', '--conf_thresh', default=0.3, type=float,
  39. help='confidence threshold')
  40. parser.add_argument('-nt', '--nms_thresh', default=0.5, type=float,
  41. help='NMS threshold')
  42. parser.add_argument('--topk', default=100, type=int,
  43. help='topk candidates dets of each level before NMS')
  44. parser.add_argument("--no_decode", action="store_true", default=False,
  45. help="not decode in inference or yes")
  46. parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
  47. help='fuse Conv & BN')
  48. parser.add_argument('--no_multi_labels', action='store_true', default=False,
  49. help='Perform post-process with multi-labels trick.')
  50. parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
  51. help='Perform NMS operations regardless of category.')
  52. # Data setting
  53. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',
  54. help='data root')
  55. parser.add_argument('-d', '--dataset', default='coco',
  56. help='coco, voc.')
  57. parser.add_argument('--min_box_size', default=8.0, type=float,
  58. help='min size of target bounding box.')
  59. parser.add_argument('--mosaic', default=None, type=float,
  60. help='mosaic augmentation.')
  61. parser.add_argument('--mixup', default=None, type=float,
  62. help='mixup augmentation.')
  63. parser.add_argument('--load_cache', action='store_true', default=False,
  64. help='load data into memory.')
  65. # Task setting
  66. parser.add_argument('-t', '--task', default='det', choices=['det', 'det_seg', 'det_pos', 'det_seg_pos'],
  67. help='task type.')
  68. return parser.parse_args()
  69. @torch.no_grad()
  70. def test_det(args,
  71. model,
  72. device,
  73. dataset,
  74. transform=None,
  75. class_colors=None,
  76. class_names=None,
  77. class_indexs=None):
  78. num_images = len(dataset)
  79. save_path = os.path.join('det_results/', args.dataset, args.model)
  80. os.makedirs(save_path, exist_ok=True)
  81. for index in range(num_images):
  82. print('Testing image {:d}/{:d}....'.format(index+1, num_images))
  83. image, _ = dataset.pull_image(index)
  84. orig_h, orig_w, _ = image.shape
  85. # prepare
  86. x, _, ratio = transform(image)
  87. x = x.unsqueeze(0).to(device) / 255.
  88. t0 = time.time()
  89. # inference
  90. outputs = model(x)
  91. scores = outputs['scores']
  92. labels = outputs['labels']
  93. bboxes = outputs['bboxes']
  94. print("detection time used ", time.time() - t0, "s")
  95. # rescale bboxes
  96. bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
  97. # vis detection
  98. img_processed = visualize(image=image,
  99. bboxes=bboxes,
  100. scores=scores,
  101. labels=labels,
  102. class_colors=class_colors,
  103. class_names=class_names,
  104. class_indexs=class_indexs)
  105. if args.show:
  106. h, w = img_processed.shape[:2]
  107. sw, sh = int(w*args.window_scale), int(h*args.window_scale)
  108. cv2.namedWindow('detection', 0)
  109. cv2.resizeWindow('detection', sw, sh)
  110. cv2.imshow('detection', img_processed)
  111. cv2.waitKey(0)
  112. if args.save:
  113. # save result
  114. cv2.imwrite(os.path.join(save_path, str(index).zfill(6) +'.jpg'), img_processed)
  115. @torch.no_grad()
  116. def test_det_seg():
  117. pass
  118. @torch.no_grad()
  119. def test_det_pos():
  120. pass
  121. @torch.no_grad()
  122. def test_det_seg_pos():
  123. pass
  124. if __name__ == '__main__':
  125. args = parse_args()
  126. # cuda
  127. if args.cuda:
  128. print('use cuda')
  129. device = torch.device("cuda")
  130. else:
  131. device = torch.device("cpu")
  132. # Dataset & Model Config
  133. data_cfg = build_dataset_config(args)
  134. model_cfg = build_model_config(args)
  135. trans_cfg = build_trans_config(model_cfg['trans_type'])
  136. # Transform
  137. val_transform, trans_cfg = build_transform(args, trans_cfg, model_cfg['max_stride'], is_train=False)
  138. # Dataset
  139. dataset, dataset_info = build_dataset(args, data_cfg, trans_cfg, val_transform, is_train=False)
  140. num_classes = dataset_info['num_classes']
  141. np.random.seed(0)
  142. class_colors = [(np.random.randint(255),
  143. np.random.randint(255),
  144. np.random.randint(255)) for _ in range(num_classes)]
  145. # build model
  146. model = build_model(args, model_cfg, device, num_classes, False)
  147. # load trained weight
  148. model = load_weight(model, args.weight, args.fuse_conv_bn)
  149. model.to(device).eval()
  150. # compute FLOPs and Params
  151. model_copy = deepcopy(model)
  152. model_copy.trainable = False
  153. model_copy.eval()
  154. compute_flops(
  155. model=model_copy,
  156. img_size=args.img_size,
  157. device=device)
  158. del model_copy
  159. # resave model weight
  160. if args.resave:
  161. print('Resave: {}'.format(args.model.upper()))
  162. checkpoint = torch.load(args.weight, map_location='cpu')
  163. checkpoint_path = 'weights/{}/{}/{}_pure.pth'.format(args.dataset, args.model, args.model)
  164. torch.save({'model': model.state_dict(),
  165. 'mAP': checkpoint.pop("mAP"),
  166. 'epoch': checkpoint.pop("epoch")},
  167. checkpoint_path)
  168. print("================= DETECT =================")
  169. # run
  170. if args.task == "det":
  171. test_det(args=args,
  172. model=model,
  173. device=device,
  174. dataset=dataset,
  175. transform=val_transform,
  176. class_colors=class_colors,
  177. class_names=dataset_info['class_names'],
  178. class_indexs=dataset_info['class_indexs'],
  179. )
  180. elif args.task == "det_seg":
  181. test_det_seg()
  182. elif args.task == "det_pos":
  183. test_det_pos()
  184. elif args.task == "det_seg_pos":
  185. test_det_seg_pos()