dsets.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import csv
  2. import functools
  3. import glob
  4. import itertools
  5. import math
  6. import random
  7. import time
  8. import scipy.ndimage
  9. import SimpleITK as sitk
  10. import numpy as np
  11. import torch
  12. import torch.cuda
  13. from torch.utils.data import Dataset
  14. from util.disk import getCache
  15. from util.util import XyzTuple, xyz2irc
  16. from util.logconf import logging
  17. log = logging.getLogger(__name__)
  18. # log.setLevel(logging.WARN)
  19. log.setLevel(logging.INFO)
  20. log.setLevel(logging.DEBUG)
  21. # cache = getCache('p2ch3')
  22. cache = getCache('part2')
  23. class Ct(object):
  24. def __init__(self, series_uid):
  25. mhd_path = glob.glob('data/luna/subset*/{}.mhd'.format(series_uid))[0]
  26. ct_mhd = sitk.ReadImage(mhd_path)
  27. ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  28. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  29. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  30. # This converts HU to g/cc.
  31. ct_ary += 1000
  32. ct_ary /= 1000
  33. # This gets rid of negative density stuff used to indicate out-of-FOV
  34. ct_ary[ct_ary < 0] = 0
  35. # This nukes any weird hotspots and clamps bone down
  36. ct_ary[ct_ary > 2] = 2
  37. self.series_uid = series_uid
  38. self.ary = ct_ary
  39. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  40. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  41. self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
  42. def getInputChunk(self, center_xyz, width_irc):
  43. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
  44. slice_list = []
  45. for axis, center_val in enumerate(center_irc):
  46. start_ndx = int(round(center_val - width_irc[axis]/2))
  47. end_ndx = int(start_ndx + width_irc[axis])
  48. assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  49. if start_ndx < 0:
  50. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  51. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  52. start_ndx = 0
  53. end_ndx = int(width_irc[axis])
  54. if end_ndx > self.ary.shape[axis]:
  55. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  56. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  57. end_ndx = self.ary.shape[axis]
  58. start_ndx = int(self.ary.shape[axis] - width_irc[axis])
  59. slice_list.append(slice(start_ndx, end_ndx))
  60. ct_chunk = self.ary[slice_list]
  61. return ct_chunk, center_irc
  62. @functools.lru_cache(1, typed=True)
  63. def getCt(series_uid):
  64. return Ct(series_uid)
  65. @cache.memoize(typed=True)
  66. def getCtInputChunk(series_uid, center_xyz, width_irc):
  67. ct = getCt(series_uid)
  68. ct_chunk, center_irc = ct.getInputChunk(center_xyz, width_irc)
  69. return ct_chunk, center_irc
  70. class LunaDataset(Dataset):
  71. def __init__(self, test_stride=0, isTestSet_bool=None, series_uid=None,
  72. balanced_bool=False,
  73. ):
  74. self.balanced_bool = balanced_bool
  75. # We construct a set with all series_uids that are present on disk.
  76. # This will let us use the data, even if we haven't downloaded all of
  77. # the subsets yet.
  78. mhd_list = glob.glob('data/luna/subset*/*.mhd')
  79. present_set = {p.rsplit('/', 1)[-1][:-4] for p in mhd_list}
  80. sample_list = []
  81. with open('data/luna/candidates.csv', "r") as f:
  82. csv_list = list(csv.reader(f))
  83. for row in csv_list[1:]:
  84. row_uid = row[0]
  85. if series_uid and series_uid != row_uid:
  86. continue
  87. # If a row_uid isn't present, that means it's in a subset that we
  88. # don't have on disk, so we should skip it.
  89. if row_uid not in present_set:
  90. continue
  91. center_xyz = tuple([float(x) for x in row[1:4]])
  92. isMalignant_bool = bool(int(row[4]))
  93. sample_list.append((row_uid, center_xyz, isMalignant_bool))
  94. sample_list.sort()
  95. if test_stride > 1:
  96. if isTestSet_bool:
  97. sample_list = sample_list[::test_stride]
  98. else:
  99. del sample_list[::test_stride]
  100. self.sample_list = sample_list
  101. self.benignIndex_list = [i for i, x in enumerate(sample_list) if not x[2]]
  102. self.malignantIndex_list = [i for i, x in enumerate(sample_list) if x[2]]
  103. self.shuffleSamples()
  104. log.info("{!r}: {} {} samples, {} ben, {} mal".format(
  105. self,
  106. len(sample_list),
  107. "testing" if isTestSet_bool else "training",
  108. len(self.benignIndex_list),
  109. len(self.malignantIndex_list),
  110. ))
  111. def shuffleSamples(self):
  112. if self.balanced_bool:
  113. log.warning("Shufflin'")
  114. random.shuffle(self.benignIndex_list)
  115. random.shuffle(self.malignantIndex_list)
  116. def __len__(self):
  117. if self.balanced_bool:
  118. return min(len(self.benignIndex_list), len(self.malignantIndex_list)) * 2 * 50
  119. else:
  120. return len(self.sample_list)
  121. def __getitem__(self, ndx):
  122. if self.balanced_bool:
  123. if ndx % 2:
  124. sample_ndx = self.benignIndex_list[(ndx // 2) % len(self.benignIndex_list)]
  125. else:
  126. sample_ndx = self.malignantIndex_list[(ndx // 2) % len(self.malignantIndex_list)]
  127. else:
  128. sample_ndx = ndx
  129. series_uid, center_xyz, isMalignant_bool = self.sample_list[sample_ndx]
  130. ct_chunk, center_irc = getCtInputChunk(series_uid, center_xyz, (16, 16, 16))
  131. # dim=3, Index x Row x Col
  132. ct_tensor = torch.from_numpy(np.array(ct_chunk, dtype=np.float32))
  133. # dim=1
  134. malignant_tensor = torch.from_numpy(np.array([isMalignant_bool], dtype=np.float32))
  135. # dim=4, Channel x Index x Row x Col
  136. ct_tensor = ct_tensor.unsqueeze(0)
  137. malignant_tensor = malignant_tensor.unsqueeze(0)
  138. # Unpacked as: input_tensor, answer_int, series_uid, center_irc
  139. return ct_tensor, malignant_tensor, series_uid, center_irc