voc.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. import os
  2. import cv2
  3. import torch
  4. import random
  5. import numpy as np
  6. import os.path as osp
  7. import xml.etree.ElementTree as ET
  8. import torch
  9. import torch.utils.data as data
  10. try:
  11. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  12. except:
  13. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. # VOC class names
  15. VOC_CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
  16. class VOCAnnotationTransform(object):
  17. """Transforms a VOC annotation into a Tensor of bbox coords and label index
  18. Initilized with a dictionary lookup of classnames to indexes
  19. Arguments:
  20. class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
  21. (default: alphabetic indexing of VOC's 20 classes)
  22. keep_difficult (bool, optional): keep difficult instances or not
  23. (default: False)
  24. height (int): height
  25. width (int): width
  26. """
  27. def __init__(self, class_to_ind=None, keep_difficult=False):
  28. self.class_to_ind = class_to_ind or dict(
  29. zip(VOC_CLASSES, range(len(VOC_CLASSES))))
  30. self.keep_difficult = keep_difficult
  31. def __call__(self, target):
  32. """
  33. Arguments:
  34. target (annotation) : the target annotation to be made usable
  35. will be an ET.Element
  36. Returns:
  37. a list containing lists of bounding boxes [bbox coords, class name]
  38. """
  39. res = []
  40. for obj in target.iter('object'):
  41. difficult = int(obj.find('difficult').text) == 1
  42. if not self.keep_difficult and difficult:
  43. continue
  44. name = obj.find('name').text.lower().strip()
  45. bbox = obj.find('bndbox')
  46. pts = ['xmin', 'ymin', 'xmax', 'ymax']
  47. bndbox = []
  48. for i, pt in enumerate(pts):
  49. cur_pt = int(bbox.find(pt).text) - 1
  50. # scale height or width
  51. cur_pt = cur_pt if i % 2 == 0 else cur_pt
  52. bndbox.append(cur_pt)
  53. label_idx = self.class_to_ind[name]
  54. bndbox.append(label_idx)
  55. res += [bndbox] # [x1, y1, x2, y2, label_ind]
  56. return res # [[x1, y1, x2, y2, label_ind], ... ]
  57. class VOCDataset(data.Dataset):
  58. def __init__(self,
  59. img_size :int = 640,
  60. data_dir :str = None,
  61. image_sets = [('2007', 'trainval'), ('2012', 'trainval')],
  62. trans_config = None,
  63. transform = None,
  64. is_train :bool = False,
  65. load_cache :str = None,
  66. ):
  67. # ----------- Basic parameters -----------
  68. self.img_size = img_size
  69. self.image_set = image_sets
  70. self.is_train = is_train
  71. self.load_cache = load_cache
  72. self.target_transform = VOCAnnotationTransform()
  73. # ----------- Path parameters -----------
  74. self.root = data_dir
  75. self._annopath = osp.join('%s', 'Annotations', '%s.xml')
  76. self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
  77. # ----------- Data parameters -----------
  78. self.ids = list()
  79. for (year, name) in image_sets:
  80. rootpath = osp.join(self.root, 'VOC' + year)
  81. for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
  82. self.ids.append((rootpath, line.strip()))
  83. self.dataset_size = len(self.ids)
  84. # ----------- Transform parameters -----------
  85. self.transform = transform
  86. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  87. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  88. self.trans_config = trans_config
  89. print('==============================')
  90. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  91. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  92. print('==============================')
  93. # load cache data
  94. if load_cache is not None:
  95. self._load_cache()
  96. # ------------ Basic dataset function ------------
  97. def __getitem__(self, index):
  98. image, target, deltas = self.pull_item(index)
  99. return image, target, deltas
  100. def __len__(self):
  101. return self.dataset_size
  102. def _load_cache(self):
  103. # load image cache
  104. try:
  105. print("Loading cached data ...")
  106. self.cached_datas = torch.load(self.load_cache)
  107. self.dataset_size = len(self.cached_datas)
  108. print("Loading done !")
  109. except:
  110. self.cached_datas = None
  111. self.load_cache = None
  112. print("{} does not exits.".format(self.load_cache))
  113. # ------------ Mosaic & Mixup ------------
  114. def load_mosaic(self, index):
  115. # load 4x mosaic image
  116. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  117. id1 = index
  118. id2, id3, id4 = random.sample(index_list, 3)
  119. indexs = [id1, id2, id3, id4]
  120. # load images and targets
  121. image_list = []
  122. target_list = []
  123. for index in indexs:
  124. img_i, target_i = self.load_image_target(index)
  125. image_list.append(img_i)
  126. target_list.append(target_i)
  127. # Mosaic
  128. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  129. image, target = yolov5_mosaic_augment(
  130. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  131. return image, target
  132. def load_mixup(self, origin_image, origin_target):
  133. # YOLOv5 type Mixup
  134. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  135. new_index = np.random.randint(0, len(self.ids))
  136. new_image, new_target = self.load_mosaic(new_index)
  137. image, target = yolov5_mixup_augment(
  138. origin_image, origin_target, new_image, new_target)
  139. # YOLOX type Mixup
  140. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  141. new_index = np.random.randint(0, len(self.ids))
  142. new_image, new_target = self.load_image_target(new_index)
  143. image, target = yolox_mixup_augment(
  144. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  145. return image, target
  146. # ------------ Load data function ------------
  147. def load_image_target(self, index):
  148. # == Load a data from the cached data ==
  149. if self.load_cache is not None and self.is_train:
  150. # load a data
  151. data_item = self.cached_datas[index]
  152. image = data_item["image"]
  153. target = data_item["target"]
  154. # == Load a data from the local disk ==
  155. else:
  156. # load an image
  157. image, _ = self.pull_image(index)
  158. height, width, channels = image.shape
  159. # laod an annotation
  160. anno, _ = self.pull_anno(index)
  161. # guard against no boxes via resizing
  162. anno = np.array(anno).reshape(-1, 5)
  163. target = {
  164. "boxes": anno[:, :4],
  165. "labels": anno[:, 4],
  166. "orig_size": [height, width]
  167. }
  168. return image, target
  169. def pull_item(self, index):
  170. if random.random() < self.mosaic_prob:
  171. # load a mosaic image
  172. mosaic = True
  173. image, target = self.load_mosaic(index)
  174. else:
  175. mosaic = False
  176. # load an image and target
  177. image, target = self.load_image_target(index)
  178. # MixUp
  179. if random.random() < self.mixup_prob:
  180. image, target = self.load_mixup(image, target)
  181. # augment
  182. image, target, deltas = self.transform(image, target, mosaic)
  183. return image, target, deltas
  184. def pull_image(self, index):
  185. img_id = self.ids[index]
  186. image = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
  187. return image, img_id
  188. def pull_anno(self, index):
  189. img_id = self.ids[index]
  190. anno = ET.parse(self._annopath % img_id).getroot()
  191. anno = self.target_transform(anno)
  192. return anno, img_id
  193. if __name__ == "__main__":
  194. import argparse
  195. from build import build_transform
  196. parser = argparse.ArgumentParser(description='VOC-Dataset')
  197. # opt
  198. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/VOCdevkit/',
  199. help='data root')
  200. parser.add_argument('-size', '--img_size', default=640, type=int,
  201. help='input image size.')
  202. parser.add_argument('--mosaic', default=None, type=float,
  203. help='mosaic augmentation.')
  204. parser.add_argument('--mixup', default=None, type=float,
  205. help='mixup augmentation.')
  206. parser.add_argument('--is_train', action="store_true", default=False,
  207. help='mixup augmentation.')
  208. parser.add_argument('--load_cache', type=str, default=None,
  209. help='Path to the cached data.')
  210. args = parser.parse_args()
  211. trans_config = {
  212. 'aug_type': 'yolov5', # optional: ssd, yolov5
  213. # Basic Augment
  214. 'degrees': 0.0,
  215. 'translate': 0.2,
  216. 'scale': [0.1, 2.0],
  217. 'shear': 0.0,
  218. 'perspective': 0.0,
  219. 'hsv_h': 0.015,
  220. 'hsv_s': 0.7,
  221. 'hsv_v': 0.4,
  222. 'use_ablu': True,
  223. # Mosaic & Mixup
  224. 'mosaic_prob': args.mosaic,
  225. 'mixup_prob': args.mixup,
  226. 'mosaic_type': 'yolov5_mosaic',
  227. 'mixup_type': 'yolov5_mixup',
  228. 'mixup_scale': [0.5, 1.5]
  229. }
  230. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  231. dataset = VOCDataset(
  232. img_size=args.img_size,
  233. data_dir=args.root,
  234. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if args.is_train else [('2007', 'test')],
  235. trans_config=trans_config,
  236. transform=transform,
  237. is_train=args.is_train,
  238. load_cache=args.load_cache
  239. )
  240. np.random.seed(0)
  241. class_colors = [(np.random.randint(255),
  242. np.random.randint(255),
  243. np.random.randint(255)) for _ in range(20)]
  244. print('Data length: ', len(dataset))
  245. for i in range(1000):
  246. image, target, deltas = dataset.pull_item(i)
  247. # to numpy
  248. image = image.permute(1, 2, 0).numpy()
  249. # to uint8
  250. image = image.astype(np.uint8)
  251. image = image.copy()
  252. img_h, img_w = image.shape[:2]
  253. boxes = target["boxes"]
  254. labels = target["labels"]
  255. for box, label in zip(boxes, labels):
  256. x1, y1, x2, y2 = box
  257. if x2 - x1 > 1 and y2 - y1 > 1:
  258. cls_id = int(label)
  259. color = class_colors[cls_id]
  260. # class name
  261. label = VOC_CLASSES[cls_id]
  262. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  263. # put the test on the bbox
  264. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  265. cv2.imshow('gt', image)
  266. # cv2.imwrite(str(i)+'.jpg', img)
  267. cv2.waitKey(0)