ssd_augment.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. import cv2
  2. import numpy as np
  3. import torch
  4. from numpy import random
  5. def intersect(box_a, box_b):
  6. max_xy = np.minimum(box_a[:, 2:], box_b[2:])
  7. min_xy = np.maximum(box_a[:, :2], box_b[:2])
  8. inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf)
  9. return inter[:, 0] * inter[:, 1]
  10. def jaccard_numpy(box_a, box_b):
  11. """Compute the jaccard overlap of two sets of boxes. The jaccard overlap
  12. is simply the intersection over union of two boxes.
  13. E.g.:
  14. A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
  15. Args:
  16. box_a: Multiple bounding boxes, Shape: [num_boxes,4]
  17. box_b: Single bounding box, Shape: [4]
  18. Return:
  19. jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]]
  20. """
  21. inter = intersect(box_a, box_b)
  22. area_a = ((box_a[:, 2]-box_a[:, 0]) *
  23. (box_a[:, 3]-box_a[:, 1])) # [A,B]
  24. area_b = ((box_b[2]-box_b[0]) *
  25. (box_b[3]-box_b[1])) # [A,B]
  26. union = area_a + area_b - inter
  27. return inter / union # [A,B]
  28. class Compose(object):
  29. """Composes several augmentations together.
  30. Args:
  31. transforms (List[Transform]): list of transforms to compose.
  32. Example:
  33. >>> augmentations.Compose([
  34. >>> transforms.CenterCrop(10),
  35. >>> transforms.ToTensor(),
  36. >>> ])
  37. """
  38. def __init__(self, transforms):
  39. self.transforms = transforms
  40. def __call__(self, img, boxes=None, labels=None):
  41. for t in self.transforms:
  42. img, boxes, labels = t(img, boxes, labels)
  43. return img, boxes, labels
  44. class ConvertFromInts(object):
  45. def __call__(self, image, boxes=None, labels=None):
  46. return image.astype(np.float32), boxes, labels
  47. class ConvertColor(object):
  48. def __init__(self, current='BGR', transform='HSV'):
  49. self.transform = transform
  50. self.current = current
  51. def __call__(self, image, boxes=None, labels=None):
  52. if self.current == 'BGR' and self.transform == 'HSV':
  53. image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
  54. elif self.current == 'HSV' and self.transform == 'BGR':
  55. image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
  56. else:
  57. raise NotImplementedError
  58. return image, boxes, labels
  59. class Resize(object):
  60. def __init__(self, img_size=640):
  61. self.img_size = img_size
  62. def __call__(self, image, boxes=None, labels=None):
  63. orig_h, orig_w = image.shape[:2]
  64. image = cv2.resize(image, (self.img_size, self.img_size))
  65. # normalize
  66. if boxes is not None:
  67. img_h, img_w = image.shape[:2]
  68. boxes[..., [0, 2]] = boxes[..., [0, 2]] / orig_w * img_w
  69. boxes[..., [1, 3]] = boxes[..., [1, 3]] / orig_h * img_h
  70. return image, boxes, labels
  71. class RandomSaturation(object):
  72. def __init__(self, lower=0.5, upper=1.5):
  73. self.lower = lower
  74. self.upper = upper
  75. assert self.upper >= self.lower, "contrast upper must be >= lower."
  76. assert self.lower >= 0, "contrast lower must be non-negative."
  77. def __call__(self, image, boxes=None, labels=None):
  78. if random.randint(2):
  79. image[:, :, 1] *= random.uniform(self.lower, self.upper)
  80. return image, boxes, labels
  81. class RandomHue(object):
  82. def __init__(self, delta=18.0):
  83. assert delta >= 0.0 and delta <= 360.0
  84. self.delta = delta
  85. def __call__(self, image, boxes=None, labels=None):
  86. if random.randint(2):
  87. image[:, :, 0] += random.uniform(-self.delta, self.delta)
  88. image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0
  89. image[:, :, 0][image[:, :, 0] < 0.0] += 360.0
  90. return image, boxes, labels
  91. class RandomLightingNoise(object):
  92. def __init__(self):
  93. self.perms = ((0, 1, 2), (0, 2, 1),
  94. (1, 0, 2), (1, 2, 0),
  95. (2, 0, 1), (2, 1, 0))
  96. def __call__(self, image, boxes=None, labels=None):
  97. if random.randint(2):
  98. swap = self.perms[random.randint(len(self.perms))]
  99. shuffle = SwapChannels(swap) # shuffle channels
  100. image = shuffle(image)
  101. return image, boxes, labels
  102. class RandomContrast(object):
  103. def __init__(self, lower=0.5, upper=1.5):
  104. self.lower = lower
  105. self.upper = upper
  106. assert self.upper >= self.lower, "contrast upper must be >= lower."
  107. assert self.lower >= 0, "contrast lower must be non-negative."
  108. # expects float image
  109. def __call__(self, image, boxes=None, labels=None):
  110. if random.randint(2):
  111. alpha = random.uniform(self.lower, self.upper)
  112. image *= alpha
  113. return image, boxes, labels
  114. class RandomBrightness(object):
  115. def __init__(self, delta=32):
  116. assert delta >= 0.0
  117. assert delta <= 255.0
  118. self.delta = delta
  119. def __call__(self, image, boxes=None, labels=None):
  120. if random.randint(2):
  121. delta = random.uniform(-self.delta, self.delta)
  122. image += delta
  123. return image, boxes, labels
  124. class RandomSampleCrop(object):
  125. """Crop
  126. Arguments:
  127. img (Image): the image being input during training
  128. boxes (Tensor): the original bounding boxes in pt form
  129. labels (Tensor): the class labels for each bbox
  130. mode (float tuple): the min and max jaccard overlaps
  131. Return:
  132. (img, boxes, classes)
  133. img (Image): the cropped image
  134. boxes (Tensor): the adjusted bounding boxes in pt form
  135. labels (Tensor): the class labels for each bbox
  136. """
  137. def __init__(self):
  138. self.sample_options = (
  139. # using entire original input image
  140. None,
  141. # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9
  142. (0.1, None),
  143. (0.3, None),
  144. (0.7, None),
  145. (0.9, None),
  146. # randomly sample a patch
  147. (None, None),
  148. )
  149. def __call__(self, image, boxes=None, labels=None):
  150. height, width, _ = image.shape
  151. while True:
  152. # randomly choose a mode
  153. sample_id = np.random.randint(len(self.sample_options))
  154. mode = self.sample_options[sample_id]
  155. if mode is None:
  156. return image, boxes, labels
  157. min_iou, max_iou = mode
  158. if min_iou is None:
  159. min_iou = float('-inf')
  160. if max_iou is None:
  161. max_iou = float('inf')
  162. # max trails (50)
  163. for _ in range(50):
  164. current_image = image
  165. w = random.uniform(0.3 * width, width)
  166. h = random.uniform(0.3 * height, height)
  167. # aspect ratio constraint b/t .5 & 2
  168. if h / w < 0.5 or h / w > 2:
  169. continue
  170. left = random.uniform(width - w)
  171. top = random.uniform(height - h)
  172. # convert to integer rect x1,y1,x2,y2
  173. rect = np.array([int(left), int(top), int(left+w), int(top+h)])
  174. # calculate IoU (jaccard overlap) b/t the cropped and gt boxes
  175. overlap = jaccard_numpy(boxes, rect)
  176. # is min and max overlap constraint satisfied? if not try again
  177. if overlap.min() < min_iou and max_iou < overlap.max():
  178. continue
  179. # cut the crop from the image
  180. current_image = current_image[rect[1]:rect[3], rect[0]:rect[2],
  181. :]
  182. # keep overlap with gt box IF center in sampled patch
  183. centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
  184. # mask in all gt boxes that above and to the left of centers
  185. m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
  186. # mask in all gt boxes that under and to the right of centers
  187. m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
  188. # mask in that both m1 and m2 are true
  189. mask = m1 * m2
  190. # have any valid boxes? try again if not
  191. if not mask.any():
  192. continue
  193. # take only matching gt boxes
  194. current_boxes = boxes[mask, :].copy()
  195. # take only matching gt labels
  196. current_labels = labels[mask]
  197. # should we use the box left and top corner or the crop's
  198. current_boxes[:, :2] = np.maximum(current_boxes[:, :2],
  199. rect[:2])
  200. # adjust to crop (by substracting crop's left,top)
  201. current_boxes[:, :2] -= rect[:2]
  202. current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:],
  203. rect[2:])
  204. # adjust to crop (by substracting crop's left,top)
  205. current_boxes[:, 2:] -= rect[:2]
  206. return current_image, current_boxes, current_labels
  207. class Expand(object):
  208. def __call__(self, image, boxes, labels):
  209. if random.randint(2):
  210. return image, boxes, labels
  211. height, width, depth = image.shape
  212. ratio = random.uniform(1, 4)
  213. left = random.uniform(0, width*ratio - width)
  214. top = random.uniform(0, height*ratio - height)
  215. expand_image = np.zeros(
  216. (int(height*ratio), int(width*ratio), depth),
  217. dtype=image.dtype)
  218. expand_image[int(top):int(top + height),
  219. int(left):int(left + width)] = image
  220. image = expand_image
  221. boxes = boxes.copy()
  222. boxes[:, :2] += (int(left), int(top))
  223. boxes[:, 2:] += (int(left), int(top))
  224. return image, boxes, labels
  225. class RandomHorizontalFlip(object):
  226. def __call__(self, image, boxes, classes):
  227. _, width, _ = image.shape
  228. if random.randint(2):
  229. image = image[:, ::-1]
  230. boxes = boxes.copy()
  231. boxes[:, 0::2] = width - boxes[:, 2::-2]
  232. return image, boxes, classes
  233. class SwapChannels(object):
  234. """Transforms a tensorized image by swapping the channels in the order
  235. specified in the swap tuple.
  236. Args:
  237. swaps (int triple): final order of channels
  238. eg: (2, 1, 0)
  239. """
  240. def __init__(self, swaps):
  241. self.swaps = swaps
  242. def __call__(self, image):
  243. """
  244. Args:
  245. image (Tensor): image tensor to be transformed
  246. Return:
  247. a tensor with channels swapped according to swap
  248. """
  249. # if torch.is_tensor(image):
  250. # image = image.data.cpu().numpy()
  251. # else:
  252. # image = np.array(image)
  253. image = image[:, :, self.swaps]
  254. return image
  255. class PhotometricDistort(object):
  256. def __init__(self):
  257. self.pd = [
  258. RandomContrast(),
  259. ConvertColor(transform='HSV'),
  260. RandomSaturation(),
  261. RandomHue(),
  262. ConvertColor(current='HSV', transform='BGR'),
  263. RandomContrast()
  264. ]
  265. self.rand_brightness = RandomBrightness()
  266. def __call__(self, image, boxes, labels):
  267. im = image.copy()
  268. im, boxes, labels = self.rand_brightness(im, boxes, labels)
  269. if random.randint(2):
  270. distort = Compose(self.pd[:-1])
  271. else:
  272. distort = Compose(self.pd[1:])
  273. im, boxes, labels = distort(im, boxes, labels)
  274. return im, boxes, labels
  275. # ----------------------- Main Functions
  276. ## SSD-style Augmentation
  277. class SSDAugmentation(object):
  278. def __init__(self, img_size=640):
  279. self.img_size = img_size
  280. self.augment = Compose([
  281. ConvertFromInts(), # 将int类型转换为float32类型
  282. PhotometricDistort(), # 图像颜色增强
  283. Expand(), # 扩充增强
  284. RandomSampleCrop(), # 随机剪裁
  285. RandomHorizontalFlip(), # 随机水平翻转
  286. Resize(self.img_size) # resize操作
  287. ])
  288. def __call__(self, image, target, mosaic=False):
  289. boxes = target['boxes'].copy()
  290. labels = target['labels'].copy()
  291. deltas = None
  292. # augment
  293. image, boxes, labels = self.augment(image, boxes, labels)
  294. # to tensor
  295. img_tensor = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
  296. target['boxes'] = torch.from_numpy(boxes).float()
  297. target['labels'] = torch.from_numpy(labels).float()
  298. return img_tensor, target, deltas
  299. ## SSD-style valTransform
  300. class SSDBaseTransform(object):
  301. def __init__(self, img_size):
  302. self.img_size = img_size
  303. def __call__(self, image, target=None, mosaic=False):
  304. deltas = None
  305. # resize
  306. orig_h, orig_w = image.shape[:2]
  307. image = cv2.resize(image, (self.img_size, self.img_size)).astype(np.float32)
  308. # scale targets
  309. if target is not None:
  310. boxes = target['boxes'].copy()
  311. labels = target['labels'].copy()
  312. img_h, img_w = image.shape[:2]
  313. boxes[..., [0, 2]] = boxes[..., [0, 2]] / orig_w * img_w
  314. boxes[..., [1, 3]] = boxes[..., [1, 3]] / orig_h * img_h
  315. target['boxes'] = boxes
  316. # to tensor
  317. img_tensor = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
  318. if target is not None:
  319. target['boxes'] = torch.from_numpy(boxes).float()
  320. target['labels'] = torch.from_numpy(labels).float()
  321. return img_tensor, target, deltas