coco.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. COCO dataset which returns image_id for evaluation.
  4. Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
  5. """
  6. from pathlib import Path
  7. import torch
  8. import torch.utils.data
  9. import torchvision
  10. try:
  11. from .transforms import build_transform
  12. except:
  13. from transforms import build_transform
  14. # coco_labels = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  15. coco_labels = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
  16. coco_indexs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
  17. class CocoDetection(torchvision.datasets.CocoDetection):
  18. def __init__(self, img_folder, ann_file, transforms):
  19. super(CocoDetection, self).__init__(img_folder, ann_file)
  20. self.coco_labels = coco_labels # 80 coco labels for detection task
  21. self.coco_indexs = coco_indexs # all original coco label index
  22. self._transforms = transforms
  23. def prepare(self, image, target):
  24. w, h = image.size
  25. # load an image
  26. image_id = target["image_id"]
  27. image_id = torch.tensor([image_id])
  28. # load an annotation
  29. anno = target["annotations"]
  30. anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
  31. # bbox target
  32. boxes = [obj["bbox"] for obj in anno]
  33. boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
  34. boxes[:, 2:] += boxes[:, :2]
  35. boxes[:, 0::2].clamp_(min=0, max=w)
  36. boxes[:, 1::2].clamp_(min=0, max=h)
  37. # class target
  38. classes = [self.coco_indexs.index(obj["category_id"]) for obj in anno]
  39. classes = torch.tensor(classes, dtype=torch.int64)
  40. # filter invalid bbox
  41. keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
  42. boxes = boxes[keep]
  43. classes = classes[keep]
  44. target = {}
  45. target["boxes"] = boxes
  46. target["labels"] = classes
  47. target["image_id"] = image_id
  48. # for conversion to coco api
  49. area = torch.tensor([obj["area"] for obj in anno])
  50. iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
  51. target["area"] = area[keep]
  52. target["iscrowd"] = iscrowd[keep]
  53. target["orig_size"] = torch.as_tensor([int(h), int(w)])
  54. target["size"] = torch.as_tensor([int(h), int(w)])
  55. return image, target
  56. def __getitem__(self, idx):
  57. img, target = super(CocoDetection, self).__getitem__(idx)
  58. image_id = self.ids[idx]
  59. target = {'image_id': image_id, 'annotations': target}
  60. img, target = self.prepare(img, target)
  61. if self._transforms is not None:
  62. img, target = self._transforms(img, target)
  63. return img, target
  64. def build_coco(args, transform=None, is_train=False):
  65. root = Path(args.root)
  66. assert root.exists(), f'provided COCO path {root} does not exist'
  67. PATHS = {
  68. "train": (root / "train2017", root / "annotations" / 'instances_train2017.json'),
  69. "val": (root / "val2017", root / "annotations" / 'instances_val2017.json'),
  70. }
  71. image_set = "train" if is_train else "val"
  72. img_folder, ann_file = PATHS[image_set]
  73. # build transform
  74. dataset = CocoDetection(img_folder, ann_file, transform)
  75. return dataset
  76. if __name__ == "__main__":
  77. import argparse
  78. import cv2
  79. import numpy as np
  80. parser = argparse.ArgumentParser(description='COCO-Dataset')
  81. # opt
  82. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  83. help='data root')
  84. parser.add_argument('--is_train', action="store_true", default=False,
  85. help='mixup augmentation.')
  86. args = parser.parse_args()
  87. np.random.seed(0)
  88. class_colors = [(np.random.randint(255),
  89. np.random.randint(255),
  90. np.random.randint(255)) for _ in range(80)]
  91. # config
  92. cfg = {
  93. # input size
  94. 'train_min_size': [800],
  95. 'train_max_size': 1333,
  96. 'test_min_size': 800,
  97. 'test_max_size': 1333,
  98. 'pixel_mean': [0.485, 0.456, 0.406],
  99. 'pixel_std': [0.229, 0.224, 0.225],
  100. # trans config
  101. 'detr_style': False,
  102. 'trans_config': [
  103. {'name': 'RandomResize', 'random_sizes': [400, 500, 600, 700, 800], 'max_size': 1333},
  104. {'name': 'RandomHFlip'},
  105. {'name': 'RandomShift', 'max_shift': 100}
  106. ],
  107. 'box_format': 'xywh',
  108. 'normalize_coords': False,
  109. }
  110. # build dataset
  111. transform = build_transform(cfg, is_train=True)
  112. dataset = build_coco(args, transform, is_train=args.is_train)
  113. for index, (image, target) in enumerate(dataset):
  114. print("{} / {}".format(index, len(dataset)))
  115. # to numpy
  116. image = image.permute(1, 2, 0).numpy()
  117. # denormalize
  118. image = (image * cfg['pixel_std'] + cfg['pixel_mean']) * 255
  119. image = image.astype(np.uint8)[..., (2, 1, 0)].copy()
  120. orig_h, orig_w = image.shape[:2]
  121. tgt_bboxes = target["boxes"]
  122. tgt_labels = target["labels"]
  123. for box, label in zip(tgt_bboxes, tgt_labels):
  124. if cfg['normalize_coords']:
  125. box[..., [0, 2]] *= orig_w
  126. box[..., [1, 3]] *= orig_h
  127. if cfg['box_format'] == 'xywh':
  128. box_x1y1 = box[..., :2] - box[..., 2:] * 0.5
  129. box_x2y2 = box[..., :2] + box[..., 2:] * 0.5
  130. box = torch.cat([box_x1y1, box_x2y2], dim=-1)
  131. # get box target
  132. x1, y1, x2, y2 = box.long()
  133. # get class label
  134. cls_name = coco_labels[label.item()]
  135. color = class_colors[label.item()]
  136. # draw bbox
  137. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  138. # put the test on the bbox
  139. cv2.putText(image, cls_name, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  140. cv2.imshow("data", image)
  141. cv2.waitKey(0)