dsets.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import copy
  2. import csv
  3. import functools
  4. import glob
  5. import math
  6. import os
  7. import random
  8. from collections import namedtuple
  9. import SimpleITK as sitk
  10. import numpy as np
  11. import torch
  12. import torch.cuda
  13. from torch.utils.data import Dataset
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. import numpy as np
  17. from util.disk import getCache
  18. from util.util import XyzTuple, xyz2irc
  19. from util.logconf import logging
  20. log = logging.getLogger(__name__)
  21. # log.setLevel(logging.WARN)
  22. # log.setLevel(logging.INFO)
  23. log.setLevel(logging.DEBUG)
  24. raw_cache = getCache('part2ch11_raw')
  25. NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
  26. @functools.lru_cache(1)
  27. def getNoduleInfoList(requireDataOnDisk_bool=True):
  28. # We construct a set with all series_uids that are present on disk.
  29. # This will let us use the data, even if we haven't downloaded all of
  30. # the subsets yet.
  31. mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
  32. dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
  33. diameter_dict = {}
  34. with open('data/part2/luna/annotations.csv', "r") as f:
  35. for row in list(csv.reader(f))[1:]:
  36. series_uid = row[0]
  37. annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
  38. annotationDiameter_mm = float(row[4])
  39. diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
  40. noduleInfo_list = []
  41. with open('data/part2/luna/candidates.csv', "r") as f:
  42. for row in list(csv.reader(f))[1:]:
  43. series_uid = row[0]
  44. if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
  45. continue
  46. isMalignant_bool = bool(int(row[4]))
  47. candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
  48. candidateDiameter_mm = 0.0
  49. for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
  50. for i in range(3):
  51. delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
  52. if delta_mm > annotationDiameter_mm / 4:
  53. break
  54. else:
  55. candidateDiameter_mm = annotationDiameter_mm
  56. break
  57. noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
  58. noduleInfo_list.sort(reverse=True)
  59. return noduleInfo_list
  60. class Ct(object):
  61. def __init__(self, series_uid):
  62. mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
  63. ct_mhd = sitk.ReadImage(mhd_path)
  64. ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  65. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  66. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  67. # This gets rid of negative density stuff used to indicate out-of-FOV
  68. ct_ary[ct_ary < -1000] = -1000
  69. # This nukes any weird hotspots and clamps bone down
  70. ct_ary[ct_ary > 1000] = 1000
  71. self.series_uid = series_uid
  72. self.ary = ct_ary
  73. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  74. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  75. self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
  76. def getRawNodule(self, center_xyz, width_irc):
  77. center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
  78. slice_list = []
  79. for axis, center_val in enumerate(center_irc):
  80. start_ndx = int(round(center_val - width_irc[axis]/2))
  81. end_ndx = int(start_ndx + width_irc[axis])
  82. assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  83. if start_ndx < 0:
  84. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  85. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  86. start_ndx = 0
  87. end_ndx = int(width_irc[axis])
  88. if end_ndx > self.ary.shape[axis]:
  89. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  90. # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
  91. end_ndx = self.ary.shape[axis]
  92. start_ndx = int(self.ary.shape[axis] - width_irc[axis])
  93. slice_list.append(slice(start_ndx, end_ndx))
  94. ct_chunk = self.ary[tuple(slice_list)]
  95. return ct_chunk, center_irc
  96. @functools.lru_cache(1, typed=True)
  97. def getCt(series_uid):
  98. return Ct(series_uid)
  99. @raw_cache.memoize(typed=True)
  100. def getCtRawNodule(series_uid, center_xyz, width_irc):
  101. ct = getCt(series_uid)
  102. ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
  103. return ct_chunk, center_irc
  104. def getCtAugmentedNodule(
  105. augmentation_dict,
  106. series_uid, center_xyz, width_irc,
  107. use_cache=True):
  108. if use_cache:
  109. ct_chunk, center_irc = getCtRawNodule(series_uid, center_xyz, width_irc)
  110. else:
  111. ct = getCt(series_uid)
  112. ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
  113. ct_tensor = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
  114. transform_tensor = torch.eye(4).to(torch.float64)
  115. # ... <1>
  116. for i in range(3):
  117. if 'flip' in augmentation_dict:
  118. if random.random() > 0.5:
  119. transform_tensor[i,i] *= -1
  120. if 'offset' in augmentation_dict:
  121. offset_float = augmentation_dict['offset']
  122. random_float = (random.random() * 2 - 1)
  123. transform_tensor[3,i] = offset_float * random_float
  124. if 'scale' in augmentation_dict:
  125. scale_float = augmentation_dict['scale']
  126. random_float = (random.random() * 2 - 1)
  127. transform_tensor[i,i] *= 1.0 + scale_float * random_float
  128. if 'rotate' in augmentation_dict:
  129. angle_rad = random.random() * math.pi * 2
  130. s = math.sin(angle_rad)
  131. c = math.cos(angle_rad)
  132. rotation_tensor = torch.tensor([
  133. [c, -s, 0, 0],
  134. [s, c, 0, 0],
  135. [0, 0, 1, 0],
  136. [0, 0, 0, 1],
  137. ], dtype=torch.float64)
  138. transform_tensor @= rotation_tensor
  139. affine_tensor = F.affine_grid(
  140. transform_tensor[:3].unsqueeze(0).to(torch.float32),
  141. ct_tensor.size(),
  142. )
  143. augmented_chunk = F.grid_sample(
  144. ct_tensor,
  145. affine_tensor,
  146. padding_mode='border'
  147. ).to('cpu')
  148. if 'noise' in augmentation_dict:
  149. noise_tensor = torch.randn_like(augmented_chunk)
  150. noise_tensor *= augmentation_dict['noise']
  151. augmented_chunk += noise_tensor
  152. return augmented_chunk[0], center_irc
  153. class LunaDataset(Dataset):
  154. def __init__(self,
  155. test_stride=0,
  156. isTestSet_bool=None,
  157. series_uid=None,
  158. sortby_str='random',
  159. ratio_int=0,
  160. augmentation_dict=None,
  161. noduleInfo_list=None,
  162. ):
  163. self.ratio_int = ratio_int
  164. self.augmentation_dict = augmentation_dict
  165. if noduleInfo_list:
  166. self.noduleInfo_list = copy.copy(noduleInfo_list)
  167. self.use_cache = False
  168. else:
  169. self.noduleInfo_list = copy.copy(getNoduleInfoList())
  170. self.use_cache = True
  171. if series_uid:
  172. self.noduleInfo_list = [x for x in self.noduleInfo_list if x.series_uid == series_uid]
  173. if test_stride > 1:
  174. if isTestSet_bool:
  175. self.noduleInfo_list = self.noduleInfo_list[::test_stride]
  176. else:
  177. del self.noduleInfo_list[::test_stride]
  178. if sortby_str == 'random':
  179. random.shuffle(self.noduleInfo_list)
  180. elif sortby_str == 'series_uid':
  181. self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
  182. elif sortby_str == 'malignancy_size':
  183. pass
  184. else:
  185. raise Exception("Unknown sort: " + repr(sortby_str))
  186. self.benign_list = [nt for nt in self.noduleInfo_list if not nt.isMalignant_bool]
  187. self.malignant_list = [nt for nt in self.noduleInfo_list if nt.isMalignant_bool]
  188. log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
  189. self,
  190. len(self.noduleInfo_list),
  191. "testing" if isTestSet_bool else "training",
  192. len(self.benign_list),
  193. len(self.malignant_list),
  194. '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
  195. ))
  196. def shuffleSamples(self):
  197. if self.ratio_int:
  198. random.shuffle(self.benign_list)
  199. random.shuffle(self.malignant_list)
  200. def __len__(self):
  201. if self.ratio_int:
  202. return 200000
  203. else:
  204. return len(self.noduleInfo_list)
  205. def __getitem__(self, ndx):
  206. if self.ratio_int:
  207. malignant_ndx = ndx // (self.ratio_int + 1)
  208. if ndx % (self.ratio_int + 1):
  209. benign_ndx = ndx - 1 - malignant_ndx
  210. benign_ndx %= len(self.benign_list)
  211. nodule_tup = self.benign_list[benign_ndx]
  212. else:
  213. malignant_ndx %= len(self.malignant_list)
  214. nodule_tup = self.malignant_list[malignant_ndx]
  215. else:
  216. nodule_tup = self.noduleInfo_list[ndx]
  217. width_irc = (24, 48, 48)
  218. if self.augmentation_dict:
  219. nodule_t, center_irc = getCtAugmentedNodule(
  220. self.augmentation_dict,
  221. nodule_tup.series_uid,
  222. nodule_tup.center_xyz,
  223. width_irc,
  224. self.use_cache,
  225. )
  226. elif self.use_cache:
  227. nodule_ary, center_irc = getCtRawNodule(
  228. nodule_tup.series_uid,
  229. nodule_tup.center_xyz,
  230. width_irc,
  231. )
  232. nodule_t = torch.from_numpy(nodule_ary).to(torch.float32)
  233. nodule_t = nodule_t.unsqueeze(0)
  234. else:
  235. ct = getCt(nodule_tup.series_uid)
  236. nodule_ary, center_irc = ct.getRawNodule(
  237. nodule_tup.center_xyz,
  238. width_irc,
  239. )
  240. nodule_t = torch.from_numpy(nodule_ary).to(torch.float32)
  241. nodule_t = nodule_t.unsqueeze(0)
  242. malignant_tensor = torch.tensor([
  243. not nodule_tup.isMalignant_bool,
  244. nodule_tup.isMalignant_bool
  245. ],
  246. dtype=torch.long,
  247. )
  248. return nodule_t, malignant_tensor, nodule_tup.series_uid, center_irc