coco.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. import os
  2. import cv2
  3. import time
  4. import random
  5. import numpy as np
  6. import torch
  7. from torch.utils.data import Dataset
  8. try:
  9. from pycocotools.coco import COCO
  10. except:
  11. print("It seems that the COCOAPI is not installed.")
  12. try:
  13. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. except:
  15. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  16. coco_class_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  17. coco_class_labels = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  18. class COCODataset(Dataset):
  19. def __init__(self,
  20. img_size :int = 640,
  21. data_dir :str = None,
  22. image_set :str = 'train2017',
  23. trans_config = None,
  24. transform = None,
  25. is_train :bool =False,
  26. load_cache :bool = False,
  27. ):
  28. # ----------- Basic parameters -----------
  29. self.img_size = img_size
  30. self.image_set = image_set
  31. self.is_train = is_train
  32. # ----------- Path parameters -----------
  33. self.data_dir = data_dir
  34. if image_set == 'train2017':
  35. self.json_file='instances_train2017_clean.json'
  36. elif image_set == 'val2017':
  37. self.json_file='instances_val2017_clean.json'
  38. elif image_set == 'test2017':
  39. self.json_file='image_info_test-dev2017.json'
  40. else:
  41. raise NotImplementedError("Unknown json image set {}.".format(image_set))
  42. # ----------- Data parameters -----------
  43. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  44. self.ids = self.coco.getImgIds()
  45. self.class_ids = sorted(self.coco.getCatIds())
  46. self.dataset_size = len(self.ids)
  47. # ----------- Transform parameters -----------
  48. self.transform = transform
  49. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  50. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  51. self.trans_config = trans_config
  52. print('==============================')
  53. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  54. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  55. print('==============================')
  56. # ----------- Cached data -----------
  57. self.load_cache = load_cache
  58. self.cached_datas = None
  59. if self.load_cache:
  60. self.cached_datas = self._load_cache()
  61. # ------------ Basic dataset function ------------
  62. def __len__(self):
  63. return len(self.ids)
  64. def __getitem__(self, index):
  65. return self.pull_item(index)
  66. def _load_cache(self):
  67. data_items = []
  68. for idx in range(self.dataset_size):
  69. if idx % 2000 == 0:
  70. print("Caching images and targets : {} / {} ...".format(idx, self.dataset_size))
  71. # load a data
  72. image, target = self.load_image_target(idx)
  73. orig_h, orig_w, _ = image.shape
  74. # resize image
  75. r = self.img_size / max(orig_h, orig_w)
  76. if r != 1:
  77. interp = cv2.INTER_LINEAR
  78. new_size = (int(orig_w * r), int(orig_h * r))
  79. image = cv2.resize(image, new_size, interpolation=interp)
  80. img_h, img_w = image.shape[:2]
  81. # rescale bbox
  82. boxes = target["boxes"].copy()
  83. boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
  84. boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
  85. target["boxes"] = boxes
  86. dict_item = {}
  87. dict_item["image"] = image
  88. dict_item["target"] = target
  89. data_items.append(dict_item)
  90. return data_items
  91. # ------------ Mosaic & Mixup ------------
  92. def load_mosaic(self, index):
  93. # load 4x mosaic image
  94. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  95. id1 = index
  96. id2, id3, id4 = random.sample(index_list, 3)
  97. indexs = [id1, id2, id3, id4]
  98. # load images and targets
  99. image_list = []
  100. target_list = []
  101. for index in indexs:
  102. img_i, target_i = self.load_image_target(index)
  103. image_list.append(img_i)
  104. target_list.append(target_i)
  105. # Mosaic
  106. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  107. image, target = yolov5_mosaic_augment(
  108. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  109. return image, target
  110. def load_mixup(self, origin_image, origin_target):
  111. # YOLOv5 type Mixup
  112. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  113. new_index = np.random.randint(0, len(self.ids))
  114. new_image, new_target = self.load_mosaic(new_index)
  115. image, target = yolov5_mixup_augment(
  116. origin_image, origin_target, new_image, new_target)
  117. # YOLOX type Mixup
  118. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  119. new_index = np.random.randint(0, len(self.ids))
  120. new_image, new_target = self.load_image_target(new_index)
  121. image, target = yolox_mixup_augment(
  122. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  123. return image, target
  124. # ------------ Load data function ------------
  125. def load_image_target(self, index):
  126. # == Load a data from the cached data ==
  127. if self.cached_datas is not None:
  128. # load a data
  129. data_item = self.cached_datas[index]
  130. image = data_item["image"]
  131. target = data_item["target"]
  132. # == Load a data from the local disk ==
  133. else:
  134. # load an image
  135. image, _ = self.pull_image(index)
  136. height, width, channels = image.shape
  137. # load a target
  138. bboxes, labels = self.pull_anno(index)
  139. target = {
  140. "boxes": bboxes,
  141. "labels": labels,
  142. "orig_size": [height, width]
  143. }
  144. return image, target
  145. def pull_item(self, index):
  146. if random.random() < self.mosaic_prob:
  147. # load a mosaic image
  148. mosaic = True
  149. image, target = self.load_mosaic(index)
  150. else:
  151. mosaic = False
  152. # load an image and target
  153. image, target = self.load_image_target(index)
  154. # MixUp
  155. if random.random() < self.mixup_prob:
  156. image, target = self.load_mixup(image, target)
  157. # augment
  158. image, target, deltas = self.transform(image, target, mosaic)
  159. return image, target, deltas
  160. def pull_image(self, index):
  161. img_id = self.ids[index]
  162. img_file = os.path.join(self.data_dir, self.image_set,
  163. '{:012}'.format(img_id) + '.jpg')
  164. image = cv2.imread(img_file)
  165. if self.json_file == 'instances_val5k.json' and image is None:
  166. img_file = os.path.join(self.data_dir, 'train2017',
  167. '{:012}'.format(img_id) + '.jpg')
  168. image = cv2.imread(img_file)
  169. assert image is not None
  170. return image, img_id
  171. def pull_anno(self, index):
  172. img_id = self.ids[index]
  173. im_ann = self.coco.loadImgs(img_id)[0]
  174. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  175. annotations = self.coco.loadAnns(anno_ids)
  176. # image infor
  177. width = im_ann['width']
  178. height = im_ann['height']
  179. #load a target
  180. bboxes = []
  181. labels = []
  182. for anno in annotations:
  183. if 'bbox' in anno and anno['area'] > 0:
  184. # bbox
  185. x1 = np.max((0, anno['bbox'][0]))
  186. y1 = np.max((0, anno['bbox'][1]))
  187. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  188. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  189. if x2 < x1 or y2 < y1:
  190. continue
  191. # class label
  192. cls_id = self.class_ids.index(anno['category_id'])
  193. bboxes.append([x1, y1, x2, y2])
  194. labels.append(cls_id)
  195. # guard against no boxes via resizing
  196. bboxes = np.array(bboxes).reshape(-1, 4)
  197. labels = np.array(labels).reshape(-1)
  198. return bboxes, labels
  199. if __name__ == "__main__":
  200. import time
  201. import argparse
  202. from build import build_transform
  203. parser = argparse.ArgumentParser(description='COCO-Dataset')
  204. # opt
  205. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  206. help='data root')
  207. parser.add_argument('--image_set', type=str, default='train2017',
  208. help='mixup augmentation.')
  209. parser.add_argument('-size', '--img_size', default=640, type=int,
  210. help='input image size.')
  211. parser.add_argument('--aug_type', type=str, default='ssd',
  212. help='augmentation type')
  213. parser.add_argument('--mosaic', default=0., type=float,
  214. help='mosaic augmentation.')
  215. parser.add_argument('--mixup', default=0., type=float,
  216. help='mixup augmentation.')
  217. parser.add_argument('--mixup_type', type=str, default='yolov5_mixup',
  218. help='mixup augmentation.')
  219. parser.add_argument('--is_train', action="store_true", default=False,
  220. help='mixup augmentation.')
  221. parser.add_argument('--load_cache', action="store_true", default=False,
  222. help='load cached data.')
  223. args = parser.parse_args()
  224. trans_config = {
  225. 'aug_type': args.aug_type, # optional: ssd, yolov5
  226. # Basic Augment
  227. 'degrees': 0.0,
  228. 'translate': 0.2,
  229. 'scale': [0.1, 2.0],
  230. 'shear': 0.0,
  231. 'perspective': 0.0,
  232. 'hsv_h': 0.015,
  233. 'hsv_s': 0.7,
  234. 'hsv_v': 0.4,
  235. 'use_ablu': True,
  236. # Mosaic & Mixup
  237. 'mosaic_prob': args.mosaic,
  238. 'mixup_prob': args.mixup,
  239. 'mosaic_type': 'yolov5_mosaic',
  240. 'mixup_type': args.mixup_type, # optional: yolov5_mixup, yolox_mixup
  241. 'mixup_scale': [0.5, 1.5]
  242. }
  243. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  244. dataset = COCODataset(
  245. img_size=args.img_size,
  246. data_dir=args.root,
  247. image_set='val2017',
  248. trans_config=trans_config,
  249. transform=transform,
  250. is_train=args.is_train,
  251. load_cache=args.load_cache
  252. )
  253. np.random.seed(0)
  254. class_colors = [(np.random.randint(255),
  255. np.random.randint(255),
  256. np.random.randint(255)) for _ in range(80)]
  257. print('Data length: ', len(dataset))
  258. for i in range(1000):
  259. t0 = time.time()
  260. image, target, deltas = dataset.pull_item(i)
  261. print("Load data: {} s".format(time.time() - t0))
  262. # to numpy
  263. image = image.permute(1, 2, 0).numpy()
  264. # to uint8
  265. image = image.astype(np.uint8)
  266. image = image.copy()
  267. img_h, img_w = image.shape[:2]
  268. boxes = target["boxes"]
  269. labels = target["labels"]
  270. for box, label in zip(boxes, labels):
  271. x1, y1, x2, y2 = box
  272. cls_id = int(label)
  273. color = class_colors[cls_id]
  274. # class name
  275. label = coco_class_labels[coco_class_index[cls_id]]
  276. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  277. # put the test on the bbox
  278. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  279. cv2.imshow('gt', image)
  280. # cv2.imwrite(str(i)+'.jpg', img)
  281. cv2.waitKey(0)