dsets.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. import copy
  2. import csv
  3. import functools
  4. import glob
  5. import math
  6. import os
  7. import random
  8. from collections import namedtuple
  9. import SimpleITK as sitk
  10. import numpy as np
  11. import scipy.ndimage.morphology as morph
  12. import torch
  13. import torch.cuda
  14. import torch.nn.functional as F
  15. from torch.utils.data import Dataset
  16. from util.disk import getCache
  17. from util.util import XyzTuple, xyz2irc
  18. from util.logconf import logging
  19. log = logging.getLogger(__name__)
  20. # log.setLevel(logging.WARN)
  21. # log.setLevel(logging.INFO)
  22. log.setLevel(logging.DEBUG)
  23. raw_cache = getCache('part2ch13_raw')
  24. MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask')
  25. CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz')
  26. @functools.lru_cache(1)
  27. def getCandidateInfoList(requireOnDisk_bool=True):
  28. # We construct a set with all series_uids that are present on disk.
  29. # This will let us use the data, even if we haven't downloaded all of
  30. # the subsets yet.
  31. mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
  32. presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
  33. candidateInfo_list = []
  34. with open('data/part2/luna/annotations_with_malignancy.csv', "r") as f:
  35. for row in list(csv.reader(f))[1:]:
  36. series_uid = row[0]
  37. annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
  38. annotationDiameter_mm = float(row[4])
  39. isMal_bool = {'False': False, 'True': True}[row[5]]
  40. candidateInfo_list.append(
  41. CandidateInfoTuple(
  42. True,
  43. True,
  44. isMal_bool,
  45. annotationDiameter_mm,
  46. series_uid,
  47. annotationCenter_xyz,
  48. )
  49. )
  50. with open('data/part2/luna/candidates.csv', "r") as f:
  51. for row in list(csv.reader(f))[1:]:
  52. series_uid = row[0]
  53. if series_uid not in presentOnDisk_set and requireOnDisk_bool:
  54. continue
  55. isNodule_bool = bool(int(row[4]))
  56. candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
  57. if not isNodule_bool:
  58. candidateInfo_list.append(
  59. CandidateInfoTuple(
  60. False,
  61. False,
  62. False,
  63. 0.0,
  64. series_uid,
  65. candidateCenter_xyz,
  66. )
  67. )
  68. candidateInfo_list.sort(reverse=True)
  69. return candidateInfo_list
  70. @functools.lru_cache(1)
  71. def getCandidateInfoDict(requireOnDisk_bool=True):
  72. candidateInfo_list = getCandidateInfoList(requireOnDisk_bool)
  73. candidateInfo_dict = {}
  74. for candidateInfo_tup in candidateInfo_list:
  75. candidateInfo_dict.setdefault(candidateInfo_tup.series_uid,
  76. []).append(candidateInfo_tup)
  77. return candidateInfo_dict
  78. class Ct:
  79. def __init__(self, series_uid):
  80. mhd_path = glob.glob(
  81. 'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
  82. )[0]
  83. ct_mhd = sitk.ReadImage(mhd_path)
  84. self.hu_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  85. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  86. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  87. self.series_uid = series_uid
  88. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  89. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  90. self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
  91. candidateInfo_list = getCandidateInfoDict()[self.series_uid]
  92. self.positiveInfo_list = [
  93. candidate_tup
  94. for candidate_tup in candidateInfo_list
  95. if candidate_tup.isNodule_bool
  96. ]
  97. self.positive_mask = self.buildAnnotationMask(self.positiveInfo_list)
  98. self.positive_indexes = (self.positive_mask.sum(axis=(1,2))
  99. .nonzero()[0].tolist())
  100. def buildAnnotationMask(self, positiveInfo_list, threshold_hu = -700):
  101. boundingBox_a = np.zeros_like(self.hu_a, dtype=np.bool)
  102. for candidateInfo_tup in positiveInfo_list:
  103. center_irc = xyz2irc(
  104. candidateInfo_tup.center_xyz,
  105. self.origin_xyz,
  106. self.vxSize_xyz,
  107. self.direction_a,
  108. )
  109. ci = int(center_irc.index)
  110. cr = int(center_irc.row)
  111. cc = int(center_irc.col)
  112. index_radius = 2
  113. try:
  114. while self.hu_a[ci + index_radius, cr, cc] > threshold_hu and \
  115. self.hu_a[ci - index_radius, cr, cc] > threshold_hu:
  116. index_radius += 1
  117. except IndexError:
  118. index_radius -= 1
  119. row_radius = 2
  120. try:
  121. while self.hu_a[ci, cr + row_radius, cc] > threshold_hu and \
  122. self.hu_a[ci, cr - row_radius, cc] > threshold_hu:
  123. row_radius += 1
  124. except IndexError:
  125. row_radius -= 1
  126. col_radius = 2
  127. try:
  128. while self.hu_a[ci, cr, cc + col_radius] > threshold_hu and \
  129. self.hu_a[ci, cr, cc - col_radius] > threshold_hu:
  130. col_radius += 1
  131. except IndexError:
  132. col_radius -= 1
  133. # assert index_radius > 0, repr([candidateInfo_tup.center_xyz, center_irc, self.hu_a[ci, cr, cc]])
  134. # assert row_radius > 0
  135. # assert col_radius > 0
  136. boundingBox_a[
  137. ci - index_radius: ci + index_radius + 1,
  138. cr - row_radius: cr + row_radius + 1,
  139. cc - col_radius: cc + col_radius + 1] = True
  140. mask_a = boundingBox_a & (self.hu_a > threshold_hu)
  141. return mask_a
  142. def getRawCandidate(self, center_xyz, width_irc):
  143. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz,
  144. self.direction_a)
  145. slice_list = []
  146. for axis, center_val in enumerate(center_irc):
  147. start_ndx = int(round(center_val - width_irc[axis]/2))
  148. end_ndx = int(start_ndx + width_irc[axis])
  149. assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  150. if start_ndx < 0:
  151. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  152. # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
  153. start_ndx = 0
  154. end_ndx = int(width_irc[axis])
  155. if end_ndx > self.hu_a.shape[axis]:
  156. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  157. # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
  158. end_ndx = self.hu_a.shape[axis]
  159. start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
  160. slice_list.append(slice(start_ndx, end_ndx))
  161. ct_chunk = self.hu_a[tuple(slice_list)]
  162. pos_chunk = self.positive_mask[tuple(slice_list)]
  163. return ct_chunk, pos_chunk, center_irc
  164. @functools.lru_cache(1, typed=True)
  165. def getCt(series_uid):
  166. return Ct(series_uid)
  167. @raw_cache.memoize(typed=True)
  168. def getCtRawCandidate(series_uid, center_xyz, width_irc):
  169. ct = getCt(series_uid)
  170. ct_chunk, pos_chunk, center_irc = ct.getRawCandidate(center_xyz,
  171. width_irc)
  172. ct_chunk.clip(-1000, 1000, ct_chunk)
  173. return ct_chunk, pos_chunk, center_irc
  174. @raw_cache.memoize(typed=True)
  175. def getCtSampleSize(series_uid):
  176. ct = Ct(series_uid)
  177. return int(ct.hu_a.shape[0]), ct.positive_indexes
  178. class Luna2dSegmentationDataset(Dataset):
  179. def __init__(self,
  180. val_stride=0,
  181. isValSet_bool=None,
  182. series_uid=None,
  183. contextSlices_count=3,
  184. fullCt_bool=False,
  185. ):
  186. self.contextSlices_count = contextSlices_count
  187. self.fullCt_bool = fullCt_bool
  188. if series_uid:
  189. self.series_list = [series_uid]
  190. else:
  191. self.series_list = sorted(getCandidateInfoDict().keys())
  192. if isValSet_bool:
  193. assert val_stride > 0, val_stride
  194. self.series_list = self.series_list[::val_stride]
  195. assert self.series_list
  196. elif val_stride > 0:
  197. del self.series_list[::val_stride]
  198. assert self.series_list
  199. self.sample_list = []
  200. for series_uid in self.series_list:
  201. index_count, positive_indexes = getCtSampleSize(series_uid)
  202. if self.fullCt_bool:
  203. self.sample_list += [(series_uid, slice_ndx)
  204. for slice_ndx in range(index_count)]
  205. else:
  206. self.sample_list += [(series_uid, slice_ndx)
  207. for slice_ndx in positive_indexes]
  208. self.candidateInfo_list = getCandidateInfoList()
  209. series_set = set(self.series_list)
  210. self.candidateInfo_list = [cit for cit in self.candidateInfo_list
  211. if cit.series_uid in series_set]
  212. self.pos_list = [nt for nt in self.candidateInfo_list
  213. if nt.isNodule_bool]
  214. log.info("{!r}: {} {} series, {} slices, {} nodules".format(
  215. self,
  216. len(self.series_list),
  217. {None: 'general', True: 'validation', False: 'training'}[isValSet_bool],
  218. len(self.sample_list),
  219. len(self.pos_list),
  220. ))
  221. def __len__(self):
  222. return len(self.sample_list)
  223. def __getitem__(self, ndx):
  224. series_uid, slice_ndx = self.sample_list[ndx % len(self.sample_list)]
  225. return self.getitem_fullSlice(series_uid, slice_ndx)
  226. def getitem_fullSlice(self, series_uid, slice_ndx):
  227. ct = getCt(series_uid)
  228. ct_t = torch.zeros((self.contextSlices_count * 2 + 1, 512, 512))
  229. start_ndx = slice_ndx - self.contextSlices_count
  230. end_ndx = slice_ndx + self.contextSlices_count + 1
  231. for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
  232. context_ndx = max(context_ndx, 0)
  233. context_ndx = min(context_ndx, ct.hu_a.shape[0] - 1)
  234. ct_t[i] = torch.from_numpy(ct.hu_a[context_ndx].astype(np.float32))
  235. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  236. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  237. # The lower bound gets rid of negative density stuff used to indicate out-of-FOV
  238. # The upper bound nukes any weird hotspots and clamps bone down
  239. ct_t.clamp_(-1000, 1000)
  240. pos_t = torch.from_numpy(ct.positive_mask[slice_ndx]).unsqueeze(0)
  241. return ct_t, pos_t, ct.series_uid, slice_ndx
  242. class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
  243. def __init__(self, *args, **kwargs):
  244. super().__init__(*args, **kwargs)
  245. self.ratio_int = 2
  246. def __len__(self):
  247. return 300000
  248. def shuffleSamples(self):
  249. random.shuffle(self.candidateInfo_list)
  250. random.shuffle(self.pos_list)
  251. def __getitem__(self, ndx):
  252. candidateInfo_tup = self.pos_list[ndx % len(self.pos_list)]
  253. return self.getitem_trainingCrop(candidateInfo_tup)
  254. def getitem_trainingCrop(self, candidateInfo_tup):
  255. ct_a, pos_a, center_irc = getCtRawCandidate(
  256. candidateInfo_tup.series_uid,
  257. candidateInfo_tup.center_xyz,
  258. (7, 96, 96),
  259. )
  260. pos_a = pos_a[3:4]
  261. row_offset = random.randrange(0,32)
  262. col_offset = random.randrange(0,32)
  263. ct_t = torch.from_numpy(ct_a[:, row_offset:row_offset+64,
  264. col_offset:col_offset+64]).to(torch.float32)
  265. pos_t = torch.from_numpy(pos_a[:, row_offset:row_offset+64,
  266. col_offset:col_offset+64]).to(torch.long)
  267. slice_ndx = center_irc.index
  268. return ct_t, pos_t, candidateInfo_tup.series_uid, slice_ndx
  269. class PrepcacheLunaDataset(Dataset):
  270. def __init__(self, *args, **kwargs):
  271. super().__init__(*args, **kwargs)
  272. self.candidateInfo_list = getCandidateInfoList()
  273. self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool]
  274. self.seen_set = set()
  275. self.candidateInfo_list.sort(key=lambda x: x.series_uid)
  276. def __len__(self):
  277. return len(self.candidateInfo_list)
  278. def __getitem__(self, ndx):
  279. # candidate_t, pos_t, series_uid, center_t = super().__getitem__(ndx)
  280. candidateInfo_tup = self.candidateInfo_list[ndx]
  281. getCtRawCandidate(candidateInfo_tup.series_uid, candidateInfo_tup.center_xyz, (7, 96, 96))
  282. series_uid = candidateInfo_tup.series_uid
  283. if series_uid not in self.seen_set:
  284. self.seen_set.add(series_uid)
  285. getCtSampleSize(series_uid)
  286. # ct = getCt(series_uid)
  287. # for mask_ndx in ct.positive_indexes:
  288. # build2dLungMask(series_uid, mask_ndx)
  289. return 0, 1 #candidate_t, pos_t, series_uid, center_t
  290. class TvTrainingLuna2dSegmentationDataset(torch.utils.data.Dataset):
  291. def __init__(self, isValSet_bool=False, val_stride=10, contextSlices_count=3):
  292. assert contextSlices_count == 3
  293. data = torch.load('./imgs_and_masks.pt')
  294. suids = list(set(data['suids']))
  295. trn_mask_suids = torch.arange(len(suids)) % val_stride < (val_stride - 1)
  296. trn_suids = {s for i, s in zip(trn_mask_suids, suids) if i}
  297. trn_mask = torch.tensor([(s in trn_suids) for s in data["suids"]])
  298. if not isValSet_bool:
  299. self.imgs = data["imgs"][trn_mask]
  300. self.masks = data["masks"][trn_mask]
  301. self.suids = [s for s, i in zip(data["suids"], trn_mask) if i]
  302. else:
  303. self.imgs = data["imgs"][~trn_mask]
  304. self.masks = data["masks"][~trn_mask]
  305. self.suids = [s for s, i in zip(data["suids"], trn_mask) if not i]
  306. # discard spurious hotspots and clamp bone
  307. self.imgs.clamp_(-1000, 1000)
  308. self.imgs /= 1000
  309. def __len__(self):
  310. return len(self.imgs)
  311. def __getitem__(self, i):
  312. oh, ow = torch.randint(0, 32, (2,))
  313. sl = self.masks.size(1)//2
  314. return self.imgs[i, :, oh: oh + 64, ow: ow + 64], 1, self.masks[i, sl: sl+1, oh: oh + 64, ow: ow + 64].to(torch.float32), self.suids[i], 9999