voc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. """VOC Dataset Classes
  2. Original author: Francisco Massa
  3. https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py
  4. Updated by: Ellis Brown, Max deGroot
  5. """
  6. import os.path as osp
  7. import random
  8. import torch.utils.data as data
  9. import cv2
  10. import numpy as np
  11. import xml.etree.ElementTree as ET
  12. try:
  13. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. except:
  15. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  16. VOC_CLASSES = ( # always index 0
  17. 'aeroplane', 'bicycle', 'bird', 'boat',
  18. 'bottle', 'bus', 'car', 'cat', 'chair',
  19. 'cow', 'diningtable', 'dog', 'horse',
  20. 'motorbike', 'person', 'pottedplant',
  21. 'sheep', 'sofa', 'train', 'tvmonitor')
  22. class VOCAnnotationTransform(object):
  23. """Transforms a VOC annotation into a Tensor of bbox coords and label index
  24. Initilized with a dictionary lookup of classnames to indexes
  25. Arguments:
  26. class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
  27. (default: alphabetic indexing of VOC's 20 classes)
  28. keep_difficult (bool, optional): keep difficult instances or not
  29. (default: False)
  30. height (int): height
  31. width (int): width
  32. """
  33. def __init__(self, class_to_ind=None, keep_difficult=False):
  34. self.class_to_ind = class_to_ind or dict(
  35. zip(VOC_CLASSES, range(len(VOC_CLASSES))))
  36. self.keep_difficult = keep_difficult
  37. def __call__(self, target):
  38. """
  39. Arguments:
  40. target (annotation) : the target annotation to be made usable
  41. will be an ET.Element
  42. Returns:
  43. a list containing lists of bounding boxes [bbox coords, class name]
  44. """
  45. res = []
  46. for obj in target.iter('object'):
  47. difficult = int(obj.find('difficult').text) == 1
  48. if not self.keep_difficult and difficult:
  49. continue
  50. name = obj.find('name').text.lower().strip()
  51. bbox = obj.find('bndbox')
  52. pts = ['xmin', 'ymin', 'xmax', 'ymax']
  53. bndbox = []
  54. for i, pt in enumerate(pts):
  55. cur_pt = int(bbox.find(pt).text) - 1
  56. # scale height or width
  57. cur_pt = cur_pt if i % 2 == 0 else cur_pt
  58. bndbox.append(cur_pt)
  59. label_idx = self.class_to_ind[name]
  60. bndbox.append(label_idx)
  61. res += [bndbox] # [x1, y1, x2, y2, label_ind]
  62. return res # [[x1, y1, x2, y2, label_ind], ... ]
  63. class VOCDetection(data.Dataset):
  64. """VOC Detection Dataset Object
  65. input is image, target is annotation
  66. Arguments:
  67. root (string): filepath to VOCdevkit folder.
  68. image_set (string): imageset to use (eg. 'train', 'val', 'test')
  69. transform (callable, optional): transformation to perform on the
  70. input image
  71. target_transform (callable, optional): transformation to perform on the
  72. target `annotation`
  73. (eg: take in caption string, return tensor of word indices)
  74. dataset_name (string, optional): which dataset to load
  75. (default: 'VOC2007')
  76. """
  77. def __init__(self,
  78. img_size=640,
  79. data_dir=None,
  80. image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
  81. trans_config=None,
  82. transform=None,
  83. is_train=False,
  84. load_cache=False
  85. ):
  86. self.root = data_dir
  87. self.img_size = img_size
  88. self.image_set = image_sets
  89. self.target_transform = VOCAnnotationTransform()
  90. self._annopath = osp.join('%s', 'Annotations', '%s.xml')
  91. self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
  92. self.ids = list()
  93. self.is_train = is_train
  94. self.load_cache = load_cache
  95. for (year, name) in image_sets:
  96. rootpath = osp.join(self.root, 'VOC' + year)
  97. for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
  98. self.ids.append((rootpath, line.strip()))
  99. # augmentation
  100. self.transform = transform
  101. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  102. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  103. self.trans_config = trans_config
  104. print('==============================')
  105. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  106. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  107. print('==============================')
  108. # load cache data
  109. if load_cache:
  110. self._load_cache()
  111. def __getitem__(self, index):
  112. image, target, deltas = self.pull_item(index)
  113. return image, target, deltas
  114. def __len__(self):
  115. return len(self.ids)
  116. def _load_cache(self):
  117. # load image cache
  118. self.cached_images = []
  119. self.cached_targets = []
  120. dataset_size = len(self.ids)
  121. print('loading data into memory ...')
  122. for i in range(dataset_size):
  123. if i % 5000 == 0:
  124. print("[{} / {}]".format(i, dataset_size))
  125. # load an image
  126. image, image_id = self.pull_image(i)
  127. orig_h, orig_w, _ = image.shape
  128. # resize image
  129. r = self.img_size / max(orig_h, orig_w)
  130. if r != 1:
  131. interp = cv2.INTER_LINEAR
  132. new_size = (int(orig_w * r), int(orig_h * r))
  133. image = cv2.resize(image, new_size, interpolation=interp)
  134. img_h, img_w = image.shape[:2]
  135. self.cached_images.append(image)
  136. # load target cache
  137. anno = ET.parse(self._annopath % image_id).getroot()
  138. anno = self.target_transform(anno)
  139. anno = np.array(anno).reshape(-1, 5)
  140. boxes = anno[:, :4]
  141. labels = anno[:, 4]
  142. boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
  143. boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
  144. self.cached_targets.append({"boxes": boxes, "labels": labels})
  145. def load_image_target(self, index):
  146. if self.load_cache:
  147. image = self.cached_images[index]
  148. target = self.cached_targets[index]
  149. height, width, channels = image.shape
  150. target["orig_size"] = [height, width]
  151. else:
  152. # load an image
  153. img_id = self.ids[index]
  154. image = cv2.imread(self._imgpath % img_id)
  155. height, width, channels = image.shape
  156. # laod an annotation
  157. anno = ET.parse(self._annopath % img_id).getroot()
  158. if self.target_transform is not None:
  159. anno = self.target_transform(anno)
  160. # guard against no boxes via resizing
  161. anno = np.array(anno).reshape(-1, 5)
  162. target = {
  163. "boxes": anno[:, :4],
  164. "labels": anno[:, 4],
  165. "orig_size": [height, width]
  166. }
  167. return image, target
  168. def load_mosaic(self, index):
  169. # load 4x mosaic image
  170. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  171. id1 = index
  172. id2, id3, id4 = random.sample(index_list, 3)
  173. indexs = [id1, id2, id3, id4]
  174. # load images and targets
  175. image_list = []
  176. target_list = []
  177. for index in indexs:
  178. img_i, target_i = self.load_image_target(index)
  179. image_list.append(img_i)
  180. target_list.append(target_i)
  181. # Mosaic
  182. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  183. image, target = yolov5_mosaic_augment(
  184. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  185. return image, target
  186. def load_mixup(self, origin_image, origin_target):
  187. # YOLOv5 type Mixup
  188. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  189. new_index = np.random.randint(0, len(self.ids))
  190. new_image, new_target = self.load_mosaic(new_index)
  191. image, target = yolov5_mixup_augment(
  192. origin_image, origin_target, new_image, new_target)
  193. # YOLOX type Mixup
  194. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  195. new_index = np.random.randint(0, len(self.ids))
  196. new_image, new_target = self.load_image_target(new_index)
  197. image, target = yolox_mixup_augment(
  198. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  199. return image, target
  200. def pull_item(self, index):
  201. if random.random() < self.mosaic_prob:
  202. # load a mosaic image
  203. mosaic = True
  204. image, target = self.load_mosaic(index)
  205. else:
  206. mosaic = False
  207. # load an image and target
  208. image, target = self.load_image_target(index)
  209. # MixUp
  210. if random.random() < self.mixup_prob:
  211. image, target = self.load_mixup(image, target)
  212. # augment
  213. image, target, deltas = self.transform(image, target, mosaic)
  214. return image, target, deltas
  215. def pull_image(self, index):
  216. '''Returns the original image object at index in PIL form
  217. Note: not using self.__getitem__(), as any transformations passed in
  218. could mess up this functionality.
  219. Argument:
  220. index (int): index of img to show
  221. Return:
  222. PIL img
  223. '''
  224. img_id = self.ids[index]
  225. return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR), img_id
  226. def pull_anno(self, index):
  227. '''Returns the original annotation of image at index
  228. Note: not using self.__getitem__(), as any transformations passed in
  229. could mess up this functionality.
  230. Argument:
  231. index (int): index of img to get annotation of
  232. Return:
  233. list: [img_id, [(label, bbox coords),...]]
  234. eg: ('001718', [('dog', (96, 13, 438, 332))])
  235. '''
  236. img_id = self.ids[index]
  237. anno = ET.parse(self._annopath % img_id).getroot()
  238. gt = self.target_transform(anno, 1, 1)
  239. return img_id[1], gt
  240. if __name__ == "__main__":
  241. import argparse
  242. from build import build_transform
  243. parser = argparse.ArgumentParser(description='VOC-Dataset')
  244. # opt
  245. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/VOCdevkit/',
  246. help='data root')
  247. parser.add_argument('-size', '--img_size', default=640, type=int,
  248. help='input image size.')
  249. parser.add_argument('--mosaic', default=None, type=float,
  250. help='mosaic augmentation.')
  251. parser.add_argument('--mixup', default=None, type=float,
  252. help='mixup augmentation.')
  253. parser.add_argument('--is_train', action="store_true", default=False,
  254. help='mixup augmentation.')
  255. parser.add_argument('--load_cache', action="store_true", default=False,
  256. help='load cached data.')
  257. args = parser.parse_args()
  258. trans_config = {
  259. 'aug_type': 'yolov5', # optional: ssd, yolov5
  260. # Basic Augment
  261. 'degrees': 0.0,
  262. 'translate': 0.2,
  263. 'scale': [0.1, 2.0],
  264. 'shear': 0.0,
  265. 'perspective': 0.0,
  266. 'hsv_h': 0.015,
  267. 'hsv_s': 0.7,
  268. 'hsv_v': 0.4,
  269. 'use_ablu': True,
  270. # Mosaic & Mixup
  271. 'mosaic_prob': 1.0,
  272. 'mixup_prob': 1.0,
  273. 'mosaic_type': 'yolov5_mosaic',
  274. 'mixup_type': 'yolov5_mixup',
  275. 'mixup_scale': [0.5, 1.5]
  276. }
  277. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  278. dataset = VOCDetection(
  279. img_size=args.img_size,
  280. data_dir=args.root,
  281. trans_config=trans_config,
  282. transform=transform,
  283. is_train=args.is_train,
  284. load_cache=args.load_cache
  285. )
  286. np.random.seed(0)
  287. class_colors = [(np.random.randint(255),
  288. np.random.randint(255),
  289. np.random.randint(255)) for _ in range(20)]
  290. print('Data length: ', len(dataset))
  291. for i in range(1000):
  292. image, target, deltas = dataset.pull_item(i)
  293. # to numpy
  294. image = image.permute(1, 2, 0).numpy()
  295. # to uint8
  296. image = image.astype(np.uint8)
  297. image = image.copy()
  298. img_h, img_w = image.shape[:2]
  299. boxes = target["boxes"]
  300. labels = target["labels"]
  301. for box, label in zip(boxes, labels):
  302. x1, y1, x2, y2 = box
  303. if x2 - x1 > 1 and y2 - y1 > 1:
  304. cls_id = int(label)
  305. color = class_colors[cls_id]
  306. # class name
  307. label = VOC_CLASSES[cls_id]
  308. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  309. # put the test on the bbox
  310. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  311. cv2.imshow('gt', image)
  312. # cv2.imwrite(str(i)+'.jpg', img)
  313. cv2.waitKey(0)