coco.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. import os
  2. import cv2
  3. import time
  4. import random
  5. import numpy as np
  6. import torch
  7. from torch.utils.data import Dataset
  8. try:
  9. from pycocotools.coco import COCO
  10. except:
  11. print("It seems that the COCOAPI is not installed.")
  12. try:
  13. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. except:
  15. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  16. coco_class_labels = ('background',
  17. 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
  18. 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign',
  19. 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
  20. 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella',
  21. 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
  22. 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
  23. 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass',
  24. 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
  25. 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
  26. 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk',
  27. 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
  28. 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book',
  29. 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  30. coco_class_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20,
  31. 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
  32. 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67,
  33. 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  34. class COCODataset(Dataset):
  35. """
  36. COCO dataset class.
  37. """
  38. def __init__(self,
  39. img_size :int = 640,
  40. data_dir :str = None,
  41. image_set :str = 'train2017',
  42. trans_config = None,
  43. transform = None,
  44. is_train :bool =False,
  45. load_cache :str = None):
  46. """
  47. COCO dataset initialization. Annotation data are read into memory by COCO API.
  48. Args:
  49. data_dir (str): dataset root directory
  50. json_file (str): COCO json file name
  51. name (str): COCO data name (e.g. 'train2017' or 'val2017')
  52. debug (bool): if True, only one data id is selected from the dataset
  53. """
  54. # ----------- Basic parameters -----------
  55. self.img_size = img_size
  56. self.image_set = image_set
  57. self.is_train = is_train
  58. self.load_cache = load_cache
  59. # ----------- Path parameters -----------
  60. self.data_dir = data_dir
  61. if image_set == 'train2017':
  62. self.json_file='instances_train2017.json'
  63. elif image_set == 'val2017':
  64. self.json_file='instances_val2017.json'
  65. elif image_set == 'test2017':
  66. self.json_file='image_info_test-dev2017.json'
  67. # ----------- Data parameters -----------
  68. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  69. self.ids = self.coco.getImgIds()
  70. self.class_ids = sorted(self.coco.getCatIds())
  71. self.dataset_size = len(self.ids)
  72. # ----------- Transform parameters -----------
  73. self.transform = transform
  74. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  75. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  76. self.trans_config = trans_config
  77. print('==============================')
  78. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  79. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  80. print('==============================')
  81. # load cache data
  82. if load_cache is not None:
  83. self._load_cache()
  84. # ------------ Basic dataset function ------------
  85. def __len__(self):
  86. return len(self.ids)
  87. def __getitem__(self, index):
  88. return self.pull_item(index)
  89. def _load_cache(self):
  90. # load image cache
  91. try:
  92. print("Loading cached data ...")
  93. self.cached_datas = torch.load(self.load_cache)
  94. self.dataset_size = len(self.cached_datas)
  95. print("Loading done !")
  96. except:
  97. self.cached_datas = None
  98. self.load_cache = None
  99. print("{} does not exits.".format(self.load_cache))
  100. # ------------ Mosaic & Mixup ------------
  101. def load_mosaic(self, index):
  102. # load 4x mosaic image
  103. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  104. id1 = index
  105. id2, id3, id4 = random.sample(index_list, 3)
  106. indexs = [id1, id2, id3, id4]
  107. # load images and targets
  108. image_list = []
  109. target_list = []
  110. for index in indexs:
  111. img_i, target_i = self.load_image_target(index)
  112. image_list.append(img_i)
  113. target_list.append(target_i)
  114. # Mosaic
  115. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  116. image, target = yolov5_mosaic_augment(
  117. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  118. return image, target
  119. def load_mixup(self, origin_image, origin_target):
  120. # YOLOv5 type Mixup
  121. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  122. new_index = np.random.randint(0, len(self.ids))
  123. new_image, new_target = self.load_mosaic(new_index)
  124. image, target = yolov5_mixup_augment(
  125. origin_image, origin_target, new_image, new_target)
  126. # YOLOX type Mixup
  127. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  128. new_index = np.random.randint(0, len(self.ids))
  129. new_image, new_target = self.load_image_target(new_index)
  130. image, target = yolox_mixup_augment(
  131. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  132. return image, target
  133. # ------------ Load data function ------------
  134. def load_image_target(self, index):
  135. # == Load a data from the cached data ==
  136. if self.load_cache is not None and self.is_train:
  137. # load a data
  138. data_item = self.cached_datas[index]
  139. image = data_item["image"]
  140. target = data_item["target"]
  141. # == Load a data from the local disk ==
  142. else:
  143. # load an image
  144. image, _ = self.pull_image(index)
  145. height, width, channels = image.shape
  146. # load a target
  147. bboxes, labels = self.pull_anno(index)
  148. target = {
  149. "boxes": bboxes,
  150. "labels": labels,
  151. "orig_size": [height, width]
  152. }
  153. return image, target
  154. def pull_item(self, index):
  155. if random.random() < self.mosaic_prob:
  156. # load a mosaic image
  157. mosaic = True
  158. image, target = self.load_mosaic(index)
  159. else:
  160. mosaic = False
  161. # load an image and target
  162. image, target = self.load_image_target(index)
  163. # MixUp
  164. if random.random() < self.mixup_prob:
  165. image, target = self.load_mixup(image, target)
  166. # augment
  167. image, target, deltas = self.transform(image, target, mosaic)
  168. return image, target, deltas
  169. def pull_image(self, index):
  170. img_id = self.ids[index]
  171. img_file = os.path.join(self.data_dir, self.image_set,
  172. '{:012}'.format(img_id) + '.jpg')
  173. image = cv2.imread(img_file)
  174. if self.json_file == 'instances_val5k.json' and image is None:
  175. img_file = os.path.join(self.data_dir, 'train2017',
  176. '{:012}'.format(img_id) + '.jpg')
  177. image = cv2.imread(img_file)
  178. assert image is not None
  179. return image, img_id
  180. def pull_anno(self, index):
  181. img_id = self.ids[index]
  182. im_ann = self.coco.loadImgs(img_id)[0]
  183. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  184. annotations = self.coco.loadAnns(anno_ids)
  185. # image infor
  186. width = im_ann['width']
  187. height = im_ann['height']
  188. #load a target
  189. bboxes = []
  190. labels = []
  191. for anno in annotations:
  192. if 'bbox' in anno and anno['area'] > 0:
  193. # bbox
  194. x1 = np.max((0, anno['bbox'][0]))
  195. y1 = np.max((0, anno['bbox'][1]))
  196. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  197. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  198. if x2 < x1 or y2 < y1:
  199. continue
  200. # class label
  201. cls_id = self.class_ids.index(anno['category_id'])
  202. bboxes.append([x1, y1, x2, y2])
  203. labels.append(cls_id)
  204. # guard against no boxes via resizing
  205. bboxes = np.array(bboxes).reshape(-1, 4)
  206. labels = np.array(labels).reshape(-1)
  207. return bboxes, labels
  208. if __name__ == "__main__":
  209. import argparse
  210. from build import build_transform
  211. parser = argparse.ArgumentParser(description='COCO-Dataset')
  212. # opt
  213. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  214. help='data root')
  215. parser.add_argument('-size', '--img_size', default=640, type=int,
  216. help='input image size.')
  217. parser.add_argument('--mosaic', default=None, type=float,
  218. help='mosaic augmentation.')
  219. parser.add_argument('--mixup', default=None, type=float,
  220. help='mixup augmentation.')
  221. parser.add_argument('--is_train', action="store_true", default=False,
  222. help='mixup augmentation.')
  223. parser.add_argument('--load_cache', action="store_true", default=False,
  224. help='load cached data.')
  225. args = parser.parse_args()
  226. trans_config = {
  227. 'aug_type': 'yolov5', # optional: ssd, yolov5
  228. # Basic Augment
  229. 'degrees': 0.0,
  230. 'translate': 0.2,
  231. 'scale': [0.5, 2.0],
  232. 'shear': 0.0,
  233. 'perspective': 0.0,
  234. 'hsv_h': 0.015,
  235. 'hsv_s': 0.7,
  236. 'hsv_v': 0.4,
  237. # Mosaic & Mixup
  238. 'mosaic_prob': 1.0,
  239. 'mixup_prob': 1.0,
  240. 'mosaic_type': 'yolov5_mosaic',
  241. 'mixup_type': 'yolov5_mixup',
  242. 'mixup_scale': [0.5, 1.5]
  243. }
  244. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  245. dataset = COCODataset(
  246. img_size=args.img_size,
  247. data_dir=args.root,
  248. image_set='val2017',
  249. trans_config=trans_config,
  250. transform=transform,
  251. is_train=args.is_train,
  252. load_cache=args.load_cache
  253. )
  254. np.random.seed(0)
  255. class_colors = [(np.random.randint(255),
  256. np.random.randint(255),
  257. np.random.randint(255)) for _ in range(80)]
  258. print('Data length: ', len(dataset))
  259. for i in range(1000):
  260. image, target, deltas = dataset.pull_item(i)
  261. # to numpy
  262. image = image.permute(1, 2, 0).numpy()
  263. # to uint8
  264. image = image.astype(np.uint8)
  265. image = image.copy()
  266. img_h, img_w = image.shape[:2]
  267. boxes = target["boxes"]
  268. labels = target["labels"]
  269. for box, label in zip(boxes, labels):
  270. x1, y1, x2, y2 = box
  271. cls_id = int(label)
  272. color = class_colors[cls_id]
  273. # class name
  274. label = coco_class_labels[coco_class_index[cls_id]]
  275. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  276. # put the test on the bbox
  277. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  278. cv2.imshow('gt', image)
  279. # cv2.imwrite(str(i)+'.jpg', img)
  280. cv2.waitKey(0)