coco.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import os
  2. import cv2
  3. import time
  4. import numpy as np
  5. from pycocotools.coco import COCO
  6. try:
  7. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  8. from .voc import VOCDataset
  9. except:
  10. from data_augment.strong_augment import MosaicAugment, MixupAugment
  11. from voc import VOCDataset
  12. coco_class_indexs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  13. coco_class_labels = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  14. coco_json_files = {
  15. 'train2017' : 'instances_train2017.json',
  16. 'val2017' : 'instances_val2017.json',
  17. 'test2017' : 'image_info_test.json',
  18. }
  19. class COCODataset(VOCDataset):
  20. def __init__(self,
  21. cfg,
  22. data_dir :str = None,
  23. transform = None,
  24. is_train :bool = False,
  25. use_mask :bool = False,
  26. ):
  27. # ----------- Basic parameters -----------
  28. self.data_dir = data_dir
  29. self.image_set = "train2017" if is_train else "val2017"
  30. self.is_train = is_train
  31. self.use_mask = use_mask
  32. self.num_classes = 80
  33. # ----------- Data parameters -----------
  34. self.json_file = coco_json_files['{}'.format(self.image_set)]
  35. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  36. self.ids = self.coco.getImgIds()
  37. self.class_ids = sorted(self.coco.getCatIds())
  38. self.dataset_size = len(self.ids)
  39. self.class_labels = coco_class_labels
  40. self.class_indexs = coco_class_indexs
  41. # ----------- Transform parameters -----------
  42. self.transform = transform
  43. if is_train:
  44. if cfg.mosaic_prob == 0.:
  45. self.mosaic_augment = None
  46. else:
  47. self.mosaic_augment = MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
  48. self.mosaic_prob = cfg.mosaic_prob
  49. if cfg.mixup_prob == 0.:
  50. self.mixup_augment = None
  51. else:
  52. self.mixup_augment = MixupAugment(cfg.train_img_size)
  53. self.mixup_prob = cfg.mixup_prob
  54. self.copy_paste = cfg.copy_paste
  55. else:
  56. self.mosaic_prob = 0.0
  57. self.mixup_prob = 0.0
  58. self.copy_paste = 0.0
  59. self.mosaic_augment = None
  60. self.mixup_augment = None
  61. print(' ============ Strong augmentation info. ============ ')
  62. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  63. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  64. print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
  65. def pull_image(self, index):
  66. # get the image file name
  67. image_dict = self.coco.dataset['images'][index]
  68. image_id = image_dict["id"]
  69. filename = image_dict["file_name"]
  70. # load the image
  71. image_path = os.path.join(self.data_dir, self.image_set, filename)
  72. image = cv2.imread(image_path)
  73. assert image is not None
  74. return image, image_id
  75. def pull_anno(self, index):
  76. img_id = self.ids[index]
  77. # image infor
  78. im_ann = self.coco.loadImgs(img_id)[0]
  79. width = im_ann['width']
  80. height = im_ann['height']
  81. # load a target
  82. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  83. annotations = self.coco.loadAnns(anno_ids)
  84. bboxes = []
  85. labels = []
  86. for anno in annotations:
  87. if 'bbox' in anno and anno['area'] > 0:
  88. # bbox
  89. x1 = np.max((0, anno['bbox'][0]))
  90. y1 = np.max((0, anno['bbox'][1]))
  91. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  92. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  93. if x2 < x1 or y2 < y1:
  94. continue
  95. # class label
  96. cls_id = self.class_ids.index(anno['category_id'])
  97. bboxes.append([x1, y1, x2, y2])
  98. labels.append(cls_id)
  99. # guard against no boxes via resizing
  100. bboxes = np.array(bboxes).reshape(-1, 4)
  101. labels = np.array(labels).reshape(-1)
  102. return bboxes, labels
  103. if __name__ == "__main__":
  104. import time
  105. import argparse
  106. from build import build_transform
  107. parser = argparse.ArgumentParser(description='COCO-Dataset')
  108. # opt
  109. parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
  110. help='data root')
  111. parser.add_argument('--is_train', action="store_true", default=False,
  112. help='mixup augmentation.')
  113. parser.add_argument('--aug_type', default="yolo", type=str, choices=["yolo", "ssd"],
  114. help='yolo, ssd.')
  115. args = parser.parse_args()
  116. class YoloBaseConfig(object):
  117. def __init__(self) -> None:
  118. self.max_stride = 32
  119. # ---------------- Data process config ----------------
  120. self.box_format = 'xywh'
  121. self.normalize_coords = False
  122. self.mosaic_prob = 1.0
  123. self.mixup_prob = 0.15
  124. self.copy_paste = 0.3
  125. ## Pixel mean & std
  126. self.pixel_mean = [0., 0., 0.]
  127. self.pixel_std = [255., 255., 255.]
  128. ## Transforms
  129. self.train_img_size = 640
  130. self.test_img_size = 640
  131. self.use_ablu = True
  132. self.aug_type = 'yolo'
  133. self.affine_params = {
  134. 'degrees': 0.0,
  135. 'translate': 0.2,
  136. 'scale': [0.1, 2.0],
  137. 'shear': 0.0,
  138. 'perspective': 0.0,
  139. 'hsv_h': 0.015,
  140. 'hsv_s': 0.7,
  141. 'hsv_v': 0.4,
  142. }
  143. class SSDBaseConfig(object):
  144. def __init__(self) -> None:
  145. self.max_stride = 32
  146. # ---------------- Data process config ----------------
  147. self.box_format = 'xywh'
  148. self.normalize_coords = False
  149. self.mosaic_prob = 0.0
  150. self.mixup_prob = 0.0
  151. self.copy_paste = 0.0
  152. ## Pixel mean & std
  153. self.pixel_mean = [0., 0., 0.]
  154. self.pixel_std = [255., 255., 255.]
  155. ## Transforms
  156. self.train_img_size = 640
  157. self.test_img_size = 640
  158. self.aug_type = 'ssd'
  159. if args.aug_type == "yolo":
  160. cfg = YoloBaseConfig()
  161. elif args.aug_type == "ssd":
  162. cfg = SSDBaseConfig()
  163. transform = build_transform(cfg, args.is_train)
  164. dataset = COCODataset(cfg, args.root, transform, args.is_train)
  165. np.random.seed(0)
  166. class_colors = [(np.random.randint(255),
  167. np.random.randint(255),
  168. np.random.randint(255)) for _ in range(80)]
  169. print('Data length: ', len(dataset))
  170. for i in range(1000):
  171. t0 = time.time()
  172. image, target, deltas = dataset.pull_item(i)
  173. print("Load data: {} s".format(time.time() - t0))
  174. # to numpy
  175. image = image.permute(1, 2, 0).numpy()
  176. # denormalize
  177. image = image * cfg.pixel_std + cfg.pixel_mean
  178. # rgb -> bgr
  179. if transform.color_format == 'rgb':
  180. image = image[..., (2, 1, 0)]
  181. # to uint8
  182. image = image.astype(np.uint8)
  183. image = image.copy()
  184. img_h, img_w = image.shape[:2]
  185. boxes = target["boxes"]
  186. labels = target["labels"]
  187. for box, label in zip(boxes, labels):
  188. if cfg.box_format == 'xyxy':
  189. x1, y1, x2, y2 = box
  190. elif cfg.box_format == 'xywh':
  191. cx, cy, bw, bh = box
  192. x1 = cx - 0.5 * bw
  193. y1 = cy - 0.5 * bh
  194. x2 = cx + 0.5 * bw
  195. y2 = cy + 0.5 * bh
  196. if cfg.normalize_coords:
  197. x1 *= img_w
  198. y1 *= img_h
  199. x2 *= img_w
  200. y2 *= img_h
  201. cls_id = int(label)
  202. color = class_colors[cls_id]
  203. # class name
  204. label = coco_class_labels[cls_id]
  205. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  206. # put the test on the bbox
  207. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  208. cv2.imshow('gt', image)
  209. # cv2.imwrite(str(i)+'.jpg', img)
  210. cv2.waitKey(0)