| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 |
- import copy
- import csv
- import functools
- import glob
- import math
- import os
- import random
- from collections import namedtuple
- import SimpleITK as sitk
- import numpy as np
- import torch
- import torch.cuda
- import torch.nn.functional as F
- from torch.utils.data import Dataset
- from util.disk import getCache
- from util.util import XyzTuple, xyz2irc
- from util.logconf import logging
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- # log.setLevel(logging.INFO)
- log.setLevel(logging.DEBUG)
- raw_cache = getCache('part2ch14_raw')
- CandidateInfoTuple = namedtuple(
- 'CandidateInfoTuple',
- 'isNodule_bool, hasAnnotation_bool, isMal_bool, diameter_mm, series_uid, center_xyz',
- )
- MaskTuple = namedtuple(
- 'MaskTuple',
- 'raw_dense_mask, dense_mask, body_mask, air_mask, raw_candidate_mask, candidate_mask, lung_mask, neg_mask, pos_mask',
- )
- @functools.lru_cache(1)
- def getCandidateInfoList(requireOnDisk_bool=True):
- # We construct a set with all series_uids that are present on disk.
- # This will let us use the data, even if we haven't downloaded all of
- # the subsets yet.
- mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
- presentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
- candidateInfo_list = []
- with open('data/part2/luna/annotations_with_malignancy.csv', "r") as f:
- for row in list(csv.reader(f))[1:]:
- series_uid = row[0]
- annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
- annotationDiameter_mm = float(row[4])
- isMal_bool = {'False': False, 'True': True}[row[5]]
- candidateInfo_list.append(CandidateInfoTuple(True, True, isMal_bool, annotationDiameter_mm, series_uid, annotationCenter_xyz))
- with open('data/part2/luna/candidates.csv', "r") as f:
- for row in list(csv.reader(f))[1:]:
- series_uid = row[0]
- if series_uid not in presentOnDisk_set and requireOnDisk_bool:
- continue
- isNodule_bool = bool(int(row[4]))
- candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
- if not isNodule_bool:
- candidateInfo_list.append(CandidateInfoTuple(
- False,
- False,
- False,
- 0.0,
- series_uid,
- candidateCenter_xyz,
- ))
- candidateInfo_list.sort(reverse=True)
- return candidateInfo_list
- @functools.lru_cache(1)
- def getCandidateInfoDict(requireOnDisk_bool=True):
- candidateInfo_list = getCandidateInfoList(requireOnDisk_bool)
- candidateInfo_dict = {}
- for candidateInfo_tup in candidateInfo_list:
- candidateInfo_dict.setdefault(candidateInfo_tup.series_uid, []).append(candidateInfo_tup)
- return candidateInfo_dict
- class Ct:
- def __init__(self, series_uid):
- mhd_path = glob.glob(
- 'data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid)
- )[0]
- ct_mhd = sitk.ReadImage(mhd_path)
- ct_a = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
- # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
- # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
- # The lower bound gets rid of negative density stuff used to indicate out-of-FOV
- # The upper bound nukes any weird hotspots and clamps bone down
- ct_a.clip(-1000, 1000, ct_a)
- self.series_uid = series_uid
- self.hu_a = ct_a
- self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
- self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
- self.direction_a = np.array(ct_mhd.GetDirection()).reshape(3, 3)
- def getRawCandidate(self, center_xyz, width_irc):
- center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_a)
- slice_list = []
- for axis, center_val in enumerate(center_irc):
- start_ndx = int(round(center_val - width_irc[axis]/2))
- end_ndx = int(start_ndx + width_irc[axis])
- assert center_val >= 0 and center_val < self.hu_a.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
- if start_ndx < 0:
- # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
- # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
- start_ndx = 0
- end_ndx = int(width_irc[axis])
- if end_ndx > self.hu_a.shape[axis]:
- # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
- # self.series_uid, center_xyz, center_irc, self.hu_a.shape, width_irc))
- end_ndx = self.hu_a.shape[axis]
- start_ndx = int(self.hu_a.shape[axis] - width_irc[axis])
- slice_list.append(slice(start_ndx, end_ndx))
- ct_chunk = self.hu_a[tuple(slice_list)]
- return ct_chunk, center_irc
- @functools.lru_cache(1, typed=True)
- def getCt(series_uid):
- return Ct(series_uid)
- @raw_cache.memoize(typed=True)
- def getCtRawCandidate(series_uid, center_xyz, width_irc):
- ct = getCt(series_uid)
- ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
- return ct_chunk, center_irc
- @raw_cache.memoize(typed=True)
- def getCtSampleSize(series_uid):
- ct = Ct(series_uid, buildMasks_bool=False)
- return len(ct.negative_indexes)
- def getCtAugmentedCandidate(
- augmentation_dict,
- series_uid, center_xyz, width_irc,
- use_cache=True):
- if use_cache:
- ct_chunk, center_irc = getCtRawCandidate(series_uid, center_xyz, width_irc)
- else:
- ct = getCt(series_uid)
- ct_chunk, center_irc = ct.getRawCandidate(center_xyz, width_irc)
- ct_t = torch.tensor(ct_chunk).unsqueeze(0).unsqueeze(0).to(torch.float32)
- transform_t = torch.eye(4)
- # ... <1>
- for i in range(3):
- if 'flip' in augmentation_dict:
- if random.random() > 0.5:
- transform_t[i,i] *= -1
- if 'offset' in augmentation_dict:
- offset_float = augmentation_dict['offset']
- random_float = (random.random() * 2 - 1)
- transform_t[i, 3] = offset_float * random_float
- if 'scale' in augmentation_dict:
- scale_float = augmentation_dict['scale']
- random_float = (random.random() * 2 - 1)
- transform_t[i,i] *= 1.0 + scale_float * random_float
- if 'rotate' in augmentation_dict:
- angle_rad = random.random() * math.pi * 2
- s = math.sin(angle_rad)
- c = math.cos(angle_rad)
- rotation_t = torch.tensor([
- [c, -s, 0, 0],
- [s, c, 0, 0],
- [0, 0, 1, 0],
- [0, 0, 0, 1],
- ])
- transform_t @= rotation_t
- affine_t = F.affine_grid(
- transform_t[:3].unsqueeze(0).to(torch.float32),
- ct_t.size(),
- align_corners=False,
- )
- augmented_chunk = F.grid_sample(
- ct_t,
- affine_t,
- padding_mode='border',
- align_corners=False,
- ).to('cpu')
- if 'noise' in augmentation_dict:
- noise_t = torch.randn_like(augmented_chunk)
- noise_t *= augmentation_dict['noise']
- augmented_chunk += noise_t
- return augmented_chunk[0], center_irc
- class LunaDataset(Dataset):
- def __init__(self,
- val_stride=0,
- isValSet_bool=None,
- series_uid=None,
- sortby_str='random',
- ratio_int=0,
- augmentation_dict=None,
- candidateInfo_list=None,
- ):
- self.ratio_int = ratio_int
- self.augmentation_dict = augmentation_dict
- if candidateInfo_list:
- self.candidateInfo_list = copy.copy(candidateInfo_list)
- self.use_cache = False
- else:
- self.candidateInfo_list = copy.copy(getCandidateInfoList())
- self.use_cache = True
- if series_uid:
- self.series_list = [series_uid]
- else:
- self.series_list = sorted(set(candidateInfo_tup.series_uid for candidateInfo_tup in self.candidateInfo_list))
- if isValSet_bool:
- assert val_stride > 0, val_stride
- self.series_list = self.series_list[::val_stride]
- assert self.series_list
- elif val_stride > 0:
- del self.series_list[::val_stride]
- assert self.series_list
- series_set = set(self.series_list)
- self.candidateInfo_list = [x for x in self.candidateInfo_list if x.series_uid in series_set]
- if sortby_str == 'random':
- random.shuffle(self.candidateInfo_list)
- elif sortby_str == 'series_uid':
- self.candidateInfo_list.sort(key=lambda x: (x.series_uid, x.center_xyz))
- elif sortby_str == 'label_and_size':
- pass
- else:
- raise Exception("Unknown sort: " + repr(sortby_str))
- self.neg_list = \
- [nt for nt in self.candidateInfo_list if not nt.isNodule_bool]
- self.pos_list = \
- [nt for nt in self.candidateInfo_list if nt.isNodule_bool]
- self.ben_list = \
- [nt for nt in self.pos_list if not nt.isMal_bool]
- self.mal_list = \
- [nt for nt in self.pos_list if nt.isMal_bool]
- log.info("{!r}: {} {} samples, {} neg, {} pos, {} ratio".format(
- self,
- len(self.candidateInfo_list),
- "validation" if isValSet_bool else "training",
- len(self.neg_list),
- len(self.pos_list),
- '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
- ))
- def shuffleSamples(self):
- if self.ratio_int:
- random.shuffle(self.candidateInfo_list)
- random.shuffle(self.neg_list)
- random.shuffle(self.pos_list)
- random.shuffle(self.ben_list)
- random.shuffle(self.mal_list)
- def __len__(self):
- if self.ratio_int:
- return 50000
- else:
- return len(self.candidateInfo_list)
- def __getitem__(self, ndx):
- if self.ratio_int:
- pos_ndx = ndx // (self.ratio_int + 1)
- if ndx % (self.ratio_int + 1):
- neg_ndx = ndx - 1 - pos_ndx
- neg_ndx %= len(self.neg_list)
- candidateInfo_tup = self.neg_list[neg_ndx]
- else:
- pos_ndx %= len(self.pos_list)
- candidateInfo_tup = self.pos_list[pos_ndx]
- else:
- candidateInfo_tup = self.candidateInfo_list[ndx]
- return self.sampleFromCandidateInfo_tup(
- candidateInfo_tup, candidateInfo_tup.isNodule_bool
- )
- def sampleFromCandidateInfo_tup(self, candidateInfo_tup, label_bool):
- width_irc = (32, 48, 48)
- if self.augmentation_dict:
- candidate_t, center_irc = getCtAugmentedCandidate(
- self.augmentation_dict,
- candidateInfo_tup.series_uid,
- candidateInfo_tup.center_xyz,
- width_irc,
- self.use_cache,
- )
- elif self.use_cache:
- candidate_a, center_irc = getCtRawCandidate(
- candidateInfo_tup.series_uid,
- candidateInfo_tup.center_xyz,
- width_irc,
- )
- candidate_t = torch.from_numpy(candidate_a).to(torch.float32)
- candidate_t = candidate_t.unsqueeze(0)
- else:
- ct = getCt(candidateInfo_tup.series_uid)
- candidate_a, center_irc = ct.getRawCandidate(
- candidateInfo_tup.center_xyz,
- width_irc,
- )
- candidate_t = torch.from_numpy(candidate_a).to(torch.float32)
- candidate_t = candidate_t.unsqueeze(0)
- label_t = torch.tensor([False, False], dtype=torch.long)
- if not label_bool:
- label_t[0] = True
- index_t = 0
- else:
- label_t[1] = True
- index_t = 1
- return candidate_t, label_t, index_t, candidateInfo_tup.series_uid, torch.tensor(center_irc)
- class MalignantLunaDataset(LunaDataset):
- def __len__(self):
- if self.ratio_int:
- return 100000
- else:
- return len(self.ben_list + self.mal_list)
- def __getitem__(self, ndx):
- if self.ratio_int:
- if ndx % 2 != 0:
- candidateInfo_tup = self.mal_list[(ndx // 2) % len(self.mal_list)]
- elif ndx % 4 == 0:
- candidateInfo_tup = self.ben_list[(ndx // 4) % len(self.ben_list)]
- else:
- candidateInfo_tup = self.neg_list[(ndx // 4) % len(self.neg_list)]
- else:
- if ndx >= len(self.ben_list):
- candidateInfo_tup = self.mal_list[ndx - len(self.ben_list)]
- else:
- candidateInfo_tup = self.ben_list[ndx]
- return self.sampleFromCandidateInfo_tup(
- candidateInfo_tup, candidateInfo_tup.isMal_bool
- )
|