coco.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. COCO dataset which returns image_id for evaluation.
  4. Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
  5. """
  6. from pathlib import Path
  7. import torch
  8. import torch.utils.data
  9. import torchvision
  10. try:
  11. from .transforms import build_transform
  12. except:
  13. from transforms import build_transform
  14. coco_labels_91 = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  15. coco_labels_80 = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  16. coco_indexs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  17. class CocoDetection(torchvision.datasets.CocoDetection):
  18. def __init__(self, img_folder, ann_file, transforms):
  19. super(CocoDetection, self).__init__(img_folder, ann_file)
  20. self.coco_labels = coco_labels_80 # 80 coco labels for detection task
  21. self.coco_indexs = coco_indexs # all original coco label index
  22. self._transforms = transforms
  23. def prepare(self, image, target):
  24. w, h = image.size
  25. # load an image
  26. image_id = target["image_id"]
  27. image_id = torch.tensor([image_id])
  28. # load an annotation
  29. anno = target["annotations"]
  30. anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
  31. # bbox target
  32. boxes = [obj["bbox"] for obj in anno]
  33. boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
  34. boxes[:, 2:] += boxes[:, :2]
  35. boxes[:, 0::2].clamp_(min=0, max=w)
  36. boxes[:, 1::2].clamp_(min=0, max=h)
  37. # class target
  38. classes = [self.coco_indexs.index(obj["category_id"]) for obj in anno]
  39. classes = torch.tensor(classes, dtype=torch.int64)
  40. # filter invalid bbox
  41. keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
  42. boxes = boxes[keep]
  43. classes = classes[keep]
  44. target = {}
  45. target["boxes"] = boxes
  46. target["labels"] = classes
  47. target["image_id"] = image_id
  48. # for conversion to coco api
  49. area = torch.tensor([obj["area"] for obj in anno])
  50. iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
  51. target["area"] = area[keep]
  52. target["iscrowd"] = iscrowd[keep]
  53. target["orig_size"] = torch.as_tensor([int(h), int(w)])
  54. target["size"] = torch.as_tensor([int(h), int(w)])
  55. return image, target
  56. def __getitem__(self, idx):
  57. img, target = super(CocoDetection, self).__getitem__(idx)
  58. image_id = self.ids[idx]
  59. target = {'image_id': image_id, 'annotations': target}
  60. img, target = self.prepare(img, target)
  61. if self._transforms is not None:
  62. img, target = self._transforms(img, target)
  63. return img, target
  64. def build_coco(args, transform=None, is_train=False):
  65. root = Path(args.root)
  66. assert root.exists(), f'provided COCO path {root} does not exist'
  67. PATHS = {
  68. "train": (root / "train2017", root / "annotations" / 'instances_train2017.json'),
  69. "val": (root / "val2017", root / "annotations" / 'instances_val2017.json'),
  70. }
  71. image_set = "train" if is_train else "val"
  72. img_folder, ann_file = PATHS[image_set]
  73. # build transform
  74. dataset = CocoDetection(img_folder, ann_file, transform)
  75. return dataset
  76. if __name__ == "__main__":
  77. import argparse
  78. import cv2
  79. import numpy as np
  80. parser = argparse.ArgumentParser(description='COCO-Dataset')
  81. # opt
  82. parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
  83. help='data root')
  84. parser.add_argument('--is_train', action="store_true", default=False,
  85. help='mixup augmentation.')
  86. args = parser.parse_args()
  87. np.random.seed(0)
  88. class_colors = [(np.random.randint(255),
  89. np.random.randint(255),
  90. np.random.randint(255)) for _ in range(80)]
  91. # config
  92. class BaseConfig(object):
  93. def __init__(self):
  94. # --------- Data process ---------
  95. ## input size
  96. self.train_min_size = [512] # short edge of image
  97. self.train_max_size = 736
  98. self.test_min_size = [512]
  99. self.test_max_size = 736
  100. ## Pixel mean & std
  101. self.pixel_mean = [0.485, 0.456, 0.406]
  102. self.pixel_std = [0.229, 0.224, 0.225]
  103. ## Transforms
  104. self.box_format = 'xyxy'
  105. self.normalize_coords = False
  106. self.detr_style = False
  107. self.trans_config = [
  108. {'name': 'RandomHFlip'},
  109. {'name': 'RandomResize'},
  110. {'name': 'RandomShift', 'max_shift': 32},
  111. ]
  112. cfg = BaseConfig()
  113. # build dataset
  114. transform = build_transform(cfg, is_train=True)
  115. dataset = build_coco(args, transform, is_train=False)
  116. for index, (image, target) in enumerate(dataset):
  117. print("{} / {}".format(index, len(dataset)))
  118. # to numpy
  119. image = image.permute(1, 2, 0).numpy()
  120. # denormalize
  121. image = (image * cfg.pixel_std + cfg.pixel_mean) * 255
  122. image = image.astype(np.uint8)[..., (2, 1, 0)].copy()
  123. orig_h, orig_w = image.shape[:2]
  124. tgt_bboxes = target["boxes"]
  125. tgt_labels = target["labels"]
  126. for box, label in zip(tgt_bboxes, tgt_labels):
  127. if cfg.normalize_coords:
  128. box[..., [0, 2]] *= orig_w
  129. box[..., [1, 3]] *= orig_h
  130. if cfg.box_format == 'xywh':
  131. box_x1y1 = box[..., :2] - box[..., 2:] * 0.5
  132. box_x2y2 = box[..., :2] + box[..., 2:] * 0.5
  133. box = torch.cat([box_x1y1, box_x2y2], dim=-1)
  134. # get box target
  135. x1, y1, x2, y2 = box.long()
  136. # get class label
  137. cls_name = coco_labels_80[label.item()]
  138. color = class_colors[label.item()]
  139. # draw bbox
  140. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  141. # put the test on the bbox
  142. cv2.putText(image, cls_name, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  143. cv2.imshow("data", image)
  144. cv2.waitKey(0)