dsets.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. import copy
  2. import csv
  3. import functools
  4. import glob
  5. import math
  6. import os
  7. import random
  8. from collections import namedtuple
  9. import SimpleITK as sitk
  10. import numpy as np
  11. import torch
  12. import torch.cuda
  13. import torch.nn.functional as F
  14. from torch.utils.data import Dataset
  15. from util.disk import getCache
  16. from util.util import XyzTuple, xyz2irc
  17. from util.logconf import logging
  18. log = logging.getLogger(__name__)
  19. # log.setLevel(logging.WARN)
  20. # log.setLevel(logging.INFO)
  21. log.setLevel(logging.DEBUG)
  22. raw_cache = getCache('part2ch14_raw')
  23. CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz')
  24. MaskTuple = namedtuple('MaskTuple', 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask')
  25. @functools.lru_cache(1)
  26. def getCandidateInfoList(requireDataOnDisk_bool=True):
  27. # We construct a set with all series_uids that are present on disk.
  28. # This will let us use the data, even if we haven't downloaded all of
  29. # the subsets yet.
  30. mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
  31. dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
  32. candidateInfo_list = []
  33. with open('data/part2/luna/annotations_with_malignancy.csv', "r") as f:
  34. for row in list(csv.reader(f))[1:]:
  35. series_uid = row[0]
  36. annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
  37. annotationDiameter_mm = float(row[4])
  38. isMal_bool = {'False': False, 'True': True}[row[5]]
  39. candidateInfo_list.append(CandidateInfoTuple(True, True, isMal_bool, annotationDiameter_mm, series_uid, annotationCenter_xyz))
  40. with open('data/part2/luna/candidates.csv', "r") as f:
  41. for row in list(csv.reader(f))[1:]:
  42. series_uid = row[0]
  43. if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
  44. continue
  45. isNodule_bool = bool(int(row[4]))
  46. candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
  47. if not isNodule_bool:
  48. candidateInfo_list.append(CandidateInfoTuple(False, False, False, 0.0, series_uid, candidateCenter_xyz))
  49. candidateInfo_list.sort(reverse=True)
  50. return candidateInfo_list
  51. @functools.lru_cache(1)
  52. def getCandidateInfoDict(requireDataOnDisk_bool=True):
  53. candidateInfo_list = getCandidateInfoList(requireDataOnDisk_bool)
  54. candidateInfo_dict = {}
  55. for candidateInfo_tup in candidateInfo_list:
  56. candidateInfo_dict.setdefault(candidateInfo_tup.series_uid, []).append(candidateInfo_tup)
  57. return candidateInfo_dict
  58. class Ct:
  59. def __init__(self, series_uid):
  60. mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
  61. ct_mhd = sitk.ReadImage(mhd_path)
  62. ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  63. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  64. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  65. # The lower bound gets rid of negative density stuff used to indicate out-of-FOV
  66. # The upper bound nukes any weird hotspots and clamps bone down
  67. ct_a.clip(-1000, 1000, ct_a)
  68. self.series_uid = series_uid
  69. self.hu_a = ct_a
  70. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  71. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  72. self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
  73. def getRawCandidate(self, center_xyz, width_irc):
  74. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_a)
  75. slice_list = []
  76. for axis, center_val in enumerate(center_irc):
  77. start_ndx = int(round(center_val - width_irc[axis]/2))
  78. end_ndx = int(start_ndx + width_irc[axis])
  79. assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  80. if start_ndx < 0:
  81. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  82. # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
  83. start_ndx = 0
  84. end_ndx = int(width_irc[axis])
  85. if end_ndx > self.hu_a.shape[axis]:
  86. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  87. # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
  88. end_ndx = self.hu_a.shape[axis]
  89. start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
  90. slice_list.append(slice(start_ndx, end_ndx))
  91. ct_chunk = self.hu_a[tuple(slice_list)]
  92. return ct_chunk, center_irc
  93. @functools.lru_cache(1, typed=True)
  94. def getCt(series_uid):
  95. return Ct(series_uid)
  96. @raw_cache.memoize(typed=True)
  97. def getCtRawCandidate(series_uid, center_xyz, width_irc):
  98. ct = getCt(series_uid)
  99. ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
  100. return ct_chunk, center_irc
  101. @raw_cache.memoize(typed=True)
  102. def getCtSampleSize(series_uid):
  103. ct = Ct(series_uid, buildMasks_bool=False)
  104. return len(ct.negative_indexes)
  105. def getCtAugmentedCandidate(
  106. augmentation_dict,
  107. series_uid, center_xyz, width_irc,
  108. use_cache=True):
  109. if use_cache:
  110. ct_chunk, center_irc = getCtRawCandidate(series_uid, center_xyz, width_irc)
  111. else:
  112. ct = getCt(series_uid)
  113. ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
  114. ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
  115. transform_t = torch.eye(4)
  116. # ... <1>
  117. for i in range(3):
  118. if 'flip' in augmentation_dict:
  119. if random.random() > 0.5:
  120. transform_t[i,i] *= -1
  121. if 'offset' in augmentation_dict:
  122. offset_float = augmentation_dict['offset']
  123. random_float = (random.random() * 2 - 1)
  124. transform_t[i, 3] = offset_float * random_float
  125. if 'scale' in augmentation_dict:
  126. scale_float = augmentation_dict['scale']
  127. random_float = (random.random() * 2 - 1)
  128. transform_t[i,i] *= 1.0 + scale_float * random_float
  129. if 'rotate' in augmentation_dict:
  130. angle_rad = random.random() * math.pi * 2
  131. s = math.sin(angle_rad)
  132. c = math.cos(angle_rad)
  133. rotation_t = torch.tensor([
  134. [c, -s, 0, 0],
  135. [s, c, 0, 0],
  136. [0, 0, 1, 0],
  137. [0, 0, 0, 1],
  138. ])
  139. transform_t @= rotation_t
  140. affine_t = F.affine_grid(
  141. transform_t[:3].unsqueeze(0).to(torch.float32),
  142. ct_t.size(),
  143. align_corners=False,
  144. )
  145. augmented_chunk = F.grid_sample(
  146. ct_t,
  147. affine_t,
  148. padding_mode='border',
  149. align_corners=False,
  150. ).to('cpu')
  151. if 'noise' in augmentation_dict:
  152. noise_t = torch.randn_like(augmented_chunk)
  153. noise_t *= augmentation_dict['noise']
  154. augmented_chunk += noise_t
  155. return augmented_chunk[0], center_irc
  156. class LunaDataset(Dataset):
  157. def __init__(self,
  158. val_stride=0,
  159. isValSet_bool=None,
  160. series_uid=None,
  161. sortby_str='random',
  162. ratio_int=0,
  163. augmentation_dict=None,
  164. candidateInfo_list=None,
  165. ):
  166. self.ratio_int = ratio_int
  167. self.augmentation_dict = augmentation_dict
  168. if candidateInfo_list:
  169. self.candidateInfo_list = copy.copy(candidateInfo_list)
  170. self.use_cache = False
  171. else:
  172. self.candidateInfo_list = copy.copy(getCandidateInfoList())
  173. self.use_cache = True
  174. if series_uid:
  175. self.series_list = [series_uid]
  176. else:
  177. self.series_list = sorted(set(candidateInfo_tup.series_uid for candidateInfo_tup in self.candidateInfo_list))
  178. if isValSet_bool:
  179. assert val_stride > 0, val_stride
  180. self.series_list = self.series_list[::val_stride]
  181. assert self.series_list
  182. elif val_stride > 0:
  183. del self.series_list[::val_stride]
  184. assert self.series_list
  185. series_set = set(self.series_list)
  186. self.candidateInfo_list = [x for x in self.candidateInfo_list if x.series_uid in series_set]
  187. if sortby_str == 'random':
  188. random.shuffle(self.candidateInfo_list)
  189. elif sortby_str == 'series_uid':
  190. self.candidateInfo_list.sort(key=lambda x: (x.series_uid, x.center_xyz))
  191. elif sortby_str == 'label_and_size':
  192. pass
  193. else:
  194. raise Exception("Unknown sort: " + repr(sortby_str))
  195. self.neg_list = [nt for nt in self.candidateInfo_list if not nt.isNodule_bool]
  196. self.pos_list = [nt for nt in self.candidateInfo_list if nt.isNodule_bool]
  197. self.ben_list = [nt for nt in self.pos_list if not nt.isMal_bool]
  198. self.mal_list = [nt for nt in self.pos_list if nt.isMal_bool]
  199. log.info("{!r}: {} {} samples, {} neg, {} pos, {} ratio".format(
  200. self,
  201. len(self.candidateInfo_list),
  202. "validation" if isValSet_bool else "training",
  203. len(self.neg_list),
  204. len(self.pos_list),
  205. '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
  206. ))
  207. def shuffleSamples(self):
  208. if self.ratio_int:
  209. random.shuffle(self.candidateInfo_list)
  210. random.shuffle(self.neg_list)
  211. random.shuffle(self.pos_list)
  212. random.shuffle(self.ben_list)
  213. random.shuffle(self.mal_list)
  214. def __len__(self):
  215. if self.ratio_int:
  216. return 50000
  217. # return 50000
  218. # return 200000
  219. else:
  220. return len(self.candidateInfo_list)
  221. def __getitem__(self, ndx):
  222. if self.ratio_int:
  223. pos_ndx = ndx // (self.ratio_int + 1)
  224. if ndx % (self.ratio_int + 1):
  225. neg_ndx = ndx - 1 - pos_ndx
  226. neg_ndx %= len(self.neg_list)
  227. candidateInfo_tup = self.neg_list[neg_ndx]
  228. else:
  229. pos_ndx %= len(self.pos_list)
  230. candidateInfo_tup = self.pos_list[pos_ndx]
  231. else:
  232. candidateInfo_tup = self.candidateInfo_list[ndx]
  233. return self.sampleFromCandidateInfo_tup(candidateInfo_tup, candidateInfo_tup.isNodule_bool)
  234. def sampleFromCandidateInfo_tup(self, candidateInfo_tup, label_bool):
  235. width_irc = (32, 48, 48)
  236. if self.augmentation_dict:
  237. candidate_t, center_irc = getCtAugmentedCandidate(
  238. self.augmentation_dict,
  239. candidateInfo_tup.series_uid,
  240. candidateInfo_tup.center_xyz,
  241. width_irc,
  242. self.use_cache,
  243. )
  244. elif self.use_cache:
  245. candidate_a, center_irc = getCtRawCandidate(
  246. candidateInfo_tup.series_uid,
  247. candidateInfo_tup.center_xyz,
  248. width_irc,
  249. )
  250. candidate_t = torch.from_numpy(candidate_a).to(torch.float32)
  251. candidate_t = candidate_t.unsqueeze(0)
  252. else:
  253. ct = getCt(candidateInfo_tup.series_uid)
  254. candidate_a, center_irc = ct.getRawCandidate(
  255. candidateInfo_tup.center_xyz,
  256. width_irc,
  257. )
  258. candidate_t = torch.from_numpy(candidate_a).to(torch.float32)
  259. candidate_t = candidate_t.unsqueeze(0)
  260. label_t = torch.tensor([False, False], dtype=torch.long)
  261. if not label_bool:
  262. label_t[0] = True
  263. index_t = 0
  264. else:
  265. label_t[1] = True
  266. index_t = 1
  267. return candidate_t, label_t, index_t, candidateInfo_tup.series_uid, torch.tensor(center_irc)
  268. # class MalignantLunaDataset(LunaDataset):
  269. # # tag::ds_balancing_len[]
  270. # def __len__(self):
  271. # if self.ratio_int:
  272. # return 10000
  273. # # return 50000
  274. # # return 200000
  275. # else:
  276. # return len(self.ben_list + self.mal_list)
  277. # # end::ds_balancing_len[]
  278. #
  279. # # tag::ds_balancing_getitem[]
  280. # def __getitem__(self, ndx):
  281. # if self.ratio_int:
  282. # mal_ndx = ndx // (self.ratio_int + 1)
  283. #
  284. # if ndx % (self.ratio_int + 1):
  285. # ben_ndx = ndx - 1 - mal_ndx
  286. # ben_ndx %= len(self.ben_list)
  287. # candidateInfo_tup = self.ben_list[ben_ndx]
  288. # else:
  289. # mal_ndx %= len(self.mal_list)
  290. # candidateInfo_tup = self.mal_list[mal_ndx]
  291. # else:
  292. # if ndx >= len(self.ben_list):
  293. # candidateInfo_tup = self.mal_list[ndx - len(self.ben_list)]
  294. # else:
  295. # candidateInfo_tup = self.ben_list[ndx]
  296. #
  297. # return self.sampleFromCandidateInfo_tup(candidateInfo_tup, candidateInfo_tup.isMal_bool)
  298. # # end::ds_balancing_getitem[]
  299. class MalignantLunaDataset(LunaDataset):
  300. def __len__(self):
  301. if self.ratio_int:
  302. # return 10000
  303. return 100000
  304. # return 50000
  305. # return 200000
  306. else:
  307. return len(self.ben_list + self.mal_list)
  308. def __getitem__(self, ndx):
  309. if self.ratio_int:
  310. if ndx % 2 != 0:
  311. candidateInfo_tup = self.mal_list[(ndx // 2) % len(self.mal_list)]
  312. elif ndx % 4 == 0:
  313. candidateInfo_tup = self.ben_list[(ndx // 4) % len(self.ben_list)]
  314. else:
  315. candidateInfo_tup = self.neg_list[(ndx // 4) % len(self.neg_list)]
  316. else:
  317. if ndx >= len(self.ben_list):
  318. candidateInfo_tup = self.mal_list[ndx - len(self.ben_list)]
  319. else:
  320. candidateInfo_tup = self.ben_list[ndx]
  321. return self.sampleFromCandidateInfo_tup(candidateInfo_tup, candidateInfo_tup.isMal_bool)