voc.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import os
  2. import cv2
  3. import time
  4. import random
  5. import numpy as np
  6. import torch
  7. from pycocotools.coco import COCO
  8. try:
  9. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  10. except:
  11. from data_augment.strong_augment import MosaicAugment, MixupAugment
  12. voc_class_indexs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
  13. voc_class_labels = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
  14. class VOCDataset(torch.utils.data.Dataset):
  15. def __init__(self,
  16. cfg,
  17. data_dir :str = None,
  18. transform = None,
  19. is_train :bool = False,
  20. ):
  21. # ----------- Basic parameters -----------
  22. self.data_dir = data_dir
  23. self.image_set = "train" if is_train else "val"
  24. self.is_train = is_train
  25. self.num_classes = 20
  26. # ----------- Data parameters -----------
  27. self.json_file = "instances_{}.json".format(self.image_set)
  28. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  29. self.ids = self.coco.getImgIds()
  30. self.class_ids = sorted(self.coco.getCatIds())
  31. self.dataset_size = len(self.ids)
  32. self.class_labels = voc_class_labels
  33. self.class_indexs = voc_class_indexs
  34. # ----------- Transform parameters -----------
  35. self.transform = transform
  36. if is_train:
  37. if cfg.mosaic_prob == 0.:
  38. self.mosaic_augment = None
  39. else:
  40. self.mosaic_augment = MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
  41. self.mosaic_prob = cfg.mosaic_prob
  42. if cfg.mixup_prob == 0.:
  43. self.mixup_augment = None
  44. else:
  45. self.mixup_augment = MixupAugment(cfg.train_img_size)
  46. self.mixup_prob = cfg.mixup_prob
  47. self.copy_paste = cfg.copy_paste
  48. else:
  49. self.mosaic_prob = 0.0
  50. self.mixup_prob = 0.0
  51. self.copy_paste = 0.0
  52. self.mosaic_augment = None
  53. self.mixup_augment = None
  54. print(' ============ Strong augmentation info. ============ ')
  55. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  56. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  57. print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
  58. # ------------ Basic dataset function ------------
  59. def __len__(self):
  60. return len(self.ids)
  61. def __getitem__(self, index):
  62. return self.pull_item(index)
  63. # ------------ Mosaic & Mixup ------------
  64. def load_mosaic(self, index):
  65. # ------------ Prepare 4 indexes of images ------------
  66. ## Load 4x mosaic image
  67. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  68. id1 = index
  69. id2, id3, id4 = random.sample(index_list, 3)
  70. indexs = [id1, id2, id3, id4]
  71. ## Load images and targets
  72. image_list = []
  73. target_list = []
  74. for index in indexs:
  75. img_i, target_i = self.load_image_target(index)
  76. image_list.append(img_i)
  77. target_list.append(target_i)
  78. # ------------ Mosaic augmentation ------------
  79. image, target = self.mosaic_augment(image_list, target_list)
  80. return image, target
  81. def load_mixup(self, origin_image, origin_target, yolox_style=False):
  82. # ------------ Load a new image & target ------------
  83. if yolox_style:
  84. new_index = np.random.randint(0, len(self.ids))
  85. new_image, new_target = self.load_image_target(new_index)
  86. else:
  87. new_index = np.random.randint(0, len(self.ids))
  88. new_image, new_target = self.load_mosaic(new_index)
  89. # ------------ Mixup augmentation ------------
  90. image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target, yolox_style)
  91. return image, target
  92. # ------------ Load data function ------------
  93. def load_image_target(self, index):
  94. # load an image
  95. image, _ = self.pull_image(index)
  96. height, width, channels = image.shape
  97. # load a target
  98. bboxes, labels = self.pull_anno(index)
  99. target = {
  100. "boxes": bboxes,
  101. "labels": labels,
  102. "orig_size": [height, width]
  103. }
  104. return image, target
  105. def pull_item(self, index):
  106. if random.random() < self.mosaic_prob:
  107. # load a mosaic image
  108. mosaic = True
  109. image, target = self.load_mosaic(index)
  110. else:
  111. mosaic = False
  112. # load an image and target
  113. image, target = self.load_image_target(index)
  114. # Yolov5-MixUp
  115. mixup = False
  116. if random.random() < self.mixup_prob:
  117. mixup = True
  118. image, target = self.load_mixup(image, target)
  119. # Copy-paste (use Yolox-Mixup to approximate copy-paste)
  120. if not mixup and random.random() < self.copy_paste:
  121. image, target = self.load_mixup(image, target, yolox_style=True)
  122. # augment
  123. image, target, deltas = self.transform(image, target, mosaic)
  124. return image, target, deltas
  125. def pull_image(self, index):
  126. # get the image file name
  127. image_dict = self.coco.dataset['images'][index]
  128. image_id = image_dict["id"]
  129. filename = image_dict["file_name"]
  130. # load the image
  131. image_path = os.path.join(self.data_dir, "images", filename)
  132. image = cv2.imread(image_path)
  133. assert image is not None
  134. return image, image_id
  135. def pull_anno(self, index):
  136. img_id = self.ids[index]
  137. # image infor
  138. im_ann = self.coco.loadImgs(img_id)[0]
  139. width = im_ann['width']
  140. height = im_ann['height']
  141. # annotation infor
  142. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=None)
  143. annotations = self.coco.loadAnns(anno_ids)
  144. #load a target
  145. bboxes = []
  146. labels = []
  147. for anno in annotations:
  148. if 'bbox' in anno and anno['area'] > 0:
  149. # bbox
  150. x1 = np.max((0, anno['bbox'][0]))
  151. y1 = np.max((0, anno['bbox'][1]))
  152. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  153. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  154. if x2 < x1 or y2 < y1:
  155. continue
  156. # class label
  157. cls_id = self.class_ids.index(anno['category_id'])
  158. bboxes.append([x1, y1, x2, y2])
  159. labels.append(cls_id)
  160. # guard against no boxes via resizing
  161. bboxes = np.array(bboxes).reshape(-1, 4)
  162. labels = np.array(labels).reshape(-1)
  163. return bboxes, labels
  164. if __name__ == "__main__":
  165. import time
  166. import argparse
  167. from build import build_transform
  168. parser = argparse.ArgumentParser(description='COCO-Dataset')
  169. # opt
  170. parser.add_argument('--root', default="D:/python_work/dataset/VOCdevkit/",
  171. help='data root')
  172. parser.add_argument('--is_train', action="store_true", default=False,
  173. help='mixup augmentation.')
  174. parser.add_argument('--aug_type', default="yolo", type=str, choices=["yolo", "ssd"],
  175. help='yolo, ssd.')
  176. args = parser.parse_args()
  177. class YoloBaseConfig(object):
  178. def __init__(self) -> None:
  179. self.max_stride = 32
  180. # ---------------- Data process config ----------------
  181. self.box_format = 'xywh'
  182. self.normalize_coords = False
  183. self.mosaic_prob = 1.0
  184. self.mixup_prob = 0.15
  185. self.copy_paste = 0.3
  186. ## Pixel mean & std
  187. self.pixel_mean = [0., 0., 0.]
  188. self.pixel_std = [255., 255., 255.]
  189. ## Transforms
  190. self.train_img_size = 640
  191. self.test_img_size = 640
  192. self.use_ablu = True
  193. self.aug_type = 'yolo'
  194. self.affine_params = {
  195. 'degrees': 0.0,
  196. 'translate': 0.2,
  197. 'scale': [0.1, 2.0],
  198. 'shear': 0.0,
  199. 'perspective': 0.0,
  200. 'hsv_h': 0.015,
  201. 'hsv_s': 0.7,
  202. 'hsv_v': 0.4,
  203. }
  204. class SSDBaseConfig(object):
  205. def __init__(self) -> None:
  206. self.max_stride = 32
  207. # ---------------- Data process config ----------------
  208. self.box_format = 'xywh'
  209. self.normalize_coords = False
  210. self.mosaic_prob = 0.0
  211. self.mixup_prob = 0.0
  212. self.copy_paste = 0.0
  213. ## Pixel mean & std
  214. self.pixel_mean = [0., 0., 0.]
  215. self.pixel_std = [255., 255., 255.]
  216. ## Transforms
  217. self.train_img_size = 640
  218. self.test_img_size = 640
  219. self.aug_type = 'ssd'
  220. if args.aug_type == "yolo":
  221. cfg = YoloBaseConfig()
  222. elif args.aug_type == "ssd":
  223. cfg = SSDBaseConfig()
  224. transform = build_transform(cfg, args.is_train)
  225. dataset = VOCDataset(cfg, args.root, transform, args.is_train)
  226. np.random.seed(0)
  227. class_colors = [(np.random.randint(255),
  228. np.random.randint(255),
  229. np.random.randint(255)) for _ in range(80)]
  230. print('Data length: ', len(dataset))
  231. for i in range(1000):
  232. t0 = time.time()
  233. image, target, deltas = dataset.pull_item(i)
  234. print("Load data: {} s".format(time.time() - t0))
  235. # to numpy
  236. image = image.permute(1, 2, 0).numpy()
  237. # denormalize
  238. image = image * cfg.pixel_std + cfg.pixel_mean
  239. # rgb -> bgr
  240. if transform.color_format == 'rgb':
  241. image = image[..., (2, 1, 0)]
  242. # to uint8
  243. image = image.astype(np.uint8)
  244. image = image.copy()
  245. img_h, img_w = image.shape[:2]
  246. boxes = target["boxes"]
  247. labels = target["labels"]
  248. for box, label in zip(boxes, labels):
  249. if cfg.box_format == 'xyxy':
  250. x1, y1, x2, y2 = box
  251. elif cfg.box_format == 'xywh':
  252. cx, cy, bw, bh = box
  253. x1 = cx - 0.5 * bw
  254. y1 = cy - 0.5 * bh
  255. x2 = cx + 0.5 * bw
  256. y2 = cy + 0.5 * bh
  257. if cfg.normalize_coords:
  258. x1 *= img_w
  259. y1 *= img_h
  260. x2 *= img_w
  261. y2 *= img_h
  262. cls_id = int(label)
  263. color = class_colors[cls_id]
  264. # class name
  265. label = voc_class_labels[cls_id]
  266. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  267. # put the test on the bbox
  268. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  269. cv2.imshow('gt', image)
  270. # cv2.imwrite(str(i)+'.jpg', img)
  271. cv2.waitKey(0)