coco.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. import os
  2. import random
  3. import numpy as np
  4. import time
  5. import torch
  6. from torch.utils.data import Dataset
  7. import cv2
  8. try:
  9. from pycocotools.coco import COCO
  10. except:
  11. print("It seems that the COCOAPI is not installed.")
  12. try:
  13. from .data_augment import build_transform
  14. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  15. except:
  16. from data_augment import build_transform
  17. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  18. coco_class_labels = ('background',
  19. 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
  20. 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign',
  21. 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
  22. 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella',
  23. 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
  24. 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
  25. 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass',
  26. 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
  27. 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
  28. 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk',
  29. 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
  30. 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book',
  31. 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  32. coco_class_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20,
  33. 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
  34. 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67,
  35. 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  36. class COCODataset(Dataset):
  37. """
  38. COCO dataset class.
  39. """
  40. def __init__(self,
  41. img_size=640,
  42. data_dir=None,
  43. image_set='train2017',
  44. trans_config=None,
  45. transform=None,
  46. is_train=False):
  47. """
  48. COCO dataset initialization. Annotation data are read into memory by COCO API.
  49. Args:
  50. data_dir (str): dataset root directory
  51. json_file (str): COCO json file name
  52. name (str): COCO data name (e.g. 'train2017' or 'val2017')
  53. debug (bool): if True, only one data id is selected from the dataset
  54. """
  55. if image_set == 'train2017':
  56. self.json_file='instances_train2017.json'
  57. elif image_set == 'val2017':
  58. self.json_file='instances_val2017.json'
  59. elif image_set == 'test2017':
  60. self.json_file='image_info_test-dev2017.json'
  61. self.img_size = img_size
  62. self.image_set = image_set
  63. self.data_dir = data_dir
  64. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  65. self.ids = self.coco.getImgIds()
  66. self.class_ids = sorted(self.coco.getCatIds())
  67. self.is_train = is_train
  68. # augmentation
  69. self.transform = transform
  70. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  71. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  72. self.trans_config = trans_config
  73. print('==============================')
  74. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  75. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  76. print('==============================')
  77. def __len__(self):
  78. return len(self.ids)
  79. def __getitem__(self, index):
  80. return self.pull_item(index)
  81. def load_image_target(self, index):
  82. # load an image
  83. image, _ = self.pull_image(index)
  84. height, width, channels = image.shape
  85. # load a target
  86. bboxes, labels = self.pull_anno(index)
  87. target = {
  88. "boxes": bboxes,
  89. "labels": labels,
  90. "orig_size": [height, width]
  91. }
  92. return image, target
  93. def load_mosaic(self, index):
  94. # load 4x mosaic image
  95. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  96. id1 = index
  97. id2, id3, id4 = random.sample(index_list, 3)
  98. indexs = [id1, id2, id3, id4]
  99. # load images and targets
  100. image_list = []
  101. target_list = []
  102. for index in indexs:
  103. img_i, target_i = self.load_image_target(index)
  104. image_list.append(img_i)
  105. target_list.append(target_i)
  106. # Mosaic
  107. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  108. image, target = yolov5_mosaic_augment(
  109. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  110. return image, target
  111. def load_mixup(self, origin_image, origin_target):
  112. # YOLOv5 type Mixup
  113. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  114. new_index = np.random.randint(0, len(self.ids))
  115. new_image, new_target = self.load_mosaic(new_index)
  116. image, target = yolov5_mixup_augment(
  117. origin_image, origin_target, new_image, new_target)
  118. # YOLOX type Mixup
  119. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  120. new_index = np.random.randint(0, len(self.ids))
  121. new_image, new_target = self.load_image_target(new_index)
  122. image, target = yolox_mixup_augment(
  123. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  124. return image, target
  125. def pull_item(self, index):
  126. if random.random() < self.mosaic_prob:
  127. # load a mosaic image
  128. mosaic = True
  129. image, target = self.load_mosaic(index)
  130. else:
  131. mosaic = False
  132. # load an image and target
  133. image, target = self.load_image_target(index)
  134. # MixUp
  135. if random.random() < self.mixup_prob:
  136. image, target = self.load_mixup(image, target)
  137. # augment
  138. image, target, deltas = self.transform(image, target, mosaic)
  139. return image, target, deltas
  140. def pull_image(self, index):
  141. img_id = self.ids[index]
  142. img_file = os.path.join(self.data_dir, self.image_set,
  143. '{:012}'.format(img_id) + '.jpg')
  144. image = cv2.imread(img_file)
  145. if self.json_file == 'instances_val5k.json' and image is None:
  146. img_file = os.path.join(self.data_dir, 'train2017',
  147. '{:012}'.format(img_id) + '.jpg')
  148. image = cv2.imread(img_file)
  149. assert image is not None
  150. return image, img_id
  151. def pull_anno(self, index):
  152. img_id = self.ids[index]
  153. im_ann = self.coco.loadImgs(img_id)[0]
  154. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  155. annotations = self.coco.loadAnns(anno_ids)
  156. # image infor
  157. width = im_ann['width']
  158. height = im_ann['height']
  159. #load a target
  160. bboxes = []
  161. labels = []
  162. for anno in annotations:
  163. if 'bbox' in anno and anno['area'] > 0:
  164. # bbox
  165. x1 = np.max((0, anno['bbox'][0]))
  166. y1 = np.max((0, anno['bbox'][1]))
  167. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  168. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  169. if x2 < x1 or y2 < y1:
  170. continue
  171. # class label
  172. cls_id = self.class_ids.index(anno['category_id'])
  173. bboxes.append([x1, y1, x2, y2])
  174. labels.append(cls_id)
  175. # guard against no boxes via resizing
  176. bboxes = np.array(bboxes).reshape(-1, 4)
  177. labels = np.array(labels).reshape(-1)
  178. return bboxes, labels
  179. if __name__ == "__main__":
  180. import argparse
  181. from data_augment import build_transform
  182. parser = argparse.ArgumentParser(description='COCO-Dataset')
  183. # opt
  184. parser.add_argument('--root', default='D:\\python_work\\object-detection\\dataset\\COCO',
  185. help='data root')
  186. args = parser.parse_args()
  187. is_train = False
  188. img_size = 640
  189. yolov5_trans_config = {
  190. 'aug_type': 'yolov5',
  191. # Basic Augment
  192. 'degrees': 0.0,
  193. 'translate': 0.2,
  194. 'scale': 0.9,
  195. 'shear': 0.0,
  196. 'perspective': 0.0,
  197. 'hsv_h': 0.015,
  198. 'hsv_s': 0.7,
  199. 'hsv_v': 0.4,
  200. # Mosaic & Mixup
  201. 'mosaic_prob': 1.0,
  202. 'mixup_prob': 0.15,
  203. 'mosaic_type': 'yolov5_mosaic',
  204. 'mixup_type': 'yolov5_mixup',
  205. 'mixup_scale': [0.5, 1.5]
  206. }
  207. ssd_trans_config = {
  208. 'aug_type': 'ssd',
  209. 'mosaic_prob': 0.0,
  210. 'mixup_prob': 0.0
  211. }
  212. transform = build_transform(img_size, yolov5_trans_config, is_train)
  213. dataset = COCODataset(
  214. img_size=img_size,
  215. data_dir=args.root,
  216. image_set='val2017',
  217. trans_config=yolov5_trans_config,
  218. transform=transform,
  219. is_train=is_train
  220. )
  221. np.random.seed(0)
  222. class_colors = [(np.random.randint(255),
  223. np.random.randint(255),
  224. np.random.randint(255)) for _ in range(80)]
  225. print('Data length: ', len(dataset))
  226. for i in range(1000):
  227. image, target, deltas = dataset.pull_item(i)
  228. # to numpy
  229. image = image.permute(1, 2, 0).numpy()
  230. # to uint8
  231. image = image.astype(np.uint8)
  232. image = image.copy()
  233. img_h, img_w = image.shape[:2]
  234. boxes = target["boxes"]
  235. labels = target["labels"]
  236. for box, label in zip(boxes, labels):
  237. x1, y1, x2, y2 = box
  238. cls_id = int(label)
  239. color = class_colors[cls_id]
  240. # class name
  241. label = coco_class_labels[coco_class_index[cls_id]]
  242. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  243. # put the test on the bbox
  244. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  245. cv2.imshow('gt', image)
  246. # cv2.imwrite(str(i)+'.jpg', img)
  247. cv2.waitKey(0)