Eli Stevens 6 жил өмнө
parent
commit
43cf8e9f0f

BIN
data/p1ch3/ourpoints.hdf5


BIN
data/p1ch3/ourpoints.t


+ 30 - 4
p1ch3/1_tensors.ipynb

@@ -1262,20 +1262,46 @@
    "cell_type": "code",
    "execution_count": 71,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(torch.Size([3, 2]), torch.Size([2, 3]))"
+      ]
+     },
+     "execution_count": 71,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "a = torch.ones(3, 2)\n",
-    "a_t = torch.transpose(a, 0, 1)"
+    "a_t = torch.transpose(a, 0, 1)\n",
+    "\n",
+    "a.shape, a_t.shape"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": 72,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(torch.Size([3, 2]), torch.Size([2, 3]))"
+      ]
+     },
+     "execution_count": 72,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "a = torch.ones(3, 2)\n",
-    "a_t = a.transpose(0, 1)"
+    "a_t = a.transpose(0, 1)\n",
+    "\n",
+    "a.shape, a_t.shape"
    ]
   },
   {

+ 25 - 25
p1ch5/3_optimizers.ipynb

@@ -293,14 +293,14 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "t_u_train = t_u[train_indices]\n",
-    "t_c_train = t_c[train_indices]\n",
+    "train_t_u = t_u[train_indices]\n",
+    "train_t_c = t_c[train_indices]\n",
     "\n",
-    "t_u_val = t_u[val_indices]\n",
-    "t_c_val = t_c[val_indices]\n",
+    "val_t_u = t_u[val_indices]\n",
+    "val_t_c = t_c[val_indices]\n",
     "\n",
-    "t_un_train = 0.1 * t_u_train\n",
-    "t_un_val = 0.1 * t_u_val"
+    "train_t_un = 0.1 * train_t_u\n",
+    "val_t_un = 0.1 * val_t_u"
    ]
   },
   {
@@ -309,21 +309,21 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "def training_loop(n_epochs, optimizer, params, t_u_train, t_u_val, t_c_train, t_c_val):\n",
+    "def training_loop(n_epochs, optimizer, params, train_t_u, val_t_u, train_t_c, val_t_c):\n",
     "    for epoch in range(1, n_epochs + 1):\n",
-    "        t_p_train = model(t_un_train, *params) # <1>\n",
-    "        loss_train = loss_fn(t_p_train, t_c_train)\n",
-    "\n",
-    "        t_p_val = model(t_un_val, *params) # <1>\n",
-    "        loss_val = loss_fn(t_p_val, t_c_val)\n",
+    "        train_t_p = model(train_t_u, *params) # <1>\n",
+    "        train_loss = loss_fn(train_t_p, train_t_c)\n",
+    "                             \n",
+    "        val_t_p = model(val_t_u, *params) # <1>\n",
+    "        val_loss = loss_fn(val_t_p, val_t_c)\n",
     "        \n",
     "        optimizer.zero_grad()\n",
-    "        loss_train.backward() # <2>\n",
+    "        train_loss.backward() # <2>\n",
     "        optimizer.step()\n",
     "\n",
     "        if epoch <= 3 or epoch % 500 == 0:\n",
     "            print('Epoch {}, Training loss {}, Validation loss {}'.format(\n",
-    "                epoch, float(loss_train), float(loss_val)))\n",
+    "                epoch, float(train_loss), float(val_loss)))\n",
     "            \n",
     "    return params"
    ]
@@ -368,10 +368,10 @@
     "    n_epochs = 3000, \n",
     "    optimizer = optimizer,\n",
     "    params = params,\n",
-    "    t_u_train = t_un_train, # <1> \n",
-    "    t_u_val = t_un_val, # <1> \n",
-    "    t_c_train = t_c_train,\n",
-    "    t_c_val = t_c_val)"
+    "    train_t_u = train_t_un, # <1> \n",
+    "    val_t_u = val_t_un, # <1> \n",
+    "    train_t_c = train_t_c,\n",
+    "    val_t_c = val_t_c)"
    ]
   },
   {
@@ -380,18 +380,18 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "def training_loop(n_epochs, optimizer, params, t_u_train, t_u_val, t_c_train, t_c_val):\n",
+    "def training_loop(n_epochs, optimizer, params, train_t_u, val_t_u, train_t_c, val_t_c):\n",
     "    for epoch in range(1, n_epochs + 1):\n",
-    "        t_p_train = model(t_un_train, *params)\n",
-    "        loss_train = loss_fn(t_p_train, t_c_train)\n",
+    "        train_t_p = model(train_t_u, *params)\n",
+    "        train_loss = loss_fn(train_t_p, train_t_c)\n",
     "\n",
     "        with torch.no_grad(): # <1>\n",
-    "            t_p_val = model(t_un_val, *params)\n",
-    "            loss_val = loss_fn(t_p_val, t_c_val)\n",
-    "            assert loss_val.requires_grad == False # <2>\n",
+    "            val_t_p = model(val_t_u, *params)\n",
+    "            val_loss = loss_fn(val_t_p, val_t_c)\n",
+    "            assert val_loss.requires_grad == False # <2>\n",
     "            \n",
     "        optimizer.zero_grad()\n",
-    "        loss_train.backward()\n",
+    "        train_loss.backward()\n",
     "        optimizer.step()"
    ]
   },

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 573 - 0
p1ch6/1_neural_networks.ipynb


Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 71 - 0
p1ch6/2_activation_functions.ipynb


+ 218 - 0
p1ch6/3_nn_module_subclassing.ipynb

@@ -0,0 +1,218 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%matplotlib inline\n",
+    "import numpy as np\n",
+    "import torch\n",
+    "import torch.optim as optim\n",
+    "import torch.nn as nn\n",
+    "\n",
+    "torch.set_printoptions(edgeitems=2)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "Sequential(\n",
+       "  (0): Linear(in_features=1, out_features=11, bias=True)\n",
+       "  (1): Tanh()\n",
+       "  (2): Linear(in_features=11, out_features=1, bias=True)\n",
+       ")"
+      ]
+     },
+     "execution_count": 2,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "seq_model = nn.Sequential(\n",
+    "            nn.Linear(1, 11), # <1>\n",
+    "            nn.Tanh(),\n",
+    "            nn.Linear(11, 1)) # <2>\n",
+    "seq_model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "Sequential(\n",
+       "  (hidden_linear): Linear(in_features=1, out_features=12, bias=True)\n",
+       "  (hidden_activation): Tanh()\n",
+       "  (output_linear): Linear(in_features=12, out_features=1, bias=True)\n",
+       ")"
+      ]
+     },
+     "execution_count": 3,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "from collections import OrderedDict\n",
+    "\n",
+    "namedseq_model = nn.Sequential(OrderedDict([\n",
+    "    ('hidden_linear', nn.Linear(1, 12)),\n",
+    "    ('hidden_activation', nn.Tanh()),\n",
+    "    ('output_linear', nn.Linear(12 , 1))\n",
+    "]))\n",
+    "\n",
+    "namedseq_model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "SubclassModel(\n",
+       "  (hidden_linear): Linear(in_features=1, out_features=13, bias=True)\n",
+       "  (hidden_activation): Tanh()\n",
+       "  (output_linear): Linear(in_features=13, out_features=1, bias=True)\n",
+       ")"
+      ]
+     },
+     "execution_count": 4,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "class SubclassModel(nn.Module):\n",
+    "    def __init__(self):\n",
+    "        super().__init__()\n",
+    "        \n",
+    "        self.hidden_linear = nn.Linear(1, 13)\n",
+    "        self.hidden_activation = nn.Tanh()\n",
+    "        self.output_linear = nn.Linear(13, 1)\n",
+    "        \n",
+    "    def forward(self, input):\n",
+    "        hidden_t = self.hidden_linear(input)\n",
+    "        activated_t = self.hidden_activation(hidden_t)\n",
+    "        output_t = self.output_linear(activated_t)\n",
+    "        \n",
+    "        return output_t\n",
+    "    \n",
+    "subclass_model = SubclassModel()\n",
+    "subclass_model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "seq\n",
+      "0.weight              torch.Size([11, 1]) 11\n",
+      "0.bias                torch.Size([11])    11\n",
+      "2.weight              torch.Size([1, 11]) 11\n",
+      "2.bias                torch.Size([1])     1\n",
+      "\n",
+      "namedseq\n",
+      "hidden_linear.weight  torch.Size([12, 1]) 12\n",
+      "hidden_linear.bias    torch.Size([12])    12\n",
+      "output_linear.weight  torch.Size([1, 12]) 12\n",
+      "output_linear.bias    torch.Size([1])     1\n",
+      "\n",
+      "subclass\n",
+      "hidden_linear.weight  torch.Size([13, 1]) 13\n",
+      "hidden_linear.bias    torch.Size([13])    13\n",
+      "output_linear.weight  torch.Size([1, 13]) 13\n",
+      "output_linear.bias    torch.Size([1])     1\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "for type_str, model in [('seq', seq_model), ('namedseq', namedseq_model), ('subclass', subclass_model)]:\n",
+    "    print(type_str)\n",
+    "    for name_str, param in model.named_parameters():\n",
+    "        print(\"{:21} {:19} {}\".format(name_str, str(param.shape), param.numel()))\n",
+    "        \n",
+    "    print()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "SubclassFunctionalModel(\n",
+       "  (hidden_linear): Linear(in_features=1, out_features=14, bias=True)\n",
+       "  (output_linear): Linear(in_features=14, out_features=1, bias=True)\n",
+       ")"
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "class SubclassFunctionalModel(nn.Module):\n",
+    "    def __init__(self):\n",
+    "        super().__init__()\n",
+    "        \n",
+    "        self.hidden_linear = nn.Linear(1, 14)  \n",
+    "                                                # <1>\n",
+    "        self.output_linear = nn.Linear(14, 1)\n",
+    "        \n",
+    "    def forward(self, input):\n",
+    "        hidden_t = self.hidden_linear(input)\n",
+    "        activated_t = torch.tanh(hidden_t) # <2>\n",
+    "        output_t = self.output_linear(activated_t)\n",
+    "        \n",
+    "        return output_t\n",
+    "    \n",
+    "func_model = SubclassFunctionalModel()\n",
+    "func_model"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 56 - 0
p1ch7/1_datasets.ipynb


Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 208 - 0
p1ch7/2_birds_airplanes.ipynb


Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 270 - 0
p1ch7/4_convolution.ipynb


Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 270 - 0
p1ch8/1_convolution.ipynb


+ 6 - 41
p2ch09/dsets.py

@@ -21,7 +21,7 @@ log = logging.getLogger(__name__)
 log.setLevel(logging.INFO)
 log.setLevel(logging.DEBUG)
 
-raw_cache = getCache('part2ch10_raw')
+raw_cache = getCache('part2ch08_raw')
 
 @functools.lru_cache(1)
 def getNoduleInfoList(requireDataOnDisk_bool=True):
@@ -136,67 +136,32 @@ class LunaDataset(Dataset):
                  test_stride=0,
                  isTestSet_bool=None,
                  series_uid=None,
-                 sortby_str='random',
-                 ratio_int=0,
             ):
-        self.ratio_int = ratio_int
-
         self.noduleInfo_list = copy.copy(getNoduleInfoList())
 
         if series_uid:
             self.noduleInfo_list = [x for x in self.noduleInfo_list if x[2] == series_uid]
 
+        # __init__ continued...
         if test_stride > 1:
             if isTestSet_bool:
                 self.noduleInfo_list = self.noduleInfo_list[::test_stride]
             else:
                 del self.noduleInfo_list[::test_stride]
 
-        if sortby_str == 'random':
-            random.shuffle(self.noduleInfo_list)
-        elif sortby_str == 'series_uid':
-            self.noduleInfo_list.sort(key=lambda x: (x[2], x[3])) # sorting by series_uid, center_xyz)
-        elif sortby_str == 'malignancy_size':
-            pass
-        else:
-            raise Exception("Unknown sort: " + repr(sortby_str))
-
-        self.benignIndex_list = [i for i, x in enumerate(self.noduleInfo_list) if not x[0]]
-        self.malignantIndex_list = [i for i, x in enumerate(self.noduleInfo_list) if x[0]]
-
-        log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
+        log.info("{!r}: {} {} samples".format(
             self,
             len(self.noduleInfo_list),
             "testing" if isTestSet_bool else "training",
-            len(self.benignIndex_list),
-            len(self.malignantIndex_list),
-            '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
         ))
 
-    def shuffleSamples(self):
-        if self.ratio_int:
-            random.shuffle(self.benignIndex_list)
-            random.shuffle(self.malignantIndex_list)
-
     def __len__(self):
-        if self.ratio_int:
-            return 100000
-        else:
-            return len(self.noduleInfo_list)
+        return len(self.noduleInfo_list)
 
     def __getitem__(self, ndx):
-        if self.ratio_int:
-            malignant_ndx = ndx // (self.ratio_int + 1)
-
-            if ndx % (self.ratio_int + 1):
-                benign_ndx = ndx - 1 - malignant_ndx
-                nodule_ndx = self.benignIndex_list[benign_ndx % len(self.benignIndex_list)]
-            else:
-                nodule_ndx = self.malignantIndex_list[malignant_ndx % len(self.malignantIndex_list)]
-        else:
-            nodule_ndx = ndx
+        sample_ndx = ndx
 
-        isMalignant_bool, _diameter_mm, series_uid, center_xyz = self.noduleInfo_list[nodule_ndx]
+        isMalignant_bool, diameter_mm, series_uid, center_xyz = self.noduleInfo_list[sample_ndx]
 
         nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, (32, 32, 32))
 

+ 15 - 23
p2ch09/vis.py

@@ -1,20 +1,17 @@
 import matplotlib
-matplotlib.use('nbagg')
-
 import numpy as np
 import matplotlib.pyplot as plt
 
-from p2ch11_old.dsets import Ct, LunaDataset
+from p2ch09.dsets import Ct, LunaDataset
 
 clim=(0.0, 1.3)
 
-def findMalignantSamples(start_ndx=0, limit=10):
+def findMalignantSamples(start_ndx=0, limit=100):
     ds = LunaDataset()
 
     malignantSample_list = []
-    for sample_tup in ds.sample_list:
-        if sample_tup[2]:
-            print(len(malignantSample_list), sample_tup)
+    for sample_tup in ds.noduleInfo_list:
+        if sample_tup[0]:
             malignantSample_list.append(sample_tup)
 
         if len(malignantSample_list) >= limit:
@@ -22,9 +19,9 @@ def findMalignantSamples(start_ndx=0, limit=10):
 
     return malignantSample_list
 
-def showNodule(series_uid, batch_ndx=None, **kwargs):
-    ds = LunaDataset(series_uid=series_uid, **kwargs)
-    malignant_list = [i for i, x in enumerate(ds.sample_list) if x[2]]
+def showNodule(series_uid, batch_ndx=None):
+    ds = LunaDataset(series_uid=series_uid)
+    malignant_list = [i for i, x in enumerate(ds.noduleInfo_list) if x[0]]
 
     if batch_ndx is None:
         if malignant_list:
@@ -34,20 +31,15 @@ def showNodule(series_uid, batch_ndx=None, **kwargs):
             batch_ndx = 0
 
     ct = Ct(series_uid)
-    # ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
-    malignant_tensor, diameter_mm, series_uid, center_irc, nodule_tensor = ds[batch_ndx]
-    ct_ary = nodule_tensor[1].numpy()
-
+    ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
+    ct_ary = ct_tensor[0].numpy()
 
     fig = plt.figure(figsize=(15, 25))
 
     group_list = [
-        #[0,1,2],
-        [3,4,5],
-        [6,7,8],
-        [9,10,11],
-        #[12,13,14],
-        #[15]
+        [9,11,13],
+        [15, 16, 17],
+        [19,21,23],
     ]
 
     subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
@@ -78,9 +70,9 @@ def showNodule(series_uid, batch_ndx=None, **kwargs):
         for col, index in enumerate(index_list):
             subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
             subplot.set_title('slice {}'.format(index))
-            plt.imshow(ct_ary[index*2], clim=clim, cmap='gray')
+            plt.imshow(ct_ary[index], clim=clim, cmap='gray')
+
 
+    print(series_uid, batch_ndx, bool(malignant_tensor[0]), malignant_list)
 
-    print(series_uid, batch_ndx, bool(malignant_tensor[0]), malignant_list, ct.vxSize_xyz)
 
-    return ct_ary

+ 361 - 0
p2ch09_explore_data.ipynb

@@ -0,0 +1,361 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%matplotlib inline\n",
+    "import numpy as np"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from p2ch09.dsets import getNoduleInfoList, getCt\n",
+    "noduleInfo_list = getNoduleInfoList(requireDataOnDisk_bool=False)\n",
+    "malignantInfo_list = [x for x in noduleInfo_list if x[0]]\n",
+    "diameter_list = [x[1] for x in malignantInfo_list]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "1351\n",
+      "(True, 32.27003025, '1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886', (67.61451718, 85.02525992, -109.8084416))\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(len(malignantInfo_list))\n",
+    "print(malignantInfo_list[0])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "   0  32.3 mm\n",
+      " 100  17.7 mm\n",
+      " 200  13.0 mm\n",
+      " 300  10.0 mm\n",
+      " 400   8.2 mm\n",
+      " 500   7.0 mm\n",
+      " 600   6.3 mm\n",
+      " 700   5.7 mm\n",
+      " 800   5.1 mm\n",
+      " 900   4.7 mm\n",
+      "1000   4.0 mm\n",
+      "1100   0.0 mm\n",
+      "1200   0.0 mm\n",
+      "1300   0.0 mm\n"
+     ]
+    }
+   ],
+   "source": [
+    "for i in range(0, len(diameter_list), 100):\n",
+    "    print('{:4}  {:4.1f} mm'.format(i, diameter_list[i]))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "(True, 32.27003025, '1.3.6.1.4.1.14519.5.2.1.6279.6001.287966244644280690737019247886', (67.61451718, 85.02525992, -109.8084416))\n",
+      "(True, 30.61040636, '1.3.6.1.4.1.14519.5.2.1.6279.6001.112740418331256326754121315800', (47.90350511, 37.60442008, -99.93417567))\n",
+      "(True, 30.61040636, '1.3.6.1.4.1.14519.5.2.1.6279.6001.112740418331256326754121315800', (44.19, 37.79, -107.01))\n",
+      "(True, 30.61040636, '1.3.6.1.4.1.14519.5.2.1.6279.6001.112740418331256326754121315800', (40.69, 32.19, -97.15))\n",
+      "(True, 27.44242293, '1.3.6.1.4.1.14519.5.2.1.6279.6001.943403138251347598519939390311', (-45.29440163, 74.86925386, -97.52812481))\n",
+      "(True, 27.07544345, '1.3.6.1.4.1.14519.5.2.1.6279.6001.481278873893653517789960724156', (-102.571208, -5.186558766, -205.1033412))\n",
+      "(True, 26.83708074, '1.3.6.1.4.1.14519.5.2.1.6279.6001.487268565754493433372433148666', (121.152909372, 12.9136003304, -159.399497186))\n",
+      "(True, 26.83708074, '1.3.6.1.4.1.14519.5.2.1.6279.6001.487268565754493433372433148666', (118.8539408, 11.54202797, -165.5042458))\n",
+      "(True, 25.87269662, '1.3.6.1.4.1.14519.5.2.1.6279.6001.177086402277715068525592995222', (-66.628286875, 57.151972075, -110.12035075))\n",
+      "(True, 25.41540526, '1.3.6.1.4.1.14519.5.2.1.6279.6001.219618492426142913407827034169', (-101.7504204, -95.65460516, -138.4943211))\n",
+      "(True, 0.0, '1.3.6.1.4.1.14519.5.2.1.6279.6001.107109359065300889765026303943', (-100.57, -66.23, -218.76))\n",
+      "(True, 0.0, '1.3.6.1.4.1.14519.5.2.1.6279.6001.106379658920626694402549886949', (-71.09, 68.3, -160.4))\n",
+      "(True, 0.0, '1.3.6.1.4.1.14519.5.2.1.6279.6001.102681962408431413578140925249', (106.18, 12.61, -96.81))\n",
+      "(True, 0.0, '1.3.6.1.4.1.14519.5.2.1.6279.6001.102681962408431413578140925249', (96.2846726653, 19.0348690723, -88.478440818))\n",
+      "(True, 0.0, '1.3.6.1.4.1.14519.5.2.1.6279.6001.100621383016233746780170740405', (89.32, 190.84, -516.82))\n",
+      "(True, 0.0, '1.3.6.1.4.1.14519.5.2.1.6279.6001.100621383016233746780170740405', (89.32, 143.23, -427.1))\n",
+      "(True, 0.0, '1.3.6.1.4.1.14519.5.2.1.6279.6001.100621383016233746780170740405', (85.12, 152.33, -425.7))\n",
+      "(True, 0.0, '1.3.6.1.4.1.14519.5.2.1.6279.6001.100621383016233746780170740405', (8.8, 174.74, -401.87))\n",
+      "(True, 0.0, '1.3.6.1.4.1.14519.5.2.1.6279.6001.100621383016233746780170740405', (5.99, 171.94, -398.37))\n",
+      "(True, 0.0, '1.3.6.1.4.1.14519.5.2.1.6279.6001.100621383016233746780170740405', (1.79, 166.34, -408.88))\n"
+     ]
+    }
+   ],
+   "source": [
+    "for nodule_tup in malignantInfo_list[:10]:\n",
+    "    print(nodule_tup)\n",
+    "for nodule_tup in malignantInfo_list[-10:]:\n",
+    "    print(nodule_tup)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(array([323, 466, 248, 111,  71,  57,  37,  29,   5,   4], dtype=int64),\n",
+       " array([ 0.        ,  3.22700302,  6.45400605,  9.68100907, 12.9080121 ,\n",
+       "        16.13501512, 19.36201815, 22.58902117, 25.8160242 , 29.04302722,\n",
+       "        32.27003025]))"
+      ]
+     },
+     "execution_count": 6,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "np.histogram(diameter_list)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2018-12-11 11:24:57,865 INFO     pid:30236 p2ch09.dsets:201:__init__ <p2ch09.dsets.LunaDataset object at 0x000001B30A802438>: 551065 training samples\n"
+     ]
+    }
+   ],
+   "source": [
+    "from p2ch09.vis import findMalignantSamples, showNodule\n",
+    "malignantSample_list = findMalignantSamples()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2018-12-11 11:25:01,788 INFO     pid:30236 p2ch09.dsets:201:__init__ <p2ch09.dsets.LunaDataset object at 0x000001B30A7D6668>: 602 training samples\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "1.3.6.1.4.1.14519.5.2.1.6279.6001.183982839679953938397312236359 0 True [0, 1, 2, 3, 4, 5, 6]\n"
+     ]
+    }
+   ],
+   "source": [
+    "series_uid = malignantSample_list[11][2]\n",
+    "showNodule(series_uid)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "2018-12-11 11:25:08,042 INFO     pid:30236 p2ch09.dsets:201:__init__ <p2ch09.dsets.LunaDataset object at 0x000001B30E7C6BE0>: 605 training samples\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "1.3.6.1.4.1.14519.5.2.1.6279.6001.126264578931778258890371755354 0 True [0]\n"
+     ]
+    }
+   ],
+   "source": [
+    "series_uid = '1.3.6.1.4.1.14519.5.2.1.6279.6001.126264578931778258890371755354'\n",
+    "showNodule(series_uid)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "C:\\Users\\elis\\Miniconda3\\envs\\book\\lib\\site-packages\\ipyvolume\\serialize.py:81: RuntimeWarning: invalid value encountered in true_divide\n",
+      "  gradient = gradient / np.sqrt(gradient[0]**2 + gradient[1]**2 + gradient[2]**2)\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "339fa8710d8b459182d9d3afeb08f720",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "VBox(children=(VBox(children=(HBox(children=(Label(value='levels:'), FloatSlider(value=0.25, max=1.0, step=0.0…"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "import numpy as np\n",
+    "import ipyvolume as ipv\n",
+    "V = np.zeros((128,128,128)) # our 3d array\n",
+    "# outer box\n",
+    "V[30:-30,30:-30,30:-30] = 0.75\n",
+    "V[35:-35,35:-35,35:-35] = 0.0\n",
+    "# inner box\n",
+    "V[50:-50,50:-50,50:-50] = 0.25\n",
+    "V[55:-55,55:-55,55:-55] = 0.0\n",
+    "ipv.quickvolshow(V, level=[0.25, 0.75], opacity=0.03, level_width=0.1, data_min=0, data_max=1)\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "C:\\Users\\elis\\Miniconda3\\envs\\book\\lib\\site-packages\\ipyvolume\\widgets.py:179: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.\n",
+      "  data_view = self.data_original[view]\n",
+      "C:\\Users\\elis\\Miniconda3\\envs\\book\\lib\\site-packages\\ipyvolume\\utils.py:204: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.\n",
+      "  data = (data[slices1] + data[slices2])/2\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "61e1eb24c53149d8bdc8bd6188257862",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "VBox(children=(VBox(children=(HBox(children=(Label(value='levels:'), FloatSlider(value=0.25, max=1.0, step=0.0…"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "ct = getCt(series_uid)\n",
+    "ipv.quickvolshow(ct.ary, level=[0.25, 0.5, 0.9], opacity=0.1, level_width=0.1, data_min=0, data_max=2)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from p2ch10.dsets import getCt\n",
+    "ct = getCt(series_uid)\n",
+    "air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask = ct.build3dLungMask()\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "C:\\Users\\elis\\Miniconda3\\envs\\book\\lib\\site-packages\\ipyvolume\\serialize.py:81: RuntimeWarning: invalid value encountered in sqrt\n",
+      "  gradient = gradient / np.sqrt(gradient[0]**2 + gradient[1]**2 + gradient[2]**2)\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "9c40dfe3a3cc4b49aebce4cbee8e3d62",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "VBox(children=(VBox(children=(HBox(children=(Label(value='levels:'), FloatSlider(value=0.17, max=1.0, step=0.0…"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "bones = ct.ary * (ct.ary > 1.5)\n",
+    "lungs = ct.ary * air_mask\n",
+    "ipv.figure()\n",
+    "ipv.pylab.volshow(bones + lungs, level=[0.17, 0.17, 0.23], data_min=0.1, data_max=0.9)\n",
+    "ipv.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from mayavi import mlab\n",
+    "mlab.init_notebook()\n",
+    "mlab.test_plot3d()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.6"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 20 - 478
p2ch10/dsets.py

@@ -2,37 +2,26 @@ import copy
 import csv
 import functools
 import glob
-import itertools
-import math
 import os
 import random
 
-from collections import namedtuple
-
 import SimpleITK as sitk
 
-import scipy.ndimage.morphology
-
 import numpy as np
 import torch
 import torch.cuda
-from torch.utils.data import Dataset, DataLoader
-from torch.utils.data.sampler import Sampler
+from torch.utils.data import Dataset
 
 from util.disk import getCache
-from util.util import XyzTuple, xyz2irc, IrcTuple
+from util.util import XyzTuple, xyz2irc
 from util.logconf import logging
-from util.affine import affine_grid_generator
 
 log = logging.getLogger(__name__)
 # log.setLevel(logging.WARN)
 log.setLevel(logging.INFO)
 log.setLevel(logging.DEBUG)
 
-raw_cache = getCache('part2ch11_raw')
-cubic_cache = getCache('part2ch11_cubic')
-
-NoduleInfoTuple = namedtuple('NoduleInfoTuple', 'isMalignant_bool, diameter_mm, series_uid, center_xyz')
+raw_cache = getCache('part2ch09_raw')
 
 @functools.lru_cache(1)
 def getNoduleInfoList(requireDataOnDisk_bool=True):
@@ -72,13 +61,13 @@ def getNoduleInfoList(requireDataOnDisk_bool=True):
                     candidateDiameter_mm = annotationDiameter_mm
                     break
 
-            noduleInfo_list.append(NoduleInfoTuple(isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
+            noduleInfo_list.append((isMalignant_bool, candidateDiameter_mm, series_uid, candidateCenter_xyz))
 
     noduleInfo_list.sort(reverse=True)
     return noduleInfo_list
 
 class Ct(object):
-    def __init__(self, series_uid, buildMasks_bool=True):
+    def __init__(self, series_uid):
         mhd_path = glob.glob('data-unversioned/part2/luna/subset*/{}.mhd'.format(series_uid))[0]
 
         ct_mhd = sitk.ReadImage(mhd_path)
@@ -103,108 +92,6 @@ class Ct(object):
         self.vxSize_xyz = XyzTuple(*ct_mhd.GetSpacing())
         self.direction_tup = tuple(int(round(x)) for x in ct_mhd.GetDirection())
 
-        noduleInfo_list = getNoduleInfoList()
-        self.benignInfo_list = [ni_tup
-                                for ni_tup in noduleInfo_list
-                                    if not ni_tup.isMalignant_bool
-                                        and ni_tup.series_uid == self.series_uid]
-        self.benign_mask = self.buildAnnotationMask(self.benignInfo_list)[0]
-        self.benign_indexes = sorted(set(self.benign_mask.nonzero()[0]))
-
-        self.malignantInfo_list = [ni_tup
-                                   for ni_tup in noduleInfo_list
-                                        if ni_tup.isMalignant_bool
-                                            and ni_tup.series_uid == self.series_uid]
-        self.malignant_mask = self.buildAnnotationMask(self.malignantInfo_list)[0]
-        self.malignant_indexes = sorted(set(self.malignant_mask.nonzero()[0]))
-
-    def buildAnnotationMask(self, noduleInfo_list, threshold_gcc = 0.5):
-        boundingBox_ary = np.zeros_like(self.ary, dtype=np.bool)
-
-        for noduleInfo_tup in noduleInfo_list:
-            center_irc = xyz2irc(noduleInfo_tup.center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
-            center_index = int(center_irc.index)
-            center_row = int(center_irc.row)
-            center_col = int(center_irc.col)
-
-            index_radius = 2
-            try:
-                while self.ary[center_index + index_radius, center_row, center_col] > threshold_gcc and \
-                            self.ary[center_index - index_radius, center_row, center_col] > threshold_gcc:
-                    index_radius += 1
-            except IndexError:
-                index_radius -= 1
-
-            row_radius = 2
-            try:
-                while self.ary[center_index, center_row + row_radius, center_col] > threshold_gcc and \
-                            self.ary[center_index, center_row - row_radius, center_col] > threshold_gcc:
-                    row_radius += 1
-            except IndexError:
-                row_radius -= 1
-
-            col_radius = 2
-            try:
-                while self.ary[center_index, center_row, center_col + col_radius] > threshold_gcc and \
-                            self.ary[center_index, center_row, center_col - col_radius] > threshold_gcc:
-                    col_radius += 1
-            except IndexError:
-                col_radius -= 1
-
-            # assert index_radius > 0, repr([noduleInfo_tup.center_xyz, center_irc, self.ary[center_index, center_row, center_col]])
-            # assert row_radius > 0
-            # assert col_radius > 0
-
-
-            slice_tup = (
-                slice(
-                    # max(0, center_index - index_radius),
-                    center_index - index_radius,
-                    center_index + index_radius + 1,
-                ),
-                slice(
-                    # max(0, center_row - row_radius),
-                    center_row - row_radius,
-                    center_row + row_radius + 1,
-                ),
-                slice(
-                    # max(0, center_col - col_radius),
-                    center_col - col_radius,
-                    center_col + row_radius + 1,
-                ),
-            )
-
-            boundingBox_ary[slice_tup] = True
-
-        thresholded_ary = boundingBox_ary & (self.ary > threshold_gcc)
-        mask_ary = scipy.ndimage.morphology.binary_dilation(thresholded_ary, iterations=2)
-
-        return mask_ary, thresholded_ary, boundingBox_ary
-
-    def build2dLungMask(self, mask_ndx, threshold_gcc = 0.7):
-        dense_mask = self.ary[mask_ndx] > threshold_gcc
-        denoise_mask = scipy.ndimage.morphology.binary_closing(dense_mask, iterations=2)
-        tissue_mask = scipy.ndimage.morphology.binary_opening(denoise_mask, iterations=10)
-        body_mask = scipy.ndimage.morphology.binary_fill_holes(tissue_mask)
-        air_mask = scipy.ndimage.morphology.binary_fill_holes(body_mask & ~tissue_mask)
-
-        lung_mask = scipy.ndimage.morphology.binary_dilation(air_mask, iterations=2)
-
-        return air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask
-
-    def build3dLungMask(self):
-        air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask = mask_list = \
-            [np.zeros_like(self.ary, dtype=np.bool) for _ in range(6)]
-
-        for mask_ndx in range(self.ary.shape[0]):
-            for i, mask_ary in enumerate(self.build2dLungMask(mask_ndx)):
-                mask_list[i][mask_ndx] = mask_ary
-
-        return air_mask, lung_mask, dense_mask, denoise_mask, tissue_mask, body_mask
-
-
-
-
     def getRawNodule(self, center_xyz, width_irc):
         center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
 
@@ -229,213 +116,34 @@ class Ct(object):
 
             slice_list.append(slice(start_ndx, end_ndx))
 
-        ct_chunk = self.ary[tuple(slice_list)]
+        ct_chunk = self.ary[slice_list]
 
         return ct_chunk, center_irc
 
-    def getCubicInputChunk(self, center_xyz, maxWidth_mm):
-        center_irc = xyz2irc(center_xyz, self.origin_xyz, self.vxSize_xyz, self.direction_tup)
-
-        ct_start = [int(round(i)) for i in xyz2irc(tuple(x - maxWidth_mm / 2 for x in center_xyz), self.origin_xyz, self.vxSize_xyz, self.direction_tup)]
-        ct_end = [int(round(i)) + 1 for i in xyz2irc(tuple(x + maxWidth_mm / 2 for x in center_xyz), self.origin_xyz, self.vxSize_xyz, self.direction_tup)]
-
-        for axis in range(3):
-            if ct_start[axis] > ct_end[axis]:
-                ct_start[axis], ct_end[axis] = ct_end[axis], ct_start[axis]
-
-        pad_start = [0, 0, 0]
-        pad_end = [ct_end[axis] - ct_start[axis] for axis in range(3)]
-        # log.info([ct_end, ct_start, pad_end])
-        chunk_ary = np.zeros(pad_end, dtype=np.float32)
-
-        for axis in range(3):
-            if ct_start[axis] < 0:
-                pad_start[axis] = -ct_start[axis]
-                ct_start[axis] = 0
 
-            if ct_end[axis] > self.ary.shape[axis]:
-                pad_end[axis] -= ct_end[axis] - self.ary.shape[axis]
-                ct_end[axis] = self.ary.shape[axis]
-
-        pad_slices = tuple(slice(s,e) for s, e in zip(pad_start, pad_end))
-        ct_slices = tuple(slice(s,e) for s, e in zip(ct_start, ct_end))
-        chunk_ary[pad_slices] = self.ary[ct_slices]
-
-        return chunk_ary, center_irc
-
-
-ctCache_depth = 3
-@functools.lru_cache(ctCache_depth, typed=True)
+@functools.lru_cache(1, typed=True)
 def getCt(series_uid):
     return Ct(series_uid)
 
-@raw_cache.memoize(typed=True)
-def getCtSize(series_uid):
-    ct = Ct(series_uid, buildMasks_bool=False)
-    return tuple(ct.ary.shape)
-
-# @raw_cache.memoize(typed=True)
-# def getCtLungExtents(series_uid):
-#     ct = getCt(series_uid)
-#     return (int(min(ct.lung_indexes)), int(max(ct.lung_indexes)))
-
 @raw_cache.memoize(typed=True)
 def getCtRawNodule(series_uid, center_xyz, width_irc):
     ct = getCt(series_uid)
     ct_chunk, center_irc = ct.getRawNodule(center_xyz, width_irc)
-
-    return ct_chunk, center_irc
-
-# clamp_value = 1.5
-@functools.lru_cache(1, typed=True)
-@cubic_cache.memoize(typed=True)
-def getCtCubicChunk(series_uid, center_xyz, maxWidth_mm):
-    ct = getCt(series_uid)
-    ct_chunk, center_irc = ct.getCubicInputChunk(center_xyz, maxWidth_mm)
-
-    # # ct_chunk has been clamped to [0, 2] at this point
-    # # We are going to convert to uint8 to reduce size on disk and loading time
-    # ct_chunk[ct_chunk > clamp_value] = clamp_value
-    # ct_chunk *= 255/clamp_value
-    # ct_chunk = np.array(ct_chunk, dtype=np.uint8)
-
     return ct_chunk, center_irc
 
-def getCtAugmentedNodule(augmentation_dict, series_uid, center_xyz, width_mm, voxels_int, maxWidth_mm=32.0, use_cache=True):
-    assert width_mm <= maxWidth_mm
-
-    if use_cache:
-        cubic_chunk, center_irc = getCtCubicChunk(series_uid, center_xyz, maxWidth_mm)
-    else:
-        ct = getCt(series_uid)
-        ct_chunk, center_irc = ct.getCubicInputChunk(center_xyz, maxWidth_mm)
-
-    slice_list = []
-    for axis in range(3):
-        crop_size = cubic_chunk.shape[axis] * width_mm / maxWidth_mm
-        crop_size = int(math.ceil(crop_size))
-        start_ndx = (cubic_chunk.shape[axis] - crop_size) // 2
-        end_ndx = start_ndx + crop_size
-
-        slice_list.append(slice(start_ndx, end_ndx))
-
-    cropped_chunk = cubic_chunk[slice_list]
-
-    # # inflate cropped_chunk back to float32
-    # cropped_chunk = np.array(cropped_chunk, dtype=np.float32)
-    # cropped_chunk *= clamp_value/255
-    cropped_tensor = torch.tensor(cropped_chunk).unsqueeze(0).unsqueeze(0)
-
-    transform_tensor = torch.eye(4).to(torch.float64)
-
-    # Scale and Mirror
-    for i in range(3):
-        if 'scale' in augmentation_dict:
-            scale_float = augmentation_dict['scale']
-            transform_tensor[i,i] *= 1.0 - scale_float/2.0 + (random.random() * scale_float)
-
-        if 'mirror' in augmentation_dict:
-            if random.random() > 0.5:
-                transform_tensor[i,i] *= -1
-
-    # Rotate
-    if 'rotate' in augmentation_dict:
-        angle_rad = random.random() * math.pi * 2
-        s = math.sin(angle_rad)
-        c = math.cos(angle_rad)
-        c1 = 1 - c
-
-        axis_tensor = torch.rand([3], dtype=torch.float64)
-        axis_tensor /= axis_tensor.pow(2).sum().pow(0.5)
-
-        z, y, x = axis_tensor
-        rotation_tensor = torch.tensor([
-            [x*x*c1 +   c, y*x*c1 - z*s, z*x*c1 + y*s, 0],
-            [x*y*c1 + z*s, y*y*c1 +   c, z*y*c1 - x*s, 0],
-            [x*z*c1 - y*s, y*z*c1 + x*s, z*z*c1 +   c, 0],
-            [0, 0, 0, 1],
-        ], dtype=torch.float64)
-
-        transform_tensor @= rotation_tensor
-
-    # Transform into final desired shape
-    affine_tensor = affine_grid_generator(
-            transform_tensor[:3].unsqueeze(0).to(torch.float32),
-            torch.Size([1, 1, voxels_int, voxels_int, voxels_int])
-        )
-
-    zoomed_chunk = torch.nn.functional.grid_sample(
-            cropped_tensor,
-            affine_tensor,
-            padding_mode='border'
-        ).to('cpu')
-
-    # Noise
-    if 'noise' in augmentation_dict:
-        noise_tensor = torch.randn(
-                zoomed_chunk.size(),
-                dtype=zoomed_chunk.dtype,
-            )
-        noise_tensor *= augmentation_dict['noise']
-        zoomed_chunk += noise_tensor
-
-    return zoomed_chunk[0,0], center_irc
-
-
-class LunaPrepcacheDataset(Dataset):
-    def __init__(self):
-        self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
-
-    def __len__(self):
-        return len(self.series_list)
-
-    def __getitem__(self, ndx):
-        getCtSize(self.series_list[ndx])
-        # getCtLungExtents(self.series_list[ndx])
-
-        return 0
-
-
-class LunaClassificationDataset(Dataset):
+class LunaDataset(Dataset):
     def __init__(self,
                  test_stride=0,
                  isTestSet_bool=None,
                  series_uid=None,
                  sortby_str='random',
-                 ratio_int=0,
-                 scaled_bool=False,
-                 multiscaled_bool=False,
-                 augmented_bool=False,
-                 noduleInfo_list=None,
             ):
-        self.ratio_int = ratio_int
-        self.scaled_bool = scaled_bool
-        self.multiscaled_bool = multiscaled_bool
-
-        if augmented_bool:
-            self.augmentation_dict = {
-                'mirror': True,
-                'rotate': True,
-            }
-
-            if isTestSet_bool:
-                self.augmentation_dict['scale'] = 0.25
-            else:
-                self.augmentation_dict['scale'] = 0.5
-                self.augmentation_dict['noise'] = 0.1
-        else:
-            self.augmentation_dict = {}
-
-        if noduleInfo_list:
-            self.noduleInfo_list = copy.copy(noduleInfo_list)
-            self.use_cache = False
-        else:
-            self.noduleInfo_list = copy.copy(getNoduleInfoList())
-            self.use_cache = True
+        self.noduleInfo_list = copy.copy(getNoduleInfoList())
 
         if series_uid:
             self.noduleInfo_list = [x for x in self.noduleInfo_list if x[2] == series_uid]
 
+        # __init__ continued...
         if test_stride > 1:
             if isTestSet_bool:
                 self.noduleInfo_list = self.noduleInfo_list[::test_stride]
@@ -451,198 +159,32 @@ class LunaClassificationDataset(Dataset):
         else:
             raise Exception("Unknown sort: " + repr(sortby_str))
 
-        self.benignIndex_list = [i for i, x in enumerate(self.noduleInfo_list) if not x[0]]
-        self.malignantIndex_list = [i for i, x in enumerate(self.noduleInfo_list) if x[0]]
-
-        log.info("{!r}: {} {} samples, {} ben, {} mal, {} ratio".format(
+        log.info("{!r}: {} {} samples".format(
             self,
             len(self.noduleInfo_list),
             "testing" if isTestSet_bool else "training",
-            len(self.benignIndex_list),
-            len(self.malignantIndex_list),
-            '{}:1'.format(self.ratio_int) if self.ratio_int else 'unbalanced'
         ))
 
-    def shuffleSamples(self):
-        if self.ratio_int:
-            random.shuffle(self.benignIndex_list)
-            random.shuffle(self.malignantIndex_list)
 
     def __len__(self):
-        if self.ratio_int:
-            # return 10000
-            return 100000
-        elif self.augmentation_dict:
-            return len(self.noduleInfo_list) * 5
-        else:
-            return len(self.noduleInfo_list)
+        # if self.ratio_int:
+        #     return min(len(self.benignIndex_list), len(self.malignantIndex_list)) * 4 * 90
+        # else:
+        return len(self.noduleInfo_list)
 
     def __getitem__(self, ndx):
-        if self.ratio_int:
-            malignant_ndx = ndx // (self.ratio_int + 1)
-
-            if ndx % (self.ratio_int + 1):
-                benign_ndx = ndx - 1 - malignant_ndx
-                nodule_ndx = self.benignIndex_list[benign_ndx % len(self.benignIndex_list)]
-            else:
-                nodule_ndx = self.malignantIndex_list[malignant_ndx % len(self.malignantIndex_list)]
+        sample_ndx = ndx
 
-            augmentation_dict = self.augmentation_dict
-        else:
-            nodule_ndx = ndx % len(self.noduleInfo_list)
+        isMalignant_bool, _diameter_mm, series_uid, center_xyz = self.noduleInfo_list[sample_ndx]
 
-            if ndx < len(self.noduleInfo_list):
-                augmentation_dict = {}
-            else:
-                augmentation_dict = self.augmentation_dict
+        nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, (32, 32, 32))
 
-        isMalignant_bool, _diameter_mm, series_uid, center_xyz = self.noduleInfo_list[nodule_ndx]
+        nodule_tensor = torch.from_numpy(nodule_ary)
+        nodule_tensor = nodule_tensor.unsqueeze(0)
 
-        if self.scaled_bool:
-            channel_list = []
-            voxels_int = 32
-
-            if self.multiscaled_bool:
-                width_list = [8, 16, 32]
-            else:
-                width_list = [24]
-
-            for width_mm in width_list:
-                nodule_ary, center_irc = getCtAugmentedNodule(augmentation_dict, series_uid, center_xyz, width_mm, voxels_int)
-                # in:  dim=3, Index x Row x Col
-                # out: dim=4, Channel x Index x Row x Col
-                nodule_ary = nodule_ary.unsqueeze(0)
-                channel_list.append(nodule_ary)
-
-            nodule_tensor = torch.cat(channel_list)
-
-        else:
-            nodule_ary, center_irc = getCtRawNodule(series_uid, center_xyz, (32, 32, 32))
-            nodule_ary = np.expand_dims(nodule_ary, 0)
-            nodule_tensor = torch.from_numpy(nodule_ary)
-
-        # dim=1
         malignant_tensor = torch.tensor([isMalignant_bool], dtype=torch.float32)
 
         return nodule_tensor, malignant_tensor, series_uid, center_irc
-        #
-        # return malignant_tensor, diameter_mm, series_uid, center_irc, nodule_tensor
-
-class Luna2dSegmentationDataset(Dataset):
-    purpose_str = 'general'
-
-    def __init__(self,
-                 contextSlices_count=2,
-                 series_uid=None,
-                 test_stride=0,
-            ):
-        self.contextSlices_count = contextSlices_count
-        if series_uid:
-            self.series_list = [series_uid]
-        else:
-            self.series_list = sorted(set(noduleInfo_tup.series_uid for noduleInfo_tup in getNoduleInfoList()))
-        self.cullTrainOrTestSeries(test_stride)
-
-
-        self.sample_list = []
-        for series_uid in self.series_list:
-            self.sample_list.extend([(series_uid, i) for i in range(int(getCtSize(series_uid)[0]))])
-
-        log.info("{!r}: {} {} series, {} slices".format(
-            self,
-            len(self.series_list),
-            self.purpose_str,
-            len(self.sample_list),
-        ))
-
-    def cullTrainOrTestSeries(self, test_stride):
-        assert test_stride == 0
-
-    def __len__(self):
-        return len(self.sample_list) #// 100
-
-    def __getitem__(self, ndx):
-        if isinstance(ndx, int):
-            series_uid, sample_ndx = self.sample_list[ndx % len(self.sample_list)]
-        else:
-            series_uid, sample_ndx = ndx
-        ct = getCt(series_uid)
-
-        ct_tensor = torch.zeros((self.contextSlices_count * 2 + 2, 512, 512))
-        masks_tensor = torch.zeros((2, 512, 512))
-
-        start_ndx = sample_ndx - self.contextSlices_count
-        end_ndx = sample_ndx + self.contextSlices_count + 1
-        for i, context_ndx in enumerate(range(start_ndx, end_ndx)):
-            context_ndx = max(context_ndx, 0)
-            context_ndx = min(context_ndx, ct.ary.shape[0] - 1)
-
-            ct_tensor[i] = torch.from_numpy(ct.ary[context_ndx].astype(np.float32))
-
-        air_mask, lung_mask = ct.build2dLungMask(sample_ndx)[:2]
-
-        ct_tensor[-1] = torch.from_numpy(lung_mask.astype(np.float32))
-
-        mal_mask = ct.malignant_mask[sample_ndx] & lung_mask
-        ben_mask = ct.benign_mask[sample_ndx] & air_mask
-
-        masks_tensor[0] = torch.from_numpy(mal_mask.astype(np.float32))
-        masks_tensor[1] = torch.from_numpy((mal_mask | ben_mask).astype(np.float32))
-        # masks_tensor[1] = torch.from_numpy(ben_mask.astype(np.float32))
-
-        return ct_tensor.contiguous(), masks_tensor.contiguous(), ct.series_uid, sample_ndx
-
-
-class TrainingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
-    purpose_str = 'training'
-
-    def __init__(self, *args, **kwargs):
-        self.needsShuffle_bool = True
-        super().__init__(*args, **kwargs)
-
-    def cullTrainOrTestSeries(self, test_stride):
-        assert test_stride > 0, test_stride
-        del self.series_list[::test_stride]
-        assert self.series_list
-
-    def __len__(self):
-        # return 100
-        # return 1000
-        # return 10000
-        return 20000
-        # return 40000
-
-    def __getitem__(self, ndx):
-        if self.needsShuffle_bool:
-            random.shuffle(self.series_list)
-            self.needsShuffle_bool = False
-
-        if random.random() < 0.01:
-            self.series_list.append(self.series_list.pop(0))
-
-        if isinstance(ndx, int):
-            series_uid = self.series_list[ndx % ctCache_depth]
-            ct = getCt(series_uid)
-            sample_ndx = random.choice(ct.malignant_indexes or ct.benign_indexes)
-            # series_uid, sample_ndx = self.sample_list[ndx % len(self.sample_list)]
-        else:
-            series_uid, sample_ndx = ndx
-
-        # if ndx % 2 == 0:
-        #     sample_ndx = random.choice(ct.malignant_indexes or ct.benign_indexes)
-        # else: #if ndx % 2 == 2:
-        #     sample_ndx = random.choice(ct.benign_indexes)
-        # else:
-        #     sample_ndx = random.randint(*self.series2extents_dict[series_uid])
-
-        return super().__getitem__((series_uid, sample_ndx))
-
 
-class TestingLuna2dSegmentationDataset(Luna2dSegmentationDataset):
-    purpose_str = 'testing'
 
-    def cullTrainOrTestSeries(self, test_stride):
-        assert test_stride > 0
-        self.series_list = self.series_list[::test_stride]
-        assert self.series_list
 

+ 0 - 16
p2ch10/model.py

@@ -3,7 +3,6 @@ import torch
 from torch import nn as nn
 
 from util.logconf import logging
-from util.unet import UNet
 
 log = logging.getLogger(__name__)
 # log.setLevel(logging.WARN)
@@ -51,18 +50,3 @@ class LunaModel(nn.Module):
 
         classifier_output = self.final(classifier_output)
         return classifier_output
-
-class UNetWrapper(nn.Module):
-    def __init__(self, **kwargs):
-        super().__init__()
-
-        self.batchnorm = nn.BatchNorm2d(kwargs['in_channels'])
-        self.unet = UNet(**kwargs)
-        self.hardtanh = nn.Hardtanh(min_val=0, max_val=1)
-
-    def forward(self, input):
-        bn_output = self.batchnorm(input)
-        un_output = self.unet(bn_output)
-        ht_output = self.hardtanh(un_output)
-
-        return ht_output

+ 6 - 15
p2ch10/prepcache.py

@@ -9,9 +9,9 @@ from torch.optim import SGD
 from torch.utils.data import DataLoader
 
 from util.util import enumerateWithEstimate
-from .dsets import LunaClassificationDataset, getCtSize
+from .dsets import LunaDataset
 from util.logconf import logging
-# from .model import LunaModel
+from .model import LunaModel
 
 log = logging.getLogger(__name__)
 # log.setLevel(logging.WARN)
@@ -28,7 +28,7 @@ class LunaPrepCacheApp(object):
         parser = argparse.ArgumentParser()
         parser.add_argument('--batch-size',
             help='Batch size to use for training',
-            default=32,
+            default=1024,
             type=int,
         )
         parser.add_argument('--num-workers',
@@ -36,11 +36,6 @@ class LunaPrepCacheApp(object):
             default=8,
             type=int,
         )
-        # parser.add_argument('--scaled',
-        #     help="Scale the CT chunks to square voxels.",
-        #     default=False,
-        #     action='store_true',
-        # )
 
         self.cli_args = parser.parse_args(sys_argv)
 
@@ -48,7 +43,7 @@ class LunaPrepCacheApp(object):
         log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
 
         self.prep_dl = DataLoader(
-            LunaClassificationDataset(
+            LunaDataset(
                 sortby_str='series_uid',
             ),
             batch_size=self.cli_args.batch_size,
@@ -60,12 +55,8 @@ class LunaPrepCacheApp(object):
             "Stuffing cache",
             start_ndx=self.prep_dl.num_workers,
         )
-        for batch_ndx, batch_tup in batch_iter:
-            _nodule_tensor, _malignant_tensor, series_list, _center_list = batch_tup
-            for series_uid in sorted(set(series_list)):
-                getCtSize(series_uid)
-            # input_tensor, label_tensor, _series_list, _start_list = batch_tup
-
+        for _ in batch_iter:
+            pass
 
 
 if __name__ == '__main__':

+ 102 - 596
p2ch10/training.py

@@ -1,7 +1,6 @@
 import argparse
 import datetime
 import os
-import socket
 import sys
 
 import numpy as np
@@ -9,56 +8,23 @@ from tensorboardX import SummaryWriter
 
 import torch
 import torch.nn as nn
-import torch.optim
-
-from torch.optim import SGD, Adam
+from torch.optim import SGD
 from torch.utils.data import DataLoader
 
 from util.util import enumerateWithEstimate
-from .dsets import TrainingLuna2dSegmentationDataset, TestingLuna2dSegmentationDataset, LunaClassificationDataset, getCt
+from .dsets import LunaDataset
 from util.logconf import logging
-from util.util import xyz2irc
-from .model import UNetWrapper, LunaModel
+from .model import LunaModel
 
 log = logging.getLogger(__name__)
 # log.setLevel(logging.WARN)
-# log.setLevel(logging.INFO)
-log.setLevel(logging.DEBUG)
-
-# Used for computeClassificationLoss and logMetrics to index into metrics_tensor/metrics_ary
-# METRICS_LABEL_NDX=0
-# METRICS_PRED_NDX=1
-# METRICS_LOSS_NDX=2
-# METRICS_MAL_LOSS_NDX=3
-# METRICS_BEN_LOSS_NDX=4
-# METRICS_LUNG_LOSS_NDX=5
-# METRICS_MASKLOSS_NDX=2
-# METRICS_MALLOSS_NDX=3
-
-
-METRICS_LOSS_NDX = 0
-METRICS_LABEL_NDX = 1
-METRICS_PRED_NDX = 2
-
-METRICS_MTP_NDX = 3
-METRICS_MFN_NDX = 4
-METRICS_MFP_NDX = 5
-METRICS_BTP_NDX = 6
-METRICS_BFN_NDX = 7
-METRICS_BFP_NDX = 8
-
-METRICS_MAL_LOSS_NDX = 9
-METRICS_BEN_LOSS_NDX = 10
-
-# METRICS_MFOUND_NDX = 2
-
-# METRICS_MOK_NDX = 2
-
-# METRICS_FLG_LOSS_NDX = 10
-METRICS_SIZE = 11
-
-
+log.setLevel(logging.INFO)
+# log.setLevel(logging.DEBUG)
 
+# Used for computeBatchLoss and logMetrics to index into metrics_tensor/metrics_ary
+METRICS_LABEL_NDX=0
+METRICS_PRED_NDX=1
+METRICS_LOSS_NDX=2
 
 class LunaTrainingApp(object):
     def __init__(self, sys_argv=None):
@@ -68,7 +34,7 @@ class LunaTrainingApp(object):
         parser = argparse.ArgumentParser()
         parser.add_argument('--batch-size',
             help='Batch size to use for training',
-            default=4,
+            default=32,
             type=int,
         )
         parser.add_argument('--num-workers',
@@ -81,179 +47,46 @@ class LunaTrainingApp(object):
             default=1,
             type=int,
         )
-        # parser.add_argument('--resume',
-        #     default=None,
-        #     help="File to resume training from.",
-        # )
-
-        parser.add_argument('--segmentation',
-            help="TODO", # TODO
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--balanced',
-            help="Balance the training data to half benign, half malignant.",
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--adaptive',
-            help="Balance the training data to start half benign, half malignant, and end at a 100:1 ratio.",
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--scaled',
-            help="Scale the CT chunks to square voxels.",
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--multiscaled',
-            help="Scale the CT chunks to square voxels.",
-            action='store_true',
-            default=False,
-        )
-        parser.add_argument('--augmented',
-            help="Augment the training data (implies --scaled).",
-            action='store_true',
-            default=False,
-        )
-
-        parser.add_argument('--tb-prefix',
-            default='p2ch10',
-            help="Data prefix to use for Tensorboard run. Defaults to chapter.",
-        )
-
-        parser.add_argument('comment',
-            help="Comment suffix for Tensorboard run.",
-            nargs='?',
-            default='none',
-        )
 
         self.cli_args = parser.parse_args(sys_argv)
-        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H.%M.%S')
+        self.time_str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
 
-        self.trn_writer = None
-        self.tst_writer = None
+    def main(self):
+        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
 
         self.use_cuda = torch.cuda.is_available()
         self.device = torch.device("cuda" if self.use_cuda else "cpu")
 
-        # TODO: remove this if block before print
-        # This is due to an odd setup that the author is using to test the code; please ignore for now
-        if socket.gethostname() == 'c2':
-            self.device = torch.device("cuda:1")
-
-        self.model = self.initModel()
-        self.optimizer = self.initOptimizer()
-
-        self.totalTrainingSamples_count = 0
-
-
-
-    def initModel(self):
-        if self.cli_args.segmentation:
-            model = UNetWrapper(in_channels=8, n_classes=2, depth=5, wf=6, padding=True, batch_norm=True, up_mode='upconv')
-        else:
-            model = LunaModel()
-
+        self.model = LunaModel()
         if self.use_cuda:
             if torch.cuda.device_count() > 1:
+                self.model = nn.DataParallel(self.model)
 
-                # TODO: remove this if block before print
-                # This is due to an odd setup that the author is using to test the code; please ignore for now
-                if socket.gethostname() == 'c2':
-                    model = nn.DataParallel(model, device_ids=[1, 0])
-                else:
-                    model = nn.DataParallel(model)
-
-            model = model.to(self.device)
-
-
-        return model
-
-    def initOptimizer(self):
-        return SGD(self.model.parameters(), lr=0.01, momentum=0.99)
-        # return Adam(self.model.parameters())
-
-
-    def initTrainDl(self):
-        if self.cli_args.segmentation:
-            train_ds = TrainingLuna2dSegmentationDataset(
-                    test_stride=10,
-                    contextSlices_count=3,
-                )
-        else:
-            train_ds = LunaClassificationDataset(
-                 test_stride=10,
-                 isTestSet_bool=False,
-                 # series_uid=None,
-                 # sortby_str='random',
-                 ratio_int=int(self.cli_args.balanced),
-                 # scaled_bool=False,
-                 # multiscaled_bool=False,
-                 # augmented_bool=False,
-                 # noduleInfo_list=None,
-            )
+            self.model = self.model.to(self.device)
+        self.optimizer = SGD(self.model.parameters(), lr=0.01, momentum=0.9)
 
         train_dl = DataLoader(
-            train_ds,
+            LunaDataset(
+                test_stride=10,
+                isTestSet_bool=False,
+            ),
             batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
             num_workers=self.cli_args.num_workers,
             pin_memory=self.use_cuda,
         )
 
-        return train_dl
-
-    def initTestDl(self):
-        if self.cli_args.segmentation:
-            test_ds = TestingLuna2dSegmentationDataset(
-                    test_stride=10,
-                    contextSlices_count=3,
-                )
-        else:
-            test_ds = LunaClassificationDataset(
-                 test_stride=10,
-                 isTestSet_bool=True,
-                 # series_uid=None,
-                 # sortby_str='random',
-                 # ratio_int=int(self.cli_args.balanced),
-                 # scaled_bool=False,
-                 # multiscaled_bool=False,
-                 # augmented_bool=False,
-                 # noduleInfo_list=None,
-            )
-
         test_dl = DataLoader(
-            test_ds,
+            LunaDataset(
+                test_stride=10,
+                isTestSet_bool=True,
+            ),
             batch_size=self.cli_args.batch_size * (torch.cuda.device_count() if self.use_cuda else 1),
             num_workers=self.cli_args.num_workers,
             pin_memory=self.use_cuda,
         )
 
-        return test_dl
-
-    def initTensorboardWriters(self):
-        if self.trn_writer is None:
-            log_dir = os.path.join('runs', self.cli_args.tb_prefix, self.time_str)
-
-            type_str = 'seg_' if self.cli_args.segmentation else 'cls_'
-
-            self.trn_writer = SummaryWriter(log_dir=log_dir + '_trn_' + type_str + self.cli_args.comment)
-            self.tst_writer = SummaryWriter(log_dir=log_dir + '_tst_' + type_str + self.cli_args.comment)
-
-
-
-    def main(self):
-        log.info("Starting {}, {}".format(type(self).__name__, self.cli_args))
-
-        train_dl = self.initTrainDl()
-        test_dl = self.initTestDl()
-
-        self.initTensorboardWriters()
-        self.logModelMetrics(self.model)
-
-        best_score = 0.0
-
         for epoch_ndx in range(1, self.cli_args.epochs + 1):
+
             log.info("Epoch {} of {}, {}/{} batches of size {}*{}".format(
                 epoch_ndx,
                 self.cli_args.epochs,
@@ -263,69 +96,37 @@ class LunaTrainingApp(object):
                 (torch.cuda.device_count() if self.use_cuda else 1),
             ))
 
-            trainingMetrics_tensor = self.doTraining(epoch_ndx, train_dl)
-            if epoch_ndx > 0:
-                self.logPerformanceMetrics(epoch_ndx, 'trn', trainingMetrics_tensor)
-
-            self.logModelMetrics(self.model)
-
-            if self.cli_args.segmentation:
-                self.logImages(epoch_ndx, train_dl, test_dl)
-
-            testingMetrics_tensor = self.doTesting(epoch_ndx, test_dl)
-            score = self.logPerformanceMetrics(epoch_ndx, 'tst', testingMetrics_tensor)
-            best_score = max(score, best_score)
-
-            self.saveModel('seg' if self.cli_args.segmentation else 'cls', epoch_ndx, score == best_score)
-
-        if hasattr(self, 'trn_writer'):
-            self.trn_writer.close()
-            self.tst_writer.close()
-
-    def doTraining(self, epoch_ndx, train_dl):
-        self.model.train()
-        trainingMetrics_tensor = torch.zeros(METRICS_SIZE, len(train_dl.dataset))
-        # train_dl.dataset.shuffleSamples()
-        batch_iter = enumerateWithEstimate(
-            train_dl,
-            "E{} Training".format(epoch_ndx),
-            start_ndx=train_dl.num_workers,
-        )
-        for batch_ndx, batch_tup in batch_iter:
-            self.optimizer.zero_grad()
-
-            if self.cli_args.segmentation:
-                loss_var = self.computeSegmentationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
-            else:
-                loss_var = self.computeClassificationLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
-
-            if loss_var is not None:
-                loss_var.backward()
-                self.optimizer.step()
-            del loss_var
-
-        self.totalTrainingSamples_count += trainingMetrics_tensor.size(1)
-
-        return trainingMetrics_tensor
-
-    def doTesting(self, epoch_ndx, test_dl):
-        with torch.no_grad():
-            self.model.eval()
-            testingMetrics_tensor = torch.zeros(METRICS_SIZE, len(test_dl.dataset))
+            # Training loop, very similar to below
+            self.model.train()
+            trainingMetrics_tensor = torch.zeros(3, len(train_dl.dataset), 1)
             batch_iter = enumerateWithEstimate(
-                test_dl,
-                "E{} Testing ".format(epoch_ndx),
-                start_ndx=test_dl.num_workers,
+                train_dl,
+                "E{} Training".format(epoch_ndx),
+                start_ndx=train_dl.num_workers,
             )
             for batch_ndx, batch_tup in batch_iter:
-                if self.cli_args.segmentation:
-                    self.computeSegmentationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
-                else:
-                    self.computeClassificationLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
+                self.optimizer.zero_grad()
+                loss_var = self.computeBatchLoss(batch_ndx, batch_tup, train_dl.batch_size, trainingMetrics_tensor)
+                loss_var.backward()
+                self.optimizer.step()
+                del loss_var
+
+            # Testing loop, very similar to above, but simplified
+            with torch.no_grad():
+                self.model.eval()
+                testingMetrics_tensor = torch.zeros(3, len(test_dl.dataset), 1)
+                batch_iter = enumerateWithEstimate(
+                    test_dl,
+                    "E{} Testing ".format(epoch_ndx),
+                    start_ndx=test_dl.num_workers,
+                )
+                for batch_ndx, batch_tup in batch_iter:
+                    self.computeBatchLoss(batch_ndx, batch_tup, test_dl.batch_size, testingMetrics_tensor)
+
+            self.logMetrics(epoch_ndx, trainingMetrics_tensor, testingMetrics_tensor)
 
-        return testingMetrics_tensor
 
-    def computeClassificationLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
+    def computeBatchLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
         input_tensor, label_tensor, _series_list, _center_list = batch_tup
 
         input_devtensor = input_tensor.to(self.device)
@@ -334,376 +135,81 @@ class LunaTrainingApp(object):
         prediction_devtensor = self.model(input_devtensor)
         loss_devtensor = nn.MSELoss(reduction='none')(prediction_devtensor, label_devtensor)
 
-        start_ndx = batch_ndx * batch_size
-        end_ndx = start_ndx + label_tensor.size(0)
-
-        with torch.no_grad():
-            # log.debug([metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx].shape, label_tensor.shape])
-
-            metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor[:,0]
-            metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')[:,0]
-            # metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
-
-
-
-
-            prediction_tensor = prediction_devtensor.to('cpu', non_blocking=True)
-            loss_tensor = loss_devtensor.to('cpu', non_blocking=True)[:,0]
-            malLabel_tensor = (label_tensor > 0.5)[:,0]
-            benLabel_tensor = ~malLabel_tensor
-
-
-            malPred_tensor = prediction_tensor > 0.5
-            benPred_tensor = ~malPred_tensor
-            metrics_tensor[METRICS_MTP_NDX, start_ndx:end_ndx] = (malLabel_tensor * malPred_tensor).sum(dim=1)
-            metrics_tensor[METRICS_MFN_NDX, start_ndx:end_ndx] = (malLabel_tensor * benPred_tensor).sum(dim=1)
-            metrics_tensor[METRICS_MFP_NDX, start_ndx:end_ndx] = (benLabel_tensor * malPred_tensor).sum(dim=1)
-
-            metrics_tensor[METRICS_BTP_NDX, start_ndx:end_ndx] = (benLabel_tensor * benPred_tensor).sum(dim=1)
-            metrics_tensor[METRICS_BFN_NDX, start_ndx:end_ndx] = (benLabel_tensor * malPred_tensor).sum(dim=1)
-            metrics_tensor[METRICS_BFP_NDX, start_ndx:end_ndx] = (malLabel_tensor * benPred_tensor).sum(dim=1)
-
-            metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_tensor
-
-            metrics_tensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = loss_tensor * benLabel_tensor.type(torch.float32)
-            metrics_tensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = loss_tensor * malLabel_tensor.type(torch.float32)
-
-
-        # TODO: replace with torch.autograd.detect_anomaly
-        # assert np.isfinite(metrics_tensor).all()
-
-        return loss_devtensor.mean()
-
-    def computeSegmentationLoss(self, batch_ndx, batch_tup, batch_size, metrics_tensor):
-        input_tensor, label_tensor, _series_list, _start_list = batch_tup
-
-        # if label_tensor.max() < 0.5:
-        #     return None
-
-        input_devtensor = input_tensor.to(self.device)
-        label_devtensor = label_tensor.to(self.device)
-
-        prediction_devtensor = self.model(input_devtensor)
-
-        # assert prediction_devtensor.is_contiguous()
 
         start_ndx = batch_ndx * batch_size
         end_ndx = start_ndx + label_tensor.size(0)
-        max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
-        intersectionSum = lambda a, b: (a * b.to(torch.float32)).view(a.size(0), -1).sum(dim=1)
-
-        diceLoss_devtensor = self.diceLoss(label_devtensor, prediction_devtensor)
-        malLoss_devtensor = self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0])
-        benLoss_devtensor = self.diceLoss(label_devtensor[:,1], prediction_devtensor[:,1])
-
-        with torch.no_grad():
-            bPred_tensor = prediction_devtensor.to('cpu', non_blocking=True)
-            diceLoss_tensor = diceLoss_devtensor.to('cpu', non_blocking=True)
-            malLoss_tensor = malLoss_devtensor.to('cpu', non_blocking=True)
-            benLoss_tensor = benLoss_devtensor.to('cpu', non_blocking=True)
-
-            # flgLoss_devtensor = self.diceLoss(label_devtensor[:,0], label_devtensor[:,0] * prediction_devtensor[:,1])
-            # flgLoss_tensor = flgLoss_devtensor.to('cpu', non_blocking=True)#.unsqueeze(1)
-
-            metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = max2(label_tensor[:,0]) + max2(label_tensor[:,1]) * 2
-            # metrics_tensor[METRICS_MFOUND_NDX, start_ndx:end_ndx] = (max2(label_tensor[:, 0] * bPred_tensor[:, 1].to(torch.float32)) > 0.5)
-
-            # metrics_tensor[METRICS_MOK_NDX, start_ndx:end_ndx] = intersectionSum( label_tensor[:,0],  bPred_tensor[:,1])
-
-            bPred_tensor = bPred_tensor > 0.5
-            metrics_tensor[METRICS_MTP_NDX, start_ndx:end_ndx] = intersectionSum(    label_tensor[:,0],  bPred_tensor[:,0])
-            metrics_tensor[METRICS_MFN_NDX, start_ndx:end_ndx] = intersectionSum(    label_tensor[:,0], ~bPred_tensor[:,0])
-            metrics_tensor[METRICS_MFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,0],  bPred_tensor[:,0])
-
-            metrics_tensor[METRICS_BTP_NDX, start_ndx:end_ndx] = intersectionSum(    label_tensor[:,1],  bPred_tensor[:,1])
-            metrics_tensor[METRICS_BFN_NDX, start_ndx:end_ndx] = intersectionSum(    label_tensor[:,1], ~bPred_tensor[:,1])
-            metrics_tensor[METRICS_BFP_NDX, start_ndx:end_ndx] = intersectionSum(1 - label_tensor[:,1],  bPred_tensor[:,1])
-
-            metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = diceLoss_tensor
-
-            metrics_tensor[METRICS_BEN_LOSS_NDX, start_ndx:end_ndx] = benLoss_tensor
-            metrics_tensor[METRICS_MAL_LOSS_NDX, start_ndx:end_ndx] = malLoss_tensor
-            # metrics_tensor[METRICS_FLG_LOSS_NDX, start_ndx:end_ndx] = flgLoss_tensor
-
-
-            # lungLoss_devtensor = self.diceLoss(label_devtensor[:,2], prediction_devtensor[:,2])
-            # lungLoss_tensor = lungLoss_devtensor.to('cpu').unsqueeze(1)
-            # metrics_tensor[METRICS_LUNG_LOSS_NDX, start_ndx:end_ndx] = lungLoss_tensor
+        metrics_tensor[METRICS_LABEL_NDX, start_ndx:end_ndx] = label_tensor
+        metrics_tensor[METRICS_PRED_NDX, start_ndx:end_ndx] = prediction_devtensor.to('cpu')
+        metrics_tensor[METRICS_LOSS_NDX, start_ndx:end_ndx] = loss_devtensor.to('cpu')
 
         # TODO: replace with torch.autograd.detect_anomaly
         # assert np.isfinite(metrics_tensor).all()
 
-        # return nn.MSELoss()(prediction_devtensor, label_devtensor)
-
-        return malLoss_devtensor.mean() + benLoss_devtensor.mean()
-        # return self.diceLoss(label_devtensor[:,0], prediction_devtensor[:,0]).mean()
-
-    def diceLoss(self, label_devtensor, prediction_devtensor, epsilon=0.01, p=False):
-        # sum2 = lambda t: t.sum([1,2,3,4])
-        sum2 = lambda t: t.view(t.size(0), -1).sum(dim=1)
-        # max2 = lambda t: t.view(t.size(0), -1).max(dim=1)[0]
-
-        diceCorrect_devtensor = sum2(prediction_devtensor * label_devtensor)
-        dicePrediction_devtensor = sum2(prediction_devtensor)
-        diceLabel_devtensor = sum2(label_devtensor)
-        epsilon_devtensor = torch.ones_like(diceCorrect_devtensor) * epsilon
-        diceLoss_devtensor = 1 - (2 * diceCorrect_devtensor + epsilon_devtensor) / (dicePrediction_devtensor + diceLabel_devtensor + epsilon_devtensor)
-
-        if not torch.isfinite(diceLoss_devtensor).all():
-            log.debug('')
-            log.debug('diceLoss_devtensor')
-            log.debug(diceLoss_devtensor.to('cpu'))
-            log.debug('diceCorrect_devtensor')
-            log.debug(diceCorrect_devtensor.to('cpu'))
-            log.debug('dicePrediction_devtensor')
-            log.debug(dicePrediction_devtensor.to('cpu'))
-            log.debug('diceLabel_devtensor')
-            log.debug(diceLabel_devtensor.to('cpu'))
-
-        return diceLoss_devtensor
-
-
-
-    def logImages(self, epoch_ndx, train_dl, test_dl):
-        for mode_str, dl in [('trn', train_dl), ('tst', test_dl)]:
-            for i, series_uid in enumerate(sorted(dl.dataset.series_list)[:12]):
-                ct = getCt(series_uid)
-                noduleInfo_tup = (ct.malignantInfo_list or ct.benignInfo_list)[0]
-                center_irc = xyz2irc(noduleInfo_tup.center_xyz, ct.origin_xyz, ct.vxSize_xyz, ct.direction_tup)
-
-                sample_tup = dl.dataset[(series_uid, int(center_irc.index))]
-                input_tensor = sample_tup[0].unsqueeze(0)
-                label_tensor = sample_tup[1].unsqueeze(0)
-
-                input_devtensor = input_tensor.to(self.device)
-                label_devtensor = label_tensor.to(self.device)
-
-                prediction_devtensor = self.model(input_devtensor)
-                prediction_ary = prediction_devtensor.to('cpu').detach().numpy()
-
-                image_ary = np.zeros((512, 512, 3), dtype=np.float32)
-                image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
-                image_ary[:,:,0] += prediction_ary[0,0] * 0.5
-                image_ary[:,:,1] += prediction_ary[0,1] * 0.25
-                # image_ary[:,:,2] += prediction_ary[0,2] * 0.5
-
-                # log.debug([image_ary.__array_interface__['typestr']])
-
-                # image_ary = (image_ary * 255).astype(np.uint8)
-
-                # log.debug([image_ary.__array_interface__['typestr']])
-
-                writer = getattr(self, mode_str + '_writer')
-                try:
-                    writer.add_image('{}/{}_pred'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
-                except:
-                    log.debug([image_ary.shape, image_ary.dtype])
-                    raise
-
-                if epoch_ndx == 1:
-                    label_ary = label_tensor.numpy()
-
-                    image_ary = np.zeros((512, 512, 3), dtype=np.float32)
-                    image_ary[:,:,:] = (input_tensor[0,2].numpy().reshape((512,512,1))) * 0.25
-                    image_ary[:,:,0] += label_ary[0,0] * 0.5
-                    image_ary[:,:,1] += label_ary[0,1] * 0.25
-                    image_ary[:,:,2] += (input_tensor[0,-1].numpy() - (label_ary[0,0].astype(np.bool) | label_ary[0,1].astype(np.bool))) * 0.25
-
-                    # log.debug([image_ary.__array_interface__['typestr']])
-
-                    image_ary = (image_ary * 255).astype(np.uint8)
-
-                    # log.debug([image_ary.__array_interface__['typestr']])
-
-                    writer = getattr(self, mode_str + '_writer')
-                    writer.add_image('{}/{}_label'.format(mode_str, i), image_ary, self.totalTrainingSamples_count, dataformats='HWC')
+        return loss_devtensor.mean()
 
 
-    def logPerformanceMetrics(self,
-                              epoch_ndx,
-                              mode_str,
-                              metrics_tensor,
-                              # trainingMetrics_tensor,
-                              # testingMetrics_tensor,
-                              classificationThreshold_float=0.5,
-                              ):
+    def logMetrics(self,
+                   epoch_ndx,
+                   trainingMetrics_tensor,
+                   testingMetrics_tensor,
+                   classificationThreshold_float=0.5,
+                   ):
         log.info("E{} {}".format(
             epoch_ndx,
             type(self).__name__,
         ))
 
-        score = 0.0
-
-
-        # for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
-        metrics_ary = metrics_tensor.cpu().detach().numpy()
-        sum_ary = metrics_ary.sum(axis=1)
-        assert np.isfinite(metrics_ary).all()
-
-        malLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 1) | (metrics_ary[METRICS_LABEL_NDX] == 3)
-
-        if self.cli_args.segmentation:
-            benLabel_mask = (metrics_ary[METRICS_LABEL_NDX] == 2) | (metrics_ary[METRICS_LABEL_NDX] == 3)
-        else:
-            benLabel_mask = ~malLabel_mask
-        # malFound_mask = metrics_ary[METRICS_MFOUND_NDX] > classificationThreshold_float
-
-        # malLabel_mask = ~benLabel_mask
-        # malPred_mask = ~benPred_mask
-
-        benLabel_count = sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]
-        malLabel_count = sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]
-
-        trueNeg_count = benCorrect_count = sum_ary[METRICS_BTP_NDX]
-        truePos_count = malCorrect_count = sum_ary[METRICS_MTP_NDX]
-#
-#             falsePos_count = benLabel_count - benCorrect_count
-#             falseNeg_count = malLabel_count - malCorrect_count
-
+        for mode_str, metrics_tensor in [('trn', trainingMetrics_tensor), ('tst', testingMetrics_tensor)]:
+            metrics_ary = metrics_tensor.detach().numpy()[:,:,0]
+            assert np.isfinite(metrics_ary).all()
 
-        metrics_dict = {}
-        metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
-        # metrics_dict['loss/msk'] = metrics_ary[METRICS_MASKLOSS_NDX].mean()
-        # metrics_dict['loss/mal'] = metrics_ary[METRICS_MALLOSS_NDX].mean()
-        # metrics_dict['loss/lng'] = metrics_ary[METRICS_LUNG_LOSS_NDX, benLabel_mask].mean()
-        metrics_dict['loss/mal'] = np.nan_to_num(metrics_ary[METRICS_MAL_LOSS_NDX, malLabel_mask].mean())
-        metrics_dict['loss/ben'] = metrics_ary[METRICS_BEN_LOSS_NDX, benLabel_mask].mean()
-        # metrics_dict['loss/flg'] = metrics_ary[METRICS_FLG_LOSS_NDX].mean()
+            benLabel_mask = metrics_ary[METRICS_LABEL_NDX] <= classificationThreshold_float
+            benPred_mask = metrics_ary[METRICS_PRED_NDX] <= classificationThreshold_float
 
-        # metrics_dict['flagged/all'] = sum_ary[METRICS_MOK_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
-        # metrics_dict['flagged/slices'] = (malLabel_mask & malFound_mask).sum() / malLabel_mask.sum() * 100
+            malLabel_mask = ~benLabel_mask
+            malPred_mask = ~benPred_mask
 
-        metrics_dict['correct/mal'] = sum_ary[METRICS_MTP_NDX] / (sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) * 100
-        metrics_dict['correct/ben'] = sum_ary[METRICS_BTP_NDX] / (sum_ary[METRICS_BTP_NDX] + sum_ary[METRICS_BFN_NDX]) * 100
+            benLabel_count = benLabel_mask.sum()
+            malLabel_count = malLabel_mask.sum()
 
-        precision = metrics_dict['pr/precision'] = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFP_NDX]) or 1)
-        recall    = metrics_dict['pr/recall']    = sum_ary[METRICS_MTP_NDX] / ((sum_ary[METRICS_MTP_NDX] + sum_ary[METRICS_MFN_NDX]) or 1)
+            benCorrect_count = (benLabel_mask & benPred_mask).sum()
+            malCorrect_count = (malLabel_mask & malPred_mask).sum()
 
-        metrics_dict['pr/f1_score'] = 2 * (precision * recall) / ((precision + recall) or 1)
+            metrics_dict = {}
 
-        log.info(("E{} {:8} "
-                 + "{loss/all:.4f} loss, "
-                 # + "{loss/flg:.4f} flagged loss, "
-                 # + "{flagged/all:-5.1f}% pixels flagged, "
-                 # + "{flagged/slices:-5.1f}% slices flagged, "
-                 + "{pr/precision:.4f} precision, "
-                 + "{pr/recall:.4f} recall, "
-                 + "{pr/f1_score:.4f} f1 score"
-                  ).format(
-            epoch_ndx,
-            mode_str,
-            **metrics_dict,
-        ))
-        log.info(("E{} {:8} "
-                 + "{loss/mal:.4f} loss, "
-                 + "{correct/mal:-5.1f}% correct ({malCorrect_count:} of {malLabel_count:})"
-        ).format(
-            epoch_ndx,
-            mode_str + '_mal',
-            malCorrect_count=malCorrect_count,
-            malLabel_count=malLabel_count,
-            **metrics_dict,
-        ))
-        log.info(("E{} {:8} "
-                 + "{loss/ben:.4f} loss, "
-                 + "{correct/ben:-5.1f}% correct ({benCorrect_count:} of {benLabel_count:})"
-        ).format(
-            epoch_ndx,
-            mode_str + '_ben',
-            benCorrect_count=benCorrect_count,
-            benLabel_count=benLabel_count,
-            **metrics_dict,
-        ))
+            metrics_dict['loss/all'] = metrics_ary[METRICS_LOSS_NDX].mean()
+            metrics_dict['loss/ben'] = metrics_ary[METRICS_LOSS_NDX, benLabel_mask].mean()
+            metrics_dict['loss/mal'] = metrics_ary[METRICS_LOSS_NDX, malLabel_mask].mean()
 
-        writer = getattr(self, mode_str + '_writer')
+            metrics_dict['correct/all'] = (malCorrect_count + benCorrect_count) / metrics_ary.shape[1] * 100
+            metrics_dict['correct/ben'] = (benCorrect_count) / benLabel_count * 100
+            metrics_dict['correct/mal'] = (malCorrect_count) / malLabel_count * 100
 
-        prefix_str = 'seg_' if self.cli_args.segmentation else ''
 
-        for key, value in metrics_dict.items():
-            writer.add_scalar(prefix_str + key, value, self.totalTrainingSamples_count)
 
-            if not self.cli_args.segmentation:
-                writer.add_pr_curve(
-                    'pr',
-                    metrics_ary[METRICS_LABEL_NDX],
-                    metrics_ary[METRICS_PRED_NDX],
-                    self.totalTrainingSamples_count,
-                )
-
-                benHist_mask = benLabel_mask & (metrics_ary[METRICS_PRED_NDX] > 0.01)
-                malHist_mask = malLabel_mask & (metrics_ary[METRICS_PRED_NDX] < 0.99)
-
-                bins = [x/50.0 for x in range(51)]
-                writer.add_histogram(
-                    'is_ben',
-                    metrics_ary[METRICS_PRED_NDX, benHist_mask],
-                    self.totalTrainingSamples_count,
-                    bins=bins,
-                )
-                writer.add_histogram(
-                    'is_mal',
-                    metrics_ary[METRICS_PRED_NDX, malHist_mask],
-                    self.totalTrainingSamples_count,
-                    bins=bins,
-                )
 
-        score = 1 \
-            + metrics_dict['pr/f1_score'] \
-            - metrics_dict['loss/mal'] * 0.01 \
-            - metrics_dict['loss/all'] * 0.0001
-
-        return score
-
-    def logModelMetrics(self, model):
-        writer = getattr(self, 'trn_writer')
-
-        model = getattr(model, 'module', model)
-
-        for name, param in model.named_parameters():
-            if param.requires_grad:
-                min_data = float(param.data.min())
-                max_data = float(param.data.max())
-                max_extent = max(abs(min_data), abs(max_data))
-
-                bins = [x/50*max_extent for x in range(-50, 51)]
-
-                writer.add_histogram(
-                    name.rsplit('.', 1)[-1] + '/' + name,
-                    param.data.cpu().numpy(),
-                    # metrics_ary[METRICS_PRED_NDX, benHist_mask],
-                    self.totalTrainingSamples_count,
-                    bins=bins,
-                )
-
-                # print name, param.data
-
-    def saveModel(self, type_str, epoch_ndx, isBest=False):
-        file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, self.totalTrainingSamples_count))
-
-        os.makedirs(os.path.dirname(file_path), mode=0o755, exist_ok=True)
-
-        model = self.model
-        if hasattr(model, 'module'):
-            model = model.module
-
-        state = {
-            'model_state': model.state_dict(),
-            'model_name': type(model).__name__,
-            'optimizer_state' : self.optimizer.state_dict(),
-            'optimizer_name': type(self.optimizer).__name__,
-            'epoch': epoch_ndx,
-            'totalTrainingSamples_count': self.totalTrainingSamples_count,
-            # 'resumed_from': self.cli_args.resume,
-        }
-        torch.save(state, file_path)
-
-        log.debug("Saved model params to {}".format(file_path))
-
-        if isBest:
-            file_path = os.path.join('data-unversioned', 'models', self.cli_args.tb_prefix, '{}_{}_{}.{}.state'.format(type_str, self.time_str, self.cli_args.comment, 'best'))
-            torch.save(state, file_path)
-
-            log.debug("Saved model params to {}".format(file_path))
+            log.info(("E{} {:8} "
+                     + "{loss/all:.4f} loss, "
+                     + "{correct/all:-5.1f}% correct"
+                      ).format(
+                epoch_ndx,
+                mode_str,
+                **metrics_dict,
+            ))
+            log.info(("E{} {:8} "
+                     + "{loss/ben:.4f} loss, "
+                     + "{correct/ben:-5.1f}% correct").format(
+                epoch_ndx,
+                mode_str + '_ben',
+                **metrics_dict,
+            ))
+            log.info(("E{} {:8} "
+                     + "{loss/mal:.4f} loss, "
+                     + "{correct/mal:-5.1f}% correct").format(
+                epoch_ndx,
+                mode_str + '_mal',
+                **metrics_dict,
+            ))
 
 
 if __name__ == '__main__':

+ 86 - 0
p2ch10/vis.py

@@ -0,0 +1,86 @@
+import matplotlib
+matplotlib.use('nbagg')
+
+import numpy as np
+import matplotlib.pyplot as plt
+
+from p2ch11_old.dsets import Ct, LunaDataset
+
+clim=(0.0, 1.3)
+
+def findMalignantSamples(start_ndx=0, limit=10):
+    ds = LunaDataset()
+
+    malignantSample_list = []
+    for sample_tup in ds.sample_list:
+        if sample_tup[2]:
+            print(len(malignantSample_list), sample_tup)
+            malignantSample_list.append(sample_tup)
+
+        if len(malignantSample_list) >= limit:
+            break
+
+    return malignantSample_list
+
+def showNodule(series_uid, batch_ndx=None, **kwargs):
+    ds = LunaDataset(series_uid=series_uid, **kwargs)
+    malignant_list = [i for i, x in enumerate(ds.sample_list) if x[2]]
+
+    if batch_ndx is None:
+        if malignant_list:
+            batch_ndx = malignant_list[0]
+        else:
+            print("Warning: no malignant samples found; using first non-malignant sample.")
+            batch_ndx = 0
+
+    ct = Ct(series_uid)
+    # ct_tensor, malignant_tensor, series_uid, center_irc = ds[batch_ndx]
+    malignant_tensor, diameter_mm, series_uid, center_irc, nodule_tensor = ds[batch_ndx]
+    ct_ary = nodule_tensor[1].numpy()
+
+
+    fig = plt.figure(figsize=(15, 25))
+
+    group_list = [
+        #[0,1,2],
+        [3,4,5],
+        [6,7,8],
+        [9,10,11],
+        #[12,13,14],
+        #[15]
+    ]
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 1)
+    subplot.set_title('index {}'.format(int(center_irc.index)))
+    plt.imshow(ct.ary[int(center_irc.index)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 2)
+    subplot.set_title('row {}'.format(int(center_irc.row)))
+    plt.imshow(ct.ary[:,int(center_irc.row)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 3)
+    subplot.set_title('col {}'.format(int(center_irc.col)))
+    plt.imshow(ct.ary[:,:,int(center_irc.col)], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 4)
+    subplot.set_title('index {}'.format(int(center_irc.index)))
+    plt.imshow(ct_ary[ct_ary.shape[0]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 5)
+    subplot.set_title('row {}'.format(int(center_irc.row)))
+    plt.imshow(ct_ary[:,ct_ary.shape[1]//2], clim=clim, cmap='gray')
+
+    subplot = fig.add_subplot(len(group_list) + 2, 3, 6)
+    subplot.set_title('col {}'.format(int(center_irc.col)))
+    plt.imshow(ct_ary[:,:,ct_ary.shape[2]//2], clim=clim, cmap='gray')
+
+    for row, index_list in enumerate(group_list):
+        for col, index in enumerate(index_list):
+            subplot = fig.add_subplot(len(group_list) + 2, 3, row * 3 + col + 7)
+            subplot.set_title('slice {}'.format(index))
+            plt.imshow(ct_ary[index*2], clim=clim, cmap='gray')
+
+
+    print(series_uid, batch_ndx, bool(malignant_tensor[0]), malignant_list, ct.vxSize_xyz)
+
+    return ct_ary

+ 1 - 1
util/disk.py

@@ -83,7 +83,7 @@ def getCache(scope_str):
                        shards=128,
                        timeout=1,
                        size_limit=2e11,
-                       disk_min_file_size=2**20,
+                       # disk_min_file_size=2**20,
                        )
 
 # def disk_cache(base_path, memsize=2):

+ 4 - 4
util/unet.py

@@ -97,15 +97,15 @@ class UNetConvBlock(nn.Module):
 
         block.append(nn.Conv2d(in_size, out_size, kernel_size=3,
                                padding=int(padding)))
-        block.append(nn.ReLU())
-        # block.append(nn.LeakyReLU())
+        # block.append(nn.ReLU())
+        block.append(nn.LeakyReLU())
         if batch_norm:
             block.append(nn.BatchNorm2d(out_size))
 
         block.append(nn.Conv2d(out_size, out_size, kernel_size=3,
                                padding=int(padding)))
-        block.append(nn.ReLU())
-        # block.append(nn.LeakyReLU())
+        # block.append(nn.ReLU())
+        block.append(nn.LeakyReLU())
         if batch_norm:
             block.append(nn.BatchNorm2d(out_size))
 

+ 48 - 48
util/util.py

@@ -254,51 +254,51 @@ so by default we double the gap between logging messages each time after the fir
         str(datetime.datetime.now()).rsplit('.', 1)[0],
     ))
 
-
-try:
-    import matplotlib
-    matplotlib.use('agg', warn=False)
-
-    import matplotlib.pyplot as plt
-    # matplotlib color maps
-    cdict = {'red':   ((0.0,  1.0, 1.0),
-                       # (0.5,  1.0, 1.0),
-                       (1.0,  1.0, 1.0)),
-
-             'green': ((0.0,  0.0, 0.0),
-                       (0.5,  0.0, 0.0),
-                       (1.0,  0.5, 0.5)),
-
-             'blue':  ((0.0,  0.0, 0.0),
-                       # (0.5,  0.5, 0.5),
-                       # (0.75, 0.0, 0.0),
-                       (1.0,  0.0, 0.0)),
-
-             'alpha':  ((0.0, 0.0, 0.0),
-                       (0.75, 0.5, 0.5),
-                       (1.0,  0.5, 0.5))}
-
-    plt.register_cmap(name='mask', data=cdict)
-
-    cdict = {'red':   ((0.0,  0.0, 0.0),
-                       (0.25,  1.0, 1.0),
-                       (1.0,  1.0, 1.0)),
-
-             'green': ((0.0,  1.0, 1.0),
-                       (0.25,  1.0, 1.0),
-                       (0.5, 0.0, 0.0),
-                       (1.0,  0.0, 0.0)),
-
-             'blue':  ((0.0,  0.0, 0.0),
-                       # (0.5,  0.5, 0.5),
-                       # (0.75, 0.0, 0.0),
-                       (1.0,  0.0, 0.0)),
-
-             'alpha':  ((0.0, 0.15, 0.15),
-                       (0.5,  0.3, 0.3),
-                       (0.8,  0.0, 0.0),
-                       (1.0,  0.0, 0.0))}
-
-    plt.register_cmap(name='maskinvert', data=cdict)
-except ImportError:
-    pass
+#
+# try:
+#     import matplotlib
+#     matplotlib.use('agg', warn=False)
+#
+#     import matplotlib.pyplot as plt
+#     # matplotlib color maps
+#     cdict = {'red':   ((0.0,  1.0, 1.0),
+#                        # (0.5,  1.0, 1.0),
+#                        (1.0,  1.0, 1.0)),
+#
+#              'green': ((0.0,  0.0, 0.0),
+#                        (0.5,  0.0, 0.0),
+#                        (1.0,  0.5, 0.5)),
+#
+#              'blue':  ((0.0,  0.0, 0.0),
+#                        # (0.5,  0.5, 0.5),
+#                        # (0.75, 0.0, 0.0),
+#                        (1.0,  0.0, 0.0)),
+#
+#              'alpha':  ((0.0, 0.0, 0.0),
+#                        (0.75, 0.5, 0.5),
+#                        (1.0,  0.5, 0.5))}
+#
+#     plt.register_cmap(name='mask', data=cdict)
+#
+#     cdict = {'red':   ((0.0,  0.0, 0.0),
+#                        (0.25,  1.0, 1.0),
+#                        (1.0,  1.0, 1.0)),
+#
+#              'green': ((0.0,  1.0, 1.0),
+#                        (0.25,  1.0, 1.0),
+#                        (0.5, 0.0, 0.0),
+#                        (1.0,  0.0, 0.0)),
+#
+#              'blue':  ((0.0,  0.0, 0.0),
+#                        # (0.5,  0.5, 0.5),
+#                        # (0.75, 0.0, 0.0),
+#                        (1.0,  0.0, 0.0)),
+#
+#              'alpha':  ((0.0, 0.15, 0.15),
+#                        (0.5,  0.3, 0.3),
+#                        (0.8,  0.0, 0.0),
+#                        (1.0,  0.0, 0.0))}
+#
+#     plt.register_cmap(name='maskinvert', data=cdict)
+# except ImportError:
+#     pass

Энэ ялгаанд хэт олон файл өөрчлөгдсөн тул зарим файлыг харуулаагүй болно