coco.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import os
  2. import cv2
  3. import time
  4. import numpy as np
  5. from pycocotools.coco import COCO
  6. try:
  7. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  8. from .voc import VOCDataset
  9. except:
  10. from data_augment.strong_augment import MosaicAugment, MixupAugment
  11. from voc import VOCDataset
  12. coco_class_indexs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  13. coco_class_labels = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  14. coco_json_files = {
  15. 'train2017' : 'instances_train2017.json',
  16. 'val2017' : 'instances_val2017.json',
  17. 'test2017' : 'image_info_test.json',
  18. }
  19. class COCODataset(VOCDataset):
  20. def __init__(self,
  21. cfg,
  22. data_dir :str = None,
  23. transform = None,
  24. is_train :bool = False,
  25. use_mask :bool = False,
  26. ):
  27. # ----------- Basic parameters -----------
  28. self.data_dir = data_dir
  29. self.image_set = "train2017" if is_train else "val2017"
  30. self.is_train = is_train
  31. self.use_mask = use_mask
  32. self.num_classes = 80
  33. # ----------- Data parameters -----------
  34. self.json_file = coco_json_files['{}'.format(self.image_set)]
  35. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  36. self.ids = self.coco.getImgIds()
  37. self.class_ids = sorted(self.coco.getCatIds())
  38. self.dataset_size = len(self.ids)
  39. self.class_labels = coco_class_labels
  40. self.class_indexs = coco_class_indexs
  41. # ----------- Transform parameters -----------
  42. self.transform = transform
  43. if is_train:
  44. self.mosaic_augment = MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
  45. self.mixup_augment = MixupAugment(cfg.train_img_size)
  46. self.mosaic_prob = cfg.mosaic_prob
  47. self.mixup_prob = cfg.mixup_prob
  48. self.copy_paste = cfg.copy_paste
  49. else:
  50. self.mosaic_prob = 0.0
  51. self.mixup_prob = 0.0
  52. self.copy_paste = 0.0
  53. self.mosaic_augment = None
  54. self.mixup_augment = None
  55. print(' ============ Strong augmentation info. ============ ')
  56. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  57. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  58. print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
  59. def pull_image(self, index):
  60. # get the image file name
  61. image_dict = self.coco.dataset['images'][index]
  62. image_id = image_dict["id"]
  63. filename = image_dict["file_name"]
  64. # load the image
  65. image_path = os.path.join(self.data_dir, self.image_set, filename)
  66. image = cv2.imread(image_path)
  67. assert image is not None
  68. return image, image_id
  69. def pull_anno(self, index):
  70. img_id = self.ids[index]
  71. # image infor
  72. im_ann = self.coco.loadImgs(img_id)[0]
  73. width = im_ann['width']
  74. height = im_ann['height']
  75. # load a target
  76. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  77. annotations = self.coco.loadAnns(anno_ids)
  78. bboxes = []
  79. labels = []
  80. for anno in annotations:
  81. if 'bbox' in anno and anno['area'] > 0:
  82. # bbox
  83. x1 = np.max((0, anno['bbox'][0]))
  84. y1 = np.max((0, anno['bbox'][1]))
  85. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  86. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  87. if x2 < x1 or y2 < y1:
  88. continue
  89. # class label
  90. cls_id = self.class_ids.index(anno['category_id'])
  91. bboxes.append([x1, y1, x2, y2])
  92. labels.append(cls_id)
  93. # guard against no boxes via resizing
  94. bboxes = np.array(bboxes).reshape(-1, 4)
  95. labels = np.array(labels).reshape(-1)
  96. return bboxes, labels
  97. if __name__ == "__main__":
  98. import time
  99. import argparse
  100. from build import build_transform
  101. parser = argparse.ArgumentParser(description='COCO-Dataset')
  102. # opt
  103. parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
  104. help='data root')
  105. parser.add_argument('--is_train', action="store_true", default=False,
  106. help='mixup augmentation.')
  107. parser.add_argument('--aug_type', default="yolo", type=str, choices=["yolo", "ssd"],
  108. help='yolo, ssd.')
  109. args = parser.parse_args()
  110. class YoloBaseConfig(object):
  111. def __init__(self) -> None:
  112. self.max_stride = 32
  113. # ---------------- Data process config ----------------
  114. self.mosaic_prob = 1.0
  115. self.mixup_prob = 0.15
  116. self.copy_paste = 0.3
  117. ## Pixel mean & std
  118. self.pixel_mean = [0., 0., 0.]
  119. self.pixel_std = [255., 255., 255.]
  120. ## Transforms
  121. self.train_img_size = 640
  122. self.test_img_size = 640
  123. self.use_ablu = True
  124. self.aug_type = 'yolo'
  125. self.affine_params = {
  126. 'degrees': 0.0,
  127. 'translate': 0.2,
  128. 'scale': [0.1, 2.0],
  129. 'shear': 0.0,
  130. 'perspective': 0.0,
  131. 'hsv_h': 0.015,
  132. 'hsv_s': 0.7,
  133. 'hsv_v': 0.4,
  134. }
  135. class SSDBaseConfig(object):
  136. def __init__(self) -> None:
  137. self.max_stride = 32
  138. # ---------------- Data process config ----------------
  139. self.mosaic_prob = 0.0
  140. self.mixup_prob = 0.0
  141. self.copy_paste = 0.0
  142. ## Pixel mean & std
  143. self.pixel_mean = [0., 0., 0.]
  144. self.pixel_std = [255., 255., 255.]
  145. ## Transforms
  146. self.train_img_size = 640
  147. self.test_img_size = 640
  148. self.aug_type = 'ssd'
  149. if args.aug_type == "yolo":
  150. cfg = YoloBaseConfig()
  151. elif args.aug_type == "ssd":
  152. cfg = SSDBaseConfig()
  153. transform = build_transform(cfg, args.is_train)
  154. dataset = COCODataset(cfg, args.root, transform, args.is_train)
  155. np.random.seed(0)
  156. class_colors = [(np.random.randint(255),
  157. np.random.randint(255),
  158. np.random.randint(255)) for _ in range(80)]
  159. print('Data length: ', len(dataset))
  160. for i in range(1000):
  161. t0 = time.time()
  162. image, target, deltas = dataset.pull_item(i)
  163. print("Load data: {} s".format(time.time() - t0))
  164. # to numpy
  165. image = image.permute(1, 2, 0).numpy()
  166. # denormalize
  167. image = image * cfg.pixel_std + cfg.pixel_mean
  168. # rgb -> bgr
  169. if transform.color_format == 'rgb':
  170. image = image[..., (2, 1, 0)]
  171. # to uint8
  172. image = image.astype(np.uint8)
  173. image = image.copy()
  174. img_h, img_w = image.shape[:2]
  175. boxes = target["boxes"]
  176. labels = target["labels"]
  177. for box, label in zip(boxes, labels):
  178. x1, y1, x2, y2 = box
  179. cls_id = int(label)
  180. color = class_colors[cls_id]
  181. # class name
  182. label = coco_class_labels[cls_id]
  183. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  184. # put the test on the bbox
  185. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  186. cv2.imshow('gt', image)
  187. # cv2.imwrite(str(i)+'.jpg', img)
  188. cv2.waitKey(0)