make_dataset.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. import os
  2. import cv2
  3. import torch
  4. import random
  5. import numpy as np
  6. from voc import VOCDataset, VOC_CLASSES
  7. from coco import COCODataset, coco_class_labels, coco_class_index
  8. dataset_cfg = {
  9. 'voc': {
  10. 'data_name': 'VOCdevkit',
  11. 'num_classes': 20,
  12. 'class_indexs': None,
  13. 'class_names': ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'),
  14. },
  15. 'coco':{
  16. 'data_name': 'COCO',
  17. 'num_classes': 80,
  18. 'class_indexs': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90],
  19. 'class_names': ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'),
  20. },
  21. }
  22. # ------------------------------ Dataset ------------------------------
  23. def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
  24. # ------------------------- Basic parameters -------------------------
  25. data_dir = os.path.join(args.root, data_cfg['data_name'])
  26. num_classes = data_cfg['num_classes']
  27. class_names = data_cfg['class_names']
  28. class_indexs = data_cfg['class_indexs']
  29. dataset_info = {
  30. 'num_classes': num_classes,
  31. 'class_names': class_names,
  32. 'class_indexs': class_indexs
  33. }
  34. # ------------------------- Build dataset -------------------------
  35. ## VOC dataset
  36. if args.dataset == 'voc':
  37. dataset = VOCDataset(
  38. img_size=args.img_size,
  39. data_dir=data_dir,
  40. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
  41. transform=transform,
  42. trans_config=trans_config,
  43. load_cache=args.load_cache
  44. )
  45. ## COCO dataset
  46. elif args.dataset == 'coco':
  47. dataset = COCODataset(
  48. img_size=args.img_size,
  49. data_dir=data_dir,
  50. image_set='train2017' if is_train else 'val2017',
  51. transform=transform,
  52. trans_config=trans_config,
  53. load_cache=args.load_cache
  54. )
  55. return dataset, dataset_info
  56. def visualize(image, target, dataset_name="voc"):
  57. if dataset_name == "voc":
  58. class_labels = VOC_CLASSES
  59. class_indexs = None
  60. num_classes = 20
  61. elif dataset_name == "coco":
  62. class_labels = coco_class_labels
  63. class_indexs = coco_class_index
  64. num_classes = 80
  65. else:
  66. raise NotImplementedError
  67. class_colors = [(np.random.randint(255),
  68. np.random.randint(255),
  69. np.random.randint(255))
  70. for _ in range(num_classes)]
  71. # to numpy
  72. # image = image.permute(1, 2, 0).numpy()
  73. image = image.astype(np.uint8)
  74. image = image.copy()
  75. boxes = target["boxes"]
  76. labels = target["labels"]
  77. for box, label in zip(boxes, labels):
  78. x1, y1, x2, y2 = box
  79. if x2 - x1 > 1 and y2 - y1 > 1:
  80. cls_id = int(label)
  81. color = class_colors[cls_id]
  82. # class name
  83. if dataset_name == 'coco':
  84. assert class_indexs is not None
  85. class_name = class_labels[class_indexs[cls_id]]
  86. else:
  87. class_name = class_labels[cls_id]
  88. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  89. # put the test on the bbox
  90. cv2.putText(image, class_name, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  91. cv2.imshow('gt', image)
  92. cv2.waitKey(0)
  93. def build_dataset_config(args):
  94. if args.dataset in ['coco', 'coco-val', 'coco-test']:
  95. cfg = dataset_cfg['coco']
  96. else:
  97. cfg = dataset_cfg[args.dataset]
  98. print('==============================')
  99. print('Dataset Config: {} \n'.format(cfg))
  100. return cfg
  101. if __name__ == "__main__":
  102. import argparse
  103. parser = argparse.ArgumentParser(description='VOC-Dataset')
  104. # Seed
  105. parser.add_argument('--seed', default=42, type=int)
  106. # Dataset
  107. parser.add_argument('--root', default='data/datasets/',
  108. help='data root')
  109. parser.add_argument('--dataset', type=str, default="voc",
  110. help='augmentation type.')
  111. parser.add_argument('--load_cache', action="store_true", default=False,
  112. help='load cached data.')
  113. parser.add_argument('--vis_tgt', action="store_true", default=False,
  114. help='load cached data.')
  115. parser.add_argument('--is_train', action="store_true", default=False,
  116. help='mixup augmentation.')
  117. # Image size
  118. parser.add_argument('-size', '--img_size', default=640, type=int,
  119. help='input image size.')
  120. # Augmentation
  121. parser.add_argument('--aug_type', type=str, default="yolov5_nano",
  122. help='augmentation type.')
  123. parser.add_argument('--mosaic', default=None, type=float,
  124. help='mosaic augmentation.')
  125. parser.add_argument('--mixup', default=None, type=float,
  126. help='mixup augmentation.')
  127. # Output
  128. parser.add_argument('--output_dir', type=str, default='cache_data/',
  129. help='data root')
  130. args = parser.parse_args()
  131. # ------------- Build transform config -------------
  132. dataset_cfg = build_dataset_config(args)
  133. # ------------- Build dataset -------------
  134. dataset, dataset_info = build_dataset(args, dataset_cfg, trans_config=None, transform=None, is_train=args.is_train)
  135. print('Data length: ', len(dataset))
  136. # ---------------------------- Main process ----------------------------
  137. # We only cache the taining data
  138. data_items = []
  139. for idx in range(len(dataset)):
  140. if idx % 2000 == 0:
  141. print("Caching images and targets : {} / {} ...".format(idx, len(dataset)))
  142. # load a data
  143. image, target = dataset.load_image_target(idx)
  144. orig_h, orig_w, _ = image.shape
  145. # resize image
  146. r = args.img_size / max(orig_h, orig_w)
  147. if r != 1:
  148. interp = cv2.INTER_LINEAR
  149. new_size = (int(orig_w * r), int(orig_h * r))
  150. image = cv2.resize(image, new_size, interpolation=interp)
  151. img_h, img_w = image.shape[:2]
  152. # rescale bbox
  153. boxes = target["boxes"].copy()
  154. boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
  155. boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
  156. target["boxes"] = boxes
  157. # visualize data
  158. if args.vis_tgt:
  159. print(image.shape)
  160. visualize(image, target, args.dataset)
  161. continue
  162. dict_item = {}
  163. dict_item["image"] = image
  164. dict_item["target"] = target
  165. data_items.append(dict_item)
  166. output_dir = os.path.join(args.output_dir, args.dataset)
  167. os.makedirs(output_dir, exist_ok=True)
  168. print('Cached data size: ', len(data_items))
  169. if args.is_train:
  170. save_file = os.path.join(output_dir, "{}_train.pth".format(args.dataset))
  171. else:
  172. save_file = os.path.join(output_dir, "{}_valid.pth".format(args.dataset))
  173. torch.save(data_items, save_file)