dsets.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import copy
  2. import csv
  3. import functools
  4. import glob
  5. import math
  6. import os
  7. import random
  8. from collections import namedtuple
  9. import SimpleITK as sitk
  10. import numpy as np
  11. import torch
  12. import torch.cuda
  13. import torch.nn.functional as F
  14. from torch.utils.data import Dataset
  15. from util.disk import getCache
  16. from util.util import XyzTuple, xyz2irc
  17. from util.logconf import logging
  18. log = logging.getLogger(__name__)
  19. # log.setLevel(logging.WARN)
  20. # log.setLevel(logging.INFO)
  21. log.setLevel(logging.DEBUG)
  22. raw_cache = getCache('part2ch12_raw')
  23. NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
  24. @functools.lru_cache(1)
  25. def getNoduleInfoList(requireDataOnDisk_bool=True):
  26. # We construct a set with all series_uids that are present on disk.
  27. # This will let us use the data, even if we haven't downloaded all of
  28. # the subsets yet.
  29. mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
  30. dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
  31. diameter_dict = {}
  32. with open('data/part2/luna/annotations.csv', "r") as f:
  33. for row in list(csv.reader(f))[1:]:
  34. series_uid = row[0]
  35. annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
  36. annotationDiameter_mm = float(row[4])
  37. diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
  38. noduleInfo_list = []
  39. with open('data/part2/luna/candidates.csv', "r") as f:
  40. for row in list(csv.reader(f))[1:]:
  41. series_uid = row[0]
  42. if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
  43. continue
  44. isMalignant_bool = bool(int(row[4]))
  45. candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
  46. candidateDiameter_mm = 0.0
  47. for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
  48. for i in range(3):
  49. delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
  50. if delta_mm > annotationDiameter_mm / 4:
  51. break
  52. else:
  53. candidateDiameter_mm = annotationDiameter_mm
  54. break
  55. noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
  56. noduleInfo_list.sort(reverse=True)
  57. return noduleInfo_list
  58. class Ct(object):
  59. def __init__(self, series_uid):
  60. mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
  61. ct_mhd = sitk.ReadImage(mhd_path)
  62. ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  63. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  64. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  65. # This gets rid of negative density stuff used to indicate out-of-FOV
  66. ct_a[ct_a < -1000] = -1000
  67. # This nukes any weird hotspots and clamps bone down
  68. ct_a[ct_a > 1000] = 1000
  69. self.series_uid = series_uid
  70. self.hu_a = ct_a
  71. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  72. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  73. self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
  74. def getRawNodule(self, center_xyz, width_irc):
  75. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
  76. slice_list = []
  77. for axis, center_val in enumerate(center_irc):
  78. start_ndx = int(round(center_val - width_irc[axis]/2))
  79. end_ndx = int(start_ndx + width_irc[axis])
  80. assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  81. if start_ndx < 0:
  82. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  83. # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
  84. start_ndx = 0
  85. end_ndx = int(width_irc[axis])
  86. if end_ndx > self.hu_a.shape[axis]:
  87. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  88. # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
  89. end_ndx = self.hu_a.shape[axis]
  90. start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
  91. slice_list.append(slice(start_ndx, end_ndx))
  92. ct_chunk = self.hu_a[tuple(slice_list)]
  93. return ct_chunk, center_irc
  94. @functools.lru_cache(1, typed=True)
  95. def getCt(series_uid):
  96. return Ct(series_uid)
  97. @raw_cache.memoize(typed=True)
  98. def getCtRawNodule(series_uid, center_xyz, width_irc):
  99. ct = getCt(series_uid)
  100. ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
  101. return ct_chunk, center_irc
  102. def getCtAugmentedNodule(
  103. augmentation_dict,
  104. series_uid, center_xyz, width_irc,
  105. use_cache=True):
  106. if use_cache:
  107. ct_chunk, center_irc = getCtRawNodule(series_uid, center_xyz, width_irc)
  108. else:
  109. ct = getCt(series_uid)
  110. ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
  111. ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
  112. transform_t = torch.eye(4).to(torch.float64)
  113. # ... <1>
  114. for i in range(3):
  115. if 'flip' in augmentation_dict:
  116. if random.random() > 0.5:
  117. transform_t[i,i] *= -1
  118. if 'offset' in augmentation_dict:
  119. offset_float = augmentation_dict['offset']
  120. random_float = (random.random() * 2 - 1)
  121. transform_t[3,i] = offset_float * random_float
  122. if 'scale' in augmentation_dict:
  123. scale_float = augmentation_dict['scale']
  124. random_float = (random.random() * 2 - 1)
  125. transform_t[i,i] *= 1.0 + scale_float * random_float
  126. if 'rotate' in augmentation_dict:
  127. angle_rad = random.random() * math.pi * 2
  128. s = math.sin(angle_rad)
  129. c = math.cos(angle_rad)
  130. rotation_t = torch.tensor([
  131. [c, -s, 0, 0],
  132. [s, c, 0, 0],
  133. [0, 0, 1, 0],
  134. [0, 0, 0, 1],
  135. ], dtype=torch.float64)
  136. transform_t @= rotation_t
  137. affine_t = F.affine_grid(
  138. transform_t[:3].unsqueeze(0).to(torch.float32),
  139. ct_t.size(),
  140. )
  141. augmented_chunk = F.grid_sample(
  142. ct_t,
  143. affine_t,
  144. padding_mode='border'
  145. ).to('cpu')
  146. if 'noise' in augmentation_dict:
  147. noise_t = torch.randn_like(augmented_chunk)
  148. noise_t *= augmentation_dict['noise']
  149. augmented_chunk += noise_t
  150. return augmented_chunk[0], center_irc
  151. class LunaDataset(Dataset):
  152. def __init__(self,
  153. val_stride=0,
  154. isValSet_bool=None,
  155. series_uid=None,
  156. sortby_str='random',
  157. ratio_int=0,
  158. augmentation_dict=None,
  159. noduleInfo_list=None,
  160. ):
  161. self.ratio_int = ratio_int
  162. self.augmentation_dict = augmentation_dict
  163. if noduleInfo_list:
  164. self.noduleInfo_list = copy.copy(noduleInfo_list)
  165. self.use_cache = False
  166. else:
  167. self.noduleInfo_list = copy.copy(getNoduleInfoList())
  168. self.use_cache = True
  169. if series_uid:
  170. self.noduleInfo_list = [x for x in self.noduleInfo_list if x.series_uid == series_uid]
  171. if isValSet_bool:
  172. assert val_stride > 0, val_stride
  173. self.noduleInfo_list = self.noduleInfo_list[::val_stride]
  174. assert self.noduleInfo_list
  175. elif val_stride > 0:
  176. del self.noduleInfo_list[::val_stride]
  177. assert self.noduleInfo_list
  178. if sortby_str == 'random':
  179. random.shuffle(self.noduleInfo_list)
  180. elif sortby_str == 'series_uid':
  181. self.noduleInfo_list.sort(key=lambda x: (x.series_uid, x.center_xyz))
  182. elif sortby_str == 'malignancy_size':
  183. pass
  184. else:
  185. raise Exception("Unknown sort: " + repr(sortby_str))
  186. self.benign_list = [nt for nt in self.noduleInfo_list if not nt.isMalignant_bool]
  187. self.malignant_list = [nt for nt in self.noduleInfo_list if nt.isMalignant_bool]
  188. log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
  189. self,
  190. len(self.noduleInfo_list),
  191. "validation" if isValSet_bool else "training",
  192. len(self.benign_list),
  193. len(self.malignant_list),
  194. '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
  195. ))
  196. def shuffleSamples(self):
  197. if self.ratio_int:
  198. random.shuffle(self.benign_list)
  199. random.shuffle(self.malignant_list)
  200. def __len__(self):
  201. if self.ratio_int:
  202. return 20000
  203. return 200000
  204. else:
  205. return len(self.noduleInfo_list) // 20
  206. def __getitem__(self, ndx):
  207. if self.ratio_int:
  208. malignant_ndx = ndx // (self.ratio_int + 1)
  209. if ndx % (self.ratio_int + 1):
  210. benign_ndx = ndx - 1 - malignant_ndx
  211. benign_ndx %= len(self.benign_list)
  212. nodule_tup = self.benign_list[benign_ndx]
  213. else:
  214. malignant_ndx %= len(self.malignant_list)
  215. nodule_tup = self.malignant_list[malignant_ndx]
  216. else:
  217. nodule_tup = self.noduleInfo_list[ndx]
  218. width_irc = (32, 48, 48)
  219. if self.augmentation_dict:
  220. nodule_t, center_irc = getCtAugmentedNodule(
  221. self.augmentation_dict,
  222. nodule_tup.series_uid,
  223. nodule_tup.center_xyz,
  224. width_irc,
  225. self.use_cache,
  226. )
  227. elif self.use_cache:
  228. nodule_a, center_irc = getCtRawNodule(
  229. nodule_tup.series_uid,
  230. nodule_tup.center_xyz,
  231. width_irc,
  232. )
  233. nodule_t = torch.from_numpy(nodule_a).to(torch.float32)
  234. nodule_t = nodule_t.unsqueeze(0)
  235. else:
  236. ct = getCt(nodule_tup.series_uid)
  237. nodule_a, center_irc = ct.getRawNodule(
  238. nodule_tup.center_xyz,
  239. width_irc,
  240. )
  241. nodule_t = torch.from_numpy(nodule_a).to(torch.float32)
  242. nodule_t = nodule_t.unsqueeze(0)
  243. malignant_t = torch.tensor([
  244. not nodule_tup.isMalignant_bool,
  245. nodule_tup.isMalignant_bool
  246. ],
  247. dtype=torch.long,
  248. )
  249. return nodule_t, malignant_t, nodule_tup.series_uid, torch.tensor(center_irc)