custom.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import os
  2. import cv2
  3. import time
  4. import numpy as np
  5. from pycocotools.coco import COCO
  6. try:
  7. from .data_augment.strong_augment import MosaicAugment, MixupAugment
  8. from .coco import COCODataset
  9. except:
  10. from data_augment.strong_augment import MosaicAugment, MixupAugment
  11. from coco import COCODataset
  12. custom_class_indexs = [0, 1, 2, 3, 4, 5, 6, 7, 8]
  13. custom_class_labels = ('bird', 'butterfly', 'cat', 'cow', 'dog', 'lion', 'person', 'pig', 'tiger', )
  14. class CustomDataset(COCODataset):
  15. def __init__(self,
  16. cfg,
  17. data_dir :str = None,
  18. transform = None,
  19. is_train :bool =False,
  20. ):
  21. # ----------- Basic parameters -----------
  22. self.image_set = "train" if is_train else "val"
  23. self.is_train = is_train
  24. self.num_classes = len(custom_class_labels)
  25. # ----------- Path parameters -----------
  26. self.data_dir = data_dir
  27. self.json_file = '{}.json'.format(self.image_set)
  28. # ----------- Data parameters -----------
  29. self.coco = COCO(os.path.join(self.data_dir, self.image_set, 'annotations', self.json_file))
  30. self.ids = self.coco.getImgIds()
  31. self.class_ids = sorted(self.coco.getCatIds())
  32. self.dataset_size = len(self.ids)
  33. self.class_labels = custom_class_labels
  34. self.class_indexs = custom_class_indexs
  35. # ----------- Transform parameters -----------
  36. self.transform = transform
  37. if is_train:
  38. if cfg.mosaic_prob == 0.:
  39. self.mosaic_augment = None
  40. else:
  41. self.mosaic_augment = MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
  42. self.mosaic_prob = cfg.mosaic_prob
  43. if cfg.mixup_prob == 0.:
  44. self.mixup_augment = None
  45. else:
  46. self.mixup_augment = MixupAugment(cfg.train_img_size)
  47. self.mixup_prob = cfg.mixup_prob
  48. self.copy_paste = cfg.copy_paste
  49. else:
  50. self.mosaic_prob = 0.0
  51. self.mixup_prob = 0.0
  52. self.copy_paste = 0.0
  53. self.mosaic_augment = None
  54. self.mixup_augment = None
  55. print(' ============ Strong augmentation info. ============ ')
  56. print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
  57. print('use Mixup Augmentation: {}'.format(self.mixup_prob))
  58. print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
  59. def pull_image(self, index):
  60. id_ = self.ids[index]
  61. im_ann = self.coco.loadImgs(id_)[0]
  62. img_file = os.path.join(
  63. self.data_dir, self.image_set, 'images', im_ann["file_name"])
  64. image = cv2.imread(img_file)
  65. return image, id_
  66. def pull_anno(self, index):
  67. img_id = self.ids[index]
  68. im_ann = self.coco.loadImgs(img_id)[0]
  69. anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=0)
  70. annotations = self.coco.loadAnns(anno_ids)
  71. # image infor
  72. width = im_ann['width']
  73. height = im_ann['height']
  74. #load a target
  75. bboxes = []
  76. labels = []
  77. for anno in annotations:
  78. if 'bbox' in anno and anno['area'] > 0:
  79. # bbox
  80. x1 = np.max((0, anno['bbox'][0]))
  81. y1 = np.max((0, anno['bbox'][1]))
  82. x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
  83. y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
  84. if x2 <= x1 or y2 <= y1:
  85. continue
  86. # class label
  87. cls_id = self.class_ids.index(anno['category_id'])
  88. bboxes.append([x1, y1, x2, y2])
  89. labels.append(cls_id)
  90. # guard against no boxes via resizing
  91. bboxes = np.array(bboxes).reshape(-1, 4)
  92. labels = np.array(labels).reshape(-1)
  93. return bboxes, labels
  94. if __name__ == "__main__":
  95. import time
  96. import argparse
  97. from build import build_transform
  98. parser = argparse.ArgumentParser(description='RT-ODLab')
  99. # opt
  100. parser.add_argument('--root', default='D:/python_work/dataset/AnimalDataset/',
  101. help='data root')
  102. parser.add_argument('--is_train', action="store_true", default=False,
  103. help='mixup augmentation.')
  104. parser.add_argument('--aug_type', default="yolo", type=str, choices=["yolo", "ssd"],
  105. help='yolo, ssd.')
  106. args = parser.parse_args()
  107. class YoloBaseConfig(object):
  108. def __init__(self) -> None:
  109. self.max_stride = 32
  110. # ---------------- Data process config ----------------
  111. self.box_format = 'xywh'
  112. self.normalize_coords = False
  113. self.mosaic_prob = 1.0
  114. self.mixup_prob = 0.15
  115. self.copy_paste = 0.3
  116. ## Pixel mean & std
  117. self.pixel_mean = [0., 0., 0.]
  118. self.pixel_std = [255., 255., 255.]
  119. ## Transforms
  120. self.train_img_size = 640
  121. self.test_img_size = 640
  122. self.use_ablu = True
  123. self.aug_type = 'yolo'
  124. self.affine_params = {
  125. 'degrees': 0.0,
  126. 'translate': 0.2,
  127. 'scale': [0.1, 2.0],
  128. 'shear': 0.0,
  129. 'perspective': 0.0,
  130. 'hsv_h': 0.015,
  131. 'hsv_s': 0.7,
  132. 'hsv_v': 0.4,
  133. }
  134. class SSDBaseConfig(object):
  135. def __init__(self) -> None:
  136. self.max_stride = 32
  137. # ---------------- Data process config ----------------
  138. self.box_format = 'xywh'
  139. self.normalize_coords = False
  140. self.mosaic_prob = 0.0
  141. self.mixup_prob = 0.0
  142. self.copy_paste = 0.0
  143. ## Pixel mean & std
  144. self.pixel_mean = [0., 0., 0.]
  145. self.pixel_std = [255., 255., 255.]
  146. ## Transforms
  147. self.train_img_size = 640
  148. self.test_img_size = 640
  149. self.aug_type = 'ssd'
  150. if args.aug_type == "yolo":
  151. cfg = YoloBaseConfig()
  152. elif args.aug_type == "ssd":
  153. cfg = SSDBaseConfig()
  154. transform = build_transform(cfg, args.is_train)
  155. dataset = CustomDataset(cfg, args.root, transform, args.is_train)
  156. np.random.seed(0)
  157. class_colors = [(np.random.randint(255),
  158. np.random.randint(255),
  159. np.random.randint(255)) for _ in range(80)]
  160. print('Data length: ', len(dataset))
  161. for i in range(1000):
  162. t0 = time.time()
  163. image, target, deltas = dataset.pull_item(i)
  164. print("Load data: {} s".format(time.time() - t0))
  165. # to numpy
  166. image = image.permute(1, 2, 0).numpy()
  167. # denormalize
  168. image = image * cfg.pixel_std + cfg.pixel_mean
  169. # rgb -> bgr
  170. if transform.color_format == 'rgb':
  171. image = image[..., (2, 1, 0)]
  172. # to uint8
  173. image = image.astype(np.uint8)
  174. image = image.copy()
  175. img_h, img_w = image.shape[:2]
  176. boxes = target["boxes"]
  177. labels = target["labels"]
  178. for box, label in zip(boxes, labels):
  179. if cfg.box_format == 'xyxy':
  180. x1, y1, x2, y2 = box
  181. elif cfg.box_format == 'xywh':
  182. cx, cy, bw, bh = box
  183. x1 = cx - 0.5 * bw
  184. y1 = cy - 0.5 * bh
  185. x2 = cx + 0.5 * bw
  186. y2 = cy + 0.5 * bh
  187. if cfg.normalize_coords:
  188. x1 *= img_w
  189. y1 *= img_h
  190. x2 *= img_w
  191. y2 *= img_h
  192. cls_id = int(label)
  193. color = class_colors[cls_id]
  194. # class name
  195. label = custom_class_labels[cls_id]
  196. image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
  197. # put the test on the bbox
  198. cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
  199. cv2.imshow('gt', image)
  200. # cv2.imwrite(str(i)+'.jpg', img)
  201. cv2.waitKey(0)