dsets.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210
  1. import copy
  2. import csv
  3. import functools
  4. import glob
  5. import os
  6. import random
  7. from collections import namedtuple
  8. import SimpleITK as sitk
  9. import numpy as np
  10. import torch
  11. import torch.cuda
  12. from torch.utils.data import Dataset
  13. from util.disk import getCache
  14. from util.util import XyzTuple, xyz2irc
  15. from util.logconf import logging
  16. log = logging.getLogger(__name__)
  17. # log.setLevel(logging.WARN)
  18. # log.setLevel(logging.INFO)
  19. log.setLevel(logging.DEBUG)
  20. raw_cache = getCache('part2ch11_raw')
  21. CandidateInfoTuple = namedtuple(
  22. 'CandidateInfoTuple',
  23. 'isNodule_bool, diameter_mm, series_uid, center_xyz',
  24. )
  25. @functools.lru_cache(1)
  26. def getCandidateInfoList(requireOnDisk_bool=True):
  27. # We construct a set with all series_uids that are present on disk.
  28. # This will let us use the data, even if we haven't downloaded all of
  29. # the subsets yet.
  30. mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
  31. presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
  32. diameter_dict = {}
  33. with open('data/part2/luna/annotations.csv', "r") as f:
  34. for row in list(csv.reader(f))[1:]:
  35. series_uid = row[0]
  36. annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
  37. annotationDiameter_mm = float(row[4])
  38. diameter_dict.setdefault(series_uid, []).append(
  39. (annotationCenter_xyz, annotationDiameter_mm),
  40. )
  41. candidateInfo_list = []
  42. with open('data/part2/luna/candidates.csv', "r") as f:
  43. for row in list(csv.reader(f))[1:]:
  44. series_uid = row[0]
  45. if series_uid not in presentOnDisk_set and requireOnDisk_bool:
  46. continue
  47. isNodule_bool = bool(int(row[4]))
  48. candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
  49. candidateDiameter_mm = 0.0
  50. for annotation_tup in diameter_dict.get(series_uid, []):
  51. annotationCenter_xyz, annotationDiameter_mm = annotation_tup
  52. for i in range(3):
  53. delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
  54. if delta_mm > annotationDiameter_mm / 4:
  55. break
  56. else:
  57. candidateDiameter_mm = annotationDiameter_mm
  58. break
  59. candidateInfo_list.append(CandidateInfoTuple(
  60. isNodule_bool,
  61. candidateDiameter_mm,
  62. series_uid,
  63. candidateCenter_xyz,
  64. ))
  65. candidateInfo_list.sort(reverse=True)
  66. return candidateInfo_list
  67. class Ct:
  68. def __init__(self, series_uid):
  69. mhd_path = glob.glob(
  70. 'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
  71. )[0]
  72. ct_mhd = sitk.ReadImage(mhd_path)
  73. ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  74. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  75. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  76. # The lower bound gets rid of negative density stuff used to indicate out-of-FOV
  77. # The upper bound nukes any weird hotspots and clamps bone down
  78. ct_a.clip(-1000, 1000, ct_a)
  79. self.series_uid = series_uid
  80. self.hu_a = ct_a
  81. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  82. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  83. self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
  84. def getRawCandidate(self, center_xyz, width_irc):
  85. center_irc = xyz2irc(
  86. center_xyz,
  87. self.origin_xyz,
  88. self.vxSize_xyz,
  89. self.direction_a,
  90. )
  91. slice_list = []
  92. for axis, center_val in enumerate(center_irc):
  93. start_ndx = int(round(center_val - width_irc[axis]/2))
  94. end_ndx = int(start_ndx + width_irc[axis])
  95. assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  96. if start_ndx < 0:
  97. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  98. # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
  99. start_ndx = 0
  100. end_ndx = int(width_irc[axis])
  101. if end_ndx > self.hu_a.shape[axis]:
  102. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  103. # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
  104. end_ndx = self.hu_a.shape[axis]
  105. start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
  106. slice_list.append(slice(start_ndx, end_ndx))
  107. ct_chunk = self.hu_a[tuple(slice_list)]
  108. return ct_chunk, center_irc
  109. @functools.lru_cache(1, typed=True)
  110. def getCt(series_uid):
  111. return Ct(series_uid)
  112. @raw_cache.memoize(typed=True)
  113. def getCtRawCandidate(series_uid, center_xyz, width_irc):
  114. ct = getCt(series_uid)
  115. ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
  116. return ct_chunk, center_irc
  117. class LunaDataset(Dataset):
  118. def __init__(self,
  119. val_stride=0,
  120. isValSet_bool=None,
  121. series_uid=None,
  122. sortby_str='random',
  123. ):
  124. self.candidateInfo_list = copy.copy(getCandidateInfoList())
  125. if series_uid:
  126. self.candidateInfo_list = [
  127. x for x in self.candidateInfo_list if x.series_uid == series_uid
  128. ]
  129. if isValSet_bool:
  130. assert val_stride > 0, val_stride
  131. self.candidateInfo_list = self.candidateInfo_list[::val_stride]
  132. assert self.candidateInfo_list
  133. elif val_stride > 0:
  134. del self.candidateInfo_list[::val_stride]
  135. assert self.candidateInfo_list
  136. if sortby_str == 'random':
  137. random.shuffle(self.candidateInfo_list)
  138. elif sortby_str == 'series_uid':
  139. self.candidateInfo_list.sort(key=lambda x: (x.series_uid, x.center_xyz))
  140. elif sortby_str == 'label_and_size':
  141. pass
  142. else:
  143. raise Exception("Unknown sort: " + repr(sortby_str))
  144. log.info("{!r}: {} {} samples".format(
  145. self,
  146. len(self.candidateInfo_list),
  147. "validation" if isValSet_bool else "training",
  148. ))
  149. def __len__(self):
  150. return len(self.candidateInfo_list)
  151. def __getitem__(self, ndx):
  152. candidateInfo_tup = self.candidateInfo_list[ndx]
  153. width_irc = (32, 48, 48)
  154. candidate_a, center_irc = getCtRawCandidate(
  155. candidateInfo_tup.series_uid,
  156. candidateInfo_tup.center_xyz,
  157. width_irc,
  158. )
  159. candidate_t = torch.from_numpy(candidate_a).to(torch.float32)
  160. candidate_t = candidate_t.unsqueeze(0)
  161. pos_t = torch.tensor([
  162. not candidateInfo_tup.isNodule_bool,
  163. candidateInfo_tup.isNodule_bool
  164. ],
  165. dtype=torch.long,
  166. )
  167. return candidate_t, pos_t, candidateInfo_tup.series_uid, torch.tensor(center_irc)