dsets.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. import copy
  2. import csv
  3. import functools
  4. import glob
  5. import itertools
  6. import math
  7. import os
  8. import random
  9. from collections import namedtuple
  10. import SimpleITK as sitk
  11. import scipy.ndimage.morphology
  12. import numpy as np
  13. import torch
  14. import torch.cuda
  15. from torch.utils.data import Dataset, DataLoader
  16. from torch.utils.data.sampler import Sampler
  17. from util.disk import getCache
  18. from util.util import XyzTuple, xyz2irc, IrcTuple
  19. from util.logconf import logging
  20. from util.affine import affine_grid_generator
  21. log = logging.getLogger(__name__)
  22. # log.setLevel(logging.WARN)
  23. log.setLevel(logging.INFO)
  24. log.setLevel(logging.DEBUG)
  25. raw_cache = getCache('part2ch11_raw')
  26. cubic_cache = getCache('part2ch11_cubic')
  27. NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
  28. @functools.lru_cache(1)
  29. def getNoduleInfoList(requireDataOnDisk_bool=True):
  30. # We construct a set with all series_uids that are present on disk.
  31. # This will let us use the data, even if we haven't downloaded all of
  32. # the subsets yet.
  33. mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
  34. dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
  35. diameter_dict = {}
  36. with open('data/part2/luna/annotations.csv', "r") as f:
  37. for row in list(csv.reader(f))[1:]:
  38. series_uid = row[0]
  39. annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
  40. annotationDiameter_mm = float(row[4])
  41. diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
  42. noduleInfo_list = []
  43. with open('data/part2/luna/candidates.csv', "r") as f:
  44. for row in list(csv.reader(f))[1:]:
  45. series_uid = row[0]
  46. if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
  47. continue
  48. isMalignant_bool = bool(int(row[4]))
  49. candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
  50. candidateDiameter_mm = 0.0
  51. for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
  52. for i in range(3):
  53. delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
  54. if delta_mm > annotationDiameter_mm / 4:
  55. break
  56. else:
  57. candidateDiameter_mm = annotationDiameter_mm
  58. break
  59. noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
  60. noduleInfo_list.sort(reverse=True)
  61. return noduleInfo_list
  62. class Ct(object):
  63. def __init__(self, series_uid, buildMasks_bool=True):
  64. mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
  65. ct_mhd = sitk.ReadImage(mhd_path)
  66. ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  67. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  68. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  69. # This converts HU to g/cc.
  70. ct_ary += 1000
  71. ct_ary /= 1000
  72. # This gets rid of negative density stuff used to indicate out-of-FOV
  73. ct_ary[ct_ary < 0] = 0
  74. # This nukes any weird hotspots and clamps bone down
  75. ct_ary[ct_ary > 2] = 2
  76. self.series_uid = series_uid
  77. self.ary = ct_ary
  78. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  79. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  80. self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
  81. noduleInfo_list = getNoduleInfoList()
  82. self.benignInfo_list = [ni_tup
  83. for ni_tup in noduleInfo_list
  84. if not ni_tup.isMalignant_bool
  85. and ni_tup.series_uid == self.series_uid]
  86. self.benign_mask = self.buildAnnotationMask(self.benignInfo_list)[0]
  87. self.benign_indexes = sorted(set(self.benign_mask.nonzero()[0]))
  88. self.malignantInfo_list = [ni_tup
  89. for ni_tup in noduleInfo_list
  90. if ni_tup.isMalignant_bool
  91. and ni_tup.series_uid == self.series_uid]
  92. self.malignant_mask = self.buildAnnotationMask(self.malignantInfo_list)[0]
  93. self.malignant_indexes = sorted(set(self.malignant_mask.nonzero()[0]))
  94. def buildAnnotationMask(self, noduleInfo_list, threshold_gcc = 0.5):
  95. boundingBox_ary = np.zeros_like(self.ary, dtype=np.bool)
  96. for noduleInfo_tup in noduleInfo_list:
  97. center_irc = xyz2irc(noduleInfo_tup.center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
  98. center_index = int(center_irc.index)
  99. center_row = int(center_irc.row)
  100. center_col = int(center_irc.col)
  101. index_radius = 2
  102. try:
  103. while self.ary[center_index + index_radius, center_row, center_col] > threshold_gcc and \
  104. self.ary[center_index - index_radius, center_row, center_col] > threshold_gcc:
  105. index_radius += 1
  106. except IndexError:
  107. index_radius -= 1
  108. row_radius = 2
  109. try:
  110. while self.ary[center_index, center_row + row_radius, center_col] > threshold_gcc and \
  111. self.ary[center_index, center_row - row_radius, center_col] > threshold_gcc:
  112. row_radius += 1
  113. except IndexError:
  114. row_radius -= 1
  115. col_radius = 2
  116. try:
  117. while self.ary[center_index, center_row, center_col + col_radius] > threshold_gcc and \
  118. self.ary[center_index, center_row, center_col - col_radius] > threshold_gcc:
  119. col_radius += 1
  120. except IndexError:
  121. col_radius -= 1
  122. # assert index_radius > 0, repr([noduleInfo_tup.center_xyz, center_irc, self.ary[center_index, center_row, center_col]])
  123. # assert row_radius > 0
  124. # assert col_radius > 0
  125. slice_tup = (
  126. slice(
  127. # max(0, center_index - index_radius),
  128. center_index - index_radius,
  129. center_index + index_radius + 1,
  130. ),
  131. slice(
  132. # max(0, center_row - row_radius),
  133. center_row - row_radius,
  134. center_row + row_radius + 1,
  135. ),
  136. slice(
  137. # max(0, center_col - col_radius),
  138. center_col - col_radius,
  139. center_col + row_radius + 1,
  140. ),
  141. )
  142. boundingBox_ary[slice_tup] = True
  143. thresholded_ary = boundingBox_ary & (self.ary > threshold_gcc)
  144. mask_ary = scipy.ndimage.morphology.binary_dilation(thresholded_ary, iterations=2)
  145. return mask_ary, thresholded_ary, boundingBox_ary
  146. def build2dLungMask(self, mask_ndx, threshold_gcc = 0.7):
  147. dense_mask = self.ary[mask_ndx] > threshold_gcc
  148. denoise_mask = scipy.ndimage.morphology.binary_closing(dense_mask, iterations=2)
  149. tissue_mask = scipy.ndimage.morphology.binary_opening(denoise_mask, iterations=10)
  150. body_mask = scipy.ndimage.morphology.binary_fill_holes(tissue_mask)
  151. air_mask = scipy.ndimage.morphology.binary_fill_holes(body_mask & ~tissue_mask)
  152. lung_mask = scipy.ndimage.morphology.binary_dilation(air_mask, iterations=2)
  153. return air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask
  154. def build3dLungMask(self):
  155. air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask = mask_list = \
  156. [np.zeros_like(self.ary, dtype=np.bool) for _ in range(6)]
  157. for mask_ndx in range(self.ary.shape[0]):
  158. for i, mask_ary in enumerate(self.build2dLungMask(mask_ndx)):
  159. mask_list[i][mask_ndx] = mask_ary
  160. return air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask
  161. def getRawNodule(self, center_xyz, width_irc):
  162. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
  163. slice_list = []
  164. for axis, center_val in enumerate(center_irc):
  165. start_ndx = int(round(center_val - width_irc[axis]/2))
  166. end_ndx = int(start_ndx + width_irc[axis])
  167. assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  168. if start_ndx < 0:
  169. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  170. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  171. start_ndx = 0
  172. end_ndx = int(width_irc[axis])
  173. if end_ndx > self.ary.shape[axis]:
  174. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  175. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  176. end_ndx = self.ary.shape[axis]
  177. start_ndx = int(self.ary.shape[axis] - width_irc[axis])
  178. slice_list.append(slice(start_ndx, end_ndx))
  179. ct_chunk = self.ary[tuple(slice_list)]
  180. return ct_chunk, center_irc
  181. def getCubicInputChunk(self, center_xyz, maxWidth_mm):
  182. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
  183. ct_start = [int(round(i)) for i in xyz2irc(tuple(x - maxWidth_mm / 2 for x in center_xyz), self.origin_xyz, self.vxSize_xyz, self.direction_tup)]
  184. ct_end = [int(round(i)) + 1 for i in xyz2irc(tuple(x + maxWidth_mm / 2 for x in center_xyz), self.origin_xyz, self.vxSize_xyz, self.direction_tup)]
  185. for axis in range(3):
  186. if ct_start[axis] > ct_end[axis]:
  187. ct_start[axis], ct_end[axis] = ct_end[axis], ct_start[axis]
  188. pad_start = [0, 0, 0]
  189. pad_end = [ct_end[axis] - ct_start[axis] for axis in range(3)]
  190. # log.info([ct_end, ct_start, pad_end])
  191. chunk_ary = np.zeros(pad_end, dtype=np.float32)
  192. for axis in range(3):
  193. if ct_start[axis] < 0:
  194. pad_start[axis] = -ct_start[axis]
  195. ct_start[axis] = 0
  196. if ct_end[axis] > self.ary.shape[axis]:
  197. pad_end[axis] -= ct_end[axis] - self.ary.shape[axis]
  198. ct_end[axis] = self.ary.shape[axis]
  199. pad_slices = tuple(slice(s,e) for s, e in zip(pad_start, pad_end))
  200. ct_slices = tuple(slice(s,e) for s, e in zip(ct_start, ct_end))
  201. chunk_ary[pad_slices] = self.ary[ct_slices]
  202. return chunk_ary, center_irc
  203. ctCache_depth = 3
  204. @functools.lru_cache(ctCache_depth, typed=True)
  205. def getCt(series_uid):
  206. return Ct(series_uid)
  207. @raw_cache.memoize(typed=True)
  208. def getCtSize(series_uid):
  209. ct = Ct(series_uid, buildMasks_bool=False)
  210. return tuple(ct.ary.shape)
  211. # @raw_cache.memoize(typed=True)
  212. # def getCtLungExtents(series_uid):
  213. # ct = getCt(series_uid)
  214. # return (int(min(ct.lung_indexes)), int(max(ct.lung_indexes)))
  215. @raw_cache.memoize(typed=True)
  216. def getCtRawNodule(series_uid, center_xyz, width_irc):
  217. ct = getCt(series_uid)
  218. ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
  219. return ct_chunk, center_irc
  220. # clamp_value = 1.5
  221. @functools.lru_cache(1, typed=True)
  222. @cubic_cache.memoize(typed=True)
  223. def getCtCubicChunk(series_uid, center_xyz, maxWidth_mm):
  224. ct = getCt(series_uid)
  225. ct_chunk, center_irc = ct.getCubicInputChunk(center_xyz, maxWidth_mm)
  226. # # ct_chunk has been clamped to [0, 2] at this point
  227. # # We are going to convert to uint8 to reduce size on disk and loading time
  228. # ct_chunk[ct_chunk > clamp_value] = clamp_value
  229. # ct_chunk *= 255/clamp_value
  230. # ct_chunk = np.array(ct_chunk, dtype=np.uint8)
  231. return ct_chunk, center_irc
  232. def getCtAugmentedNodule(augmentation_dict, series_uid, center_xyz, width_mm, voxels_int, maxWidth_mm=32.0, use_cache=True):
  233. assert width_mm <= maxWidth_mm
  234. if use_cache:
  235. cubic_chunk, center_irc = getCtCubicChunk(series_uid, center_xyz, maxWidth_mm)
  236. else:
  237. ct = getCt(series_uid)
  238. ct_chunk, center_irc = ct.getCubicInputChunk(center_xyz, maxWidth_mm)
  239. slice_list = []
  240. for axis in range(3):
  241. crop_size = cubic_chunk.shape[axis] * width_mm / maxWidth_mm
  242. crop_size = int(math.ceil(crop_size))
  243. start_ndx = (cubic_chunk.shape[axis] - crop_size) // 2
  244. end_ndx = start_ndx + crop_size
  245. slice_list.append(slice(start_ndx, end_ndx))
  246. cropped_chunk = cubic_chunk[slice_list]
  247. # # inflate cropped_chunk back to float32
  248. # cropped_chunk = np.array(cropped_chunk, dtype=np.float32)
  249. # cropped_chunk *= clamp_value/255
  250. cropped_tensor = torch.tensor(cropped_chunk).unsqueeze(0).unsqueeze(0)
  251. transform_tensor = torch.eye(4).to(torch.float64)
  252. # Scale and Mirror
  253. for i in range(3):
  254. if 'scale' in augmentation_dict:
  255. scale_float = augmentation_dict['scale']
  256. transform_tensor[i,i] *= 1.0 - scale_float/2.0 + (random.random() * scale_float)
  257. if 'mirror' in augmentation_dict:
  258. if random.random() > 0.5:
  259. transform_tensor[i,i] *= -1
  260. # Rotate
  261. if 'rotate' in augmentation_dict:
  262. angle_rad = random.random() * math.pi * 2
  263. s = math.sin(angle_rad)
  264. c = math.cos(angle_rad)
  265. c1 = 1 - c
  266. axis_tensor = torch.rand([3], dtype=torch.float64)
  267. axis_tensor /= axis_tensor.pow(2).sum().pow(0.5)
  268. z, y, x = axis_tensor
  269. rotation_tensor = torch.tensor([
  270. [x*x*c1 + c, y*x*c1 - z*s, z*x*c1 + y*s, 0],
  271. [x*y*c1 + z*s, y*y*c1 + c, z*y*c1 - x*s, 0],
  272. [x*z*c1 - y*s, y*z*c1 + x*s, z*z*c1 + c, 0],
  273. [0, 0, 0, 1],
  274. ], dtype=torch.float64)
  275. transform_tensor @= rotation_tensor
  276. # Transform into final desired shape
  277. affine_tensor = affine_grid_generator(
  278. transform_tensor[:3].unsqueeze(0).to(torch.float32),
  279. torch.Size([1, 1, voxels_int, voxels_int, voxels_int])
  280. )
  281. zoomed_chunk = torch.nn.functional.grid_sample(
  282. cropped_tensor,
  283. affine_tensor,
  284. padding_mode='border'
  285. ).to('cpu')
  286. # Noise
  287. if 'noise' in augmentation_dict:
  288. noise_tensor = torch.randn(
  289. zoomed_chunk.size(),
  290. dtype=zoomed_chunk.dtype,
  291. )
  292. noise_tensor *= augmentation_dict['noise']
  293. zoomed_chunk += noise_tensor
  294. return zoomed_chunk[0,0], center_irc
  295. class LunaPrepcacheDataset(Dataset):
  296. def __init__(self):
  297. self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
  298. def __len__(self):
  299. return len(self.series_list)
  300. def __getitem__(self, ndx):
  301. getCtSize(self.series_list[ndx])
  302. # getCtLungExtents(self.series_list[ndx])
  303. return 0
  304. class LunaClassificationDataset(Dataset):
  305. def __init__(self,
  306. test_stride=0,
  307. isTestSet_bool=None,
  308. series_uid=None,
  309. sortby_str='random',
  310. ratio_int=0,
  311. scaled_bool=False,
  312. multiscaled_bool=False,
  313. augmented_bool=False,
  314. noduleInfo_list=None,
  315. ):
  316. self.ratio_int = ratio_int
  317. self.scaled_bool = scaled_bool
  318. self.multiscaled_bool = multiscaled_bool
  319. if augmented_bool:
  320. self.augmentation_dict = {
  321. 'mirror': True,
  322. 'rotate': True,
  323. }
  324. if isTestSet_bool:
  325. self.augmentation_dict['scale'] = 0.25
  326. else:
  327. self.augmentation_dict['scale'] = 0.5
  328. self.augmentation_dict['noise'] = 0.1
  329. else:
  330. self.augmentation_dict = {}
  331. if noduleInfo_list:
  332. self.noduleInfo_list = copy.copy(noduleInfo_list)
  333. self.use_cache = False
  334. else:
  335. self.noduleInfo_list = copy.copy(getNoduleInfoList())
  336. self.use_cache = True
  337. if series_uid:
  338. self.noduleInfo_list = [x for x in self.noduleInfo_list if x[2] == series_uid]
  339. if test_stride > 1:
  340. if isTestSet_bool:
  341. self.noduleInfo_list = self.noduleInfo_list[::test_stride]
  342. else:
  343. del self.noduleInfo_list[::test_stride]
  344. if sortby_str == 'random':
  345. random.shuffle(self.noduleInfo_list)
  346. elif sortby_str == 'series_uid':
  347. self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
  348. elif sortby_str == 'malignancy_size':
  349. pass
  350. else:
  351. raise Exception("Unknown sort: " + repr(sortby_str))
  352. self.benignIndex_list = [i for i, x in enumerate(self.noduleInfo_list) if not x[0]]
  353. self.malignantIndex_list = [i for i, x in enumerate(self.noduleInfo_list) if x[0]]
  354. log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
  355. self,
  356. len(self.noduleInfo_list),
  357. "testing" if isTestSet_bool else "training",
  358. len(self.benignIndex_list),
  359. len(self.malignantIndex_list),
  360. '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
  361. ))
  362. def shuffleSamples(self):
  363. if self.ratio_int:
  364. random.shuffle(self.benignIndex_list)
  365. random.shuffle(self.malignantIndex_list)
  366. def __len__(self):
  367. if self.ratio_int:
  368. # return 10000
  369. return 100000
  370. elif self.augmentation_dict:
  371. return len(self.noduleInfo_list) * 5
  372. else:
  373. return len(self.noduleInfo_list)
  374. def __getitem__(self, ndx):
  375. if self.ratio_int:
  376. malignant_ndx = ndx // (self.ratio_int + 1)
  377. if ndx % (self.ratio_int + 1):
  378. benign_ndx = ndx - 1 - malignant_ndx
  379. nodule_ndx = self.benignIndex_list[benign_ndx % len(self.benignIndex_list)]
  380. else:
  381. nodule_ndx = self.malignantIndex_list[malignant_ndx % len(self.malignantIndex_list)]
  382. augmentation_dict = self.augmentation_dict
  383. else:
  384. nodule_ndx = ndx % len(self.noduleInfo_list)
  385. if ndx < len(self.noduleInfo_list):
  386. augmentation_dict = {}
  387. else:
  388. augmentation_dict = self.augmentation_dict
  389. isMalignant_bool, _diameter_mm, series_uid, center_xyz = self.noduleInfo_list[nodule_ndx]
  390. if self.scaled_bool:
  391. channel_list = []
  392. voxels_int = 32
  393. if self.multiscaled_bool:
  394. width_list = [8, 16, 32]
  395. else:
  396. width_list = [24]
  397. for width_mm in width_list:
  398. nodule_ary, center_irc = getCtAugmentedNodule(augmentation_dict, series_uid, center_xyz, width_mm, voxels_int)
  399. # in: dim=3, Index x Row x Col
  400. # out: dim=4, Channel x Index x Row x Col
  401. nodule_ary = nodule_ary.unsqueeze(0)
  402. channel_list.append(nodule_ary)
  403. nodule_tensor = torch.cat(channel_list)
  404. else:
  405. nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, (32, 32, 32))
  406. nodule_ary = np.expand_dims(nodule_ary, 0)
  407. nodule_tensor = torch.from_numpy(nodule_ary)
  408. # dim=1
  409. malignant_tensor = torch.tensor([isMalignant_bool], dtype=torch.float32)
  410. return nodule_tensor, malignant_tensor, series_uid, center_irc
  411. #
  412. # return malignant_tensor, diameter_mm, series_uid, center_irc, nodule_tensor
  413. class Luna2dSegmentationDataset(Dataset):
  414. purpose_str = 'general'
  415. def __init__(self,
  416. contextSlices_count=2,
  417. series_uid=None,
  418. test_stride=0,
  419. ):
  420. self.contextSlices_count = contextSlices_count
  421. if series_uid:
  422. self.series_list = [series_uid]
  423. else:
  424. self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
  425. self.cullTrainOrTestSeries(test_stride)
  426. self.sample_list = []
  427. for series_uid in self.series_list:
  428. self.sample_list.extend([(series_uid, i) for i in range(int(getCtSize(series_uid)[0]))])
  429. log.info("{!r}: {} {} series, {} slices".format(
  430. self,
  431. len(self.series_list),
  432. self.purpose_str,
  433. len(self.sample_list),
  434. ))
  435. def cullTrainOrTestSeries(self, test_stride):
  436. assert test_stride == 0
  437. def __len__(self):
  438. return len(self.sample_list) #// 100
  439. def __getitem__(self, ndx):
  440. if isinstance(ndx, int):
  441. series_uid, sample_ndx = self.sample_list[ndx % len(self.sample_list)]
  442. else:
  443. series_uid, sample_ndx = ndx
  444. ct = getCt(series_uid)
  445. ct_tensor = torch.zeros((self.contextSlices_count * 2 + 2, 512, 512))
  446. masks_tensor = torch.zeros((2, 512, 512))
  447. start_ndx = sample_ndx - self.contextSlices_count
  448. end_ndx = sample_ndx + self.contextSlices_count + 1
  449. for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
  450. context_ndx = max(context_ndx, 0)
  451. context_ndx = min(context_ndx, ct.ary.shape[0] - 1)
  452. ct_tensor[i] = torch.from_numpy(ct.ary[context_ndx].astype(np.float32))
  453. air_mask, lung_mask = ct.build2dLungMask(sample_ndx)[:2]
  454. ct_tensor[-1] = torch.from_numpy(lung_mask.astype(np.float32))
  455. mal_mask = ct.malignant_mask[sample_ndx] & lung_mask
  456. ben_mask = ct.benign_mask[sample_ndx] & air_mask
  457. masks_tensor[0] = torch.from_numpy(mal_mask.astype(np.float32))
  458. masks_tensor[1] = torch.from_numpy((mal_mask | ben_mask).astype(np.float32))
  459. # masks_tensor[1] = torch.from_numpy(ben_mask.astype(np.float32))
  460. return ct_tensor.contiguous(), masks_tensor.contiguous(), ct.series_uid, sample_ndx
  461. class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
  462. purpose_str = 'training'
  463. def __init__(self, *args, **kwargs):
  464. self.needsShuffle_bool = True
  465. super().__init__(*args, **kwargs)
  466. def cullTrainOrTestSeries(self, test_stride):
  467. assert test_stride > 0, test_stride
  468. del self.series_list[::test_stride]
  469. assert self.series_list
  470. def __len__(self):
  471. # return 100
  472. # return 1000
  473. # return 10000
  474. return 20000
  475. # return 40000
  476. def __getitem__(self, ndx):
  477. if self.needsShuffle_bool:
  478. random.shuffle(self.series_list)
  479. self.needsShuffle_bool = False
  480. if random.random() < 0.01:
  481. self.series_list.append(self.series_list.pop(0))
  482. if isinstance(ndx, int):
  483. series_uid = self.series_list[ndx % ctCache_depth]
  484. ct = getCt(series_uid)
  485. sample_ndx = random.choice(ct.malignant_indexes or ct.benign_indexes)
  486. # series_uid, sample_ndx = self.sample_list[ndx % len(self.sample_list)]
  487. else:
  488. series_uid, sample_ndx = ndx
  489. # if ndx % 2 == 0:
  490. # sample_ndx = random.choice(ct.malignant_indexes or ct.benign_indexes)
  491. # else: #if ndx % 2 == 2:
  492. # sample_ndx = random.choice(ct.benign_indexes)
  493. # else:
  494. # sample_ndx = random.randint(*self.series2extents_dict[series_uid])
  495. return super().__getitem__((series_uid, sample_ndx))
  496. class TestingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
  497. purpose_str = 'testing'
  498. def cullTrainOrTestSeries(self, test_stride):
  499. assert test_stride > 0
  500. self.series_list = self.series_list[::test_stride]
  501. assert self.series_list