coco.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import os
  2. import cv2
  3. import time
  4. import random
  5. import numpy as np
  6. import torch
  7. from torch.utils.data import Dataset
  8. try:
  9. from pycocotools.coco import COCO
  10. except:
  11. print("It seems that the COCOAPI is not installed.")
  12. try:
  13. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. except:
  15. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  16. coco_class_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  17. coco_class_labels = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  18. class COCODataset(Dataset):
  19. def __init__(self,
  20. img_size :int = 640,
  21. data_dir :str = None,
  22. image_set :str = 'train2017',
  23. trans_config = None,
  24. transform = None,
  25. is_train :bool =False,
  26. load_cache :bool = False,
  27. ):
  28. # ----------- Basic parameters -----------
  29. self.img_size = img_size
  30. self.image_set = image_set
  31. self.is_train = is_train
  32. # ----------- Path parameters -----------
  33. self.data_dir = data_dir
  34. if image_set == 'train2017':
  35. self.json_file='instances_train2017.json'
  36. elif image_set == 'val2017':
  37. self.json_file='instances_val2017.json'
  38. elif image_set == 'test2017':
  39. self.json_file='image_info_test-dev2017.json'
  40. # ----------- Data parameters -----------
  41. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  42. self.ids = self.coco.getImgIds()
  43. self.class_ids = sorted(self.coco.getCatIds())
  44. self.dataset_size = len(self.ids)
  45. # ----------- Transform parameters -----------
  46. self.transform = transform
  47. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  48. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  49. self.trans_config = trans_config
  50. print('==============================')
  51. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  52. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  53. print('==============================')
  54. # ----------- Cached data -----------
  55. self.load_cache = load_cache
  56. self.cached_datas = None
  57. if self.load_cache:
  58. self.cached_datas = self._load_cache()
  59. # ------------ Basic dataset function ------------
  60. def __len__(self):
  61. return len(self.ids)
  62. def __getitem__(self, index):
  63. return self.pull_item(index)
  64. def _load_cache(self):
  65. data_items = []
  66. for idx in range(self.dataset_size):
  67. if idx % 2000 == 0:
  68. print("Caching images and targets : {} / {} ...".format(idx, self.dataset_size))
  69. # load a data
  70. image, target = self.load_image_target(idx)
  71. orig_h, orig_w, _ = image.shape
  72. # resize image
  73. r = self.img_size / max(orig_h, orig_w)
  74. if r != 1:
  75. interp = cv2.INTER_LINEAR
  76. new_size = (int(orig_w * r), int(orig_h * r))
  77. image = cv2.resize(image, new_size, interpolation=interp)
  78. img_h, img_w = image.shape[:2]
  79. # rescale bbox
  80. boxes = target["boxes"].copy()
  81. boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
  82. boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
  83. target["boxes"] = boxes
  84. dict_item = {}
  85. dict_item["image"] = image
  86. dict_item["target"] = target
  87. data_items.append(dict_item)
  88. return data_items
  89. # ------------ Mosaic & Mixup ------------
  90. def load_mosaic(self, index):
  91. # load 4x mosaic image
  92. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  93. id1 = index
  94. id2, id3, id4 = random.sample(index_list, 3)
  95. indexs = [id1, id2, id3, id4]
  96. # load images and targets
  97. image_list = []
  98. target_list = []
  99. for index in indexs:
  100. img_i, target_i = self.load_image_target(index)
  101. image_list.append(img_i)
  102. target_list.append(target_i)
  103. # Mosaic
  104. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  105. image, target = yolov5_mosaic_augment(
  106. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  107. return image, target
  108. def load_mixup(self, origin_image, origin_target):
  109. # YOLOv5 type Mixup
  110. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  111. new_index = np.random.randint(0, len(self.ids))
  112. new_image, new_target = self.load_mosaic(new_index)
  113. image, target = yolov5_mixup_augment(
  114. origin_image, origin_target, new_image, new_target)
  115. # YOLOX type Mixup
  116. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  117. new_index = np.random.randint(0, len(self.ids))
  118. new_image, new_target = self.load_image_target(new_index)
  119. image, target = yolox_mixup_augment(
  120. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  121. return image, target
  122. # ------------ Load data function ------------
  123. def load_image_target(self, index):
  124. # == Load a data from the cached data ==
  125. if self.cached_datas is not None:
  126. # load a data
  127. data_item = self.cached_datas[index]
  128. image = data_item["image"]
  129. target = data_item["target"]
  130. # == Load a data from the local disk ==
  131. else:
  132. # load an image
  133. image, _ = self.pull_image(index)
  134. height, width, channels = image.shape
  135. # load a target
  136. bboxes, labels = self.pull_anno(index)
  137. target = {
  138. "boxes": bboxes,
  139. "labels": labels,
  140. "orig_size": [height, width]
  141. }
  142. return image, target
  143. def pull_item(self, index):
  144. if random.random() < self.mosaic_prob:
  145. # load a mosaic image
  146. mosaic = True
  147. image, target = self.load_mosaic(index)
  148. else:
  149. mosaic = False
  150. # load an image and target
  151. image, target = self.load_image_target(index)
  152. # MixUp
  153. if random.random() < self.mixup_prob:
  154. image, target = self.load_mixup(image, target)
  155. # augment
  156. image, target, deltas = self.transform(image, target, mosaic)
  157. return image, target, deltas
  158. def pull_image(self, index):
  159. img_id = self.ids[index]
  160. img_file = os.path.join(self.data_dir, self.image_set,
  161. '{:012}'.format(img_id) + '.jpg')
  162. image = cv2.imread(img_file)
  163. if self.json_file == 'instances_val5k.json' and image is None:
  164. img_file = os.path.join(self.data_dir, 'train2017',
  165. '{:012}'.format(img_id) + '.jpg')
  166. image = cv2.imread(img_file)
  167. assert image is not None
  168. return image, img_id
  169. def pull_anno(self, index):
  170. img_id = self.ids[index]
  171. im_ann = self.coco.loadImgs(img_id)[0]
  172. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  173. annotations = self.coco.loadAnns(anno_ids)
  174. # image infor
  175. width = im_ann['width']
  176. height = im_ann['height']
  177. #load a target
  178. bboxes = []
  179. labels = []
  180. for anno in annotations:
  181. if 'bbox' in anno and anno['area'] > 0:
  182. # bbox
  183. x1 = np.max((0, anno['bbox'][0]))
  184. y1 = np.max((0, anno['bbox'][1]))
  185. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  186. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  187. if x2 < x1 or y2 < y1:
  188. continue
  189. # class label
  190. cls_id = self.class_ids.index(anno['category_id'])
  191. bboxes.append([x1, y1, x2, y2])
  192. labels.append(cls_id)
  193. # guard against no boxes via resizing
  194. bboxes = np.array(bboxes).reshape(-1, 4)
  195. labels = np.array(labels).reshape(-1)
  196. return bboxes, labels
  197. if __name__ == "__main__":
  198. import time
  199. import argparse
  200. from build import build_transform
  201. parser = argparse.ArgumentParser(description='COCO-Dataset')
  202. # opt
  203. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  204. help='data root')
  205. parser.add_argument('-size', '--img_size', default=640, type=int,
  206. help='input image size.')
  207. parser.add_argument('--mosaic', default=None, type=float,
  208. help='mosaic augmentation.')
  209. parser.add_argument('--mixup', default=None, type=float,
  210. help='mixup augmentation.')
  211. parser.add_argument('--is_train', action="store_true", default=False,
  212. help='mixup augmentation.')
  213. parser.add_argument('--load_cache', action="store_true", default=False,
  214. help='load cached data.')
  215. args = parser.parse_args()
  216. trans_config = {
  217. 'aug_type': 'yolov5', # optional: ssd, yolov5
  218. # Basic Augment
  219. 'degrees': 0.0,
  220. 'translate': 0.2,
  221. 'scale': [0.5, 2.0],
  222. 'shear': 0.0,
  223. 'perspective': 0.0,
  224. 'hsv_h': 0.015,
  225. 'hsv_s': 0.7,
  226. 'hsv_v': 0.4,
  227. # Mosaic & Mixup
  228. 'mosaic_prob': 1.0,
  229. 'mixup_prob': 1.0,
  230. 'mosaic_type': 'yolov5_mosaic',
  231. 'mixup_type': 'yolov5_mixup',
  232. 'mixup_scale': [0.5, 1.5]
  233. }
  234. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  235. dataset = COCODataset(
  236. img_size=args.img_size,
  237. data_dir=args.root,
  238. image_set='val2017',
  239. trans_config=trans_config,
  240. transform=transform,
  241. is_train=args.is_train,
  242. load_cache=args.load_cache
  243. )
  244. np.random.seed(0)
  245. class_colors = [(np.random.randint(255),
  246. np.random.randint(255),
  247. np.random.randint(255)) for _ in range(80)]
  248. print('Data length: ', len(dataset))
  249. for i in range(1000):
  250. t0 = time.time()
  251. image, target, deltas = dataset.pull_item(i)
  252. print("Load data: {} s".format(time.time() - t0))
  253. # to numpy
  254. image = image.permute(1, 2, 0).numpy()
  255. # to uint8
  256. image = image.astype(np.uint8)
  257. image = image.copy()
  258. img_h, img_w = image.shape[:2]
  259. boxes = target["boxes"]
  260. labels = target["labels"]
  261. for box, label in zip(boxes, labels):
  262. x1, y1, x2, y2 = box
  263. cls_id = int(label)
  264. color = class_colors[cls_id]
  265. # class name
  266. label = coco_class_labels[coco_class_index[cls_id]]
  267. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  268. # put the test on the bbox
  269. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  270. cv2.imshow('gt', image)
  271. # cv2.imwrite(str(i)+'.jpg', img)
  272. cv2.waitKey(0)