| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211 |
- import copy
- import csv
- import functools
- import glob
- import os
- import random
- import SimpleITK as sitk
- import numpy as np
- import torch
- import torch.cuda
- from torch.utils.data import Dataset
- from util.disk import getCache
- from util.util import XyzTuple, xyz2irc
- from util.logconf import logging
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- log.setLevel(logging.INFO)
- log.setLevel(logging.DEBUG)
- raw_cache = getCache('part2ch10_raw')
- @functools.lru_cache(1)
- def getNoduleInfoList(requireDataOnDisk_bool=True):
- # We construct a set with all series_uids that are present on disk.
- # This will let us use the data, even if we haven't downloaded all of
- # the subsets yet.
- mhd_list = glob.glob('data/luna/subset*/*.mhd')
- dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
- diameter_dict = {}
- with open('data/luna/annotations.csv', "r") as f:
- for row in list(csv.reader(f))[1:]:
- series_uid = row[0]
- annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
- annotationDiameter_mm = float(row[4])
- diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
- noduleInfo_list = []
- with open('data/luna/candidates.csv', "r") as f:
- for row in list(csv.reader(f))[1:]:
- series_uid = row[0]
- if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
- continue
- isMalignant_bool = bool(int(row[4]))
- candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
- candidateDiameter_mm = 0.0
- for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
- for i in range(3):
- delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
- if delta_mm > annotationDiameter_mm / 4:
- break
- else:
- candidateDiameter_mm = annotationDiameter_mm
- break
- noduleInfo_list.append((isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
- noduleInfo_list.sort(reverse=True)
- return noduleInfo_list
- class Ct(object):
- def __init__(self, series_uid):
- mhd_path = glob.glob('data/luna/subset*/{}.mhd'.format(series_uid))[0]
- ct_mhd = sitk.ReadImage(mhd_path)
- ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
- # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
- # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
- # This converts HU to g/cc.
- ct_ary += 1000
- ct_ary /= 1000
- # This gets rid of negative density stuff used to indicate out-of-FOV
- ct_ary[ct_ary < 0] = 0
- # This nukes any weird hotspots and clamps bone down
- ct_ary[ct_ary > 2] = 2
- self.series_uid = series_uid
- self.ary = ct_ary
- self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
- self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
- self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
- def getRawNodule(self, center_xyz, width_irc):
- center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
- slice_list = []
- for axis, center_val in enumerate(center_irc):
- start_ndx = int(round(center_val - width_irc[axis]/2))
- end_ndx = int(start_ndx + width_irc[axis])
- assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
- if start_ndx < 0:
- # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
- # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
- start_ndx = 0
- end_ndx = int(width_irc[axis])
- if end_ndx > self.ary.shape[axis]:
- # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
- # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
- end_ndx = self.ary.shape[axis]
- start_ndx = int(self.ary.shape[axis] - width_irc[axis])
- slice_list.append(slice(start_ndx, end_ndx))
- ct_chunk = self.ary[slice_list]
- return ct_chunk, center_irc
- @functools.lru_cache(1, typed=True)
- def getCt(series_uid):
- return Ct(series_uid)
- @raw_cache.memoize(typed=True)
- def getCtRawNodule(series_uid, center_xyz, width_irc):
- ct = getCt(series_uid)
- ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
- return ct_chunk, center_irc
- class LunaDataset(Dataset):
- def __init__(self,
- test_stride=0,
- isTestSet_bool=None,
- series_uid=None,
- sortby_str='random',
- ratio_int=0,
- ):
- self.ratio_int = ratio_int
- self.noduleInfo_list = copy.copy(getNoduleInfoList())
- if series_uid:
- self.noduleInfo_list = [x for x in self.noduleInfo_list if x[2] == series_uid]
- if test_stride > 1:
- if isTestSet_bool:
- self.noduleInfo_list = self.noduleInfo_list[::test_stride]
- else:
- del self.noduleInfo_list[::test_stride]
- if sortby_str == 'random':
- random.shuffle(self.noduleInfo_list)
- elif sortby_str == 'series_uid':
- self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
- elif sortby_str == 'malignancy_size':
- pass
- else:
- raise Exception("Unknown sort: " + repr(sortby_str))
- self.benignIndex_list = [i for i, x in enumerate(self.noduleInfo_list) if not x[0]]
- self.malignantIndex_list = [i for i, x in enumerate(self.noduleInfo_list) if x[0]]
- log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
- self,
- len(self.noduleInfo_list),
- "testing" if isTestSet_bool else "training",
- len(self.benignIndex_list),
- len(self.malignantIndex_list),
- '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
- ))
- def shuffleSamples(self):
- if self.ratio_int:
- random.shuffle(self.benignIndex_list)
- random.shuffle(self.malignantIndex_list)
- def __len__(self):
- if self.ratio_int:
- return 100000
- else:
- return len(self.noduleInfo_list)
- def __getitem__(self, ndx):
- if self.ratio_int:
- malignant_ndx = ndx // (self.ratio_int + 1)
- if ndx % (self.ratio_int + 1):
- benign_ndx = ndx - 1 - malignant_ndx
- nodule_ndx = self.benignIndex_list[benign_ndx % len(self.benignIndex_list)]
- else:
- nodule_ndx = self.malignantIndex_list[malignant_ndx % len(self.malignantIndex_list)]
- else:
- nodule_ndx = ndx
- isMalignant_bool, _diameter_mm, series_uid, center_xyz = self.noduleInfo_list[nodule_ndx]
- nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, (32, 32, 32))
- nodule_tensor = torch.from_numpy(nodule_ary)
- nodule_tensor = nodule_tensor.unsqueeze(0)
- malignant_tensor = torch.tensor([isMalignant_bool], dtype=torch.float32)
- return nodule_tensor, malignant_tensor, series_uid, center_irc
|