voc.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. """VOC Dataset Classes
  2. Original author: Francisco Massa
  3. https://github.com/fmassa/vision/blob/voc_dataset/torchvision/datasets/voc.py
  4. Updated by: Ellis Brown, Max deGroot
  5. """
  6. import os.path as osp
  7. import random
  8. import torch.utils.data as data
  9. import cv2
  10. import numpy as np
  11. import xml.etree.ElementTree as ET
  12. try:
  13. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. except:
  15. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  16. VOC_CLASSES = ( # always index 0
  17. 'aeroplane', 'bicycle', 'bird', 'boat',
  18. 'bottle', 'bus', 'car', 'cat', 'chair',
  19. 'cow', 'diningtable', 'dog', 'horse',
  20. 'motorbike', 'person', 'pottedplant',
  21. 'sheep', 'sofa', 'train', 'tvmonitor')
  22. class VOCAnnotationTransform(object):
  23. """Transforms a VOC annotation into a Tensor of bbox coords and label index
  24. Initilized with a dictionary lookup of classnames to indexes
  25. Arguments:
  26. class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
  27. (default: alphabetic indexing of VOC's 20 classes)
  28. keep_difficult (bool, optional): keep difficult instances or not
  29. (default: False)
  30. height (int): height
  31. width (int): width
  32. """
  33. def __init__(self, class_to_ind=None, keep_difficult=False):
  34. self.class_to_ind = class_to_ind or dict(
  35. zip(VOC_CLASSES, range(len(VOC_CLASSES))))
  36. self.keep_difficult = keep_difficult
  37. def __call__(self, target):
  38. """
  39. Arguments:
  40. target (annotation) : the target annotation to be made usable
  41. will be an ET.Element
  42. Returns:
  43. a list containing lists of bounding boxes [bbox coords, class name]
  44. """
  45. res = []
  46. for obj in target.iter('object'):
  47. difficult = int(obj.find('difficult').text) == 1
  48. if not self.keep_difficult and difficult:
  49. continue
  50. name = obj.find('name').text.lower().strip()
  51. bbox = obj.find('bndbox')
  52. pts = ['xmin', 'ymin', 'xmax', 'ymax']
  53. bndbox = []
  54. for i, pt in enumerate(pts):
  55. cur_pt = int(bbox.find(pt).text) - 1
  56. # scale height or width
  57. cur_pt = cur_pt if i % 2 == 0 else cur_pt
  58. bndbox.append(cur_pt)
  59. label_idx = self.class_to_ind[name]
  60. bndbox.append(label_idx)
  61. res += [bndbox] # [x1, y1, x2, y2, label_ind]
  62. return res # [[x1, y1, x2, y2, label_ind], ... ]
  63. class VOCDetection(data.Dataset):
  64. """VOC Detection Dataset Object
  65. input is image, target is annotation
  66. Arguments:
  67. root (string): filepath to VOCdevkit folder.
  68. image_set (string): imageset to use (eg. 'train', 'val', 'test')
  69. transform (callable, optional): transformation to perform on the
  70. input image
  71. target_transform (callable, optional): transformation to perform on the
  72. target `annotation`
  73. (eg: take in caption string, return tensor of word indices)
  74. dataset_name (string, optional): which dataset to load
  75. (default: 'VOC2007')
  76. """
  77. def __init__(self,
  78. img_size=640,
  79. data_dir=None,
  80. image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
  81. trans_config=None,
  82. transform=None,
  83. is_train=False,
  84. load_cache=False
  85. ):
  86. self.root = data_dir
  87. self.img_size = img_size
  88. self.image_set = image_sets
  89. self.target_transform = VOCAnnotationTransform()
  90. self._annopath = osp.join('%s', 'Annotations', '%s.xml')
  91. self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
  92. self.ids = list()
  93. self.is_train = is_train
  94. self.load_cache = load_cache
  95. for (year, name) in image_sets:
  96. rootpath = osp.join(self.root, 'VOC' + year)
  97. for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
  98. self.ids.append((rootpath, line.strip()))
  99. # augmentation
  100. self.transform = transform
  101. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  102. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  103. self.trans_config = trans_config
  104. print('==============================')
  105. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  106. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  107. print('==============================')
  108. # load cache data
  109. if load_cache:
  110. self._load_cache()
  111. def __getitem__(self, index):
  112. image, target, deltas = self.pull_item(index)
  113. return image, target, deltas
  114. def __len__(self):
  115. return len(self.ids)
  116. def _load_cache(self):
  117. # load image cache
  118. self.image_list = None # TODO: H5PY file
  119. # load target cache
  120. self.target_list = []
  121. for img_id in self.ids:
  122. anno = ET.parse(self._annopath % img_id).getroot()
  123. anno = self.target_transform(anno)
  124. anno = np.array(anno).reshape(-1, 5)
  125. self.target_list.append({"boxes": anno[:, :4], "labels": anno[:, 4]})
  126. def load_image_target(self, index):
  127. if self.load_cache:
  128. image = self.image_list[index]
  129. target = self.target_list[index]
  130. height, width, channels = image.shape
  131. target["orig_size"] = [height, width]
  132. else:
  133. # load an image
  134. img_id = self.ids[index]
  135. image = cv2.imread(self._imgpath % img_id)
  136. height, width, channels = image.shape
  137. # laod an annotation
  138. anno = ET.parse(self._annopath % img_id).getroot()
  139. if self.target_transform is not None:
  140. anno = self.target_transform(anno)
  141. # guard against no boxes via resizing
  142. anno = np.array(anno).reshape(-1, 5)
  143. target = {
  144. "boxes": anno[:, :4],
  145. "labels": anno[:, 4],
  146. "orig_size": [height, width]
  147. }
  148. return image, target
  149. def load_mosaic(self, index):
  150. # load 4x mosaic image
  151. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  152. id1 = index
  153. id2, id3, id4 = random.sample(index_list, 3)
  154. indexs = [id1, id2, id3, id4]
  155. # load images and targets
  156. image_list = []
  157. target_list = []
  158. for index in indexs:
  159. img_i, target_i = self.load_image_target(index)
  160. image_list.append(img_i)
  161. target_list.append(target_i)
  162. # Mosaic
  163. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  164. image, target = yolov5_mosaic_augment(
  165. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  166. return image, target
  167. def load_mixup(self, origin_image, origin_target):
  168. # YOLOv5 type Mixup
  169. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  170. new_index = np.random.randint(0, len(self.ids))
  171. new_image, new_target = self.load_mosaic(new_index)
  172. image, target = yolov5_mixup_augment(
  173. origin_image, origin_target, new_image, new_target)
  174. # YOLOX type Mixup
  175. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  176. new_index = np.random.randint(0, len(self.ids))
  177. new_image, new_target = self.load_image_target(new_index)
  178. image, target = yolox_mixup_augment(
  179. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  180. return image, target
  181. def pull_item(self, index):
  182. if random.random() < self.mosaic_prob:
  183. # load a mosaic image
  184. mosaic = True
  185. image, target = self.load_mosaic(index)
  186. else:
  187. mosaic = False
  188. # load an image and target
  189. image, target = self.load_image_target(index)
  190. # MixUp
  191. if random.random() < self.mixup_prob:
  192. image, target = self.load_mixup(image, target)
  193. # augment
  194. image, target, deltas = self.transform(image, target, mosaic)
  195. return image, target, deltas
  196. def pull_image(self, index):
  197. '''Returns the original image object at index in PIL form
  198. Note: not using self.__getitem__(), as any transformations passed in
  199. could mess up this functionality.
  200. Argument:
  201. index (int): index of img to show
  202. Return:
  203. PIL img
  204. '''
  205. img_id = self.ids[index]
  206. return cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR), img_id
  207. def pull_anno(self, index):
  208. '''Returns the original annotation of image at index
  209. Note: not using self.__getitem__(), as any transformations passed in
  210. could mess up this functionality.
  211. Argument:
  212. index (int): index of img to get annotation of
  213. Return:
  214. list: [img_id, [(label, bbox coords),...]]
  215. eg: ('001718', [('dog', (96, 13, 438, 332))])
  216. '''
  217. img_id = self.ids[index]
  218. anno = ET.parse(self._annopath % img_id).getroot()
  219. gt = self.target_transform(anno, 1, 1)
  220. return img_id[1], gt
  221. if __name__ == "__main__":
  222. import argparse
  223. from build import build_transform
  224. parser = argparse.ArgumentParser(description='VOC-Dataset')
  225. # opt
  226. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/VOCdevkit/',
  227. help='data root')
  228. parser.add_argument('-size', '--img_size', default=640, type=int,
  229. help='input image size.')
  230. parser.add_argument('--mosaic', default=None, type=float,
  231. help='mosaic augmentation.')
  232. parser.add_argument('--mixup', default=None, type=float,
  233. help='mixup augmentation.')
  234. parser.add_argument('--is_train', action="store_true", default=False,
  235. help='mixup augmentation.')
  236. parser.add_argument('--load_cache', action="store_true", default=False,
  237. help='load cached data.')
  238. args = parser.parse_args()
  239. trans_config = {
  240. 'aug_type': 'yolov5', # optional: ssd, yolov5
  241. # Basic Augment
  242. 'degrees': 0.0,
  243. 'translate': 0.2,
  244. 'scale': [0.5, 2.0],
  245. 'shear': 0.0,
  246. 'perspective': 0.0,
  247. 'hsv_h': 0.015,
  248. 'hsv_s': 0.7,
  249. 'hsv_v': 0.4,
  250. # Mosaic & Mixup
  251. 'mosaic_prob': 1.0,
  252. 'mixup_prob': 1.0,
  253. 'mosaic_type': 'yolov5_mosaic',
  254. 'mixup_type': 'yolov5_mixup',
  255. 'mixup_scale': [0.5, 1.5]
  256. }
  257. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  258. dataset = VOCDetection(
  259. img_size=args.img_size,
  260. data_dir=args.root,
  261. trans_config=trans_config,
  262. transform=transform,
  263. is_train=args.is_train
  264. )
  265. np.random.seed(0)
  266. class_colors = [(np.random.randint(255),
  267. np.random.randint(255),
  268. np.random.randint(255)) for _ in range(20)]
  269. print('Data length: ', len(dataset))
  270. for i in range(1000):
  271. image, target, deltas = dataset.pull_item(i)
  272. # to numpy
  273. image = image.permute(1, 2, 0).numpy()
  274. # to uint8
  275. image = image.astype(np.uint8)
  276. image = image.copy()
  277. img_h, img_w = image.shape[:2]
  278. boxes = target["boxes"]
  279. labels = target["labels"]
  280. for box, label in zip(boxes, labels):
  281. x1, y1, x2, y2 = box
  282. cls_id = int(label)
  283. color = class_colors[cls_id]
  284. # class name
  285. label = VOC_CLASSES[cls_id]
  286. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  287. # put the test on the bbox
  288. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  289. cv2.imshow('gt', image)
  290. # cv2.imwrite(str(i)+'.jpg', img)
  291. cv2.waitKey(0)