Browse Source

Update code repo to state at 2nd review.

Eli Stevens 7 years ago
parent
commit
b480f7ab68

File diff suppressed because it is too large
+ 21931 - 0
p1ch5/p1ch5.ipynb


File diff suppressed because it is too large
+ 357 - 0
p1ch6/p1ch6.ipynb


+ 0 - 0
p2ch1/__init__.py → p2ch08/__init__.py


+ 176 - 0
p2ch08/dsets.py

@@ -0,0 +1,176 @@
+import copy
+import csv
+import functools
+import glob
+import os
+import random
+
+import SimpleITK as sitk
+
+import numpy as np
+import torch
+import torch.cuda
+from torch.utils.data import Dataset
+
+from util.disk import getCache
+from util.util import XyzTuple, xyz2irc
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+raw_cache = getCache('part2ch08_raw')
+
+@functools.lru_cache(1)
+def getNoduleInfoList(requireDataOnDisk_bool=True):
+    # We construct a set with all series_uids that are present on disk.
+    # This will let us use the data, even if we haven't downloaded all of
+    # the subsets yet.
+    mhd_list = glob.glob('data/luna/subset*/*.mhd')
+    dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
+
+    diameter_dict = {}
+    with open('data/luna/annotations.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
+            annotationDiameter_mm = float(row[4])
+
+            diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
+
+    noduleInfo_list = []
+    with open('data/luna/candidates.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+
+            if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
+                continue
+
+            isMalignant_bool = bool(int(row[4]))
+            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
+
+            candidateDiameter_mm = 0.0
+            for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
+                for i in range(3):
+                    delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
+                    if delta_mm > annotationDiameter_mm / 4:
+                        break
+                else:
+                    candidateDiameter_mm = annotationDiameter_mm
+                    break
+
+            noduleInfo_list.append((isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+
+    noduleInfo_list.sort(reverse=True)
+    return noduleInfo_list
+
+class Ct(object):
+    def __init__(self, series_uid):
+        mhd_path = glob.glob('data/luna/subset*/{}.mhd'.format(series_uid))[0]
+
+        ct_mhd = sitk.ReadImage(mhd_path)
+        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
+
+        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
+        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
+        # This converts HU to g/cc.
+        ct_ary += 1000
+        ct_ary /= 1000
+
+        # This gets rid of negative density stuff used to indicate out-of-FOV
+        ct_ary[ct_ary < 0] = 0
+
+        # This nukes any weird hotspots and clamps bone down
+        ct_ary[ct_ary > 2] = 2
+
+        self.series_uid = series_uid
+        self.ary = ct_ary
+
+        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
+        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
+        self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
+
+    def getRawNodule(self, center_xyz, width_irc):
+        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
+
+        slice_list = []
+        for axis, center_val in enumerate(center_irc):
+            start_ndx = int(round(center_val - width_irc[axis]/2))
+            end_ndx = int(start_ndx + width_irc[axis])
+
+            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
+
+            if start_ndx < 0:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                start_ndx = 0
+                end_ndx = int(width_irc[axis])
+
+            if end_ndx > self.ary.shape[axis]:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                end_ndx = self.ary.shape[axis]
+                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
+
+            slice_list.append(slice(start_ndx, end_ndx))
+
+        ct_chunk = self.ary[slice_list]
+
+        return ct_chunk, center_irc
+
+
+@functools.lru_cache(1, typed=True)
+def getCt(series_uid):
+    return Ct(series_uid)
+
+@raw_cache.memoize(typed=True)
+def getCtRawNodule(series_uid, center_xyz, width_irc):
+    ct = getCt(series_uid)
+    ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
+    return ct_chunk, center_irc
+
+class LunaDataset(Dataset):
+    def __init__(self,
+                 test_stride=0,
+                 isTestSet_bool=None,
+                 series_uid=None,
+            ):
+        self.noduleInfo_list = copy.copy(getNoduleInfoList())
+
+        if series_uid:
+            self.noduleInfo_list = [x for x in self.noduleInfo_list if x[2] == series_uid]
+
+        # __init__ continued...
+        if test_stride > 1:
+            if isTestSet_bool:
+                self.noduleInfo_list = self.noduleInfo_list[::test_stride]
+            else:
+                del self.noduleInfo_list[::test_stride]
+
+        log.info("{!r}: {} {} samples".format(
+            self,
+            len(self.noduleInfo_list),
+            "testing" if isTestSet_bool else "training",
+        ))
+
+    def __len__(self):
+        return len(self.noduleInfo_list)
+
+    def __getitem__(self, ndx):
+        sample_ndx = ndx
+
+        isMalignant_bool, diameter_mm, series_uid, center_xyz = self.noduleInfo_list[sample_ndx]
+
+        nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, (32, 32, 32))
+
+        nodule_tensor = torch.from_numpy(nodule_ary)
+        nodule_tensor = nodule_tensor.unsqueeze(0)
+
+        malignant_tensor = torch.tensor([isMalignant_bool], dtype=torch.float32)
+
+        return nodule_tensor, malignant_tensor, series_uid, center_irc
+
+
+

+ 76 - 81
p2ch1/vis.py → p2ch08/vis.py

@@ -1,81 +1,76 @@
-import matplotlib
-import numpy as np
-import matplotlib.pyplot as plt
-
-from p2ch1.dsets import Ct, LunaDataset
-
-clim=(0.0, 1.3)
-
-def findMalignantSamples(start_ndx=0, limit=10):
-    ds = LunaDataset()
-
-    malignantSample_list = []
-    for sample_tup in ds.sample_list:
-        if sample_tup[2]:
-            malignantSample_list.append(sample_tup)
-            print(sample_tup)
-
-        if len(malignantSample_list) >= limit:
-            break
-
-    return malignantSample_list
-
-def showNodule(series_uid, batch_ndx=None):
-    ds = LunaDataset(series_uid=series_uid)
-    malignant_list = [i for i, x in enumerate(ds.sample_list) if x[2]]
-
-    if batch_ndx is None:
-        if malignant_list:
-            batch_ndx = malignant_list[0]
-        else:
-            print("Warning: no malignant samples found; using first non-malignant sample.")
-            batch_ndx = 0
-
-    ct = Ct(series_uid)
-    ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
-    ct_ary = ct_tensor[0].numpy()
-
-    fig = plt.figure(figsize=(15, 25))
-
-    group_list = [
-        #[0,1,2],
-        [3,4,5],
-        [6,7,8],
-        [9,10,11],
-        #[12,13,14],
-        #[15]
-    ]
-
-    subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
-    subplot.set_title('index {}'.format(int(center_irc.index)))
-    plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
-
-    subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
-    subplot.set_title('row {}'.format(int(center_irc.row)))
-    plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
-
-    subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
-    subplot.set_title('col {}'.format(int(center_irc.col)))
-    plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
-
-    subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
-    subplot.set_title('index {}'.format(int(center_irc.index)))
-    plt.imshow(ct_ary[7], clim=clim, cmap='gray')
-
-    subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
-    subplot.set_title('row {}'.format(int(center_irc.row)))
-    plt.imshow(ct_ary[:,7], clim=clim, cmap='gray')
-
-    subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
-    subplot.set_title('col {}'.format(int(center_irc.col)))
-    plt.imshow(ct_ary[:,:,7], clim=clim, cmap='gray')
-
-
-    for row, index_list in enumerate(group_list):
-        for col, index in enumerate(index_list):
-            subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
-            subplot.set_title('slice {}'.format(index))
-            plt.imshow(ct_ary[index], clim=clim, cmap='gray')
-
-
-    print(series_uid, batch_ndx, bool(malignant_tensor[0][0]), malignant_list)
+import matplotlib
+import numpy as np
+import matplotlib.pyplot as plt
+
+from p2ch08.dsets import Ct, LunaDataset
+
+clim=(0.0, 1.3)
+
+def findMalignantSamples(start_ndx=0, limit=100):
+    ds = LunaDataset()
+
+    malignantSample_list = []
+    for sample_tup in ds.noduleInfo_list:
+        if sample_tup[0]:
+            malignantSample_list.append(sample_tup)
+
+        if len(malignantSample_list) >= limit:
+            break
+
+    return malignantSample_list
+
+def showNodule(series_uid, batch_ndx=None):
+    ds = LunaDataset(series_uid=series_uid)
+    malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x[0]]
+
+    if batch_ndx is None:
+        if malignant_list:
+            batch_ndx = malignant_list[0]
+        else:
+            print("Warning: no malignant samples found; using first non-malignant sample.")
+            batch_ndx = 0
+
+    ct = Ct(series_uid)
+    ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
+    ct_ary = ct_tensor[0].numpy()
+
+    fig = plt.figure(figsize=(15, 25))
+
+    group_list = [
+        [9,11,13],
+        [15, 16, 17],
+        [19,21,23],
+    ]
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
+    subplot.set_title('index {}'.format(int(center_irc.index)))
+    plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
+    subplot.set_title('row {}'.format(int(center_irc.row)))
+    plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
+    subplot.set_title('col {}'.format(int(center_irc.col)))
+    plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
+    subplot.set_title('index {}'.format(int(center_irc.index)))
+    plt.imshow(ct_ary[ct_ary.shape[0]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
+    subplot.set_title('row {}'.format(int(center_irc.row)))
+    plt.imshow(ct_ary[:,ct_ary.shape[1]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
+    subplot.set_title('col {}'.format(int(center_irc.col)))
+    plt.imshow(ct_ary[:,:,ct_ary.shape[2]//2], clim=clim, cmap='gray')
+
+    for row, index_list in enumerate(group_list):
+        for col, index in enumerate(index_list):
+            subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
+            subplot.set_title('slice {}'.format(index))
+            plt.imshow(ct_ary[index], clim=clim, cmap='gray')
+
+
+    print(series_uid, batch_ndx, bool(malignant_tensor[0]), malignant_list)

File diff suppressed because it is too large
+ 175 - 0
p2ch08_explore_data.ipynb


+ 0 - 0
p2ch2/__init__.py → p2ch09/__init__.py


+ 190 - 0
p2ch09/dsets.py

@@ -0,0 +1,190 @@
+import copy
+import csv
+import functools
+import glob
+import os
+import random
+
+import SimpleITK as sitk
+
+import numpy as np
+import torch
+import torch.cuda
+from torch.utils.data import Dataset
+
+from util.disk import getCache
+from util.util import XyzTuple, xyz2irc
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+raw_cache = getCache('part2ch09_raw')
+
+@functools.lru_cache(1)
+def getNoduleInfoList(requireDataOnDisk_bool=True):
+    # We construct a set with all series_uids that are present on disk.
+    # This will let us use the data, even if we haven't downloaded all of
+    # the subsets yet.
+    mhd_list = glob.glob('data/luna/subset*/*.mhd')
+    dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
+
+    diameter_dict = {}
+    with open('data/luna/annotations.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
+            annotationDiameter_mm = float(row[4])
+
+            diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
+
+    noduleInfo_list = []
+    with open('data/luna/candidates.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+
+            if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
+                continue
+
+            isMalignant_bool = bool(int(row[4]))
+            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
+
+            candidateDiameter_mm = 0.0
+            for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
+                for i in range(3):
+                    delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
+                    if delta_mm > annotationDiameter_mm / 4:
+                        break
+                else:
+                    candidateDiameter_mm = annotationDiameter_mm
+                    break
+
+            noduleInfo_list.append((isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+
+    noduleInfo_list.sort(reverse=True)
+    return noduleInfo_list
+
+class Ct(object):
+    def __init__(self, series_uid):
+        mhd_path = glob.glob('data/luna/subset*/{}.mhd'.format(series_uid))[0]
+
+        ct_mhd = sitk.ReadImage(mhd_path)
+        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
+
+        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
+        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
+        # This converts HU to g/cc.
+        ct_ary += 1000
+        ct_ary /= 1000
+
+        # This gets rid of negative density stuff used to indicate out-of-FOV
+        ct_ary[ct_ary < 0] = 0
+
+        # This nukes any weird hotspots and clamps bone down
+        ct_ary[ct_ary > 2] = 2
+
+        self.series_uid = series_uid
+        self.ary = ct_ary
+
+        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
+        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
+        self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
+
+    def getRawNodule(self, center_xyz, width_irc):
+        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
+
+        slice_list = []
+        for axis, center_val in enumerate(center_irc):
+            start_ndx = int(round(center_val - width_irc[axis]/2))
+            end_ndx = int(start_ndx + width_irc[axis])
+
+            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
+
+            if start_ndx < 0:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                start_ndx = 0
+                end_ndx = int(width_irc[axis])
+
+            if end_ndx > self.ary.shape[axis]:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                end_ndx = self.ary.shape[axis]
+                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
+
+            slice_list.append(slice(start_ndx, end_ndx))
+
+        ct_chunk = self.ary[slice_list]
+
+        return ct_chunk, center_irc
+
+
+@functools.lru_cache(1, typed=True)
+def getCt(series_uid):
+    return Ct(series_uid)
+
+@raw_cache.memoize(typed=True)
+def getCtRawNodule(series_uid, center_xyz, width_irc):
+    ct = getCt(series_uid)
+    ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
+    return ct_chunk, center_irc
+
+class LunaDataset(Dataset):
+    def __init__(self,
+                 test_stride=0,
+                 isTestSet_bool=None,
+                 series_uid=None,
+                 sortby_str='random',
+            ):
+        self.noduleInfo_list = copy.copy(getNoduleInfoList())
+
+        if series_uid:
+            self.noduleInfo_list = [x for x in self.noduleInfo_list if x[2] == series_uid]
+
+        # __init__ continued...
+        if test_stride > 1:
+            if isTestSet_bool:
+                self.noduleInfo_list = self.noduleInfo_list[::test_stride]
+            else:
+                del self.noduleInfo_list[::test_stride]
+
+        if sortby_str == 'random':
+            random.shuffle(self.noduleInfo_list)
+        elif sortby_str == 'series_uid':
+            self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
+        elif sortby_str == 'malignancy_size':
+            pass
+        else:
+            raise Exception("Unknown sort: " + repr(sortby_str))
+
+        log.info("{!r}: {} {} samples".format(
+            self,
+            len(self.noduleInfo_list),
+            "testing" if isTestSet_bool else "training",
+        ))
+
+
+    def __len__(self):
+        # if self.ratio_int:
+        #     return min(len(self.benignIndex_list), len(self.malignantIndex_list)) * 4 * 90
+        # else:
+        return len(self.noduleInfo_list)
+
+    def __getitem__(self, ndx):
+        sample_ndx = ndx
+
+        isMalignant_bool, _diameter_mm, series_uid, center_xyz = self.noduleInfo_list[sample_ndx]
+
+        nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, (32, 32, 32))
+
+        nodule_tensor = torch.from_numpy(nodule_ary)
+        nodule_tensor = nodule_tensor.unsqueeze(0)
+
+        malignant_tensor = torch.tensor([isMalignant_bool], dtype=torch.float32)
+
+        return nodule_tensor, malignant_tensor, series_uid, center_irc
+
+
+

+ 52 - 0
p2ch09/model.py

@@ -0,0 +1,52 @@
+
+import torch
+from torch import nn as nn
+
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+class LunaModel(nn.Module):
+    def __init__(self, layer_count=4, in_channels=1, conv_channels=8):
+        super().__init__()
+
+        layer_list = []
+        for layer_ndx in range(layer_count):
+            layer_list += [
+                nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False),
+                nn.BatchNorm3d(conv_channels), # eli: will assume that p1ch6 doesn't use this
+                nn.LeakyReLU(inplace=True), # eli: will assume plan ReLU
+                nn.Dropout3d(p=0.2),  # eli: will assume that p1ch6 doesn't use this
+
+                nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=False),
+                nn.BatchNorm3d(conv_channels),
+                nn.LeakyReLU(inplace=True),
+                nn.Dropout3d(p=0.2),
+
+                nn.MaxPool3d(2, 2),
+ # tag::model_init[]
+           ]
+
+            in_channels = conv_channels
+            conv_channels *= 2
+
+        self.convAndPool_seq = nn.Sequential(*layer_list)
+        self.fullyConnected_layer = nn.Linear(512, 1)
+        self.final = nn.Hardtanh(min_val=0.0, max_val=1.0)
+
+
+    def forward(self, input_batch):
+        conv_output = self.convAndPool_seq(input_batch)
+        conv_flat = conv_output.view(conv_output.size(0), -1)
+
+        try:
+            classifier_output = self.fullyConnected_layer(conv_flat)
+        except:
+            log.debug(conv_flat.size())
+            raise
+
+        classifier_output = self.final(classifier_output)
+        return classifier_output

+ 63 - 61
p2ch3/prepcache.py → p2ch09/prepcache.py

@@ -1,61 +1,63 @@
-import argparse
-import sys
-
-import numpy as np
-
-import torch.nn as nn
-from torch.autograd import Variable
-from torch.optim import SGD
-from torch.utils.data import DataLoader
-
-from util.util import enumerateWithEstimate
-from .dsets import LunaDataset
-from util.logconf import logging
-from .model import LunaModel
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-log.setLevel(logging.INFO)
-# log.setLevel(logging.DEBUG)
-
-
-class LunaPrepCacheApp(object):
-    @classmethod
-    def __init__(self, sys_argv=None):
-        if sys_argv is None:
-            sys_argv = sys.argv[1:]
-
-        parser = argparse.ArgumentParser()
-        parser.add_argument('--batch-size',
-            help='Batch size to use for training',
-            default=256,
-            type=int,
-        )
-        parser.add_argument('--num-workers',
-            help='Number of worker processes for background data loading',
-            default=8,
-            type=int,
-        )
-
-        self.cli_args = parser.parse_args(sys_argv)
-
-    def main(self):
-        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
-        self.prep_dl = DataLoader(
-            LunaDataset(
-            ),
-            batch_size=self.cli_args.batch_size,
-            num_workers=self.cli_args.num_workers,
-        )
-
-        batch_iter = enumerateWithEstimate(
-            self.prep_dl,
-            "Stuffing cache",
-            start_ndx=self.prep_dl.num_workers,
-        )
-        for _ in batch_iter:
-            pass
-
-
-if __name__ == '__main__':
-    sys.exit(LunaPrepCacheApp().main() or 0)
+import argparse
+import sys
+
+import numpy as np
+
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset
+from util.logconf import logging
+from .model import LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+
+class LunaPrepCacheApp(object):
+    @classmethod
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=1024,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        self.prep_dl = DataLoader(
+            LunaDataset(
+                sortby_str='series_uid',
+            ),
+            batch_size=self.cli_args.batch_size,
+            num_workers=self.cli_args.num_workers,
+        )
+
+        batch_iter = enumerateWithEstimate(
+            self.prep_dl,
+            "Stuffing cache",
+            start_ndx=self.prep_dl.num_workers,
+        )
+        for _ in batch_iter:
+            pass
+
+
+if __name__ == '__main__':
+    sys.exit(LunaPrepCacheApp().main() or 0)

+ 216 - 0
p2ch09/training.py

@@ -0,0 +1,216 @@
+import argparse
+import datetime
+import os
+import sys
+
+import numpy as np
+from tensorboardX import SummaryWriter
+
+import torch
+import torch.nn as nn
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset
+from util.logconf import logging
+from .model import LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+# Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
+METRICS_LABEL_NDX=0
+METRICS_PRED_NDX=1
+METRICS_LOSS_NDX=2
+
+class LunaTrainingApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=32,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        parser.add_argument('--epochs',
+            help='Number of epochs to train for',
+            default=1,
+            type=int,
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+
+        self.model = LunaModel()
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                self.model = nn.DataParallel(self.model)
+
+            self.model = self.model.to(self.device)
+        self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
+
+        train_dl = DataLoader(
+            LunaDataset(
+                test_stride=10,
+                isTestSet_bool=False,
+            ),
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        test_dl = DataLoader(
+            LunaDataset(
+                test_stride=10,
+                isTestSet_bool=True,
+            ),
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        for epoch_ndx in range(1, self.cli_args.epochs + 1):
+
+            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
+                epoch_ndx,
+                self.cli_args.epochs,
+                len(train_dl),
+                len(test_dl),
+                self.cli_args.batch_size,
+                (torch.cuda.device_count() if self.use_cuda else 1),
+            ))
+
+            # Training loop, very similar to below
+            self.model.train()
+            trainingMetrics_tensor = torch.zeros(3, len(train_dl.dataset), 1)
+            batch_iter = enumerateWithEstimate(
+                train_dl,
+                "E{} Training".format(epoch_ndx),
+                start_ndx=train_dl.num_workers,
+            )
+            for batch_ndx, batch_tup in batch_iter:
+                self.optimizer.zero_grad()
+                loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
+                loss_var.backward()
+                self.optimizer.step()
+                del loss_var
+
+            # Testing loop, very similar to above, but simplified
+            with torch.no_grad():
+                self.model.eval()
+                testingMetrics_tensor = torch.zeros(3, len(test_dl.dataset), 1)
+                batch_iter = enumerateWithEstimate(
+                    test_dl,
+                    "E{} Testing ".format(epoch_ndx),
+                    start_ndx=test_dl.num_workers,
+                )
+                for batch_ndx, batch_tup in batch_iter:
+                    self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
+
+            self.logMetrics(epoch_ndx, trainingMetrics_tensor, testingMetrics_tensor)
+
+
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
+        input_tensor, label_tensor, _series_list, _center_list = batch_tup
+
+        input_devtensor = input_tensor.to(self.device)
+        label_devtensor = label_tensor.to(self.device)
+
+        prediction_devtensor = self.model(input_devtensor)
+        loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
+
+
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_tensor.size(0)
+        metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor
+        metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')
+        metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
+
+        # TODO: replace with torch.autograd.detect_anomaly
+        # assert np.isfinite(metrics_tensor).all()
+
+        return loss_devtensor.mean()
+
+
+    def logMetrics(self,
+                   epoch_ndx,
+                   trainingMetrics_tensor,
+                   testingMetrics_tensor,
+                   classificationThreshold_float=0.5,
+                   ):
+        log.info("E{} {}".format(
+            epoch_ndx,
+            type(self).__name__,
+        ))
+
+        for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
+            metrics_ary = metrics_tensor.detach().numpy()[:,:,0]
+            assert np.isfinite(metrics_ary).all()
+
+            benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= classificationThreshold_float
+            benPred_mask = metrics_ary[METRICS_PRED_NDX] <= classificationThreshold_float
+
+            malLabel_mask = ~benLabel_mask
+            malPred_mask = ~benPred_mask
+
+            benLabel_count = benLabel_mask.sum()
+            malLabel_count = malLabel_mask.sum()
+
+            benCorrect_count = (benLabel_mask & benPred_mask).sum()
+            malCorrect_count = (malLabel_mask & malPred_mask).sum()
+
+            metrics_dict = {}
+
+            metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
+            metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
+            metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
+
+            metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
+            metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
+            metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
+
+
+
+
+            log.info(("E{} {:8} "
+                     + "{loss/all:.4f} loss, "
+                     + "{correct/all:-5.1f}% correct"
+                      ).format(
+                epoch_ndx,
+                mode_str,
+                **metrics_dict,
+            ))
+            log.info(("E{} {:8} "
+                     + "{loss/ben:.4f} loss, "
+                     + "{correct/ben:-5.1f}% correct").format(
+                epoch_ndx,
+                mode_str + '_ben',
+                **metrics_dict,
+            ))
+            log.info(("E{} {:8} "
+                     + "{loss/mal:.4f} loss, "
+                     + "{correct/mal:-5.1f}% correct").format(
+                epoch_ndx,
+                mode_str + '_mal',
+                **metrics_dict,
+            ))
+
+
+if __name__ == '__main__':
+    sys.exit(LunaTrainingApp().main() or 0)

+ 86 - 80
p2ch4/vis.py → p2ch09/vis.py

@@ -1,80 +1,86 @@
-import matplotlib
-import numpy as np
-import matplotlib.pyplot as plt
-
-from p2ch4.dsets import Ct, LunaDataset
-
-clim=(0.0, 1.3)
-
-def findMalignantSamples(start_ndx=0, limit=10):
-    ds = LunaDataset()
-
-    malignantSample_list = []
-    for sample_tup in ds.sample_list:
-        if sample_tup[2]:
-            print(len(malignantSample_list), sample_tup)
-            malignantSample_list.append(sample_tup)
-
-        if len(malignantSample_list) >= limit:
-            break
-
-    return malignantSample_list
-
-def showNodule(series_uid, batch_ndx=None, **kwargs):
-    ds = LunaDataset(series_uid=series_uid, **kwargs)
-    malignant_list = [i for i, x in enumerate(ds.sample_list) if x[2]]
-
-    if batch_ndx is None:
-        if malignant_list:
-            batch_ndx = malignant_list[0]
-        else:
-            print("Warning: no malignant samples found; using first non-malignant sample.")
-            batch_ndx = 0
-
-    ct = Ct(series_uid)
-    ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
-    ct_ary = ct_tensor[0].numpy()
-
-    fig = plt.figure(figsize=(15, 25))
-
-    group_list = [
-        #[0,1,2],
-        [3,4,5],
-        [6,7,8],
-        [9,10,11],
-        #[12,13,14],
-        #[15]
-    ]
-
-    subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
-    subplot.set_title('index {}'.format(int(center_irc.index)))
-    plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
-
-    subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
-    subplot.set_title('row {}'.format(int(center_irc.row)))
-    plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
-
-    subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
-    subplot.set_title('col {}'.format(int(center_irc.col)))
-    plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
-
-    subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
-    subplot.set_title('index {}'.format(int(center_irc.index)))
-    plt.imshow(ct_ary[7], clim=clim, cmap='gray')
-
-    subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
-    subplot.set_title('row {}'.format(int(center_irc.row)))
-    plt.imshow(ct_ary[:,7], clim=clim, cmap='gray')
-
-    subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
-    subplot.set_title('col {}'.format(int(center_irc.col)))
-    plt.imshow(ct_ary[:,:,7], clim=clim, cmap='gray')
-
-    for row, index_list in enumerate(group_list):
-        for col, index in enumerate(index_list):
-            subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
-            subplot.set_title('slice {}'.format(index))
-            plt.imshow(ct_ary[index], clim=clim, cmap='gray')
-
-
-    print(series_uid, batch_ndx, bool(malignant_tensor[0][0]), malignant_list, ct.vxSize_xyz)
+import matplotlib
+matplotlib.use('nbagg')
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from p2ch11.dsets import Ct, LunaDataset
+
+clim=(0.0, 1.3)
+
+def findMalignantSamples(start_ndx=0, limit=10):
+    ds = LunaDataset()
+
+    malignantSample_list = []
+    for sample_tup in ds.sample_list:
+        if sample_tup[2]:
+            print(len(malignantSample_list), sample_tup)
+            malignantSample_list.append(sample_tup)
+
+        if len(malignantSample_list) >= limit:
+            break
+
+    return malignantSample_list
+
+def showNodule(series_uid, batch_ndx=None, **kwargs):
+    ds = LunaDataset(series_uid=series_uid, **kwargs)
+    malignant_list = [i for i, x in enumerate(ds.sample_list) if x[2]]
+
+    if batch_ndx is None:
+        if malignant_list:
+            batch_ndx = malignant_list[0]
+        else:
+            print("Warning: no malignant samples found; using first non-malignant sample.")
+            batch_ndx = 0
+
+    ct = Ct(series_uid)
+    # ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
+    malignant_tensor, diameter_mm, series_uid, center_irc, nodule_tensor = ds[batch_ndx]
+    ct_ary = nodule_tensor[1].numpy()
+
+
+    fig = plt.figure(figsize=(15, 25))
+
+    group_list = [
+        #[0,1,2],
+        [3,4,5],
+        [6,7,8],
+        [9,10,11],
+        #[12,13,14],
+        #[15]
+    ]
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
+    subplot.set_title('index {}'.format(int(center_irc.index)))
+    plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
+    subplot.set_title('row {}'.format(int(center_irc.row)))
+    plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
+    subplot.set_title('col {}'.format(int(center_irc.col)))
+    plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
+    subplot.set_title('index {}'.format(int(center_irc.index)))
+    plt.imshow(ct_ary[ct_ary.shape[0]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
+    subplot.set_title('row {}'.format(int(center_irc.row)))
+    plt.imshow(ct_ary[:,ct_ary.shape[1]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
+    subplot.set_title('col {}'.format(int(center_irc.col)))
+    plt.imshow(ct_ary[:,:,ct_ary.shape[2]//2], clim=clim, cmap='gray')
+
+    for row, index_list in enumerate(group_list):
+        for col, index in enumerate(index_list):
+            subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
+            subplot.set_title('slice {}'.format(index))
+            plt.imshow(ct_ary[index*2], clim=clim, cmap='gray')
+
+
+    print(series_uid, batch_ndx, bool(malignant_tensor[0]), malignant_list, ct.vxSize_xyz)
+
+    return ct_ary

+ 0 - 147
p2ch1/dsets.py

@@ -1,147 +0,0 @@
-import csv
-import functools
-import glob
-import math
-import time
-
-import SimpleITK as sitk
-
-import numpy as np
-import torch
-import torch.cuda
-from torch.utils.data import Dataset
-
-from util.disk import getCache
-from util.util import XyzTuple, xyz2irc
-from util.logconf import logging
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-log.setLevel(logging.INFO)
-log.setLevel(logging.DEBUG)
-
-cache = getCache('p2ch1')
-
-class Ct(object):
-    def __init__(self, series_uid):
-        mhd_path = glob.glob('data/luna/subset*/{}.mhd'.format(series_uid))[0]
-
-        ct_mhd = sitk.ReadImage(mhd_path)
-        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
-
-        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
-        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
-        # This converts HU to g/cc.
-        ct_ary += 1000
-        ct_ary /= 1000
-
-        # This gets rid of negative density stuff used to indicate out-of-FOV
-        ct_ary[ct_ary < 0] = 0
-
-        # This nukes any weird hotspots and clamps bone down
-        ct_ary[ct_ary > 2] = 2
-
-        self.series_uid = series_uid
-        self.ary = ct_ary
-        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
-        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
-        self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
-
-    def getInputChunk(self, center_xyz, width_irc):
-        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
-
-        slice_list = []
-        for axis, center_val in enumerate(center_irc):
-            start_ndx = int(round(center_val - width_irc[axis]/2))
-            end_ndx = int(start_ndx + width_irc[axis])
-
-            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
-
-            if start_ndx < 0:
-                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
-                start_ndx = 0
-                end_ndx = int(width_irc[axis])
-
-            if end_ndx > self.ary.shape[axis]:
-                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
-                end_ndx = self.ary.shape[axis]
-                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
-
-            slice_list.append(slice(start_ndx, end_ndx))
-
-        ct_chunk = self.ary[slice_list]
-
-        return ct_chunk, center_irc
-
-
-@functools.lru_cache(1, typed=True)
-def getCt(series_uid):
-    return Ct(series_uid)
-
-@cache.memoize(typed=True)
-def getCtInputChunk(series_uid, center_xyz, width_irc):
-    ct = getCt(series_uid)
-    ct_chunk, center_irc = ct.getInputChunk(center_xyz, width_irc)
-    return ct_chunk, center_irc
-
-class LunaDataset(Dataset):
-    def __init__(self, test_stride=0, isTestSet_bool=None, series_uid=None):
-        # We construct a set with all series_uids that are present on disk.
-        # This will let us use the data, even if we haven't downloaded all of
-        # the subsets yet.
-        mhd_list = glob.glob('data/luna/subset*/*.mhd')
-        present_set = {p.rsplit('/', 1)[-1][:-4] for p in mhd_list}
-
-        sample_list = []
-        with open('data/luna/candidates.csv', "r") as f:
-            csv_list = list(csv.reader(f))
-
-        for row in csv_list[1:]:
-            row_uid = row[0]
-
-            if series_uid and series_uid != row_uid:
-                continue
-
-            # If a row_uid isn't present, that means it's in a subset that we
-            # don't have on disk, so we should skip it.
-            if row_uid not in present_set:
-                continue
-
-            center_xyz = tuple([float(x) for x in row[1:4]])
-            isMalignant_bool = bool(int(row[4]))
-            sample_list.append((row_uid, center_xyz, isMalignant_bool))
-
-        sample_list.sort()
-        if test_stride > 1:
-            if isTestSet_bool:
-                sample_list = sample_list[::test_stride]
-            else:
-                del sample_list[::test_stride]
-
-        log.info("{!r}: {} {} samples".format(self, len(sample_list), "testing" if isTestSet_bool else "training"))
-        self.sample_list = sample_list
-
-    def __len__(self):
-        return len(self.sample_list)
-
-    def __getitem__(self, ndx):
-        series_uid, center_xyz, isMalignant_bool = self.sample_list[ndx]
-        ct_chunk, center_irc = getCtInputChunk(series_uid, center_xyz, (16, 16, 16))
-
-        # dim=3, Index x Row x Col
-        ct_tensor = torch.from_numpy(np.array(ct_chunk, dtype=np.float32))
-
-        # dim=1
-        malignant_tensor = torch.from_numpy(np.array([isMalignant_bool], dtype=np.float32))
-
-        # dim=4, Channel x Index x Row x Col
-        ct_tensor = ct_tensor.unsqueeze(0)
-        malignant_tensor = malignant_tensor.unsqueeze(0)
-
-        # Unpacked as: input_tensor, answer_int, series_uid, center_irc
-        return ct_tensor, malignant_tensor, series_uid, center_irc
-
-
-

File diff suppressed because it is too large
+ 59 - 0
p2ch10/1_final_metric_f1_score.ipynb


+ 0 - 0
p2ch3/__init__.py → p2ch10/__init__.py


+ 211 - 0
p2ch10/dsets.py

@@ -0,0 +1,211 @@
+import copy
+import csv
+import functools
+import glob
+import os
+import random
+
+import SimpleITK as sitk
+
+import numpy as np
+import torch
+import torch.cuda
+from torch.utils.data import Dataset
+
+from util.disk import getCache
+from util.util import XyzTuple, xyz2irc
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+raw_cache = getCache('part2ch10_raw')
+
+@functools.lru_cache(1)
+def getNoduleInfoList(requireDataOnDisk_bool=True):
+    # We construct a set with all series_uids that are present on disk.
+    # This will let us use the data, even if we haven't downloaded all of
+    # the subsets yet.
+    mhd_list = glob.glob('data/luna/subset*/*.mhd')
+    dataPresentOnDisk_set = {os.path.split(p)[-1][:-4] for p in mhd_list}
+
+    diameter_dict = {}
+    with open('data/luna/annotations.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+            annotationCenter_xyz = tuple([float(x) for x in row[1:4]])
+            annotationDiameter_mm = float(row[4])
+
+            diameter_dict.setdefault(series_uid, []).append((annotationCenter_xyz, annotationDiameter_mm))
+
+    noduleInfo_list = []
+    with open('data/luna/candidates.csv', "r") as f:
+        for row in list(csv.reader(f))[1:]:
+            series_uid = row[0]
+
+            if series_uid not in dataPresentOnDisk_set and requireDataOnDisk_bool:
+                continue
+
+            isMalignant_bool = bool(int(row[4]))
+            candidateCenter_xyz = tuple([float(x) for x in row[1:4]])
+
+            candidateDiameter_mm = 0.0
+            for annotationCenter_xyz, annotationDiameter_mm in diameter_dict.get(series_uid, []):
+                for i in range(3):
+                    delta_mm = abs(candidateCenter_xyz[i] - annotationCenter_xyz[i])
+                    if delta_mm > annotationDiameter_mm / 4:
+                        break
+                else:
+                    candidateDiameter_mm = annotationDiameter_mm
+                    break
+
+            noduleInfo_list.append((isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+
+    noduleInfo_list.sort(reverse=True)
+    return noduleInfo_list
+
+class Ct(object):
+    def __init__(self, series_uid):
+        mhd_path = glob.glob('data/luna/subset*/{}.mhd'.format(series_uid))[0]
+
+        ct_mhd = sitk.ReadImage(mhd_path)
+        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
+
+        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
+        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
+        # This converts HU to g/cc.
+        ct_ary += 1000
+        ct_ary /= 1000
+
+        # This gets rid of negative density stuff used to indicate out-of-FOV
+        ct_ary[ct_ary < 0] = 0
+
+        # This nukes any weird hotspots and clamps bone down
+        ct_ary[ct_ary > 2] = 2
+
+        self.series_uid = series_uid
+        self.ary = ct_ary
+
+        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
+        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
+        self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
+
+    def getRawNodule(self, center_xyz, width_irc):
+        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
+
+        slice_list = []
+        for axis, center_val in enumerate(center_irc):
+            start_ndx = int(round(center_val - width_irc[axis]/2))
+            end_ndx = int(start_ndx + width_irc[axis])
+
+            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
+
+            if start_ndx < 0:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                start_ndx = 0
+                end_ndx = int(width_irc[axis])
+
+            if end_ndx > self.ary.shape[axis]:
+                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
+                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
+                end_ndx = self.ary.shape[axis]
+                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
+
+            slice_list.append(slice(start_ndx, end_ndx))
+
+        ct_chunk = self.ary[slice_list]
+
+        return ct_chunk, center_irc
+
+
+@functools.lru_cache(1, typed=True)
+def getCt(series_uid):
+    return Ct(series_uid)
+
+@raw_cache.memoize(typed=True)
+def getCtRawNodule(series_uid, center_xyz, width_irc):
+    ct = getCt(series_uid)
+    ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
+    return ct_chunk, center_irc
+
+class LunaDataset(Dataset):
+    def __init__(self,
+                 test_stride=0,
+                 isTestSet_bool=None,
+                 series_uid=None,
+                 sortby_str='random',
+                 ratio_int=0,
+            ):
+        self.ratio_int = ratio_int
+
+        self.noduleInfo_list = copy.copy(getNoduleInfoList())
+
+        if series_uid:
+            self.noduleInfo_list = [x for x in self.noduleInfo_list if x[2] == series_uid]
+
+        if test_stride > 1:
+            if isTestSet_bool:
+                self.noduleInfo_list = self.noduleInfo_list[::test_stride]
+            else:
+                del self.noduleInfo_list[::test_stride]
+
+        if sortby_str == 'random':
+            random.shuffle(self.noduleInfo_list)
+        elif sortby_str == 'series_uid':
+            self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
+        elif sortby_str == 'malignancy_size':
+            pass
+        else:
+            raise Exception("Unknown sort: " + repr(sortby_str))
+
+        self.benignIndex_list = [i for i, x in enumerate(self.noduleInfo_list) if not x[0]]
+        self.malignantIndex_list = [i for i, x in enumerate(self.noduleInfo_list) if x[0]]
+
+        log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
+            self,
+            len(self.noduleInfo_list),
+            "testing" if isTestSet_bool else "training",
+            len(self.benignIndex_list),
+            len(self.malignantIndex_list),
+            '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
+        ))
+
+    def shuffleSamples(self):
+        if self.ratio_int:
+            random.shuffle(self.benignIndex_list)
+            random.shuffle(self.malignantIndex_list)
+
+    def __len__(self):
+        if self.ratio_int:
+            return 100000
+        else:
+            return len(self.noduleInfo_list)
+
+    def __getitem__(self, ndx):
+        if self.ratio_int:
+            malignant_ndx = ndx // (self.ratio_int + 1)
+
+            if ndx % (self.ratio_int + 1):
+                benign_ndx = ndx - 1 - malignant_ndx
+                nodule_ndx = self.benignIndex_list[benign_ndx % len(self.benignIndex_list)]
+            else:
+                nodule_ndx = self.malignantIndex_list[malignant_ndx % len(self.malignantIndex_list)]
+        else:
+            nodule_ndx = ndx
+
+        isMalignant_bool, _diameter_mm, series_uid, center_xyz = self.noduleInfo_list[nodule_ndx]
+
+        nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, (32, 32, 32))
+
+        nodule_tensor = torch.from_numpy(nodule_ary)
+        nodule_tensor = nodule_tensor.unsqueeze(0)
+
+        malignant_tensor = torch.tensor([isMalignant_bool], dtype=torch.float32)
+
+        return nodule_tensor, malignant_tensor, series_uid, center_irc
+
+
+

+ 52 - 0
p2ch10/model.py

@@ -0,0 +1,52 @@
+
+import torch
+from torch import nn as nn
+
+from util.logconf import logging
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+class LunaModel(nn.Module):
+    def __init__(self, layer_count=4, in_channels=1, conv_channels=8):
+        super().__init__()
+
+        layer_list = []
+        for layer_ndx in range(layer_count):
+            layer_list += [
+                nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=False),
+                nn.BatchNorm3d(conv_channels), # eli: will assume that p1ch6 doesn't use this
+                nn.LeakyReLU(inplace=True), # eli: will assume plan ReLU
+                nn.Dropout3d(p=0.2),  # eli: will assume that p1ch6 doesn't use this
+
+                nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=False),
+                nn.BatchNorm3d(conv_channels),
+                nn.LeakyReLU(inplace=True),
+                nn.Dropout3d(p=0.2),
+
+                nn.MaxPool3d(2, 2),
+ # tag::model_init[]
+           ]
+
+            in_channels = conv_channels
+            conv_channels *= 2
+
+        self.convAndPool_seq = nn.Sequential(*layer_list)
+        self.fullyConnected_layer = nn.Linear(512, 1)
+        self.final = nn.Hardtanh(min_val=0.0, max_val=1.0)
+
+
+    def forward(self, input_batch):
+        conv_output = self.convAndPool_seq(input_batch)
+        conv_flat = conv_output.view(conv_output.size(0), -1)
+
+        try:
+            classifier_output = self.fullyConnected_layer(conv_flat)
+        except:
+            log.debug(conv_flat.size())
+            raise
+
+        classifier_output = self.final(classifier_output)
+        return classifier_output

+ 63 - 67
p2ch4/prepcache.py → p2ch10/prepcache.py

@@ -1,67 +1,63 @@
-import argparse
-import sys
-
-import numpy as np
-
-import torch.nn as nn
-from torch.autograd import Variable
-from torch.optim import SGD
-from torch.utils.data import DataLoader
-
-from util.util import enumerateWithEstimate
-from .dsets import LunaDataset
-from util.logconf import logging
-from .model import LunaModel
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-log.setLevel(logging.INFO)
-# log.setLevel(logging.DEBUG)
-
-
-class LunaPrepCacheApp(object):
-    @classmethod
-    def __init__(self, sys_argv=None):
-        if sys_argv is None:
-            sys_argv = sys.argv[1:]
-
-        parser = argparse.ArgumentParser()
-        parser.add_argument('--batch-size',
-            help='Batch size to use for training',
-            default=256,
-            type=int,
-        )
-        parser.add_argument('--num-workers',
-            help='Number of worker processes for background data loading',
-            default=8,
-            type=int,
-        )
-        parser.add_argument('--scaled',
-            help="Scale the CT chunks to square voxels.",
-            default=False,
-            action='store_true',
-        )
-
-        self.cli_args = parser.parse_args(sys_argv)
-
-    def main(self):
-        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
-        self.prep_dl = DataLoader(
-            LunaDataset(
-                scaled_bool=self.cli_args.scaled,
-            ),
-            batch_size=self.cli_args.batch_size,
-            num_workers=self.cli_args.num_workers,
-        )
-
-        batch_iter = enumerateWithEstimate(
-            self.prep_dl,
-            "Stuffing cache",
-            start_ndx=self.prep_dl.num_workers,
-        )
-        for _ in batch_iter:
-            pass
-
-
-if __name__ == '__main__':
-    sys.exit(LunaPrepCacheApp().main() or 0)
+import argparse
+import sys
+
+import numpy as np
+
+import torch.nn as nn
+from torch.autograd import Variable
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset
+from util.logconf import logging
+from .model import LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+
+class LunaPrepCacheApp(object):
+    @classmethod
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=1024,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        self.prep_dl = DataLoader(
+            LunaDataset(
+                sortby_str='series_uid',
+            ),
+            batch_size=self.cli_args.batch_size,
+            num_workers=self.cli_args.num_workers,
+        )
+
+        batch_iter = enumerateWithEstimate(
+            self.prep_dl,
+            "Stuffing cache",
+            start_ndx=self.prep_dl.num_workers,
+        )
+        for _ in batch_iter:
+            pass
+
+
+if __name__ == '__main__':
+    sys.exit(LunaPrepCacheApp().main() or 0)

+ 291 - 0
p2ch10/training.py

@@ -0,0 +1,291 @@
+import argparse
+import datetime
+import os
+import sys
+
+import numpy as np
+
+from tensorboardX import SummaryWriter
+
+import torch
+import torch.nn as nn
+from torch.optim import SGD
+from torch.utils.data import DataLoader
+
+from util.util import enumerateWithEstimate
+from .dsets import LunaDataset
+from util.logconf import logging
+from .model import LunaModel
+
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
+
+# Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
+METRICS_LABEL_NDX=0
+METRICS_PRED_NDX=1
+METRICS_LOSS_NDX=2
+
+class LunaTrainingApp(object):
+    def __init__(self, sys_argv=None):
+        if sys_argv is None:
+            sys_argv = sys.argv[1:]
+
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--batch-size',
+            help='Batch size to use for training',
+            default=32,
+            type=int,
+        )
+        parser.add_argument('--num-workers',
+            help='Number of worker processes for background data loading',
+            default=8,
+            type=int,
+        )
+        parser.add_argument('--epochs',
+            help='Number of epochs to train for',
+            default=1,
+            type=int,
+        )
+        parser.add_argument('--balanced',
+            help="Balance the training data to half benign, half malignant.",
+            action='store_true',
+            default=False,
+        )
+
+        parser.add_argument('--tb-prefix',
+            default='p2ch10',
+            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
+        )
+
+        parser.add_argument('comment',
+            help="Comment suffix for Tensorboard run.",
+            nargs='?',
+            default='none',
+        )
+
+        self.cli_args = parser.parse_args(sys_argv)
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
+
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
+
+        self.use_cuda = torch.cuda.is_available()
+        self.device = torch.device("cuda" if self.use_cuda else "cpu")
+        self.totalTrainingSamples_count = 0
+
+        self.model = LunaModel()
+        if self.use_cuda:
+            if torch.cuda.device_count() > 1:
+                self.model = nn.DataParallel(self.model)
+
+            self.model = self.model.to(self.device)
+
+        self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
+
+        train_dl = DataLoader(
+            LunaDataset(
+                test_stride=10,
+                isTestSet_bool=False,
+                ratio_int=int(self.cli_args.balanced),
+            ),
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        test_dl = DataLoader(
+            LunaDataset(
+                test_stride=10,
+                isTestSet_bool=True,
+            ),
+            batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
+            num_workers=self.cli_args.num_workers,
+            pin_memory=self.use_cuda,
+        )
+
+        for epoch_ndx in range(1, self.cli_args.epochs + 1):
+
+            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
+                epoch_ndx,
+                self.cli_args.epochs,
+                len(train_dl),
+                len(test_dl),
+                self.cli_args.batch_size,
+                (torch.cuda.device_count() if self.use_cuda else 1),
+            ))
+
+            # Training loop, very similar to below
+            self.model.train()
+            trainingMetrics_tensor = torch.zeros(3, len(train_dl.dataset), 1)
+            train_dl.dataset.shuffleSamples()
+            batch_iter = enumerateWithEstimate(
+                train_dl,
+                "E{} Training".format(epoch_ndx),
+                start_ndx=train_dl.num_workers,
+            )
+            for batch_ndx, batch_tup in batch_iter:
+                self.optimizer.zero_grad()
+                loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
+                loss_var.backward()
+                self.optimizer.step()
+                del loss_var
+
+            # Testing loop, very similar to above, but simplified
+            with torch.no_grad():
+                self.model.eval()
+                testingMetrics_tensor = torch.zeros(3, len(test_dl.dataset), 1)
+                batch_iter = enumerateWithEstimate(
+                    test_dl,
+                    "E{} Testing ".format(epoch_ndx),
+                    start_ndx=test_dl.num_workers,
+                )
+                for batch_ndx, batch_tup in batch_iter:
+                    self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
+
+            self.logMetrics(epoch_ndx, trainingMetrics_tensor, testingMetrics_tensor)
+
+        if hasattr(self, 'trn_writer'):
+            self.trn_writer.close()
+            self.tst_writer.close()
+
+
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
+        input_tensor, label_tensor, _series_list, _center_list = batch_tup
+
+        input_devtensor = input_tensor.to(self.device)
+        label_devtensor = label_tensor.to(self.device)
+
+        prediction_devtensor = self.model(input_devtensor)
+        loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
+
+
+        start_ndx = batch_ndx * batch_size
+        end_ndx = start_ndx + label_tensor.size(0)
+        metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor
+        metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')
+        metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
+
+        # TODO: replace with torch.autograd.detect_anomaly
+        # assert np.isfinite(metrics_tensor).all()
+
+        return loss_devtensor.mean()
+
+
+    def logMetrics(self,
+                   epoch_ndx,
+                   trainingMetrics_tensor,
+                   testingMetrics_tensor,
+                   classificationThreshold_float=0.5,
+                   ):
+        log.info("E{} {}".format(
+            epoch_ndx,
+            type(self).__name__,
+        ))
+
+        if epoch_ndx == 2:
+            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
+
+            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_' + self.cli_args.comment)
+            self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_' + self.cli_args.comment)
+
+        self.totalTrainingSamples_count += trainingMetrics_tensor.size(1)
+
+        for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
+            metrics_ary = metrics_tensor.cpu().detach().numpy()[:,:,0]
+            assert np.isfinite(metrics_ary).all()
+
+            benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= classificationThreshold_float
+            benPred_mask = metrics_ary[METRICS_PRED_NDX] <= classificationThreshold_float
+
+            malLabel_mask = ~benLabel_mask
+            malPred_mask = ~benPred_mask
+
+            benLabel_count = benLabel_mask.sum()
+            malLabel_count = malLabel_mask.sum()
+
+            trueNeg_count = benCorrect_count = (benLabel_mask & benPred_mask).sum()
+            truePos_count = malCorrect_count = (malLabel_mask & malPred_mask).sum()
+
+            falsePos_count = benLabel_count - benCorrect_count
+            falseNeg_count = malLabel_count - malCorrect_count
+
+
+            metrics_dict = {}
+            metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
+            metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
+            metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
+
+            metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
+            metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
+            metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
+
+            precision = metrics_dict['pr/precision'] = truePos_count / (truePos_count + falsePos_count)
+            recall    = metrics_dict['pr/recall']    = truePos_count / (truePos_count + falseNeg_count)
+
+            metrics_dict['pr/f1_score'] = 2 * (precision * recall) / (precision + recall)
+
+            log.info(("E{} {:8} "
+                     + "{loss/all:.4f} loss, "
+                     + "{correct/all:-5.1f}% correct, "
+                     + "{pr/precision:.4f} precision, "
+                     + "{pr/recall:.4f} recall, "
+                     + "{pr/f1_score:.4f} f1 score"
+                      ).format(
+                epoch_ndx,
+                mode_str,
+                **metrics_dict,
+            ))
+            log.info(("E{} {:8} "
+                     + "{loss/ben:.4f} loss, "
+                     + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})").format(
+                epoch_ndx,
+                mode_str + '_ben',
+                benCorrect_count=benCorrect_count,
+                benLabel_count=benLabel_count,
+                **metrics_dict,
+            ))
+            log.info(("E{} {:8} "
+                     + "{loss/mal:.4f} loss, "
+                     + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})").format(
+                epoch_ndx,
+                mode_str + '_mal',
+                malCorrect_count=malCorrect_count,
+                malLabel_count=malLabel_count,
+                **metrics_dict,
+            ))
+
+            if epoch_ndx > 1:
+                writer = getattr(self, mode_str + '_writer')
+
+                for key, value in metrics_dict.items():
+                    writer.add_scalar(key, value, self.totalTrainingSamples_count)
+
+                writer.add_pr_curve(
+                    'pr',
+                    metrics_ary[METRICS_LABEL_NDX],
+                    metrics_ary[METRICS_PRED_NDX],
+                    self.totalTrainingSamples_count,
+                )
+
+                benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
+                malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
+
+                bins = [x/50.0 for x in range(51)]
+                writer.add_histogram(
+                    'is_ben',
+                    metrics_ary[METRICS_PRED_NDX, benHist_mask],
+                    self.totalTrainingSamples_count,
+                    bins=bins,
+                )
+                writer.add_histogram(
+                    'is_mal',
+                    metrics_ary[METRICS_PRED_NDX, malHist_mask],
+                    self.totalTrainingSamples_count,
+                    bins=bins,
+                )
+
+
+if __name__ == '__main__':
+    sys.exit(LunaTrainingApp().main() or 0)

+ 86 - 0
p2ch10/vis.py

@@ -0,0 +1,86 @@
+import matplotlib
+matplotlib.use('nbagg')
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from p2ch11.dsets import Ct, LunaDataset
+
+clim=(0.0, 1.3)
+
+def findMalignantSamples(start_ndx=0, limit=10):
+    ds = LunaDataset()
+
+    malignantSample_list = []
+    for sample_tup in ds.sample_list:
+        if sample_tup[2]:
+            print(len(malignantSample_list), sample_tup)
+            malignantSample_list.append(sample_tup)
+
+        if len(malignantSample_list) >= limit:
+            break
+
+    return malignantSample_list
+
+def showNodule(series_uid, batch_ndx=None, **kwargs):
+    ds = LunaDataset(series_uid=series_uid, **kwargs)
+    malignant_list = [i for i, x in enumerate(ds.sample_list) if x[2]]
+
+    if batch_ndx is None:
+        if malignant_list:
+            batch_ndx = malignant_list[0]
+        else:
+            print("Warning: no malignant samples found; using first non-malignant sample.")
+            batch_ndx = 0
+
+    ct = Ct(series_uid)
+    # ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
+    malignant_tensor, diameter_mm, series_uid, center_irc, nodule_tensor = ds[batch_ndx]
+    ct_ary = nodule_tensor[1].numpy()
+
+
+    fig = plt.figure(figsize=(15, 25))
+
+    group_list = [
+        #[0,1,2],
+        [3,4,5],
+        [6,7,8],
+        [9,10,11],
+        #[12,13,14],
+        #[15]
+    ]
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
+    subplot.set_title('index {}'.format(int(center_irc.index)))
+    plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
+    subplot.set_title('row {}'.format(int(center_irc.row)))
+    plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
+    subplot.set_title('col {}'.format(int(center_irc.col)))
+    plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
+    subplot.set_title('index {}'.format(int(center_irc.index)))
+    plt.imshow(ct_ary[ct_ary.shape[0]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
+    subplot.set_title('row {}'.format(int(center_irc.row)))
+    plt.imshow(ct_ary[:,ct_ary.shape[1]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
+    subplot.set_title('col {}'.format(int(center_irc.col)))
+    plt.imshow(ct_ary[:,:,ct_ary.shape[2]//2], clim=clim, cmap='gray')
+
+    for row, index_list in enumerate(group_list):
+        for col, index in enumerate(index_list):
+            subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
+            subplot.set_title('slice {}'.format(index))
+            plt.imshow(ct_ary[index*2], clim=clim, cmap='gray')
+
+
+    print(series_uid, batch_ndx, bool(malignant_tensor[0]), malignant_list, ct.vxSize_xyz)
+
+    return ct_ary

+ 0 - 148
p2ch2/dsets.py

@@ -1,148 +0,0 @@
-import csv
-import functools
-import glob
-import math
-import time
-
-import SimpleITK as sitk
-
-import numpy as np
-import torch
-import torch.cuda
-from torch.utils.data import Dataset
-
-from util.disk import getCache
-from util.util import XyzTuple, xyz2irc
-from util.logconf import logging
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-log.setLevel(logging.INFO)
-log.setLevel(logging.DEBUG)
-
-# cache = getCache('p2ch2')
-cache = getCache('part2')
-
-class Ct(object):
-    def __init__(self, series_uid):
-        mhd_path = glob.glob('data/luna/subset*/{}.mhd'.format(series_uid))[0]
-
-        ct_mhd = sitk.ReadImage(mhd_path)
-        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
-
-        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
-        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
-        # This converts HU to g/cc.
-        ct_ary += 1000
-        ct_ary /= 1000
-
-        # This gets rid of negative density stuff used to indicate out-of-FOV
-        ct_ary[ct_ary < 0] = 0
-
-        # This nukes any weird hotspots and clamps bone down
-        ct_ary[ct_ary > 2] = 2
-
-        self.series_uid = series_uid
-        self.ary = ct_ary
-        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
-        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
-        self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
-
-    def getInputChunk(self, center_xyz, width_irc):
-        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
-
-        slice_list = []
-        for axis, center_val in enumerate(center_irc):
-            start_ndx = int(round(center_val - width_irc[axis]/2))
-            end_ndx = int(start_ndx + width_irc[axis])
-
-            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
-
-            if start_ndx < 0:
-                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
-                start_ndx = 0
-                end_ndx = int(width_irc[axis])
-
-            if end_ndx > self.ary.shape[axis]:
-                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
-                end_ndx = self.ary.shape[axis]
-                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
-
-            slice_list.append(slice(start_ndx, end_ndx))
-
-        ct_chunk = self.ary[slice_list]
-
-        return ct_chunk, center_irc
-
-
-@functools.lru_cache(1, typed=True)
-def getCt(series_uid):
-    return Ct(series_uid)
-
-@cache.memoize(typed=True)
-def getCtInputChunk(series_uid, center_xyz, width_irc):
-    ct = getCt(series_uid)
-    ct_chunk, center_irc = ct.getInputChunk(center_xyz, width_irc)
-    return ct_chunk, center_irc
-
-class LunaDataset(Dataset):
-    def __init__(self, test_stride=0, isTestSet_bool=None, series_uid=None):
-        # We construct a set with all series_uids that are present on disk.
-        # This will let us use the data, even if we haven't downloaded all of
-        # the subsets yet.
-        mhd_list = glob.glob('data/luna/subset*/*.mhd')
-        present_set = {p.rsplit('/', 1)[-1][:-4] for p in mhd_list}
-
-        sample_list = []
-        with open('data/luna/candidates.csv', "r") as f:
-            csv_list = list(csv.reader(f))
-
-        for row in csv_list[1:]:
-            row_uid = row[0]
-
-            if series_uid and series_uid != row_uid:
-                continue
-
-            # If a row_uid isn't present, that means it's in a subset that we
-            # don't have on disk, so we should skip it.
-            if row_uid not in present_set:
-                continue
-
-            center_xyz = tuple([float(x) for x in row[1:4]])
-            isMalignant_bool = bool(int(row[4]))
-            sample_list.append((row_uid, center_xyz, isMalignant_bool))
-
-        sample_list.sort()
-        if test_stride > 1:
-            if isTestSet_bool:
-                sample_list = sample_list[::test_stride]
-            else:
-                del sample_list[::test_stride]
-
-        log.info("{!r}: {} {} samples".format(self, len(sample_list), "testing" if isTestSet_bool else "training"))
-        self.sample_list = sample_list
-
-    def __len__(self):
-        return len(self.sample_list)
-
-    def __getitem__(self, ndx):
-        series_uid, center_xyz, isMalignant_bool = self.sample_list[ndx]
-        ct_chunk, center_irc = getCtInputChunk(series_uid, center_xyz, (16, 16, 16))
-
-        # dim=3, Index x Row x Col
-        ct_tensor = torch.from_numpy(np.array(ct_chunk, dtype=np.float32))
-
-        # dim=1
-        malignant_tensor = torch.from_numpy(np.array([isMalignant_bool], dtype=np.float32))
-
-        # dim=4, Channel x Index x Row x Col
-        ct_tensor = ct_tensor.unsqueeze(0)
-        malignant_tensor = malignant_tensor.unsqueeze(0)
-
-        # Unpacked as: input_tensor, answer_int, series_uid, center_irc
-        return ct_tensor, malignant_tensor, series_uid, center_irc
-
-
-

+ 0 - 43
p2ch2/model.py

@@ -1,43 +0,0 @@
-from torch import nn as nn
-
-from util.logconf import logging
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-# log.setLevel(logging.INFO)
-log.setLevel(logging.DEBUG)
-
-class LunaModel(nn.Module):
-    def __init__(self, layer_count, in_channels, conv_channels):
-        super().__init__()
-
-        layer_list = []
-        for layer_ndx in range(layer_count):
-            layer_list += [
-                nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True),
-                nn.ReLU(inplace=True),
-
-                nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True),
-                nn.ReLU(inplace=True),
-
-                nn.MaxPool3d(2, 2),
-            ]
-
-            in_channels = conv_channels
-            conv_channels *= 2
-
-        self.convAndPool_seq = nn.Sequential(*layer_list)
-        self.fullyConnected_layer = nn.Linear(256, 1)
-
-
-    def forward(self, x):
-        conv_out = self.convAndPool_seq(x)
-        flattened_out = conv_out.view(conv_out.size(0), -1)
-
-        try:
-            classification_out = self.fullyConnected_layer(flattened_out)
-        except:
-            log.debug(flattened_out.size())
-            raise
-
-        return classification_out

+ 0 - 192
p2ch2/training.py

@@ -1,192 +0,0 @@
-import argparse
-import sys
-
-import numpy as np
-
-import torch
-import torch.nn as nn
-from torch.autograd import Variable
-from torch.optim import SGD
-from torch.utils.data import DataLoader
-
-from util.util import enumerateWithEstimate
-from .dsets import LunaDataset
-from util.logconf import logging
-from .model import LunaModel
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-log.setLevel(logging.INFO)
-# log.setLevel(logging.DEBUG)
-
-# Used for metrics_ary index 0
-LABEL=0
-PRED=1
-LOSS=2
-# ...
-
-class LunaTrainingApp(object):
-    @classmethod
-    def __init__(self, sys_argv=None):
-        if sys_argv is None:
-            sys_argv = sys.argv[1:]
-
-        parser = argparse.ArgumentParser()
-        parser.add_argument('--batch-size',
-            help='Batch size to use for training',
-            default=256,
-            type=int,
-        )
-        parser.add_argument('--num-workers',
-            help='Number of worker processes for background data loading',
-            default=8,
-            type=int,
-        )
-        parser.add_argument('--epochs',
-            help='Number of epochs to train for',
-            default=10,
-            type=int,
-        )
-        parser.add_argument('--layers',
-            help='Number of layers to the model',
-            default=3,
-            type=int,
-        )
-        parser.add_argument('--channels',
-            help="Number of channels for the first layer's convolutions to the model (doubles each layer)",
-            default=8,
-            type=int,
-        )
-
-        self.cli_args = parser.parse_args(sys_argv)
-
-    def main(self):
-        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
-        self.train_dl = DataLoader(
-            LunaDataset(
-                test_stride=10,
-                isTestSet_bool=False,
-            ),
-            batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
-            num_workers=self.cli_args.num_workers,
-            pin_memory=True,
-        )
-        self.test_dl = DataLoader(
-            LunaDataset(
-                test_stride=10,
-                isTestSet_bool=True,
-            ),
-            batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
-            num_workers=self.cli_args.num_workers,
-            pin_memory=True,
-        )
-
-        self.model = LunaModel(self.cli_args.layers, 1, self.cli_args.channels)
-        self.model = nn.DataParallel(self.model)
-        self.model = self.model.cuda()
-
-        self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
-
-        for epoch_ndx in range(1, self.cli_args.epochs + 1):
-            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
-                epoch_ndx,
-                self.cli_args.epochs,
-                len(self.train_dl),
-                len(self.test_dl),
-                self.cli_args.batch_size,
-                torch.cuda.device_count(),
-            ))
-            
-            # Training loop, very similar to below
-            self.model.train()
-            batch_iter = enumerateWithEstimate(
-                self.train_dl,
-                "E{} Training".format(epoch_ndx),
-                start_ndx=self.train_dl.num_workers,
-            )
-            trainingMetrics_ary = np.zeros((3, len(self.train_dl.dataset)), dtype=np.float32)
-            for batch_ndx, batch_tup in batch_iter:
-                self.optimizer.zero_grad()
-                loss_var = self.computeBatchLoss(batch_ndx, batch_tup, self.train_dl.batch_size, trainingMetrics_ary)
-                loss_var.backward()
-                self.optimizer.step()
-                del loss_var
-
-            # Testing loop, very similar to above, but simplified
-            # ...
-            self.model.eval()
-            batch_iter = enumerateWithEstimate(
-                self.test_dl,
-                "E{} Testing ".format(epoch_ndx),
-                start_ndx=self.test_dl.num_workers,
-            )
-            testingMetrics_ary = np.zeros((3, len(self.test_dl.dataset)), dtype=np.float32)
-            for batch_ndx, batch_tup in batch_iter:
-                self.computeBatchLoss(batch_ndx, batch_tup, self.test_dl.batch_size, testingMetrics_ary)
-
-            self.logMetrics(epoch_ndx, trainingMetrics_ary, testingMetrics_ary)
-
-
-    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_ary):
-        input_tensor, label_tensor, series_list, center_list = batch_tup
-
-        input_var = Variable(input_tensor.cuda())
-        label_var = Variable(label_tensor.cuda())
-        prediction_var = self.model(input_var)
-        # ...
-
-        start_ndx = batch_ndx * batch_size
-        end_ndx = start_ndx + label_tensor.size(0)
-        metrics_ary[LABEL, start_ndx:end_ndx] = label_tensor.numpy()[:,0,0]
-        metrics_ary[PRED,  start_ndx:end_ndx] = prediction_var.data.cpu().numpy()[:,0]
-
-        for sample_ndx in range(label_tensor.size(0)):
-            subloss_var = nn.MSELoss()(prediction_var[sample_ndx], label_var[sample_ndx])
-            metrics_ary[LOSS, start_ndx+sample_ndx] = subloss_var.data[0]
-            del subloss_var
-
-        loss_var = nn.MSELoss()(prediction_var, label_var)
-        return loss_var
-
-
-    def logMetrics(self, epoch_ndx, trainingMetrics_ary, testingMetrics_ary):
-        log.info("E{} {}".format(
-            epoch_ndx,
-            type(self).__name__,
-        ))
-
-        for mode_str, metrics_ary in [('trn', trainingMetrics_ary), ('tst', testingMetrics_ary)]:
-            pos_mask = metrics_ary[LABEL] > 0.5
-            neg_mask = ~pos_mask
-
-            truePos_count = (metrics_ary[PRED, pos_mask] > 0.5).sum()
-            trueNeg_count = (metrics_ary[PRED, neg_mask] < 0.5).sum()
-
-            metrics_dict = {}
-            metrics_dict['loss/all'] = metrics_ary[LOSS].mean()
-            metrics_dict['loss/ben'] = metrics_ary[LOSS, neg_mask].mean()
-            metrics_dict['loss/mal'] = metrics_ary[LOSS, pos_mask].mean()
-
-            metrics_dict['correct/all'] = (truePos_count + trueNeg_count) / metrics_ary.shape[1] * 100
-            metrics_dict['correct/ben'] = (trueNeg_count) / neg_mask.sum() * 100
-            metrics_dict['correct/mal'] = (truePos_count) / pos_mask.sum() * 100
-
-            log.info("E{} {:8} {loss/all:.4f} loss, {correct/all:-5.1f}% correct".format(
-                epoch_ndx,
-                mode_str,
-                **metrics_dict,
-            ))
-            log.info("E{} {:8} {loss/ben:.4f} loss, {correct/ben:-5.1f}% correct".format(
-                epoch_ndx,
-                mode_str + '_ben',
-                **metrics_dict,
-            ))
-            log.info("E{} {:8} {loss/mal:.4f} loss, {correct/mal:-5.1f}% correct".format(
-                epoch_ndx,
-                mode_str + '_mal',
-                **metrics_dict,
-            ))
-
-
-if __name__ == '__main__':
-    sys.exit(LunaTrainingApp().main() or 0)

+ 0 - 184
p2ch3/dsets.py

@@ -1,184 +0,0 @@
-import csv
-import functools
-import glob
-import itertools
-import math
-import random
-import time
-
-import scipy.ndimage
-import SimpleITK as sitk
-
-import numpy as np
-import torch
-import torch.cuda
-from torch.utils.data import Dataset
-
-from util.disk import getCache
-from util.util import XyzTuple, xyz2irc
-from util.logconf import logging
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-log.setLevel(logging.INFO)
-log.setLevel(logging.DEBUG)
-
-# cache = getCache('p2ch3')
-cache = getCache('part2')
-
-class Ct(object):
-    def __init__(self, series_uid):
-        mhd_path = glob.glob('data/luna/subset*/{}.mhd'.format(series_uid))[0]
-
-        ct_mhd = sitk.ReadImage(mhd_path)
-        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
-
-        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
-        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
-        # This converts HU to g/cc.
-        ct_ary += 1000
-        ct_ary /= 1000
-
-        # This gets rid of negative density stuff used to indicate out-of-FOV
-        ct_ary[ct_ary < 0] = 0
-
-        # This nukes any weird hotspots and clamps bone down
-        ct_ary[ct_ary > 2] = 2
-
-        self.series_uid = series_uid
-        self.ary = ct_ary
-        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
-        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
-        self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
-
-    def getInputChunk(self, center_xyz, width_irc):
-        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
-
-        slice_list = []
-        for axis, center_val in enumerate(center_irc):
-            start_ndx = int(round(center_val - width_irc[axis]/2))
-            end_ndx = int(start_ndx + width_irc[axis])
-
-            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
-
-            if start_ndx < 0:
-                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
-                start_ndx = 0
-                end_ndx = int(width_irc[axis])
-
-            if end_ndx > self.ary.shape[axis]:
-                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
-                end_ndx = self.ary.shape[axis]
-                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
-
-            slice_list.append(slice(start_ndx, end_ndx))
-
-        ct_chunk = self.ary[slice_list]
-
-        return ct_chunk, center_irc
-
-
-@functools.lru_cache(1, typed=True)
-def getCt(series_uid):
-    return Ct(series_uid)
-
-@cache.memoize(typed=True)
-def getCtInputChunk(series_uid, center_xyz, width_irc):
-    ct = getCt(series_uid)
-    ct_chunk, center_irc = ct.getInputChunk(center_xyz, width_irc)
-    return ct_chunk, center_irc
-
-class LunaDataset(Dataset):
-    def __init__(self, test_stride=0, isTestSet_bool=None, series_uid=None,
-                 balanced_bool=False,
-                 ):
-        self.balanced_bool = balanced_bool
-
-        # We construct a set with all series_uids that are present on disk.
-        # This will let us use the data, even if we haven't downloaded all of
-        # the subsets yet.
-        mhd_list = glob.glob('data/luna/subset*/*.mhd')
-        present_set = {p.rsplit('/', 1)[-1][:-4] for p in mhd_list}
-
-        sample_list = []
-        with open('data/luna/candidates.csv', "r") as f:
-            csv_list = list(csv.reader(f))
-
-        for row in csv_list[1:]:
-            row_uid = row[0]
-
-            if series_uid and series_uid != row_uid:
-                continue
-
-            # If a row_uid isn't present, that means it's in a subset that we
-            # don't have on disk, so we should skip it.
-            if row_uid not in present_set:
-                continue
-
-            center_xyz = tuple([float(x) for x in row[1:4]])
-            isMalignant_bool = bool(int(row[4]))
-            sample_list.append((row_uid, center_xyz, isMalignant_bool))
-
-        sample_list.sort()
-        if test_stride > 1:
-            if isTestSet_bool:
-                sample_list = sample_list[::test_stride]
-            else:
-                del sample_list[::test_stride]
-
-        self.sample_list = sample_list
-        self.benignIndex_list = [i for i, x in enumerate(sample_list) if not x[2]]
-        self.malignantIndex_list = [i for i, x in enumerate(sample_list) if x[2]]
-
-        self.shuffleSamples()
-
-        log.info("{!r}: {} {} samples, {} ben, {} mal".format(
-            self,
-            len(sample_list),
-            "testing" if isTestSet_bool else "training",
-            len(self.benignIndex_list),
-            len(self.malignantIndex_list),
-        ))
-
-
-    def shuffleSamples(self):
-        if self.balanced_bool:
-            log.warning("Shufflin'")
-            random.shuffle(self.benignIndex_list)
-            random.shuffle(self.malignantIndex_list)
-
-    def __len__(self):
-        if self.balanced_bool:
-            return min(len(self.benignIndex_list), len(self.malignantIndex_list)) * 2 * 50
-        else:
-            return len(self.sample_list)
-
-    def __getitem__(self, ndx):
-        if self.balanced_bool:
-            if ndx % 2:
-                sample_ndx = self.benignIndex_list[(ndx // 2) % len(self.benignIndex_list)]
-            else:
-                sample_ndx = self.malignantIndex_list[(ndx // 2) % len(self.malignantIndex_list)]
-        else:
-            sample_ndx = ndx
-
-        series_uid, center_xyz, isMalignant_bool = self.sample_list[sample_ndx]
-        ct_chunk, center_irc = getCtInputChunk(series_uid, center_xyz, (16, 16, 16))
-
-        # dim=3, Index x Row x Col
-        ct_tensor = torch.from_numpy(np.array(ct_chunk, dtype=np.float32))
-
-        # dim=1
-        malignant_tensor = torch.from_numpy(np.array([isMalignant_bool], dtype=np.float32))
-
-        # dim=4, Channel x Index x Row x Col
-        ct_tensor = ct_tensor.unsqueeze(0)
-        malignant_tensor = malignant_tensor.unsqueeze(0)
-
-        # Unpacked as: input_tensor, answer_int, series_uid, center_irc
-        return ct_tensor, malignant_tensor, series_uid, center_irc
-
-
-

+ 0 - 45
p2ch3/model.py

@@ -1,45 +0,0 @@
-from torch import nn as nn
-
-from util.logconf import logging
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-# log.setLevel(logging.INFO)
-log.setLevel(logging.DEBUG)
-
-class LunaModel(nn.Module):
-    def __init__(self, layer_count, in_channels, conv_channels):
-        super().__init__()
-
-        layer_list = []
-        for layer_ndx in range(layer_count):
-            layer_list += [
-                nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True),
-                # nn.BatchNorm3d(conv_channels),
-                nn.ReLU(inplace=True),
-
-                nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True),
-                # nn.BatchNorm3d(conv_channels),
-                nn.ReLU(inplace=True),
-
-                nn.MaxPool3d(2, 2),
-            ]
-
-            in_channels = conv_channels
-            conv_channels *= 2
-
-        self.convAndPool_seq = nn.Sequential(*layer_list)
-        self.fullyConnected_layer = nn.Linear(256, 1)
-
-
-    def forward(self, x):
-        conv_out = self.convAndPool_seq(x)
-        flattened_out = conv_out.view(conv_out.size(0), -1)
-
-        try:
-            classification_out = self.fullyConnected_layer(flattened_out)
-        except:
-            log.debug(flattened_out.size())
-            raise
-
-        return classification_out

+ 0 - 241
p2ch3/training.py

@@ -1,241 +0,0 @@
-import argparse
-import datetime
-import os
-import sys
-
-import numpy as np
-from tensorboardX import SummaryWriter
-
-import torch
-import torch.nn as nn
-from torch.autograd import Variable
-from torch.optim import SGD
-from torch.utils.data import DataLoader
-
-from util.util import enumerateWithEstimate
-from .dsets import LunaDataset
-from util.logconf import logging
-from .model import LunaModel
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-log.setLevel(logging.INFO)
-# log.setLevel(logging.DEBUG)
-
-# Used for metrics_ary index 0
-LABEL=0
-PRED=1
-LOSS=2
-# ...
-
-class LunaTrainingApp(object):
-    @classmethod
-    def __init__(self, sys_argv=None):
-        if sys_argv is None:
-            sys_argv = sys.argv[1:]
-
-        parser = argparse.ArgumentParser()
-        parser.add_argument('--batch-size',
-            help='Batch size to use for training',
-            default=256,
-            type=int,
-        )
-        parser.add_argument('--num-workers',
-            help='Number of worker processes for background data loading',
-            default=8,
-            type=int,
-        )
-        parser.add_argument('--epochs',
-            help='Number of epochs to train for',
-            default=10,
-            type=int,
-        )
-        parser.add_argument('--layers',
-            help='Number of layers to the model',
-            default=3,
-            type=int,
-        )
-        parser.add_argument('--channels',
-            help="Number of channels for the first layer's convolutions to the model (doubles each layer)",
-            default=8,
-            type=int,
-        )
-        parser.add_argument('--balanced',
-            help="Balance the training data to half benign, half malignant.",
-            action='store_true',
-            default=False,
-        )
-
-        parser.add_argument('--tb-prefix',
-            help="Data prefix to use for Tensorboard. Defaults to chapter.",
-            default='p2ch3',
-        )
-
-        self.cli_args = parser.parse_args(sys_argv)
-
-    def main(self):
-        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
-        self.train_dl = DataLoader(
-            LunaDataset(
-                test_stride=10,
-                isTestSet_bool=False,
-                balanced_bool=self.cli_args.balanced,
-            ),
-            batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
-            num_workers=self.cli_args.num_workers,
-            pin_memory=True,
-        )
-        self.test_dl = DataLoader(
-            LunaDataset(
-                test_stride=10,
-                isTestSet_bool=True,
-            ),
-            batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
-            num_workers=self.cli_args.num_workers,
-            pin_memory=True,
-        )
-
-        self.model = LunaModel(self.cli_args.layers, 1, self.cli_args.channels)
-        self.model = nn.DataParallel(self.model)
-        self.model = self.model.cuda()
-
-        self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
-
-        time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
-        log_dir = os.path.join('runs', self.cli_args.tb_prefix, time_str)
-        self.trn_writer = SummaryWriter(log_dir=log_dir + '_train')
-        self.tst_writer = SummaryWriter(log_dir=log_dir + '_test')
-
-        for epoch_ndx in range(1, self.cli_args.epochs + 1):
-            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
-                epoch_ndx,
-                self.cli_args.epochs,
-                len(self.train_dl),
-                len(self.test_dl),
-                self.cli_args.batch_size,
-                torch.cuda.device_count(),
-            ))
-            
-            # Training loop, very similar to below
-            self.model.train()
-            self.train_dl.dataset.shuffleSamples()
-            batch_iter = enumerateWithEstimate(
-                self.train_dl,
-                "E{} Training".format(epoch_ndx),
-                start_ndx=self.train_dl.num_workers,
-            )
-            trainingMetrics_ary = np.zeros((3, len(self.train_dl.dataset)), dtype=np.float32)
-            for batch_ndx, batch_tup in batch_iter:
-                self.optimizer.zero_grad()
-                loss_var = self.computeBatchLoss(batch_ndx, batch_tup, self.train_dl.batch_size, trainingMetrics_ary)
-                loss_var.backward()
-                self.optimizer.step()
-                del loss_var
-
-            # Testing loop, very similar to above, but simplified
-            # ...
-            self.model.eval()
-            self.test_dl.dataset.shuffleSamples()
-            batch_iter = enumerateWithEstimate(
-                self.test_dl,
-                "E{} Testing ".format(epoch_ndx),
-                start_ndx=self.test_dl.num_workers,
-            )
-            testingMetrics_ary = np.zeros((3, len(self.test_dl.dataset)), dtype=np.float32)
-            for batch_ndx, batch_tup in batch_iter:
-                self.computeBatchLoss(batch_ndx, batch_tup, self.test_dl.batch_size, testingMetrics_ary)
-
-            self.logMetrics(epoch_ndx, trainingMetrics_ary, testingMetrics_ary)
-
-        self.trn_writer.close()
-        self.tst_writer.close()
-
-    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_ary):
-        input_tensor, label_tensor, series_list, center_list = batch_tup
-
-        input_var = Variable(input_tensor.cuda())
-        label_var = Variable(label_tensor.cuda())
-        prediction_var = self.model(input_var)
-        # ...
-
-        start_ndx = batch_ndx * batch_size
-        end_ndx = start_ndx + label_tensor.size(0)
-        metrics_ary[LABEL, start_ndx:end_ndx] = label_tensor.numpy()[:,0,0]
-        metrics_ary[PRED,  start_ndx:end_ndx] = prediction_var.data.cpu().numpy()[:,0]
-
-        for sample_ndx in range(label_tensor.size(0)):
-            subloss_var = nn.MSELoss()(prediction_var[sample_ndx], label_var[sample_ndx])
-            metrics_ary[LOSS, start_ndx+sample_ndx] = subloss_var.data[0]
-            del subloss_var
-
-        loss_var = nn.MSELoss()(prediction_var, label_var)
-        return loss_var
-
-
-    def logMetrics(self, epoch_ndx, trainingMetrics_ary, testingMetrics_ary):
-        log.info("E{} {}".format(
-            epoch_ndx,
-            type(self).__name__,
-        ))
-
-        for mode_str, metrics_ary in [('trn', trainingMetrics_ary), ('tst', testingMetrics_ary)]:
-            pos_mask = metrics_ary[LABEL] > 0.5
-            neg_mask = ~pos_mask
-
-            truePos_count = (metrics_ary[PRED, pos_mask] > 0.5).sum()
-            trueNeg_count = (metrics_ary[PRED, neg_mask] < 0.5).sum()
-            falseNeg_count = pos_mask.sum() - truePos_count
-            falsePos_count = neg_mask.sum() - trueNeg_count
-
-            metrics_dict = {}
-            metrics_dict['pr/precision'] = p = truePos_count / (truePos_count + falsePos_count)
-            metrics_dict['pr/recall'] = r = truePos_count / (truePos_count + falseNeg_count)
-
-            # https://en.wikipedia.org/wiki/F1_score
-            for n in [0.5, 1, 2]:
-                metrics_dict['pr/f{}_score'.format(n)] = \
-                    (1 + n**2) * (p * r / (n**2 * p + r))
-
-            metrics_dict['loss/all'] = metrics_ary[LOSS].mean()
-            metrics_dict['loss/ben'] = metrics_ary[LOSS, neg_mask].mean()
-            metrics_dict['loss/mal'] = metrics_ary[LOSS, pos_mask].mean()
-
-            metrics_dict['correct/all'] = (truePos_count + trueNeg_count) / metrics_ary.shape[1] * 100
-            metrics_dict['correct/ben'] = (trueNeg_count) / neg_mask.sum() * 100
-            metrics_dict['correct/mal'] = (truePos_count) / pos_mask.sum() * 100
-
-            log.info(("E{} {:8} "
-                     + "{loss/all:.4f} loss, "
-                     + "{correct/all:-5.1f}% correct, "
-                     + "{pr/precision:.4f} precision, "
-                     + "{pr/recall:.4f} recall").format(
-                epoch_ndx,
-                mode_str,
-                **metrics_dict,
-            ))
-            log.info(("E{} {:8} "
-                     + "{loss/ben:.4f} loss, "
-                     + "{correct/ben:-5.1f}% correct").format(
-                epoch_ndx,
-                mode_str + '_ben',
-                **metrics_dict,
-            ))
-            log.info(("E{} {:8} "
-                     + "{loss/mal:.4f} loss, "
-                     + "{correct/mal:-5.1f}% correct").format(
-                epoch_ndx,
-                mode_str + '_mal',
-                **metrics_dict,
-            ))
-
-            writer = getattr(self, mode_str + '_writer')
-            tb_ndx = epoch_ndx * trainingMetrics_ary.shape[1]
-            for key, value in metrics_dict.items():
-                writer.add_scalar(key, value, tb_ndx)
-            writer.add_pr_curve('pr', metrics_ary[LABEL], metrics_ary[PRED], tb_ndx)
-            writer.add_histogram('is_mal', metrics_ary[PRED, pos_mask], tb_ndx)
-            writer.add_histogram('is_ben', metrics_ary[PRED, neg_mask], tb_ndx)
-
-
-if __name__ == '__main__':
-    sys.exit(LunaTrainingApp().main() or 0)

+ 0 - 0
p2ch4/__init__.py


+ 0 - 320
p2ch4/dsets.py

@@ -1,320 +0,0 @@
-import csv
-import functools
-import glob
-import itertools
-import math
-import random
-import time
-import warnings
-
-import scipy.ndimage
-import SimpleITK as sitk
-
-import numpy as np
-import torch
-import torch.cuda
-from torch.utils.data import Dataset
-
-from util.disk import getCache
-from util.util import XyzTuple, xyz2irc
-from util.logconf import logging
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-log.setLevel(logging.INFO)
-log.setLevel(logging.DEBUG)
-
-cache = getCache('part2')
-
-class Ct(object):
-    def __init__(self, series_uid):
-        mhd_path = glob.glob('data/luna/subset*/{}.mhd'.format(series_uid))[0]
-
-        ct_mhd = sitk.ReadImage(mhd_path)
-        ct_ary = np.array(sitk.GetArrayFromImage(ct_mhd), dtype=np.float32)
-
-        # CTs are natively expressed in https://en.wikipedia.org/wiki/Hounsfield_scale
-        # HU are scaled oddly, with 0 g/cc (air, approximately) being -1000 and 1 g/cc (water) being 0.
-        # This converts HU to g/cc.
-        ct_ary += 1000
-        ct_ary /= 1000
-
-        # This gets rid of negative density stuff used to indicate out-of-FOV
-        ct_ary[ct_ary < 0] = 0
-
-        # This nukes any weird hotspots and clamps bone down
-        ct_ary[ct_ary > 2] = 2
-
-        self.series_uid = series_uid
-        self.ary = ct_ary
-        self.origin_xyz = XyzTuple(*ct_mhd.GetOrigin())
-        self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
-        self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
-
-    def getInputChunk(self, center_xyz, width_irc):
-        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
-
-        slice_list = []
-        for axis, center_val in enumerate(center_irc):
-            start_ndx = int(round(center_val - width_irc[axis]/2))
-            end_ndx = int(start_ndx + width_irc[axis])
-
-            assert center_val >= 0 and center_val < self.ary.shape[axis], repr([self.series_uid, center_xyz, self.origin_xyz, self.vxSize_xyz, center_irc, axis])
-
-            if start_ndx < 0:
-                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
-                start_ndx = 0
-                end_ndx = int(width_irc[axis])
-
-            if end_ndx > self.ary.shape[axis]:
-                # log.warning("Crop outside of CT array: {} {}, center:{} shape:{} width:{}".format(
-                #     self.series_uid, center_xyz, center_irc, self.ary.shape, width_irc))
-                end_ndx = self.ary.shape[axis]
-                start_ndx = int(self.ary.shape[axis] - width_irc[axis])
-
-            slice_list.append(slice(start_ndx, end_ndx))
-
-        ct_chunk = self.ary[slice_list]
-
-        return ct_chunk, center_irc
-
-    def getScaledInputChunk(self, center_xyz, width_mm, voxels_int):
-        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
-
-        ct_start = [int(round(i)) for i in xyz2irc(tuple(x - width_mm/2 for x in center_xyz), self.origin_xyz, self.vxSize_xyz, self.direction_tup)]
-        ct_end = [int(round(i)) + 1 for i in xyz2irc(tuple(x + width_mm/2 for x in center_xyz), self.origin_xyz, self.vxSize_xyz, self.direction_tup)]
-
-        for axis in range(3):
-            if ct_start[axis] > ct_end[axis]:
-                ct_start[axis], ct_end[axis] = ct_end[axis], ct_start[axis]
-
-        pad_start = [0, 0, 0]
-        pad_end = [ct_end[axis] - ct_start[axis] for axis in range(3)]
-        # log.info([ct_end, ct_start, pad_end])
-        pad_ary = np.zeros(pad_end, dtype=np.float32)
-
-        for axis in range(3):
-            if ct_start[axis] < 0:
-                pad_start[axis] = -ct_start[axis]
-                ct_start[axis] = 0
-
-            if ct_end[axis] > self.ary.shape[axis]:
-                pad_end[axis] -= ct_end[axis] - self.ary.shape[axis]
-                ct_end[axis] = self.ary.shape[axis]
-
-        pad_slices = tuple(slice(s,e) for s, e in zip(pad_start, pad_end))
-        ct_slices = tuple(slice(s,e) for s, e in zip(ct_start, ct_end))
-        pad_ary[pad_slices] = self.ary[ct_slices]
-
-        try:
-            zoom_seq = tuple(voxels_int/pad_ary.shape[axis] for axis in range(3))
-        except:
-            log.error([ct_end, ct_start, pad_end, center_irc, center_xyz, width_mm, self.vxSize_xyz])
-            raise
-
-        chunk_ary = scipy.ndimage.zoom(pad_ary, zoom_seq, order=1)
-
-        # log.info("chunk_ary.shape {}".format([chunk_ary.shape, pad_ary.shape, zoom_seq, voxels_int]))
-
-        return chunk_ary, center_irc
-
-
-@functools.lru_cache(1, typed=True)
-def getCt(series_uid):
-    return Ct(series_uid)
-
-@cache.memoize(typed=True)
-def getCtInputChunk(series_uid, center_xyz, width_irc):
-    ct = getCt(series_uid)
-    ct_chunk, center_irc = ct.getInputChunk(center_xyz, width_irc)
-    return ct_chunk, center_irc
-
-@cache.memoize(typed=True)
-def getScaledCtInputChunk(series_uid, center_xyz, width_mm, voxels_int):
-    # log.info([series_uid, center_xyz, width_mm, voxels_int])
-    ct = getCt(series_uid)
-    ct_chunk, center_irc = ct.getScaledInputChunk(center_xyz, width_mm, voxels_int)
-    return ct_chunk, center_irc
-
-
-def augmentChunk_shift(ct_chunk):
-
-    for axis in range(1,3):
-        new_chunk = np.zeros_like(ct_chunk)
-        shift = random.randint(0, 2)
-
-        slice_list = [slice(None)] * ct_chunk.ndim
-
-        new_chunk
-
-
-    return ct_chunk + np.random.normal(scale=0.1, size=ct_chunk.shape)
-
-def augmentChunk_noise(ct_chunk):
-    return ct_chunk + np.random.normal(scale=0.1, size=ct_chunk.shape)
-
-def augmentChunk_mirror(ct_chunk):
-    if random.random() > 0.5:
-        ct_chunk = np.flip(ct_chunk, -1)
-    return ct_chunk
-
-def augmentChunk_rotate(ct_chunk):
-    # Rotate the nodule around the head-foot axis
-    angle = 360 * random.random()
-    # https://docs.scipy.org/doc/scipy-0.16.1/reference/generated/scipy.ndimage.interpolation.rotate.html
-    ct_chunk = scipy.ndimage.interpolation.rotate(
-        ct_chunk,
-        angle,
-        axes=(-2, -1),
-        reshape=False,
-        order=1,
-    )
-    return ct_chunk
-
-def augmentChunk_zoomAndCrop(ct_chunk):
-    # log.info([ct_chunk.shape])
-    zoom = 1.0 + 0.2 * random.random()
-
-    with warnings.catch_warnings():
-        warnings.simplefilter("ignore")
-        # https://docs.scipy.org/doc/scipy-0.16.1/reference/generated/scipy.ndimage.interpolation.zoom.html
-        ct_chunk = scipy.ndimage.interpolation.zoom(
-            ct_chunk,
-            zoom,
-            order=1
-        )
-
-    crop_list = [random.randint(0, ct_chunk.shape[axis]-16) for axis in range(1,4)]
-    slice_list = [slice(None)] + [slice(start, start+16) for start in crop_list]
-
-    ct_chunk = ct_chunk[slice_list]
-
-    assert ct_chunk.shape[-3:] == (16, 16, 16), repr(ct_chunk.shape)
-
-    return ct_chunk
-
-def augmentCtInputChunk(ct_chunk):
-    augment_list = [
-        augmentChunk_mirror,
-        augmentChunk_rotate,
-        augmentChunk_noise,
-        augmentChunk_zoomAndCrop,
-    ]
-
-    for augment_func in augment_list:
-        ct_chunk = augment_func(ct_chunk)
-
-    return ct_chunk
-
-
-class LunaDataset(Dataset):
-    def __init__(self, test_stride=0, isTestSet_bool=None, series_uid=None,
-                 balanced_bool=False,
-                 scaled_bool=False,
-                 augmented_bool=False,
-                 ):
-        self.balanced_bool = balanced_bool
-        self.scaled_bool = scaled_bool
-        self.augmented_bool = augmented_bool
-
-        # We construct a set with all series_uids that are present on disk.
-        # This will let us use the data, even if we haven't downloaded all of
-        # the subsets yet.
-        mhd_list = glob.glob('data/luna/subset*/*.mhd')
-        present_set = {p.rsplit('/', 1)[-1][:-4] for p in mhd_list}
-
-        sample_list = []
-        with open('data/luna/candidates.csv', "r") as f:
-            csv_list = list(csv.reader(f))
-
-        for row in csv_list[1:]:
-            row_uid = row[0]
-
-            if series_uid and series_uid != row_uid:
-                continue
-
-            # If a row_uid isn't present, that means it's in a subset that we
-            # don't have on disk, so we should skip it.
-            if row_uid not in present_set:
-                continue
-
-            center_xyz = tuple([float(x) for x in row[1:4]])
-            isMalignant_bool = bool(int(row[4]))
-            sample_list.append((row_uid, center_xyz, isMalignant_bool))
-
-        sample_list.sort()
-        if test_stride > 1:
-            if isTestSet_bool:
-                sample_list = sample_list[::test_stride]
-            else:
-                del sample_list[::test_stride]
-
-        self.sample_list = sample_list
-        self.benignIndex_list = [i for i, x in enumerate(sample_list) if not x[2]]
-        self.malignantIndex_list = [i for i, x in enumerate(sample_list) if x[2]]
-
-        self.shuffleSamples()
-
-        log.info("{!r}: {} {} samples, {} ben, {} mal".format(
-            self,
-            len(sample_list),
-            "testing" if isTestSet_bool else "training",
-            len(self.benignIndex_list),
-            len(self.malignantIndex_list),
-        ))
-
-
-    def shuffleSamples(self):
-        if self.balanced_bool:
-            log.warning("Shufflin'")
-            random.shuffle(self.benignIndex_list)
-            random.shuffle(self.malignantIndex_list)
-
-    def __len__(self):
-        if self.balanced_bool:
-            return min(len(self.benignIndex_list), len(self.malignantIndex_list)) * 2 * 50
-        else:
-            return len(self.sample_list)
-
-    def __getitem__(self, ndx):
-        if self.balanced_bool:
-            if ndx % 2:
-                sample_ndx = self.benignIndex_list[(ndx // 2) % len(self.benignIndex_list)]
-            else:
-                sample_ndx = self.malignantIndex_list[(ndx // 2) % len(self.malignantIndex_list)]
-        else:
-            sample_ndx = ndx
-
-        series_uid, center_xyz, isMalignant_bool = self.sample_list[sample_ndx]
-
-        if self.scaled_bool:
-            ct_chunk, center_irc = getScaledCtInputChunk(series_uid, center_xyz, 12, 20)
-            # in:  dim=3, Index x Row x Col
-            # out: dim=4, Channel x Index x Row x Col
-            ct_chunk = np.expand_dims(ct_chunk, 0)
-
-            if self.augmented_bool:
-                ct_chunk = augmentCtInputChunk(ct_chunk)
-            else:
-                ct_chunk = ct_chunk[:, 2:-2, 2:-2, 2:-2]
-
-        else:
-            ct_chunk, center_irc = getCtInputChunk(series_uid, center_xyz, (16, 16, 16))
-            ct_chunk = np.expand_dims(ct_chunk, 0)
-
-        assert ct_chunk.shape[-3:] == (16, 16, 16), repr(ct_chunk.shape)
-
-
-        ct_tensor = torch.from_numpy(np.array(ct_chunk, dtype=np.float32))
-        # ct_tensor = ct_tensor.unsqueeze(0)
-
-        # dim=1
-        malignant_tensor = torch.from_numpy(np.array([isMalignant_bool], dtype=np.float32))
-        malignant_tensor = malignant_tensor.unsqueeze(0)
-
-        # Unpacked as: input_tensor, answer_int, series_uid, center_irc
-        return ct_tensor, malignant_tensor, series_uid, center_irc
-
-
-

+ 0 - 45
p2ch4/model.py

@@ -1,45 +0,0 @@
-from torch import nn as nn
-
-from util.logconf import logging
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-# log.setLevel(logging.INFO)
-log.setLevel(logging.DEBUG)
-
-class LunaModel(nn.Module):
-    def __init__(self, layer_count, in_channels, conv_channels):
-        super().__init__()
-
-        layer_list = []
-        for layer_ndx in range(layer_count):
-            layer_list += [
-                nn.Conv3d(in_channels, conv_channels, kernel_size=3, padding=1, bias=True),
-                # nn.BatchNorm3d(conv_channels),
-                nn.ReLU(inplace=True),
-
-                nn.Conv3d(conv_channels, conv_channels, kernel_size=3, padding=1, bias=True),
-                # nn.BatchNorm3d(conv_channels),
-                nn.ReLU(inplace=True),
-
-                nn.MaxPool3d(2, 2),
-            ]
-
-            in_channels = conv_channels
-            conv_channels *= 2
-
-        self.convAndPool_seq = nn.Sequential(*layer_list)
-        self.fullyConnected_layer = nn.Linear(256, 1)
-
-
-    def forward(self, x):
-        conv_out = self.convAndPool_seq(x)
-        flattened_out = conv_out.view(conv_out.size(0), -1)
-
-        try:
-            classification_out = self.fullyConnected_layer(flattened_out)
-        except:
-            log.debug(flattened_out.size())
-            raise
-
-        return classification_out

+ 0 - 255
p2ch4/training.py

@@ -1,255 +0,0 @@
-import argparse
-import datetime
-import os
-import sys
-
-import numpy as np
-from tensorboardX import SummaryWriter
-
-import torch
-import torch.nn as nn
-from torch.autograd import Variable
-from torch.optim import SGD
-from torch.utils.data import DataLoader
-
-from util.util import enumerateWithEstimate
-from .dsets import LunaDataset
-from util.logconf import logging
-from .model import LunaModel
-
-log = logging.getLogger(__name__)
-# log.setLevel(logging.WARN)
-log.setLevel(logging.INFO)
-# log.setLevel(logging.DEBUG)
-
-# Used for metrics_ary index 0
-LABEL=0
-PRED=1
-LOSS=2
-# ...
-
-class LunaTrainingApp(object):
-    @classmethod
-    def __init__(self, sys_argv=None):
-        if sys_argv is None:
-            sys_argv = sys.argv[1:]
-
-        parser = argparse.ArgumentParser()
-        parser.add_argument('--batch-size',
-            help='Batch size to use for training',
-            default=256,
-            type=int,
-        )
-        parser.add_argument('--num-workers',
-            help='Number of worker processes for background data loading',
-            default=8,
-            type=int,
-        )
-        parser.add_argument('--epochs',
-            help='Number of epochs to train for',
-            default=10,
-            type=int,
-        )
-        parser.add_argument('--layers',
-            help='Number of layers to the model',
-            default=3,
-            type=int,
-        )
-        parser.add_argument('--channels',
-            help="Number of channels for the first layer's convolutions to the model (doubles each layer)",
-            default=8,
-            type=int,
-        )
-        parser.add_argument('--balanced',
-            help="Balance the training data to half benign, half malignant.",
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--scaled',
-            help="Scale the CT chunks to square voxels.",
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--augmented',
-            help="Augment the training data (implies --scaled).",
-            action='store_true',
-            default=False,
-        )
-
-        parser.add_argument('--tb-prefix',
-            help="Data prefix to use for Tensorboard. Defaults to chapter.",
-            default='p2ch4',
-        )
-
-        self.cli_args = parser.parse_args(sys_argv)
-
-    def main(self):
-        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
-        self.train_dl = DataLoader(
-            LunaDataset(
-                test_stride=10,
-                isTestSet_bool=False,
-                balanced_bool=self.cli_args.balanced,
-                scaled_bool=self.cli_args.scaled or self.cli_args.augmented,
-                augmented_bool=self.cli_args.augmented,
-            ),
-            batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
-            num_workers=self.cli_args.num_workers,
-            pin_memory=True,
-        )
-        self.test_dl = DataLoader(
-            LunaDataset(
-                test_stride=10,
-                isTestSet_bool=True,
-                scaled_bool=self.cli_args.scaled or self.cli_args.augmented,
-                # augmented_bool=self.cli_args.augmented,
-            ),
-            batch_size=self.cli_args.batch_size * torch.cuda.device_count(),
-            num_workers=self.cli_args.num_workers,
-            pin_memory=True,
-        )
-
-        self.model = LunaModel(self.cli_args.layers, 1, self.cli_args.channels)
-        self.model = nn.DataParallel(self.model)
-        self.model = self.model.cuda()
-
-        self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
-
-        time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
-        log_dir = os.path.join('runs', self.cli_args.tb_prefix, time_str)
-        self.trn_writer = SummaryWriter(log_dir=log_dir + '_train')
-        self.tst_writer = SummaryWriter(log_dir=log_dir + '_test')
-
-        for epoch_ndx in range(1, self.cli_args.epochs + 1):
-            log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
-                epoch_ndx,
-                self.cli_args.epochs,
-                len(self.train_dl),
-                len(self.test_dl),
-                self.cli_args.batch_size,
-                torch.cuda.device_count(),
-            ))
-
-            # Training loop, very similar to below
-            self.model.train()
-            self.train_dl.dataset.shuffleSamples()
-            batch_iter = enumerateWithEstimate(
-                self.train_dl,
-                "E{} Training".format(epoch_ndx),
-                start_ndx=self.train_dl.num_workers,
-            )
-            trainingMetrics_ary = np.zeros((3, len(self.train_dl.dataset)), dtype=np.float32)
-            for batch_ndx, batch_tup in batch_iter:
-                self.optimizer.zero_grad()
-                loss_var = self.computeBatchLoss(batch_ndx, batch_tup, self.train_dl.batch_size, trainingMetrics_ary)
-                loss_var.backward()
-                self.optimizer.step()
-                del loss_var
-
-            # Testing loop, very similar to above, but simplified
-            # ...
-            self.model.eval()
-            self.test_dl.dataset.shuffleSamples()
-            batch_iter = enumerateWithEstimate(
-                self.test_dl,
-                "E{} Testing ".format(epoch_ndx),
-                start_ndx=self.test_dl.num_workers,
-            )
-            testingMetrics_ary = np.zeros((3, len(self.test_dl.dataset)), dtype=np.float32)
-            for batch_ndx, batch_tup in batch_iter:
-                self.computeBatchLoss(batch_ndx, batch_tup, self.test_dl.batch_size, testingMetrics_ary)
-
-            self.logMetrics(epoch_ndx, trainingMetrics_ary, testingMetrics_ary)
-
-        self.trn_writer.close()
-        self.tst_writer.close()
-
-    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_ary):
-        input_tensor, label_tensor, series_list, center_list = batch_tup
-
-        input_var = Variable(input_tensor.cuda())
-        label_var = Variable(label_tensor.cuda())
-        prediction_var = self.model(input_var)
-        # ...
-
-        start_ndx = batch_ndx * batch_size
-        end_ndx = start_ndx + label_tensor.size(0)
-        metrics_ary[LABEL, start_ndx:end_ndx] = label_tensor.numpy()[:,0,0]
-        metrics_ary[PRED,  start_ndx:end_ndx] = prediction_var.data.cpu().numpy()[:,0]
-
-        for sample_ndx in range(label_tensor.size(0)):
-            subloss_var = nn.MSELoss()(prediction_var[sample_ndx], label_var[sample_ndx])
-            metrics_ary[LOSS, start_ndx+sample_ndx] = subloss_var.data[0]
-            del subloss_var
-
-        loss_var = nn.MSELoss()(prediction_var, label_var)
-        return loss_var
-
-
-    def logMetrics(self, epoch_ndx, trainingMetrics_ary, testingMetrics_ary):
-        log.info("E{} {}".format(
-            epoch_ndx,
-            type(self).__name__,
-        ))
-
-        for mode_str, metrics_ary in [('trn', trainingMetrics_ary), ('tst', testingMetrics_ary)]:
-            pos_mask = metrics_ary[LABEL] > 0.5
-            neg_mask = ~pos_mask
-
-            truePos_count = (metrics_ary[PRED, pos_mask] > 0.5).sum()
-            trueNeg_count = (metrics_ary[PRED, neg_mask] < 0.5).sum()
-            falseNeg_count = pos_mask.sum() - truePos_count
-            falsePos_count = neg_mask.sum() - trueNeg_count
-
-            metrics_dict = {}
-            metrics_dict['pr/precision'] = p = truePos_count / (truePos_count + falsePos_count)
-            metrics_dict['pr/recall'] = r = truePos_count / (truePos_count + falseNeg_count)
-
-            # https://en.wikipedia.org/wiki/F1_score
-            for n in [0.5, 1, 2]:
-                metrics_dict['pr/f{}_score'.format(n)] = \
-                    (1 + n**2) * (p * r / (n**2 * p + r))
-
-            metrics_dict['loss/all'] = metrics_ary[LOSS].mean()
-            metrics_dict['loss/ben'] = metrics_ary[LOSS, neg_mask].mean()
-            metrics_dict['loss/mal'] = metrics_ary[LOSS, pos_mask].mean()
-
-            metrics_dict['correct/all'] = (truePos_count + trueNeg_count) / metrics_ary.shape[1] * 100
-            metrics_dict['correct/ben'] = (trueNeg_count) / neg_mask.sum() * 100
-            metrics_dict['correct/mal'] = (truePos_count) / pos_mask.sum() * 100
-
-            log.info(("E{} {:8} "
-                     + "{loss/all:.4f} loss, "
-                     + "{correct/all:-5.1f}% correct, "
-                     + "{pr/precision:.4f} precision, "
-                     + "{pr/recall:.4f} recall").format(
-                epoch_ndx,
-                mode_str,
-                **metrics_dict,
-            ))
-            log.info(("E{} {:8} "
-                     + "{loss/ben:.4f} loss, "
-                     + "{correct/ben:-5.1f}% correct").format(
-                epoch_ndx,
-                mode_str + '_ben',
-                **metrics_dict,
-            ))
-            log.info(("E{} {:8} "
-                     + "{loss/mal:.4f} loss, "
-                     + "{correct/mal:-5.1f}% correct").format(
-                epoch_ndx,
-                mode_str + '_mal',
-                **metrics_dict,
-            ))
-
-            writer = getattr(self, mode_str + '_writer')
-            tb_ndx = epoch_ndx * trainingMetrics_ary.shape[1]
-            for key, value in metrics_dict.items():
-                writer.add_scalar(key, value, tb_ndx)
-            writer.add_pr_curve('pr', metrics_ary[LABEL], metrics_ary[PRED], tb_ndx)
-            writer.add_histogram('is_mal', metrics_ary[PRED, pos_mask], tb_ndx)
-            writer.add_histogram('is_ben', metrics_ary[PRED, neg_mask], tb_ndx)
-
-
-if __name__ == '__main__':
-    sys.exit(LunaTrainingApp().main() or 0)

+ 106 - 0
util/affine.py

@@ -0,0 +1,106 @@
+import torch
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+import torch.backends.cudnn as cudnn
+
+from util.logconf import logging
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+
+def affine_grid_generator(theta, size):
+    if theta.data.is_cuda and len(size) == 4:
+        if not cudnn.enabled:
+            raise RuntimeError("AffineGridGenerator needs CuDNN for "
+                               "processing CUDA inputs, but CuDNN is not enabled")
+        if not cudnn.is_acceptable(theta.data):
+            raise RuntimeError("AffineGridGenerator generator theta not acceptable for CuDNN")
+        N, C, H, W = size
+        return torch.cudnn_affine_grid_generator(theta, N, C, H, W)
+    else:
+        return AffineGridGenerator.apply(theta, size)
+
+class AffineGridGenerator(Function):
+    @staticmethod
+    def _enforce_cudnn(input):
+        if not cudnn.enabled:
+            raise RuntimeError("AffineGridGenerator needs CuDNN for "
+                               "processing CUDA inputs, but CuDNN is not enabled")
+        assert cudnn.is_acceptable(input)
+
+    @staticmethod
+    def forward(ctx, theta, size):
+        assert type(size) == torch.Size
+
+        if len(size) == 5:
+            N, C, D, H, W = size
+            ctx.size = size
+            ctx.is_cuda = theta.is_cuda
+            base_grid = theta.new(N, D, H, W, 4)
+
+            w_points = (torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1]))
+            h_points = (torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1])).unsqueeze(-1)
+            d_points = (torch.linspace(-1, 1, D) if D > 1 else torch.Tensor([-1])).unsqueeze(-1).unsqueeze(-1)
+
+            base_grid[:, :, :, :, 0] = w_points
+            base_grid[:, :, :, :, 1] = h_points
+            base_grid[:, :, :, :, 2] = d_points
+            base_grid[:, :, :, :, 3] = 1
+            ctx.base_grid = base_grid
+            grid = torch.bmm(base_grid.view(N, D * H * W, 4), theta.transpose(1, 2))
+            grid = grid.view(N, D, H, W, 3)
+
+        elif len(size) == 4:
+            N, C, H, W = size
+            ctx.size = size
+            if theta.is_cuda:
+                AffineGridGenerator._enforce_cudnn(theta)
+                assert False
+            ctx.is_cuda = False
+            base_grid = theta.new(N, H, W, 3)
+            linear_points = torch.linspace(-1, 1, W) if W > 1 else torch.Tensor([-1])
+            base_grid[:, :, :, 0] = torch.ger(torch.ones(H), linear_points).expand_as(base_grid[:, :, :, 0])
+            linear_points = torch.linspace(-1, 1, H) if H > 1 else torch.Tensor([-1])
+            base_grid[:, :, :, 1] = torch.ger(linear_points, torch.ones(W)).expand_as(base_grid[:, :, :, 1])
+            base_grid[:, :, :, 2] = 1
+            ctx.base_grid = base_grid
+            grid = torch.bmm(base_grid.view(N, H * W, 3), theta.transpose(1, 2))
+            grid = grid.view(N, H, W, 2)
+        else:
+            raise RuntimeError("AffineGridGenerator needs 4d (spatial) or 5d (volumetric) inputs.")
+
+        return grid
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, grad_grid):
+        if len(ctx.size) == 5:
+            N, C, D, H, W = ctx.size
+            assert grad_grid.size() == torch.Size([N, D, H, W, 3])
+            assert ctx.is_cuda == grad_grid.is_cuda
+            # if grad_grid.is_cuda:
+            #     AffineGridGenerator._enforce_cudnn(grad_grid)
+            #     assert False
+            base_grid = ctx.base_grid
+            grad_theta = torch.bmm(
+                base_grid.view(N, D * H * W, 4).transpose(1, 2),
+                grad_grid.view(N, D * H * W, 3))
+            grad_theta = grad_theta.transpose(1, 2)
+        elif len(ctx.size) == 4:
+            N, C, H, W = ctx.size
+            assert grad_grid.size() == torch.Size([N, H, W, 2])
+            assert ctx.is_cuda == grad_grid.is_cuda
+            if grad_grid.is_cuda:
+                AffineGridGenerator._enforce_cudnn(grad_grid)
+                assert False
+            base_grid = ctx.base_grid
+            grad_theta = torch.bmm(
+                base_grid.view(N, H * W, 3).transpose(1, 2),
+                grad_grid.view(N, H * W, 2))
+            grad_theta = grad_theta.transpose(1, 2)
+        else:
+            assert False
+
+        return grad_theta, None

+ 1 - 1
util/disk.py

@@ -78,7 +78,7 @@ class GzipDisk(Disk):
         return value
 
 def getCache(scope_str):
-    return FanoutCache('data/cache/' + scope_str, disk=GzipDisk, shards=32, timeout=1, size_limit=8e10)
+    return FanoutCache('data/cache/' + scope_str, disk=GzipDisk, shards=32, timeout=1, size_limit=2e11)
 
 # def disk_cache(base_path, memsize=2):
 #     def disk_cache_decorator(f):

+ 572 - 0
util/test_affine.py

@@ -0,0 +1,572 @@
+import math
+import random
+
+import numpy as np
+import scipy.ndimage
+
+import torch
+
+import pytest
+
+from util.logconf import logging
+log = logging.getLogger(__name__)
+# log.setLevel(logging.WARN)
+# log.setLevel(logging.INFO)
+log.setLevel(logging.DEBUG)
+
+from .affine import affine_grid_generator
+
+
+if torch.cuda.is_available():
+    @pytest.fixture(params=['cpu', 'cuda'])
+    def device(request):
+        return request.param
+else:
+    @pytest.fixture(params=['cpu'])
+    def device(request):
+        return request.param
+
+# @pytest.fixture(params=[0., 0.25])
+@pytest.fixture(params=[0.0, 0.5, 0.25, 0.125, 'random'])
+def angle_rad(request):
+    if request.param == 'random':
+        return random.random() * math.pi * 2
+    return request.param * math.pi * 2
+
+@pytest.fixture(params=[(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0), 'random'])
+def axis_vector(request):
+    if request.param == 'random':
+        t = (random.random(), random.random(), random.random())
+        l = sum(x**2 for x in t)**0.5
+        return tuple(x/l for x in t)
+    return request.param
+
+@pytest.fixture(params=[torch.nn.functional.affine_grid, affine_grid_generator])
+def affine_func2d(request):
+    return request.param
+
+@pytest.fixture(params=[affine_grid_generator])
+def affine_func3d(request):
+    return request.param
+
+# @pytest.fixture(params=[[1, 1, 3, 5], [1, 1, 3, 3]])
+@pytest.fixture(params=[[1, 1, 3, 5], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 3, 4]])
+def input_size2d(request):
+    return request.param
+
+# @pytest.fixture(params=[[1, 1, 5, 3], [1, 1, 3, 5], [1, 1, 5, 5]])
+@pytest.fixture(params=[[1, 1, 5, 3], [1, 1, 3, 5], [1, 1, 4, 3], [1, 1, 5, 5], [1, 1, 6, 6]])
+def output_size2d(request):
+    return request.param
+
+@pytest.fixture(params=[[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 6, 6], ])
+def input_size2dsq(request):
+    return request.param
+
+@pytest.fixture(params=[[1, 1, 2, 2], [1, 1, 3, 3], [1, 1, 4, 4], [1, 1, 5, 5], [1, 1, 6, 6], ])
+def output_size2dsq(request):
+    return request.param
+
+
+# @pytest.fixture(params=[[1, 1, 2, 2, 2], [1, 1, 2, 3, 4]])
+@pytest.fixture(params=[[1, 1, 2, 2, 2], [1, 1, 2, 3, 4], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 3, 4, 5]])
+def input_size3d(request):
+    return request.param
+
+@pytest.fixture(params=[[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 6, 6, 6]])
+def input_size3dsq(request):
+    return request.param
+
+@pytest.fixture(params=[[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 4, 4, 4], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]])
+def output_size3dsq(request):
+    return request.param
+
+# @pytest.fixture(params=[[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 3, 4, 5]])
+@pytest.fixture(params=[[1, 1, 2, 2, 2], [1, 1, 3, 3, 3], [1, 1, 3, 4, 5], [1, 1, 4, 3, 2], [1, 1, 5, 5, 5], [1, 1, 6, 6, 6]])
+def output_size3d(request):
+    return request.param
+
+
+def _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad):
+    print("_buildEquivalentTransforms2d", device, input_size, output_size, angle_rad * 180 / math.pi)
+    input_center = [(x-1)/2 for x in input_size]
+    output_center = [(x-1)/2 for x in output_size]
+
+    s = math.sin(angle_rad)
+    c = math.cos(angle_rad)
+
+    intrans_ary = np.array([
+        [1, 0, input_center[2]],
+        [0, 1, input_center[3]],
+        [0, 0, 1],
+    ], dtype=np.float64)
+
+    inscale_ary = np.array([
+        [input_center[2], 0, 0],
+        [0, input_center[3], 0],
+        [0, 0, 1],
+    ], dtype=np.float64)
+
+    rotation_ary = np.array([
+        [c, -s, 0],
+        [s,  c, 0],
+        [0,  0, 1],
+    ], dtype=np.float64)
+
+    outscale_ary = np.array([
+        [1/output_center[2], 0, 0],
+        [0, 1/output_center[3], 0],
+        [0, 0, 1],
+    ], dtype=np.float64)
+
+    outtrans_ary = np.array([
+        [1, 0, -output_center[2]],
+        [0, 1, -output_center[3]],
+        [0, 0, 1],
+    ], dtype=np.float64)
+
+    reorder_ary = np.array([
+        [0, 1, 0],
+        [1, 0, 0],
+        [0, 0, 1],
+    ], dtype=np.float64)
+
+    transform_ary = intrans_ary @ inscale_ary @ rotation_ary.T @ outscale_ary @ outtrans_ary
+    grid_ary = reorder_ary @ rotation_ary.T @ outscale_ary @ outtrans_ary
+    transform_tensor = torch.from_numpy((rotation_ary)).to(device, torch.float32)
+
+    transform_tensor = transform_tensor[:2].unsqueeze(0)
+
+    print('transform_tensor', transform_tensor.size(), transform_tensor.dtype, transform_tensor.device)
+    print(transform_tensor)
+    print('outtrans_ary', outtrans_ary.shape, outtrans_ary.dtype)
+    print(outtrans_ary.round(3))
+    print('outscale_ary', outscale_ary.shape, outscale_ary.dtype)
+    print(outscale_ary.round(3))
+    print('rotation_ary', rotation_ary.shape, rotation_ary.dtype)
+    print(rotation_ary.round(3))
+    print('inscale_ary', inscale_ary.shape, inscale_ary.dtype)
+    print(inscale_ary.round(3))
+    print('intrans_ary', intrans_ary.shape, intrans_ary.dtype)
+    print(intrans_ary.round(3))
+    print('transform_ary', transform_ary.shape, transform_ary.dtype)
+    print(transform_ary.round(3))
+    print('grid_ary', grid_ary.shape, grid_ary.dtype)
+    print(grid_ary.round(3))
+
+    def prtf(pt):
+        print(pt, 'transformed', (transform_ary @ (pt + [1]))[:2].round(3))
+
+    prtf([0, 0])
+    prtf([1, 0])
+    prtf([2, 0])
+
+    print('')
+
+    prtf([0, 0])
+    prtf([0, 1])
+    prtf([0, 2])
+    prtf(output_center[2:])
+
+    return transform_tensor, transform_ary, grid_ary
+
+def _buildEquivalentTransforms3d(device, input_size, output_size, angle_rad, axis_vector):
+    print("_buildEquivalentTransforms2d", device, input_size, output_size, angle_rad * 180 / math.pi, axis_vector)
+    input_center = [(x-1)/2 for x in input_size]
+    output_center = [(x-1)/2 for x in output_size]
+
+    s = math.sin(angle_rad)
+    c = math.cos(angle_rad)
+    c1 = 1 - c
+
+    intrans_ary = np.array([
+        [1, 0, 0, input_center[2]],
+        [0, 1, 0, input_center[3]],
+        [0, 0, 1, input_center[4]],
+        [0, 0, 0, 1],
+    ], dtype=np.float64)
+
+    inscale_ary = np.array([
+        [input_center[2], 0, 0, 0],
+        [0, input_center[3], 0, 0],
+        [0, 0, input_center[4], 0],
+        [0, 0, 0, 1],
+    ], dtype=np.float64)
+
+    l, m, n = axis_vector
+    scipyRotation_ary = np.array([
+        [l*l*c1 +   c, m*l*c1 - n*s, n*l*c1 + m*s, 0],
+        [l*m*c1 + n*s, m*m*c1 +   c, n*m*c1 - l*s, 0],
+        [l*n*c1 - m*s, m*n*c1 + l*s, n*n*c1 +   c, 0],
+        [0, 0, 0, 1],
+    ], dtype=np.float64)
+
+    z, y, x = axis_vector
+    torchRotation_ary = np.array([
+        [x*x*c1 +   c, y*x*c1 - z*s, z*x*c1 + y*s, 0],
+        [x*y*c1 + z*s, y*y*c1 +   c, z*y*c1 - x*s, 0],
+        [x*z*c1 - y*s, y*z*c1 + x*s, z*z*c1 +   c, 0],
+        [0, 0, 0, 1],
+    ], dtype=np.float64)
+
+    outscale_ary = np.array([
+        [1/output_center[2], 0, 0, 0],
+        [0, 1/output_center[3], 0, 0],
+        [0, 0, 1/output_center[4], 0],
+        [0, 0, 0, 1],
+    ], dtype=np.float64)
+
+    outtrans_ary = np.array([
+        [1, 0, 0, -output_center[2]],
+        [0, 1, 0, -output_center[3]],
+        [0, 0, 1, -output_center[4]],
+        [0, 0, 0, 1],
+    ], dtype=np.float64)
+
+    reorder_ary = np.array([
+        [0, 0, 1, 0],
+        [0, 1, 0, 0],
+        [1, 0, 0, 0],
+        [0, 0, 0, 1],
+    ], dtype=np.float64)
+
+    transform_ary = intrans_ary @ inscale_ary @ np.linalg.inv(scipyRotation_ary) @ outscale_ary @ outtrans_ary
+    grid_ary = reorder_ary @ np.linalg.inv(scipyRotation_ary) @ outscale_ary @ outtrans_ary
+    transform_tensor = torch.from_numpy((torchRotation_ary)).to(device, torch.float32)
+    transform_tensor = transform_tensor[:3].unsqueeze(0)
+
+    print('transform_tensor', transform_tensor.size(), transform_tensor.dtype, transform_tensor.device)
+    print(transform_tensor)
+    print('outtrans_ary', outtrans_ary.shape, outtrans_ary.dtype)
+    print(outtrans_ary.round(3))
+    print('outscale_ary', outscale_ary.shape, outscale_ary.dtype)
+    print(outscale_ary.round(3))
+    print('rotation_ary', scipyRotation_ary.shape, scipyRotation_ary.dtype, axis_vector, angle_rad)
+    print(scipyRotation_ary.round(3))
+    print('inscale_ary', inscale_ary.shape, inscale_ary.dtype)
+    print(inscale_ary.round(3))
+    print('intrans_ary', intrans_ary.shape, intrans_ary.dtype)
+    print(intrans_ary.round(3))
+    print('transform_ary', transform_ary.shape, transform_ary.dtype)
+    print(transform_ary.round(3))
+    print('grid_ary', grid_ary.shape, grid_ary.dtype)
+    print(grid_ary.round(3))
+
+    def prtf(pt):
+        print(pt, 'transformed', (transform_ary @ (pt + [1]))[:3].round(3))
+
+    prtf([0, 0, 0])
+    prtf([1, 0, 0])
+    prtf([2, 0, 0])
+
+    print('')
+
+    prtf([0, 0, 0])
+    prtf([0, 1, 0])
+    prtf([0, 2, 0])
+
+    print('')
+
+    prtf([0, 0, 0])
+    prtf([0, 0, 1])
+    prtf([0, 0, 2])
+
+    prtf(output_center[2:])
+
+    return transform_tensor, transform_ary, grid_ary
+
+
+def test_affine_2d_rotate0(device, affine_func2d):
+    input_size = [1, 1, 3, 3]
+    input_ary = np.array(np.random.random(input_size), dtype=np.float32)
+    output_size = [1, 1, 5, 5]
+    angle_rad = 0.
+
+    transform_tensor, transform_ary, offset = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
+
+    # reference
+    # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
+    scipy_ary = scipy.ndimage.affine_transform(
+        input_ary[0,0],
+        transform_ary,
+        offset=offset,
+        output_shape=output_size[2:],
+        # output=None,
+        order=1,
+        mode='nearest',
+        # cval=0.0,
+        prefilter=False)
+
+    print('input_ary', input_ary.shape, input_ary.dtype)
+    print(input_ary)
+    print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
+    print(scipy_ary)
+
+    affine_tensor = affine_func2d(
+            transform_tensor,
+            torch.Size(output_size)
+        )
+
+    print('affine_tensor', affine_tensor.size(), affine_tensor.dtype, affine_tensor.device)
+    print(affine_tensor)
+
+    gridsample_ary = torch.nn.functional.grid_sample(
+            torch.tensor(input_ary, device=device).to(device),
+            affine_tensor,
+            padding_mode='border'
+        ).to('cpu').numpy()
+
+    print('input_ary', input_ary.shape, input_ary.dtype)
+    print(input_ary)
+    print('gridsample_ary', gridsample_ary.shape, gridsample_ary.dtype)
+    print(gridsample_ary)
+    print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
+    print(scipy_ary)
+
+    assert np.abs(scipy_ary.mean() - gridsample_ary.mean()) < 1e-6
+    assert np.abs(scipy_ary - gridsample_ary).max() < 1e-6
+    # assert False
+
+def test_affine_2d_rotate90(device, affine_func2d, input_size2dsq, output_size2dsq):
+    input_size = input_size2dsq
+    input_ary = np.array(np.random.random(input_size), dtype=np.float32)
+    output_size = output_size2dsq
+    angle_rad = 0.25 * math.pi * 2
+
+    transform_tensor, transform_ary, offset = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
+
+    # reference
+    # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
+    scipy_ary = scipy.ndimage.affine_transform(
+        input_ary[0,0],
+        transform_ary,
+        offset=offset,
+        output_shape=output_size[2:],
+        # output=None,
+        order=1,
+        mode='nearest',
+        # cval=0.0,
+        prefilter=True)
+
+    print('input_ary', input_ary.shape, input_ary.dtype, input_ary.mean())
+    print(input_ary)
+    print('scipy_ary', scipy_ary.shape, scipy_ary.dtype, scipy_ary.mean())
+    print(scipy_ary)
+
+    if input_size2dsq == output_size2dsq:
+        assert np.abs(scipy_ary.mean() - input_ary.mean()) < 1e-6
+    assert np.abs(scipy_ary[0,0] - input_ary[0,0,0,-1]).max() < 1e-6
+    assert np.abs(scipy_ary[0,-1] - input_ary[0,0,-1,-1]).max() < 1e-6
+    assert np.abs(scipy_ary[-1,-1] - input_ary[0,0,-1,0]).max() < 1e-6
+    assert np.abs(scipy_ary[-1,0] - input_ary[0,0,0,0]).max() < 1e-6
+
+    affine_tensor = affine_func2d(
+            transform_tensor,
+            torch.Size(output_size)
+        )
+
+    print('affine_tensor', affine_tensor.size(), affine_tensor.dtype, affine_tensor.device)
+    print(affine_tensor)
+
+    gridsample_ary = torch.nn.functional.grid_sample(
+            torch.tensor(input_ary, device=device).to(device),
+            affine_tensor,
+            padding_mode='border'
+        ).to('cpu').numpy()
+
+    print('input_ary', input_ary.shape, input_ary.dtype)
+    print(input_ary)
+    print('gridsample_ary', gridsample_ary.shape, gridsample_ary.dtype)
+    print(gridsample_ary)
+    print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
+    print(scipy_ary)
+
+    assert np.abs(scipy_ary.mean() - gridsample_ary.mean()) < 1e-6
+    assert np.abs(scipy_ary - gridsample_ary).max() < 1e-6
+    # assert False
+
+def test_affine_2d_rotate45(device, affine_func2d):
+    input_size = [1, 1, 3, 3]
+    input_ary = np.array(np.zeros(input_size), dtype=np.float32)
+    input_ary[0,0,0,:] = 0.5
+    input_ary[0,0,2,2] = 1.0
+    output_size = [1, 1, 3, 3]
+    angle_rad = 0.125 * math.pi * 2
+
+    transform_tensor, transform_ary, offset = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
+
+    # reference
+    # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
+    scipy_ary = scipy.ndimage.affine_transform(
+        input_ary[0,0],
+        transform_ary,
+        offset=offset,
+        output_shape=output_size[2:],
+        # output=None,
+        order=1,
+        mode='nearest',
+        # cval=0.0,
+        prefilter=False)
+
+    print('input_ary', input_ary.shape, input_ary.dtype)
+    print(input_ary)
+    print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
+    print(scipy_ary)
+
+    affine_tensor = affine_func2d(
+            transform_tensor,
+            torch.Size(output_size)
+        )
+
+    print('affine_tensor', affine_tensor.size(), affine_tensor.dtype, affine_tensor.device)
+    print(affine_tensor)
+
+    gridsample_ary = torch.nn.functional.grid_sample(
+            torch.tensor(input_ary, device=device).to(device),
+            affine_tensor,
+            padding_mode='border'
+        ).to('cpu').numpy()
+
+    print('input_ary', input_ary.shape, input_ary.dtype)
+    print(input_ary)
+    print('gridsample_ary', gridsample_ary.shape, gridsample_ary.dtype)
+    print(gridsample_ary)
+    print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
+    print(scipy_ary)
+
+    assert np.abs(scipy_ary - gridsample_ary).max() < 1e-6
+    # assert False
+
+def test_affine_2d_rotateRandom(device, affine_func2d, angle_rad, input_size2d, output_size2d):
+    input_size = input_size2d
+    input_ary = np.array(np.random.random(input_size), dtype=np.float32).round(3)
+    output_size = output_size2d
+
+    input_ary[0,0,0,0] = 2
+    input_ary[0,0,0,-1] = 4
+    input_ary[0,0,-1,0] = 6
+    input_ary[0,0,-1,-1] = 8
+
+    transform_tensor, transform_ary, grid_ary = _buildEquivalentTransforms2d(device, input_size, output_size, angle_rad)
+
+    # reference
+    # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
+    scipy_ary = scipy.ndimage.affine_transform(
+        input_ary[0,0],
+        transform_ary,
+        # offset=offset,
+        output_shape=output_size[2:],
+        # output=None,
+        order=1,
+        mode='nearest',
+        # cval=0.0,
+        prefilter=False)
+
+    affine_tensor = affine_func2d(
+            transform_tensor,
+            torch.Size(output_size)
+        )
+
+    print('affine_tensor', affine_tensor.size(), affine_tensor.dtype, affine_tensor.device)
+    print(affine_tensor)
+
+    for r in range(affine_tensor.size(1)):
+        for c in range(affine_tensor.size(2)):
+            grid_out = grid_ary @ [r, c, 1]
+            print(r, c, 'affine:', affine_tensor[0,r,c], 'grid:', grid_out[:2])
+
+    gridsample_ary = torch.nn.functional.grid_sample(
+            torch.tensor(input_ary, device=device).to(device),
+            affine_tensor,
+            padding_mode='border'
+        ).to('cpu').numpy()
+
+    print('input_ary', input_ary.shape, input_ary.dtype)
+    print(input_ary.round(3))
+    print('gridsample_ary', gridsample_ary.shape, gridsample_ary.dtype)
+    print(gridsample_ary.round(3))
+    print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
+    print(scipy_ary.round(3))
+
+    for r in range(affine_tensor.size(1)):
+        for c in range(affine_tensor.size(2)):
+            grid_out = grid_ary @ [r, c, 1]
+
+            try:
+                assert np.allclose(affine_tensor[0,r,c], grid_out[:2], atol=1e-5)
+            except:
+                print(r, c, 'affine:', affine_tensor[0,r,c], 'grid:', grid_out[:2])
+                raise
+
+    assert np.abs(scipy_ary - gridsample_ary).max() < 1e-5
+    # assert False
+
+def test_affine_3d_rotateRandom(device, affine_func3d, angle_rad, axis_vector, input_size3d, output_size3d):
+    input_size = input_size3d
+    input_ary = np.array(np.random.random(input_size), dtype=np.float32)
+    output_size = output_size3d
+
+    input_ary[0,0,  0,  0,  0] = 2
+    input_ary[0,0,  0,  0, -1] = 3
+    input_ary[0,0,  0, -1,  0] = 4
+    input_ary[0,0,  0, -1, -1] = 5
+    input_ary[0,0, -1,  0,  0] = 6
+    input_ary[0,0, -1,  0, -1] = 7
+    input_ary[0,0, -1, -1,  0] = 8
+    input_ary[0,0, -1, -1, -1] = 9
+
+    transform_tensor, transform_ary, grid_ary = _buildEquivalentTransforms3d(device, input_size, output_size, angle_rad, axis_vector)
+
+    # reference
+    # https://stackoverflow.com/questions/20161175/how-can-i-use-scipy-ndimage-interpolation-affine-transform-to-rotate-an-image-ab
+    scipy_ary = scipy.ndimage.affine_transform(
+        input_ary[0,0],
+        transform_ary,
+        # offset=offset,
+        output_shape=output_size[2:],
+        # output=None,
+        order=1,
+        mode='nearest',
+        # cval=0.0,
+        prefilter=False)
+
+    affine_tensor = affine_func3d(
+            transform_tensor,
+            torch.Size(output_size)
+        )
+
+    print('affine_tensor', affine_tensor.size(), affine_tensor.dtype, affine_tensor.device)
+    print(affine_tensor)
+
+    for i in range(affine_tensor.size(1)):
+        for r in range(affine_tensor.size(2)):
+            for c in range(affine_tensor.size(3)):
+                grid_out = grid_ary @ [i, r, c, 1]
+                print(i, r, c, 'affine:', affine_tensor[0,i,r,c], 'grid:', grid_out[:3].round(3))
+
+    print('input_ary', input_ary.shape, input_ary.dtype)
+    print(input_ary.round(3))
+
+    gridsample_ary = torch.nn.functional.grid_sample(
+            torch.tensor(input_ary, device=device).to(device),
+            affine_tensor,
+            padding_mode='border'
+        ).to('cpu').numpy()
+
+    print('gridsample_ary', gridsample_ary.shape, gridsample_ary.dtype)
+    print(gridsample_ary.round(3))
+    print('scipy_ary', scipy_ary.shape, scipy_ary.dtype)
+    print(scipy_ary.round(3))
+
+    for i in range(affine_tensor.size(1)):
+        for r in range(affine_tensor.size(2)):
+            for c in range(affine_tensor.size(3)):
+                grid_out = grid_ary @ [i, r, c, 1]
+                try:
+                    assert np.allclose(affine_tensor[0,i,r,c], grid_out[:3], atol=1e-5)
+                except:
+                    print(i, r, c, 'affine:', affine_tensor[0,i,r,c], 'grid:', grid_out[:3].round(3))
+                    raise
+
+    assert np.abs(scipy_ary - gridsample_ary).max() < 1e-5
+    # assert False

+ 11 - 9
util/util.py

@@ -147,6 +147,7 @@ def enumerateWithEstimate(iter, desc_str, start_ndx=0, print_ndx=4, backoff=2, i
     if iter_len is None:
         iter_len = len(iter)
 
+    assert backoff >= 2
     while print_ndx < start_ndx * backoff:
         print_ndx *= backoff
 
@@ -154,14 +155,14 @@ def enumerateWithEstimate(iter, desc_str, start_ndx=0, print_ndx=4, backoff=2, i
         desc_str,
         iter_len,
     ))
+    start_ts = time.time()
     for (current_ndx, item) in enumerate(iter):
-        if current_ndx == start_ndx:
-            start_ts = time.time()
-        elif current_ndx == print_ndx:
-            # ... <1>
-            duration_sec = ((time.time() - start_ts) *
-                            (iter_len-start_ndx) /
-                            (current_ndx-start_ndx))
+        yield (current_ndx, item)
+        if current_ndx == print_ndx:
+            duration_sec = ((time.time() - start_ts)
+                            / (current_ndx - start_ndx + 1)
+                            * (iter_len-start_ndx)
+                            )
 
             done_dt = datetime.datetime.fromtimestamp(start_ts + duration_sec)
             done_td = datetime.timedelta(seconds=duration_sec)
@@ -176,7 +177,8 @@ def enumerateWithEstimate(iter, desc_str, start_ndx=0, print_ndx=4, backoff=2, i
 
             print_ndx *= backoff
 
-        yield (current_ndx, item)
+        if current_ndx + 1 == start_ndx:
+            start_ts = time.time()
 
     log.warning("{} ----/{}, done at {}".format(
         desc_str,
@@ -187,7 +189,7 @@ def enumerateWithEstimate(iter, desc_str, start_ndx=0, print_ndx=4, backoff=2, i
 
 try:
     import matplotlib
-    matplotlib.use('agg')
+    matplotlib.use('agg', warn=False)
 
     import matplotlib.pyplot as plt
     # matplotlib color maps

Some files were not shown because too many files changed in this diff