transforms.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. Transforms and data augmentation for both image + bbox.
  4. """
  5. import PIL
  6. import random
  7. import torch
  8. import torchvision
  9. import torchvision.transforms as T
  10. import torchvision.transforms.functional as F
  11. # ----------------- Basic transform functions -----------------
  12. def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
  13. return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
  14. def crop(image, target, region):
  15. cropped_image = F.crop(image, *region)
  16. target = target.copy()
  17. i, j, h, w = region
  18. # should we do something wrt the original size?
  19. target["size"] = torch.tensor([h, w])
  20. fields = ["labels", "area", "iscrowd"]
  21. if "boxes" in target:
  22. boxes = target["boxes"]
  23. max_size = torch.as_tensor([w, h], dtype=torch.float32)
  24. cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
  25. cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
  26. cropped_boxes = cropped_boxes.clamp(min=0)
  27. area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
  28. target["boxes"] = cropped_boxes.reshape(-1, 4)
  29. target["area"] = area
  30. fields.append("boxes")
  31. if "masks" in target:
  32. # FIXME should we update the area here if there are no boxes?
  33. target['masks'] = target['masks'][:, i:i + h, j:j + w]
  34. fields.append("masks")
  35. # remove elements for which the boxes or masks that have zero area
  36. if "boxes" in target or "masks" in target:
  37. # favor boxes selection when defining which elements to keep
  38. # this is compatible with previous implementation
  39. if "boxes" in target:
  40. cropped_boxes = target['boxes'].reshape(-1, 2, 2)
  41. keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
  42. else:
  43. keep = target['masks'].flatten(1).any(1)
  44. for field in fields:
  45. target[field] = target[field][keep]
  46. return cropped_image, target
  47. def hflip(image, target):
  48. flipped_image = F.hflip(image)
  49. w, h = image.size
  50. target = target.copy()
  51. if "boxes" in target:
  52. boxes = target["boxes"]
  53. boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
  54. target["boxes"] = boxes
  55. if "masks" in target:
  56. target['masks'] = target['masks'].flip(-1)
  57. return flipped_image, target
  58. def resize(image, target, size, max_size=None):
  59. # size can be min_size (scalar) or (w, h) tuple
  60. def get_size_with_aspect_ratio(image_size, size, max_size=None):
  61. w, h = image_size
  62. if max_size is not None:
  63. min_original_size = float(min((w, h)))
  64. max_original_size = float(max((w, h)))
  65. if max_original_size / min_original_size * size > max_size:
  66. size = int(round(max_size * min_original_size / max_original_size))
  67. if (w <= h and w == size) or (h <= w and h == size):
  68. return (h, w)
  69. if w < h:
  70. ow = size
  71. oh = int(size * h / w)
  72. else:
  73. oh = size
  74. ow = int(size * w / h)
  75. return (oh, ow)
  76. def get_size(image_size, size, max_size=None):
  77. if isinstance(size, (list, tuple)):
  78. return size[::-1]
  79. else:
  80. return get_size_with_aspect_ratio(image_size, size, max_size)
  81. size = get_size(image.size, size, max_size)
  82. rescaled_image = F.resize(image, size)
  83. if target is None:
  84. return rescaled_image, None
  85. ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
  86. ratio_width, ratio_height = ratios
  87. target = target.copy()
  88. if "boxes" in target:
  89. boxes = target["boxes"]
  90. scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
  91. target["boxes"] = scaled_boxes
  92. if "area" in target:
  93. area = target["area"]
  94. scaled_area = area * (ratio_width * ratio_height)
  95. target["area"] = scaled_area
  96. h, w = size
  97. target["size"] = torch.tensor([h, w])
  98. if "masks" in target:
  99. target['masks'] = interpolate(
  100. target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
  101. return rescaled_image, target
  102. def pad(image, target, padding):
  103. # assumes that we only pad on the bottom right corners
  104. padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
  105. if target is None:
  106. return padded_image, None
  107. target = target.copy()
  108. # should we do something wrt the original size?
  109. target["size"] = torch.tensor(padded_image.size[::-1])
  110. if "masks" in target:
  111. target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
  112. return padded_image, target
  113. # ----------------- Basic transform -----------------
  114. class RandomCrop(object):
  115. def __init__(self, size):
  116. self.size = size
  117. def __call__(self, img, target=None):
  118. region = T.RandomCrop.get_params(img, self.size)
  119. return crop(img, target, region)
  120. class RandomSizeCrop(object):
  121. def __init__(self, min_size: int, max_size: int):
  122. self.min_size = min_size
  123. self.max_size = max_size
  124. def __call__(self, img: PIL.Image.Image, target: dict = None):
  125. w = random.randint(self.min_size, min(img.width, self.max_size))
  126. h = random.randint(self.min_size, min(img.height, self.max_size))
  127. region = T.RandomCrop.get_params(img, [h, w])
  128. return crop(img, target, region)
  129. class RandomHorizontalFlip(object):
  130. def __init__(self, p=0.5):
  131. self.p = p
  132. def __call__(self, img, target=None):
  133. if random.random() < self.p:
  134. return hflip(img, target)
  135. return img, target
  136. class RandomResize(object):
  137. def __init__(self, sizes, max_size=None):
  138. assert isinstance(sizes, (list, tuple))
  139. self.sizes = sizes
  140. self.max_size = max_size
  141. def __call__(self, img, target=None):
  142. size = random.choice(self.sizes)
  143. return resize(img, target, size, self.max_size)
  144. class RandomShift(object):
  145. def __init__(self, p=0.5, max_shift=32):
  146. self.p = p
  147. self.max_shift = max_shift
  148. def __call__(self, image, target=None):
  149. if random.random() < self.p:
  150. img_h, img_w = image.height, image.width
  151. shift_x = random.randint(-self.max_shift, self.max_shift)
  152. shift_y = random.randint(-self.max_shift, self.max_shift)
  153. shifted_image = F.affine(image, translate=[shift_x, shift_y], angle=0, scale=1.0, shear=0)
  154. target = target.copy()
  155. target["boxes"][..., [0, 2]] += shift_x
  156. target["boxes"][..., [1, 3]] += shift_y
  157. target["boxes"][..., [0, 2]] = target["boxes"][..., [0, 2]].clip(0, img_w)
  158. target["boxes"][..., [1, 3]] = target["boxes"][..., [1, 3]].clip(0, img_h)
  159. return shifted_image, target
  160. return image, target
  161. class RandomSelect(object):
  162. """
  163. Randomly selects between transforms1 and transforms2,
  164. with probability p for transforms1 and (1 - p) for transforms2
  165. """
  166. def __init__(self, transforms1, transforms2, p=0.5):
  167. self.transforms1 = transforms1
  168. self.transforms2 = transforms2
  169. self.p = p
  170. def __call__(self, img, target=None):
  171. if random.random() < self.p:
  172. return self.transforms1(img, target)
  173. return self.transforms2(img, target)
  174. class ToTensor(object):
  175. def __call__(self, img, target=None):
  176. return F.to_tensor(img), target
  177. class Normalize(object):
  178. def __init__(self, mean, std, normalize_coords=False):
  179. self.mean = mean
  180. self.std = std
  181. self.normalize_coords = normalize_coords
  182. def __call__(self, image, target=None):
  183. image = F.normalize(image, mean=self.mean, std=self.std)
  184. if target is None:
  185. return image, None
  186. if self.normalize_coords:
  187. target = target.copy()
  188. h, w = image.shape[-2:]
  189. if "boxes" in target:
  190. boxes = target["boxes"]
  191. boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
  192. target["boxes"] = boxes
  193. return image, target
  194. class RefineBBox(object):
  195. def __init__(self, min_box_size=1):
  196. self.min_box_size = min_box_size
  197. def __call__(self, img, target):
  198. boxes = target["boxes"].clone()
  199. labels = target["labels"].clone()
  200. tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
  201. min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
  202. keep = (min_tgt_size >= self.min_box_size)
  203. target["boxes"] = boxes[keep]
  204. target["labels"] = labels[keep]
  205. return img, target
  206. class ConvertBoxFormat(object):
  207. def __init__(self, box_format="xyxy"):
  208. self.box_format = box_format
  209. def __call__(self, image, target=None):
  210. # convert box format
  211. if self.box_format == "xyxy" or target is None:
  212. pass
  213. elif self.box_format == "xywh":
  214. target = target.copy()
  215. if "boxes" in target:
  216. boxes_xyxy = target["boxes"]
  217. boxes_xywh = torch.zeros_like(boxes_xyxy)
  218. boxes_xywh[..., :2] = (boxes_xyxy[..., :2] + boxes_xyxy[..., 2:]) * 0.5 # cxcy
  219. boxes_xywh[..., 2:] = boxes_xyxy[..., 2:] - boxes_xyxy[..., :2] # bwbh
  220. target["boxes"] = boxes_xywh
  221. else:
  222. raise NotImplementedError("Unknown box format: {}".format(self.box_format))
  223. return image, target
  224. class Compose(object):
  225. def __init__(self, transforms):
  226. self.transforms = transforms
  227. def __call__(self, image, target=None):
  228. for t in self.transforms:
  229. image, target = t(image, target)
  230. return image, target
  231. def __repr__(self):
  232. format_string = self.__class__.__name__ + "("
  233. for t in self.transforms:
  234. format_string += "\n"
  235. format_string += " {0}".format(t)
  236. format_string += "\n)"
  237. return format_string
  238. # build transforms
  239. def build_transform(cfg, is_train=False):
  240. # ---------------- Transform for Training ----------------
  241. if is_train:
  242. transforms = []
  243. trans_config = cfg.trans_config
  244. # build transform
  245. if not cfg.detr_style:
  246. for t in trans_config:
  247. if t['name'] == 'RandomHFlip':
  248. transforms.append(RandomHorizontalFlip())
  249. if t['name'] == 'RandomResize':
  250. transforms.append(RandomResize(cfg.train_min_size, max_size=cfg.train_max_size))
  251. if t['name'] == 'RandomSizeCrop':
  252. transforms.append(RandomSizeCrop(t['min_crop_size'], max_size=t['max_crop_size']))
  253. if t['name'] == 'RandomShift':
  254. transforms.append(RandomShift(max_shift=t['max_shift']))
  255. if t['name'] == 'RefineBBox':
  256. transforms.append(RefineBBox(min_box_size=t['min_box_size']))
  257. transforms.extend([
  258. ToTensor(),
  259. Normalize(cfg.pixel_mean, cfg.pixel_std, cfg.normalize_coords),
  260. ConvertBoxFormat(cfg.box_format)
  261. ])
  262. # build transform for DETR-style detector
  263. else:
  264. transforms = [
  265. RandomHorizontalFlip(),
  266. RandomSelect(
  267. RandomResize(cfg.train_min_size, max_size=cfg.train_max_size),
  268. Compose([
  269. RandomResize(cfg.train_min_size2),
  270. RandomSizeCrop(*cfg.random_crop_size),
  271. RandomResize(cfg.train_min_size, max_size=cfg.train_max_size),
  272. ])
  273. ),
  274. ToTensor(),
  275. Normalize(cfg.pixel_mean, cfg.pixel_std, cfg.normalize_coords),
  276. ConvertBoxFormat(cfg.box_format)
  277. ]
  278. # ---------------- Transform for Evaluating ----------------
  279. else:
  280. transforms = [
  281. RandomResize(cfg.test_min_size, max_size=cfg.test_max_size),
  282. ToTensor(),
  283. Normalize(cfg.pixel_mean, cfg.pixel_std, cfg.normalize_coords),
  284. ConvertBoxFormat(cfg.box_format)
  285. ]
  286. return Compose(transforms)