customed.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. import os
  2. import cv2
  3. import time
  4. import random
  5. import numpy as np
  6. from torch.utils.data import Dataset
  7. try:
  8. from pycocotools.coco import COCO
  9. except:
  10. print("It seems that the COCOAPI is not installed.")
  11. try:
  12. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  13. except:
  14. from data_augment.strong_augment import MosaicAugment, MixupAugment
  15. class CustomedDataset(Dataset):
  16. def __init__(self,
  17. img_size :int = 640,
  18. data_dir :str = None,
  19. image_set :str = 'train',
  20. transform = None,
  21. trans_config = None,
  22. is_train :bool =False,
  23. ):
  24. # ----------- Basic parameters -----------
  25. self.img_size = img_size
  26. self.image_set = image_set
  27. self.is_train = is_train
  28. # ----------- Path parameters -----------
  29. self.data_dir = data_dir
  30. self.json_file = '{}.json'.format(image_set)
  31. # ----------- Data parameters -----------
  32. self.coco = COCO(os.path.join(self.data_dir, image_set, 'annotations', self.json_file))
  33. self.ids = self.coco.getImgIds()
  34. self.class_ids = sorted(self.coco.getCatIds())
  35. self.dataset_size = len(self.ids)
  36. # ----------- Transform parameters -----------
  37. self.trans_config = trans_config
  38. self.transform = transform
  39. # ----------- Strong augmentation -----------
  40. if is_train:
  41. self.mosaic_prob = trans_config['mosaic_prob'] if trans_config else 0.0
  42. self.mixup_prob = trans_config['mixup_prob'] if trans_config else 0.0
  43. self.mosaic_augment = MosaicAugment(img_size, trans_config, is_train) if self.mosaic_prob > 0. else None
  44. self.mixup_augment = MixupAugment(img_size, trans_config) if self.mixup_prob > 0. else None
  45. else:
  46. self.mosaic_prob = 0.0
  47. self.mixup_prob = 0.0
  48. self.mosaic_augment = None
  49. self.mixup_augment = None
  50. print('==============================')
  51. print('Image Set: {}'.format(image_set))
  52. print('Json file: {}'.format(self.json_file))
  53. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  54. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  55. # ------------ Basic dataset function ------------
  56. def __len__(self):
  57. return len(self.ids)
  58. def __getitem__(self, index):
  59. return self.pull_item(index)
  60. # ------------ Mosaic & Mixup ------------
  61. def load_mosaic(self, index):
  62. # ------------ Prepare 4 indexes of images ------------
  63. ## Load 4x mosaic image
  64. index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
  65. id1 = index
  66. id2, id3, id4 = random.sample(index_list, 3)
  67. indexs = [id1, id2, id3, id4]
  68. ## Load images and targets
  69. image_list = []
  70. target_list = []
  71. for index in indexs:
  72. img_i, target_i = self.load_image_target(index)
  73. image_list.append(img_i)
  74. target_list.append(target_i)
  75. # ------------ Mosaic augmentation ------------
  76. image, target = self.mosaic_augment(image_list, target_list)
  77. return image, target
  78. def load_mixup(self, origin_image, origin_target):
  79. # ------------ Load a new image & target ------------
  80. if self.mixup_augment.mixup_type == 'yolov5':
  81. new_index = np.random.randint(0, len(self.ids))
  82. new_image, new_target = self.load_mosaic(new_index)
  83. elif self.mixup_augment.mixup_type == 'yolox':
  84. new_index = np.random.randint(0, len(self.ids))
  85. new_image, new_target = self.load_image_target(new_index)
  86. # ------------ Mixup augmentation ------------
  87. image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target)
  88. return image, target
  89. # ------------ Load data function ------------
  90. def load_image_target(self, index):
  91. # load an image
  92. image, _ = self.pull_image(index)
  93. height, width, channels = image.shape
  94. # load a target
  95. bboxes, labels = self.pull_anno(index)
  96. target = {
  97. "boxes": bboxes,
  98. "labels": labels,
  99. "orig_size": [height, width]
  100. }
  101. return image, target
  102. def pull_item(self, index):
  103. if random.random() < self.mosaic_prob:
  104. # load a mosaic image
  105. mosaic = True
  106. image, target = self.load_mosaic(index)
  107. else:
  108. mosaic = False
  109. # load an image and target
  110. image, target = self.load_image_target(index)
  111. # MixUp
  112. if random.random() < self.mixup_prob:
  113. image, target = self.load_mixup(image, target)
  114. # augment
  115. image, target, deltas = self.transform(image, target, mosaic)
  116. return image, target, deltas
  117. def pull_image(self, index):
  118. id_ = self.ids[index]
  119. im_ann = self.coco.loadImgs(id_)[0]
  120. img_file = os.path.join(
  121. self.data_dir, self.image_set, 'images', im_ann["file_name"])
  122. image = cv2.imread(img_file)
  123. return image, id_
  124. def pull_anno(self, index):
  125. img_id = self.ids[index]
  126. im_ann = self.coco.loadImgs(img_id)[0]
  127. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=0)
  128. annotations = self.coco.loadAnns(anno_ids)
  129. # image infor
  130. width = im_ann['width']
  131. height = im_ann['height']
  132. #load a target
  133. bboxes = []
  134. labels = []
  135. for anno in annotations:
  136. if 'bbox' in anno and anno['area'] > 0:
  137. # bbox
  138. x1 = np.max((0, anno['bbox'][0]))
  139. y1 = np.max((0, anno['bbox'][1]))
  140. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  141. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  142. if x2 <= x1 or y2 <= y1:
  143. continue
  144. # class label
  145. cls_id = self.class_ids.index(anno['category_id'])
  146. bboxes.append([x1, y1, x2, y2])
  147. labels.append(cls_id)
  148. # guard against no boxes via resizing
  149. bboxes = np.array(bboxes).reshape(-1, 4)
  150. labels = np.array(labels).reshape(-1)
  151. return bboxes, labels
  152. if __name__ == "__main__":
  153. import time
  154. import argparse
  155. from build import build_transform
  156. import sys
  157. sys.path.append("..")
  158. from config.data_config.dataset_config import dataset_cfg
  159. data_config = dataset_cfg["customed"]
  160. categories = data_config["class_names"]
  161. parser = argparse.ArgumentParser(description='RT-ODLab')
  162. # opt
  163. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/AnimalDataset/',
  164. help='data root')
  165. parser.add_argument('--split', default='train',
  166. help='data split')
  167. parser.add_argument('-size', '--img_size', default=640, type=int,
  168. help='input image size')
  169. parser.add_argument('--min_box_size', default=8.0, type=float,
  170. help='min size of target bounding box.')
  171. parser.add_argument('--mosaic', default=None, type=float,
  172. help='mosaic augmentation.')
  173. parser.add_argument('--mixup', default=None, type=float,
  174. help='mixup augmentation.')
  175. parser.add_argument('--is_train', action="store_true", default=False,
  176. help='mixup augmentation.')
  177. args = parser.parse_args()
  178. trans_config = {
  179. 'aug_type': args.aug_type, # optional: ssd, yolov5
  180. 'pixel_mean': [0., 0., 0.],
  181. 'pixel_std': [255., 255., 255.],
  182. # Basic Augment
  183. 'degrees': 0.0,
  184. 'translate': 0.2,
  185. 'scale': [0.1, 2.0],
  186. 'shear': 0.0,
  187. 'perspective': 0.0,
  188. 'hsv_h': 0.015,
  189. 'hsv_s': 0.7,
  190. 'hsv_v': 0.4,
  191. 'use_ablu': True,
  192. # Mosaic & Mixup
  193. 'mosaic_prob': args.mosaic,
  194. 'mixup_prob': args.mixup,
  195. 'mosaic_type': 'yolov5',
  196. 'mixup_type': 'yolov5',
  197. 'mixup_scale': [0.5, 1.5]
  198. }
  199. transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
  200. pixel_mean = transform.pixel_mean
  201. pixel_std = transform.pixel_std
  202. color_format = transform.color_format
  203. dataset = CustomedDataset(
  204. img_size=args.img_size,
  205. data_dir=args.root,
  206. image_set=args.split,
  207. transform=transform,
  208. trans_config=trans_config,
  209. is_train=args.is_train,
  210. load_cache=args.load_cache
  211. )
  212. np.random.seed(0)
  213. class_colors = [(np.random.randint(255),
  214. np.random.randint(255),
  215. np.random.randint(255)) for _ in range(80)]
  216. print('Data length: ', len(dataset))
  217. for i in range(1000):
  218. t0 = time.time()
  219. image, target, deltas = dataset.pull_item(i)
  220. print("Load data: {} s".format(time.time() - t0))
  221. # to numpy
  222. image = image.permute(1, 2, 0).numpy()
  223. # denormalize
  224. image = image * pixel_std + pixel_mean
  225. if color_format == 'rgb':
  226. # RGB to BGR
  227. image = image[..., (2, 1, 0)]
  228. # to uint8
  229. image = image.astype(np.uint8)
  230. image = image.copy()
  231. img_h, img_w = image.shape[:2]
  232. boxes = target["boxes"]
  233. labels = target["labels"]
  234. for box, label in zip(boxes, labels):
  235. x1, y1, x2, y2 = box
  236. cls_id = int(label)
  237. color = class_colors[cls_id]
  238. # class name
  239. label = categories[cls_id]
  240. if x2 - x1 > 0. and y2 - y1 > 0.:
  241. # draw bbox
  242. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  243. # put the test on the bbox
  244. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  245. cv2.imshow('gt', image)
  246. # cv2.imwrite(str(i)+'.jpg', img)
  247. cv2.waitKey(0)