coco.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import os
  2. import cv2
  3. import time
  4. import random
  5. import numpy as np
  6. from torch.utils.data import Dataset
  7. from pycocotools.coco import COCO
  8. try:
  9. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  10. except:
  11. from data_augment.strong_augment import MosaicAugment, MixupAugment
  12. coco_class_indexs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  13. coco_class_labels = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  14. coco_json_files = {
  15. 'train2017_clean': 'instances_train2017_clean.json',
  16. 'val2017_clean' : 'instances_val2017_clean.json',
  17. 'train2017' : 'instances_train2017.json',
  18. 'val2017' : 'instances_val2017.json',
  19. 'test2017' : 'image_info_test.json',
  20. }
  21. class COCODataset(Dataset):
  22. def __init__(self,
  23. cfg,
  24. data_dir :str = None,
  25. image_set :str = 'train2017',
  26. transform = None,
  27. is_train :bool = False,
  28. use_mask :bool = False,
  29. ):
  30. # ----------- Basic parameters -----------
  31. self.data_dir = data_dir
  32. self.image_set = image_set
  33. self.is_train = is_train
  34. self.use_mask = use_mask
  35. self.num_classes = 80
  36. # ----------- Data parameters -----------
  37. try:
  38. self.json_file = coco_json_files['{}_clean'.format(image_set)]
  39. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  40. except:
  41. self.json_file = coco_json_files['{}'.format(image_set)]
  42. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  43. self.ids = self.coco.getImgIds()
  44. self.class_ids = sorted(self.coco.getCatIds())
  45. self.dataset_size = len(self.ids)
  46. self.class_labels = coco_class_labels
  47. self.class_indexs = coco_class_indexs
  48. # ----------- Transform parameters -----------
  49. self.transform = transform
  50. if is_train:
  51. self.mosaic_prob = cfg.mosaic_prob
  52. self.mixup_prob = cfg.mixup_prob
  53. self.copy_paste = cfg.copy_paste
  54. self.mosaic_augment = None if cfg.mosaic_prob == 0. else MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
  55. self.mixup_augment = None if cfg.mixup_prob == 0. and cfg.copy_paste == 0. else MixupAugment(cfg.train_img_size)
  56. else:
  57. self.mosaic_prob = 0.0
  58. self.mixup_prob = 0.0
  59. self.copy_paste = 0.0
  60. self.mosaic_augment = None
  61. self.mixup_augment = None
  62. print('==============================')
  63. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  64. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  65. print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
  66. # ------------ Basic dataset function ------------
  67. def __len__(self):
  68. return len(self.ids)
  69. def __getitem__(self, index):
  70. return self.pull_item(index)
  71. # ------------ Mosaic & Mixup ------------
  72. def load_mosaic(self, index):
  73. # ------------ Prepare 4 indexes of images ------------
  74. ## Load 4x mosaic image
  75. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  76. id1 = index
  77. id2, id3, id4 = random.sample(index_list, 3)
  78. indexs = [id1, id2, id3, id4]
  79. ## Load images and targets
  80. image_list = []
  81. target_list = []
  82. for index in indexs:
  83. img_i, target_i = self.load_image_target(index)
  84. image_list.append(img_i)
  85. target_list.append(target_i)
  86. # ------------ Mosaic augmentation ------------
  87. image, target = self.mosaic_augment(image_list, target_list)
  88. return image, target
  89. def load_mixup(self, origin_image, origin_target, yolox_style=False):
  90. # ------------ Load a new image & target ------------
  91. new_index = np.random.randint(0, len(self.ids))
  92. new_image, new_target = self.load_mosaic(new_index)
  93. # ------------ Mixup augmentation ------------
  94. image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target, yolox_style)
  95. return image, target
  96. # ------------ Load data function ------------
  97. def load_image_target(self, index):
  98. # load an image
  99. image, _ = self.pull_image(index)
  100. height, width, channels = image.shape
  101. # load a target
  102. bboxes, labels = self.pull_anno(index)
  103. target = {
  104. "boxes": bboxes,
  105. "labels": labels,
  106. "orig_size": [height, width]
  107. }
  108. return image, target
  109. def pull_item(self, index):
  110. if random.random() < self.mosaic_prob:
  111. # load a mosaic image
  112. mosaic = True
  113. image, target = self.load_mosaic(index)
  114. else:
  115. mosaic = False
  116. # load an image and target
  117. image, target = self.load_image_target(index)
  118. # Yolov5-MixUp
  119. mixup = False
  120. if random.random() < self.mixup_prob:
  121. mixup = True
  122. image, target = self.load_mixup(image, target)
  123. # Copy-paste (use Yolox-Mixup to approximate copy-paste)
  124. if not mixup and random.random() < self.copy_paste:
  125. image, target = self.load_mixup(image, target, yolox_style=True)
  126. # augment
  127. image, target, deltas = self.transform(image, target, mosaic)
  128. return image, target, deltas
  129. def pull_image(self, index):
  130. img_id = self.ids[index]
  131. img_file = os.path.join(self.data_dir, self.image_set,
  132. '{:012}'.format(img_id) + '.jpg')
  133. image = cv2.imread(img_file)
  134. if self.json_file == 'instances_val5k.json' and image is None:
  135. img_file = os.path.join(self.data_dir, 'train2017',
  136. '{:012}'.format(img_id) + '.jpg')
  137. image = cv2.imread(img_file)
  138. assert image is not None
  139. return image, img_id
  140. def pull_anno(self, index):
  141. img_id = self.ids[index]
  142. im_ann = self.coco.loadImgs(img_id)[0]
  143. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  144. annotations = self.coco.loadAnns(anno_ids)
  145. # image infor
  146. width = im_ann['width']
  147. height = im_ann['height']
  148. #load a target
  149. bboxes = []
  150. labels = []
  151. for anno in annotations:
  152. if 'bbox' in anno and anno['area'] > 0:
  153. # bbox
  154. x1 = np.max((0, anno['bbox'][0]))
  155. y1 = np.max((0, anno['bbox'][1]))
  156. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  157. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  158. if x2 < x1 or y2 < y1:
  159. continue
  160. # class label
  161. cls_id = self.class_ids.index(anno['category_id'])
  162. bboxes.append([x1, y1, x2, y2])
  163. labels.append(cls_id)
  164. # guard against no boxes via resizing
  165. bboxes = np.array(bboxes).reshape(-1, 4)
  166. labels = np.array(labels).reshape(-1)
  167. return bboxes, labels
  168. if __name__ == "__main__":
  169. import time
  170. import argparse
  171. from build import build_transform
  172. parser = argparse.ArgumentParser(description='COCO-Dataset')
  173. # opt
  174. parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
  175. help='data root')
  176. parser.add_argument('--is_train', action="store_true", default=False,
  177. help='mixup augmentation.')
  178. parser.add_argument('--aug_type', default="yolo", type=str, choices=["yolo", "ssd"],
  179. help='yolo, ssd.')
  180. args = parser.parse_args()
  181. class YoloBaseConfig(object):
  182. def __init__(self) -> None:
  183. self.max_stride = 32
  184. # ---------------- Data process config ----------------
  185. self.box_format = 'xywh'
  186. self.normalize_coords = False
  187. self.mosaic_prob = 1.0
  188. self.mixup_prob = 0.15
  189. self.copy_paste = 0.3
  190. ## Pixel mean & std
  191. self.pixel_mean = [0., 0., 0.]
  192. self.pixel_std = [255., 255., 255.]
  193. ## Transforms
  194. self.train_img_size = 640
  195. self.test_img_size = 640
  196. self.use_ablu = True
  197. self.aug_type = 'yolo'
  198. self.affine_params = {
  199. 'degrees': 0.0,
  200. 'translate': 0.2,
  201. 'scale': [0.1, 2.0],
  202. 'shear': 0.0,
  203. 'perspective': 0.0,
  204. 'hsv_h': 0.015,
  205. 'hsv_s': 0.7,
  206. 'hsv_v': 0.4,
  207. }
  208. class SSDBaseConfig(object):
  209. def __init__(self) -> None:
  210. self.max_stride = 32
  211. # ---------------- Data process config ----------------
  212. self.box_format = 'xywh'
  213. self.normalize_coords = False
  214. self.mosaic_prob = 0.0
  215. self.mixup_prob = 0.0
  216. self.copy_paste = 0.0
  217. ## Pixel mean & std
  218. self.pixel_mean = [0., 0., 0.]
  219. self.pixel_std = [255., 255., 255.]
  220. ## Transforms
  221. self.train_img_size = 640
  222. self.test_img_size = 640
  223. self.aug_type = 'ssd'
  224. if args.aug_type == "yolo":
  225. cfg = YoloBaseConfig()
  226. elif args.aug_type == "ssd":
  227. cfg = SSDBaseConfig()
  228. transform = build_transform(cfg, args.is_train)
  229. dataset = COCODataset(cfg, args.root, 'val2017', transform, args.is_train)
  230. np.random.seed(0)
  231. class_colors = [(np.random.randint(255),
  232. np.random.randint(255),
  233. np.random.randint(255)) for _ in range(80)]
  234. print('Data length: ', len(dataset))
  235. for i in range(1000):
  236. t0 = time.time()
  237. image, target, deltas = dataset.pull_item(i)
  238. print("Load data: {} s".format(time.time() - t0))
  239. # to numpy
  240. image = image.permute(1, 2, 0).numpy()
  241. # denormalize
  242. image = image * cfg.pixel_std + cfg.pixel_mean
  243. # rgb -> bgr
  244. if transform.color_format == 'rgb':
  245. image = image[..., (2, 1, 0)]
  246. # to uint8
  247. image = image.astype(np.uint8)
  248. image = image.copy()
  249. img_h, img_w = image.shape[:2]
  250. boxes = target["boxes"]
  251. labels = target["labels"]
  252. for box, label in zip(boxes, labels):
  253. if cfg.box_format == 'xyxy':
  254. x1, y1, x2, y2 = box
  255. elif cfg.box_format == 'xywh':
  256. cx, cy, bw, bh = box
  257. x1 = cx - 0.5 * bw
  258. y1 = cy - 0.5 * bh
  259. x2 = cx + 0.5 * bw
  260. y2 = cy + 0.5 * bh
  261. if cfg.normalize_coords:
  262. x1 *= img_w
  263. y1 *= img_h
  264. x2 *= img_w
  265. y2 *= img_h
  266. cls_id = int(label)
  267. color = class_colors[cls_id]
  268. # class name
  269. label = coco_class_labels[cls_id]
  270. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  271. # put the test on the bbox
  272. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  273. cv2.imshow('gt', image)
  274. # cv2.imwrite(str(i)+'.jpg', img)
  275. cv2.waitKey(0)