ourdataset.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import os
  2. import cv2
  3. import random
  4. import numpy as np
  5. import time
  6. from torch.utils.data import Dataset
  7. try:
  8. from pycocotools.coco import COCO
  9. except:
  10. print("It seems that the COCOAPI is not installed.")
  11. try:
  12. from .data_augment import build_transform
  13. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. except:
  15. from data_augment import build_transform
  16. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  17. # please define our class labels
  18. our_class_labels = ('cat',)
  19. class OurDataset(Dataset):
  20. """
  21. Our dataset class.
  22. """
  23. def __init__(self,
  24. img_size=640,
  25. data_dir=None,
  26. image_set='train',
  27. transform=None,
  28. trans_config=None,
  29. is_train=False):
  30. """
  31. COCO dataset initialization. Annotation data are read into memory by COCO API.
  32. Args:
  33. data_dir (str): dataset root directory
  34. json_file (str): COCO json file name
  35. name (str): COCO data name (e.g. 'train2017' or 'val2017')
  36. debug (bool): if True, only one data id is selected from the dataset
  37. """
  38. self.img_size = img_size
  39. self.image_set = image_set
  40. self.json_file = '{}.json'.format(image_set)
  41. self.data_dir = data_dir
  42. self.coco = COCO(os.path.join(self.data_dir, image_set, 'annotations', self.json_file))
  43. self.ids = self.coco.getImgIds()
  44. self.class_ids = sorted(self.coco.getCatIds())
  45. self.is_train = is_train
  46. # augmentation
  47. self.transform = transform
  48. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  49. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  50. self.trans_config = trans_config
  51. print('==============================')
  52. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  53. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  54. print('==============================')
  55. def __len__(self):
  56. return len(self.ids)
  57. def __getitem__(self, index):
  58. return self.pull_item(index)
  59. def load_image_target(self, index):
  60. # load an image
  61. image, _ = self.pull_image(index)
  62. height, width, channels = image.shape
  63. # load a target
  64. bboxes, labels = self.pull_anno(index)
  65. target = {
  66. "boxes": bboxes,
  67. "labels": labels,
  68. "orig_size": [height, width]
  69. }
  70. return image, target
  71. def load_mosaic(self, index):
  72. # load 4x mosaic image
  73. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  74. id1 = index
  75. id2, id3, id4 = random.sample(index_list, 3)
  76. indexs = [id1, id2, id3, id4]
  77. # load images and targets
  78. image_list = []
  79. target_list = []
  80. for index in indexs:
  81. img_i, target_i = self.load_image_target(index)
  82. image_list.append(img_i)
  83. target_list.append(target_i)
  84. # Mosaic Augment
  85. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  86. image, target = yolov5_mosaic_augment(
  87. image_list, target_list, self.img_size, self.trans_config)
  88. return image, target
  89. def load_mixup(self, origin_image, origin_target):
  90. # YOLOv5 type Mixup
  91. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  92. new_index = np.random.randint(0, len(self.ids))
  93. new_image, new_target = self.load_mosaic(new_index)
  94. image, target = yolov5_mixup_augment(
  95. origin_image, origin_target, new_image, new_target)
  96. # YOLOX type Mixup
  97. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  98. new_index = np.random.randint(0, len(self.ids))
  99. new_image, new_target = self.load_image_target(new_index)
  100. image, target = yolox_mixup_augment(
  101. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  102. return image, target
  103. def pull_item(self, index):
  104. if random.random() < self.mosaic_prob:
  105. # load a mosaic image
  106. mosaic = True
  107. image, target = self.load_mosaic(index)
  108. else:
  109. mosaic = False
  110. # load an image and target
  111. image, target = self.load_image_target(index)
  112. # MixUp
  113. if random.random() < self.mixup_prob:
  114. image, target = self.load_mixup(image, target)
  115. # augment
  116. image, target, deltas = self.transform(image, target, mosaic)
  117. return image, target, deltas
  118. def pull_image(self, index):
  119. id_ = self.ids[index]
  120. im_ann = self.coco.loadImgs(id_)[0]
  121. img_file = os.path.join(
  122. self.data_dir, self.image_set, 'images', im_ann["file_name"])
  123. image = cv2.imread(img_file)
  124. return image, id_
  125. def pull_anno(self, index):
  126. id_ = self.ids[index]
  127. anno_ids = self.coco.getAnnIds(imgIds=[int(id_)], iscrowd=None)
  128. annotations = self.coco.loadAnns(anno_ids)
  129. #load a target
  130. bboxes = []
  131. labels = []
  132. for anno in annotations:
  133. if 'bbox' in anno and anno['area'] > 0:
  134. # bbox
  135. x1 = np.max((0, anno['bbox'][0]))
  136. y1 = np.max((0, anno['bbox'][1]))
  137. x2 = x1 + anno['bbox'][2]
  138. y2 = y1 + anno['bbox'][3]
  139. if x2 < x1 or y2 < y1:
  140. continue
  141. # class label
  142. cls_id = self.class_ids.index(anno['category_id'])
  143. bboxes.append([x1, y1, x2, y2])
  144. labels.append(cls_id)
  145. # guard against no boxes via resizing
  146. bboxes = np.array(bboxes).reshape(-1, 4)
  147. labels = np.array(labels).reshape(-1)
  148. return bboxes, labels
  149. if __name__ == "__main__":
  150. import argparse
  151. import sys
  152. from data_augment import build_transform
  153. sys.path.append('.')
  154. parser = argparse.ArgumentParser(description='Our-Dataset')
  155. # opt
  156. parser.add_argument('--root', default='OurDataset',
  157. help='data root')
  158. parser.add_argument('--split', default='train',
  159. help='data split')
  160. args = parser.parse_args()
  161. is_train = False
  162. img_size = 640
  163. yolov5_trans_config = {
  164. 'aug_type': 'yolov5',
  165. # Basic Augment
  166. 'degrees': 0.0,
  167. 'translate': 0.2,
  168. 'scale': 0.9,
  169. 'shear': 0.0,
  170. 'perspective': 0.0,
  171. 'hsv_h': 0.015,
  172. 'hsv_s': 0.7,
  173. 'hsv_v': 0.4,
  174. # Mosaic & Mixup
  175. 'mosaic_prob': 1.0,
  176. 'mixup_prob': 0.15,
  177. 'mosaic_type': 'yolov5_mosaic',
  178. 'mixup_type': 'yolov5_mixup',
  179. 'mixup_scale': [0.5, 1.5]
  180. }
  181. ssd_trans_config = {
  182. 'aug_type': 'ssd',
  183. 'mosaic_prob': 0.0,
  184. 'mixup_prob': 0.0
  185. }
  186. transform = build_transform(img_size, yolov5_trans_config, is_train)
  187. dataset = OurDataset(
  188. img_size=img_size,
  189. data_dir=args.root,
  190. image_set='train',
  191. trans_config=yolov5_trans_config,
  192. transform=transform,
  193. is_train=is_train
  194. )
  195. np.random.seed(0)
  196. class_colors = [(np.random.randint(255),
  197. np.random.randint(255),
  198. np.random.randint(255)) for _ in range(80)]
  199. print('Data length: ', len(dataset))
  200. for i in range(1000):
  201. image, target, deltas = dataset.pull_item(i)
  202. # to numpy
  203. image = image.permute(1, 2, 0).numpy()
  204. # to uint8
  205. image = image.astype(np.uint8)
  206. image = image.copy()
  207. img_h, img_w = image.shape[:2]
  208. boxes = target["boxes"]
  209. labels = target["labels"]
  210. for box, label in zip(boxes, labels):
  211. x1, y1, x2, y2 = box
  212. cls_id = int(label)
  213. color = class_colors[cls_id]
  214. # class name
  215. label = our_class_labels[cls_id]
  216. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  217. # put the test on the bbox
  218. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  219. cv2.imshow('gt', image)
  220. # cv2.imwrite(str(i)+'.jpg', img)
  221. cv2.waitKey(0)