strong_augment.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. import random
  2. import cv2
  3. import numpy as np
  4. from .yolov5_augment import random_perspective
  5. # ------------------------- Strong augmentations -------------------------
  6. ## Mosaic Augmentation
  7. class MosaicAugment(object):
  8. def __init__(self,
  9. img_size,
  10. transform_config,
  11. is_train=False,
  12. ) -> None:
  13. self.img_size = img_size
  14. self.is_train = is_train
  15. self.keep_ratio = transform_config['mosaic_keep_ratio']
  16. self.affine_params = transform_config['affine_params']
  17. self.mosaic_type = transform_config['mosaic_type']
  18. def yolov5_mosaic_augment(self, image_list, target_list):
  19. assert len(image_list) == 4
  20. mosaic_img = np.ones([self.img_size*2, self.img_size*2, image_list[0].shape[2]], dtype=np.uint8) * 114
  21. # mosaic center
  22. yc, xc = [int(random.uniform(-x, 2*self.img_size + x)) for x in [-self.img_size // 2, -self.img_size // 2]]
  23. # yc = xc = self.img_size
  24. mosaic_bboxes = []
  25. mosaic_labels = []
  26. for i in range(4):
  27. img_i, target_i = image_list[i], target_list[i]
  28. bboxes_i = target_i["boxes"]
  29. labels_i = target_i["labels"]
  30. orig_h, orig_w, _ = img_i.shape
  31. # resize
  32. if self.keep_ratio:
  33. r = self.img_size / max(orig_h, orig_w)
  34. if r != 1:
  35. interp = cv2.INTER_LINEAR if (self.is_train or r > 1) else cv2.INTER_AREA
  36. img_i = cv2.resize(img_i, (int(orig_w * r), int(orig_h * r)), interpolation=interp)
  37. else:
  38. interp = cv2.INTER_LINEAR if self.is_train else cv2.INTER_AREA
  39. img_i = cv2.resize(img_i, (self.img_size, self.img_size), interpolation=interp)
  40. h, w, _ = img_i.shape
  41. # place img in img4
  42. if i == 0: # top left
  43. x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc # xmin, ymin, xmax, ymax (large image)
  44. x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h # xmin, ymin, xmax, ymax (small image)
  45. elif i == 1: # top right
  46. x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, self.img_size * 2), yc
  47. x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
  48. elif i == 2: # bottom left
  49. x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(self.img_size * 2, yc + h)
  50. x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
  51. elif i == 3: # bottom right
  52. x1a, y1a, x2a, y2a = xc, yc, min(xc + w, self.img_size * 2), min(self.img_size * 2, yc + h)
  53. x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
  54. mosaic_img[y1a:y2a, x1a:x2a] = img_i[y1b:y2b, x1b:x2b]
  55. padw = x1a - x1b
  56. padh = y1a - y1b
  57. # labels
  58. bboxes_i_ = bboxes_i.copy()
  59. if len(bboxes_i) > 0:
  60. # a valid target, and modify it.
  61. bboxes_i_[:, 0] = (w * bboxes_i[:, 0] / orig_w + padw)
  62. bboxes_i_[:, 1] = (h * bboxes_i[:, 1] / orig_h + padh)
  63. bboxes_i_[:, 2] = (w * bboxes_i[:, 2] / orig_w + padw)
  64. bboxes_i_[:, 3] = (h * bboxes_i[:, 3] / orig_h + padh)
  65. mosaic_bboxes.append(bboxes_i_)
  66. mosaic_labels.append(labels_i)
  67. if len(mosaic_bboxes) == 0:
  68. mosaic_bboxes = np.array([]).reshape(-1, 4)
  69. mosaic_labels = np.array([]).reshape(-1)
  70. else:
  71. mosaic_bboxes = np.concatenate(mosaic_bboxes)
  72. mosaic_labels = np.concatenate(mosaic_labels)
  73. # clip
  74. mosaic_bboxes = mosaic_bboxes.clip(0, self.img_size * 2)
  75. # random perspective
  76. mosaic_targets = np.concatenate([mosaic_labels[..., None], mosaic_bboxes], axis=-1)
  77. mosaic_img, mosaic_targets = random_perspective(
  78. mosaic_img,
  79. mosaic_targets,
  80. self.affine_params['degrees'],
  81. translate=self.affine_params['translate'],
  82. scale=self.affine_params['scale'],
  83. shear=self.affine_params['shear'],
  84. perspective=self.affine_params['perspective'],
  85. border=[-self.img_size//2, -self.img_size//2]
  86. )
  87. # target
  88. mosaic_target = {
  89. "boxes": mosaic_targets[..., 1:],
  90. "labels": mosaic_targets[..., 0],
  91. "orig_size": [self.img_size, self.img_size]
  92. }
  93. return mosaic_img, mosaic_target
  94. def __call__(self, image_list, target_list):
  95. if self.mosaic_type == 'yolov5':
  96. return self.yolov5_mosaic_augment(image_list, target_list)
  97. else:
  98. raise NotImplementedError("Unknown mosaic type: {}".format(self.mosaic_type))
  99. ## Mixup Augmentation
  100. class MixupAugment(object):
  101. def __init__(self,
  102. img_size,
  103. transform_config,
  104. ) -> None:
  105. self.img_size = img_size
  106. self.mixup_type = transform_config['mixup_type']
  107. self.mixup_scale = transform_config['mixup_scale']
  108. def yolov5_mixup_augment(self, origin_image, origin_target, new_image, new_target):
  109. if origin_image.shape[:2] != new_image.shape[:2]:
  110. img_size = max(new_image.shape[:2])
  111. # origin_image is not a mosaic image
  112. orig_h, orig_w = origin_image.shape[:2]
  113. scale_ratio = img_size / max(orig_h, orig_w)
  114. if scale_ratio != 1:
  115. interp = cv2.INTER_LINEAR if scale_ratio > 1 else cv2.INTER_AREA
  116. resize_size = (int(orig_w * scale_ratio), int(orig_h * scale_ratio))
  117. origin_image = cv2.resize(origin_image, resize_size, interpolation=interp)
  118. # pad new image
  119. pad_origin_image = np.ones([img_size, img_size, origin_image.shape[2]], dtype=np.uint8) * 114
  120. pad_origin_image[:resize_size[1], :resize_size[0]] = origin_image
  121. origin_image = pad_origin_image.copy()
  122. del pad_origin_image
  123. r = np.random.beta(32.0, 32.0) # mixup ratio, alpha=beta=32.0
  124. mixup_image = r * origin_image.astype(np.float32) + \
  125. (1.0 - r)* new_image.astype(np.float32)
  126. mixup_image = mixup_image.astype(np.uint8)
  127. cls_labels = new_target["labels"].copy()
  128. box_labels = new_target["boxes"].copy()
  129. mixup_bboxes = np.concatenate([origin_target["boxes"], box_labels], axis=0)
  130. mixup_labels = np.concatenate([origin_target["labels"], cls_labels], axis=0)
  131. mixup_target = {
  132. "boxes": mixup_bboxes,
  133. "labels": mixup_labels,
  134. 'orig_size': mixup_image.shape[:2]
  135. }
  136. return mixup_image, mixup_target
  137. def yolox_mixup_augment(self, origin_image, origin_target, new_image, new_target):
  138. assert self.mixup_scale is not None, "You should set mixup_scale as a List type, such as [0.5, 1.5], not a NoneType."
  139. jit_factor = random.uniform(*self.mixup_scale)
  140. FLIP = random.uniform(0, 1) > 0.5
  141. # resize new image
  142. orig_h, orig_w = new_image.shape[:2]
  143. cp_scale_ratio = self.img_size / max(orig_h, orig_w)
  144. if cp_scale_ratio != 1:
  145. interp = cv2.INTER_LINEAR if cp_scale_ratio > 1 else cv2.INTER_AREA
  146. resized_new_img = cv2.resize(
  147. new_image, (int(orig_w * cp_scale_ratio), int(orig_h * cp_scale_ratio)), interpolation=interp)
  148. else:
  149. resized_new_img = new_image
  150. # pad new image
  151. cp_img = np.ones([self.img_size, self.img_size, new_image.shape[2]], dtype=np.uint8) * 114
  152. new_shape = (resized_new_img.shape[1], resized_new_img.shape[0])
  153. cp_img[:new_shape[1], :new_shape[0]] = resized_new_img
  154. # resize padded new image
  155. cp_img_h, cp_img_w = cp_img.shape[:2]
  156. cp_new_shape = (int(cp_img_w * jit_factor),
  157. int(cp_img_h * jit_factor))
  158. cp_img = cv2.resize(cp_img, (cp_new_shape[0], cp_new_shape[1]))
  159. cp_scale_ratio *= jit_factor
  160. # flip new image
  161. if FLIP:
  162. cp_img = cp_img[:, ::-1, :]
  163. # pad image
  164. origin_h, origin_w = cp_img.shape[:2]
  165. target_h, target_w = origin_image.shape[:2]
  166. padded_img = np.zeros(
  167. (max(origin_h, target_h), max(origin_w, target_w), 3), dtype=np.uint8
  168. )
  169. padded_img[:origin_h, :origin_w] = cp_img
  170. # crop padded image
  171. x_offset, y_offset = 0, 0
  172. if padded_img.shape[0] > target_h:
  173. y_offset = random.randint(0, padded_img.shape[0] - target_h - 1)
  174. if padded_img.shape[1] > target_w:
  175. x_offset = random.randint(0, padded_img.shape[1] - target_w - 1)
  176. padded_cropped_img = padded_img[
  177. y_offset: y_offset + target_h, x_offset: x_offset + target_w
  178. ]
  179. # process target
  180. new_boxes = new_target["boxes"]
  181. new_labels = new_target["labels"]
  182. new_boxes[:, 0::2] = np.clip(new_boxes[:, 0::2] * cp_scale_ratio, 0, origin_w)
  183. new_boxes[:, 1::2] = np.clip(new_boxes[:, 1::2] * cp_scale_ratio, 0, origin_h)
  184. if FLIP:
  185. new_boxes[:, 0::2] = (
  186. origin_w - new_boxes[:, 0::2][:, ::-1]
  187. )
  188. new_boxes[:, 0::2] = np.clip(
  189. new_boxes[:, 0::2] - x_offset, 0, target_w
  190. )
  191. new_boxes[:, 1::2] = np.clip(
  192. new_boxes[:, 1::2] - y_offset, 0, target_h
  193. )
  194. # mixup target
  195. mixup_boxes = np.concatenate([new_boxes, origin_target['boxes']], axis=0)
  196. mixup_labels = np.concatenate([new_labels, origin_target['labels']], axis=0)
  197. mixup_target = {
  198. 'boxes': mixup_boxes,
  199. 'labels': mixup_labels
  200. }
  201. # mixup images
  202. origin_image = origin_image.astype(np.float32)
  203. origin_image = 0.5 * origin_image + 0.5 * padded_cropped_img.astype(np.float32)
  204. return origin_image.astype(np.uint8), mixup_target
  205. def __call__(self, origin_image, origin_target, new_image, new_target):
  206. if self.mixup_type == "yolov5":
  207. return self.yolov5_mixup_augment(origin_image, origin_target, new_image, new_target)
  208. elif self.mixup_type == "yolox":
  209. return self.yolox_mixup_augment(origin_image, origin_target, new_image, new_target)
  210. else:
  211. raise NotImplementedError("Unknown mixup type: {}".format(self.mixup_type))