voc.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. import cv2
  2. import random
  3. import numpy as np
  4. import os.path as osp
  5. import xml.etree.ElementTree as ET
  6. import torch.utils.data as data
  7. try:
  8. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  9. except:
  10. from data_augment.strong_augment import MosaicAugment, MixupAugment
  11. VOC_CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
  12. voc_class_indexs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
  13. voc_class_labels = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
  14. class VOCAnnotationTransform(object):
  15. def __init__(self, class_to_ind=None, keep_difficult=False):
  16. self.class_to_ind = class_to_ind or dict(
  17. zip(VOC_CLASSES, range(len(VOC_CLASSES))))
  18. self.keep_difficult = keep_difficult
  19. def __call__(self, target):
  20. res = []
  21. for obj in target.iter('object'):
  22. difficult = int(obj.find('difficult').text) == 1
  23. if not self.keep_difficult and difficult:
  24. continue
  25. name = obj.find('name').text.lower().strip()
  26. bbox = obj.find('bndbox')
  27. pts = ['xmin', 'ymin', 'xmax', 'ymax']
  28. bndbox = []
  29. for i, pt in enumerate(pts):
  30. cur_pt = int(bbox.find(pt).text) - 1
  31. bndbox.append(cur_pt)
  32. label_idx = self.class_to_ind[name]
  33. bndbox.append(label_idx)
  34. res += [bndbox] # [x1, y1, x2, y2, label_ind]
  35. return res # [[x1, y1, x2, y2, label_ind], ... ]
  36. class VOCDataset(data.Dataset):
  37. def __init__(self,
  38. cfg,
  39. data_dir :str = None,
  40. image_set = [('2007', 'trainval'), ('2012', 'trainval')],
  41. transform = None,
  42. is_train :bool =False,
  43. ):
  44. # ----------- Basic parameters -----------
  45. self.image_set = image_set
  46. self.is_train = is_train
  47. self.num_classes = 20
  48. # ----------- Path parameters -----------
  49. self.root = data_dir
  50. self._annopath = osp.join('%s', 'Annotations', '%s.xml')
  51. self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
  52. # ----------- Data parameters -----------
  53. self.ids = list()
  54. for (year, name) in image_set:
  55. rootpath = osp.join(self.root, 'VOC' + year)
  56. for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
  57. self.ids.append((rootpath, line.strip()))
  58. self.dataset_size = len(self.ids)
  59. self.class_labels = voc_class_labels
  60. self.class_indexs = voc_class_indexs
  61. # ----------- Transform parameters -----------
  62. self.target_transform = VOCAnnotationTransform()
  63. self.transform = transform
  64. if is_train:
  65. self.mosaic_prob = cfg.mosaic_prob
  66. self.mixup_prob = cfg.mixup_prob
  67. self.copy_paste = cfg.copy_paste
  68. self.mosaic_augment = None if cfg.mosaic_prob == 0. else MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
  69. self.mixup_augment = None if cfg.mixup_prob == 0. and cfg.copy_paste == 0. else MixupAugment(cfg.train_img_size)
  70. else:
  71. self.mosaic_prob = 0.0
  72. self.mixup_prob = 0.0
  73. self.copy_paste = 0.0
  74. self.mosaic_augment = None
  75. self.mixup_augment = None
  76. print('==============================')
  77. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  78. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  79. print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
  80. # ------------ Basic dataset function ------------
  81. def __getitem__(self, index):
  82. image, target, deltas = self.pull_item(index)
  83. return image, target, deltas
  84. def __len__(self):
  85. return self.dataset_size
  86. # ------------ Mosaic & Mixup ------------
  87. def load_mosaic(self, index):
  88. # ------------ Prepare 4 indexes of images ------------
  89. ## Load 4x mosaic image
  90. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  91. id1 = index
  92. id2, id3, id4 = random.sample(index_list, 3)
  93. indexs = [id1, id2, id3, id4]
  94. ## Load images and targets
  95. image_list = []
  96. target_list = []
  97. for index in indexs:
  98. img_i, target_i = self.load_image_target(index)
  99. image_list.append(img_i)
  100. target_list.append(target_i)
  101. # ------------ Mosaic augmentation ------------
  102. image, target = self.mosaic_augment(image_list, target_list)
  103. return image, target
  104. def load_mixup(self, origin_image, origin_target, yolox_style=False):
  105. # ------------ Load a new image & target ------------
  106. if yolox_style:
  107. new_index = np.random.randint(0, len(self.ids))
  108. new_image, new_target = self.load_image_target(new_index)
  109. else:
  110. new_index = np.random.randint(0, len(self.ids))
  111. new_image, new_target = self.load_mosaic(new_index)
  112. # ------------ Mixup augmentation ------------
  113. image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target, yolox_style)
  114. return image, target
  115. # ------------ Load data function ------------
  116. def load_image_target(self, index):
  117. # load an image
  118. image, _ = self.pull_image(index)
  119. height, width, channels = image.shape
  120. # laod an annotation
  121. anno, _ = self.pull_anno(index)
  122. # guard against no boxes via resizing
  123. anno = np.array(anno).reshape(-1, 5)
  124. target = {
  125. "boxes": anno[:, :4],
  126. "labels": anno[:, 4],
  127. "orig_size": [height, width]
  128. }
  129. return image, target
  130. def pull_item(self, index):
  131. if random.random() < self.mosaic_prob:
  132. # load a mosaic image
  133. mosaic = True
  134. image, target = self.load_mosaic(index)
  135. else:
  136. mosaic = False
  137. # load an image and target
  138. image, target = self.load_image_target(index)
  139. # Yolov5-MixUp
  140. mixup = False
  141. if random.random() < self.mixup_prob:
  142. mixup = True
  143. image, target = self.load_mixup(image, target)
  144. # Copy-paste (use Yolox-Mixup to approximate copy-paste)
  145. if not mixup and random.random() < self.copy_paste:
  146. image, target = self.load_mixup(image, target, yolox_style=True)
  147. # augment
  148. image, target, deltas = self.transform(image, target, mosaic)
  149. return image, target, deltas
  150. def pull_image(self, index):
  151. img_id = self.ids[index]
  152. image = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
  153. return image, img_id
  154. def pull_anno(self, index):
  155. img_id = self.ids[index]
  156. anno = ET.parse(self._annopath % img_id).getroot()
  157. anno = self.target_transform(anno)
  158. return anno, img_id
  159. if __name__ == "__main__":
  160. import time
  161. import argparse
  162. from build import build_transform
  163. parser = argparse.ArgumentParser(description='VOC-Dataset')
  164. # opt
  165. parser.add_argument('--root', default='D:/python_work/dataset/VOCdevkit/',
  166. help='data root')
  167. parser.add_argument('--is_train', action="store_true", default=False,
  168. help='train or not.')
  169. parser.add_argument('--aug_type', default="yolo", type=str, choices=["yolo", "ssd"],
  170. help='yolo, ssd.')
  171. args = parser.parse_args()
  172. class YoloBaseConfig(object):
  173. def __init__(self) -> None:
  174. self.max_stride = 32
  175. # ---------------- Data process config ----------------
  176. self.box_format = 'xywh'
  177. self.normalize_coords = False
  178. self.mosaic_prob = 1.0
  179. self.mixup_prob = 0.15
  180. self.copy_paste = 0.3
  181. ## Pixel mean & std
  182. self.pixel_mean = [0., 0., 0.]
  183. self.pixel_std = [255., 255., 255.]
  184. ## Transforms
  185. self.train_img_size = 640
  186. self.test_img_size = 640
  187. self.use_ablu = True
  188. self.aug_type = 'yolo'
  189. self.affine_params = {
  190. 'degrees': 0.0,
  191. 'translate': 0.2,
  192. 'scale': [0.1, 2.0],
  193. 'shear': 0.0,
  194. 'perspective': 0.0,
  195. 'hsv_h': 0.015,
  196. 'hsv_s': 0.7,
  197. 'hsv_v': 0.4,
  198. }
  199. class SSDBaseConfig(object):
  200. def __init__(self) -> None:
  201. self.max_stride = 32
  202. # ---------------- Data process config ----------------
  203. self.box_format = 'xywh'
  204. self.normalize_coords = False
  205. self.mosaic_prob = 0.0
  206. self.mixup_prob = 0.0
  207. self.copy_paste = 0.0
  208. ## Pixel mean & std
  209. self.pixel_mean = [0., 0., 0.]
  210. self.pixel_std = [255., 255., 255.]
  211. ## Transforms
  212. self.train_img_size = 640
  213. self.test_img_size = 640
  214. self.aug_type = 'ssd'
  215. if args.aug_type == "yolo":
  216. cfg = YoloBaseConfig()
  217. elif args.aug_type == "ssd":
  218. cfg = SSDBaseConfig()
  219. transform = build_transform(cfg, args.is_train)
  220. dataset = VOCDataset(cfg, args.root, [('2007', 'test')], transform, args.is_train)
  221. np.random.seed(0)
  222. class_colors = [(np.random.randint(255),
  223. np.random.randint(255),
  224. np.random.randint(255)) for _ in range(80)]
  225. print('Data length: ', len(dataset))
  226. for i in range(1000):
  227. t0 = time.time()
  228. image, target, deltas = dataset.pull_item(i)
  229. print("Load data: {} s".format(time.time() - t0))
  230. # to numpy
  231. image = image.permute(1, 2, 0).numpy()
  232. # denormalize
  233. image = image * cfg.pixel_std + cfg.pixel_mean
  234. # rgb -> bgr
  235. if transform.color_format == 'rgb':
  236. image = image[..., (2, 1, 0)]
  237. # to uint8
  238. image = image.astype(np.uint8)
  239. image = image.copy()
  240. img_h, img_w = image.shape[:2]
  241. boxes = target["boxes"]
  242. labels = target["labels"]
  243. for box, label in zip(boxes, labels):
  244. if cfg.box_format == 'xyxy':
  245. x1, y1, x2, y2 = box
  246. elif cfg.box_format == 'xywh':
  247. cx, cy, bw, bh = box
  248. x1 = cx - 0.5 * bw
  249. y1 = cy - 0.5 * bh
  250. x2 = cx + 0.5 * bw
  251. y2 = cy + 0.5 * bh
  252. if cfg.normalize_coords:
  253. x1 *= img_w
  254. y1 *= img_h
  255. x2 *= img_w
  256. y2 *= img_h
  257. cls_id = int(label)
  258. color = class_colors[cls_id]
  259. # class name
  260. label = voc_class_labels[cls_id]
  261. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  262. # put the test on the bbox
  263. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  264. cv2.imshow('gt', image)
  265. # cv2.imwrite(str(i)+'.jpg', img)
  266. cv2.waitKey(0)