augmentation.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import random
  2. import warnings
  3. import numpy as np
  4. import scipy.ndimage
  5. from util.logconf import logging
  6. log = logging.getLogger(__name__)
  7. # log.setLevel(logging.WARN)
  8. # log.setLevel(logging.INFO)
  9. log.setLevel(logging.DEBUG)
  10. def cropToShape(image, new_shape, center_list=None, fill=0.0):
  11. # log.debug([image.shape, new_shape, center_list])
  12. # assert len(image.shape) == 3, repr(image.shape)
  13. if center_list is None:
  14. center_list = [int(image.shape[i] / 2) for i in range(3)]
  15. crop_list = []
  16. for i in range(0, 3):
  17. crop_int = center_list[i]
  18. if image.shape[i] > new_shape[i] and crop_int is not None:
  19. # We can't just do crop_int +/- shape/2 since shape might be odd
  20. # and ints round down.
  21. start_int = crop_int - int(new_shape[i]/2)
  22. end_int = start_int + new_shape[i]
  23. crop_list.append(slice(max(0, start_int), end_int))
  24. else:
  25. crop_list.append(slice(0, image.shape[i]))
  26. # log.debug([image.shape, crop_list])
  27. image = image[crop_list]
  28. crop_list = []
  29. for i in range(0, 3):
  30. if image.shape[i] < new_shape[i]:
  31. crop_int = int((new_shape[i] - image.shape[i]) / 2)
  32. crop_list.append(slice(crop_int, crop_int + image.shape[i]))
  33. else:
  34. crop_list.append(slice(0, image.shape[i]))
  35. # log.debug([image.shape, crop_list])
  36. new_image = np.zeros(new_shape, dtype=image.dtype)
  37. new_image[:] = fill
  38. new_image[crop_list] = image
  39. return new_image
  40. def zoomToShape(image, new_shape, square=True):
  41. # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
  42. if square and image.shape[0] != image.shape[1]:
  43. crop_int = min(image.shape[0], image.shape[1])
  44. new_shape = [crop_int, crop_int, image.shape[2]]
  45. image = cropToShape(image, new_shape)
  46. zoom_shape = [new_shape[i] / image.shape[i] for i in range(3)]
  47. with warnings.catch_warnings():
  48. warnings.simplefilter("ignore")
  49. image = scipy.ndimage.interpolation.zoom(
  50. image, zoom_shape,
  51. output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
  52. return image
  53. def randomOffset(image_list, offset_rows=0.125, offset_cols=0.125):
  54. center_list = [int(image_list[0].shape[i] / 2) for i in range(3)]
  55. center_list[0] += int(offset_rows * (random.random() - 0.5) * 2)
  56. center_list[1] += int(offset_cols * (random.random() - 0.5) * 2)
  57. center_list[2] = None
  58. new_list = []
  59. for image in image_list:
  60. new_image = cropToShape(image, image.shape, center_list)
  61. new_list.append(new_image)
  62. return new_list
  63. def randomZoom(image_list, scale=None, scale_min=0.8, scale_max=1.3):
  64. if scale is None:
  65. scale = scale_min + (scale_max - scale_min) * random.random()
  66. new_list = []
  67. for image in image_list:
  68. # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
  69. with warnings.catch_warnings():
  70. warnings.simplefilter("ignore")
  71. # log.info([image.shape])
  72. zimage = scipy.ndimage.interpolation.zoom(
  73. image, [scale, scale, 1.0],
  74. output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
  75. image = cropToShape(zimage, image.shape)
  76. new_list.append(image)
  77. return new_list
  78. _randomFlip_transform_list = [
  79. # lambda a: np.rot90(a, axes=(0, 1)),
  80. # lambda a: np.flip(a, 0),
  81. lambda a: np.flip(a, 1),
  82. ]
  83. def randomFlip(image_list, transform_bits=None):
  84. if transform_bits is None:
  85. transform_bits = random.randrange(0, 2 ** len(_randomFlip_transform_list))
  86. new_list = []
  87. for image in image_list:
  88. # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
  89. for n in range(len(_randomFlip_transform_list)):
  90. if transform_bits & 2**n:
  91. # prhist(image, 'before')
  92. image = _randomFlip_transform_list[n](image)
  93. # prhist(image, 'after ')
  94. new_list.append(image)
  95. return new_list
  96. def randomSpin(image_list, angle=None, range_tup=None, axes=(0, 1)):
  97. if range_tup is None:
  98. range_tup = (0, 360)
  99. if angle is None:
  100. angle = range_tup[0] + (range_tup[1] - range_tup[0]) * random.random()
  101. new_list = []
  102. for image in image_list:
  103. # assert image.shape[-1] in {1, 3, 4}, repr(image.shape)
  104. image = scipy.ndimage.interpolation.rotate(
  105. image, angle, axes=axes, reshape=False,
  106. output=None, order=0, mode='nearest', cval=0.0, prefilter=True)
  107. new_list.append(image)
  108. return new_list
  109. def randomNoise(image_list, noise_min=-0.1, noise_max=0.1):
  110. noise = np.zeros_like(image_list[0])
  111. noise += (noise_max - noise_min) * np.random.random_sample(image_list[0].shape) + noise_min
  112. noise *= 5
  113. noise = scipy.ndimage.filters.gaussian_filter(noise, 3)
  114. # noise += (noise_max - noise_min) * np.random.random_sample(image_hsv.shape) + noise_min
  115. new_list = []
  116. for image_hsv in image_list:
  117. image_hsv = image_hsv + noise
  118. new_list.append(image_hsv)
  119. return new_list
  120. def randomHsvShift(image_list, h=None, s=None, v=None,
  121. h_min=-0.1, h_max=0.1,
  122. s_min=0.5, s_max=2.0,
  123. v_min=0.5, v_max=2.0):
  124. if h is None:
  125. h = h_min + (h_max - h_min) * random.random()
  126. if s is None:
  127. s = s_min + (s_max - s_min) * random.random()
  128. if v is None:
  129. v = v_min + (v_max - v_min) * random.random()
  130. new_list = []
  131. for image_hsv in image_list:
  132. # assert image_hsv.shape[-1] == 3, repr(image_hsv.shape)
  133. image_hsv[:,:,0::3] += h
  134. image_hsv[:,:,1::3] = image_hsv[:,:,1::3] ** s
  135. image_hsv[:,:,2::3] = image_hsv[:,:,2::3] ** v
  136. new_list.append(image_hsv)
  137. return clampHsv(new_list)
  138. def clampHsv(image_list):
  139. new_list = []
  140. for image_hsv in image_list:
  141. image_hsv = image_hsv.clone()
  142. # Hue wraps around
  143. image_hsv[:,:,0][image_hsv[:,:,0] > 1] -= 1
  144. image_hsv[:,:,0][image_hsv[:,:,0] < 0] += 1
  145. # Everything else clamps between 0 and 1
  146. image_hsv[image_hsv > 1] = 1
  147. image_hsv[image_hsv < 0] = 0
  148. new_list.append(image_hsv)
  149. return new_list