ssd_augment.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # ------------------------------------------------------------
  2. # Data preprocessor for SSD
  3. # ------------------------------------------------------------
  4. import cv2
  5. import numpy as np
  6. import torch
  7. from numpy import random
  8. # ------------------------- Augmentations -------------------------
  9. class Compose(object):
  10. """Composes several augmentations together.
  11. Args:
  12. transforms (List[Transform]): list of transforms to compose.
  13. Example:
  14. >>> augmentations.Compose([
  15. >>> transforms.CenterCrop(10),
  16. >>> transforms.ToTensor(),
  17. >>> ])
  18. """
  19. def __init__(self, transforms):
  20. self.transforms = transforms
  21. def __call__(self, img, boxes=None, labels=None):
  22. for t in self.transforms:
  23. img, boxes, labels = t(img, boxes, labels)
  24. return img, boxes, labels
  25. ## Convert Image to float type
  26. class ConvertFromInts(object):
  27. def __call__(self, image, boxes=None, labels=None):
  28. return image.astype(np.float32), boxes, labels
  29. ## Convert color format
  30. class ConvertColor(object):
  31. def __init__(self, current='BGR', transform='HSV'):
  32. self.transform = transform
  33. self.current = current
  34. def __call__(self, image, boxes=None, labels=None):
  35. if self.current == 'BGR' and self.transform == 'HSV':
  36. image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
  37. elif self.current == 'HSV' and self.transform == 'BGR':
  38. image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
  39. else:
  40. raise NotImplementedError
  41. return image, boxes, labels
  42. ## Resize image
  43. class Resize(object):
  44. def __init__(self, img_size=640):
  45. self.img_size = img_size
  46. def __call__(self, image, boxes=None, labels=None):
  47. orig_h, orig_w = image.shape[:2]
  48. image = cv2.resize(image, (self.img_size, self.img_size))
  49. # rescale bbox
  50. if boxes is not None:
  51. img_h, img_w = image.shape[:2]
  52. boxes[..., [0, 2]] = boxes[..., [0, 2]] / orig_w * img_w
  53. boxes[..., [1, 3]] = boxes[..., [1, 3]] / orig_h * img_h
  54. return image, boxes, labels
  55. ## Random Saturation
  56. class RandomSaturation(object):
  57. def __init__(self, lower=0.5, upper=1.5):
  58. self.lower = lower
  59. self.upper = upper
  60. assert self.upper >= self.lower, "contrast upper must be >= lower."
  61. assert self.lower >= 0, "contrast lower must be non-negative."
  62. def __call__(self, image, boxes=None, labels=None):
  63. if random.randint(2):
  64. image[:, :, 1] *= random.uniform(self.lower, self.upper)
  65. return image, boxes, labels
  66. ## Random Hue
  67. class RandomHue(object):
  68. def __init__(self, delta=18.0):
  69. assert delta >= 0.0 and delta <= 360.0
  70. self.delta = delta
  71. def __call__(self, image, boxes=None, labels=None):
  72. if random.randint(2):
  73. image[:, :, 0] += random.uniform(-self.delta, self.delta)
  74. image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0
  75. image[:, :, 0][image[:, :, 0] < 0.0] += 360.0
  76. return image, boxes, labels
  77. ## Random Lighting noise
  78. class RandomLightingNoise(object):
  79. def __init__(self):
  80. self.perms = ((0, 1, 2), (0, 2, 1),
  81. (1, 0, 2), (1, 2, 0),
  82. (2, 0, 1), (2, 1, 0))
  83. def __call__(self, image, boxes=None, labels=None):
  84. if random.randint(2):
  85. swap = self.perms[random.randint(len(self.perms))]
  86. shuffle = SwapChannels(swap) # shuffle channels
  87. image = shuffle(image)
  88. return image, boxes, labels
  89. ## Random Contrast
  90. class RandomContrast(object):
  91. def __init__(self, lower=0.5, upper=1.5):
  92. self.lower = lower
  93. self.upper = upper
  94. assert self.upper >= self.lower, "contrast upper must be >= lower."
  95. assert self.lower >= 0, "contrast lower must be non-negative."
  96. # expects float image
  97. def __call__(self, image, boxes=None, labels=None):
  98. if random.randint(2):
  99. alpha = random.uniform(self.lower, self.upper)
  100. image *= alpha
  101. return image, boxes, labels
  102. ## Random Brightness
  103. class RandomBrightness(object):
  104. def __init__(self, delta=32):
  105. assert delta >= 0.0
  106. assert delta <= 255.0
  107. self.delta = delta
  108. def __call__(self, image, boxes=None, labels=None):
  109. if random.randint(2):
  110. delta = random.uniform(-self.delta, self.delta)
  111. image += delta
  112. return image, boxes, labels
  113. ## Random SampleCrop
  114. class RandomSampleCrop(object):
  115. """Crop
  116. Arguments:
  117. img (Image): the image being input during training
  118. boxes (Tensor): the original bounding boxes in pt form
  119. labels (Tensor): the class labels for each bbox
  120. mode (float tuple): the min and max jaccard overlaps
  121. Return:
  122. (img, boxes, classes)
  123. img (Image): the cropped image
  124. boxes (Tensor): the adjusted bounding boxes in pt form
  125. labels (Tensor): the class labels for each bbox
  126. """
  127. def __init__(self):
  128. self.sample_options = (
  129. # using entire original input image
  130. None,
  131. # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9
  132. (0.1, None),
  133. (0.3, None),
  134. (0.7, None),
  135. (0.9, None),
  136. # randomly sample a patch
  137. (None, None),
  138. )
  139. def intersect(self, box_a, box_b):
  140. max_xy = np.minimum(box_a[:, 2:], box_b[2:])
  141. min_xy = np.maximum(box_a[:, :2], box_b[:2])
  142. inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf)
  143. return inter[:, 0] * inter[:, 1]
  144. def compute_iou(self, box_a, box_b):
  145. inter = self.intersect(box_a, box_b)
  146. area_a = ((box_a[:, 2]-box_a[:, 0]) *
  147. (box_a[:, 3]-box_a[:, 1])) # [A,B]
  148. area_b = ((box_b[2]-box_b[0]) *
  149. (box_b[3]-box_b[1])) # [A,B]
  150. union = area_a + area_b - inter
  151. return inter / union # [A,B]
  152. def __call__(self, image, boxes=None, labels=None):
  153. height, width, _ = image.shape
  154. # check
  155. if len(boxes) == 0:
  156. return image, boxes, labels
  157. while True:
  158. # randomly choose a mode
  159. sample_id = np.random.randint(len(self.sample_options))
  160. mode = self.sample_options[sample_id]
  161. if mode is None:
  162. return image, boxes, labels
  163. min_iou, max_iou = mode
  164. if min_iou is None:
  165. min_iou = float('-inf')
  166. if max_iou is None:
  167. max_iou = float('inf')
  168. # max trails (50)
  169. for _ in range(50):
  170. current_image = image
  171. w = random.uniform(0.3 * width, width)
  172. h = random.uniform(0.3 * height, height)
  173. # aspect ratio constraint b/t .5 & 2
  174. if h / w < 0.5 or h / w > 2:
  175. continue
  176. left = random.uniform(width - w)
  177. top = random.uniform(height - h)
  178. # convert to integer rect x1,y1,x2,y2
  179. rect = np.array([int(left), int(top), int(left+w), int(top+h)])
  180. # calculate IoU (jaccard overlap) b/t the cropped and gt boxes
  181. overlap = self.compute_iou(boxes, rect)
  182. # is min and max overlap constraint satisfied? if not try again
  183. if overlap.min() < min_iou and max_iou < overlap.max():
  184. continue
  185. # cut the crop from the image
  186. current_image = current_image[rect[1]:rect[3], rect[0]:rect[2],
  187. :]
  188. # keep overlap with gt box IF center in sampled patch
  189. centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
  190. # mask in all gt boxes that above and to the left of centers
  191. m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
  192. # mask in all gt boxes that under and to the right of centers
  193. m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
  194. # mask in that both m1 and m2 are true
  195. mask = m1 * m2
  196. # have any valid boxes? try again if not
  197. if not mask.any():
  198. continue
  199. # take only matching gt boxes
  200. current_boxes = boxes[mask, :].copy()
  201. # take only matching gt labels
  202. current_labels = labels[mask]
  203. # should we use the box left and top corner or the crop's
  204. current_boxes[:, :2] = np.maximum(current_boxes[:, :2],
  205. rect[:2])
  206. # adjust to crop (by substracting crop's left,top)
  207. current_boxes[:, :2] -= rect[:2]
  208. current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:],
  209. rect[2:])
  210. # adjust to crop (by substracting crop's left,top)
  211. current_boxes[:, 2:] -= rect[:2]
  212. return current_image, current_boxes, current_labels
  213. ## Random scaling
  214. class Expand(object):
  215. def __call__(self, image, boxes, labels):
  216. if random.randint(2):
  217. return image, boxes, labels
  218. height, width, depth = image.shape
  219. ratio = random.uniform(1, 4)
  220. left = random.uniform(0, width*ratio - width)
  221. top = random.uniform(0, height*ratio - height)
  222. expand_image = np.zeros(
  223. (int(height*ratio), int(width*ratio), depth),
  224. dtype=image.dtype)
  225. expand_image[int(top):int(top + height),
  226. int(left):int(left + width)] = image
  227. image = expand_image
  228. boxes = boxes.copy()
  229. boxes[:, :2] += (int(left), int(top))
  230. boxes[:, 2:] += (int(left), int(top))
  231. return image, boxes, labels
  232. ## Random HFlip
  233. class RandomHorizontalFlip(object):
  234. def __call__(self, image, boxes, classes):
  235. _, width, _ = image.shape
  236. if random.randint(2):
  237. image = image[:, ::-1]
  238. boxes = boxes.copy()
  239. boxes[:, 0::2] = width - boxes[:, 2::-2]
  240. return image, boxes, classes
  241. ## Random swap channels
  242. class SwapChannels(object):
  243. """Transforms a tensorized image by swapping the channels in the order
  244. specified in the swap tuple.
  245. Args:
  246. swaps (int triple): final order of channels
  247. eg: (2, 1, 0)
  248. """
  249. def __init__(self, swaps):
  250. self.swaps = swaps
  251. def __call__(self, image):
  252. """
  253. Args:
  254. image (Tensor): image tensor to be transformed
  255. Return:
  256. a tensor with channels swapped according to swap
  257. """
  258. # if torch.is_tensor(image):
  259. # image = image.data.cpu().numpy()
  260. # else:
  261. # image = np.array(image)
  262. image = image[:, :, self.swaps]
  263. return image
  264. ## Random color jitter
  265. class PhotometricDistort(object):
  266. def __init__(self):
  267. self.pd = [
  268. RandomContrast(),
  269. ConvertColor(transform='HSV'),
  270. RandomSaturation(),
  271. RandomHue(),
  272. ConvertColor(current='HSV', transform='BGR'),
  273. RandomContrast()
  274. ]
  275. self.rand_brightness = RandomBrightness()
  276. def __call__(self, image, boxes, labels):
  277. im = image.copy()
  278. im, boxes, labels = self.rand_brightness(im, boxes, labels)
  279. if random.randint(2):
  280. distort = Compose(self.pd[:-1])
  281. else:
  282. distort = Compose(self.pd[1:])
  283. im, boxes, labels = distort(im, boxes, labels)
  284. return im, boxes, labels
  285. # ------------------------- Preprocessers -------------------------
  286. ## SSD-style Augmentation
  287. class SSDAugmentation(object):
  288. def __init__(self, img_size=640):
  289. self.img_size = img_size
  290. self.pixel_mean = [0., 0., 0.]
  291. self.pixel_std = [255., 255., 255.]
  292. self.color_format = 'bgr'
  293. self.augment = Compose([
  294. ConvertFromInts(), # 将int类型转换为float32类型
  295. PhotometricDistort(), # 图像颜色增强
  296. Expand(), # 扩充增强
  297. RandomSampleCrop(), # 随机剪裁
  298. RandomHorizontalFlip(), # 随机水平翻转
  299. Resize(self.img_size) # resize操作
  300. ])
  301. def __call__(self, image, target, mosaic=False):
  302. orig_h, orig_w = image.shape[:2]
  303. ratio = [self.img_size / orig_w, self.img_size / orig_h]
  304. # augment
  305. boxes = target['boxes'].copy()
  306. labels = target['labels'].copy()
  307. image, boxes, labels = self.augment(image, boxes, labels)
  308. # to tensor
  309. img_tensor = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
  310. target['boxes'] = torch.from_numpy(boxes).float()
  311. target['labels'] = torch.from_numpy(labels).float()
  312. # normalize image
  313. img_tensor /= 255.
  314. return img_tensor, target, ratio
  315. ## SSD-style valTransform
  316. class SSDBaseTransform(object):
  317. def __init__(self, img_size):
  318. self.img_size = img_size
  319. self.pixel_mean = [0., 0., 0.]
  320. self.pixel_std = [255., 255., 255.]
  321. self.color_format = 'bgr'
  322. def __call__(self, image, target=None, mosaic=False):
  323. # resize
  324. orig_h, orig_w = image.shape[:2]
  325. ratio = [self.img_size / orig_w, self.img_size / orig_h]
  326. image = cv2.resize(image, (self.img_size, self.img_size)).astype(np.float32)
  327. # scale targets
  328. if target is not None:
  329. boxes = target['boxes'].copy()
  330. labels = target['labels'].copy()
  331. img_h, img_w = image.shape[:2]
  332. boxes[..., [0, 2]] = boxes[..., [0, 2]] / orig_w * img_w
  333. boxes[..., [1, 3]] = boxes[..., [1, 3]] / orig_h * img_h
  334. target['boxes'] = boxes
  335. # to tensor
  336. img_tensor = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
  337. if target is not None:
  338. target['boxes'] = torch.from_numpy(boxes).float()
  339. target['labels'] = torch.from_numpy(labels).float()
  340. # normalize image
  341. img_tensor /= 255.
  342. return img_tensor, target, ratio