voc.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. import os
  2. import cv2
  3. import torch
  4. import random
  5. import numpy as np
  6. import os.path as osp
  7. import xml.etree.ElementTree as ET
  8. import torch
  9. import torch.utils.data as data
  10. try:
  11. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  12. except:
  13. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. # VOC class names
  15. VOC_CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
  16. class VOCAnnotationTransform(object):
  17. """Transforms a VOC annotation into a Tensor of bbox coords and label index
  18. Initilized with a dictionary lookup of classnames to indexes
  19. Arguments:
  20. class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
  21. (default: alphabetic indexing of VOC's 20 classes)
  22. keep_difficult (bool, optional): keep difficult instances or not
  23. (default: False)
  24. height (int): height
  25. width (int): width
  26. """
  27. def __init__(self, class_to_ind=None, keep_difficult=False):
  28. self.class_to_ind = class_to_ind or dict(
  29. zip(VOC_CLASSES, range(len(VOC_CLASSES))))
  30. self.keep_difficult = keep_difficult
  31. def __call__(self, target):
  32. """
  33. Arguments:
  34. target (annotation) : the target annotation to be made usable
  35. will be an ET.Element
  36. Returns:
  37. a list containing lists of bounding boxes [bbox coords, class name]
  38. """
  39. res = []
  40. for obj in target.iter('object'):
  41. difficult = int(obj.find('difficult').text) == 1
  42. if not self.keep_difficult and difficult:
  43. continue
  44. name = obj.find('name').text.lower().strip()
  45. bbox = obj.find('bndbox')
  46. pts = ['xmin', 'ymin', 'xmax', 'ymax']
  47. bndbox = []
  48. for i, pt in enumerate(pts):
  49. cur_pt = int(bbox.find(pt).text) - 1
  50. # scale height or width
  51. cur_pt = cur_pt if i % 2 == 0 else cur_pt
  52. bndbox.append(cur_pt)
  53. label_idx = self.class_to_ind[name]
  54. bndbox.append(label_idx)
  55. res += [bndbox] # [x1, y1, x2, y2, label_ind]
  56. return res # [[x1, y1, x2, y2, label_ind], ... ]
  57. class VOCDataset(data.Dataset):
  58. def __init__(self,
  59. img_size :int = 640,
  60. data_dir :str = None,
  61. image_sets = [('2007', 'trainval'), ('2012', 'trainval')],
  62. trans_config = None,
  63. transform = None,
  64. is_train :bool = False,
  65. load_cache :str = None,
  66. ):
  67. # ----------- Basic parameters -----------
  68. self.img_size = img_size
  69. self.image_set = image_sets
  70. self.is_train = is_train
  71. self.load_cache = load_cache
  72. self.target_transform = VOCAnnotationTransform()
  73. # ----------- Path parameters -----------
  74. self.root = data_dir
  75. self._annopath = osp.join('%s', 'Annotations', '%s.xml')
  76. self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
  77. # ----------- Data parameters -----------
  78. self.ids = list()
  79. for (year, name) in image_sets:
  80. rootpath = osp.join(self.root, 'VOC' + year)
  81. for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
  82. self.ids.append((rootpath, line.strip()))
  83. self.dataset_size = len(self.ids)
  84. # ----------- Transform parameters -----------
  85. self.transform = transform
  86. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  87. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  88. self.trans_config = trans_config
  89. print('==============================')
  90. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  91. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  92. print('==============================')
  93. # load cache data
  94. if load_cache is not None:
  95. self._load_cache()
  96. # ------------ Basic dataset function ------------
  97. def __getitem__(self, index):
  98. image, target, deltas = self.pull_item(index)
  99. return image, target, deltas
  100. def __len__(self):
  101. return self.dataset_size
  102. def _load_cache(self):
  103. # load image cache
  104. try:
  105. print("Loading cached data ...")
  106. self.cached_datas = torch.load(self.load_cache)
  107. self.dataset_size = len(self.cached_datas)
  108. print("Loading done !")
  109. except:
  110. self.load_cache = None
  111. print("{} does not exits.".format(self.load_cache))
  112. # ------------ Mosaic & Mixup ------------
  113. def load_mosaic(self, index):
  114. # load 4x mosaic image
  115. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  116. id1 = index
  117. id2, id3, id4 = random.sample(index_list, 3)
  118. indexs = [id1, id2, id3, id4]
  119. # load images and targets
  120. image_list = []
  121. target_list = []
  122. for index in indexs:
  123. img_i, target_i = self.load_image_target(index)
  124. image_list.append(img_i)
  125. target_list.append(target_i)
  126. # Mosaic
  127. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  128. image, target = yolov5_mosaic_augment(
  129. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  130. return image, target
  131. def load_mixup(self, origin_image, origin_target):
  132. # YOLOv5 type Mixup
  133. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  134. new_index = np.random.randint(0, len(self.ids))
  135. new_image, new_target = self.load_mosaic(new_index)
  136. image, target = yolov5_mixup_augment(
  137. origin_image, origin_target, new_image, new_target)
  138. # YOLOX type Mixup
  139. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  140. new_index = np.random.randint(0, len(self.ids))
  141. new_image, new_target = self.load_image_target(new_index)
  142. image, target = yolox_mixup_augment(
  143. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  144. return image, target
  145. # ------------ Load data function ------------
  146. def load_image_target(self, index):
  147. # == Load a data from the cached data ==
  148. if self.load_cache and self.is_train:
  149. # load a data
  150. data_item = self.cached_datas[index]
  151. image = data_item["image"]
  152. target = data_item["target"]
  153. # == Load a data from the local disk ==
  154. else:
  155. # load an image
  156. image, _ = self.pull_image(index)
  157. height, width, channels = image.shape
  158. # laod an annotation
  159. anno, _ = self.pull_anno(index)
  160. # guard against no boxes via resizing
  161. anno = np.array(anno).reshape(-1, 5)
  162. target = {
  163. "boxes": anno[:, :4],
  164. "labels": anno[:, 4],
  165. "orig_size": [height, width]
  166. }
  167. return image, target
  168. def pull_item(self, index):
  169. if random.random() < self.mosaic_prob:
  170. # load a mosaic image
  171. mosaic = True
  172. image, target = self.load_mosaic(index)
  173. else:
  174. mosaic = False
  175. # load an image and target
  176. image, target = self.load_image_target(index)
  177. # MixUp
  178. if random.random() < self.mixup_prob:
  179. image, target = self.load_mixup(image, target)
  180. # augment
  181. image, target, deltas = self.transform(image, target, mosaic)
  182. return image, target, deltas
  183. def pull_image(self, index):
  184. img_id = self.ids[index]
  185. image = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
  186. return image, img_id
  187. def pull_anno(self, index):
  188. img_id = self.ids[index]
  189. anno = ET.parse(self._annopath % img_id).getroot()
  190. anno = self.target_transform(anno)
  191. return anno, img_id
  192. if __name__ == "__main__":
  193. import argparse
  194. from build import build_transform
  195. parser = argparse.ArgumentParser(description='VOC-Dataset')
  196. # opt
  197. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/VOCdevkit/',
  198. help='data root')
  199. parser.add_argument('-size', '--img_size', default=640, type=int,
  200. help='input image size.')
  201. parser.add_argument('--mosaic', default=None, type=float,
  202. help='mosaic augmentation.')
  203. parser.add_argument('--mixup', default=None, type=float,
  204. help='mixup augmentation.')
  205. parser.add_argument('--is_train', action="store_true", default=False,
  206. help='mixup augmentation.')
  207. parser.add_argument('--load_cache', type=str, default=None,
  208. help='Path to the cached data.')
  209. args = parser.parse_args()
  210. trans_config = {
  211. 'aug_type': 'yolov5', # optional: ssd, yolov5
  212. # Basic Augment
  213. 'degrees': 0.0,
  214. 'translate': 0.2,
  215. 'scale': [0.1, 2.0],
  216. 'shear': 0.0,
  217. 'perspective': 0.0,
  218. 'hsv_h': 0.015,
  219. 'hsv_s': 0.7,
  220. 'hsv_v': 0.4,
  221. 'use_ablu': True,
  222. # Mosaic & Mixup
  223. 'mosaic_prob': args.mosaic,
  224. 'mixup_prob': args.mixup,
  225. 'mosaic_type': 'yolov5_mosaic',
  226. 'mixup_type': 'yolov5_mixup',
  227. 'mixup_scale': [0.5, 1.5]
  228. }
  229. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  230. dataset = VOCDataset(
  231. img_size=args.img_size,
  232. data_dir=args.root,
  233. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if args.is_train else [('2007', 'test')],
  234. trans_config=trans_config,
  235. transform=transform,
  236. is_train=args.is_train,
  237. load_cache=args.load_cache
  238. )
  239. np.random.seed(0)
  240. class_colors = [(np.random.randint(255),
  241. np.random.randint(255),
  242. np.random.randint(255)) for _ in range(20)]
  243. print('Data length: ', len(dataset))
  244. for i in range(1000):
  245. image, target, deltas = dataset.pull_item(i)
  246. # to numpy
  247. image = image.permute(1, 2, 0).numpy()
  248. # to uint8
  249. image = image.astype(np.uint8)
  250. image = image.copy()
  251. img_h, img_w = image.shape[:2]
  252. boxes = target["boxes"]
  253. labels = target["labels"]
  254. for box, label in zip(boxes, labels):
  255. x1, y1, x2, y2 = box
  256. if x2 - x1 > 1 and y2 - y1 > 1:
  257. cls_id = int(label)
  258. color = class_colors[cls_id]
  259. # class name
  260. label = VOC_CLASSES[cls_id]
  261. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  262. # put the test on the bbox
  263. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  264. cv2.imshow('gt', image)
  265. # cv2.imwrite(str(i)+'.jpg', img)
  266. cv2.waitKey(0)