dsets.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. import csv
  2. import functools
  3. import glob
  4. import math
  5. import time
  6. import SimpleITK as sitk
  7. import numpy as np
  8. import torch
  9. import torch.cuda
  10. from torch.utils.data import Dataset
  11. from util.disk import getCache
  12. from util.util import XyzTuple, xyz2irc
  13. from util.logconf import logging
  14. log = logging.getLogger(__name__)
  15. # log.setLevel(logging.WARN)
  16. log.setLevel(logging.INFO)
  17. log.setLevel(logging.DEBUG)
  18. cache = getCache('p2ch1')
  19. class Ct(object):
  20. def __init__(self, series_uid):
  21. mhd_path = glob.glob('data/luna/subset*/{}.mhd'.format(series_uid))[0]
  22. ct_mhd = sitk.ReadImage(mhd_path)
  23. ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  24. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  25. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  26. # This converts HU to g/cc.
  27. ct_ary += 1000
  28. ct_ary /= 1000
  29. # This gets rid of negative density stuff used to indicate out-of-FOV
  30. ct_ary[ct_ary < 0] = 0
  31. # This nukes any weird hotspots and clamps bone down
  32. ct_ary[ct_ary > 2] = 2
  33. self.series_uid = series_uid
  34. self.ary = ct_ary
  35. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  36. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  37. self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
  38. def getInputChunk(self, center_xyz, width_irc):
  39. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
  40. slice_list = []
  41. for axis, center_val in enumerate(center_irc):
  42. start_ndx = int(round(center_val - width_irc[axis]/2))
  43. end_ndx = int(start_ndx + width_irc[axis])
  44. assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  45. if start_ndx < 0:
  46. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  47. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  48. start_ndx = 0
  49. end_ndx = int(width_irc[axis])
  50. if end_ndx > self.ary.shape[axis]:
  51. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  52. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  53. end_ndx = self.ary.shape[axis]
  54. start_ndx = int(self.ary.shape[axis] - width_irc[axis])
  55. slice_list.append(slice(start_ndx, end_ndx))
  56. ct_chunk = self.ary[slice_list]
  57. return ct_chunk, center_irc
  58. @functools.lru_cache(1, typed=True)
  59. def getCt(series_uid):
  60. return Ct(series_uid)
  61. @cache.memoize(typed=True)
  62. def getCtInputChunk(series_uid, center_xyz, width_irc):
  63. ct = getCt(series_uid)
  64. ct_chunk, center_irc = ct.getInputChunk(center_xyz, width_irc)
  65. return ct_chunk, center_irc
  66. class LunaDataset(Dataset):
  67. def __init__(self, test_stride=0, isTestSet_bool=None, series_uid=None):
  68. # We construct a set with all series_uids that are present on disk.
  69. # This will let us use the data, even if we haven't downloaded all of
  70. # the subsets yet.
  71. mhd_list = glob.glob('data/luna/subset*/*.mhd')
  72. present_set = {p.rsplit('/', 1)[-1][:-4] for p in mhd_list}
  73. sample_list = []
  74. with open('data/luna/candidates.csv', "r") as f:
  75. csv_list = list(csv.reader(f))
  76. for row in csv_list[1:]:
  77. row_uid = row[0]
  78. if series_uid and series_uid != row_uid:
  79. continue
  80. # If a row_uid isn't present, that means it's in a subset that we
  81. # don't have on disk, so we should skip it.
  82. if row_uid not in present_set:
  83. continue
  84. center_xyz = tuple([float(x) for x in row[1:4]])
  85. isMalignant_bool = bool(int(row[4]))
  86. sample_list.append((row_uid, center_xyz, isMalignant_bool))
  87. sample_list.sort()
  88. if test_stride > 1:
  89. if isTestSet_bool:
  90. sample_list = sample_list[::test_stride]
  91. else:
  92. del sample_list[::test_stride]
  93. log.info("{!r}: {} {} samples".format(self, len(sample_list), "testing" if isTestSet_bool else "training"))
  94. self.sample_list = sample_list
  95. def __len__(self):
  96. return len(self.sample_list)
  97. def __getitem__(self, ndx):
  98. series_uid, center_xyz, isMalignant_bool = self.sample_list[ndx]
  99. ct_chunk, center_irc = getCtInputChunk(series_uid, center_xyz, (16, 16, 16))
  100. # dim=3, Index x Row x Col
  101. ct_tensor = torch.from_numpy(np.array(ct_chunk, dtype=np.float32))
  102. # dim=1
  103. malignant_tensor = torch.from_numpy(np.array([isMalignant_bool], dtype=np.float32))
  104. # dim=4, Channel x Index x Row x Col
  105. ct_tensor = ct_tensor.unsqueeze(0)
  106. malignant_tensor = malignant_tensor.unsqueeze(0)
  107. # Unpacked as: input_tensor, answer_int, series_uid, center_irc
  108. return ct_tensor, malignant_tensor, series_uid, center_irc