coco.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import os
  2. import cv2
  3. import time
  4. import random
  5. import numpy as np
  6. from torch.utils.data import Dataset
  7. from pycocotools.coco import COCO
  8. try:
  9. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  10. except:
  11. from data_augment.strong_augment import MosaicAugment, MixupAugment
  12. coco_class_indexs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  13. coco_class_labels = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  14. coco_json_files = {
  15. 'train2017_clean': 'instances_train2017_clean.json',
  16. 'val2017_clean' : 'instances_val2017_clean.json',
  17. 'train2017' : 'instances_train2017.json',
  18. 'val2017' : 'instances_val2017.json',
  19. 'test2017' : 'image_info_test.json',
  20. }
  21. class COCODataset(Dataset):
  22. def __init__(self,
  23. cfg,
  24. data_dir :str = None,
  25. image_set :str = 'train2017',
  26. transform = None,
  27. is_train :bool = False,
  28. use_mask :bool = False,
  29. ):
  30. # ----------- Basic parameters -----------
  31. self.data_dir = data_dir
  32. self.image_set = image_set
  33. self.is_train = is_train
  34. self.use_mask = use_mask
  35. self.num_classes = 80
  36. # ----------- Data parameters -----------
  37. try:
  38. self.json_file = coco_json_files['{}_clean'.format(image_set)]
  39. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  40. except:
  41. self.json_file = coco_json_files['{}'.format(image_set)]
  42. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  43. self.ids = self.coco.getImgIds()
  44. self.class_ids = sorted(self.coco.getCatIds())
  45. self.dataset_size = len(self.ids)
  46. self.class_labels = coco_class_labels
  47. self.class_indexs = coco_class_indexs
  48. # ----------- Transform parameters -----------
  49. self.transform = transform
  50. if is_train:
  51. self.mosaic_prob = cfg.mosaic_prob
  52. self.mixup_prob = cfg.mixup_prob
  53. self.copy_paste = cfg.copy_paste
  54. self.mosaic_augment = None if cfg.mosaic_prob == 0. else MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
  55. self.mixup_augment = None if cfg.mixup_prob == 0. and cfg.copy_paste == 0. else MixupAugment(cfg.train_img_size)
  56. else:
  57. self.mosaic_prob = 0.0
  58. self.mixup_prob = 0.0
  59. self.copy_paste = 0.0
  60. self.mosaic_augment = None
  61. self.mixup_augment = None
  62. print('==============================')
  63. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  64. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  65. print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
  66. # ------------ Basic dataset function ------------
  67. def __len__(self):
  68. return len(self.ids)
  69. def __getitem__(self, index):
  70. return self.pull_item(index)
  71. # ------------ Mosaic & Mixup ------------
  72. def load_mosaic(self, index):
  73. # ------------ Prepare 4 indexes of images ------------
  74. ## Load 4x mosaic image
  75. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  76. id1 = index
  77. id2, id3, id4 = random.sample(index_list, 3)
  78. indexs = [id1, id2, id3, id4]
  79. ## Load images and targets
  80. image_list = []
  81. target_list = []
  82. for index in indexs:
  83. img_i, target_i = self.load_image_target(index)
  84. image_list.append(img_i)
  85. target_list.append(target_i)
  86. # ------------ Mosaic augmentation ------------
  87. image, target = self.mosaic_augment(image_list, target_list)
  88. return image, target
  89. def load_mixup(self, origin_image, origin_target, yolox_style=False):
  90. # ------------ Load a new image & target ------------
  91. if yolox_style:
  92. new_index = np.random.randint(0, len(self.ids))
  93. new_image, new_target = self.load_image_target(new_index)
  94. else:
  95. new_index = np.random.randint(0, len(self.ids))
  96. new_image, new_target = self.load_mosaic(new_index)
  97. # ------------ Mixup augmentation ------------
  98. image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target, yolox_style)
  99. return image, target
  100. # ------------ Load data function ------------
  101. def load_image_target(self, index):
  102. # load an image
  103. image, _ = self.pull_image(index)
  104. height, width, channels = image.shape
  105. # load a target
  106. bboxes, labels = self.pull_anno(index)
  107. target = {
  108. "boxes": bboxes,
  109. "labels": labels,
  110. "orig_size": [height, width]
  111. }
  112. return image, target
  113. def pull_item(self, index):
  114. if random.random() < self.mosaic_prob:
  115. # load a mosaic image
  116. mosaic = True
  117. image, target = self.load_mosaic(index)
  118. else:
  119. mosaic = False
  120. # load an image and target
  121. image, target = self.load_image_target(index)
  122. # Yolov5-MixUp
  123. mixup = False
  124. if random.random() < self.mixup_prob:
  125. mixup = True
  126. image, target = self.load_mixup(image, target)
  127. # Copy-paste (use Yolox-Mixup to approximate copy-paste)
  128. if not mixup and random.random() < self.copy_paste:
  129. image, target = self.load_mixup(image, target, yolox_style=True)
  130. # augment
  131. image, target, deltas = self.transform(image, target, mosaic)
  132. return image, target, deltas
  133. def pull_image(self, index):
  134. img_id = self.ids[index]
  135. img_file = os.path.join(self.data_dir, self.image_set,
  136. '{:012}'.format(img_id) + '.jpg')
  137. image = cv2.imread(img_file)
  138. if self.json_file == 'instances_val5k.json' and image is None:
  139. img_file = os.path.join(self.data_dir, 'train2017',
  140. '{:012}'.format(img_id) + '.jpg')
  141. image = cv2.imread(img_file)
  142. assert image is not None
  143. return image, img_id
  144. def pull_anno(self, index):
  145. img_id = self.ids[index]
  146. im_ann = self.coco.loadImgs(img_id)[0]
  147. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  148. annotations = self.coco.loadAnns(anno_ids)
  149. # image infor
  150. width = im_ann['width']
  151. height = im_ann['height']
  152. #load a target
  153. bboxes = []
  154. labels = []
  155. for anno in annotations:
  156. if 'bbox' in anno and anno['area'] > 0:
  157. # bbox
  158. x1 = np.max((0, anno['bbox'][0]))
  159. y1 = np.max((0, anno['bbox'][1]))
  160. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  161. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  162. if x2 < x1 or y2 < y1:
  163. continue
  164. # class label
  165. cls_id = self.class_ids.index(anno['category_id'])
  166. bboxes.append([x1, y1, x2, y2])
  167. labels.append(cls_id)
  168. # guard against no boxes via resizing
  169. bboxes = np.array(bboxes).reshape(-1, 4)
  170. labels = np.array(labels).reshape(-1)
  171. return bboxes, labels
  172. if __name__ == "__main__":
  173. import time
  174. import argparse
  175. from build import build_transform
  176. parser = argparse.ArgumentParser(description='COCO-Dataset')
  177. # opt
  178. parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
  179. help='data root')
  180. parser.add_argument('--is_train', action="store_true", default=False,
  181. help='mixup augmentation.')
  182. parser.add_argument('--aug_type', default="yolo", type=str, choices=["yolo", "ssd"],
  183. help='yolo, ssd.')
  184. args = parser.parse_args()
  185. class YoloBaseConfig(object):
  186. def __init__(self) -> None:
  187. self.max_stride = 32
  188. # ---------------- Data process config ----------------
  189. self.box_format = 'xywh'
  190. self.normalize_coords = False
  191. self.mosaic_prob = 1.0
  192. self.mixup_prob = 0.15
  193. self.copy_paste = 0.3
  194. ## Pixel mean & std
  195. self.pixel_mean = [0., 0., 0.]
  196. self.pixel_std = [255., 255., 255.]
  197. ## Transforms
  198. self.train_img_size = 640
  199. self.test_img_size = 640
  200. self.use_ablu = True
  201. self.aug_type = 'yolo'
  202. self.affine_params = {
  203. 'degrees': 0.0,
  204. 'translate': 0.2,
  205. 'scale': [0.1, 2.0],
  206. 'shear': 0.0,
  207. 'perspective': 0.0,
  208. 'hsv_h': 0.015,
  209. 'hsv_s': 0.7,
  210. 'hsv_v': 0.4,
  211. }
  212. class SSDBaseConfig(object):
  213. def __init__(self) -> None:
  214. self.max_stride = 32
  215. # ---------------- Data process config ----------------
  216. self.box_format = 'xywh'
  217. self.normalize_coords = False
  218. self.mosaic_prob = 0.0
  219. self.mixup_prob = 0.0
  220. self.copy_paste = 0.0
  221. ## Pixel mean & std
  222. self.pixel_mean = [0., 0., 0.]
  223. self.pixel_std = [255., 255., 255.]
  224. ## Transforms
  225. self.train_img_size = 640
  226. self.test_img_size = 640
  227. self.aug_type = 'ssd'
  228. if args.aug_type == "yolo":
  229. cfg = YoloBaseConfig()
  230. elif args.aug_type == "ssd":
  231. cfg = SSDBaseConfig()
  232. transform = build_transform(cfg, args.is_train)
  233. dataset = COCODataset(cfg, args.root, 'val2017', transform, args.is_train)
  234. np.random.seed(0)
  235. class_colors = [(np.random.randint(255),
  236. np.random.randint(255),
  237. np.random.randint(255)) for _ in range(80)]
  238. print('Data length: ', len(dataset))
  239. for i in range(1000):
  240. t0 = time.time()
  241. image, target, deltas = dataset.pull_item(i)
  242. print("Load data: {} s".format(time.time() - t0))
  243. # to numpy
  244. image = image.permute(1, 2, 0).numpy()
  245. # denormalize
  246. image = image * cfg.pixel_std + cfg.pixel_mean
  247. # rgb -> bgr
  248. if transform.color_format == 'rgb':
  249. image = image[..., (2, 1, 0)]
  250. # to uint8
  251. image = image.astype(np.uint8)
  252. image = image.copy()
  253. img_h, img_w = image.shape[:2]
  254. boxes = target["boxes"]
  255. labels = target["labels"]
  256. for box, label in zip(boxes, labels):
  257. if cfg.box_format == 'xyxy':
  258. x1, y1, x2, y2 = box
  259. elif cfg.box_format == 'xywh':
  260. cx, cy, bw, bh = box
  261. x1 = cx - 0.5 * bw
  262. y1 = cy - 0.5 * bh
  263. x2 = cx + 0.5 * bw
  264. y2 = cy + 0.5 * bh
  265. if cfg.normalize_coords:
  266. x1 *= img_w
  267. y1 *= img_h
  268. x2 *= img_w
  269. y2 *= img_h
  270. cls_id = int(label)
  271. color = class_colors[cls_id]
  272. # class name
  273. label = coco_class_labels[cls_id]
  274. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  275. # put the test on the bbox
  276. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  277. cv2.imshow('gt', image)
  278. # cv2.imwrite(str(i)+'.jpg', img)
  279. cv2.waitKey(0)