ourdataset.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import os
  2. import cv2
  3. import random
  4. import numpy as np
  5. import time
  6. from torch.utils.data import Dataset
  7. try:
  8. from pycocotools.coco import COCO
  9. except:
  10. print("It seems that the COCOAPI is not installed.")
  11. try:
  12. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  13. except:
  14. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  15. # please define our class labels
  16. our_class_labels = ('bird', 'butterfly', 'cat', 'cow', 'dog', 'lion', 'person', 'pig', 'tiger', )
  17. class OurDataset(Dataset):
  18. """
  19. Our dataset class.
  20. """
  21. def __init__(self,
  22. img_size=640,
  23. data_dir=None,
  24. image_set='train',
  25. transform=None,
  26. trans_config=None,
  27. is_train=False,
  28. load_cache=False):
  29. """
  30. COCO dataset initialization. Annotation data are read into memory by COCO API.
  31. Args:
  32. data_dir (str): dataset root directory
  33. json_file (str): COCO json file name
  34. name (str): COCO data name (e.g. 'train2017' or 'val2017')
  35. debug (bool): if True, only one data id is selected from the dataset
  36. """
  37. self.img_size = img_size
  38. self.image_set = image_set
  39. self.json_file = '{}.json'.format(image_set)
  40. self.data_dir = data_dir
  41. self.coco = COCO(os.path.join(self.data_dir, image_set, 'annotations', self.json_file))
  42. self.ids = self.coco.getImgIds()
  43. self.class_ids = sorted(self.coco.getCatIds())
  44. self.is_train = is_train
  45. self.load_cache = load_cache
  46. # augmentation
  47. self.transform = transform
  48. self.mosaic_prob = 0
  49. self.mixup_prob = 0
  50. self.trans_config = trans_config
  51. if trans_config is not None:
  52. self.mosaic_prob = trans_config['mosaic_prob']
  53. self.mixup_prob = trans_config['mixup_prob']
  54. print('==============================')
  55. print('Image Set: {}'.format(image_set))
  56. print('Json file: {}'.format(self.json_file))
  57. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  58. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  59. print('==============================')
  60. # load cache data
  61. if load_cache:
  62. self._load_cache()
  63. def __len__(self):
  64. return len(self.ids)
  65. def __getitem__(self, index):
  66. return self.pull_item(index)
  67. def _load_cache(self):
  68. # load image cache
  69. self.cached_images = []
  70. self.cached_targets = []
  71. dataset_size = len(self.ids)
  72. print('loading data into memory ...')
  73. for i in range(dataset_size):
  74. if i % 5000 == 0:
  75. print("[{} / {}]".format(i, dataset_size))
  76. # load an image
  77. image, image_id = self.pull_image(i)
  78. orig_h, orig_w, _ = image.shape
  79. # resize image
  80. r = self.img_size / max(orig_h, orig_w)
  81. if r != 1:
  82. interp = cv2.INTER_LINEAR
  83. new_size = (int(orig_w * r), int(orig_h * r))
  84. image = cv2.resize(image, new_size, interpolation=interp)
  85. img_h, img_w = image.shape[:2]
  86. self.cached_images.append(image)
  87. # load target cache
  88. bboxes, labels = self.pull_anno(i)
  89. bboxes[:, [0, 2]] = bboxes[:, [0, 2]] / orig_w * img_w
  90. bboxes[:, [1, 3]] = bboxes[:, [1, 3]] / orig_h * img_h
  91. self.cached_targets.append({"boxes": bboxes, "labels": labels})
  92. def load_image_target(self, index):
  93. if self.load_cache:
  94. # load data from cache
  95. image = self.cached_images[index]
  96. target = self.cached_targets[index]
  97. height, width, channels = image.shape
  98. target["orig_size"] = [height, width]
  99. else:
  100. # load an image
  101. image, _ = self.pull_image(index)
  102. height, width, channels = image.shape
  103. # load a target
  104. bboxes, labels = self.pull_anno(index)
  105. target = {
  106. "boxes": bboxes,
  107. "labels": labels,
  108. "orig_size": [height, width]
  109. }
  110. return image, target
  111. def load_mosaic(self, index):
  112. # load 4x mosaic image
  113. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  114. id1 = index
  115. id2, id3, id4 = random.sample(index_list, 3)
  116. indexs = [id1, id2, id3, id4]
  117. # load images and targets
  118. image_list = []
  119. target_list = []
  120. for index in indexs:
  121. img_i, target_i = self.load_image_target(index)
  122. image_list.append(img_i)
  123. target_list.append(target_i)
  124. # Mosaic
  125. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  126. image, target = yolov5_mosaic_augment(
  127. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  128. return image, target
  129. def load_mixup(self, origin_image, origin_target):
  130. # YOLOv5 type Mixup
  131. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  132. new_index = np.random.randint(0, len(self.ids))
  133. new_image, new_target = self.load_mosaic(new_index)
  134. image, target = yolov5_mixup_augment(
  135. origin_image, origin_target, new_image, new_target)
  136. # YOLOX type Mixup
  137. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  138. new_index = np.random.randint(0, len(self.ids))
  139. new_image, new_target = self.load_image_target(new_index)
  140. image, target = yolox_mixup_augment(
  141. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  142. return image, target
  143. def pull_item(self, index):
  144. if random.random() < self.mosaic_prob:
  145. # load a mosaic image
  146. mosaic = True
  147. image, target = self.load_mosaic(index)
  148. else:
  149. mosaic = False
  150. # load an image and target
  151. image, target = self.load_image_target(index)
  152. # MixUp
  153. if random.random() < self.mixup_prob:
  154. image, target = self.load_mixup(image, target)
  155. # augment
  156. image, target, deltas = self.transform(image, target, mosaic)
  157. return image, target, deltas
  158. def pull_image(self, index):
  159. id_ = self.ids[index]
  160. im_ann = self.coco.loadImgs(id_)[0]
  161. img_file = os.path.join(
  162. self.data_dir, self.image_set, 'images', im_ann["file_name"])
  163. image = cv2.imread(img_file)
  164. return image, id_
  165. def pull_anno(self, index):
  166. img_id = self.ids[index]
  167. im_ann = self.coco.loadImgs(img_id)[0]
  168. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=0)
  169. annotations = self.coco.loadAnns(anno_ids)
  170. # image infor
  171. width = im_ann['width']
  172. height = im_ann['height']
  173. #load a target
  174. bboxes = []
  175. labels = []
  176. for anno in annotations:
  177. if 'bbox' in anno and anno['area'] > 0:
  178. # bbox
  179. x1 = np.max((0, anno['bbox'][0]))
  180. y1 = np.max((0, anno['bbox'][1]))
  181. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  182. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  183. if x2 <= x1 or y2 <= y1:
  184. continue
  185. # class label
  186. cls_id = self.class_ids.index(anno['category_id'])
  187. bboxes.append([x1, y1, x2, y2])
  188. labels.append(cls_id)
  189. # guard against no boxes via resizing
  190. bboxes = np.array(bboxes).reshape(-1, 4)
  191. labels = np.array(labels).reshape(-1)
  192. return bboxes, labels
  193. if __name__ == "__main__":
  194. import argparse
  195. from build import build_transform
  196. parser = argparse.ArgumentParser(description='FreeYOLOv2')
  197. # opt
  198. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/AnimalDataset/',
  199. help='data root')
  200. parser.add_argument('--split', default='train',
  201. help='data split')
  202. parser.add_argument('-size', '--img_size', default=640, type=int,
  203. help='input image size')
  204. parser.add_argument('--min_box_size', default=8.0, type=float,
  205. help='min size of target bounding box.')
  206. parser.add_argument('--mosaic', default=None, type=float,
  207. help='mosaic augmentation.')
  208. parser.add_argument('--mixup', default=None, type=float,
  209. help='mixup augmentation.')
  210. parser.add_argument('--is_train', action="store_true", default=False,
  211. help='mixup augmentation.')
  212. parser.add_argument('--load_cache', action="store_true", default=False,
  213. help='load cached data.')
  214. args = parser.parse_args()
  215. trans_config = {
  216. 'aug_type': 'yolov5', # optional: ssd, yolov5
  217. # Basic Augment
  218. 'degrees': 0.0,
  219. 'translate': 0.2,
  220. 'scale': [0.5, 2.0],
  221. 'shear': 0.0,
  222. 'perspective': 0.0,
  223. 'hsv_h': 0.015,
  224. 'hsv_s': 0.7,
  225. 'hsv_v': 0.4,
  226. # Mosaic & Mixup
  227. 'mosaic_prob': 1.0,
  228. 'mixup_prob': 1.0,
  229. 'mosaic_type': 'yolov5_mosaic',
  230. 'mixup_type': 'yolov5_mixup',
  231. 'mixup_scale': [0.5, 1.5]
  232. }
  233. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  234. dataset = OurDataset(
  235. img_size=args.img_size,
  236. data_dir=args.root,
  237. image_set=args.split,
  238. transform=transform,
  239. trans_config=trans_config,
  240. is_train=args.is_train,
  241. load_cache=args.load_cache
  242. )
  243. np.random.seed(0)
  244. class_colors = [(np.random.randint(255),
  245. np.random.randint(255),
  246. np.random.randint(255)) for _ in range(80)]
  247. print('Data length: ', len(dataset))
  248. for i in range(1000):
  249. image, target, deltas = dataset.pull_item(i)
  250. # to numpy
  251. image = image.permute(1, 2, 0).numpy()
  252. image = image.astype(np.uint8)
  253. image = image.copy()
  254. img_h, img_w = image.shape[:2]
  255. boxes = target["boxes"]
  256. labels = target["labels"]
  257. for box, label in zip(boxes, labels):
  258. x1, y1, x2, y2 = box
  259. cls_id = int(label)
  260. color = class_colors[cls_id]
  261. # class name
  262. label = our_class_labels[cls_id]
  263. if x2 - x1 > 0. and y2 - y1 > 0.:
  264. # draw bbox
  265. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  266. # put the test on the bbox
  267. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  268. cv2.imshow('gt', image)
  269. # cv2.imwrite(str(i)+'.jpg', img)
  270. cv2.waitKey(0)