coco.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346
  1. import os
  2. import random
  3. import numpy as np
  4. import time
  5. import torch
  6. from torch.utils.data import Dataset
  7. import cv2
  8. try:
  9. from pycocotools.coco import COCO
  10. except:
  11. print("It seems that the COCOAPI is not installed.")
  12. try:
  13. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. except:
  15. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  16. coco_class_labels = ('background',
  17. 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
  18. 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign',
  19. 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
  20. 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella',
  21. 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
  22. 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
  23. 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass',
  24. 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
  25. 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
  26. 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk',
  27. 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
  28. 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book',
  29. 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  30. coco_class_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20,
  31. 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
  32. 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67,
  33. 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  34. class COCODataset(Dataset):
  35. """
  36. COCO dataset class.
  37. """
  38. def __init__(self,
  39. img_size=640,
  40. data_dir=None,
  41. image_set='train2017',
  42. trans_config=None,
  43. transform=None,
  44. is_train=False,
  45. load_cache=False):
  46. """
  47. COCO dataset initialization. Annotation data are read into memory by COCO API.
  48. Args:
  49. data_dir (str): dataset root directory
  50. json_file (str): COCO json file name
  51. name (str): COCO data name (e.g. 'train2017' or 'val2017')
  52. debug (bool): if True, only one data id is selected from the dataset
  53. """
  54. if image_set == 'train2017':
  55. self.json_file='instances_train2017.json'
  56. elif image_set == 'val2017':
  57. self.json_file='instances_val2017.json'
  58. elif image_set == 'test2017':
  59. self.json_file='image_info_test-dev2017.json'
  60. self.img_size = img_size
  61. self.image_set = image_set
  62. self.data_dir = data_dir
  63. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  64. self.ids = self.coco.getImgIds()
  65. self.class_ids = sorted(self.coco.getCatIds())
  66. self.is_train = is_train
  67. # augmentation
  68. self.transform = transform
  69. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  70. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  71. self.trans_config = trans_config
  72. print('==============================')
  73. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  74. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  75. print('==============================')
  76. # load cache data
  77. if load_cache:
  78. self._load_cache()
  79. def __len__(self):
  80. return len(self.ids)
  81. def __getitem__(self, index):
  82. return self.pull_item(index)
  83. def _load_cache(self):
  84. # load image cache
  85. self.cached_images = []
  86. self.cached_targets = []
  87. dataset_size = len(self.ids)
  88. for i in range(dataset_size):
  89. if i % 5000 == 0 and i > 0:
  90. print("[{} / {}]".format(i, dataset_size))
  91. break
  92. # load an image
  93. image, image_id = self.pull_image(i)
  94. orig_h, orig_w, _ = image.shape
  95. # resize image
  96. r = args.img_size / max(orig_h, orig_w)
  97. if r != 1:
  98. interp = cv2.INTER_LINEAR
  99. new_size = (int(orig_w * r), int(orig_h * r))
  100. image = cv2.resize(image, new_size, interpolation=interp)
  101. img_h, img_w = image.shape[:2]
  102. self.cached_images.append(image)
  103. # load target cache
  104. bboxes, labels = self.pull_anno(i)
  105. bboxes[:, [0, 2]] = bboxes[:, [0, 2]] / orig_w * img_w
  106. bboxes[:, [1, 3]] = bboxes[:, [1, 3]] / orig_h * img_h
  107. self.cached_targets.append({"boxes": bboxes, "labels": labels})
  108. def load_image_target(self, index):
  109. if self.load_cache:
  110. # load data from cache
  111. image = self.cached_images[index]
  112. target = self.cached_targets[index]
  113. height, width, channels = image.shape
  114. target["orig_size"] = [height, width]
  115. else:
  116. # load an image
  117. image, _ = self.pull_image(index)
  118. height, width, channels = image.shape
  119. # load a target
  120. bboxes, labels = self.pull_anno(index)
  121. target = {
  122. "boxes": bboxes,
  123. "labels": labels,
  124. "orig_size": [height, width]
  125. }
  126. return image, target
  127. def load_mosaic(self, index):
  128. # load 4x mosaic image
  129. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  130. id1 = index
  131. id2, id3, id4 = random.sample(index_list, 3)
  132. indexs = [id1, id2, id3, id4]
  133. # load images and targets
  134. image_list = []
  135. target_list = []
  136. for index in indexs:
  137. img_i, target_i = self.load_image_target(index)
  138. image_list.append(img_i)
  139. target_list.append(target_i)
  140. # Mosaic
  141. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  142. image, target = yolov5_mosaic_augment(
  143. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  144. return image, target
  145. def load_mixup(self, origin_image, origin_target):
  146. # YOLOv5 type Mixup
  147. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  148. new_index = np.random.randint(0, len(self.ids))
  149. new_image, new_target = self.load_mosaic(new_index)
  150. image, target = yolov5_mixup_augment(
  151. origin_image, origin_target, new_image, new_target)
  152. # YOLOX type Mixup
  153. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  154. new_index = np.random.randint(0, len(self.ids))
  155. new_image, new_target = self.load_image_target(new_index)
  156. image, target = yolox_mixup_augment(
  157. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  158. return image, target
  159. def pull_item(self, index):
  160. if random.random() < self.mosaic_prob:
  161. # load a mosaic image
  162. mosaic = True
  163. image, target = self.load_mosaic(index)
  164. else:
  165. mosaic = False
  166. # load an image and target
  167. image, target = self.load_image_target(index)
  168. # MixUp
  169. if random.random() < self.mixup_prob:
  170. image, target = self.load_mixup(image, target)
  171. # augment
  172. image, target, deltas = self.transform(image, target, mosaic)
  173. return image, target, deltas
  174. def pull_image(self, index):
  175. img_id = self.ids[index]
  176. img_file = os.path.join(self.data_dir, self.image_set,
  177. '{:012}'.format(img_id) + '.jpg')
  178. image = cv2.imread(img_file)
  179. if self.json_file == 'instances_val5k.json' and image is None:
  180. img_file = os.path.join(self.data_dir, 'train2017',
  181. '{:012}'.format(img_id) + '.jpg')
  182. image = cv2.imread(img_file)
  183. assert image is not None
  184. return image, img_id
  185. def pull_anno(self, index):
  186. img_id = self.ids[index]
  187. im_ann = self.coco.loadImgs(img_id)[0]
  188. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  189. annotations = self.coco.loadAnns(anno_ids)
  190. # image infor
  191. width = im_ann['width']
  192. height = im_ann['height']
  193. #load a target
  194. bboxes = []
  195. labels = []
  196. for anno in annotations:
  197. if 'bbox' in anno and anno['area'] > 0:
  198. # bbox
  199. x1 = np.max((0, anno['bbox'][0]))
  200. y1 = np.max((0, anno['bbox'][1]))
  201. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  202. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  203. if x2 < x1 or y2 < y1:
  204. continue
  205. # class label
  206. cls_id = self.class_ids.index(anno['category_id'])
  207. bboxes.append([x1, y1, x2, y2])
  208. labels.append(cls_id)
  209. # guard against no boxes via resizing
  210. bboxes = np.array(bboxes).reshape(-1, 4)
  211. labels = np.array(labels).reshape(-1)
  212. return bboxes, labels
  213. if __name__ == "__main__":
  214. import argparse
  215. from build import build_transform
  216. parser = argparse.ArgumentParser(description='COCO-Dataset')
  217. # opt
  218. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  219. help='data root')
  220. parser.add_argument('-size', '--img_size', default=640, type=int,
  221. help='input image size.')
  222. parser.add_argument('--mosaic', default=None, type=float,
  223. help='mosaic augmentation.')
  224. parser.add_argument('--mixup', default=None, type=float,
  225. help='mixup augmentation.')
  226. parser.add_argument('--is_train', action="store_true", default=False,
  227. help='mixup augmentation.')
  228. parser.add_argument('--load_cache', action="store_true", default=False,
  229. help='load cached data.')
  230. args = parser.parse_args()
  231. trans_config = {
  232. 'aug_type': 'yolov5', # optional: ssd, yolov5
  233. # Basic Augment
  234. 'degrees': 0.0,
  235. 'translate': 0.2,
  236. 'scale': [0.5, 2.0],
  237. 'shear': 0.0,
  238. 'perspective': 0.0,
  239. 'hsv_h': 0.015,
  240. 'hsv_s': 0.7,
  241. 'hsv_v': 0.4,
  242. # Mosaic & Mixup
  243. 'mosaic_prob': 1.0,
  244. 'mixup_prob': 1.0,
  245. 'mosaic_type': 'yolov5_mosaic',
  246. 'mixup_type': 'yolov5_mixup',
  247. 'mixup_scale': [0.5, 1.5]
  248. }
  249. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  250. dataset = COCODataset(
  251. img_size=args.img_size,
  252. data_dir=args.root,
  253. image_set='val2017',
  254. trans_config=trans_config,
  255. transform=transform,
  256. is_train=args.is_train,
  257. load_cache=args.load_cache
  258. )
  259. np.random.seed(0)
  260. class_colors = [(np.random.randint(255),
  261. np.random.randint(255),
  262. np.random.randint(255)) for _ in range(80)]
  263. print('Data length: ', len(dataset))
  264. for i in range(1000):
  265. image, target, deltas = dataset.pull_item(i)
  266. # to numpy
  267. image = image.permute(1, 2, 0).numpy()
  268. # to uint8
  269. image = image.astype(np.uint8)
  270. image = image.copy()
  271. img_h, img_w = image.shape[:2]
  272. boxes = target["boxes"]
  273. labels = target["labels"]
  274. for box, label in zip(boxes, labels):
  275. x1, y1, x2, y2 = box
  276. cls_id = int(label)
  277. color = class_colors[cls_id]
  278. # class name
  279. label = coco_class_labels[coco_class_index[cls_id]]
  280. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  281. # put the test on the bbox
  282. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  283. cv2.imshow('gt', image)
  284. # cv2.imwrite(str(i)+'.jpg', img)
  285. cv2.waitKey(0)