coco.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. import os
  2. import random
  3. import numpy as np
  4. import time
  5. import torch
  6. from torch.utils.data import Dataset
  7. import cv2
  8. try:
  9. from pycocotools.coco import COCO
  10. except:
  11. print("It seems that the COCOAPI is not installed.")
  12. try:
  13. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. except:
  15. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  16. coco_class_labels = ('background',
  17. 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
  18. 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign',
  19. 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
  20. 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella',
  21. 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
  22. 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
  23. 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass',
  24. 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
  25. 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
  26. 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk',
  27. 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
  28. 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book',
  29. 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  30. coco_class_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20,
  31. 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
  32. 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67,
  33. 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  34. class COCODataset(Dataset):
  35. """
  36. COCO dataset class.
  37. """
  38. def __init__(self,
  39. img_size=640,
  40. data_dir=None,
  41. image_set='train2017',
  42. trans_config=None,
  43. transform=None,
  44. is_train=False,
  45. load_cache=False):
  46. """
  47. COCO dataset initialization. Annotation data are read into memory by COCO API.
  48. Args:
  49. data_dir (str): dataset root directory
  50. json_file (str): COCO json file name
  51. name (str): COCO data name (e.g. 'train2017' or 'val2017')
  52. debug (bool): if True, only one data id is selected from the dataset
  53. """
  54. if image_set == 'train2017':
  55. self.json_file='instances_train2017.json'
  56. elif image_set == 'val2017':
  57. self.json_file='instances_val2017.json'
  58. elif image_set == 'test2017':
  59. self.json_file='image_info_test-dev2017.json'
  60. self.img_size = img_size
  61. self.image_set = image_set
  62. self.data_dir = data_dir
  63. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  64. self.ids = self.coco.getImgIds()
  65. self.class_ids = sorted(self.coco.getCatIds())
  66. self.is_train = is_train
  67. # augmentation
  68. self.transform = transform
  69. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  70. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  71. self.trans_config = trans_config
  72. print('==============================')
  73. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  74. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  75. print('==============================')
  76. # load cache data
  77. if load_cache:
  78. self._load_cache()
  79. def __len__(self):
  80. return len(self.ids)
  81. def __getitem__(self, index):
  82. return self.pull_item(index)
  83. def _load_cache(self):
  84. # load image cache
  85. self.cached_images = []
  86. self.cached_targets = []
  87. dataset_size = len(self.ids)
  88. for i in range(dataset_size):
  89. if i % 5000 == 0:
  90. print("[{} / {}]".format(i, dataset_size))
  91. # load an image
  92. image, image_id = self.pull_image(i)
  93. orig_h, orig_w, _ = image.shape
  94. # resize image
  95. r = args.img_size / max(orig_h, orig_w)
  96. if r != 1:
  97. interp = cv2.INTER_LINEAR
  98. new_size = (int(orig_w * r), int(orig_h * r))
  99. image = cv2.resize(image, new_size, interpolation=interp)
  100. img_h, img_w = image.shape[:2]
  101. self.cached_images.append(image)
  102. # load target cache
  103. bboxes, labels = self.pull_anno(i)
  104. bboxes[:, [0, 2]] = bboxes[:, [0, 2]] / orig_w * img_w
  105. bboxes[:, [1, 3]] = bboxes[:, [1, 3]] / orig_h * img_h
  106. self.cached_targets.append({"boxes": bboxes, "labels": labels})
  107. def load_image_target(self, index):
  108. if self.load_cache:
  109. # load data from cache
  110. image = self.cached_images[index]
  111. target = self.cached_targets[index]
  112. height, width, channels = image.shape
  113. target["orig_size"] = [height, width]
  114. else:
  115. # load an image
  116. image, _ = self.pull_image(index)
  117. height, width, channels = image.shape
  118. # load a target
  119. bboxes, labels = self.pull_anno(index)
  120. target = {
  121. "boxes": bboxes,
  122. "labels": labels,
  123. "orig_size": [height, width]
  124. }
  125. return image, target
  126. def load_mosaic(self, index):
  127. # load 4x mosaic image
  128. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  129. id1 = index
  130. id2, id3, id4 = random.sample(index_list, 3)
  131. indexs = [id1, id2, id3, id4]
  132. # load images and targets
  133. image_list = []
  134. target_list = []
  135. for index in indexs:
  136. img_i, target_i = self.load_image_target(index)
  137. image_list.append(img_i)
  138. target_list.append(target_i)
  139. # Mosaic
  140. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  141. image, target = yolov5_mosaic_augment(
  142. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  143. return image, target
  144. def load_mixup(self, origin_image, origin_target):
  145. # YOLOv5 type Mixup
  146. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  147. new_index = np.random.randint(0, len(self.ids))
  148. new_image, new_target = self.load_mosaic(new_index)
  149. image, target = yolov5_mixup_augment(
  150. origin_image, origin_target, new_image, new_target)
  151. # YOLOX type Mixup
  152. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  153. new_index = np.random.randint(0, len(self.ids))
  154. new_image, new_target = self.load_image_target(new_index)
  155. image, target = yolox_mixup_augment(
  156. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  157. return image, target
  158. def pull_item(self, index):
  159. if random.random() < self.mosaic_prob:
  160. # load a mosaic image
  161. mosaic = True
  162. image, target = self.load_mosaic(index)
  163. else:
  164. mosaic = False
  165. # load an image and target
  166. image, target = self.load_image_target(index)
  167. # MixUp
  168. if random.random() < self.mixup_prob:
  169. image, target = self.load_mixup(image, target)
  170. # augment
  171. image, target, deltas = self.transform(image, target, mosaic)
  172. return image, target, deltas
  173. def pull_image(self, index):
  174. img_id = self.ids[index]
  175. img_file = os.path.join(self.data_dir, self.image_set,
  176. '{:012}'.format(img_id) + '.jpg')
  177. image = cv2.imread(img_file)
  178. if self.json_file == 'instances_val5k.json' and image is None:
  179. img_file = os.path.join(self.data_dir, 'train2017',
  180. '{:012}'.format(img_id) + '.jpg')
  181. image = cv2.imread(img_file)
  182. assert image is not None
  183. return image, img_id
  184. def pull_anno(self, index):
  185. img_id = self.ids[index]
  186. im_ann = self.coco.loadImgs(img_id)[0]
  187. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  188. annotations = self.coco.loadAnns(anno_ids)
  189. # image infor
  190. width = im_ann['width']
  191. height = im_ann['height']
  192. #load a target
  193. bboxes = []
  194. labels = []
  195. for anno in annotations:
  196. if 'bbox' in anno and anno['area'] > 0:
  197. # bbox
  198. x1 = np.max((0, anno['bbox'][0]))
  199. y1 = np.max((0, anno['bbox'][1]))
  200. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  201. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  202. if x2 < x1 or y2 < y1:
  203. continue
  204. # class label
  205. cls_id = self.class_ids.index(anno['category_id'])
  206. bboxes.append([x1, y1, x2, y2])
  207. labels.append(cls_id)
  208. # guard against no boxes via resizing
  209. bboxes = np.array(bboxes).reshape(-1, 4)
  210. labels = np.array(labels).reshape(-1)
  211. return bboxes, labels
  212. if __name__ == "__main__":
  213. import argparse
  214. from build import build_transform
  215. parser = argparse.ArgumentParser(description='COCO-Dataset')
  216. # opt
  217. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  218. help='data root')
  219. parser.add_argument('-size', '--img_size', default=640, type=int,
  220. help='input image size.')
  221. parser.add_argument('--mosaic', default=None, type=float,
  222. help='mosaic augmentation.')
  223. parser.add_argument('--mixup', default=None, type=float,
  224. help='mixup augmentation.')
  225. parser.add_argument('--is_train', action="store_true", default=False,
  226. help='mixup augmentation.')
  227. parser.add_argument('--load_cache', action="store_true", default=False,
  228. help='load cached data.')
  229. args = parser.parse_args()
  230. trans_config = {
  231. 'aug_type': 'yolov5', # optional: ssd, yolov5
  232. # Basic Augment
  233. 'degrees': 0.0,
  234. 'translate': 0.2,
  235. 'scale': [0.5, 2.0],
  236. 'shear': 0.0,
  237. 'perspective': 0.0,
  238. 'hsv_h': 0.015,
  239. 'hsv_s': 0.7,
  240. 'hsv_v': 0.4,
  241. # Mosaic & Mixup
  242. 'mosaic_prob': 1.0,
  243. 'mixup_prob': 1.0,
  244. 'mosaic_type': 'yolov5_mosaic',
  245. 'mixup_type': 'yolov5_mixup',
  246. 'mixup_scale': [0.5, 1.5]
  247. }
  248. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  249. dataset = COCODataset(
  250. img_size=args.img_size,
  251. data_dir=args.root,
  252. image_set='val2017',
  253. trans_config=trans_config,
  254. transform=transform,
  255. is_train=args.is_train,
  256. load_cache=args.load_cache
  257. )
  258. np.random.seed(0)
  259. class_colors = [(np.random.randint(255),
  260. np.random.randint(255),
  261. np.random.randint(255)) for _ in range(80)]
  262. print('Data length: ', len(dataset))
  263. for i in range(1000):
  264. image, target, deltas = dataset.pull_item(i)
  265. # to numpy
  266. image = image.permute(1, 2, 0).numpy()
  267. # to uint8
  268. image = image.astype(np.uint8)
  269. image = image.copy()
  270. img_h, img_w = image.shape[:2]
  271. boxes = target["boxes"]
  272. labels = target["labels"]
  273. for box, label in zip(boxes, labels):
  274. x1, y1, x2, y2 = box
  275. cls_id = int(label)
  276. color = class_colors[cls_id]
  277. # class name
  278. label = coco_class_labels[coco_class_index[cls_id]]
  279. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  280. # put the test on the bbox
  281. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  282. cv2.imshow('gt', image)
  283. # cv2.imwrite(str(i)+'.jpg', img)
  284. cv2.waitKey(0)