voc.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import os
  2. import cv2
  3. import torch
  4. import random
  5. import numpy as np
  6. import os.path as osp
  7. import xml.etree.ElementTree as ET
  8. import torch
  9. import torch.utils.data as data
  10. try:
  11. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  12. except:
  13. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. # VOC class names
  15. VOC_CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
  16. class VOCAnnotationTransform(object):
  17. """Transforms a VOC annotation into a Tensor of bbox coords and label index
  18. Initilized with a dictionary lookup of classnames to indexes
  19. Arguments:
  20. class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
  21. (default: alphabetic indexing of VOC's 20 classes)
  22. keep_difficult (bool, optional): keep difficult instances or not
  23. (default: False)
  24. height (int): height
  25. width (int): width
  26. """
  27. def __init__(self, class_to_ind=None, keep_difficult=False):
  28. self.class_to_ind = class_to_ind or dict(
  29. zip(VOC_CLASSES, range(len(VOC_CLASSES))))
  30. self.keep_difficult = keep_difficult
  31. def __call__(self, target):
  32. """
  33. Arguments:
  34. target (annotation) : the target annotation to be made usable
  35. will be an ET.Element
  36. Returns:
  37. a list containing lists of bounding boxes [bbox coords, class name]
  38. """
  39. res = []
  40. for obj in target.iter('object'):
  41. difficult = int(obj.find('difficult').text) == 1
  42. if not self.keep_difficult and difficult:
  43. continue
  44. name = obj.find('name').text.lower().strip()
  45. bbox = obj.find('bndbox')
  46. pts = ['xmin', 'ymin', 'xmax', 'ymax']
  47. bndbox = []
  48. for i, pt in enumerate(pts):
  49. cur_pt = int(bbox.find(pt).text) - 1
  50. # scale height or width
  51. cur_pt = cur_pt if i % 2 == 0 else cur_pt
  52. bndbox.append(cur_pt)
  53. label_idx = self.class_to_ind[name]
  54. bndbox.append(label_idx)
  55. res += [bndbox] # [x1, y1, x2, y2, label_ind]
  56. return res # [[x1, y1, x2, y2, label_ind], ... ]
  57. class VOCDataset(data.Dataset):
  58. def __init__(self,
  59. img_size :int = 640,
  60. data_dir :str = None,
  61. image_sets = [('2007', 'trainval'), ('2012', 'trainval')],
  62. trans_config = None,
  63. transform = None,
  64. is_train :bool = False,
  65. load_cache :bool = False,
  66. ):
  67. # ----------- Basic parameters -----------
  68. self.img_size = img_size
  69. self.image_set = image_sets
  70. self.is_train = is_train
  71. self.target_transform = VOCAnnotationTransform()
  72. # ----------- Path parameters -----------
  73. self.root = data_dir
  74. self._annopath = osp.join('%s', 'Annotations', '%s.xml')
  75. self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
  76. # ----------- Data parameters -----------
  77. self.ids = list()
  78. for (year, name) in image_sets:
  79. rootpath = osp.join(self.root, 'VOC' + year)
  80. for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
  81. self.ids.append((rootpath, line.strip()))
  82. self.dataset_size = len(self.ids)
  83. # ----------- Transform parameters -----------
  84. self.transform = transform
  85. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  86. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  87. self.trans_config = trans_config
  88. print('==============================')
  89. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  90. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  91. print('==============================')
  92. # ----------- Cached data -----------
  93. self.load_cache = load_cache
  94. self.cached_datas = None
  95. if self.load_cache:
  96. self.cached_datas = self._load_cache()
  97. # ------------ Basic dataset function ------------
  98. def __getitem__(self, index):
  99. image, target, deltas = self.pull_item(index)
  100. return image, target, deltas
  101. def __len__(self):
  102. return self.dataset_size
  103. def _load_cache(self):
  104. data_items = []
  105. for idx in range(self.dataset_size):
  106. if idx % 2000 == 0:
  107. print("Caching images and targets : {} / {} ...".format(idx, self.dataset_size))
  108. # load a data
  109. image, target = self.load_image_target(idx)
  110. orig_h, orig_w, _ = image.shape
  111. # resize image
  112. r = self.img_size / max(orig_h, orig_w)
  113. if r != 1:
  114. interp = cv2.INTER_LINEAR
  115. new_size = (int(orig_w * r), int(orig_h * r))
  116. image = cv2.resize(image, new_size, interpolation=interp)
  117. img_h, img_w = image.shape[:2]
  118. # rescale bbox
  119. boxes = target["boxes"].copy()
  120. boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
  121. boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
  122. target["boxes"] = boxes
  123. dict_item = {}
  124. dict_item["image"] = image
  125. dict_item["target"] = target
  126. data_items.append(dict_item)
  127. return data_items
  128. # ------------ Mosaic & Mixup ------------
  129. def load_mosaic(self, index):
  130. # load 4x mosaic image
  131. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  132. id1 = index
  133. id2, id3, id4 = random.sample(index_list, 3)
  134. indexs = [id1, id2, id3, id4]
  135. # load images and targets
  136. image_list = []
  137. target_list = []
  138. for index in indexs:
  139. img_i, target_i = self.load_image_target(index)
  140. image_list.append(img_i)
  141. target_list.append(target_i)
  142. # Mosaic
  143. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  144. image, target = yolov5_mosaic_augment(
  145. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  146. return image, target
  147. def load_mixup(self, origin_image, origin_target):
  148. # YOLOv5 type Mixup
  149. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  150. new_index = np.random.randint(0, len(self.ids))
  151. new_image, new_target = self.load_mosaic(new_index)
  152. image, target = yolov5_mixup_augment(
  153. origin_image, origin_target, new_image, new_target)
  154. # YOLOX type Mixup
  155. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  156. new_index = np.random.randint(0, len(self.ids))
  157. new_image, new_target = self.load_image_target(new_index)
  158. image, target = yolox_mixup_augment(
  159. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  160. return image, target
  161. # ------------ Load data function ------------
  162. def load_image_target(self, index):
  163. # == Load a data from the cached data ==
  164. if self.cached_datas is not None:
  165. # load a data
  166. data_item = self.cached_datas[index]
  167. image = data_item["image"]
  168. target = data_item["target"]
  169. # == Load a data from the local disk ==
  170. else:
  171. # load an image
  172. image, _ = self.pull_image(index)
  173. height, width, channels = image.shape
  174. # laod an annotation
  175. anno, _ = self.pull_anno(index)
  176. # guard against no boxes via resizing
  177. anno = np.array(anno).reshape(-1, 5)
  178. target = {
  179. "boxes": anno[:, :4],
  180. "labels": anno[:, 4],
  181. "orig_size": [height, width]
  182. }
  183. return image, target
  184. def pull_item(self, index):
  185. if random.random() < self.mosaic_prob:
  186. # load a mosaic image
  187. mosaic = True
  188. image, target = self.load_mosaic(index)
  189. else:
  190. mosaic = False
  191. # load an image and target
  192. image, target = self.load_image_target(index)
  193. # MixUp
  194. if random.random() < self.mixup_prob:
  195. image, target = self.load_mixup(image, target)
  196. # augment
  197. image, target, deltas = self.transform(image, target, mosaic)
  198. return image, target, deltas
  199. def pull_image(self, index):
  200. img_id = self.ids[index]
  201. image = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
  202. return image, img_id
  203. def pull_anno(self, index):
  204. img_id = self.ids[index]
  205. anno = ET.parse(self._annopath % img_id).getroot()
  206. anno = self.target_transform(anno)
  207. return anno, img_id
  208. if __name__ == "__main__":
  209. import time
  210. import argparse
  211. from build import build_transform
  212. parser = argparse.ArgumentParser(description='VOC-Dataset')
  213. # opt
  214. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/VOCdevkit/',
  215. help='data root')
  216. parser.add_argument('-size', '--img_size', default=640, type=int,
  217. help='input image size.')
  218. parser.add_argument('--aug_type', type=str, default='ssd',
  219. help='augmentation type')
  220. parser.add_argument('--mosaic', default=0., type=float,
  221. help='mosaic augmentation.')
  222. parser.add_argument('--mixup', default=0., type=float,
  223. help='mixup augmentation.')
  224. parser.add_argument('--mixup_type', type=str, default='yolov5_mixup',
  225. help='mixup augmentation.')
  226. parser.add_argument('--is_train', action="store_true", default=False,
  227. help='mixup augmentation.')
  228. parser.add_argument('--load_cache', action="store_true", default=False,
  229. help='Path to the cached data.')
  230. args = parser.parse_args()
  231. trans_config = {
  232. 'aug_type': args.aug_type, # optional: ssd, yolov5
  233. # Basic Augment
  234. 'degrees': 0.0,
  235. 'translate': 0.2,
  236. 'scale': [0.1, 2.0],
  237. 'shear': 0.0,
  238. 'perspective': 0.0,
  239. 'hsv_h': 0.015,
  240. 'hsv_s': 0.7,
  241. 'hsv_v': 0.4,
  242. 'use_ablu': True,
  243. # Mosaic & Mixup
  244. 'mosaic_prob': args.mosaic,
  245. 'mixup_prob': args.mixup,
  246. 'mosaic_type': 'yolov5_mosaic',
  247. 'mixup_type': args.mixup_type, # optional: yolov5_mixup, yolox_mixup
  248. 'mixup_scale': [0.5, 1.5]
  249. }
  250. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  251. dataset = VOCDataset(
  252. img_size=args.img_size,
  253. data_dir=args.root,
  254. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if args.is_train else [('2007', 'test')],
  255. trans_config=trans_config,
  256. transform=transform,
  257. is_train=args.is_train,
  258. load_cache=args.load_cache
  259. )
  260. np.random.seed(0)
  261. class_colors = [(np.random.randint(255),
  262. np.random.randint(255),
  263. np.random.randint(255)) for _ in range(20)]
  264. print('Data length: ', len(dataset))
  265. for i in range(1000):
  266. t0 = time.time()
  267. image, target, deltas = dataset.pull_item(i)
  268. print("Load data: {} s".format(time.time() - t0))
  269. # to numpy
  270. image = image.permute(1, 2, 0).numpy()
  271. # to uint8
  272. image = image.astype(np.uint8)
  273. image = image.copy()
  274. img_h, img_w = image.shape[:2]
  275. boxes = target["boxes"]
  276. labels = target["labels"]
  277. for box, label in zip(boxes, labels):
  278. x1, y1, x2, y2 = box
  279. if x2 - x1 > 1 and y2 - y1 > 1:
  280. cls_id = int(label)
  281. color = class_colors[cls_id]
  282. # class name
  283. label = VOC_CLASSES[cls_id]
  284. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  285. # put the test on the bbox
  286. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  287. cv2.imshow('gt', image)
  288. # cv2.imwrite(str(i)+'.jpg', img)
  289. cv2.waitKey(0)