dsets.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. import copy
  2. import csv
  3. import functools
  4. import glob
  5. import math
  6. import os
  7. import random
  8. from collections import namedtuple
  9. import SimpleITK as sitk
  10. import numpy as np
  11. import torch
  12. import torch.cuda
  13. import torch.nn.functional as F
  14. from torch.utils.data import Dataset
  15. from util.disk import getCache
  16. from util.util import XyzTuple, xyz2irc
  17. from util.logconf import logging
  18. log = logging.getLogger(__name__)
  19. # log.setLevel(logging.WARN)
  20. # log.setLevel(logging.INFO)
  21. log.setLevel(logging.DEBUG)
  22. raw_cache = getCache('part2ch12_raw')
  23. CandidateInfoTuple = namedtuple('CandidateInfoTuple', 'isNodule_bool, diameter_mm, series_uid, center_xyz')
  24. @functools.lru_cache(1)
  25. def getCandidateInfoList(requireOnDisk_bool=True):
  26. # We construct a set with all series_uids that are present on disk.
  27. # This will let us use the data, even if we haven't downloaded all of
  28. # the subsets yet.
  29. mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
  30. presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
  31. diameter_dict = {}
  32. with open('data/part2/luna/annotations.csv', "r") as f:
  33. for row in list(csv.reader(f))[1:]:
  34. series_uid = row[0]
  35. annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
  36. annotationDiameter_mm = float(row[4])
  37. diameter_dict.setdefault(series_uid, []).append(
  38. (annotationCenter_xyz, annotationDiameter_mm),
  39. )
  40. candidateInfo_list = []
  41. with open('data/part2/luna/candidates.csv', "r") as f:
  42. for row in list(csv.reader(f))[1:]:
  43. series_uid = row[0]
  44. if series_uid not in presentOnDisk_set and requireOnDisk_bool:
  45. continue
  46. isNodule_bool = bool(int(row[4]))
  47. candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
  48. candidateDiameter_mm = 0.0
  49. for annotation_tup in diameter_dict.get(series_uid, []):
  50. annotationCenter_xyz, annotationDiameter_mm = annotation_tup
  51. for i in range(3):
  52. delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
  53. if delta_mm > annotationDiameter_mm / 4:
  54. break
  55. else:
  56. candidateDiameter_mm = annotationDiameter_mm
  57. break
  58. candidateInfo_list.append(CandidateInfoTuple(
  59. isNodule_bool,
  60. candidateDiameter_mm,
  61. series_uid,
  62. candidateCenter_xyz,
  63. ))
  64. candidateInfo_list.sort(reverse=True)
  65. return candidateInfo_list
  66. class Ct:
  67. def __init__(self, series_uid):
  68. mhd_path = glob.glob(
  69. 'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
  70. )[0]
  71. ct_mhd = sitk.ReadImage(mhd_path)
  72. ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
  73. # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
  74. # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
  75. # The lower bound gets rid of negative density stuff used to indicate out-of-FOV
  76. # The upper bound nukes any weird hotspots and clamps bone down
  77. ct_a.clip(-1000, 1000, ct_a)
  78. self.series_uid = series_uid
  79. self.hu_a = ct_a
  80. self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
  81. self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
  82. self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
  83. def getRawCandidate(self, center_xyz, width_irc):
  84. center_irc = xyz2irc(
  85. center_xyz,
  86. self.origin_xyz,
  87. self.vxSize_xyz,
  88. self.direction_a,
  89. )
  90. slice_list = []
  91. for axis, center_val in enumerate(center_irc):
  92. start_ndx = int(round(center_val - width_irc[axis]/2))
  93. end_ndx = int(start_ndx + width_irc[axis])
  94. assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
  95. if start_ndx < 0:
  96. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  97. # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
  98. start_ndx = 0
  99. end_ndx = int(width_irc[axis])
  100. if end_ndx > self.hu_a.shape[axis]:
  101. # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
  102. # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
  103. end_ndx = self.hu_a.shape[axis]
  104. start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
  105. slice_list.append(slice(start_ndx, end_ndx))
  106. ct_chunk = self.hu_a[tuple(slice_list)]
  107. return ct_chunk, center_irc
  108. @functools.lru_cache(1, typed=True)
  109. def getCt(series_uid):
  110. return Ct(series_uid)
  111. @raw_cache.memoize(typed=True)
  112. def getCtRawCandidate(series_uid, center_xyz, width_irc):
  113. ct = getCt(series_uid)
  114. ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
  115. return ct_chunk, center_irc
  116. def getCtAugmentedCandidate(
  117. augmentation_dict,
  118. series_uid, center_xyz, width_irc,
  119. use_cache=True):
  120. if use_cache:
  121. ct_chunk, center_irc = \
  122. getCtRawCandidate(series_uid, center_xyz, width_irc)
  123. else:
  124. ct = getCt(series_uid)
  125. ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
  126. ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
  127. transform_t = torch.eye(4)
  128. # ... <1>
  129. for i in range(3):
  130. if 'flip' in augmentation_dict:
  131. if random.random() > 0.5:
  132. transform_t[i,i] *= -1
  133. if 'offset' in augmentation_dict:
  134. offset_float = augmentation_dict['offset']
  135. random_float = (random.random() * 2 - 1)
  136. transform_t[i,3] = offset_float * random_float
  137. if 'scale' in augmentation_dict:
  138. scale_float = augmentation_dict['scale']
  139. random_float = (random.random() * 2 - 1)
  140. transform_t[i,i] *= 1.0 + scale_float * random_float
  141. if 'rotate' in augmentation_dict:
  142. angle_rad = random.random() * math.pi * 2
  143. s = math.sin(angle_rad)
  144. c = math.cos(angle_rad)
  145. rotation_t = torch.tensor([
  146. [c, -s, 0, 0],
  147. [s, c, 0, 0],
  148. [0, 0, 1, 0],
  149. [0, 0, 0, 1],
  150. ])
  151. transform_t @= rotation_t
  152. affine_t = F.affine_grid(
  153. transform_t[:3].unsqueeze(0).to(torch.float32),
  154. ct_t.size(),
  155. align_corners=False,
  156. )
  157. augmented_chunk = F.grid_sample(
  158. ct_t,
  159. affine_t,
  160. padding_mode='border',
  161. align_corners=False,
  162. ).to('cpu')
  163. if 'noise' in augmentation_dict:
  164. noise_t = torch.randn_like(augmented_chunk)
  165. noise_t *= augmentation_dict['noise']
  166. augmented_chunk += noise_t
  167. return augmented_chunk[0], center_irc
  168. class LunaDataset(Dataset):
  169. def __init__(self,
  170. val_stride=0,
  171. isValSet_bool=None,
  172. series_uid=None,
  173. sortby_str='random',
  174. ratio_int=0,
  175. augmentation_dict=None,
  176. candidateInfo_list=None,
  177. ):
  178. self.ratio_int = ratio_int
  179. self.augmentation_dict = augmentation_dict
  180. if candidateInfo_list:
  181. self.candidateInfo_list = copy.copy(candidateInfo_list)
  182. self.use_cache = False
  183. else:
  184. self.candidateInfo_list = copy.copy(getCandidateInfoList())
  185. self.use_cache = True
  186. if series_uid:
  187. self.candidateInfo_list = [
  188. x for x in self.candidateInfo_list if x.series_uid == series_uid
  189. ]
  190. if isValSet_bool:
  191. assert val_stride > 0, val_stride
  192. self.candidateInfo_list = self.candidateInfo_list[::val_stride]
  193. assert self.candidateInfo_list
  194. elif val_stride > 0:
  195. del self.candidateInfo_list[::val_stride]
  196. assert self.candidateInfo_list
  197. if sortby_str == 'random':
  198. random.shuffle(self.candidateInfo_list)
  199. elif sortby_str == 'series_uid':
  200. self.candidateInfo_list.sort(key=lambda x: (x.series_uid, x.center_xyz))
  201. elif sortby_str == 'label_and_size':
  202. pass
  203. else:
  204. raise Exception("Unknown sort: " + repr(sortby_str))
  205. self.negative_list = [
  206. nt for nt in self.candidateInfo_list if not nt.isNodule_bool
  207. ]
  208. self.pos_list = [
  209. nt for nt in self.candidateInfo_list if nt.isNodule_bool
  210. ]
  211. log.info("{!r}: {} {} samples, {} neg, {} pos, {} ratio".format(
  212. self,
  213. len(self.candidateInfo_list),
  214. "validation" if isValSet_bool else "training",
  215. len(self.negative_list),
  216. len(self.pos_list),
  217. '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
  218. ))
  219. def shuffleSamples(self):
  220. if self.ratio_int:
  221. random.shuffle(self.negative_list)
  222. random.shuffle(self.pos_list)
  223. def __len__(self):
  224. if self.ratio_int:
  225. return 200000
  226. else:
  227. return len(self.candidateInfo_list)
  228. def __getitem__(self, ndx):
  229. if self.ratio_int:
  230. pos_ndx = ndx // (self.ratio_int + 1)
  231. if ndx % (self.ratio_int + 1):
  232. neg_ndx = ndx - 1 - pos_ndx
  233. neg_ndx %= len(self.negative_list)
  234. candidateInfo_tup = self.negative_list[neg_ndx]
  235. else:
  236. pos_ndx %= len(self.pos_list)
  237. candidateInfo_tup = self.pos_list[pos_ndx]
  238. else:
  239. candidateInfo_tup = self.candidateInfo_list[ndx]
  240. width_irc = (32, 48, 48)
  241. if self.augmentation_dict:
  242. candidate_t, center_irc = getCtAugmentedCandidate(
  243. self.augmentation_dict,
  244. candidateInfo_tup.series_uid,
  245. candidateInfo_tup.center_xyz,
  246. width_irc,
  247. self.use_cache,
  248. )
  249. elif self.use_cache:
  250. candidate_a, center_irc = getCtRawCandidate(
  251. candidateInfo_tup.series_uid,
  252. candidateInfo_tup.center_xyz,
  253. width_irc,
  254. )
  255. candidate_t = torch.from_numpy(candidate_a).to(torch.float32)
  256. candidate_t = candidate_t.unsqueeze(0)
  257. else:
  258. ct = getCt(candidateInfo_tup.series_uid)
  259. candidate_a, center_irc = ct.getRawCandidate(
  260. candidateInfo_tup.center_xyz,
  261. width_irc,
  262. )
  263. candidate_t = torch.from_numpy(candidate_a).to(torch.float32)
  264. candidate_t = candidate_t.unsqueeze(0)
  265. pos_t = torch.tensor([
  266. not candidateInfo_tup.isNodule_bool,
  267. candidateInfo_tup.isNodule_bool
  268. ],
  269. dtype=torch.long,
  270. )
  271. return candidate_t, pos_t, candidateInfo_tup.series_uid, torch.tensor(center_irc)