dsets.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569
  1. import copy
  2. import csv
  3. import functools
  4. import glob
  5. import math
  6. import os
  7. import random
  8. from collections import namedtuple
  9. import SimpleITK as sitk
  10. import numpy as np
  11. import scipy.ndimage.morphology as morph
  12. import torch
  13. import torch.cuda
  14. from torch.utils.data import Dataset
  15. import torch.nn as nn
  16. import torch.nn.functional as F
  17. from util.disk import getCache
  18. from util.util import XyzTuple, xyz2irc
  19. from util.logconf import logging
  20. log = logging.getLogger(__name__)
  21. # log.setLevel(logging.WARN)
  22. # log.setLevel(logging.INFO)
  23. log.setLevel(logging.DEBUG)
  24. raw_cache = getCache('part2ch12_raw')
  25. NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
  26. MaskTuple = namedtuple('MaskTuple', 'air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, ben_mask, mal_mask')
  27. @functools.lru_cache(1)
  28. def getNoduleInfoList(requireDataOnDisk_bool=True):
  29. # We construct a set with all series_uids that are present on disk.
  30. # This will let us use the data, even if we haven't downloaded all of
  31. # the subsets yet.
  32. mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
  33. dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
  34. diameter_dict = {}
  35. with open('data/part2/luna/annotations.csv', "r") as f:
  36. for row in list(csv.reader(f))[1:]:
  37. series_uid = row[0]
  38. annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
  39. annotationDiameter_mm = float(row[4])
  40. diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
  41. noduleInfo_list = []
  42. with open('data/part2/luna/candidates.csv', "r") as f:
  43. for row in list(csv.reader(f))[1:]:
  44. series_uid = row[0]
  45. if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
  46. continue
  47. isMalignant_bool = bool(int(row[4]))
  48. candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
  49. candidateDiameter_mm = 0.0
  50. for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
  51. for i in range(3):
  52. delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
  53. if delta_mm > annotationDiameter_mm / 4:
  54. break
  55. else:
  56. candidateDiameter_mm = annotationDiameter_mm
  57. break
  58. noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
  59. noduleInfo_list.sort(reverse=True)
  60. return noduleInfo_list
  61. class Ct(object):
  62. def __init__(self, series_uid, buildMasks_bool=True):
  63. mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
  64. ct_mhd = sitk.ReadImage(mhd_path)
  65. ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  66. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  67. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  68. # This gets rid of negative density stuff used to indicate out-of-FOV
  69. ct_ary[ct_ary < -1000] = -1000
  70. # This nukes any weird hotspots and clamps bone down
  71. ct_ary[ct_ary > 1000] = 1000
  72. self.series_uid = series_uid
  73. self.ary = ct_ary
  74. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  75. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  76. self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
  77. noduleInfo_list = getNoduleInfoList()
  78. self.benignInfo_list = [ni_tup
  79. for ni_tup in noduleInfo_list
  80. if not ni_tup.isMalignant_bool
  81. and ni_tup.series_uid == self.series_uid]
  82. self.benign_mask = self.buildAnnotationMask(self.benignInfo_list)[0]
  83. self.benign_indexes = sorted(set(self.benign_mask.nonzero()[0]))
  84. self.malignantInfo_list = [ni_tup
  85. for ni_tup in noduleInfo_list
  86. if ni_tup.isMalignant_bool
  87. and ni_tup.series_uid == self.series_uid]
  88. self.malignant_mask = self.buildAnnotationMask(self.malignantInfo_list)[0]
  89. self.malignant_indexes = sorted(set(self.malignant_mask.nonzero()[0]))
  90. def buildAnnotationMask(self, noduleInfo_list, threshold_gcc = -500):
  91. boundingBox_ary = np.zeros_like(self.ary, dtype=np.bool)
  92. for noduleInfo_tup in noduleInfo_list:
  93. center_irc = xyz2irc(
  94. noduleInfo_tup.center_xyz,
  95. self.origin_xyz,
  96. self.vxSize_xyz,
  97. self.direction_tup,
  98. )
  99. ci = int(center_irc.index)
  100. cr = int(center_irc.row)
  101. cc = int(center_irc.col)
  102. index_radius = 2
  103. try:
  104. while self.ary[ci + index_radius, cr, cc] > threshold_gcc and \
  105. self.ary[ci - index_radius, cr, cc] > threshold_gcc:
  106. index_radius += 1
  107. except IndexError:
  108. index_radius -= 1
  109. row_radius = 2
  110. try:
  111. while self.ary[ci, cr + row_radius, cc] > threshold_gcc and \
  112. self.ary[ci, cr - row_radius, cc] > threshold_gcc:
  113. row_radius += 1
  114. except IndexError:
  115. row_radius -= 1
  116. col_radius = 2
  117. try:
  118. while self.ary[ci, cr, cc + col_radius] > threshold_gcc and \
  119. self.ary[ci, cr, cc - col_radius] > threshold_gcc:
  120. col_radius += 1
  121. except IndexError:
  122. col_radius -= 1
  123. # assert index_radius > 0, repr([noduleInfo_tup.center_xyz, center_irc, self.ary[ci, cr, cc]])
  124. # assert row_radius > 0
  125. # assert col_radius > 0
  126. slice_tup = (
  127. slice(ci - index_radius, ci + index_radius + 1),
  128. slice(cr - row_radius, cr + row_radius + 1),
  129. slice(cc - col_radius, cc + row_radius + 1),
  130. )
  131. boundingBox_ary[slice_tup] = True
  132. thresholded_ary = boundingBox_ary & (self.ary > threshold_gcc)
  133. mask_ary = morph.binary_dilation(thresholded_ary, iterations=2)
  134. return mask_ary, thresholded_ary, boundingBox_ary
  135. def build2dLungMask(self, mask_ndx, threshold_gcc = -300):
  136. dense_mask = self.ary[mask_ndx] > threshold_gcc
  137. denoise_mask = morph.binary_closing(dense_mask, iterations=2)
  138. tissue_mask = morph.binary_opening(denoise_mask, iterations=10)
  139. body_mask = morph.binary_fill_holes(tissue_mask)
  140. air_mask = morph.binary_fill_holes(body_mask & ~tissue_mask)
  141. lung_mask = morph.binary_dilation(air_mask, iterations=2)
  142. ben_mask = denoise_mask & air_mask
  143. ben_mask = morph.binary_dilation(ben_mask, iterations=2)
  144. ben_mask &= ~self.malignant_mask[mask_ndx]
  145. mal_mask = self.malignant_mask[mask_ndx]
  146. return MaskTuple(
  147. air_mask,
  148. lung_mask,
  149. dense_mask,
  150. denoise_mask,
  151. tissue_mask,
  152. body_mask,
  153. ben_mask,
  154. mal_mask,
  155. )
  156. def build3dLungMask(self):
  157. air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, ben_mask, mal_mask = mask_list = \
  158. [np.zeros_like(self.ary, dtype=np.bool) for _ in range(7)]
  159. for mask_ndx in range(self.ary.shape[0]):
  160. for i, mask_ary in enumerate(self.build2dLungMask(mask_ndx)):
  161. mask_list[i][mask_ndx] = mask_ary
  162. return MaskTuple(air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask, ben_mask, mal_mask)
  163. def getRawNodule(self, center_xyz, width_irc):
  164. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
  165. slice_list = []
  166. for axis, center_val in enumerate(center_irc):
  167. try:
  168. start_ndx = int(round(center_val - width_irc[axis]/2))
  169. except:
  170. log.debug([center_val, width_irc, center_xyz, center_irc])
  171. raise
  172. end_ndx = int(start_ndx + width_irc[axis])
  173. assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  174. if start_ndx < 0:
  175. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  176. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  177. start_ndx = 0
  178. end_ndx = int(width_irc[axis])
  179. if end_ndx > self.ary.shape[axis]:
  180. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  181. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  182. end_ndx = self.ary.shape[axis]
  183. start_ndx = int(self.ary.shape[axis] - width_irc[axis])
  184. slice_list.append(slice(start_ndx, end_ndx))
  185. ct_chunk = self.ary[tuple(slice_list)]
  186. return ct_chunk, center_irc
  187. ctCache_depth = 5
  188. @functools.lru_cache(ctCache_depth, typed=True)
  189. def getCt(series_uid):
  190. return Ct(series_uid)
  191. @raw_cache.memoize(typed=True)
  192. def getCtRawNodule(series_uid, center_xyz, width_irc):
  193. ct = getCt(series_uid)
  194. ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
  195. return ct_chunk, center_irc
  196. @raw_cache.memoize(typed=True)
  197. def getCtSampleSize(series_uid):
  198. ct = Ct(series_uid, buildMasks_bool=False)
  199. return len(ct.benign_indexes)
  200. def getCtAugmentedNodule(
  201. augmentation_dict,
  202. series_uid, center_xyz, width_irc,
  203. use_cache=True):
  204. if use_cache:
  205. ct_chunk, center_irc = getCtRawNodule(series_uid, center_xyz, width_irc)
  206. else:
  207. ct = getCt(series_uid)
  208. ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
  209. ct_tensor = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
  210. transform_tensor = torch.eye(4).to(torch.float64)
  211. for i in range(3):
  212. if 'flip' in augmentation_dict:
  213. if random.random() > 0.5:
  214. transform_tensor[i,i] *= -1
  215. if 'offset' in augmentation_dict:
  216. offset_float = augmentation_dict['offset']
  217. random_float = (random.random() * 2 - 1)
  218. transform_tensor[3,i] = offset_float * random_float
  219. if 'scale' in augmentation_dict:
  220. scale_float = augmentation_dict['scale']
  221. random_float = (random.random() * 2 - 1)
  222. transform_tensor[i,i] *= 1.0 + scale_float * random_float
  223. if 'rotate' in augmentation_dict:
  224. angle_rad = random.random() * math.pi * 2
  225. s = math.sin(angle_rad)
  226. c = math.cos(angle_rad)
  227. rotation_tensor = torch.tensor([
  228. [c, -s, 0, 0],
  229. [s, c, 0, 0],
  230. [0, 0, 1, 0],
  231. [0, 0, 0, 1],
  232. ], dtype=torch.float64)
  233. transform_tensor @= rotation_tensor
  234. affine_tensor = F.affine_grid(
  235. transform_tensor[:3].unsqueeze(0).to(torch.float32),
  236. ct_tensor.size(),
  237. )
  238. augmented_chunk = F.grid_sample(
  239. ct_tensor,
  240. affine_tensor,
  241. padding_mode='border'
  242. ).to('cpu')
  243. if 'noise' in augmentation_dict:
  244. noise_tensor = torch.randn_like(augmented_chunk)
  245. noise_tensor *= augmentation_dict['noise']
  246. augmented_chunk += noise_tensor
  247. return augmented_chunk[0], center_irc
  248. class LunaDataset(Dataset):
  249. def __init__(self,
  250. test_stride=0,
  251. isTestSet_bool=None,
  252. series_uid=None,
  253. sortby_str='random',
  254. ratio_int=0,
  255. augmentation_dict=None,
  256. noduleInfo_list=None,
  257. ):
  258. self.ratio_int = ratio_int
  259. self.augmentation_dict = augmentation_dict
  260. if noduleInfo_list:
  261. self.noduleInfo_list = copy.copy(noduleInfo_list)
  262. self.use_cache = False
  263. else:
  264. self.noduleInfo_list = copy.copy(getNoduleInfoList())
  265. self.use_cache = True
  266. if series_uid:
  267. self.series_list = [series_uid]
  268. else:
  269. self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
  270. if isTestSet_bool:
  271. assert test_stride > 0, test_stride
  272. self.series_list = self.series_list[::test_stride]
  273. assert self.series_list
  274. elif test_stride > 0:
  275. del self.series_list[::test_stride]
  276. assert self.series_list
  277. series_set = set(self.series_list)
  278. self.noduleInfo_list = [x for x in self.noduleInfo_list if x.series_uid in series_set]
  279. if sortby_str == 'random':
  280. random.shuffle(self.noduleInfo_list)
  281. elif sortby_str == 'series_uid':
  282. self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
  283. elif sortby_str == 'malignancy_size':
  284. pass
  285. else:
  286. raise Exception("Unknown sort: " + repr(sortby_str))
  287. self.benign_list = [nt for nt in self.noduleInfo_list if not nt.isMalignant_bool]
  288. self.malignant_list = [nt for nt in self.noduleInfo_list if nt.isMalignant_bool]
  289. log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
  290. self,
  291. len(self.noduleInfo_list),
  292. "testing" if isTestSet_bool else "training",
  293. len(self.benign_list),
  294. len(self.malignant_list),
  295. '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
  296. ))
  297. def shuffleSamples(self):
  298. if self.ratio_int:
  299. random.shuffle(self.benign_list)
  300. random.shuffle(self.malignant_list)
  301. def __len__(self):
  302. if self.ratio_int:
  303. # return 20000
  304. return 200000
  305. else:
  306. return len(self.noduleInfo_list)
  307. def __getitem__(self, ndx):
  308. if self.ratio_int:
  309. malignant_ndx = ndx // (self.ratio_int + 1)
  310. if ndx % (self.ratio_int + 1):
  311. benign_ndx = ndx - 1 - malignant_ndx
  312. nodule_tup = self.benign_list[benign_ndx % len(self.benign_list)]
  313. else:
  314. nodule_tup = self.malignant_list[malignant_ndx % len(self.malignant_list)]
  315. else:
  316. nodule_tup = self.noduleInfo_list[ndx]
  317. width_irc = (24, 48, 48)
  318. if self.augmentation_dict:
  319. nodule_t, center_irc = getCtAugmentedNodule(
  320. self.augmentation_dict,
  321. nodule_tup.series_uid,
  322. nodule_tup.center_xyz,
  323. width_irc,
  324. self.use_cache,
  325. )
  326. elif self.use_cache:
  327. nodule_ary, center_irc = getCtRawNodule(
  328. nodule_tup.series_uid,
  329. nodule_tup.center_xyz,
  330. width_irc,
  331. )
  332. nodule_t = torch.from_numpy(nodule_ary).to(torch.float32)
  333. nodule_t = nodule_t.unsqueeze(0)
  334. else:
  335. ct = getCt(nodule_tup.series_uid)
  336. nodule_ary, center_irc = ct.getRawNodule(
  337. nodule_tup.center_xyz,
  338. width_irc,
  339. )
  340. nodule_t = torch.from_numpy(nodule_ary).to(torch.float32)
  341. nodule_t = nodule_t.unsqueeze(0)
  342. malignant_tensor = torch.tensor([
  343. not nodule_tup.isMalignant_bool,
  344. nodule_tup.isMalignant_bool
  345. ],
  346. dtype=torch.long,
  347. )
  348. # log.debug([type(center_irc), center_irc])
  349. return nodule_t, malignant_tensor, nodule_tup.series_uid, torch.tensor(center_irc)
  350. class Luna2dSegmentationDataset(Dataset):
  351. def __init__(self,
  352. test_stride=0,
  353. isTestSet_bool=None,
  354. series_uid=None,
  355. contextSlices_count=2,
  356. augmentation_dict=None,
  357. fullCt_bool=False,
  358. ):
  359. self.contextSlices_count = contextSlices_count
  360. self.augmentation_dict = augmentation_dict
  361. if series_uid:
  362. self.series_list = [series_uid]
  363. else:
  364. self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
  365. if isTestSet_bool:
  366. assert test_stride > 0, test_stride
  367. self.series_list = self.series_list[::test_stride]
  368. assert self.series_list
  369. elif test_stride > 0:
  370. del self.series_list[::test_stride]
  371. assert self.series_list
  372. self.sample_list = []
  373. for series_uid in self.series_list:
  374. if fullCt_bool:
  375. self.sample_list.extend([(series_uid, ct_ndx) for ct_ndx in range(getCt(series_uid).ary.shape[0])])
  376. else:
  377. self.sample_list.extend([(series_uid, ct_ndx) for ct_ndx in range(getCtSampleSize(series_uid))])
  378. log.info("{!r}: {} {} series, {} slices".format(
  379. self,
  380. len(self.series_list),
  381. {None: 'general', True: 'testing', False: 'training'}[isTestSet_bool],
  382. len(self.sample_list),
  383. ))
  384. def __len__(self):
  385. return len(self.sample_list) #// 100
  386. def __getitem__(self, ndx):
  387. if isinstance(ndx, int):
  388. series_uid, sample_ndx = self.sample_list[ndx % len(self.sample_list)]
  389. ct = getCt(series_uid)
  390. ct_ndx = self.sample_list[sample_ndx][1]
  391. useAugmentation_bool = False
  392. else:
  393. series_uid, ct_ndx, useAugmentation_bool = ndx
  394. ct = getCt(series_uid)
  395. ct_tensor = torch.zeros((self.contextSlices_count * 2 + 1 + 1, 512, 512))
  396. start_ndx = ct_ndx - self.contextSlices_count
  397. end_ndx = ct_ndx + self.contextSlices_count + 1
  398. for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
  399. context_ndx = max(context_ndx, 0)
  400. context_ndx = min(context_ndx, ct.ary.shape[0] - 1)
  401. ct_tensor[i] = torch.from_numpy(ct.ary[context_ndx].astype(np.float32))
  402. ct_tensor /= 1000
  403. mask_tup = ct.build2dLungMask(ct_ndx)
  404. ct_tensor[-1] = torch.from_numpy(mask_tup.body_mask.astype(np.float32))
  405. nodule_tensor = torch.from_numpy(
  406. (mask_tup.mal_mask | mask_tup.ben_mask).astype(np.float32)
  407. ).unsqueeze(0)
  408. ben_tensor = torch.from_numpy(mask_tup.ben_mask.astype(np.float32))
  409. mal_tensor = torch.from_numpy(mask_tup.mal_mask.astype(np.float32))
  410. label_int = mal_tensor.max() + ben_tensor.max() * 2
  411. if self.augmentation_dict and useAugmentation_bool:
  412. if 'rotate' in self.augmentation_dict:
  413. if random.random() > 0.5:
  414. ct_tensor = ct_tensor.rot90(1, [1, 2])
  415. nodule_tensor = nodule_tensor.rot90(1, [1, 2])
  416. if 'flip' in self.augmentation_dict:
  417. dims = [d+1 for d in range(2) if random.random() > 0.5]
  418. if dims:
  419. ct_tensor = ct_tensor.flip(dims)
  420. nodule_tensor = nodule_tensor.flip(dims)
  421. if 'noise' in self.augmentation_dict:
  422. noise_tensor = torch.randn_like(ct_tensor)
  423. noise_tensor *= self.augmentation_dict['noise']
  424. ct_tensor += noise_tensor
  425. return ct_tensor, nodule_tensor, label_int, ben_tensor, mal_tensor, ct.series_uid, ct_ndx
  426. class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
  427. def __init__(self, *args, batch_size=80, **kwargs):
  428. self.needsShuffle_bool = True
  429. self.batch_size = batch_size
  430. # self.rotate_frac = 0.5 * len(self.series_list) / len(self)
  431. super().__init__(*args, **kwargs)
  432. def __len__(self):
  433. return 50000
  434. def __getitem__(self, ndx):
  435. if self.needsShuffle_bool:
  436. random.shuffle(self.series_list)
  437. self.needsShuffle_bool = False
  438. if isinstance(ndx, int):
  439. if ndx % self.batch_size == 0:
  440. self.series_list.append(self.series_list.pop(0))
  441. series_uid = self.series_list[ndx % ctCache_depth]
  442. ct = getCt(series_uid)
  443. if ndx % 3 == 0:
  444. ct_ndx = random.choice(ct.malignant_indexes or ct.benign_indexes)
  445. elif ndx % 3 == 1:
  446. ct_ndx = random.choice(ct.benign_indexes)
  447. elif ndx % 3 == 2:
  448. ct_ndx = random.choice(list(range(ct.ary.shape[0])))
  449. useAugmentation_bool = True
  450. else:
  451. series_uid, ct_ndx, useAugmentation_bool = ndx
  452. return super().__getitem__((series_uid, ct_ndx, useAugmentation_bool))