dsets.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import csv
  2. import functools
  3. import glob
  4. import math
  5. import time
  6. import SimpleITK as sitk
  7. import numpy as np
  8. import torch
  9. import torch.cuda
  10. from torch.utils.data import Dataset
  11. from util.disk import getCache
  12. from util.util import XyzTuple, xyz2irc
  13. from util.logconf import logging
  14. log = logging.getLogger(__name__)
  15. # log.setLevel(logging.WARN)
  16. log.setLevel(logging.INFO)
  17. log.setLevel(logging.DEBUG)
  18. # cache = getCache('p2ch2')
  19. cache = getCache('part2')
  20. class Ct(object):
  21. def __init__(self, series_uid):
  22. mhd_path = glob.glob('data/luna/subset*/{}.mhd'.format(series_uid))[0]
  23. ct_mhd = sitk.ReadImage(mhd_path)
  24. ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  25. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  26. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  27. # This converts HU to g/cc.
  28. ct_ary += 1000
  29. ct_ary /= 1000
  30. # This gets rid of negative density stuff used to indicate out-of-FOV
  31. ct_ary[ct_ary < 0] = 0
  32. # This nukes any weird hotspots and clamps bone down
  33. ct_ary[ct_ary > 2] = 2
  34. self.series_uid = series_uid
  35. self.ary = ct_ary
  36. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  37. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  38. self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
  39. def getInputChunk(self, center_xyz, width_irc):
  40. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
  41. slice_list = []
  42. for axis, center_val in enumerate(center_irc):
  43. start_ndx = int(round(center_val - width_irc[axis]/2))
  44. end_ndx = int(start_ndx + width_irc[axis])
  45. assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  46. if start_ndx < 0:
  47. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  48. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  49. start_ndx = 0
  50. end_ndx = int(width_irc[axis])
  51. if end_ndx > self.ary.shape[axis]:
  52. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  53. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  54. end_ndx = self.ary.shape[axis]
  55. start_ndx = int(self.ary.shape[axis] - width_irc[axis])
  56. slice_list.append(slice(start_ndx, end_ndx))
  57. ct_chunk = self.ary[slice_list]
  58. return ct_chunk, center_irc
  59. @functools.lru_cache(1, typed=True)
  60. def getCt(series_uid):
  61. return Ct(series_uid)
  62. @cache.memoize(typed=True)
  63. def getCtInputChunk(series_uid, center_xyz, width_irc):
  64. ct = getCt(series_uid)
  65. ct_chunk, center_irc = ct.getInputChunk(center_xyz, width_irc)
  66. return ct_chunk, center_irc
  67. class LunaDataset(Dataset):
  68. def __init__(self, test_stride=0, isTestSet_bool=None, series_uid=None):
  69. # We construct a set with all series_uids that are present on disk.
  70. # This will let us use the data, even if we haven't downloaded all of
  71. # the subsets yet.
  72. mhd_list = glob.glob('data/luna/subset*/*.mhd')
  73. present_set = {p.rsplit('/', 1)[-1][:-4] for p in mhd_list}
  74. sample_list = []
  75. with open('data/luna/candidates.csv', "r") as f:
  76. csv_list = list(csv.reader(f))
  77. for row in csv_list[1:]:
  78. row_uid = row[0]
  79. if series_uid and series_uid != row_uid:
  80. continue
  81. # If a row_uid isn't present, that means it's in a subset that we
  82. # don't have on disk, so we should skip it.
  83. if row_uid not in present_set:
  84. continue
  85. center_xyz = tuple([float(x) for x in row[1:4]])
  86. isMalignant_bool = bool(int(row[4]))
  87. sample_list.append((row_uid, center_xyz, isMalignant_bool))
  88. sample_list.sort()
  89. if test_stride > 1:
  90. if isTestSet_bool:
  91. sample_list = sample_list[::test_stride]
  92. else:
  93. del sample_list[::test_stride]
  94. log.info("{!r}: {} {} samples".format(self, len(sample_list), "testing" if isTestSet_bool else "training"))
  95. self.sample_list = sample_list
  96. def __len__(self):
  97. return len(self.sample_list)
  98. def __getitem__(self, ndx):
  99. series_uid, center_xyz, isMalignant_bool = self.sample_list[ndx]
  100. ct_chunk, center_irc = getCtInputChunk(series_uid, center_xyz, (16, 16, 16))
  101. # dim=3, Index x Row x Col
  102. ct_tensor = torch.from_numpy(np.array(ct_chunk, dtype=np.float32))
  103. # dim=1
  104. malignant_tensor = torch.from_numpy(np.array([isMalignant_bool], dtype=np.float32))
  105. # dim=4, Channel x Index x Row x Col
  106. ct_tensor = ct_tensor.unsqueeze(0)
  107. malignant_tensor = malignant_tensor.unsqueeze(0)
  108. # Unpacked as: input_tensor, answer_int, series_uid, center_irc
  109. return ct_tensor, malignant_tensor, series_uid, center_irc