dsets.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. import copy
  2. import csv
  3. import functools
  4. import glob
  5. import math
  6. import os
  7. import random
  8. from collections import namedtuple
  9. import SimpleITK as sitk
  10. import numpy as np
  11. import scipy.ndimage.morphology as morph
  12. import torch
  13. import torch.cuda
  14. import torch.nn.functional as F
  15. from torch.utils.data import Dataset
  16. from util.disk import getCache
  17. from util.util import XyzTuple, xyz2irc
  18. from util.logconf import logging
  19. log = logging.getLogger(__name__)
  20. # log.setLevel(logging.WARN)
  21. # log.setLevel(logging.INFO)
  22. log.setLevel(logging.DEBUG)
  23. raw_cache = getCache('part2ch13_raw')
  24. NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
  25. MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_nodule_mask, nodule_mask, lung_mask, ben_mask, mal_mask')
  26. @functools.lru_cache(1)
  27. def getNoduleInfoList(requireDataOnDisk_bool=True):
  28. # We construct a set with all series_uids that are present on disk.
  29. # This will let us use the data, even if we haven't downloaded all of
  30. # the subsets yet.
  31. mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
  32. dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
  33. diameter_dict = {}
  34. with open('data/part2/luna/annotations.csv', "r") as f:
  35. for row in list(csv.reader(f))[1:]:
  36. series_uid = row[0]
  37. annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
  38. annotationDiameter_mm = float(row[4])
  39. diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
  40. noduleInfo_list = []
  41. with open('data/part2/luna/candidates.csv', "r") as f:
  42. for row in list(csv.reader(f))[1:]:
  43. series_uid = row[0]
  44. if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
  45. continue
  46. isMalignant_bool = bool(int(row[4]))
  47. candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
  48. candidateDiameter_mm = 0.0
  49. for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
  50. for i in range(3):
  51. delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
  52. if delta_mm > annotationDiameter_mm / 4:
  53. break
  54. else:
  55. candidateDiameter_mm = annotationDiameter_mm
  56. break
  57. noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
  58. noduleInfo_list.sort(reverse=True)
  59. return noduleInfo_list
  60. class Ct(object):
  61. def __init__(self, series_uid, buildMasks_bool=True):
  62. mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
  63. ct_mhd = sitk.ReadImage(mhd_path)
  64. ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  65. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  66. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  67. # This gets rid of negative density stuff used to indicate out-of-FOV
  68. ct_a[ct_a < -1000] = -1000
  69. # This nukes any weird hotspots and clamps bone down
  70. ct_a[ct_a > 1000] = 1000
  71. self.series_uid = series_uid
  72. self.hu_a = ct_a
  73. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  74. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  75. self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
  76. noduleInfo_list = getNoduleInfoList()
  77. self.benignInfo_list = [ni_tup
  78. for ni_tup in noduleInfo_list
  79. if not ni_tup.isMalignant_bool
  80. and ni_tup.series_uid == self.series_uid]
  81. self.benign_mask = self.buildAnnotationMask(self.benignInfo_list)[0]
  82. self.benign_indexes = sorted(set(self.benign_mask.nonzero()[0]))
  83. self.malignantInfo_list = [ni_tup
  84. for ni_tup in noduleInfo_list
  85. if ni_tup.isMalignant_bool
  86. and ni_tup.series_uid == self.series_uid]
  87. self.malignant_mask = self.buildAnnotationMask(self.malignantInfo_list)[0]
  88. self.malignant_indexes = sorted(set(self.malignant_mask.nonzero()[0]))
  89. def buildAnnotationMask(self, noduleInfo_list, threshold_hu = -500):
  90. boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool)
  91. for noduleInfo_tup in noduleInfo_list:
  92. center_irc = xyz2irc(
  93. noduleInfo_tup.center_xyz,
  94. self.origin_xyz,
  95. self.vxSize_xyz,
  96. self.direction_tup,
  97. )
  98. ci = int(center_irc.index)
  99. cr = int(center_irc.row)
  100. cc = int(center_irc.col)
  101. index_radius = 2
  102. try:
  103. while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and \
  104. self.hu_a[ci - index_radius, cr, cc] > threshold_hu:
  105. index_radius += 1
  106. except IndexError:
  107. index_radius -= 1
  108. row_radius = 2
  109. try:
  110. while self.hu_a[ci, cr + row_radius, cc] > threshold_hu and \
  111. self.hu_a[ci, cr - row_radius, cc] > threshold_hu:
  112. row_radius += 1
  113. except IndexError:
  114. row_radius -= 1
  115. col_radius = 2
  116. try:
  117. while self.hu_a[ci, cr, cc + col_radius] > threshold_hu and \
  118. self.hu_a[ci, cr, cc - col_radius] > threshold_hu:
  119. col_radius += 1
  120. except IndexError:
  121. col_radius -= 1
  122. # assert index_radius > 0, repr([noduleInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]])
  123. # assert row_radius > 0
  124. # assert col_radius > 0
  125. slice_tup = (
  126. slice(ci - index_radius, ci + index_radius + 1),
  127. slice(cr - row_radius, cr + row_radius + 1),
  128. slice(cc - col_radius, cc + row_radius + 1),
  129. )
  130. boundingBox_a[slice_tup] = True
  131. thresholded_a = boundingBox_a & (self.hu_a > threshold_hu)
  132. mask_a = morph.binary_dilation(thresholded_a, iterations=2)
  133. return mask_a, thresholded_a, boundingBox_a
  134. def build2dLungMask(self, mask_ndx):
  135. raw_dense_mask = self.hu_a[mask_ndx] > -300
  136. dense_mask = morph.binary_closing(raw_dense_mask, iterations=2)
  137. dense_mask = morph.binary_opening(dense_mask, iterations=2)
  138. body_mask = morph.binary_fill_holes(dense_mask)
  139. air_mask = morph.binary_fill_holes(body_mask & ~dense_mask)
  140. air_mask = morph.binary_erosion(air_mask, iterations=1)
  141. lung_mask = morph.binary_dilation(air_mask, iterations=5)
  142. raw_nodule_mask = self.hu_a[mask_ndx] > -600
  143. raw_nodule_mask &= air_mask
  144. nodule_mask = morph.binary_opening(raw_nodule_mask, iterations=1)
  145. ben_mask = morph.binary_dilation(nodule_mask, iterations=1)
  146. ben_mask &= ~self.malignant_mask[mask_ndx]
  147. mal_mask = self.malignant_mask[mask_ndx]
  148. return MaskTuple(
  149. raw_dense_mask,
  150. dense_mask,
  151. body_mask,
  152. air_mask,
  153. raw_nodule_mask,
  154. nodule_mask,
  155. lung_mask,
  156. ben_mask,
  157. mal_mask,
  158. )
  159. def getRawNodule(self, center_xyz, width_irc):
  160. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
  161. slice_list = []
  162. for axis, center_val in enumerate(center_irc):
  163. start_ndx = int(round(center_val - width_irc[axis]/2))
  164. end_ndx = int(start_ndx + width_irc[axis])
  165. assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  166. if start_ndx < 0:
  167. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  168. # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
  169. start_ndx = 0
  170. end_ndx = int(width_irc[axis])
  171. if end_ndx > self.hu_a.shape[axis]:
  172. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  173. # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
  174. end_ndx = self.hu_a.shape[axis]
  175. start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
  176. slice_list.append(slice(start_ndx, end_ndx))
  177. ct_chunk = self.hu_a[tuple(slice_list)]
  178. return ct_chunk, center_irc
  179. ctCache_depth = 5
  180. @functools.lru_cache(ctCache_depth, typed=True)
  181. def getCt(series_uid):
  182. return Ct(series_uid)
  183. @raw_cache.memoize(typed=True)
  184. def getCtRawNodule(series_uid, center_xyz, width_irc):
  185. ct = getCt(series_uid)
  186. ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
  187. return ct_chunk, center_irc
  188. @raw_cache.memoize(typed=True)
  189. def getCtSampleSize(series_uid):
  190. ct = Ct(series_uid, buildMasks_bool=False)
  191. return len(ct.benign_indexes)
  192. def getCtAugmentedNodule(
  193. augmentation_dict,
  194. series_uid, center_xyz, width_irc,
  195. use_cache=True):
  196. if use_cache:
  197. ct_chunk, center_irc = getCtRawNodule(series_uid, center_xyz, width_irc)
  198. else:
  199. ct = getCt(series_uid)
  200. ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
  201. ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
  202. transform_t = torch.eye(4).to(torch.float64)
  203. for i in range(3):
  204. if 'flip' in augmentation_dict:
  205. if random.random() > 0.5:
  206. transform_t[i,i] *= -1
  207. if 'offset' in augmentation_dict:
  208. offset_float = augmentation_dict['offset']
  209. random_float = (random.random() * 2 - 1)
  210. transform_t[3,i] = offset_float * random_float
  211. if 'scale' in augmentation_dict:
  212. scale_float = augmentation_dict['scale']
  213. random_float = (random.random() * 2 - 1)
  214. transform_t[i,i] *= 1.0 + scale_float * random_float
  215. if 'rotate' in augmentation_dict:
  216. angle_rad = random.random() * math.pi * 2
  217. s = math.sin(angle_rad)
  218. c = math.cos(angle_rad)
  219. rotation_t = torch.tensor([
  220. [c, -s, 0, 0],
  221. [s, c, 0, 0],
  222. [0, 0, 1, 0],
  223. [0, 0, 0, 1],
  224. ], dtype=torch.float64)
  225. transform_t @= rotation_t
  226. affine_t = F.affine_grid(
  227. transform_t[:3].unsqueeze(0).to(torch.float32),
  228. ct_t.size(),
  229. )
  230. augmented_chunk = F.grid_sample(
  231. ct_t,
  232. affine_t,
  233. padding_mode='border'
  234. ).to('cpu')
  235. if 'noise' in augmentation_dict:
  236. noise_t = torch.randn_like(augmented_chunk)
  237. noise_t *= augmentation_dict['noise']
  238. augmented_chunk += noise_t
  239. return augmented_chunk[0], center_irc
  240. class LunaDataset(Dataset):
  241. def __init__(self,
  242. val_stride=0,
  243. isValSet_bool=None,
  244. series_uid=None,
  245. sortby_str='random',
  246. ratio_int=0,
  247. augmentation_dict=None,
  248. noduleInfo_list=None,
  249. ):
  250. self.ratio_int = ratio_int
  251. self.augmentation_dict = augmentation_dict
  252. if noduleInfo_list:
  253. self.noduleInfo_list = copy.copy(noduleInfo_list)
  254. self.use_cache = False
  255. else:
  256. self.noduleInfo_list = copy.copy(getNoduleInfoList())
  257. self.use_cache = True
  258. if series_uid:
  259. self.series_list = [series_uid]
  260. else:
  261. self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
  262. if isValSet_bool:
  263. assert val_stride > 0, val_stride
  264. self.series_list = self.series_list[::val_stride]
  265. assert self.series_list
  266. elif val_stride > 0:
  267. del self.series_list[::val_stride]
  268. assert self.series_list
  269. series_set = set(self.series_list)
  270. self.noduleInfo_list = [x for x in self.noduleInfo_list if x.series_uid in series_set]
  271. if sortby_str == 'random':
  272. random.shuffle(self.noduleInfo_list)
  273. elif sortby_str == 'series_uid':
  274. self.noduleInfo_list.sort(key=lambda x: (x.series_uid, x.center_xyz))
  275. elif sortby_str == 'malignancy_size':
  276. pass
  277. else:
  278. raise Exception("Unknown sort: " + repr(sortby_str))
  279. self.benign_list = [nt for nt in self.noduleInfo_list if not nt.isMalignant_bool]
  280. self.malignant_list = [nt for nt in self.noduleInfo_list if nt.isMalignant_bool]
  281. log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
  282. self,
  283. len(self.noduleInfo_list),
  284. "validation" if isValSet_bool else "training",
  285. len(self.benign_list),
  286. len(self.malignant_list),
  287. '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
  288. ))
  289. def shuffleSamples(self):
  290. if self.ratio_int:
  291. random.shuffle(self.benign_list)
  292. random.shuffle(self.malignant_list)
  293. def __len__(self):
  294. if self.ratio_int:
  295. # return 20000
  296. return 200000
  297. else:
  298. return len(self.noduleInfo_list)
  299. def __getitem__(self, ndx):
  300. if self.ratio_int:
  301. malignant_ndx = ndx // (self.ratio_int + 1)
  302. if ndx % (self.ratio_int + 1):
  303. benign_ndx = ndx - 1 - malignant_ndx
  304. benign_ndx %= len(self.benign_list)
  305. nodule_tup = self.benign_list[benign_ndx]
  306. else:
  307. malignant_ndx %= len(self.malignant_list)
  308. nodule_tup = self.malignant_list[malignant_ndx]
  309. else:
  310. nodule_tup = self.noduleInfo_list[ndx]
  311. width_irc = (32, 48, 48)
  312. if self.augmentation_dict:
  313. nodule_t, center_irc = getCtAugmentedNodule(
  314. self.augmentation_dict,
  315. nodule_tup.series_uid,
  316. nodule_tup.center_xyz,
  317. width_irc,
  318. self.use_cache,
  319. )
  320. elif self.use_cache:
  321. nodule_a, center_irc = getCtRawNodule(
  322. nodule_tup.series_uid,
  323. nodule_tup.center_xyz,
  324. width_irc,
  325. )
  326. nodule_t = torch.from_numpy(nodule_a).to(torch.float32)
  327. nodule_t = nodule_t.unsqueeze(0)
  328. else:
  329. ct = getCt(nodule_tup.series_uid)
  330. nodule_a, center_irc = ct.getRawNodule(
  331. nodule_tup.center_xyz,
  332. width_irc,
  333. )
  334. nodule_t = torch.from_numpy(nodule_a).to(torch.float32)
  335. nodule_t = nodule_t.unsqueeze(0)
  336. malignant_t = torch.tensor([
  337. not nodule_tup.isMalignant_bool,
  338. nodule_tup.isMalignant_bool
  339. ],
  340. dtype=torch.long,
  341. )
  342. # log.debug([type(center_irc), center_irc])
  343. return nodule_t, malignant_t, nodule_tup.series_uid, torch.tensor(center_irc)
  344. class PrepcacheLunaDataset(LunaDataset):
  345. def __getitem__(self, ndx):
  346. nodule_t, malignant_t, series_uid, center_t = super().__getitem__(ndx)
  347. getCtSampleSize(series_uid)
  348. return nodule_t, malignant_t, series_uid, center_t
  349. class Luna2dSegmentationDataset(Dataset):
  350. def __init__(self,
  351. val_stride=0,
  352. isValSet_bool=None,
  353. series_uid=None,
  354. contextSlices_count=2,
  355. augmentation_dict=None,
  356. fullCt_bool=False,
  357. ):
  358. self.contextSlices_count = contextSlices_count
  359. self.augmentation_dict = augmentation_dict
  360. if series_uid:
  361. self.series_list = [series_uid]
  362. else:
  363. self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
  364. if isValSet_bool:
  365. assert val_stride > 0, val_stride
  366. self.series_list = self.series_list[::val_stride]
  367. assert self.series_list
  368. elif val_stride > 0:
  369. del self.series_list[::val_stride]
  370. assert self.series_list
  371. self.sample_list = []
  372. for series_uid in self.series_list:
  373. if fullCt_bool:
  374. self.sample_list.extend([(series_uid, ct_ndx) for ct_ndx in range(getCt(series_uid).hu_a.shape[0])])
  375. else:
  376. self.sample_list.extend([(series_uid, ct_ndx) for ct_ndx in range(getCtSampleSize(series_uid))])
  377. log.info("{!r}: {} {} series, {} slices".format(
  378. self,
  379. len(self.series_list),
  380. {None: 'general', True: 'validation', False: 'training'}[isValSet_bool],
  381. len(self.sample_list),
  382. ))
  383. def __len__(self):
  384. return len(self.sample_list) #// 100
  385. def __getitem__(self, ndx):
  386. if isinstance(ndx, int):
  387. series_uid, sample_ndx = self.sample_list[ndx % len(self.sample_list)]
  388. ct = getCt(series_uid)
  389. ct_ndx = self.sample_list[sample_ndx][1]
  390. useAugmentation_bool = False
  391. else:
  392. series_uid, ct_ndx, useAugmentation_bool = ndx
  393. ct = getCt(series_uid)
  394. ct_t = torch.zeros((self.contextSlices_count * 2 + 1 + 1, 512, 512))
  395. start_ndx = ct_ndx - self.contextSlices_count
  396. end_ndx = ct_ndx + self.contextSlices_count + 1
  397. for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
  398. context_ndx = max(context_ndx, 0)
  399. context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
  400. ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))
  401. ct_t /= 1000
  402. mask_tup = ct.build2dLungMask(ct_ndx)
  403. ct_t[-1] = torch.from_numpy(mask_tup.lung_mask.astype(np.float32))
  404. nodule_t = torch.from_numpy(
  405. (mask_tup.mal_mask | mask_tup.ben_mask).astype(np.float32)
  406. ).unsqueeze(0)
  407. ben_t = torch.from_numpy(mask_tup.ben_mask.astype(np.float32)).unsqueeze(0)
  408. mal_t = torch.from_numpy(mask_tup.mal_mask.astype(np.float32)).unsqueeze(0)
  409. label_int = mal_t.max() + ben_t.max() * 2
  410. if self.augmentation_dict and useAugmentation_bool:
  411. if 'rotate' in self.augmentation_dict:
  412. if random.random() > 0.5:
  413. ct_t = ct_t.rot90(1, [1, 2])
  414. nodule_t = nodule_t.rot90(1, [1, 2])
  415. if 'flip' in self.augmentation_dict:
  416. dims = [d+1 for d in range(2) if random.random() > 0.5]
  417. if dims:
  418. ct_t = ct_t.flip(dims)
  419. nodule_t = nodule_t.flip(dims)
  420. if 'noise' in self.augmentation_dict:
  421. noise_t = torch.randn_like(ct_t)
  422. noise_t *= self.augmentation_dict['noise']
  423. ct_t += noise_t
  424. return ct_t, nodule_t, label_int, ben_t, mal_t, ct.series_uid, ct_ndx
  425. class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
  426. def __init__(self, *args, batch_size=80, **kwargs):
  427. self.needsShuffle_bool = True
  428. self.batch_size = batch_size
  429. # self.rotate_frac = 0.5 * len(self.series_list) / len(self)
  430. super().__init__(*args, **kwargs)
  431. def __len__(self):
  432. return 50000
  433. def __getitem__(self, ndx):
  434. if self.needsShuffle_bool:
  435. random.shuffle(self.series_list)
  436. self.needsShuffle_bool = False
  437. if isinstance(ndx, int):
  438. if ndx % self.batch_size == 0:
  439. self.series_list.append(self.series_list.pop(0))
  440. series_uid = self.series_list[ndx % ctCache_depth]
  441. ct = getCt(series_uid)
  442. if ndx % 3 == 0:
  443. ct_ndx = random.choice(ct.malignant_indexes or ct.benign_indexes)
  444. elif ndx % 3 == 1:
  445. ct_ndx = random.choice(ct.benign_indexes)
  446. elif ndx % 3 == 2:
  447. ct_ndx = random.choice(list(range(ct.hu_a.shape[0])))
  448. useAugmentation_bool = True
  449. else:
  450. series_uid, ct_ndx, useAugmentation_bool = ndx
  451. return super().__getitem__((series_uid, ct_ndx, useAugmentation_bool))