voc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. import cv2
  2. import random
  3. import numpy as np
  4. import os.path as osp
  5. import xml.etree.ElementTree as ET
  6. import torch.utils.data as data
  7. try:
  8. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  9. except:
  10. from data_augment.strong_augment import MosaicAugment, MixupAugment
  11. # VOC class names
  12. VOC_CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
  13. class VOCAnnotationTransform(object):
  14. """Transforms a VOC annotation into a Tensor of bbox coords and label index
  15. Initilized with a dictionary lookup of classnames to indexes
  16. Arguments:
  17. class_to_ind (dict, optional): dictionary lookup of classnames -> indexes
  18. (default: alphabetic indexing of VOC's 20 classes)
  19. keep_difficult (bool, optional): keep difficult instances or not
  20. (default: False)
  21. height (int): height
  22. width (int): width
  23. """
  24. def __init__(self, class_to_ind=None, keep_difficult=False):
  25. self.class_to_ind = class_to_ind or dict(
  26. zip(VOC_CLASSES, range(len(VOC_CLASSES))))
  27. self.keep_difficult = keep_difficult
  28. def __call__(self, target):
  29. """
  30. Arguments:
  31. target (annotation) : the target annotation to be made usable
  32. will be an ET.Element
  33. Returns:
  34. a list containing lists of bounding boxes [bbox coords, class name]
  35. """
  36. res = []
  37. for obj in target.iter('object'):
  38. difficult = int(obj.find('difficult').text) == 1
  39. if not self.keep_difficult and difficult:
  40. continue
  41. name = obj.find('name').text.lower().strip()
  42. bbox = obj.find('bndbox')
  43. pts = ['xmin', 'ymin', 'xmax', 'ymax']
  44. bndbox = []
  45. for i, pt in enumerate(pts):
  46. cur_pt = int(bbox.find(pt).text) - 1
  47. # scale height or width
  48. cur_pt = cur_pt if i % 2 == 0 else cur_pt
  49. bndbox.append(cur_pt)
  50. label_idx = self.class_to_ind[name]
  51. bndbox.append(label_idx)
  52. res += [bndbox] # [x1, y1, x2, y2, label_ind]
  53. return res # [[x1, y1, x2, y2, label_ind], ... ]
  54. class VOCDataset(data.Dataset):
  55. def __init__(self,
  56. img_size :int = 640,
  57. data_dir :str = None,
  58. image_sets = [('2007', 'trainval'), ('2012', 'trainval')],
  59. trans_config = None,
  60. transform = None,
  61. is_train :bool = False,
  62. load_cache :bool = False,
  63. ):
  64. # ----------- Basic parameters -----------
  65. self.img_size = img_size
  66. self.image_set = image_sets
  67. self.is_train = is_train
  68. self.target_transform = VOCAnnotationTransform()
  69. # ----------- Path parameters -----------
  70. self.root = data_dir
  71. self._annopath = osp.join('%s', 'Annotations', '%s.xml')
  72. self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
  73. # ----------- Data parameters -----------
  74. self.ids = list()
  75. for (year, name) in image_sets:
  76. rootpath = osp.join(self.root, 'VOC' + year)
  77. for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
  78. self.ids.append((rootpath, line.strip()))
  79. self.dataset_size = len(self.ids)
  80. # ----------- Transform parameters -----------
  81. self.trans_config = trans_config
  82. self.transform = transform
  83. # ----------- Strong augmentation -----------
  84. if is_train:
  85. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  86. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  87. self.mosaic_augment = MosaicAugment(img_size, trans_config, is_train)
  88. self.mixup_augment = MixupAugment(img_size, trans_config)
  89. else:
  90. self.mosaic_prob = 0.0
  91. self.mixup_prob = 0.0
  92. self.mosaic_augment = None
  93. self.mixup_augment = None
  94. print('==============================')
  95. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  96. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  97. print('==============================')
  98. # ----------- Cached data -----------
  99. self.load_cache = load_cache
  100. self.cached_datas = None
  101. if self.load_cache:
  102. self.cached_datas = self._load_cache()
  103. # ------------ Basic dataset function ------------
  104. def __getitem__(self, index):
  105. image, target, deltas = self.pull_item(index)
  106. return image, target, deltas
  107. def __len__(self):
  108. return self.dataset_size
  109. def _load_cache(self):
  110. data_items = []
  111. for idx in range(self.dataset_size):
  112. if idx % 2000 == 0:
  113. print("Caching images and targets : {} / {} ...".format(idx, self.dataset_size))
  114. # load a data
  115. image, target = self.load_image_target(idx)
  116. orig_h, orig_w, _ = image.shape
  117. # resize image
  118. r = self.img_size / max(orig_h, orig_w)
  119. if r != 1:
  120. interp = cv2.INTER_LINEAR
  121. new_size = (int(orig_w * r), int(orig_h * r))
  122. image = cv2.resize(image, new_size, interpolation=interp)
  123. img_h, img_w = image.shape[:2]
  124. # rescale bbox
  125. boxes = target["boxes"].copy()
  126. boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
  127. boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
  128. target["boxes"] = boxes
  129. dict_item = {}
  130. dict_item["image"] = image
  131. dict_item["target"] = target
  132. data_items.append(dict_item)
  133. return data_items
  134. # ------------ Mosaic & Mixup ------------
  135. def load_mosaic(self, index):
  136. # ------------ Prepare 4 indexes of images ------------
  137. ## Load 4x mosaic image
  138. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  139. id1 = index
  140. id2, id3, id4 = random.sample(index_list, 3)
  141. indexs = [id1, id2, id3, id4]
  142. ## Load images and targets
  143. image_list = []
  144. target_list = []
  145. for index in indexs:
  146. img_i, target_i = self.load_image_target(index)
  147. image_list.append(img_i)
  148. target_list.append(target_i)
  149. # ------------ Mosaic augmentation ------------
  150. image, target = self.mosaic_augment(image_list, target_list)
  151. return image, target
  152. def load_mixup(self, origin_image, origin_target):
  153. # ------------ Load a new image & target ------------
  154. if self.mixup_augment.mixup_type == 'yolov5':
  155. new_index = np.random.randint(0, len(self.ids))
  156. new_image, new_target = self.load_mosaic(new_index)
  157. elif self.mixup_augment.mixup_type == 'yolox':
  158. new_index = np.random.randint(0, len(self.ids))
  159. new_image, new_target = self.load_image_target(new_index)
  160. # ------------ Mixup augmentation ------------
  161. image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target)
  162. return image, target
  163. # ------------ Load data function ------------
  164. def load_image_target(self, index):
  165. # == Load a data from the cached data ==
  166. if self.cached_datas is not None:
  167. # load a data
  168. data_item = self.cached_datas[index]
  169. image = data_item["image"]
  170. target = data_item["target"]
  171. # == Load a data from the local disk ==
  172. else:
  173. # load an image
  174. image, _ = self.pull_image(index)
  175. height, width, channels = image.shape
  176. # laod an annotation
  177. anno, _ = self.pull_anno(index)
  178. # guard against no boxes via resizing
  179. anno = np.array(anno).reshape(-1, 5)
  180. target = {
  181. "boxes": anno[:, :4],
  182. "labels": anno[:, 4],
  183. "orig_size": [height, width]
  184. }
  185. return image, target
  186. def pull_item(self, index):
  187. if random.random() < self.mosaic_prob:
  188. # load a mosaic image
  189. mosaic = True
  190. image, target = self.load_mosaic(index)
  191. else:
  192. mosaic = False
  193. # load an image and target
  194. image, target = self.load_image_target(index)
  195. # MixUp
  196. if random.random() < self.mixup_prob:
  197. image, target = self.load_mixup(image, target)
  198. # augment
  199. image, target, deltas = self.transform(image, target, mosaic)
  200. return image, target, deltas
  201. def pull_image(self, index):
  202. img_id = self.ids[index]
  203. image = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
  204. return image, img_id
  205. def pull_anno(self, index):
  206. img_id = self.ids[index]
  207. anno = ET.parse(self._annopath % img_id).getroot()
  208. anno = self.target_transform(anno)
  209. return anno, img_id
  210. if __name__ == "__main__":
  211. import time
  212. import argparse
  213. from build import build_transform
  214. parser = argparse.ArgumentParser(description='VOC-Dataset')
  215. # opt
  216. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/VOCdevkit/',
  217. help='data root')
  218. parser.add_argument('-size', '--img_size', default=640, type=int,
  219. help='input image size.')
  220. parser.add_argument('--aug_type', type=str, default='ssd',
  221. help='augmentation type: ssd, yolov5, rtdetr.')
  222. parser.add_argument('--mosaic', default=0., type=float,
  223. help='mosaic augmentation.')
  224. parser.add_argument('--mixup', default=0., type=float,
  225. help='mixup augmentation.')
  226. parser.add_argument('--mixup_type', type=str, default='yolov5_mixup',
  227. help='mixup augmentation.')
  228. parser.add_argument('--is_train', action="store_true", default=False,
  229. help='mixup augmentation.')
  230. parser.add_argument('--load_cache', action="store_true", default=False,
  231. help='Path to the cached data.')
  232. args = parser.parse_args()
  233. trans_config = {
  234. 'aug_type': args.aug_type, # optional: ssd, yolov5
  235. 'pixel_mean': [123.675, 116.28, 103.53],
  236. 'pixel_std': [58.395, 57.12, 57.375],
  237. 'use_ablu': True,
  238. # Basic Augment
  239. 'affine_params': {
  240. 'degrees': 0.0,
  241. 'translate': 0.2,
  242. 'scale': [0.1, 2.0],
  243. 'shear': 0.0,
  244. 'perspective': 0.0,
  245. 'hsv_h': 0.015,
  246. 'hsv_s': 0.7,
  247. 'hsv_v': 0.4,
  248. },
  249. # Mosaic & Mixup
  250. 'mosaic_keep_ratio': False,
  251. 'mosaic_prob': args.mosaic,
  252. 'mixup_prob': args.mixup,
  253. 'mosaic_type': 'yolov5',
  254. 'mixup_type': 'yolov5',
  255. 'mixup_scale': [0.5, 1.5]
  256. }
  257. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  258. pixel_mean = transform.pixel_mean
  259. pixel_std = transform.pixel_std
  260. color_format = transform.color_format
  261. dataset = VOCDataset(
  262. img_size=args.img_size,
  263. data_dir=args.root,
  264. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if args.is_train else [('2007', 'test')],
  265. trans_config=trans_config,
  266. transform=transform,
  267. is_train=args.is_train,
  268. load_cache=args.load_cache
  269. )
  270. np.random.seed(0)
  271. class_colors = [(np.random.randint(255),
  272. np.random.randint(255),
  273. np.random.randint(255)) for _ in range(20)]
  274. print('Data length: ', len(dataset))
  275. for i in range(1000):
  276. t0 = time.time()
  277. image, target, deltas = dataset.pull_item(i)
  278. print("Load data: {} s".format(time.time() - t0))
  279. # to numpy
  280. image = image.permute(1, 2, 0).numpy()
  281. # denormalize
  282. image = image * pixel_std + pixel_mean
  283. if color_format == 'rgb':
  284. # RGB to BGR
  285. image = image[..., (2, 1, 0)]
  286. # to uint8
  287. image = image.astype(np.uint8)
  288. image = image.copy()
  289. img_h, img_w = image.shape[:2]
  290. boxes = target["boxes"]
  291. labels = target["labels"]
  292. for box, label in zip(boxes, labels):
  293. x1, y1, x2, y2 = box
  294. if x2 - x1 > 1 and y2 - y1 > 1:
  295. cls_id = int(label)
  296. color = class_colors[cls_id]
  297. # class name
  298. label = VOC_CLASSES[cls_id]
  299. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  300. # put the test on the bbox
  301. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  302. cv2.imshow('gt', image)
  303. # cv2.imwrite(str(i)+'.jpg', img)
  304. cv2.waitKey(0)