voc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import cv2
  2. import random
  3. import numpy as np
  4. import os.path as osp
  5. import xml.etree.ElementTree as ET
  6. import torch
  7. import torch.utils.data as data
  8. try:
  9. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  10. except:
  11. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  12. # VOC class names
  13. VOC_CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
  14. class VOCAnnotationTransform(object):
  15. """Transforms a VOC annotation into a Tensor of bbox coords and label index
  16. Initilized with a dictionary lookup of classnames to indexes
  17. Arguments:
  18. class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
  19. (default: alphabetic indexing of VOC's 20 classes)
  20. keep_difficult (bool, optional): keep difficult instances or not
  21. (default: False)
  22. height (int): height
  23. width (int): width
  24. """
  25. def __init__(self, class_to_ind=None, keep_difficult=False):
  26. self.class_to_ind = class_to_ind or dict(
  27. zip(VOC_CLASSES, range(len(VOC_CLASSES))))
  28. self.keep_difficult = keep_difficult
  29. def __call__(self, target):
  30. """
  31. Arguments:
  32. target (annotation) : the target annotation to be made usable
  33. will be an ET.Element
  34. Returns:
  35. a list containing lists of bounding boxes [bbox coords, class name]
  36. """
  37. res = []
  38. for obj in target.iter('object'):
  39. difficult = int(obj.find('difficult').text) == 1
  40. if not self.keep_difficult and difficult:
  41. continue
  42. name = obj.find('name').text.lower().strip()
  43. bbox = obj.find('bndbox')
  44. pts = ['xmin', 'ymin', 'xmax', 'ymax']
  45. bndbox = []
  46. for i, pt in enumerate(pts):
  47. cur_pt = int(bbox.find(pt).text) - 1
  48. # scale height or width
  49. cur_pt = cur_pt if i % 2 == 0 else cur_pt
  50. bndbox.append(cur_pt)
  51. label_idx = self.class_to_ind[name]
  52. bndbox.append(label_idx)
  53. res += [bndbox] # [x1, y1, x2, y2, label_ind]
  54. return res # [[x1, y1, x2, y2, label_ind], ... ]
  55. class VOCDataset(data.Dataset):
  56. def __init__(self,
  57. img_size :int = 640,
  58. data_dir :str = None,
  59. image_sets = [('2007', 'trainval'), ('2012', 'trainval')],
  60. trans_config = None,
  61. transform = None,
  62. is_train :bool = False,
  63. load_cache :bool = False,
  64. ):
  65. # ----------- Basic parameters -----------
  66. self.img_size = img_size
  67. self.image_set = image_sets
  68. self.is_train = is_train
  69. self.target_transform = VOCAnnotationTransform()
  70. # ----------- Path parameters -----------
  71. self.root = data_dir
  72. self._annopath = osp.join('%s', 'Annotations', '%s.xml')
  73. self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
  74. # ----------- Data parameters -----------
  75. self.ids = list()
  76. for (year, name) in image_sets:
  77. rootpath = osp.join(self.root, 'VOC' + year)
  78. for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
  79. self.ids.append((rootpath, line.strip()))
  80. self.dataset_size = len(self.ids)
  81. # ----------- Transform parameters -----------
  82. self.transform = transform
  83. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  84. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  85. self.trans_config = trans_config
  86. print('==============================')
  87. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  88. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  89. print('==============================')
  90. # ----------- Cached data -----------
  91. self.load_cache = load_cache
  92. self.cached_datas = None
  93. if self.load_cache:
  94. self.cached_datas = self._load_cache()
  95. # ------------ Basic dataset function ------------
  96. def __getitem__(self, index):
  97. image, target, deltas = self.pull_item(index)
  98. return image, target, deltas
  99. def __len__(self):
  100. return self.dataset_size
  101. def _load_cache(self):
  102. data_items = []
  103. for idx in range(self.dataset_size):
  104. if idx % 2000 == 0:
  105. print("Caching images and targets : {} / {} ...".format(idx, self.dataset_size))
  106. # load a data
  107. image, target = self.load_image_target(idx)
  108. orig_h, orig_w, _ = image.shape
  109. # resize image
  110. r = self.img_size / max(orig_h, orig_w)
  111. if r != 1:
  112. interp = cv2.INTER_LINEAR
  113. new_size = (int(orig_w * r), int(orig_h * r))
  114. image = cv2.resize(image, new_size, interpolation=interp)
  115. img_h, img_w = image.shape[:2]
  116. # rescale bbox
  117. boxes = target["boxes"].copy()
  118. boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
  119. boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
  120. target["boxes"] = boxes
  121. dict_item = {}
  122. dict_item["image"] = image
  123. dict_item["target"] = target
  124. data_items.append(dict_item)
  125. return data_items
  126. # ------------ Mosaic & Mixup ------------
  127. def load_mosaic(self, index):
  128. # load 4x mosaic image
  129. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  130. id1 = index
  131. id2, id3, id4 = random.sample(index_list, 3)
  132. indexs = [id1, id2, id3, id4]
  133. # load images and targets
  134. image_list = []
  135. target_list = []
  136. for index in indexs:
  137. img_i, target_i = self.load_image_target(index)
  138. image_list.append(img_i)
  139. target_list.append(target_i)
  140. # Mosaic
  141. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  142. image, target = yolov5_mosaic_augment(
  143. image_list, target_list, self.img_size, self.trans_config, self.trans_config['mosaic_keep_ratio'], self.is_train)
  144. return image, target
  145. def load_mixup(self, origin_image, origin_target):
  146. # YOLOv5 type Mixup
  147. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  148. new_index = np.random.randint(0, len(self.ids))
  149. new_image, new_target = self.load_mosaic(new_index)
  150. image, target = yolov5_mixup_augment(
  151. origin_image, origin_target, new_image, new_target)
  152. # YOLOX type Mixup
  153. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  154. new_index = np.random.randint(0, len(self.ids))
  155. new_image, new_target = self.load_image_target(new_index)
  156. image, target = yolox_mixup_augment(
  157. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  158. return image, target
  159. # ------------ Load data function ------------
  160. def load_image_target(self, index):
  161. # == Load a data from the cached data ==
  162. if self.cached_datas is not None:
  163. # load a data
  164. data_item = self.cached_datas[index]
  165. image = data_item["image"]
  166. target = data_item["target"]
  167. # == Load a data from the local disk ==
  168. else:
  169. # load an image
  170. image, _ = self.pull_image(index)
  171. height, width, channels = image.shape
  172. # laod an annotation
  173. anno, _ = self.pull_anno(index)
  174. # guard against no boxes via resizing
  175. anno = np.array(anno).reshape(-1, 5)
  176. target = {
  177. "boxes": anno[:, :4],
  178. "labels": anno[:, 4],
  179. "orig_size": [height, width]
  180. }
  181. return image, target
  182. def pull_item(self, index):
  183. if random.random() < self.mosaic_prob:
  184. # load a mosaic image
  185. mosaic = True
  186. image, target = self.load_mosaic(index)
  187. else:
  188. mosaic = False
  189. # load an image and target
  190. image, target = self.load_image_target(index)
  191. # MixUp
  192. if random.random() < self.mixup_prob:
  193. image, target = self.load_mixup(image, target)
  194. # augment
  195. image, target, deltas = self.transform(image, target, mosaic)
  196. return image, target, deltas
  197. def pull_image(self, index):
  198. img_id = self.ids[index]
  199. image = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
  200. return image, img_id
  201. def pull_anno(self, index):
  202. img_id = self.ids[index]
  203. anno = ET.parse(self._annopath % img_id).getroot()
  204. anno = self.target_transform(anno)
  205. return anno, img_id
  206. if __name__ == "__main__":
  207. import time
  208. import argparse
  209. from build import build_transform
  210. parser = argparse.ArgumentParser(description='VOC-Dataset')
  211. # opt
  212. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/VOCdevkit/',
  213. help='data root')
  214. parser.add_argument('-size', '--img_size', default=640, type=int,
  215. help='input image size.')
  216. parser.add_argument('--aug_type', type=str, default='ssd',
  217. help='augmentation type: ssd, yolov5, rtdetr.')
  218. parser.add_argument('--mosaic', default=0., type=float,
  219. help='mosaic augmentation.')
  220. parser.add_argument('--mixup', default=0., type=float,
  221. help='mixup augmentation.')
  222. parser.add_argument('--mixup_type', type=str, default='yolov5_mixup',
  223. help='mixup augmentation.')
  224. parser.add_argument('--is_train', action="store_true", default=False,
  225. help='mixup augmentation.')
  226. parser.add_argument('--load_cache', action="store_true", default=False,
  227. help='Path to the cached data.')
  228. args = parser.parse_args()
  229. trans_config = {
  230. 'aug_type': args.aug_type, # optional: ssd, yolov5
  231. 'pixel_mean': [123.675, 116.28, 103.53],
  232. 'pixel_std': [58.395, 57.12, 57.375],
  233. # Basic Augment
  234. 'degrees': 0.0,
  235. 'translate': 0.2,
  236. 'scale': [0.1, 2.0],
  237. 'shear': 0.0,
  238. 'perspective': 0.0,
  239. 'hsv_h': 0.015,
  240. 'hsv_s': 0.7,
  241. 'hsv_v': 0.4,
  242. 'use_ablu': True,
  243. # Mosaic & Mixup
  244. 'mosaic_prob': args.mosaic,
  245. 'mixup_prob': args.mixup,
  246. 'mosaic_type': 'yolov5_mosaic',
  247. 'mixup_type': args.mixup_type, # optional: yolov5_mixup, yolox_mixup
  248. 'mosaic_keep_ratio': False,
  249. 'mixup_scale': [0.5, 1.5]
  250. }
  251. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  252. pixel_mean = transform.pixel_mean
  253. pixel_std = transform.pixel_std
  254. color_format = transform.color_format
  255. dataset = VOCDataset(
  256. img_size=args.img_size,
  257. data_dir=args.root,
  258. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if args.is_train else [('2007', 'test')],
  259. trans_config=trans_config,
  260. transform=transform,
  261. is_train=args.is_train,
  262. load_cache=args.load_cache
  263. )
  264. np.random.seed(0)
  265. class_colors = [(np.random.randint(255),
  266. np.random.randint(255),
  267. np.random.randint(255)) for _ in range(20)]
  268. print('Data length: ', len(dataset))
  269. for i in range(1000):
  270. t0 = time.time()
  271. image, target, deltas = dataset.pull_item(i)
  272. print("Load data: {} s".format(time.time() - t0))
  273. # to numpy
  274. image = image.permute(1, 2, 0).numpy()
  275. # denormalize
  276. image = image * pixel_std + pixel_mean
  277. if color_format == 'rgb':
  278. # RGB to BGR
  279. image = image[..., (2, 1, 0)]
  280. # to uint8
  281. image = image.astype(np.uint8)
  282. image = image.copy()
  283. img_h, img_w = image.shape[:2]
  284. boxes = target["boxes"]
  285. labels = target["labels"]
  286. for box, label in zip(boxes, labels):
  287. x1, y1, x2, y2 = box
  288. if x2 - x1 > 1 and y2 - y1 > 1:
  289. cls_id = int(label)
  290. color = class_colors[cls_id]
  291. # class name
  292. label = VOC_CLASSES[cls_id]
  293. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  294. # put the test on the bbox
  295. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  296. cv2.imshow('gt', image)
  297. # cv2.imwrite(str(i)+'.jpg', img)
  298. cv2.waitKey(0)