augmentation.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import math
  2. import random
  3. import warnings
  4. import numpy as np
  5. import scipy.ndimage
  6. import torch
  7. from torch.autograd import Function
  8. from torch.autograd.function import once_differentiable
  9. import torch.backends.cudnn as cudnn
  10. from util.logconf import logging
  11. log = logging.getLogger(__name__)
  12. # log.setLevel(logging.WARN)
  13. # log.setLevel(logging.INFO)
  14. log.setLevel(logging.DEBUG)
  15. def cropToShape(image, new_shape, center_list=None, fill=0.0):
  16. # log.debug([image.shape, new_shape, center_list])
  17. # assert len(image.shape) == 3, repr(image.shape)
  18. if center_list is None:
  19. center_list = [int(image.shape[i] / 2) for i in range(3)]
  20. crop_list = []
  21. for i in range(0, 3):
  22. crop_int = center_list[i]
  23. if image.shape[i] > new_shape[i] and crop_int is not None:
  24. # We can't just do crop_int +/- shape/2 since shape might be odd
  25. # and ints round down.
  26. start_int = crop_int - int(new_shape[i]/2)
  27. end_int = start_int + new_shape[i]
  28. crop_list.append(slice(max(0, start_int), end_int))
  29. else:
  30. crop_list.append(slice(0, image.shape[i]))
  31. # log.debug([image.shape, crop_list])
  32. image = image[crop_list]
  33. crop_list = []
  34. for i in range(0, 3):
  35. if image.shape[i] < new_shape[i]:
  36. crop_int = int((new_shape[i] - image.shape[i]) / 2)
  37. crop_list.append(slice(crop_int, crop_int + image.shape[i]))
  38. else:
  39. crop_list.append(slice(0, image.shape[i]))
  40. # log.debug([image.shape, crop_list])
  41. new_image = np.zeros(new_shape, dtype=image.dtype)
  42. new_image[:] = fill
  43. new_image[crop_list] = image
  44. return new_image
  45. def zoomToShape(image, new_shape, square=True):
  46. # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
  47. if square and image.shape[0] != image.shape[1]:
  48. crop_int = min(image.shape[0], image.shape[1])
  49. new_shape = [crop_int, crop_int, image.shape[2]]
  50. image = cropToShape(image, new_shape)
  51. zoom_shape = [new_shape[i] / image.shape[i] for i in range(3)]
  52. with warnings.catch_warnings():
  53. warnings.simplefilter("ignore")
  54. image = scipy.ndimage.interpolation.zoom(
  55. image, zoom_shape,
  56. output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
  57. return image
  58. def randomOffset(image_list, offset_rows=0.125, offset_cols=0.125):
  59. center_list = [int(image_list[0].shape[i] / 2) for i in range(3)]
  60. center_list[0] += int(offset_rows * (random.random() - 0.5) * 2)
  61. center_list[1] += int(offset_cols * (random.random() - 0.5) * 2)
  62. center_list[2] = None
  63. new_list = []
  64. for image in image_list:
  65. new_image = cropToShape(image, image.shape, center_list)
  66. new_list.append(new_image)
  67. return new_list
  68. def randomZoom(image_list, scale=None, scale_min=0.8, scale_max=1.3):
  69. if scale is None:
  70. scale = scale_min + (scale_max - scale_min) * random.random()
  71. new_list = []
  72. for image in image_list:
  73. # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
  74. with warnings.catch_warnings():
  75. warnings.simplefilter("ignore")
  76. # log.info([image.shape])
  77. zimage = scipy.ndimage.interpolation.zoom(
  78. image, [scale, scale, 1.0],
  79. output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
  80. image = cropToShape(zimage, image.shape)
  81. new_list.append(image)
  82. return new_list
  83. _randomFlip_transform_list = [
  84. # lambda a: np.rot90(a, axes=(0, 1)),
  85. # lambda a: np.flip(a, 0),
  86. lambda a: np.flip(a, 1),
  87. ]
  88. def randomFlip(image_list, transform_bits=None):
  89. if transform_bits is None:
  90. transform_bits = random.randrange(0, 2 ** len(_randomFlip_transform_list))
  91. new_list = []
  92. for image in image_list:
  93. # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
  94. for n in range(len(_randomFlip_transform_list)):
  95. if transform_bits & 2**n:
  96. # prhist(image, 'before')
  97. image = _randomFlip_transform_list[n](image)
  98. # prhist(image, 'after ')
  99. new_list.append(image)
  100. return new_list
  101. def randomSpin(image_list, angle=None, range_tup=None, axes=(0, 1)):
  102. if range_tup is None:
  103. range_tup = (0, 360)
  104. if angle is None:
  105. angle = range_tup[0] + (range_tup[1] - range_tup[0]) * random.random()
  106. new_list = []
  107. for image in image_list:
  108. # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
  109. image = scipy.ndimage.interpolation.rotate(
  110. image, angle, axes=axes, reshape=False,
  111. output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
  112. new_list.append(image)
  113. return new_list
  114. def randomNoise(image_list, noise_min=-0.1, noise_max=0.1):
  115. noise = np.zeros_like(image_list[0])
  116. noise += (noise_max - noise_min) * np.random.random_sample(image_list[0].shape) + noise_min
  117. noise *= 5
  118. noise = scipy.ndimage.filters.gaussian_filter(noise, 3)
  119. # noise += (noise_max - noise_min) * np.random.random_sample(image_hsv.shape) + noise_min
  120. new_list = []
  121. for image_hsv in image_list:
  122. image_hsv = image_hsv + noise
  123. new_list.append(image_hsv)
  124. return new_list
  125. def randomHsvShift(image_list, h=None, s=None, v=None,
  126. h_min=-0.1, h_max=0.1,
  127. s_min=0.5, s_max=2.0,
  128. v_min=0.5, v_max=2.0):
  129. if h is None:
  130. h = h_min + (h_max - h_min) * random.random()
  131. if s is None:
  132. s = s_min + (s_max - s_min) * random.random()
  133. if v is None:
  134. v = v_min + (v_max - v_min) * random.random()
  135. new_list = []
  136. for image_hsv in image_list:
  137. # assert image_hsv.shape[-1] == 3, repr(image_hsv.shape)
  138. image_hsv[:,:,0::3] += h
  139. image_hsv[:,:,1::3] = image_hsv[:,:,1::3] ** s
  140. image_hsv[:,:,2::3] = image_hsv[:,:,2::3] ** v
  141. new_list.append(image_hsv)
  142. return clampHsv(new_list)
  143. def clampHsv(image_list):
  144. new_list = []
  145. for image_hsv in image_list:
  146. image_hsv = image_hsv.clone()
  147. # Hue wraps around
  148. image_hsv[:,:,0][image_hsv[:,:,0] > 1] -= 1
  149. image_hsv[:,:,0][image_hsv[:,:,0] < 0] += 1
  150. # Everything else clamps between 0 and 1
  151. image_hsv[image_hsv > 1] = 1
  152. image_hsv[image_hsv < 0] = 0
  153. new_list.append(image_hsv)
  154. return new_list
  155. # def torch_augment(input):
  156. # theta = random.random() * math.pi * 2
  157. # s = math.sin(theta)
  158. # c = math.cos(theta)
  159. # c1 = 1 - c
  160. # axis_vector = torch.rand(3, device='cpu', dtype=torch.float64)
  161. # axis_vector -= 0.5
  162. # axis_vector /= axis_vector.abs().sum()
  163. # l, m, n = axis_vector
  164. #
  165. # matrix = torch.tensor([
  166. # [l*l*c1 + c, m*l*c1 - n*s, n*l*c1 + m*s, 0],
  167. # [l*m*c1 + n*s, m*m*c1 + c, n*m*c1 - l*s, 0],
  168. # [l*n*c1 - m*s, m*n*c1 + l*s, n*n*c1 + c, 0],
  169. # [0, 0, 0, 1],
  170. # ], device=input.device, dtype=torch.float32)
  171. #
  172. # return th_affine3d(input, matrix)
  173. # following from https://github.com/ncullen93/torchsample/blob/master/torchsample/utils.py
  174. # MIT licensed
  175. # def th_affine3d(input, matrix):
  176. # """
  177. # 3D Affine image transform on torch.Tensor
  178. # """
  179. # A = matrix[:3,:3]
  180. # b = matrix[:3,3]
  181. #
  182. # # make a meshgrid of normal coordinates
  183. # coords = th_iterproduct(input.size(-3), input.size(-2), input.size(-1), dtype=torch.float32)
  184. #
  185. # # shift the coordinates so center is the origin
  186. # coords[:,0] = coords[:,0] - (input.size(-3) / 2. - 0.5)
  187. # coords[:,1] = coords[:,1] - (input.size(-2) / 2. - 0.5)
  188. # coords[:,2] = coords[:,2] - (input.size(-1) / 2. - 0.5)
  189. #
  190. # # apply the coordinate transformation
  191. # new_coords = coords.mm(A.t().contiguous()) + b.expand_as(coords)
  192. #
  193. # # shift the coordinates back so origin is origin
  194. # new_coords[:,0] = new_coords[:,0] + (input.size(-3) / 2. - 0.5)
  195. # new_coords[:,1] = new_coords[:,1] + (input.size(-2) / 2. - 0.5)
  196. # new_coords[:,2] = new_coords[:,2] + (input.size(-1) / 2. - 0.5)
  197. #
  198. # # map new coordinates using bilinear interpolation
  199. # input_transformed = th_trilinear_interp3d(input, new_coords)
  200. #
  201. # return input_transformed
  202. #
  203. #
  204. # def th_trilinear_interp3d(input, coords):
  205. # """
  206. # trilinear interpolation of 3D torch.Tensor image
  207. # """
  208. # # take clamp then floor/ceil of x coords
  209. # x = torch.clamp(coords[:,0], 0, input.size(-3)-2)
  210. # x0 = x.floor()
  211. # x1 = x0 + 1
  212. # # take clamp then floor/ceil of y coords
  213. # y = torch.clamp(coords[:,1], 0, input.size(-2)-2)
  214. # y0 = y.floor()
  215. # y1 = y0 + 1
  216. # # take clamp then floor/ceil of z coords
  217. # z = torch.clamp(coords[:,2], 0, input.size(-1)-2)
  218. # z0 = z.floor()
  219. # z1 = z0 + 1
  220. #
  221. # stride = torch.tensor(input.stride()[-3:], dtype=torch.int64, device=input.device)
  222. # x0_ix = x0.mul(stride[0]).long()
  223. # x1_ix = x1.mul(stride[0]).long()
  224. # y0_ix = y0.mul(stride[1]).long()
  225. # y1_ix = y1.mul(stride[1]).long()
  226. # z0_ix = z0.mul(stride[2]).long()
  227. # z1_ix = z1.mul(stride[2]).long()
  228. #
  229. # # input_flat = th_flatten(input)
  230. # input_flat = x.contiguous().view(x[0], x[1], -1)
  231. #
  232. # vals_000 = input_flat[:, :, x0_ix+y0_ix+z0_ix]
  233. # vals_001 = input_flat[:, :, x0_ix+y0_ix+z1_ix]
  234. # vals_010 = input_flat[:, :, x0_ix+y1_ix+z0_ix]
  235. # vals_011 = input_flat[:, :, x0_ix+y1_ix+z1_ix]
  236. # vals_100 = input_flat[:, :, x1_ix+y0_ix+z0_ix]
  237. # vals_101 = input_flat[:, :, x1_ix+y0_ix+z1_ix]
  238. # vals_110 = input_flat[:, :, x1_ix+y1_ix+z0_ix]
  239. # vals_111 = input_flat[:, :, x1_ix+y1_ix+z1_ix]
  240. #
  241. # xd = x - x0
  242. # yd = y - y0
  243. # zd = z - z0
  244. # xm1 = 1 - xd
  245. # ym1 = 1 - yd
  246. # zm1 = 1 - zd
  247. #
  248. # x_mapped = (
  249. # vals_000.mul(xm1).mul(ym1).mul(zm1) +
  250. # vals_001.mul(xm1).mul(ym1).mul(zd) +
  251. # vals_010.mul(xm1).mul(yd).mul(zm1) +
  252. # vals_011.mul(xm1).mul(yd).mul(zd) +
  253. # vals_100.mul(xd).mul(ym1).mul(zm1) +
  254. # vals_101.mul(xd).mul(ym1).mul(zd) +
  255. # vals_110.mul(xd).mul(yd).mul(zm1) +
  256. # vals_111.mul(xd).mul(yd).mul(zd)
  257. # )
  258. #
  259. # return x_mapped.view_as(input)
  260. #
  261. # def th_iterproduct(*args, dtype=None):
  262. # return torch.from_numpy(np.indices(args).reshape((len(args),-1)).T)
  263. #
  264. # def th_flatten(x):
  265. # """Flatten tensor"""
  266. # return x.contiguous().view(x[0], x[1], -1)