coco.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. import os
  2. import random
  3. import numpy as np
  4. import time
  5. import torch
  6. from torch.utils.data import Dataset
  7. import cv2
  8. try:
  9. from pycocotools.coco import COCO
  10. except:
  11. print("It seems that the COCOAPI is not installed.")
  12. try:
  13. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. except:
  15. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  16. coco_class_labels = ('background',
  17. 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
  18. 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign',
  19. 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
  20. 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella',
  21. 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
  22. 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
  23. 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass',
  24. 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
  25. 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
  26. 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk',
  27. 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
  28. 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book',
  29. 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  30. coco_class_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20,
  31. 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
  32. 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67,
  33. 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  34. class COCODataset(Dataset):
  35. """
  36. COCO dataset class.
  37. """
  38. def __init__(self,
  39. img_size=640,
  40. data_dir=None,
  41. image_set='train2017',
  42. trans_config=None,
  43. transform=None,
  44. is_train=False,
  45. load_cache=False):
  46. """
  47. COCO dataset initialization. Annotation data are read into memory by COCO API.
  48. Args:
  49. data_dir (str): dataset root directory
  50. json_file (str): COCO json file name
  51. name (str): COCO data name (e.g. 'train2017' or 'val2017')
  52. debug (bool): if True, only one data id is selected from the dataset
  53. """
  54. if image_set == 'train2017':
  55. self.json_file='instances_train2017.json'
  56. elif image_set == 'val2017':
  57. self.json_file='instances_val2017.json'
  58. elif image_set == 'test2017':
  59. self.json_file='image_info_test-dev2017.json'
  60. self.img_size = img_size
  61. self.image_set = image_set
  62. self.data_dir = data_dir
  63. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  64. self.ids = self.coco.getImgIds()
  65. self.class_ids = sorted(self.coco.getCatIds())
  66. self.is_train = is_train
  67. self.load_cache = load_cache
  68. # augmentation
  69. self.transform = transform
  70. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  71. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  72. self.trans_config = trans_config
  73. print('==============================')
  74. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  75. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  76. print('==============================')
  77. # load cache data
  78. if load_cache:
  79. self._load_cache()
  80. def __len__(self):
  81. return len(self.ids)
  82. def __getitem__(self, index):
  83. return self.pull_item(index)
  84. def _load_cache(self):
  85. # load image cache
  86. self.cached_images = []
  87. self.cached_targets = []
  88. dataset_size = len(self.ids)
  89. print('loading data into cache ...')
  90. for i in range(dataset_size):
  91. if i % 5000 == 0 and i > 0:
  92. print("[{} / {}]".format(i, dataset_size))
  93. break
  94. # load an image
  95. image, image_id = self.pull_image(i)
  96. orig_h, orig_w, _ = image.shape
  97. # resize image
  98. r = args.img_size / max(orig_h, orig_w)
  99. if r != 1:
  100. interp = cv2.INTER_LINEAR
  101. new_size = (int(orig_w * r), int(orig_h * r))
  102. image = cv2.resize(image, new_size, interpolation=interp)
  103. img_h, img_w = image.shape[:2]
  104. self.cached_images.append(image)
  105. # load target cache
  106. bboxes, labels = self.pull_anno(i)
  107. bboxes[:, [0, 2]] = bboxes[:, [0, 2]] / orig_w * img_w
  108. bboxes[:, [1, 3]] = bboxes[:, [1, 3]] / orig_h * img_h
  109. self.cached_targets.append({"boxes": bboxes, "labels": labels})
  110. def load_image_target(self, index):
  111. if self.load_cache:
  112. # load data from cache
  113. image = self.cached_images[index]
  114. target = self.cached_targets[index]
  115. height, width, channels = image.shape
  116. target["orig_size"] = [height, width]
  117. else:
  118. # load an image
  119. image, _ = self.pull_image(index)
  120. height, width, channels = image.shape
  121. # load a target
  122. bboxes, labels = self.pull_anno(index)
  123. target = {
  124. "boxes": bboxes,
  125. "labels": labels,
  126. "orig_size": [height, width]
  127. }
  128. return image, target
  129. def load_mosaic(self, index):
  130. # load 4x mosaic image
  131. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  132. id1 = index
  133. id2, id3, id4 = random.sample(index_list, 3)
  134. indexs = [id1, id2, id3, id4]
  135. # load images and targets
  136. image_list = []
  137. target_list = []
  138. for index in indexs:
  139. img_i, target_i = self.load_image_target(index)
  140. image_list.append(img_i)
  141. target_list.append(target_i)
  142. # Mosaic
  143. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  144. image, target = yolov5_mosaic_augment(
  145. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  146. return image, target
  147. def load_mixup(self, origin_image, origin_target):
  148. # YOLOv5 type Mixup
  149. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  150. new_index = np.random.randint(0, len(self.ids))
  151. new_image, new_target = self.load_mosaic(new_index)
  152. image, target = yolov5_mixup_augment(
  153. origin_image, origin_target, new_image, new_target)
  154. # YOLOX type Mixup
  155. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  156. new_index = np.random.randint(0, len(self.ids))
  157. new_image, new_target = self.load_image_target(new_index)
  158. image, target = yolox_mixup_augment(
  159. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  160. return image, target
  161. def pull_item(self, index):
  162. if random.random() < self.mosaic_prob:
  163. # load a mosaic image
  164. mosaic = True
  165. image, target = self.load_mosaic(index)
  166. else:
  167. mosaic = False
  168. # load an image and target
  169. image, target = self.load_image_target(index)
  170. # MixUp
  171. if random.random() < self.mixup_prob:
  172. image, target = self.load_mixup(image, target)
  173. # augment
  174. image, target, deltas = self.transform(image, target, mosaic)
  175. return image, target, deltas
  176. def pull_image(self, index):
  177. img_id = self.ids[index]
  178. img_file = os.path.join(self.data_dir, self.image_set,
  179. '{:012}'.format(img_id) + '.jpg')
  180. image = cv2.imread(img_file)
  181. if self.json_file == 'instances_val5k.json' and image is None:
  182. img_file = os.path.join(self.data_dir, 'train2017',
  183. '{:012}'.format(img_id) + '.jpg')
  184. image = cv2.imread(img_file)
  185. assert image is not None
  186. return image, img_id
  187. def pull_anno(self, index):
  188. img_id = self.ids[index]
  189. im_ann = self.coco.loadImgs(img_id)[0]
  190. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  191. annotations = self.coco.loadAnns(anno_ids)
  192. # image infor
  193. width = im_ann['width']
  194. height = im_ann['height']
  195. #load a target
  196. bboxes = []
  197. labels = []
  198. for anno in annotations:
  199. if 'bbox' in anno and anno['area'] > 0:
  200. # bbox
  201. x1 = np.max((0, anno['bbox'][0]))
  202. y1 = np.max((0, anno['bbox'][1]))
  203. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  204. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  205. if x2 < x1 or y2 < y1:
  206. continue
  207. # class label
  208. cls_id = self.class_ids.index(anno['category_id'])
  209. bboxes.append([x1, y1, x2, y2])
  210. labels.append(cls_id)
  211. # guard against no boxes via resizing
  212. bboxes = np.array(bboxes).reshape(-1, 4)
  213. labels = np.array(labels).reshape(-1)
  214. return bboxes, labels
  215. if __name__ == "__main__":
  216. import argparse
  217. from build import build_transform
  218. parser = argparse.ArgumentParser(description='COCO-Dataset')
  219. # opt
  220. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  221. help='data root')
  222. parser.add_argument('-size', '--img_size', default=640, type=int,
  223. help='input image size.')
  224. parser.add_argument('--mosaic', default=None, type=float,
  225. help='mosaic augmentation.')
  226. parser.add_argument('--mixup', default=None, type=float,
  227. help='mixup augmentation.')
  228. parser.add_argument('--is_train', action="store_true", default=False,
  229. help='mixup augmentation.')
  230. parser.add_argument('--load_cache', action="store_true", default=False,
  231. help='load cached data.')
  232. args = parser.parse_args()
  233. trans_config = {
  234. 'aug_type': 'yolov5', # optional: ssd, yolov5
  235. # Basic Augment
  236. 'degrees': 0.0,
  237. 'translate': 0.2,
  238. 'scale': [0.5, 2.0],
  239. 'shear': 0.0,
  240. 'perspective': 0.0,
  241. 'hsv_h': 0.015,
  242. 'hsv_s': 0.7,
  243. 'hsv_v': 0.4,
  244. # Mosaic & Mixup
  245. 'mosaic_prob': 1.0,
  246. 'mixup_prob': 1.0,
  247. 'mosaic_type': 'yolov5_mosaic',
  248. 'mixup_type': 'yolov5_mixup',
  249. 'mixup_scale': [0.5, 1.5]
  250. }
  251. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  252. dataset = COCODataset(
  253. img_size=args.img_size,
  254. data_dir=args.root,
  255. image_set='val2017',
  256. trans_config=trans_config,
  257. transform=transform,
  258. is_train=args.is_train,
  259. load_cache=args.load_cache
  260. )
  261. np.random.seed(0)
  262. class_colors = [(np.random.randint(255),
  263. np.random.randint(255),
  264. np.random.randint(255)) for _ in range(80)]
  265. print('Data length: ', len(dataset))
  266. for i in range(1000):
  267. image, target, deltas = dataset.pull_item(i)
  268. # to numpy
  269. image = image.permute(1, 2, 0).numpy()
  270. # to uint8
  271. image = image.astype(np.uint8)
  272. image = image.copy()
  273. img_h, img_w = image.shape[:2]
  274. boxes = target["boxes"]
  275. labels = target["labels"]
  276. for box, label in zip(boxes, labels):
  277. x1, y1, x2, y2 = box
  278. cls_id = int(label)
  279. color = class_colors[cls_id]
  280. # class name
  281. label = coco_class_labels[coco_class_index[cls_id]]
  282. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  283. # put the test on the bbox
  284. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  285. cv2.imshow('gt', image)
  286. # cv2.imwrite(str(i)+'.jpg', img)
  287. cv2.waitKey(0)