yolo_augment.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import random
  2. import cv2
  3. import math
  4. import numpy as np
  5. import albumentations as albu
  6. import torch
  7. import torchvision.transforms.functional as F
  8. # ------------------------- Basic augmentations -------------------------
  9. ## Spatial transform
  10. def random_perspective(image,
  11. targets=(),
  12. degrees=10,
  13. translate=.1,
  14. scale=[0.1, 2.0],
  15. shear=10,
  16. perspective=0.0,
  17. border=(0, 0)):
  18. # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10))
  19. # targets = [cls, xyxy]
  20. height = image.shape[0] + border[0] * 2 # shape(h,w,c)
  21. width = image.shape[1] + border[1] * 2
  22. # Center
  23. C = np.eye(3)
  24. C[0, 2] = -image.shape[1] / 2 # x translation (pixels)
  25. C[1, 2] = -image.shape[0] / 2 # y translation (pixels)
  26. # Perspective
  27. P = np.eye(3)
  28. P[2, 0] = random.uniform(-perspective, perspective) # x perspective (about y)
  29. P[2, 1] = random.uniform(-perspective, perspective) # y perspective (about x)
  30. # Rotation and Scale
  31. R = np.eye(3)
  32. a = random.uniform(-degrees, degrees)
  33. # a += random.choice([-180, -90, 0, 90]) # add 90deg rotations to small rotations
  34. s = random.uniform(scale[0], scale[1])
  35. # s = 2 ** random.uniform(-scale, scale)
  36. R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
  37. # Shear
  38. S = np.eye(3)
  39. S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # x shear (deg)
  40. S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180) # y shear (deg)
  41. # Translation
  42. T = np.eye(3)
  43. T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width # x translation (pixels)
  44. T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height # y translation (pixels)
  45. # Combined rotation matrix
  46. M = T @ S @ R @ P @ C # order of operations (right to left) is IMPORTANT
  47. if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any(): # image changed
  48. if perspective:
  49. image = cv2.warpPerspective(image, M, dsize=(width, height), borderValue=(0, 0, 0))
  50. else: # affine
  51. image = cv2.warpAffine(image, M[:2], dsize=(width, height), borderValue=(0, 0, 0))
  52. # Transform label coordinates
  53. n = len(targets)
  54. if n:
  55. new = np.zeros((n, 4))
  56. # warp boxes
  57. xy = np.ones((n * 4, 3))
  58. xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
  59. xy = xy @ M.T # transform
  60. xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8) # perspective rescale or affine
  61. # create new boxes
  62. x = xy[:, [0, 2, 4, 6]]
  63. y = xy[:, [1, 3, 5, 7]]
  64. new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
  65. # clip
  66. new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
  67. new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
  68. targets[:, 1:5] = new
  69. return image, targets
  70. ## Color transform
  71. def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
  72. r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1 # random gains
  73. hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
  74. dtype = img.dtype # uint8
  75. x = np.arange(0, 256, dtype=np.int16)
  76. lut_hue = ((x * r[0]) % 180).astype(dtype)
  77. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  78. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  79. img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
  80. cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img) # no return needed
  81. return img
  82. ## Ablu transform
  83. class Albumentations(object):
  84. def __init__(self, img_size=640):
  85. self.img_size = img_size
  86. self.transform = albu.Compose(
  87. [albu.Blur(p=0.01),
  88. albu.MedianBlur(p=0.01),
  89. albu.ToGray(p=0.01),
  90. albu.CLAHE(p=0.01),
  91. ],
  92. bbox_params=albu.BboxParams(format='pascal_voc', label_fields=['labels'])
  93. )
  94. def __call__(self, image, target=None):
  95. labels = target['labels']
  96. bboxes = target['boxes']
  97. if len(labels) > 0:
  98. new = self.transform(image=image, bboxes=bboxes, labels=labels)
  99. if len(new["labels"]) > 0:
  100. image = new['image']
  101. target['labels'] = np.array(new["labels"], dtype=labels.dtype)
  102. target['boxes'] = np.array(new["bboxes"], dtype=bboxes.dtype)
  103. return image, target
  104. # ------------------------- Preprocessers -------------------------
  105. ## YOLO-style Transform for Train
  106. class YOLOAugmentation(object):
  107. def __init__(self,
  108. img_size=640,
  109. affine_params=None,
  110. use_ablu=False,
  111. pixel_mean = [0., 0., 0.],
  112. pixel_std = [255., 255., 255.],
  113. box_format='xyxy',
  114. normalize_coords=False):
  115. # Basic parameters
  116. self.img_size = img_size
  117. self.pixel_mean = pixel_mean
  118. self.pixel_std = pixel_std
  119. self.box_format = box_format
  120. self.affine_params = affine_params
  121. self.normalize_coords = normalize_coords
  122. self.color_format = 'bgr'
  123. # Albumentations
  124. self.ablu_trans = Albumentations(img_size) if use_ablu else None
  125. def __call__(self, image, target, mosaic=False):
  126. # --------------- Resize image ---------------
  127. orig_h, orig_w = image.shape[:2]
  128. ratio = self.img_size / max(orig_h, orig_w)
  129. if ratio != 1:
  130. new_shape = (int(round(orig_w * ratio)), int(round(orig_h * ratio)))
  131. image = cv2.resize(image, new_shape)
  132. img_h, img_w = image.shape[:2]
  133. # --------------- Filter bad targets ---------------
  134. tgt_boxes_wh = target["boxes"][..., 2:] - target["boxes"][..., :2]
  135. min_tgt_size = np.min(tgt_boxes_wh, axis=-1)
  136. keep = (min_tgt_size > 1)
  137. target["boxes"] = target["boxes"][keep]
  138. target["labels"] = target["labels"][keep]
  139. # --------------- Albumentations ---------------
  140. if self.ablu_trans is not None:
  141. image, target = self.ablu_trans(image, target)
  142. # --------------- HSV augmentations ---------------
  143. image = augment_hsv(image,
  144. hgain=self.affine_params['hsv_h'],
  145. sgain=self.affine_params['hsv_s'],
  146. vgain=self.affine_params['hsv_v'])
  147. # --------------- Spatial augmentations ---------------
  148. ## Random perspective
  149. if not mosaic:
  150. # rescale bbox
  151. target["boxes"][..., [0, 2]] = target["boxes"][..., [0, 2]] / orig_w * img_w
  152. target["boxes"][..., [1, 3]] = target["boxes"][..., [1, 3]] / orig_h * img_h
  153. # spatial augment
  154. target_ = np.concatenate((target['labels'][..., None], target['boxes']), axis=-1)
  155. image, target_ = random_perspective(image, target_,
  156. degrees = self.affine_params['degrees'],
  157. translate = self.affine_params['translate'],
  158. scale = self.affine_params['scale'],
  159. shear = self.affine_params['shear'],
  160. perspective = self.affine_params['perspective']
  161. )
  162. target['boxes'] = target_[..., 1:]
  163. target['labels'] = target_[..., 0]
  164. ## Random flip
  165. if random.random() < 0.5:
  166. w = image.shape[1]
  167. image = np.fliplr(image).copy()
  168. boxes = target['boxes'].copy()
  169. boxes[..., [0, 2]] = w - boxes[..., [2, 0]]
  170. target["boxes"] = boxes
  171. # --------------- To torch.Tensor ---------------
  172. image = F.to_tensor(image) * 255.
  173. image = F.normalize(image, self.pixel_mean, self.pixel_std)
  174. if target is not None:
  175. target["boxes"] = torch.as_tensor(target["boxes"]).float()
  176. target["labels"] = torch.as_tensor(target["labels"]).long()
  177. # normalize coords
  178. if self.normalize_coords:
  179. target["boxes"][..., [0, 2]] /= img_w
  180. target["boxes"][..., [1, 3]] /= img_h
  181. # xyxy -> xywh
  182. if self.box_format == "xywh":
  183. box_cxcy = (target["boxes"][..., :2] + target["boxes"][..., 2:]) * 0.5
  184. box_bwbh = target["boxes"][..., 2:] - target["boxes"][..., :2]
  185. target["boxes"] = torch.cat([box_cxcy, box_bwbh], dim=-1)
  186. # --------------- Pad Image ---------------
  187. img_h0, img_w0 = image.shape[1:]
  188. pad_image = torch.zeros([image.size(0), self.img_size, self.img_size]).float()
  189. pad_image[:, :img_h0, :img_w0] = image
  190. return pad_image, target, ratio
  191. ## YOLO-style Transform for Eval
  192. class YOLOBaseTransform(object):
  193. def __init__(self,
  194. img_size=640,
  195. max_stride=32,
  196. pixel_mean = [0., 0., 0.],
  197. pixel_std = [255., 255., 255.],
  198. box_format='xyxy',
  199. normalize_coords=False):
  200. self.img_size = img_size
  201. self.max_stride = max_stride
  202. self.pixel_mean = pixel_mean
  203. self.pixel_std = pixel_std
  204. self.box_format = box_format
  205. self.normalize_coords = normalize_coords
  206. self.color_format = 'bgr'
  207. def __call__(self, image, target=None, mosaic=False):
  208. # --------------- Resize image ---------------
  209. orig_h, orig_w = image.shape[:2]
  210. ratio = self.img_size / max(orig_h, orig_w)
  211. if ratio != 1:
  212. new_shape = (int(round(orig_w * ratio)), int(round(orig_h * ratio)))
  213. image = cv2.resize(image, new_shape)
  214. img_h, img_w = image.shape[:2]
  215. # --------------- Rescale bboxes ---------------
  216. if target is not None:
  217. # rescale bbox
  218. target["boxes"][..., [0, 2]] = target["boxes"][..., [0, 2]] / orig_w * img_w
  219. target["boxes"][..., [1, 3]] = target["boxes"][..., [1, 3]] / orig_h * img_h
  220. # --------------- To torch.Tensor ---------------
  221. image = F.to_tensor(image) * 255.
  222. image = F.normalize(image, self.pixel_mean, self.pixel_std)
  223. if target is not None:
  224. target["boxes"] = torch.as_tensor(target["boxes"]).float()
  225. target["labels"] = torch.as_tensor(target["labels"]).long()
  226. # normalize coords
  227. if self.normalize_coords:
  228. target["boxes"][..., [0, 2]] /= img_w
  229. target["boxes"][..., [1, 3]] /= img_h
  230. # xyxy -> xywh
  231. if self.box_format == "xywh":
  232. box_cxcy = (target["boxes"][..., :2] + target["boxes"][..., 2:]) * 0.5
  233. box_bwbh = target["boxes"][..., 2:] - target["boxes"][..., :2]
  234. target["boxes"] = torch.cat([box_cxcy, box_bwbh], dim=-1)
  235. # --------------- Pad image ---------------
  236. img_h0, img_w0 = image.shape[1:]
  237. dh = img_h0 % self.max_stride
  238. dw = img_w0 % self.max_stride
  239. dh = dh if dh == 0 else self.max_stride - dh
  240. dw = dw if dw == 0 else self.max_stride - dw
  241. pad_img_h = img_h0 + dh
  242. pad_img_w = img_w0 + dw
  243. pad_image = torch.zeros([image.size(0), pad_img_h, pad_img_w]).float()
  244. pad_image[:, :img_h0, :img_w0] = image
  245. return pad_image, target, ratio