coco.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. import os
  2. import random
  3. import numpy as np
  4. import time
  5. import torch
  6. from torch.utils.data import Dataset
  7. import cv2
  8. try:
  9. from pycocotools.coco import COCO
  10. except:
  11. print("It seems that the COCOAPI is not installed.")
  12. try:
  13. from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  14. except:
  15. from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
  16. coco_class_labels = ('background',
  17. 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
  18. 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign',
  19. 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
  20. 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella',
  21. 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
  22. 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
  23. 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass',
  24. 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
  25. 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
  26. 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk',
  27. 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
  28. 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book',
  29. 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  30. coco_class_index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20,
  31. 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
  32. 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67,
  33. 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  34. class COCODataset(Dataset):
  35. """
  36. COCO dataset class.
  37. """
  38. def __init__(self,
  39. img_size=640,
  40. data_dir=None,
  41. image_set='train2017',
  42. trans_config=None,
  43. transform=None,
  44. is_train=False):
  45. """
  46. COCO dataset initialization. Annotation data are read into memory by COCO API.
  47. Args:
  48. data_dir (str): dataset root directory
  49. json_file (str): COCO json file name
  50. name (str): COCO data name (e.g. 'train2017' or 'val2017')
  51. debug (bool): if True, only one data id is selected from the dataset
  52. """
  53. if image_set == 'train2017':
  54. self.json_file='instances_train2017.json'
  55. elif image_set == 'val2017':
  56. self.json_file='instances_val2017.json'
  57. elif image_set == 'test2017':
  58. self.json_file='image_info_test-dev2017.json'
  59. self.img_size = img_size
  60. self.image_set = image_set
  61. self.data_dir = data_dir
  62. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  63. self.ids = self.coco.getImgIds()
  64. self.class_ids = sorted(self.coco.getCatIds())
  65. self.is_train = is_train
  66. # augmentation
  67. self.transform = transform
  68. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  69. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  70. self.trans_config = trans_config
  71. print('==============================')
  72. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  73. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  74. print('==============================')
  75. def __len__(self):
  76. return len(self.ids)
  77. def __getitem__(self, index):
  78. return self.pull_item(index)
  79. def load_image_target(self, index):
  80. # load an image
  81. image, _ = self.pull_image(index)
  82. height, width, channels = image.shape
  83. # load a target
  84. bboxes, labels = self.pull_anno(index)
  85. target = {
  86. "boxes": bboxes,
  87. "labels": labels,
  88. "orig_size": [height, width]
  89. }
  90. return image, target
  91. def load_mosaic(self, index):
  92. # load 4x mosaic image
  93. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  94. id1 = index
  95. id2, id3, id4 = random.sample(index_list, 3)
  96. indexs = [id1, id2, id3, id4]
  97. # load images and targets
  98. image_list = []
  99. target_list = []
  100. for index in indexs:
  101. img_i, target_i = self.load_image_target(index)
  102. image_list.append(img_i)
  103. target_list.append(target_i)
  104. # Mosaic
  105. if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
  106. image, target = yolov5_mosaic_augment(
  107. image_list, target_list, self.img_size, self.trans_config, self.is_train)
  108. return image, target
  109. def load_mixup(self, origin_image, origin_target):
  110. # YOLOv5 type Mixup
  111. if self.trans_config['mixup_type'] == 'yolov5_mixup':
  112. new_index = np.random.randint(0, len(self.ids))
  113. new_image, new_target = self.load_mosaic(new_index)
  114. image, target = yolov5_mixup_augment(
  115. origin_image, origin_target, new_image, new_target)
  116. # YOLOX type Mixup
  117. elif self.trans_config['mixup_type'] == 'yolox_mixup':
  118. new_index = np.random.randint(0, len(self.ids))
  119. new_image, new_target = self.load_image_target(new_index)
  120. image, target = yolox_mixup_augment(
  121. origin_image, origin_target, new_image, new_target, self.img_size, self.trans_config['mixup_scale'])
  122. return image, target
  123. def pull_item(self, index):
  124. if random.random() < self.mosaic_prob:
  125. # load a mosaic image
  126. mosaic = True
  127. image, target = self.load_mosaic(index)
  128. else:
  129. mosaic = False
  130. # load an image and target
  131. image, target = self.load_image_target(index)
  132. # MixUp
  133. if random.random() < self.mixup_prob:
  134. image, target = self.load_mixup(image, target)
  135. # augment
  136. image, target, deltas = self.transform(image, target, mosaic)
  137. return image, target, deltas
  138. def pull_image(self, index):
  139. img_id = self.ids[index]
  140. img_file = os.path.join(self.data_dir, self.image_set,
  141. '{:012}'.format(img_id) + '.jpg')
  142. image = cv2.imread(img_file)
  143. if self.json_file == 'instances_val5k.json' and image is None:
  144. img_file = os.path.join(self.data_dir, 'train2017',
  145. '{:012}'.format(img_id) + '.jpg')
  146. image = cv2.imread(img_file)
  147. assert image is not None
  148. return image, img_id
  149. def pull_anno(self, index):
  150. img_id = self.ids[index]
  151. im_ann = self.coco.loadImgs(img_id)[0]
  152. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  153. annotations = self.coco.loadAnns(anno_ids)
  154. # image infor
  155. width = im_ann['width']
  156. height = im_ann['height']
  157. #load a target
  158. bboxes = []
  159. labels = []
  160. for anno in annotations:
  161. if 'bbox' in anno and anno['area'] > 0:
  162. # bbox
  163. x1 = np.max((0, anno['bbox'][0]))
  164. y1 = np.max((0, anno['bbox'][1]))
  165. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  166. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  167. if x2 < x1 or y2 < y1:
  168. continue
  169. # class label
  170. cls_id = self.class_ids.index(anno['category_id'])
  171. bboxes.append([x1, y1, x2, y2])
  172. labels.append(cls_id)
  173. # guard against no boxes via resizing
  174. bboxes = np.array(bboxes).reshape(-1, 4)
  175. labels = np.array(labels).reshape(-1)
  176. return bboxes, labels
  177. if __name__ == "__main__":
  178. import argparse
  179. from build import build_transform
  180. parser = argparse.ArgumentParser(description='COCO-Dataset')
  181. # opt
  182. parser.add_argument('--root', default='D:\\python_work\\object-detection\\dataset\\COCO',
  183. help='data root')
  184. args = parser.parse_args()
  185. is_train = False
  186. img_size = 640
  187. yolov5_trans_config = {
  188. 'aug_type': 'yolov5',
  189. # Basic Augment
  190. 'degrees': 0.0,
  191. 'translate': 0.2,
  192. 'scale': 0.9,
  193. 'shear': 0.0,
  194. 'perspective': 0.0,
  195. 'hsv_h': 0.015,
  196. 'hsv_s': 0.7,
  197. 'hsv_v': 0.4,
  198. # Mosaic & Mixup
  199. 'mosaic_prob': 1.0,
  200. 'mixup_prob': 0.15,
  201. 'mosaic_type': 'yolov5_mosaic',
  202. 'mixup_type': 'yolov5_mixup',
  203. 'mixup_scale': [0.5, 1.5]
  204. }
  205. ssd_trans_config = {
  206. 'aug_type': 'ssd',
  207. 'mosaic_prob': 0.0,
  208. 'mixup_prob': 0.0
  209. }
  210. transform = build_transform(img_size, yolov5_trans_config, is_train)
  211. dataset = COCODataset(
  212. img_size=img_size,
  213. data_dir=args.root,
  214. image_set='val2017',
  215. trans_config=yolov5_trans_config,
  216. transform=transform,
  217. is_train=is_train
  218. )
  219. np.random.seed(0)
  220. class_colors = [(np.random.randint(255),
  221. np.random.randint(255),
  222. np.random.randint(255)) for _ in range(80)]
  223. print('Data length: ', len(dataset))
  224. for i in range(1000):
  225. image, target, deltas = dataset.pull_item(i)
  226. # to numpy
  227. image = image.permute(1, 2, 0).numpy()
  228. # to uint8
  229. image = image.astype(np.uint8)
  230. image = image.copy()
  231. img_h, img_w = image.shape[:2]
  232. boxes = target["boxes"]
  233. labels = target["labels"]
  234. for box, label in zip(boxes, labels):
  235. x1, y1, x2, y2 = box
  236. cls_id = int(label)
  237. color = class_colors[cls_id]
  238. # class name
  239. label = coco_class_labels[coco_class_index[cls_id]]
  240. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  241. # put the test on the bbox
  242. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  243. cv2.imshow('gt', image)
  244. # cv2.imwrite(str(i)+'.jpg', img)
  245. cv2.waitKey(0)