coco.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import os
  2. import cv2
  3. import time
  4. import random
  5. import numpy as np
  6. from torch.utils.data import Dataset
  7. from pycocotools.coco import COCO
  8. try:
  9. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  10. except:
  11. from data_augment.strong_augment import MosaicAugment, MixupAugment
  12. coco_class_indexs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  13. coco_class_labels = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  14. coco_json_files = {
  15. 'train2017' : 'instances_train2017.json',
  16. 'val2017' : 'instances_val2017.json',
  17. 'test2017' : 'image_info_test.json',
  18. }
  19. class COCODataset(Dataset):
  20. def __init__(self,
  21. cfg,
  22. data_dir :str = None,
  23. transform = None,
  24. is_train :bool = False,
  25. use_mask :bool = False,
  26. ):
  27. # ----------- Basic parameters -----------
  28. self.data_dir = data_dir
  29. self.image_set = "train2017" if is_train else "val2017"
  30. self.is_train = is_train
  31. self.use_mask = use_mask
  32. self.num_classes = 80
  33. # ----------- Data parameters -----------
  34. self.json_file = coco_json_files['{}'.format(self.image_set)]
  35. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  36. self.ids = self.coco.getImgIds()
  37. self.class_ids = sorted(self.coco.getCatIds())
  38. self.dataset_size = len(self.ids)
  39. self.class_labels = coco_class_labels
  40. self.class_indexs = coco_class_indexs
  41. # ----------- Transform parameters -----------
  42. self.transform = transform
  43. if is_train:
  44. self.mosaic_prob = cfg.mosaic_prob
  45. self.mixup_prob = cfg.mixup_prob
  46. self.copy_paste = cfg.copy_paste
  47. self.mosaic_augment = None if cfg.mosaic_prob == 0. else MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
  48. self.mixup_augment = None if cfg.mixup_prob == 0. and cfg.copy_paste == 0. else MixupAugment(cfg.train_img_size)
  49. else:
  50. self.mosaic_prob = 0.0
  51. self.mixup_prob = 0.0
  52. self.copy_paste = 0.0
  53. self.mosaic_augment = None
  54. self.mixup_augment = None
  55. print(' ============ Strong augmentation info. ============ ')
  56. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  57. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  58. print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
  59. # ------------ Basic dataset function ------------
  60. def __len__(self):
  61. return len(self.ids)
  62. def __getitem__(self, index):
  63. return self.pull_item(index)
  64. # ------------ Mosaic & Mixup ------------
  65. def load_mosaic(self, index):
  66. # ------------ Prepare 4 indexes of images ------------
  67. ## Load 4x mosaic image
  68. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  69. id1 = index
  70. id2, id3, id4 = random.sample(index_list, 3)
  71. indexs = [id1, id2, id3, id4]
  72. ## Load images and targets
  73. image_list = []
  74. target_list = []
  75. for index in indexs:
  76. img_i, target_i = self.load_image_target(index)
  77. image_list.append(img_i)
  78. target_list.append(target_i)
  79. # ------------ Mosaic augmentation ------------
  80. image, target = self.mosaic_augment(image_list, target_list)
  81. return image, target
  82. def load_mixup(self, origin_image, origin_target, yolox_style=False):
  83. # ------------ Load a new image & target ------------
  84. if yolox_style:
  85. new_index = np.random.randint(0, len(self.ids))
  86. new_image, new_target = self.load_image_target(new_index)
  87. else:
  88. new_index = np.random.randint(0, len(self.ids))
  89. new_image, new_target = self.load_mosaic(new_index)
  90. # ------------ Mixup augmentation ------------
  91. image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target, yolox_style)
  92. return image, target
  93. # ------------ Load data function ------------
  94. def load_image_target(self, index):
  95. # load an image
  96. image, _ = self.pull_image(index)
  97. height, width, channels = image.shape
  98. # load a target
  99. bboxes, labels = self.pull_anno(index)
  100. target = {
  101. "boxes": bboxes,
  102. "labels": labels,
  103. "orig_size": [height, width]
  104. }
  105. return image, target
  106. def pull_item(self, index):
  107. if random.random() < self.mosaic_prob:
  108. # load a mosaic image
  109. mosaic = True
  110. image, target = self.load_mosaic(index)
  111. else:
  112. mosaic = False
  113. # load an image and target
  114. image, target = self.load_image_target(index)
  115. # Yolov5-MixUp
  116. mixup = False
  117. if random.random() < self.mixup_prob:
  118. mixup = True
  119. image, target = self.load_mixup(image, target)
  120. # Copy-paste (use Yolox-Mixup to approximate copy-paste)
  121. if not mixup and random.random() < self.copy_paste:
  122. image, target = self.load_mixup(image, target, yolox_style=True)
  123. # augment
  124. image, target, deltas = self.transform(image, target, mosaic)
  125. return image, target, deltas
  126. def pull_image(self, index):
  127. # get the image file name
  128. image_dict = self.coco.dataset['images'][index]
  129. image_id = image_dict["id"]
  130. filename = image_dict["file_name"]
  131. # load the image
  132. image_path = os.path.join(self.data_dir, self.image_set, filename)
  133. image = cv2.imread(image_path)
  134. assert image is not None
  135. return image, image_id
  136. def pull_anno(self, index):
  137. img_id = self.ids[index]
  138. # image infor
  139. im_ann = self.coco.loadImgs(img_id)[0]
  140. width = im_ann['width']
  141. height = im_ann['height']
  142. # load a target
  143. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  144. annotations = self.coco.loadAnns(anno_ids)
  145. bboxes = []
  146. labels = []
  147. for anno in annotations:
  148. if 'bbox' in anno and anno['area'] > 0:
  149. # bbox
  150. x1 = np.max((0, anno['bbox'][0]))
  151. y1 = np.max((0, anno['bbox'][1]))
  152. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  153. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  154. if x2 < x1 or y2 < y1:
  155. continue
  156. # class label
  157. cls_id = self.class_ids.index(anno['category_id'])
  158. bboxes.append([x1, y1, x2, y2])
  159. labels.append(cls_id)
  160. # guard against no boxes via resizing
  161. bboxes = np.array(bboxes).reshape(-1, 4)
  162. labels = np.array(labels).reshape(-1)
  163. return bboxes, labels
  164. if __name__ == "__main__":
  165. import time
  166. import argparse
  167. from build import build_transform
  168. parser = argparse.ArgumentParser(description='COCO-Dataset')
  169. # opt
  170. parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
  171. help='data root')
  172. parser.add_argument('--is_train', action="store_true", default=False,
  173. help='mixup augmentation.')
  174. parser.add_argument('--aug_type', default="yolo", type=str, choices=["yolo", "ssd"],
  175. help='yolo, ssd.')
  176. args = parser.parse_args()
  177. class YoloBaseConfig(object):
  178. def __init__(self) -> None:
  179. self.max_stride = 32
  180. # ---------------- Data process config ----------------
  181. self.box_format = 'xywh'
  182. self.normalize_coords = False
  183. self.mosaic_prob = 1.0
  184. self.mixup_prob = 0.15
  185. self.copy_paste = 0.3
  186. ## Pixel mean & std
  187. self.pixel_mean = [0., 0., 0.]
  188. self.pixel_std = [255., 255., 255.]
  189. ## Transforms
  190. self.train_img_size = 640
  191. self.test_img_size = 640
  192. self.use_ablu = True
  193. self.aug_type = 'yolo'
  194. self.affine_params = {
  195. 'degrees': 0.0,
  196. 'translate': 0.2,
  197. 'scale': [0.1, 2.0],
  198. 'shear': 0.0,
  199. 'perspective': 0.0,
  200. 'hsv_h': 0.015,
  201. 'hsv_s': 0.7,
  202. 'hsv_v': 0.4,
  203. }
  204. class SSDBaseConfig(object):
  205. def __init__(self) -> None:
  206. self.max_stride = 32
  207. # ---------------- Data process config ----------------
  208. self.box_format = 'xywh'
  209. self.normalize_coords = False
  210. self.mosaic_prob = 0.0
  211. self.mixup_prob = 0.0
  212. self.copy_paste = 0.0
  213. ## Pixel mean & std
  214. self.pixel_mean = [0., 0., 0.]
  215. self.pixel_std = [255., 255., 255.]
  216. ## Transforms
  217. self.train_img_size = 640
  218. self.test_img_size = 640
  219. self.aug_type = 'ssd'
  220. if args.aug_type == "yolo":
  221. cfg = YoloBaseConfig()
  222. elif args.aug_type == "ssd":
  223. cfg = SSDBaseConfig()
  224. transform = build_transform(cfg, args.is_train)
  225. dataset = COCODataset(cfg, args.root, transform, args.is_train)
  226. np.random.seed(0)
  227. class_colors = [(np.random.randint(255),
  228. np.random.randint(255),
  229. np.random.randint(255)) for _ in range(80)]
  230. print('Data length: ', len(dataset))
  231. for i in range(1000):
  232. t0 = time.time()
  233. image, target, deltas = dataset.pull_item(i)
  234. print("Load data: {} s".format(time.time() - t0))
  235. # to numpy
  236. image = image.permute(1, 2, 0).numpy()
  237. # denormalize
  238. image = image * cfg.pixel_std + cfg.pixel_mean
  239. # rgb -> bgr
  240. if transform.color_format == 'rgb':
  241. image = image[..., (2, 1, 0)]
  242. # to uint8
  243. image = image.astype(np.uint8)
  244. image = image.copy()
  245. img_h, img_w = image.shape[:2]
  246. boxes = target["boxes"]
  247. labels = target["labels"]
  248. for box, label in zip(boxes, labels):
  249. if cfg.box_format == 'xyxy':
  250. x1, y1, x2, y2 = box
  251. elif cfg.box_format == 'xywh':
  252. cx, cy, bw, bh = box
  253. x1 = cx - 0.5 * bw
  254. y1 = cy - 0.5 * bh
  255. x2 = cx + 0.5 * bw
  256. y2 = cy + 0.5 * bh
  257. if cfg.normalize_coords:
  258. x1 *= img_w
  259. y1 *= img_h
  260. x2 *= img_w
  261. y2 *= img_h
  262. cls_id = int(label)
  263. color = class_colors[cls_id]
  264. # class name
  265. label = coco_class_labels[cls_id]
  266. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  267. # put the test on the bbox
  268. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  269. cv2.imshow('gt', image)
  270. # cv2.imwrite(str(i)+'.jpg', img)
  271. cv2.waitKey(0)