coco.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import os
  2. import cv2
  3. import time
  4. import random
  5. import numpy as np
  6. from torch.utils.data import Dataset
  7. try:
  8. from pycocotools.coco import COCO
  9. except:
  10. print("It seems that the COCOAPI is not installed.")
  11. try:
  12. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  13. except:
  14. from data_augment.strong_augment import MosaicAugment, MixupAugment
  15. coco_class_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  16. coco_class_labels = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  17. class COCODataset(Dataset):
  18. def __init__(self,
  19. img_size :int = 640,
  20. data_dir :str = None,
  21. image_set :str = 'train2017',
  22. trans_config = None,
  23. transform = None,
  24. is_train :bool =False,
  25. load_cache :bool = False,
  26. ):
  27. # ----------- Basic parameters -----------
  28. self.img_size = img_size
  29. self.image_set = image_set
  30. self.is_train = is_train
  31. # ----------- Path parameters -----------
  32. self.data_dir = data_dir
  33. if image_set == 'train2017':
  34. self.json_file='instances_train2017_clean.json'
  35. elif image_set == 'val2017':
  36. self.json_file='instances_val2017_clean.json'
  37. elif image_set == 'test2017':
  38. self.json_file='image_info_test-dev2017.json'
  39. else:
  40. raise NotImplementedError("Unknown json image set {}.".format(image_set))
  41. # ----------- Data parameters -----------
  42. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  43. self.ids = self.coco.getImgIds()
  44. self.class_ids = sorted(self.coco.getCatIds())
  45. self.dataset_size = len(self.ids)
  46. # ----------- Transform parameters -----------
  47. self.trans_config = trans_config
  48. self.transform = transform
  49. # ----------- Strong augmentation -----------
  50. if is_train:
  51. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  52. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  53. self.mosaic_augment = MosaicAugment(img_size, trans_config, is_train)
  54. self.mixup_augment = MixupAugment(img_size, trans_config)
  55. else:
  56. self.mosaic_prob = 0.0
  57. self.mixup_prob = 0.0
  58. self.mosaic_augment = None
  59. self.mixup_augment = None
  60. print('==============================')
  61. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  62. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  63. print('==============================')
  64. # ----------- Cached data -----------
  65. self.load_cache = load_cache
  66. self.cached_datas = None
  67. if self.load_cache:
  68. self.cached_datas = self._load_cache()
  69. # ------------ Basic dataset function ------------
  70. def __len__(self):
  71. return len(self.ids)
  72. def __getitem__(self, index):
  73. return self.pull_item(index)
  74. def _load_cache(self):
  75. data_items = []
  76. for idx in range(self.dataset_size):
  77. if idx % 2000 == 0:
  78. print("Caching images and targets : {} / {} ...".format(idx, self.dataset_size))
  79. # load a data
  80. image, target = self.load_image_target(idx)
  81. orig_h, orig_w, _ = image.shape
  82. # resize image
  83. r = self.img_size / max(orig_h, orig_w)
  84. if r != 1:
  85. interp = cv2.INTER_LINEAR
  86. new_size = (int(orig_w * r), int(orig_h * r))
  87. image = cv2.resize(image, new_size, interpolation=interp)
  88. img_h, img_w = image.shape[:2]
  89. # rescale bbox
  90. boxes = target["boxes"].copy()
  91. boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
  92. boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
  93. target["boxes"] = boxes
  94. dict_item = {}
  95. dict_item["image"] = image
  96. dict_item["target"] = target
  97. data_items.append(dict_item)
  98. return data_items
  99. # ------------ Mosaic & Mixup ------------
  100. def load_mosaic(self, index):
  101. # ------------ Prepare 4 indexes of images ------------
  102. ## Load 4x mosaic image
  103. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  104. id1 = index
  105. id2, id3, id4 = random.sample(index_list, 3)
  106. indexs = [id1, id2, id3, id4]
  107. ## Load images and targets
  108. image_list = []
  109. target_list = []
  110. for index in indexs:
  111. img_i, target_i = self.load_image_target(index)
  112. image_list.append(img_i)
  113. target_list.append(target_i)
  114. # ------------ Mosaic augmentation ------------
  115. image, target = self.mosaic_augment(image_list, target_list)
  116. return image, target
  117. def load_mixup(self, origin_image, origin_target):
  118. # ------------ Load a new image & target ------------
  119. if self.mixup_augment.mixup_type == 'yolov5':
  120. new_index = np.random.randint(0, len(self.ids))
  121. new_image, new_target = self.load_mosaic(new_index)
  122. elif self.mixup_augment.mixup_type == 'yolox':
  123. new_index = np.random.randint(0, len(self.ids))
  124. new_image, new_target = self.load_image_target(new_index)
  125. # ------------ Mixup augmentation ------------
  126. image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target)
  127. return image, target
  128. # ------------ Load data function ------------
  129. def load_image_target(self, index):
  130. # == Load a data from the cached data ==
  131. if self.cached_datas is not None:
  132. # load a data
  133. data_item = self.cached_datas[index]
  134. image = data_item["image"]
  135. target = data_item["target"]
  136. # == Load a data from the local disk ==
  137. else:
  138. # load an image
  139. image, _ = self.pull_image(index)
  140. height, width, channels = image.shape
  141. # load a target
  142. bboxes, labels = self.pull_anno(index)
  143. target = {
  144. "boxes": bboxes,
  145. "labels": labels,
  146. "orig_size": [height, width]
  147. }
  148. return image, target
  149. def pull_item(self, index):
  150. if random.random() < self.mosaic_prob:
  151. # load a mosaic image
  152. mosaic = True
  153. image, target = self.load_mosaic(index)
  154. else:
  155. mosaic = False
  156. # load an image and target
  157. image, target = self.load_image_target(index)
  158. # MixUp
  159. if random.random() < self.mixup_prob:
  160. image, target = self.load_mixup(image, target)
  161. # augment
  162. image, target, deltas = self.transform(image, target, mosaic)
  163. return image, target, deltas
  164. def pull_image(self, index):
  165. img_id = self.ids[index]
  166. img_file = os.path.join(self.data_dir, self.image_set,
  167. '{:012}'.format(img_id) + '.jpg')
  168. image = cv2.imread(img_file)
  169. if self.json_file == 'instances_val5k.json' and image is None:
  170. img_file = os.path.join(self.data_dir, 'train2017',
  171. '{:012}'.format(img_id) + '.jpg')
  172. image = cv2.imread(img_file)
  173. assert image is not None
  174. return image, img_id
  175. def pull_anno(self, index):
  176. img_id = self.ids[index]
  177. im_ann = self.coco.loadImgs(img_id)[0]
  178. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  179. annotations = self.coco.loadAnns(anno_ids)
  180. # image infor
  181. width = im_ann['width']
  182. height = im_ann['height']
  183. #load a target
  184. bboxes = []
  185. labels = []
  186. for anno in annotations:
  187. if 'bbox' in anno and anno['area'] > 0:
  188. # bbox
  189. x1 = np.max((0, anno['bbox'][0]))
  190. y1 = np.max((0, anno['bbox'][1]))
  191. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  192. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  193. if x2 < x1 or y2 < y1:
  194. continue
  195. # class label
  196. cls_id = self.class_ids.index(anno['category_id'])
  197. bboxes.append([x1, y1, x2, y2])
  198. labels.append(cls_id)
  199. # guard against no boxes via resizing
  200. bboxes = np.array(bboxes).reshape(-1, 4)
  201. labels = np.array(labels).reshape(-1)
  202. return bboxes, labels
  203. if __name__ == "__main__":
  204. import time
  205. import argparse
  206. from build import build_transform
  207. parser = argparse.ArgumentParser(description='COCO-Dataset')
  208. # opt
  209. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  210. help='data root')
  211. parser.add_argument('--image_set', type=str, default='train2017',
  212. help='mixup augmentation.')
  213. parser.add_argument('-size', '--img_size', default=640, type=int,
  214. help='input image size.')
  215. parser.add_argument('--aug_type', type=str, default='ssd',
  216. help='augmentation type: ssd, yolov5, rtdetr.')
  217. parser.add_argument('--mosaic', default=0., type=float,
  218. help='mosaic augmentation.')
  219. parser.add_argument('--mixup', default=0., type=float,
  220. help='mixup augmentation.')
  221. parser.add_argument('--mixup_type', type=str, default='yolov5_mixup',
  222. help='mixup augmentation.')
  223. parser.add_argument('--is_train', action="store_true", default=False,
  224. help='mixup augmentation.')
  225. parser.add_argument('--load_cache', action="store_true", default=False,
  226. help='load cached data.')
  227. args = parser.parse_args()
  228. trans_config = {
  229. 'aug_type': args.aug_type, # optional: ssd, yolov5
  230. 'pixel_mean': [123.675, 116.28, 103.53],
  231. 'pixel_std': [58.395, 57.12, 57.375],
  232. 'use_ablu': True,
  233. # Basic Augment
  234. 'affine_params': {
  235. 'degrees': 0.0,
  236. 'translate': 0.2,
  237. 'scale': [0.1, 2.0],
  238. 'shear': 0.0,
  239. 'perspective': 0.0,
  240. 'hsv_h': 0.015,
  241. 'hsv_s': 0.7,
  242. 'hsv_v': 0.4,
  243. },
  244. # Mosaic & Mixup
  245. 'mosaic_keep_ratio': False,
  246. 'mosaic_prob': args.mosaic,
  247. 'mixup_prob': args.mixup,
  248. 'mosaic_type': 'yolov5',
  249. 'mixup_type': 'yolov5',
  250. 'mixup_scale': [0.5, 1.5]
  251. }
  252. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  253. pixel_mean = transform.pixel_mean
  254. pixel_std = transform.pixel_std
  255. color_format = transform.color_format
  256. dataset = COCODataset(
  257. img_size=args.img_size,
  258. data_dir=args.root,
  259. image_set='val2017',
  260. trans_config=trans_config,
  261. transform=transform,
  262. is_train=args.is_train,
  263. load_cache=args.load_cache
  264. )
  265. np.random.seed(0)
  266. class_colors = [(np.random.randint(255),
  267. np.random.randint(255),
  268. np.random.randint(255)) for _ in range(80)]
  269. print('Data length: ', len(dataset))
  270. for i in range(1000):
  271. t0 = time.time()
  272. image, target, deltas = dataset.pull_item(i)
  273. print("Load data: {} s".format(time.time() - t0))
  274. # to numpy
  275. image = image.permute(1, 2, 0).numpy()
  276. # denormalize
  277. image = image * pixel_std + pixel_mean
  278. if color_format == 'rgb':
  279. # RGB to BGR
  280. image = image[..., (2, 1, 0)]
  281. # to uint8
  282. image = image.astype(np.uint8)
  283. image = image.copy()
  284. img_h, img_w = image.shape[:2]
  285. boxes = target["boxes"]
  286. labels = target["labels"]
  287. for box, label in zip(boxes, labels):
  288. x1, y1, x2, y2 = box
  289. cls_id = int(label)
  290. color = class_colors[cls_id]
  291. # class name
  292. label = coco_class_labels[coco_class_index[cls_id]]
  293. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  294. # put the test on the bbox
  295. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  296. cv2.imshow('gt', image)
  297. # cv2.imwrite(str(i)+'.jpg', img)
  298. cv2.waitKey(0)