coco.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. import os
  2. import cv2
  3. import time
  4. import numpy as np
  5. from pycocotools.coco import COCO
  6. try:
  7. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  8. from .voc import VOCDataset
  9. except:
  10. from data_augment.strong_augment import MosaicAugment, MixupAugment
  11. from voc import VOCDataset
  12. coco_class_indexs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  13. coco_class_labels = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  14. coco_json_files = {
  15. 'train2017' : 'instances_train2017.json',
  16. 'val2017' : 'instances_val2017.json',
  17. 'test2017' : 'image_info_test.json',
  18. }
  19. class COCODataset(VOCDataset):
  20. def __init__(self,
  21. cfg,
  22. data_dir :str = None,
  23. transform = None,
  24. is_train :bool = False,
  25. use_mask :bool = False,
  26. ):
  27. # ----------- Basic parameters -----------
  28. self.data_dir = data_dir
  29. self.image_set = "train2017" if is_train else "val2017"
  30. self.is_train = is_train
  31. self.use_mask = use_mask
  32. self.num_classes = 80
  33. # ----------- Data parameters -----------
  34. self.json_file = coco_json_files['{}'.format(self.image_set)]
  35. self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
  36. self.ids = self.coco.getImgIds()
  37. self.class_ids = sorted(self.coco.getCatIds())
  38. self.dataset_size = len(self.ids)
  39. self.class_labels = coco_class_labels
  40. self.class_indexs = coco_class_indexs
  41. # ----------- Transform parameters -----------
  42. self.transform = transform
  43. if is_train:
  44. self.mosaic_augment = MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
  45. self.mixup_augment = MixupAugment(cfg.train_img_size)
  46. self.mosaic_prob = cfg.mosaic_prob
  47. self.mixup_prob = cfg.mixup_prob
  48. self.copy_paste = cfg.copy_paste
  49. else:
  50. self.mosaic_prob = 0.0
  51. self.mixup_prob = 0.0
  52. self.copy_paste = 0.0
  53. self.mosaic_augment = None
  54. self.mixup_augment = None
  55. print(' ============ Strong augmentation info. ============ ')
  56. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  57. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  58. print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
  59. def pull_image(self, index):
  60. # get the image file name
  61. image_dict = self.coco.dataset['images'][index]
  62. image_id = image_dict["id"]
  63. filename = image_dict["file_name"]
  64. # load the image
  65. image_path = os.path.join(self.data_dir, self.image_set, filename)
  66. image = cv2.imread(image_path)
  67. assert image is not None
  68. return image, image_id
  69. def pull_anno(self, index):
  70. img_id = self.ids[index]
  71. # image infor
  72. im_ann = self.coco.loadImgs(img_id)[0]
  73. width = im_ann['width']
  74. height = im_ann['height']
  75. # load a target
  76. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
  77. annotations = self.coco.loadAnns(anno_ids)
  78. bboxes = []
  79. labels = []
  80. for anno in annotations:
  81. if 'bbox' in anno and anno['area'] > 0:
  82. # bbox
  83. x1 = np.max((0, anno['bbox'][0]))
  84. y1 = np.max((0, anno['bbox'][1]))
  85. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  86. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  87. if x2 < x1 or y2 < y1:
  88. continue
  89. # class label
  90. cls_id = self.class_ids.index(anno['category_id'])
  91. bboxes.append([x1, y1, x2, y2])
  92. labels.append(cls_id)
  93. # guard against no boxes via resizing
  94. bboxes = np.array(bboxes).reshape(-1, 4)
  95. labels = np.array(labels).reshape(-1)
  96. return bboxes, labels
  97. if __name__ == "__main__":
  98. import time
  99. import argparse
  100. from build import build_transform
  101. parser = argparse.ArgumentParser(description='COCO-Dataset')
  102. # opt
  103. parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
  104. help='data root')
  105. parser.add_argument('--is_train', action="store_true", default=False,
  106. help='mixup augmentation.')
  107. parser.add_argument('--aug_type', default="yolo", type=str, choices=["yolo", "ssd"],
  108. help='yolo, ssd.')
  109. args = parser.parse_args()
  110. class YoloBaseConfig(object):
  111. def __init__(self) -> None:
  112. self.max_stride = 32
  113. # ---------------- Data process config ----------------
  114. self.box_format = 'xywh'
  115. self.normalize_coords = False
  116. self.mosaic_prob = 1.0
  117. self.mixup_prob = 0.15
  118. self.copy_paste = 0.3
  119. ## Pixel mean & std
  120. self.pixel_mean = [0., 0., 0.]
  121. self.pixel_std = [255., 255., 255.]
  122. ## Transforms
  123. self.train_img_size = 640
  124. self.test_img_size = 640
  125. self.use_ablu = True
  126. self.aug_type = 'yolo'
  127. self.affine_params = {
  128. 'degrees': 0.0,
  129. 'translate': 0.2,
  130. 'scale': [0.1, 2.0],
  131. 'shear': 0.0,
  132. 'perspective': 0.0,
  133. 'hsv_h': 0.015,
  134. 'hsv_s': 0.7,
  135. 'hsv_v': 0.4,
  136. }
  137. class SSDBaseConfig(object):
  138. def __init__(self) -> None:
  139. self.max_stride = 32
  140. # ---------------- Data process config ----------------
  141. self.box_format = 'xywh'
  142. self.normalize_coords = False
  143. self.mosaic_prob = 0.0
  144. self.mixup_prob = 0.0
  145. self.copy_paste = 0.0
  146. ## Pixel mean & std
  147. self.pixel_mean = [0., 0., 0.]
  148. self.pixel_std = [255., 255., 255.]
  149. ## Transforms
  150. self.train_img_size = 640
  151. self.test_img_size = 640
  152. self.aug_type = 'ssd'
  153. if args.aug_type == "yolo":
  154. cfg = YoloBaseConfig()
  155. elif args.aug_type == "ssd":
  156. cfg = SSDBaseConfig()
  157. transform = build_transform(cfg, args.is_train)
  158. dataset = COCODataset(cfg, args.root, transform, args.is_train)
  159. np.random.seed(0)
  160. class_colors = [(np.random.randint(255),
  161. np.random.randint(255),
  162. np.random.randint(255)) for _ in range(80)]
  163. print('Data length: ', len(dataset))
  164. for i in range(1000):
  165. t0 = time.time()
  166. image, target, deltas = dataset.pull_item(i)
  167. print("Load data: {} s".format(time.time() - t0))
  168. # to numpy
  169. image = image.permute(1, 2, 0).numpy()
  170. # denormalize
  171. image = image * cfg.pixel_std + cfg.pixel_mean
  172. # rgb -> bgr
  173. if transform.color_format == 'rgb':
  174. image = image[..., (2, 1, 0)]
  175. # to uint8
  176. image = image.astype(np.uint8)
  177. image = image.copy()
  178. img_h, img_w = image.shape[:2]
  179. boxes = target["boxes"]
  180. labels = target["labels"]
  181. for box, label in zip(boxes, labels):
  182. if cfg.box_format == 'xyxy':
  183. x1, y1, x2, y2 = box
  184. elif cfg.box_format == 'xywh':
  185. cx, cy, bw, bh = box
  186. x1 = cx - 0.5 * bw
  187. y1 = cy - 0.5 * bh
  188. x2 = cx + 0.5 * bw
  189. y2 = cy + 0.5 * bh
  190. if cfg.normalize_coords:
  191. x1 *= img_w
  192. y1 *= img_h
  193. x2 *= img_w
  194. y2 *= img_h
  195. cls_id = int(label)
  196. color = class_colors[cls_id]
  197. # class name
  198. label = coco_class_labels[cls_id]
  199. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  200. # put the test on the bbox
  201. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  202. cv2.imshow('gt', image)
  203. # cv2.imwrite(str(i)+'.jpg', img)
  204. cv2.waitKey(0)