transforms.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Transforms and data augmentation for both image + bbox.
  4. """
  5. import PIL
  6. import random
  7. import torch
  8. import torchvision
  9. import torchvision.transforms as T
  10. import torchvision.transforms.functional as F
  11. # ----------------- Basic transform functions -----------------
  12. def box_xyxy_to_cxcywh(x):
  13. x0, y0, x1, y1 = x.unbind(-1)
  14. b = [(x0 + x1) / 2, (y0 + y1) / 2,
  15. (x1 - x0), (y1 - y0)]
  16. return torch.stack(b, dim=-1)
  17. def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
  18. return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
  19. def crop(image, target, region):
  20. cropped_image = F.crop(image, *region)
  21. target = target.copy()
  22. i, j, h, w = region
  23. # should we do something wrt the original size?
  24. target["size"] = torch.tensor([h, w])
  25. fields = ["labels", "area", "iscrowd"]
  26. if "boxes" in target:
  27. boxes = target["boxes"]
  28. max_size = torch.as_tensor([w, h], dtype=torch.float32)
  29. cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
  30. cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
  31. cropped_boxes = cropped_boxes.clamp(min=0)
  32. area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
  33. target["boxes"] = cropped_boxes.reshape(-1, 4)
  34. target["area"] = area
  35. fields.append("boxes")
  36. if "masks" in target:
  37. # FIXME should we update the area here if there are no boxes?
  38. target['masks'] = target['masks'][:, i:i + h, j:j + w]
  39. fields.append("masks")
  40. # remove elements for which the boxes or masks that have zero area
  41. if "boxes" in target or "masks" in target:
  42. # favor boxes selection when defining which elements to keep
  43. # this is compatible with previous implementation
  44. if "boxes" in target:
  45. cropped_boxes = target['boxes'].reshape(-1, 2, 2)
  46. keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
  47. else:
  48. keep = target['masks'].flatten(1).any(1)
  49. for field in fields:
  50. target[field] = target[field][keep]
  51. return cropped_image, target
  52. def hflip(image, target):
  53. flipped_image = F.hflip(image)
  54. w, h = image.size
  55. target = target.copy()
  56. if "boxes" in target:
  57. boxes = target["boxes"]
  58. boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
  59. target["boxes"] = boxes
  60. if "masks" in target:
  61. target['masks'] = target['masks'].flip(-1)
  62. return flipped_image, target
  63. def resize(image, target, size, max_size=None):
  64. # size can be min_size (scalar) or (w, h) tuple
  65. def get_size_with_aspect_ratio(image_size, size, max_size=None):
  66. w, h = image_size
  67. if max_size is not None:
  68. min_original_size = float(min((w, h)))
  69. max_original_size = float(max((w, h)))
  70. if max_original_size / min_original_size * size > max_size:
  71. size = int(round(max_size * min_original_size / max_original_size))
  72. if (w <= h and w == size) or (h <= w and h == size):
  73. return (h, w)
  74. if w < h:
  75. ow = size
  76. oh = int(size * h / w)
  77. else:
  78. oh = size
  79. ow = int(size * w / h)
  80. return (oh, ow)
  81. def get_size(image_size, size, max_size=None):
  82. if isinstance(size, (list, tuple)):
  83. return size[::-1]
  84. else:
  85. return get_size_with_aspect_ratio(image_size, size, max_size)
  86. size = get_size(image.size, size, max_size)
  87. rescaled_image = F.resize(image, size)
  88. if target is None:
  89. return rescaled_image, None
  90. ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
  91. ratio_width, ratio_height = ratios
  92. target = target.copy()
  93. if "boxes" in target:
  94. boxes = target["boxes"]
  95. scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
  96. target["boxes"] = scaled_boxes
  97. if "area" in target:
  98. area = target["area"]
  99. scaled_area = area * (ratio_width * ratio_height)
  100. target["area"] = scaled_area
  101. h, w = size
  102. target["size"] = torch.tensor([h, w])
  103. if "masks" in target:
  104. target['masks'] = interpolate(
  105. target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
  106. return rescaled_image, target
  107. def pad(image, target, padding):
  108. # assumes that we only pad on the bottom right corners
  109. padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
  110. if target is None:
  111. return padded_image, None
  112. target = target.copy()
  113. # should we do something wrt the original size?
  114. target["size"] = torch.tensor(padded_image.size[::-1])
  115. if "masks" in target:
  116. target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
  117. return padded_image, target
  118. # ----------------- Basic transform -----------------
  119. class RandomCrop(object):
  120. def __init__(self, size):
  121. self.size = size
  122. def __call__(self, img, target=None):
  123. region = T.RandomCrop.get_params(img, self.size)
  124. return crop(img, target, region)
  125. class RandomSizeCrop(object):
  126. def __init__(self, min_size: int, max_size: int):
  127. self.min_size = min_size
  128. self.max_size = max_size
  129. def __call__(self, img: PIL.Image.Image, target: dict = None):
  130. w = random.randint(self.min_size, min(img.width, self.max_size))
  131. h = random.randint(self.min_size, min(img.height, self.max_size))
  132. region = T.RandomCrop.get_params(img, [h, w])
  133. return crop(img, target, region)
  134. class RandomHorizontalFlip(object):
  135. def __init__(self, p=0.5):
  136. self.p = p
  137. def __call__(self, img, target=None):
  138. if random.random() < self.p:
  139. return hflip(img, target)
  140. return img, target
  141. class RandomResize(object):
  142. def __init__(self, sizes, max_size=None):
  143. assert isinstance(sizes, (list, tuple))
  144. self.sizes = sizes
  145. self.max_size = max_size
  146. def __call__(self, img, target=None):
  147. size = random.choice(self.sizes)
  148. return resize(img, target, size, self.max_size)
  149. class RandomShift(object):
  150. def __init__(self, p=0.5, max_shift=32):
  151. self.p = p
  152. self.max_shift = max_shift
  153. def __call__(self, image, target=None):
  154. if random.random() < self.p:
  155. img_h, img_w = image.height, image.width
  156. shift_x = random.randint(-self.max_shift, self.max_shift)
  157. shift_y = random.randint(-self.max_shift, self.max_shift)
  158. shifted_image = F.affine(image, translate=[shift_x, shift_y], angle=0, scale=1.0, shear=0)
  159. target = target.copy()
  160. target["boxes"][..., [0, 2]] += shift_x
  161. target["boxes"][..., [1, 3]] += shift_y
  162. target["boxes"][..., [0, 2]] = target["boxes"][..., [0, 2]].clip(0, img_w)
  163. target["boxes"][..., [1, 3]] = target["boxes"][..., [1, 3]].clip(0, img_h)
  164. return shifted_image, target
  165. return image, target
  166. class RandomSelect(object):
  167. """
  168. Randomly selects between transforms1 and transforms2,
  169. with probability p for transforms1 and (1 - p) for transforms2
  170. """
  171. def __init__(self, transforms1, transforms2, p=0.5):
  172. self.transforms1 = transforms1
  173. self.transforms2 = transforms2
  174. self.p = p
  175. def __call__(self, img, target=None):
  176. if random.random() < self.p:
  177. return self.transforms1(img, target)
  178. return self.transforms2(img, target)
  179. class ToTensor(object):
  180. def __call__(self, img, target=None):
  181. return F.to_tensor(img), target
  182. class Normalize(object):
  183. def __init__(self, mean, std, normalize_coords=False):
  184. self.mean = mean
  185. self.std = std
  186. self.normalize_coords = normalize_coords
  187. def __call__(self, image, target=None):
  188. image = F.normalize(image, mean=self.mean, std=self.std)
  189. if target is None:
  190. return image, None
  191. if self.normalize_coords:
  192. target = target.copy()
  193. h, w = image.shape[-2:]
  194. if "boxes" in target:
  195. boxes = target["boxes"]
  196. boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
  197. target["boxes"] = boxes
  198. return image, target
  199. class RefineBBox(object):
  200. def __init__(self, min_box_size=1):
  201. self.min_box_size = min_box_size
  202. def __call__(self, img, target):
  203. boxes = target["boxes"].clone()
  204. labels = target["labels"].clone()
  205. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  206. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  207. keep = (min_tgt_size >= self.min_box_size)
  208. target["boxes"] = boxes[keep]
  209. target["labels"] = labels[keep]
  210. return img, target
  211. class ConvertBoxFormat(object):
  212. def __init__(self, box_format="xyxy"):
  213. self.box_format = box_format
  214. def __call__(self, image, target=None):
  215. # convert box format
  216. if self.box_format == "xyxy" or target is None:
  217. pass
  218. elif self.box_format == "xywh":
  219. target = target.copy()
  220. if "boxes" in target:
  221. boxes_xyxy = target["boxes"]
  222. boxes_xywh = torch.zeros_like(boxes_xyxy)
  223. boxes_xywh[..., :2] = (boxes_xyxy[..., :2] + boxes_xyxy[..., 2:]) * 0.5 # cxcy
  224. boxes_xywh[..., 2:] = boxes_xyxy[..., 2:] - boxes_xyxy[..., :2] # bwbh
  225. target["boxes"] = boxes_xywh
  226. else:
  227. raise NotImplementedError("Unknown box format: {}".format(self.box_format))
  228. return image, target
  229. class Compose(object):
  230. def __init__(self, transforms):
  231. self.transforms = transforms
  232. def __call__(self, image, target=None):
  233. for t in self.transforms:
  234. image, target = t(image, target)
  235. return image, target
  236. def __repr__(self):
  237. format_string = self.__class__.__name__ + "("
  238. for t in self.transforms:
  239. format_string += "\n"
  240. format_string += " {0}".format(t)
  241. format_string += "\n)"
  242. return format_string
  243. # build transforms
  244. def build_transform(cfg=None, is_train=False):
  245. # ---------------- Transform for Training ----------------
  246. if is_train:
  247. transforms = []
  248. trans_config = cfg['trans_config']
  249. # build transform
  250. if not cfg['detr_style']:
  251. for t in trans_config:
  252. if t['name'] == 'RandomHFlip':
  253. transforms.append(RandomHorizontalFlip())
  254. if t['name'] == 'RandomResize':
  255. transforms.append(RandomResize(cfg['train_min_size'], max_size=cfg['train_max_size']))
  256. if t['name'] == 'RandomSizeCrop':
  257. transforms.append(RandomSizeCrop(t['min_crop_size'], max_size=t['max_crop_size']))
  258. if t['name'] == 'RandomShift':
  259. transforms.append(RandomShift(max_shift=t['max_shift']))
  260. if t['name'] == 'RefineBBox':
  261. transforms.append(RefineBBox(min_box_size=t['min_box_size']))
  262. transforms.extend([
  263. ToTensor(),
  264. Normalize(cfg['pixel_mean'], cfg['pixel_std'], cfg['normalize_coords']),
  265. ConvertBoxFormat(cfg['box_format'])
  266. ])
  267. # build transform for DETR-style detector
  268. else:
  269. transforms = [
  270. RandomHorizontalFlip(),
  271. RandomSelect(
  272. RandomResize(cfg['train_min_size'], max_size=cfg['train_max_size']),
  273. Compose([
  274. RandomResize(cfg['train_min_size2']),
  275. RandomSizeCrop(*cfg['random_crop_size']),
  276. RandomResize(cfg['train_min_size'], max_size=cfg['train_max_size']),
  277. ])
  278. ),
  279. ToTensor(),
  280. Normalize(cfg['pixel_mean'], cfg['pixel_std'], cfg['normalize_coords']),
  281. ConvertBoxFormat(cfg['box_format'])
  282. ]
  283. # ---------------- Transform for Evaluating ----------------
  284. else:
  285. transforms = [
  286. RandomResize(cfg['test_min_size'], max_size=cfg['test_max_size']),
  287. ToTensor(),
  288. Normalize(cfg['pixel_mean'], cfg['pixel_std'], cfg['normalize_coords']),
  289. ConvertBoxFormat(cfg['box_format'])
  290. ]
  291. return Compose(transforms)