| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648 |
- import copy
- import csv
- import functools
- import glob
- import itertools
- import math
- import os
- import random
- from collections import namedtuple
- import SimpleITK as sitk
- import scipy.ndimage.morphology
- import numpy as np
- import torch
- import torch.cuda
- from torch.utils.data import Dataset, DataLoader
- from torch.utils.data.sampler import Sampler
- from util.disk import getCache
- from util.util import XyzTuple, xyz2irc, IrcTuple
- from util.logconf import logging
- from util.affine import affine_grid_generator
- log = logging.getLogger(__name__)
- # log.setLevel(logging.WARN)
- log.setLevel(logging.INFO)
- log.setLevel(logging.DEBUG)
- raw_cache = getCache('part2ch11_raw')
- cubic_cache = getCache('part2ch11_cubic')
- NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
- @functools.lru_cache(1)
- def getNoduleInfoList(requireDataOnDisk_bool=True):
- # We construct a set with all series_uids that are present on disk.
- # This will let us use the data, even if we haven't downloaded all of
- # the subsets yet.
- mhd_list = glob.glob('data-unversioned/part2/luna/subset*/*.mhd')
- dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
- diameter_dict = {}
- with open('data/part2/luna/annotations.csv', "r") as f:
- for row in list(csv.reader(f))[1:]:
- series_uid = row[0]
- annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
- annotationDiameter_mm = float(row[4])
- diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
- noduleInfo_list = []
- with open('data/part2/luna/candidates.csv', "r") as f:
- for row in list(csv.reader(f))[1:]:
- series_uid = row[0]
- if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
- continue
- isMalignant_bool = bool(int(row[4]))
- candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
- candidateDiameter_mm = 0.0
- for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
- for i in range(3):
- delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
- if delta_mm > annotationDiameter_mm / 4:
- break
- else:
- candidateDiameter_mm = annotationDiameter_mm
- break
- noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
- noduleInfo_list.sort(reverse=True)
- return noduleInfo_list
- class Ct(object):
- def __init__(self, series_uid, buildMasks_bool=True):
- mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
- ct_mhd = sitk.ReadImage(mhd_path)
- ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
- # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
- # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
- # This converts HU to g/cc.
- ct_ary += 1000
- ct_ary /= 1000
- # This gets rid of negative density stuff used to indicate out-of-FOV
- ct_ary[ct_ary < 0] = 0
- # This nukes any weird hotspots and clamps bone down
- ct_ary[ct_ary > 2] = 2
- self.series_uid = series_uid
- self.ary = ct_ary
- self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
- self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
- self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
- noduleInfo_list = getNoduleInfoList()
- self.benignInfo_list = [ni_tup
- for ni_tup in noduleInfo_list
- if not ni_tup.isMalignant_bool
- and ni_tup.series_uid == self.series_uid]
- self.benign_mask = self.buildAnnotationMask(self.benignInfo_list)[0]
- self.benign_indexes = sorted(set(self.benign_mask.nonzero()[0]))
- self.malignantInfo_list = [ni_tup
- for ni_tup in noduleInfo_list
- if ni_tup.isMalignant_bool
- and ni_tup.series_uid == self.series_uid]
- self.malignant_mask = self.buildAnnotationMask(self.malignantInfo_list)[0]
- self.malignant_indexes = sorted(set(self.malignant_mask.nonzero()[0]))
- def buildAnnotationMask(self, noduleInfo_list, threshold_gcc = 0.5):
- boundingBox_ary = np.zeros_like(self.ary, dtype=np.bool)
- for noduleInfo_tup in noduleInfo_list:
- center_irc = xyz2irc(noduleInfo_tup.center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
- center_index = int(center_irc.index)
- center_row = int(center_irc.row)
- center_col = int(center_irc.col)
- index_radius = 2
- try:
- while self.ary[center_index + index_radius, center_row, center_col] > threshold_gcc and \
- self.ary[center_index - index_radius, center_row, center_col] > threshold_gcc:
- index_radius += 1
- except IndexError:
- index_radius -= 1
- row_radius = 2
- try:
- while self.ary[center_index, center_row + row_radius, center_col] > threshold_gcc and \
- self.ary[center_index, center_row - row_radius, center_col] > threshold_gcc:
- row_radius += 1
- except IndexError:
- row_radius -= 1
- col_radius = 2
- try:
- while self.ary[center_index, center_row, center_col + col_radius] > threshold_gcc and \
- self.ary[center_index, center_row, center_col - col_radius] > threshold_gcc:
- col_radius += 1
- except IndexError:
- col_radius -= 1
- # assert index_radius > 0, repr([noduleInfo_tup.center_xyz, center_irc, self.ary[center_index, center_row, center_col]])
- # assert row_radius > 0
- # assert col_radius > 0
- slice_tup = (
- slice(
- # max(0, center_index - index_radius),
- center_index - index_radius,
- center_index + index_radius + 1,
- ),
- slice(
- # max(0, center_row - row_radius),
- center_row - row_radius,
- center_row + row_radius + 1,
- ),
- slice(
- # max(0, center_col - col_radius),
- center_col - col_radius,
- center_col + row_radius + 1,
- ),
- )
- boundingBox_ary[slice_tup] = True
- thresholded_ary = boundingBox_ary & (self.ary > threshold_gcc)
- mask_ary = scipy.ndimage.morphology.binary_dilation(thresholded_ary, iterations=2)
- return mask_ary, thresholded_ary, boundingBox_ary
- def build2dLungMask(self, mask_ndx, threshold_gcc = 0.7):
- dense_mask = self.ary[mask_ndx] > threshold_gcc
- denoise_mask = scipy.ndimage.morphology.binary_closing(dense_mask, iterations=2)
- tissue_mask = scipy.ndimage.morphology.binary_opening(denoise_mask, iterations=10)
- body_mask = scipy.ndimage.morphology.binary_fill_holes(tissue_mask)
- air_mask = scipy.ndimage.morphology.binary_fill_holes(body_mask & ~tissue_mask)
- lung_mask = scipy.ndimage.morphology.binary_dilation(air_mask, iterations=2)
- return air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask
- def build3dLungMask(self):
- air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask = mask_list = \
- [np.zeros_like(self.ary, dtype=np.bool) for _ in range(6)]
- for mask_ndx in range(self.ary.shape[0]):
- for i, mask_ary in enumerate(self.build2dLungMask(mask_ndx)):
- mask_list[i][mask_ndx] = mask_ary
- return air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask
- def getRawNodule(self, center_xyz, width_irc):
- center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
- slice_list = []
- for axis, center_val in enumerate(center_irc):
- start_ndx = int(round(center_val - width_irc[axis]/2))
- end_ndx = int(start_ndx + width_irc[axis])
- assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
- if start_ndx < 0:
- # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
- # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
- start_ndx = 0
- end_ndx = int(width_irc[axis])
- if end_ndx > self.ary.shape[axis]:
- # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
- # self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
- end_ndx = self.ary.shape[axis]
- start_ndx = int(self.ary.shape[axis] - width_irc[axis])
- slice_list.append(slice(start_ndx, end_ndx))
- ct_chunk = self.ary[tuple(slice_list)]
- return ct_chunk, center_irc
- def getCubicInputChunk(self, center_xyz, maxWidth_mm):
- center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
- ct_start = [int(round(i)) for i in xyz2irc(tuple(x - maxWidth_mm / 2 for x in center_xyz), self.origin_xyz, self.vxSize_xyz, self.direction_tup)]
- ct_end = [int(round(i)) + 1 for i in xyz2irc(tuple(x + maxWidth_mm / 2 for x in center_xyz), self.origin_xyz, self.vxSize_xyz, self.direction_tup)]
- for axis in range(3):
- if ct_start[axis] > ct_end[axis]:
- ct_start[axis], ct_end[axis] = ct_end[axis], ct_start[axis]
- pad_start = [0, 0, 0]
- pad_end = [ct_end[axis] - ct_start[axis] for axis in range(3)]
- # log.info([ct_end, ct_start, pad_end])
- chunk_ary = np.zeros(pad_end, dtype=np.float32)
- for axis in range(3):
- if ct_start[axis] < 0:
- pad_start[axis] = -ct_start[axis]
- ct_start[axis] = 0
- if ct_end[axis] > self.ary.shape[axis]:
- pad_end[axis] -= ct_end[axis] - self.ary.shape[axis]
- ct_end[axis] = self.ary.shape[axis]
- pad_slices = tuple(slice(s,e) for s, e in zip(pad_start, pad_end))
- ct_slices = tuple(slice(s,e) for s, e in zip(ct_start, ct_end))
- chunk_ary[pad_slices] = self.ary[ct_slices]
- return chunk_ary, center_irc
- ctCache_depth = 3
- @functools.lru_cache(ctCache_depth, typed=True)
- def getCt(series_uid):
- return Ct(series_uid)
- @raw_cache.memoize(typed=True)
- def getCtSize(series_uid):
- ct = Ct(series_uid, buildMasks_bool=False)
- return tuple(ct.ary.shape)
- # @raw_cache.memoize(typed=True)
- # def getCtLungExtents(series_uid):
- # ct = getCt(series_uid)
- # return (int(min(ct.lung_indexes)), int(max(ct.lung_indexes)))
- @raw_cache.memoize(typed=True)
- def getCtRawNodule(series_uid, center_xyz, width_irc):
- ct = getCt(series_uid)
- ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
- return ct_chunk, center_irc
- # clamp_value = 1.5
- @functools.lru_cache(1, typed=True)
- @cubic_cache.memoize(typed=True)
- def getCtCubicChunk(series_uid, center_xyz, maxWidth_mm):
- ct = getCt(series_uid)
- ct_chunk, center_irc = ct.getCubicInputChunk(center_xyz, maxWidth_mm)
- # # ct_chunk has been clamped to [0, 2] at this point
- # # We are going to convert to uint8 to reduce size on disk and loading time
- # ct_chunk[ct_chunk > clamp_value] = clamp_value
- # ct_chunk *= 255/clamp_value
- # ct_chunk = np.array(ct_chunk, dtype=np.uint8)
- return ct_chunk, center_irc
- def getCtAugmentedNodule(augmentation_dict, series_uid, center_xyz, width_mm, voxels_int, maxWidth_mm=32.0, use_cache=True):
- assert width_mm <= maxWidth_mm
- if use_cache:
- cubic_chunk, center_irc = getCtCubicChunk(series_uid, center_xyz, maxWidth_mm)
- else:
- ct = getCt(series_uid)
- ct_chunk, center_irc = ct.getCubicInputChunk(center_xyz, maxWidth_mm)
- slice_list = []
- for axis in range(3):
- crop_size = cubic_chunk.shape[axis] * width_mm / maxWidth_mm
- crop_size = int(math.ceil(crop_size))
- start_ndx = (cubic_chunk.shape[axis] - crop_size) // 2
- end_ndx = start_ndx + crop_size
- slice_list.append(slice(start_ndx, end_ndx))
- cropped_chunk = cubic_chunk[slice_list]
- # # inflate cropped_chunk back to float32
- # cropped_chunk = np.array(cropped_chunk, dtype=np.float32)
- # cropped_chunk *= clamp_value/255
- cropped_tensor = torch.tensor(cropped_chunk).unsqueeze(0).unsqueeze(0)
- transform_tensor = torch.eye(4).to(torch.float64)
- # Scale and Mirror
- for i in range(3):
- if 'scale' in augmentation_dict:
- scale_float = augmentation_dict['scale']
- transform_tensor[i,i] *= 1.0 - scale_float/2.0 + (random.random() * scale_float)
- if 'mirror' in augmentation_dict:
- if random.random() > 0.5:
- transform_tensor[i,i] *= -1
- # Rotate
- if 'rotate' in augmentation_dict:
- angle_rad = random.random() * math.pi * 2
- s = math.sin(angle_rad)
- c = math.cos(angle_rad)
- c1 = 1 - c
- axis_tensor = torch.rand([3], dtype=torch.float64)
- axis_tensor /= axis_tensor.pow(2).sum().pow(0.5)
- z, y, x = axis_tensor
- rotation_tensor = torch.tensor([
- [x*x*c1 + c, y*x*c1 - z*s, z*x*c1 + y*s, 0],
- [x*y*c1 + z*s, y*y*c1 + c, z*y*c1 - x*s, 0],
- [x*z*c1 - y*s, y*z*c1 + x*s, z*z*c1 + c, 0],
- [0, 0, 0, 1],
- ], dtype=torch.float64)
- transform_tensor @= rotation_tensor
- # Transform into final desired shape
- affine_tensor = affine_grid_generator(
- transform_tensor[:3].unsqueeze(0).to(torch.float32),
- torch.Size([1, 1, voxels_int, voxels_int, voxels_int])
- )
- zoomed_chunk = torch.nn.functional.grid_sample(
- cropped_tensor,
- affine_tensor,
- padding_mode='border'
- ).to('cpu')
- # Noise
- if 'noise' in augmentation_dict:
- noise_tensor = torch.randn(
- zoomed_chunk.size(),
- dtype=zoomed_chunk.dtype,
- )
- noise_tensor *= augmentation_dict['noise']
- zoomed_chunk += noise_tensor
- return zoomed_chunk[0,0], center_irc
- class LunaPrepcacheDataset(Dataset):
- def __init__(self):
- self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
- def __len__(self):
- return len(self.series_list)
- def __getitem__(self, ndx):
- getCtSize(self.series_list[ndx])
- # getCtLungExtents(self.series_list[ndx])
- return 0
- class LunaClassificationDataset(Dataset):
- def __init__(self,
- test_stride=0,
- isTestSet_bool=None,
- series_uid=None,
- sortby_str='random',
- ratio_int=0,
- scaled_bool=False,
- multiscaled_bool=False,
- augmented_bool=False,
- noduleInfo_list=None,
- ):
- self.ratio_int = ratio_int
- self.scaled_bool = scaled_bool
- self.multiscaled_bool = multiscaled_bool
- if augmented_bool:
- self.augmentation_dict = {
- 'mirror': True,
- 'rotate': True,
- }
- if isTestSet_bool:
- self.augmentation_dict['scale'] = 0.25
- else:
- self.augmentation_dict['scale'] = 0.5
- self.augmentation_dict['noise'] = 0.1
- else:
- self.augmentation_dict = {}
- if noduleInfo_list:
- self.noduleInfo_list = copy.copy(noduleInfo_list)
- self.use_cache = False
- else:
- self.noduleInfo_list = copy.copy(getNoduleInfoList())
- self.use_cache = True
- if series_uid:
- self.noduleInfo_list = [x for x in self.noduleInfo_list if x[2] == series_uid]
- if test_stride > 1:
- if isTestSet_bool:
- self.noduleInfo_list = self.noduleInfo_list[::test_stride]
- else:
- del self.noduleInfo_list[::test_stride]
- if sortby_str == 'random':
- random.shuffle(self.noduleInfo_list)
- elif sortby_str == 'series_uid':
- self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
- elif sortby_str == 'malignancy_size':
- pass
- else:
- raise Exception("Unknown sort: " + repr(sortby_str))
- self.benignIndex_list = [i for i, x in enumerate(self.noduleInfo_list) if not x[0]]
- self.malignantIndex_list = [i for i, x in enumerate(self.noduleInfo_list) if x[0]]
- log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
- self,
- len(self.noduleInfo_list),
- "testing" if isTestSet_bool else "training",
- len(self.benignIndex_list),
- len(self.malignantIndex_list),
- '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
- ))
- def shuffleSamples(self):
- if self.ratio_int:
- random.shuffle(self.benignIndex_list)
- random.shuffle(self.malignantIndex_list)
- def __len__(self):
- if self.ratio_int:
- # return 10000
- return 100000
- elif self.augmentation_dict:
- return len(self.noduleInfo_list) * 5
- else:
- return len(self.noduleInfo_list)
- def __getitem__(self, ndx):
- if self.ratio_int:
- malignant_ndx = ndx // (self.ratio_int + 1)
- if ndx % (self.ratio_int + 1):
- benign_ndx = ndx - 1 - malignant_ndx
- nodule_ndx = self.benignIndex_list[benign_ndx % len(self.benignIndex_list)]
- else:
- nodule_ndx = self.malignantIndex_list[malignant_ndx % len(self.malignantIndex_list)]
- augmentation_dict = self.augmentation_dict
- else:
- nodule_ndx = ndx % len(self.noduleInfo_list)
- if ndx < len(self.noduleInfo_list):
- augmentation_dict = {}
- else:
- augmentation_dict = self.augmentation_dict
- isMalignant_bool, _diameter_mm, series_uid, center_xyz = self.noduleInfo_list[nodule_ndx]
- if self.scaled_bool:
- channel_list = []
- voxels_int = 32
- if self.multiscaled_bool:
- width_list = [8, 16, 32]
- else:
- width_list = [24]
- for width_mm in width_list:
- nodule_ary, center_irc = getCtAugmentedNodule(augmentation_dict, series_uid, center_xyz, width_mm, voxels_int)
- # in: dim=3, Index x Row x Col
- # out: dim=4, Channel x Index x Row x Col
- nodule_ary = nodule_ary.unsqueeze(0)
- channel_list.append(nodule_ary)
- nodule_tensor = torch.cat(channel_list)
- else:
- nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, (32, 32, 32))
- nodule_ary = np.expand_dims(nodule_ary, 0)
- nodule_tensor = torch.from_numpy(nodule_ary)
- # dim=1
- malignant_tensor = torch.tensor([isMalignant_bool], dtype=torch.float32)
- return nodule_tensor, malignant_tensor, series_uid, center_irc
- #
- # return malignant_tensor, diameter_mm, series_uid, center_irc, nodule_tensor
- class Luna2dSegmentationDataset(Dataset):
- purpose_str = 'general'
- def __init__(self,
- contextSlices_count=2,
- series_uid=None,
- test_stride=0,
- ):
- self.contextSlices_count = contextSlices_count
- if series_uid:
- self.series_list = [series_uid]
- else:
- self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
- self.cullTrainOrTestSeries(test_stride)
- self.sample_list = []
- for series_uid in self.series_list:
- self.sample_list.extend([(series_uid, i) for i in range(int(getCtSize(series_uid)[0]))])
- log.info("{!r}: {} {} series, {} slices".format(
- self,
- len(self.series_list),
- self.purpose_str,
- len(self.sample_list),
- ))
- def cullTrainOrTestSeries(self, test_stride):
- assert test_stride == 0
- def __len__(self):
- return len(self.sample_list) #// 100
- def __getitem__(self, ndx):
- if isinstance(ndx, int):
- series_uid, sample_ndx = self.sample_list[ndx % len(self.sample_list)]
- else:
- series_uid, sample_ndx = ndx
- ct = getCt(series_uid)
- ct_tensor = torch.zeros((self.contextSlices_count * 2 + 2, 512, 512))
- masks_tensor = torch.zeros((2, 512, 512))
- start_ndx = sample_ndx - self.contextSlices_count
- end_ndx = sample_ndx + self.contextSlices_count + 1
- for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
- context_ndx = max(context_ndx, 0)
- context_ndx = min(context_ndx, ct.ary.shape[0] - 1)
- ct_tensor[i] = torch.from_numpy(ct.ary[context_ndx].astype(np.float32))
- air_mask, lung_mask = ct.build2dLungMask(sample_ndx)[:2]
- ct_tensor[-1] = torch.from_numpy(lung_mask.astype(np.float32))
- mal_mask = ct.malignant_mask[sample_ndx] & lung_mask
- ben_mask = ct.benign_mask[sample_ndx] & air_mask
- masks_tensor[0] = torch.from_numpy(mal_mask.astype(np.float32))
- masks_tensor[1] = torch.from_numpy((mal_mask | ben_mask).astype(np.float32))
- # masks_tensor[1] = torch.from_numpy(ben_mask.astype(np.float32))
- return ct_tensor.contiguous(), masks_tensor.contiguous(), ct.series_uid, sample_ndx
- class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
- purpose_str = 'training'
- def __init__(self, *args, **kwargs):
- self.needsShuffle_bool = True
- super().__init__(*args, **kwargs)
- def cullTrainOrTestSeries(self, test_stride):
- assert test_stride > 0, test_stride
- del self.series_list[::test_stride]
- assert self.series_list
- def __len__(self):
- # return 100
- # return 1000
- # return 10000
- return 20000
- # return 40000
- def __getitem__(self, ndx):
- if self.needsShuffle_bool:
- random.shuffle(self.series_list)
- self.needsShuffle_bool = False
- if random.random() < 0.01:
- self.series_list.append(self.series_list.pop(0))
- if isinstance(ndx, int):
- series_uid = self.series_list[ndx % ctCache_depth]
- ct = getCt(series_uid)
- sample_ndx = random.choice(ct.malignant_indexes or ct.benign_indexes)
- # series_uid, sample_ndx = self.sample_list[ndx % len(self.sample_list)]
- else:
- series_uid, sample_ndx = ndx
- # if ndx % 2 == 0:
- # sample_ndx = random.choice(ct.malignant_indexes or ct.benign_indexes)
- # else: #if ndx % 2 == 2:
- # sample_ndx = random.choice(ct.benign_indexes)
- # else:
- # sample_ndx = random.randint(*self.series2extents_dict[series_uid])
- return super().__getitem__((series_uid, sample_ndx))
- class TestingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
- purpose_str = 'testing'
- def cullTrainOrTestSeries(self, test_stride):
- assert test_stride > 0
- self.series_list = self.series_list[::test_stride]
- assert self.series_list
|