coco.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import os
  2. import cv2
  3. import time
  4. import random
  5. import numpy as np
  6. from torch.utils.data import Dataset
  7. from pycocotools.coco import COCO
  8. try:
  9. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  10. except:
  11. from data_augment.strong_augment import MosaicAugment, MixupAugment
  12. coco_class_indexs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  13. coco_class_labels = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  14. coco_json_files = {
  15. 'train2017' : 'instances_train2017.json',
  16. 'val2017' : 'instances_val2017.json',
  17. 'test2017' : 'image_info_test.json',
  18. }
  19. class COCODataset(Dataset):
  20. def __init__(self,
  21. cfg,
  22. data_dir :str = None,
  23. image_set :str = 'train2017',
  24. transform = None,
  25. is_train :bool = False,
  26. use_mask :bool = False,
  27. ):
  28. # ----------- Basic parameters -----------
  29. self.data_dir = data_dir
  30. self.image_set = image_set
  31. self.is_train = is_train
  32. self.use_mask = use_mask
  33. self.num_classes = 80
  34. # ----------- Data parameters -----------
  35. self.json_file = coco_json_files['{}'.format(image_set)]
  36. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  37. self.ids = self.coco.getImgIds()
  38. self.class_ids = sorted(self.coco.getCatIds())
  39. self.dataset_size = len(self.ids)
  40. self.class_labels = coco_class_labels
  41. self.class_indexs = coco_class_indexs
  42. # ----------- Transform parameters -----------
  43. self.transform = transform
  44. if is_train:
  45. self.mosaic_prob = cfg.mosaic_prob
  46. self.mixup_prob = cfg.mixup_prob
  47. self.copy_paste = cfg.copy_paste
  48. self.mosaic_augment = None if cfg.mosaic_prob == 0. else MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
  49. self.mixup_augment = None if cfg.mixup_prob == 0. and cfg.copy_paste == 0. else MixupAugment(cfg.train_img_size)
  50. else:
  51. self.mosaic_prob = 0.0
  52. self.mixup_prob = 0.0
  53. self.copy_paste = 0.0
  54. self.mosaic_augment = None
  55. self.mixup_augment = None
  56. print('==============================')
  57. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  58. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  59. print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
  60. # ------------ Basic dataset function ------------
  61. def __len__(self):
  62. return len(self.ids)
  63. def __getitem__(self, index):
  64. return self.pull_item(index)
  65. # ------------ Mosaic & Mixup ------------
  66. def load_mosaic(self, index):
  67. # ------------ Prepare 4 indexes of images ------------
  68. ## Load 4x mosaic image
  69. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  70. id1 = index
  71. id2, id3, id4 = random.sample(index_list, 3)
  72. indexs = [id1, id2, id3, id4]
  73. ## Load images and targets
  74. image_list = []
  75. target_list = []
  76. for index in indexs:
  77. img_i, target_i = self.load_image_target(index)
  78. image_list.append(img_i)
  79. target_list.append(target_i)
  80. # ------------ Mosaic augmentation ------------
  81. image, target = self.mosaic_augment(image_list, target_list)
  82. return image, target
  83. def load_mixup(self, origin_image, origin_target, yolox_style=False):
  84. # ------------ Load a new image & target ------------
  85. if yolox_style:
  86. new_index = np.random.randint(0, len(self.ids))
  87. new_image, new_target = self.load_image_target(new_index)
  88. else:
  89. new_index = np.random.randint(0, len(self.ids))
  90. new_image, new_target = self.load_mosaic(new_index)
  91. # ------------ Mixup augmentation ------------
  92. image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target, yolox_style)
  93. return image, target
  94. # ------------ Load data function ------------
  95. def load_image_target(self, index):
  96. # load an image
  97. image, _ = self.pull_image(index)
  98. height, width, channels = image.shape
  99. # load a target
  100. bboxes, labels = self.pull_anno(index)
  101. target = {
  102. "boxes": bboxes,
  103. "labels": labels,
  104. "orig_size": [height, width]
  105. }
  106. return image, target
  107. def pull_item(self, index):
  108. if random.random() < self.mosaic_prob:
  109. # load a mosaic image
  110. mosaic = True
  111. image, target = self.load_mosaic(index)
  112. else:
  113. mosaic = False
  114. # load an image and target
  115. image, target = self.load_image_target(index)
  116. # Yolov5-MixUp
  117. mixup = False
  118. if random.random() < self.mixup_prob:
  119. mixup = True
  120. image, target = self.load_mixup(image, target)
  121. # Copy-paste (use Yolox-Mixup to approximate copy-paste)
  122. if not mixup and random.random() < self.copy_paste:
  123. image, target = self.load_mixup(image, target, yolox_style=True)
  124. # augment
  125. image, target, deltas = self.transform(image, target, mosaic)
  126. return image, target, deltas
  127. def pull_image(self, index):
  128. img_id = self.ids[index]
  129. img_file = os.path.join(self.data_dir, self.image_set,
  130. '{:012}'.format(img_id) + '.jpg')
  131. image = cv2.imread(img_file)
  132. if self.json_file == 'instances_val5k.json' and image is None:
  133. img_file = os.path.join(self.data_dir, 'train2017',
  134. '{:012}'.format(img_id) + '.jpg')
  135. image = cv2.imread(img_file)
  136. assert image is not None
  137. return image, img_id
  138. def pull_anno(self, index):
  139. img_id = self.ids[index]
  140. im_ann = self.coco.loadImgs(img_id)[0]
  141. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  142. annotations = self.coco.loadAnns(anno_ids)
  143. # image infor
  144. width = im_ann['width']
  145. height = im_ann['height']
  146. #load a target
  147. bboxes = []
  148. labels = []
  149. for anno in annotations:
  150. if 'bbox' in anno and anno['area'] > 0:
  151. # bbox
  152. x1 = np.max((0, anno['bbox'][0]))
  153. y1 = np.max((0, anno['bbox'][1]))
  154. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  155. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  156. if x2 < x1 or y2 < y1:
  157. continue
  158. # class label
  159. cls_id = self.class_ids.index(anno['category_id'])
  160. bboxes.append([x1, y1, x2, y2])
  161. labels.append(cls_id)
  162. # guard against no boxes via resizing
  163. bboxes = np.array(bboxes).reshape(-1, 4)
  164. labels = np.array(labels).reshape(-1)
  165. return bboxes, labels
  166. if __name__ == "__main__":
  167. import time
  168. import argparse
  169. from build import build_transform
  170. parser = argparse.ArgumentParser(description='COCO-Dataset')
  171. # opt
  172. parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
  173. help='data root')
  174. parser.add_argument('--is_train', action="store_true", default=False,
  175. help='mixup augmentation.')
  176. parser.add_argument('--aug_type', default="yolo", type=str, choices=["yolo", "ssd"],
  177. help='yolo, ssd.')
  178. args = parser.parse_args()
  179. class YoloBaseConfig(object):
  180. def __init__(self) -> None:
  181. self.max_stride = 32
  182. # ---------------- Data process config ----------------
  183. self.box_format = 'xywh'
  184. self.normalize_coords = False
  185. self.mosaic_prob = 1.0
  186. self.mixup_prob = 0.15
  187. self.copy_paste = 0.3
  188. ## Pixel mean & std
  189. self.pixel_mean = [0., 0., 0.]
  190. self.pixel_std = [255., 255., 255.]
  191. ## Transforms
  192. self.train_img_size = 640
  193. self.test_img_size = 640
  194. self.use_ablu = True
  195. self.aug_type = 'yolo'
  196. self.affine_params = {
  197. 'degrees': 0.0,
  198. 'translate': 0.2,
  199. 'scale': [0.1, 2.0],
  200. 'shear': 0.0,
  201. 'perspective': 0.0,
  202. 'hsv_h': 0.015,
  203. 'hsv_s': 0.7,
  204. 'hsv_v': 0.4,
  205. }
  206. class SSDBaseConfig(object):
  207. def __init__(self) -> None:
  208. self.max_stride = 32
  209. # ---------------- Data process config ----------------
  210. self.box_format = 'xywh'
  211. self.normalize_coords = False
  212. self.mosaic_prob = 0.0
  213. self.mixup_prob = 0.0
  214. self.copy_paste = 0.0
  215. ## Pixel mean & std
  216. self.pixel_mean = [0., 0., 0.]
  217. self.pixel_std = [255., 255., 255.]
  218. ## Transforms
  219. self.train_img_size = 640
  220. self.test_img_size = 640
  221. self.aug_type = 'ssd'
  222. if args.aug_type == "yolo":
  223. cfg = YoloBaseConfig()
  224. elif args.aug_type == "ssd":
  225. cfg = SSDBaseConfig()
  226. transform = build_transform(cfg, args.is_train)
  227. dataset = COCODataset(cfg, args.root, 'val2017', transform, args.is_train)
  228. np.random.seed(0)
  229. class_colors = [(np.random.randint(255),
  230. np.random.randint(255),
  231. np.random.randint(255)) for _ in range(80)]
  232. print('Data length: ', len(dataset))
  233. for i in range(1000):
  234. t0 = time.time()
  235. image, target, deltas = dataset.pull_item(i)
  236. print("Load data: {} s".format(time.time() - t0))
  237. # to numpy
  238. image = image.permute(1, 2, 0).numpy()
  239. # denormalize
  240. image = image * cfg.pixel_std + cfg.pixel_mean
  241. # rgb -> bgr
  242. if transform.color_format == 'rgb':
  243. image = image[..., (2, 1, 0)]
  244. # to uint8
  245. image = image.astype(np.uint8)
  246. image = image.copy()
  247. img_h, img_w = image.shape[:2]
  248. boxes = target["boxes"]
  249. labels = target["labels"]
  250. for box, label in zip(boxes, labels):
  251. if cfg.box_format == 'xyxy':
  252. x1, y1, x2, y2 = box
  253. elif cfg.box_format == 'xywh':
  254. cx, cy, bw, bh = box
  255. x1 = cx - 0.5 * bw
  256. y1 = cy - 0.5 * bh
  257. x2 = cx + 0.5 * bw
  258. y2 = cy + 0.5 * bh
  259. if cfg.normalize_coords:
  260. x1 *= img_w
  261. y1 *= img_h
  262. x2 *= img_w
  263. y2 *= img_h
  264. cls_id = int(label)
  265. color = class_colors[cls_id]
  266. # class name
  267. label = coco_class_labels[cls_id]
  268. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  269. # put the test on the bbox
  270. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  271. cv2.imshow('gt', image)
  272. # cv2.imwrite(str(i)+'.jpg', img)
  273. cv2.waitKey(0)