dsets.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import csv
  2. import functools
  3. import glob
  4. import itertools
  5. import math
  6. import random
  7. import time
  8. import warnings
  9. import scipy.ndimage
  10. import SimpleITK as sitk
  11. import numpy as np
  12. import torch
  13. import torch.cuda
  14. from torch.utils.data import Dataset
  15. from util.disk import getCache
  16. from util.util import XyzTuple, xyz2irc
  17. from util.logconf import logging
  18. log = logging.getLogger(__name__)
  19. # log.setLevel(logging.WARN)
  20. log.setLevel(logging.INFO)
  21. log.setLevel(logging.DEBUG)
  22. cache = getCache('part2')
  23. class Ct(object):
  24. def __init__(self, series_uid):
  25. mhd_path = glob.glob('data/luna/subset*/{}.mhd'.format(series_uid))[0]
  26. ct_mhd = sitk.ReadImage(mhd_path)
  27. ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  28. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  29. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  30. # This converts HU to g/cc.
  31. ct_ary += 1000
  32. ct_ary /= 1000
  33. # This gets rid of negative density stuff used to indicate out-of-FOV
  34. ct_ary[ct_ary < 0] = 0
  35. # This nukes any weird hotspots and clamps bone down
  36. ct_ary[ct_ary > 2] = 2
  37. self.series_uid = series_uid
  38. self.ary = ct_ary
  39. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  40. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  41. self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
  42. def getInputChunk(self, center_xyz, width_irc):
  43. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
  44. slice_list = []
  45. for axis, center_val in enumerate(center_irc):
  46. start_ndx = int(round(center_val - width_irc[axis]/2))
  47. end_ndx = int(start_ndx + width_irc[axis])
  48. assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  49. if start_ndx < 0:
  50. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  51. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  52. start_ndx = 0
  53. end_ndx = int(width_irc[axis])
  54. if end_ndx > self.ary.shape[axis]:
  55. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  56. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  57. end_ndx = self.ary.shape[axis]
  58. start_ndx = int(self.ary.shape[axis] - width_irc[axis])
  59. slice_list.append(slice(start_ndx, end_ndx))
  60. ct_chunk = self.ary[slice_list]
  61. return ct_chunk, center_irc
  62. def getScaledInputChunk(self, center_xyz, width_mm, voxels_int):
  63. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
  64. ct_start = [int(round(i)) for i in xyz2irc(tuple(x - width_mm/2 for x in center_xyz), self.origin_xyz, self.vxSize_xyz, self.direction_tup)]
  65. ct_end = [int(round(i)) + 1 for i in xyz2irc(tuple(x + width_mm/2 for x in center_xyz), self.origin_xyz, self.vxSize_xyz, self.direction_tup)]
  66. for axis in range(3):
  67. if ct_start[axis] > ct_end[axis]:
  68. ct_start[axis], ct_end[axis] = ct_end[axis], ct_start[axis]
  69. pad_start = [0, 0, 0]
  70. pad_end = [ct_end[axis] - ct_start[axis] for axis in range(3)]
  71. # log.info([ct_end, ct_start, pad_end])
  72. pad_ary = np.zeros(pad_end, dtype=np.float32)
  73. for axis in range(3):
  74. if ct_start[axis] < 0:
  75. pad_start[axis] = -ct_start[axis]
  76. ct_start[axis] = 0
  77. if ct_end[axis] > self.ary.shape[axis]:
  78. pad_end[axis] -= ct_end[axis] - self.ary.shape[axis]
  79. ct_end[axis] = self.ary.shape[axis]
  80. pad_slices = tuple(slice(s,e) for s, e in zip(pad_start, pad_end))
  81. ct_slices = tuple(slice(s,e) for s, e in zip(ct_start, ct_end))
  82. pad_ary[pad_slices] = self.ary[ct_slices]
  83. try:
  84. zoom_seq = tuple(voxels_int/pad_ary.shape[axis] for axis in range(3))
  85. except:
  86. log.error([ct_end, ct_start, pad_end, center_irc, center_xyz, width_mm, self.vxSize_xyz])
  87. raise
  88. chunk_ary = scipy.ndimage.zoom(pad_ary, zoom_seq, order=1)
  89. # log.info("chunk_ary.shape {}".format([chunk_ary.shape, pad_ary.shape, zoom_seq, voxels_int]))
  90. return chunk_ary, center_irc
  91. @functools.lru_cache(1, typed=True)
  92. def getCt(series_uid):
  93. return Ct(series_uid)
  94. @cache.memoize(typed=True)
  95. def getCtInputChunk(series_uid, center_xyz, width_irc):
  96. ct = getCt(series_uid)
  97. ct_chunk, center_irc = ct.getInputChunk(center_xyz, width_irc)
  98. return ct_chunk, center_irc
  99. @cache.memoize(typed=True)
  100. def getScaledCtInputChunk(series_uid, center_xyz, width_mm, voxels_int):
  101. # log.info([series_uid, center_xyz, width_mm, voxels_int])
  102. ct = getCt(series_uid)
  103. ct_chunk, center_irc = ct.getScaledInputChunk(center_xyz, width_mm, voxels_int)
  104. return ct_chunk, center_irc
  105. def augmentChunk_shift(ct_chunk):
  106. for axis in range(1,3):
  107. new_chunk = np.zeros_like(ct_chunk)
  108. shift = random.randint(0, 2)
  109. slice_list = [slice(None)] * ct_chunk.ndim
  110. new_chunk
  111. return ct_chunk + np.random.normal(scale=0.1, size=ct_chunk.shape)
  112. def augmentChunk_noise(ct_chunk):
  113. return ct_chunk + np.random.normal(scale=0.1, size=ct_chunk.shape)
  114. def augmentChunk_mirror(ct_chunk):
  115. if random.random() > 0.5:
  116. ct_chunk = np.flip(ct_chunk, -1)
  117. return ct_chunk
  118. def augmentChunk_rotate(ct_chunk):
  119. # Rotate the nodule around the head-foot axis
  120. angle = 360 * random.random()
  121. # https://docs.scipy.org/doc/scipy-0.16.1/reference/generated/scipy.ndimage.interpolation.rotate.html
  122. ct_chunk = scipy.ndimage.interpolation.rotate(
  123. ct_chunk,
  124. angle,
  125. axes=(-2, -1),
  126. reshape=False,
  127. order=1,
  128. )
  129. return ct_chunk
  130. def augmentChunk_zoomAndCrop(ct_chunk):
  131. # log.info([ct_chunk.shape])
  132. zoom = 1.0 + 0.2 * random.random()
  133. with warnings.catch_warnings():
  134. warnings.simplefilter("ignore")
  135. # https://docs.scipy.org/doc/scipy-0.16.1/reference/generated/scipy.ndimage.interpolation.zoom.html
  136. ct_chunk = scipy.ndimage.interpolation.zoom(
  137. ct_chunk,
  138. zoom,
  139. order=1
  140. )
  141. crop_list = [random.randint(0, ct_chunk.shape[axis]-16) for axis in range(1,4)]
  142. slice_list = [slice(None)] + [slice(start, start+16) for start in crop_list]
  143. ct_chunk = ct_chunk[slice_list]
  144. assert ct_chunk.shape[-3:] == (16, 16, 16), repr(ct_chunk.shape)
  145. return ct_chunk
  146. def augmentCtInputChunk(ct_chunk):
  147. augment_list = [
  148. augmentChunk_mirror,
  149. augmentChunk_rotate,
  150. augmentChunk_noise,
  151. augmentChunk_zoomAndCrop,
  152. ]
  153. for augment_func in augment_list:
  154. ct_chunk = augment_func(ct_chunk)
  155. return ct_chunk
  156. class LunaDataset(Dataset):
  157. def __init__(self, test_stride=0, isTestSet_bool=None, series_uid=None,
  158. balanced_bool=False,
  159. scaled_bool=False,
  160. augmented_bool=False,
  161. ):
  162. self.balanced_bool = balanced_bool
  163. self.scaled_bool = scaled_bool
  164. self.augmented_bool = augmented_bool
  165. # We construct a set with all series_uids that are present on disk.
  166. # This will let us use the data, even if we haven't downloaded all of
  167. # the subsets yet.
  168. mhd_list = glob.glob('data/luna/subset*/*.mhd')
  169. present_set = {p.rsplit('/', 1)[-1][:-4] for p in mhd_list}
  170. sample_list = []
  171. with open('data/luna/candidates.csv', "r") as f:
  172. csv_list = list(csv.reader(f))
  173. for row in csv_list[1:]:
  174. row_uid = row[0]
  175. if series_uid and series_uid != row_uid:
  176. continue
  177. # If a row_uid isn't present, that means it's in a subset that we
  178. # don't have on disk, so we should skip it.
  179. if row_uid not in present_set:
  180. continue
  181. center_xyz = tuple([float(x) for x in row[1:4]])
  182. isMalignant_bool = bool(int(row[4]))
  183. sample_list.append((row_uid, center_xyz, isMalignant_bool))
  184. sample_list.sort()
  185. if test_stride > 1:
  186. if isTestSet_bool:
  187. sample_list = sample_list[::test_stride]
  188. else:
  189. del sample_list[::test_stride]
  190. self.sample_list = sample_list
  191. self.benignIndex_list = [i for i, x in enumerate(sample_list) if not x[2]]
  192. self.malignantIndex_list = [i for i, x in enumerate(sample_list) if x[2]]
  193. self.shuffleSamples()
  194. log.info("{!r}: {} {} samples, {} ben, {} mal".format(
  195. self,
  196. len(sample_list),
  197. "testing" if isTestSet_bool else "training",
  198. len(self.benignIndex_list),
  199. len(self.malignantIndex_list),
  200. ))
  201. def shuffleSamples(self):
  202. if self.balanced_bool:
  203. log.warning("Shufflin'")
  204. random.shuffle(self.benignIndex_list)
  205. random.shuffle(self.malignantIndex_list)
  206. def __len__(self):
  207. if self.balanced_bool:
  208. return min(len(self.benignIndex_list), len(self.malignantIndex_list)) * 2 * 50
  209. else:
  210. return len(self.sample_list)
  211. def __getitem__(self, ndx):
  212. if self.balanced_bool:
  213. if ndx % 2:
  214. sample_ndx = self.benignIndex_list[(ndx // 2) % len(self.benignIndex_list)]
  215. else:
  216. sample_ndx = self.malignantIndex_list[(ndx // 2) % len(self.malignantIndex_list)]
  217. else:
  218. sample_ndx = ndx
  219. series_uid, center_xyz, isMalignant_bool = self.sample_list[sample_ndx]
  220. if self.scaled_bool:
  221. ct_chunk, center_irc = getScaledCtInputChunk(series_uid, center_xyz, 12, 20)
  222. # in: dim=3, Index x Row x Col
  223. # out: dim=4, Channel x Index x Row x Col
  224. ct_chunk = np.expand_dims(ct_chunk, 0)
  225. if self.augmented_bool:
  226. ct_chunk = augmentCtInputChunk(ct_chunk)
  227. else:
  228. ct_chunk = ct_chunk[:, 2:-2, 2:-2, 2:-2]
  229. else:
  230. ct_chunk, center_irc = getCtInputChunk(series_uid, center_xyz, (16, 16, 16))
  231. ct_chunk = np.expand_dims(ct_chunk, 0)
  232. assert ct_chunk.shape[-3:] == (16, 16, 16), repr(ct_chunk.shape)
  233. ct_tensor = torch.from_numpy(np.array(ct_chunk, dtype=np.float32))
  234. # ct_tensor = ct_tensor.unsqueeze(0)
  235. # dim=1
  236. malignant_tensor = torch.from_numpy(np.array([isMalignant_bool], dtype=np.float32))
  237. malignant_tensor = malignant_tensor.unsqueeze(0)
  238. # Unpacked as: input_tensor, answer_int, series_uid, center_irc
  239. return ct_tensor, malignant_tensor, series_uid, center_irc