make_dataset.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import os
  2. import cv2
  3. import torch
  4. import random
  5. import numpy as np
  6. import sys
  7. sys.path.append("../")
  8. from utils import distributed_utils
  9. from dataset.voc import VOCDataset, VOC_CLASSES
  10. from dataset.coco import COCODataset, coco_class_labels, coco_class_index
  11. from config import build_trans_config, build_dataset_config
  12. def fix_random_seed(args):
  13. seed = args.seed + distributed_utils.get_rank()
  14. torch.manual_seed(seed)
  15. np.random.seed(seed)
  16. random.seed(seed)
  17. # ------------------------------ Dataset ------------------------------
  18. def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
  19. # ------------------------- Basic parameters -------------------------
  20. data_dir = os.path.join(args.root, data_cfg['data_name'])
  21. num_classes = data_cfg['num_classes']
  22. class_names = data_cfg['class_names']
  23. class_indexs = data_cfg['class_indexs']
  24. dataset_info = {
  25. 'num_classes': num_classes,
  26. 'class_names': class_names,
  27. 'class_indexs': class_indexs
  28. }
  29. # ------------------------- Build dataset -------------------------
  30. ## VOC dataset
  31. if args.dataset == 'voc':
  32. dataset = VOCDataset(
  33. img_size=args.img_size,
  34. data_dir=data_dir,
  35. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
  36. transform=transform,
  37. trans_config=trans_config,
  38. load_cache=args.load_cache
  39. )
  40. ## COCO dataset
  41. elif args.dataset == 'coco':
  42. dataset = COCODataset(
  43. img_size=args.img_size,
  44. data_dir=data_dir,
  45. image_set='train2017' if is_train else 'val2017',
  46. transform=transform,
  47. trans_config=trans_config,
  48. load_cache=args.load_cache
  49. )
  50. return dataset, dataset_info
  51. def visualize(image, target, dataset_name="voc"):
  52. if dataset_name == "voc":
  53. class_labels = VOC_CLASSES
  54. class_indexs = None
  55. num_classes = 20
  56. elif dataset_name == "coco":
  57. class_labels = coco_class_labels
  58. class_indexs = coco_class_index
  59. num_classes = 80
  60. else:
  61. raise NotImplementedError
  62. class_colors = [(np.random.randint(255),
  63. np.random.randint(255),
  64. np.random.randint(255))
  65. for _ in range(num_classes)]
  66. # to numpy
  67. # image = image.permute(1, 2, 0).numpy()
  68. image = image.astype(np.uint8)
  69. image = image.copy()
  70. boxes = target["boxes"]
  71. labels = target["labels"]
  72. for box, label in zip(boxes, labels):
  73. x1, y1, x2, y2 = box
  74. if x2 - x1 > 1 and y2 - y1 > 1:
  75. cls_id = int(label)
  76. color = class_colors[cls_id]
  77. # class name
  78. if dataset_name == 'coco':
  79. assert class_indexs is not None
  80. class_name = class_labels[class_indexs[cls_id]]
  81. else:
  82. class_name = class_labels[cls_id]
  83. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
  84. # put the test on the bbox
  85. cv2.putText(image, class_name, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  86. cv2.imshow('gt', image)
  87. cv2.waitKey(0)
  88. if __name__ == "__main__":
  89. import argparse
  90. from build import build_transform
  91. parser = argparse.ArgumentParser(description='VOC-Dataset')
  92. # Seed
  93. parser.add_argument('--seed', default=42, type=int)
  94. # Dataset
  95. parser.add_argument('--root', default='data/datasets/',
  96. help='data root')
  97. parser.add_argument('--dataset', type=str, default="voc",
  98. help='augmentation type.')
  99. parser.add_argument('--load_cache', action="store_true", default=False,
  100. help='load cached data.')
  101. parser.add_argument('--vis_tgt', action="store_true", default=False,
  102. help='load cached data.')
  103. parser.add_argument('--is_train', action="store_true", default=False,
  104. help='mixup augmentation.')
  105. # Image size
  106. parser.add_argument('-size', '--img_size', default=640, type=int,
  107. help='input image size.')
  108. # Augmentation
  109. parser.add_argument('--aug_type', type=str, default="yolov5_nano",
  110. help='augmentation type.')
  111. parser.add_argument('--mosaic', default=None, type=float,
  112. help='mosaic augmentation.')
  113. parser.add_argument('--mixup', default=None, type=float,
  114. help='mixup augmentation.')
  115. # DDP train
  116. parser.add_argument('-dist', '--distributed', action='store_true', default=False,
  117. help='distributed training')
  118. parser.add_argument('--dist_url', default='env://',
  119. help='url used to set up distributed training')
  120. parser.add_argument('--world_size', default=1, type=int,
  121. help='number of distributed processes')
  122. parser.add_argument('--sybn', action='store_true', default=False,
  123. help='use sybn.')
  124. # Output
  125. parser.add_argument('--output_dir', type=str, default='cache_data/',
  126. help='data root')
  127. args = parser.parse_args()
  128. assert args.aug_type in ["yolov5_pico", "yolov5_nano", "yolov5_small", "yolov5_medium", "yolov5_large", "yolov5_huge",
  129. "yolox_pico", "yolox_nano", "yolox_small", "yolox_medium", "yolox_large", "yolox_huge"]
  130. # ------------- Build transform config -------------
  131. dataset_cfg = build_dataset_config(args)
  132. trans_config = build_trans_config(args.aug_type)
  133. # ------------- Build transform -------------
  134. transform, trans_config = build_transform(args, trans_config, max_stride=32, is_train=args.is_train)
  135. # ------------- Build dataset -------------
  136. dataset, dataset_info = build_dataset(args, dataset_cfg, trans_config, transform, is_train=args.is_train)
  137. print('Data length: ', len(dataset))
  138. # ---------------------------- Fix random seed ----------------------------
  139. fix_random_seed(args)
  140. # ---------------------------- Main process ----------------------------
  141. # We only cache the taining data
  142. data_items = []
  143. for idx in range(len(dataset)):
  144. if idx % 2000 == 0:
  145. print("Caching images and targets : {} / {} ...".format(idx, len(dataset)))
  146. # load a data
  147. image, target = dataset.load_image_target(idx)
  148. orig_h, orig_w, _ = image.shape
  149. # resize image
  150. r = args.img_size / max(orig_h, orig_w)
  151. if r != 1:
  152. interp = cv2.INTER_LINEAR
  153. new_size = (int(orig_w * r), int(orig_h * r))
  154. image = cv2.resize(image, new_size, interpolation=interp)
  155. img_h, img_w = image.shape[:2]
  156. # rescale bbox
  157. boxes = target["boxes"].copy()
  158. boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
  159. boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
  160. target["boxes"] = boxes
  161. # visualize data
  162. if args.vis_tgt:
  163. print(image.shape)
  164. visualize(image, target, args.dataset)
  165. continue
  166. dict_item = {}
  167. dict_item["image"] = image
  168. dict_item["target"] = target
  169. data_items.append(dict_item)
  170. output_dir = os.path.join(args.output_dir, args.dataset)
  171. os.makedirs(output_dir, exist_ok=True)
  172. print('Cached data size: ', len(data_items))
  173. if args.is_train:
  174. save_file = os.path.join(output_dir, "{}_train.pth".format(args.dataset))
  175. else:
  176. save_file = os.path.join(output_dir, "{}_valid.pth".format(args.dataset))
  177. torch.save(data_items, save_file)