coco.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import os
  2. import cv2
  3. import time
  4. import random
  5. import numpy as np
  6. from torch.utils.data import Dataset
  7. from pycocotools.coco import COCO
  8. try:
  9. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  10. except:
  11. from data_augment.strong_augment import MosaicAugment, MixupAugment
  12. coco_class_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  13. coco_class_labels = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  14. class COCODataset(Dataset):
  15. def __init__(self,
  16. img_size :int = 640,
  17. data_dir :str = None,
  18. image_set :str = 'train2017',
  19. trans_config = None,
  20. transform = None,
  21. is_train :bool =False,
  22. ):
  23. # ----------- Basic parameters -----------
  24. self.img_size = img_size
  25. self.image_set = image_set
  26. self.is_train = is_train
  27. # ----------- Path parameters -----------
  28. self.data_dir = data_dir
  29. if image_set == 'train2017':
  30. self.json_file='instances_train2017_clean.json'
  31. elif image_set == 'val2017':
  32. self.json_file='instances_val2017_clean.json'
  33. elif image_set == 'test2017':
  34. self.json_file='image_info_test-dev2017.json'
  35. else:
  36. raise NotImplementedError("Unknown json image set {}.".format(image_set))
  37. # ----------- Data parameters -----------
  38. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  39. self.ids = self.coco.getImgIds()
  40. self.class_ids = sorted(self.coco.getCatIds())
  41. self.dataset_size = len(self.ids)
  42. # ----------- Transform parameters -----------
  43. self.trans_config = trans_config
  44. self.transform = transform
  45. # ----------- Strong augmentation -----------
  46. if is_train:
  47. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  48. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  49. self.mosaic_augment = MosaicAugment(img_size, trans_config, is_train) if self.mosaic_prob > 0. else None
  50. self.mixup_augment = MixupAugment(img_size, trans_config) if self.mixup_prob > 0. else None
  51. else:
  52. self.mosaic_prob = 0.0
  53. self.mixup_prob = 0.0
  54. self.mosaic_augment = None
  55. self.mixup_augment = None
  56. print('==============================')
  57. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  58. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  59. # ------------ Basic dataset function ------------
  60. def __len__(self):
  61. return len(self.ids)
  62. def __getitem__(self, index):
  63. return self.pull_item(index)
  64. # ------------ Mosaic & Mixup ------------
  65. def load_mosaic(self, index):
  66. # ------------ Prepare 4 indexes of images ------------
  67. ## Load 4x mosaic image
  68. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  69. id1 = index
  70. id2, id3, id4 = random.sample(index_list, 3)
  71. indexs = [id1, id2, id3, id4]
  72. ## Load images and targets
  73. image_list = []
  74. target_list = []
  75. for index in indexs:
  76. img_i, target_i = self.load_image_target(index)
  77. image_list.append(img_i)
  78. target_list.append(target_i)
  79. # ------------ Mosaic augmentation ------------
  80. image, target = self.mosaic_augment(image_list, target_list)
  81. return image, target
  82. def load_mixup(self, origin_image, origin_target):
  83. # ------------ Load a new image & target ------------
  84. if self.mixup_augment.mixup_type == 'yolov5':
  85. new_index = np.random.randint(0, len(self.ids))
  86. new_image, new_target = self.load_mosaic(new_index)
  87. elif self.mixup_augment.mixup_type == 'yolox':
  88. new_index = np.random.randint(0, len(self.ids))
  89. new_image, new_target = self.load_image_target(new_index)
  90. # ------------ Mixup augmentation ------------
  91. image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target)
  92. return image, target
  93. # ------------ Load data function ------------
  94. def load_image_target(self, index):
  95. # load an image
  96. image, _ = self.pull_image(index)
  97. height, width, channels = image.shape
  98. # load a target
  99. bboxes, labels = self.pull_anno(index)
  100. target = {
  101. "boxes": bboxes,
  102. "labels": labels,
  103. "orig_size": [height, width]
  104. }
  105. return image, target
  106. def pull_item(self, index):
  107. if random.random() < self.mosaic_prob:
  108. # load a mosaic image
  109. mosaic = True
  110. image, target = self.load_mosaic(index)
  111. else:
  112. mosaic = False
  113. # load an image and target
  114. image, target = self.load_image_target(index)
  115. # MixUp
  116. if random.random() < self.mixup_prob:
  117. image, target = self.load_mixup(image, target)
  118. # augment
  119. image, target, deltas = self.transform(image, target, mosaic)
  120. return image, target, deltas
  121. def pull_image(self, index):
  122. img_id = self.ids[index]
  123. img_file = os.path.join(self.data_dir, self.image_set,
  124. '{:012}'.format(img_id) + '.jpg')
  125. image = cv2.imread(img_file)
  126. if self.json_file == 'instances_val5k.json' and image is None:
  127. img_file = os.path.join(self.data_dir, 'train2017',
  128. '{:012}'.format(img_id) + '.jpg')
  129. image = cv2.imread(img_file)
  130. assert image is not None
  131. return image, img_id
  132. def pull_anno(self, index):
  133. img_id = self.ids[index]
  134. im_ann = self.coco.loadImgs(img_id)[0]
  135. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  136. annotations = self.coco.loadAnns(anno_ids)
  137. # image infor
  138. width = im_ann['width']
  139. height = im_ann['height']
  140. #load a target
  141. bboxes = []
  142. labels = []
  143. for anno in annotations:
  144. if 'bbox' in anno and anno['area'] > 0:
  145. # bbox
  146. x1 = np.max((0, anno['bbox'][0]))
  147. y1 = np.max((0, anno['bbox'][1]))
  148. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  149. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  150. if x2 < x1 or y2 < y1:
  151. continue
  152. # class label
  153. cls_id = self.class_ids.index(anno['category_id'])
  154. bboxes.append([x1, y1, x2, y2])
  155. labels.append(cls_id)
  156. # guard against no boxes via resizing
  157. bboxes = np.array(bboxes).reshape(-1, 4)
  158. labels = np.array(labels).reshape(-1)
  159. return bboxes, labels
  160. if __name__ == "__main__":
  161. import time
  162. import argparse
  163. from build import build_transform
  164. parser = argparse.ArgumentParser(description='COCO-Dataset')
  165. # opt
  166. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  167. help='data root')
  168. parser.add_argument('--image_set', type=str, default='train2017',
  169. help='mixup augmentation.')
  170. parser.add_argument('-size', '--img_size', default=640, type=int,
  171. help='input image size.')
  172. parser.add_argument('--aug_type', type=str, default='ssd',
  173. help='augmentation type: ssd, yolo.')
  174. parser.add_argument('--mosaic', default=0., type=float,
  175. help='mosaic augmentation.')
  176. parser.add_argument('--mixup', default=0., type=float,
  177. help='mixup augmentation.')
  178. parser.add_argument('--mixup_type', type=str, default='yolov5_mixup',
  179. help='mixup augmentation.')
  180. parser.add_argument('--is_train', action="store_true", default=False,
  181. help='mixup augmentation.')
  182. args = parser.parse_args()
  183. trans_config = {
  184. 'aug_type': args.aug_type, # optional: ssd, yolov5
  185. 'pixel_mean': [123.675, 116.28, 103.53],
  186. 'pixel_std': [58.395, 57.12, 57.375],
  187. 'use_ablu': True,
  188. # Basic Augment
  189. 'affine_params': {
  190. 'degrees': 0.0,
  191. 'translate': 0.2,
  192. 'scale': [0.1, 2.0],
  193. 'shear': 0.0,
  194. 'perspective': 0.0,
  195. 'hsv_h': 0.015,
  196. 'hsv_s': 0.7,
  197. 'hsv_v': 0.4,
  198. },
  199. # Mosaic & Mixup
  200. 'mosaic_keep_ratio': False,
  201. 'mosaic_prob': args.mosaic,
  202. 'mixup_prob': args.mixup,
  203. 'mosaic_type': 'yolov5',
  204. 'mixup_type': 'yolov5',
  205. 'mixup_scale': [0.5, 1.5]
  206. }
  207. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  208. pixel_mean = transform.pixel_mean
  209. pixel_std = transform.pixel_std
  210. color_format = transform.color_format
  211. dataset = COCODataset(
  212. img_size=args.img_size,
  213. data_dir=args.root,
  214. image_set='val2017',
  215. trans_config=trans_config,
  216. transform=transform,
  217. is_train=args.is_train,
  218. )
  219. np.random.seed(0)
  220. class_colors = [(np.random.randint(255),
  221. np.random.randint(255),
  222. np.random.randint(255)) for _ in range(80)]
  223. print('Data length: ', len(dataset))
  224. for i in range(1000):
  225. t0 = time.time()
  226. image, target, deltas = dataset.pull_item(i)
  227. print("Load data: {} s".format(time.time() - t0))
  228. # to numpy
  229. image = image.permute(1, 2, 0).numpy()
  230. # denormalize
  231. image = image * pixel_std + pixel_mean
  232. if color_format == 'rgb':
  233. # RGB to BGR
  234. image = image[..., (2, 1, 0)]
  235. # to uint8
  236. image = image.astype(np.uint8)
  237. image = image.copy()
  238. img_h, img_w = image.shape[:2]
  239. boxes = target["boxes"]
  240. labels = target["labels"]
  241. for box, label in zip(boxes, labels):
  242. x1, y1, x2, y2 = box
  243. cls_id = int(label)
  244. color = class_colors[cls_id]
  245. # class name
  246. label = coco_class_labels[coco_class_index[cls_id]]
  247. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  248. # put the test on the bbox
  249. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  250. cv2.imshow('gt', image)
  251. # cv2.imwrite(str(i)+'.jpg', img)
  252. cv2.waitKey(0)